From 4688e1a916967467e38c40ff179fada716b67581 Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Sun, 23 Mar 2025 16:03:17 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0spark=E9=80=9A=E7=94=A8?= =?UTF-8?q?=E8=B0=83=E7=94=A8=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- service/imService.go | 118 +++++++++++-------- worker/spark.go | 274 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 340 insertions(+), 52 deletions(-) create mode 100644 worker/spark.go diff --git a/service/imService.go b/service/imService.go index e393321..2f01766 100644 --- a/service/imService.go +++ b/service/imService.go @@ -84,58 +84,72 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS var resID uint //处理消息 if msg.Type == "ollama" { - var err error - if client == nil { - client, err = api.ClientFromEnvironment() - } - if err != nil { - log.Println("get ollama client error:", err) - } - prompt := msg.Msg - actx, ok := SessionContextMap[sessionID] - if ok { - log.Println("actx has get will use:", len(actx)) - } - req := &api.GenerateRequest{ - Model: "qwen2.5:0.5b", - Context: actx, - Prompt: prompt, - } - ctx := context.Background() - robotMsg := "" - var aiMsg proto.AIQueueMessage - var data proto.WSMessage - respFunc := func(resp api.GenerateResponse) error { - aiMsg.Type = "ollama" - aiMsg.Msg = resp - robotMsg += resp.Response - if resp.Done { //该消息完成 - actx = resp.Context - if SessionContextMap == nil { - SessionContextMap = make(map[int][]int) - } - SessionContextMap[sessionID] = actx - log.Println("actx has save:", len(SessionContextMap[sessionID])) - //创建消息 - err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1 - if err2 != nil { - log.Println("create robot message error:", err2) - } - aiMsg.ID = msgID - resID = msgID - } - //发送消息 - data.Msg = aiMsg - data.SessionID = sessionID - data.ToID = userID - aiMsgStr, _ := json.Marshal(data) - worker.Publish(channel, string(aiMsgStr), time.Second*60) - return nil - } - err = client.Generate(ctx, req, respFunc) - if err != nil { - log.Fatal(err) - } + resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg) + } else if msg.Type == "spark" { + } return resErr, resID } + +func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) { + var resErr error + var resID uint + var err error + if client == nil { + client, err = api.ClientFromEnvironment() + } + if err != nil { + log.Println("get ollama client error:", err) + } + prompt := msg.Msg + actx, ok := SessionContextMap[sessionID] + if ok { + log.Println("actx has get will use:", len(actx)) + } + req := &api.GenerateRequest{ + Model: "qwen2.5:0.5b", + Context: actx, + Prompt: prompt, + } + ctx := context.Background() + robotMsg := "" + var aiMsg proto.AIQueueMessage + var data proto.WSMessage + respFunc := func(resp api.GenerateResponse) error { + aiMsg.Type = "ollama" + aiMsg.Msg = resp + robotMsg += resp.Response + if resp.Done { //该消息完成 + actx = resp.Context + if SessionContextMap == nil { + SessionContextMap = make(map[int][]int) + } + SessionContextMap[sessionID] = actx + log.Println("actx has save:", len(SessionContextMap[sessionID])) + //创建消息 + err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1 + if err2 != nil { + log.Println("create robot message error:", err2) + } + aiMsg.ID = msgID + resID = msgID + } + //发送消息 + data.Msg = aiMsg + data.SessionID = sessionID + data.ToID = userID + aiMsgStr, _ := json.Marshal(data) + worker.Publish(channel, string(aiMsgStr), time.Second*60) + return nil + } + err = client.Generate(ctx, req, respFunc) + if err != nil { + log.Fatal(err) + } + return resErr, resID +} + +func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) { + + return nil, 0 +} diff --git a/worker/spark.go b/worker/spark.go new file mode 100644 index 0000000..614bd5a --- /dev/null +++ b/worker/spark.go @@ -0,0 +1,274 @@ +package worker + +import ( + "StuAcaWorksAI/dao" + "StuAcaWorksAI/proto" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "github.com/gorilla/websocket" + "github.com/ollama/ollama/api" + "io" + "log" + "net/http" + "net/url" + "strings" + "time" +) + +/***************** Spark请求结构 ***********************/ +type SparkRequestHeader struct { + AppID string `json:"app_id"` + UID string `json:"uid"` +} + +type ChatParameter struct { + Domain string `json:"domain"` + Temperature float64 `json:"temperature"` + MaxTokens int `json:"max_tokens"` +} + +type Parameter struct { + Chat ChatParameter `json:"chat"` +} + +type MessageContent struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type SparkMessage struct { + Text []MessageContent `json:"text"` +} + +type SparkRequestPayload struct { + Message SparkMessage `json:"message"` +} + +type SparkRequest struct { + Header SparkRequestHeader `json:"header"` + Parameter Parameter `json:"parameter"` + Payload SparkRequestPayload `json:"payload"` +} + +/***************** Spark请求结构 ***********************/ + +/***************** Spark响应结构 ***********************/ + +type SparkResponseHeader struct { + Code int `json:"code"` + Message string `json:"message"` + SID string `json:"sid"` + Status int `json:"status"` //会话状态,取值为[0,1,2];0代表首次结果;1代表中间结果;2代表最后一个结果 +} + +// ChoiceText 结构体对应 JSON 中 "choices.text[].text" 部分 +type ChoiceText struct { + Content string `json:"content"` + Role string `json:"role"` + Index int `json:"index"` +} + +// Choices 结构体对应 JSON 中 "choices" 部分 +type Choices struct { + Status int `json:"status"` + Seq int `json:"seq"` + Text []ChoiceText `json:"text"` +} + +// UsageText 结构体对应 JSON 中 "usage.text" 部分 +type UsageText struct { + QuestionTokens int `json:"question_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// Usage 结构体对应 JSON 中 "usage" 部分 +type Usage struct { + Text UsageText `json:"text"` +} + +// Payload 结构体对应 JSON 中 "payload" 部分 +type SparkResponsePayload struct { + Choices Choices `json:"choices"` + Usage Usage `json:"usage"` +} + +// Request 结构体对应整个 JSON 数据 +type SparkResponse struct { + Header SparkResponseHeader `json:"header"` + Payload SparkResponsePayload `json:"payload"` +} + +/***************** Spark响应结构 ***********************/ + +func Spark(hostUrl string, apiKey, apiSecret, appid string, question, channel string, SessionID, UserID int) { + //建立websocket连接 + dialer := websocket.Dialer{ + HandshakeTimeout: 5 * time.Second, + } + //建立连接 + conn, resp, err := dialer.Dial(assembleAuthUrl(hostUrl, apiKey, apiSecret), nil) + if err != nil { + panic(readResp(resp) + err.Error()) + return + } else if resp.StatusCode != 101 { + panic(readResp(resp) + err.Error()) + } + + //发送消息 + go func() { + data := genSparkParams(appid, question) + //将数据转换为json + //dataByte, err := json.Marshal(data) + if err != nil { + fmt.Println("Error parsing JSON:", err) + return + } + err2 := conn.WriteJSON(data) + if err != nil { + fmt.Println("write message error:", err2) + return + } + }() + + answer := "" + + //接收消息 + for { + _, msg, err2 := conn.ReadMessage() + if err2 != nil { + log.Println("read message error:", err2) + break + } + var data SparkResponse + err = json.Unmarshal(msg, &data) + if err != nil { + log.Println("Error parsing JSON:", err) + return + } + + if data.Header.Code != 0 { + log.Println("Error response:", data.Header.Message) + break + } + if data.Header.Status == 2 { + //最后一个结果 + answer += data.Payload.Choices.Text[0].Content + } else { + answer += data.Payload.Choices.Text[0].Content + } + SparkToGeneralMassageAndSendMsgQueue(&data, channel, SessionID, UserID) + } + + conn.Close() + //将消息存入数据库 + err, _ = dao.CreateMessage(SessionID, 3, UserID, 2, 1, answer) + if err != nil { + log.Println("spark create message error:", err) + } +} + +func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponsem, channel string, SessionID int, UserID int) { + var aiMsg proto.AIQueueMessage + var wsMsg proto.WSMessage + var apiMsg api.GenerateResponse + apiMsg.Model = "spark" + apiMsg.Response = data.Payload.Choices.Text[0].Content + if data.Header.Status == 2 { + apiMsg.Done = true + } else { + apiMsg.Done = false + } + //消息队列部分 + aiMsg.Type = "spark" + aiMsg.Msg = apiMsg + //ws发送消息部分 + wsMsg.Msg = aiMsg + wsMsg.SessionID = SessionID + wsMsg.ToID = UserID + wsMsg.Type = "spark" + + //发送消息 + wsMsgStr, _ := json.Marshal(wsMsg) + Publish(channel, string(wsMsgStr), time.Second*60) +} + +// 生成参数 +func genSparkParams(question, appid string) SparkRequest { + messages := []MessageContent{ + {Role: "user", Content: question}, + } + + data := SparkRequest{ + Header: SparkRequestHeader{ + AppID: appid, + }, + Parameter: Parameter{ + Chat: ChatParameter{ + Domain: "4.0Ultra", + Temperature: 0.8, + MaxTokens: 2048, + }, + }, + Payload: SparkRequestPayload{ + Message: SparkMessage{ + Text: messages, + }, + }, + } + return data +} + +// 创建鉴权url apikey 即 hmac username +func assembleAuthUrl(hostUrl string, apiKey, apiSecret string) string { + ul, err := url.Parse(hostUrl) + if err != nil { + fmt.Println(err) + } + //签名时间 + date := time.Now().UTC().Format(time.RFC1123) + //date = "Tue, 28 May 2019 09:10:42 MST" + //参与签名的字段 host ,date, request-line + signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} + //拼接签名字符串 + sgin := strings.Join(signString, "\n") + // fmt.Println(sgin) + //签名结果 + sha := HmacWithShaTobase64("hmac-sha256", sgin, apiSecret) + // fmt.Println(sha) + //构建请求参数 此时不需要urlencoding + authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, + "hmac-sha256", "host date request-line", sha) + //将请求参数使用base64编码 + authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) + + v := url.Values{} + v.Add("host", ul.Host) + v.Add("date", date) + v.Add("authorization", authorization) + //将编码后的字符串url encode后添加到url后面 + callurl := hostUrl + "?" + v.Encode() + return callurl +} + +func HmacWithShaTobase64(algorithm, data, key string) string { + mac := hmac.New(sha256.New, []byte(key)) + mac.Write([]byte(data)) + encodeData := mac.Sum(nil) + return base64.StdEncoding.EncodeToString(encodeData) +} + +func readResp(resp *http.Response) string { + if resp == nil { + return "" + } + b, err := io.ReadAll(resp.Body) + if err != nil { + panic(err) + } + return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b)) +}