From 142463a092eb96a2a6b219a92b4d527c137e1eb6 Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Sun, 23 Mar 2025 13:17:03 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BC=9A=E8=AF=9D?= =?UTF-8?q?=E4=B8=8A=E4=B8=8B=E6=96=87=E4=BF=9D=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dao/im.go | 11 ++++++----- handler/im.go | 13 +++++++++++++ handler/session.go | 2 +- service/imService.go | 16 ++++++++-------- 4 files changed, 28 insertions(+), 14 deletions(-) diff --git a/dao/im.go b/dao/im.go index e2b51fa..5112d34 100644 --- a/dao/im.go +++ b/dao/im.go @@ -7,8 +7,9 @@ import ( type Session struct { gorm.Model - UserID int `gorm:"column:user_id"` //只能由用户创建 - Name string `gorm:"column:name"` + UserID int `gorm:"column:user_id"` //只能由用户创建 + Name string `gorm:"column:name"` + Context []int `gorm:"column:context"` //会话上下文 } type Message struct { @@ -54,12 +55,12 @@ func FindSessionByUserID(userID int) []Session { } // 更新会话的名字 -func UpdateSessionByID(id int, userId int, name string) error { +func UpdateSessionByID(id int, userId int, name string, context []int) error { var res *gorm.DB if proto.Config.SERVER_SQL_LOG { - res = DB.Debug().Model(&Session{}).Where("id = ? and user_id = ?", id, userId).Update("name", name) + res = DB.Debug().Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name, Context: context}) } else { - res = DB.Model(&Session{}).Where("id = ? and user_id = ?", id, userId).Update("name", name) + res = DB.Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name, Context: context}) } return res.Error } diff --git a/handler/im.go b/handler/im.go index 925a9fb..ec2aa0f 100644 --- a/handler/im.go +++ b/handler/im.go @@ -176,6 +176,18 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) { for { select { case <-done: + //结束会话将context保存 + ctx2, ok := service.SessionContextMap[int(sessionID)] + if !ok { + log.Println("session context not found") + return + } + err := service.UpdateSessionByID(int(sessionID), userId, "", ctx2) + if err != nil { + log.Println("update session context error:", err) + } + //保存的会话 + delete(service.SessionContextMap, int(sessionID)) //客户端断开连接 return //case t := <-ticker.C: @@ -188,4 +200,5 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) { //} } } + } diff --git a/handler/session.go b/handler/session.go index 769fe8d..f070a92 100644 --- a/handler/session.go +++ b/handler/session.go @@ -70,7 +70,7 @@ func UpdateSession(c *gin.Context) { var req Session if err := c.ShouldBind(&req); err == nil { // 修改会话 - err2 := service.UpdateSessionByID(req.ID, userID, req.Name) + err2 := service.UpdateSessionByID(req.ID, userID, req.Name, nil) if err2 == nil { c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success"}) } else { diff --git a/service/imService.go b/service/imService.go index 607a62c..e393321 100644 --- a/service/imService.go +++ b/service/imService.go @@ -43,8 +43,8 @@ func DeleteSessionByID(id, userID int) error { } // 更新会话名字 -func UpdateSessionByID(id int, userId int, name string) error { - err := dao.UpdateSessionByID(id, userId, name) +func UpdateSessionByID(id int, userId int, name string, context []int) error { + err := dao.UpdateSessionByID(id, userId, name, context) return err } @@ -77,7 +77,7 @@ func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg strin } var client *api.Client -var actxMap map[int][]int +var SessionContextMap map[int][]int func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) { var resErr error @@ -92,7 +92,7 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS log.Println("get ollama client error:", err) } prompt := msg.Msg - actx, ok := actxMap[sessionID] + actx, ok := SessionContextMap[sessionID] if ok { log.Println("actx has get will use:", len(actx)) } @@ -111,11 +111,11 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS robotMsg += resp.Response if resp.Done { //该消息完成 actx = resp.Context - if actxMap == nil { - actxMap = make(map[int][]int) + if SessionContextMap == nil { + SessionContextMap = make(map[int][]int) } - actxMap[sessionID] = actx - log.Println("actx has save:", len(actxMap[sessionID])) + SessionContextMap[sessionID] = actx + log.Println("actx has save:", len(SessionContextMap[sessionID])) //创建消息 err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1 if err2 != nil { From 4688e1a916967467e38c40ff179fada716b67581 Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Sun, 23 Mar 2025 16:03:17 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E6=B7=BB=E5=8A=A0spark=E9=80=9A=E7=94=A8?= =?UTF-8?q?=E8=B0=83=E7=94=A8=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- service/imService.go | 118 +++++++++++-------- worker/spark.go | 274 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 340 insertions(+), 52 deletions(-) create mode 100644 worker/spark.go diff --git a/service/imService.go b/service/imService.go index e393321..2f01766 100644 --- a/service/imService.go +++ b/service/imService.go @@ -84,58 +84,72 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS var resID uint //处理消息 if msg.Type == "ollama" { - var err error - if client == nil { - client, err = api.ClientFromEnvironment() - } - if err != nil { - log.Println("get ollama client error:", err) - } - prompt := msg.Msg - actx, ok := SessionContextMap[sessionID] - if ok { - log.Println("actx has get will use:", len(actx)) - } - req := &api.GenerateRequest{ - Model: "qwen2.5:0.5b", - Context: actx, - Prompt: prompt, - } - ctx := context.Background() - robotMsg := "" - var aiMsg proto.AIQueueMessage - var data proto.WSMessage - respFunc := func(resp api.GenerateResponse) error { - aiMsg.Type = "ollama" - aiMsg.Msg = resp - robotMsg += resp.Response - if resp.Done { //该消息完成 - actx = resp.Context - if SessionContextMap == nil { - SessionContextMap = make(map[int][]int) - } - SessionContextMap[sessionID] = actx - log.Println("actx has save:", len(SessionContextMap[sessionID])) - //创建消息 - err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1 - if err2 != nil { - log.Println("create robot message error:", err2) - } - aiMsg.ID = msgID - resID = msgID - } - //发送消息 - data.Msg = aiMsg - data.SessionID = sessionID - data.ToID = userID - aiMsgStr, _ := json.Marshal(data) - worker.Publish(channel, string(aiMsgStr), time.Second*60) - return nil - } - err = client.Generate(ctx, req, respFunc) - if err != nil { - log.Fatal(err) - } + resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg) + } else if msg.Type == "spark" { + } return resErr, resID } + +func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) { + var resErr error + var resID uint + var err error + if client == nil { + client, err = api.ClientFromEnvironment() + } + if err != nil { + log.Println("get ollama client error:", err) + } + prompt := msg.Msg + actx, ok := SessionContextMap[sessionID] + if ok { + log.Println("actx has get will use:", len(actx)) + } + req := &api.GenerateRequest{ + Model: "qwen2.5:0.5b", + Context: actx, + Prompt: prompt, + } + ctx := context.Background() + robotMsg := "" + var aiMsg proto.AIQueueMessage + var data proto.WSMessage + respFunc := func(resp api.GenerateResponse) error { + aiMsg.Type = "ollama" + aiMsg.Msg = resp + robotMsg += resp.Response + if resp.Done { //该消息完成 + actx = resp.Context + if SessionContextMap == nil { + SessionContextMap = make(map[int][]int) + } + SessionContextMap[sessionID] = actx + log.Println("actx has save:", len(SessionContextMap[sessionID])) + //创建消息 + err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1 + if err2 != nil { + log.Println("create robot message error:", err2) + } + aiMsg.ID = msgID + resID = msgID + } + //发送消息 + data.Msg = aiMsg + data.SessionID = sessionID + data.ToID = userID + aiMsgStr, _ := json.Marshal(data) + worker.Publish(channel, string(aiMsgStr), time.Second*60) + return nil + } + err = client.Generate(ctx, req, respFunc) + if err != nil { + log.Fatal(err) + } + return resErr, resID +} + +func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) { + + return nil, 0 +} diff --git a/worker/spark.go b/worker/spark.go new file mode 100644 index 0000000..614bd5a --- /dev/null +++ b/worker/spark.go @@ -0,0 +1,274 @@ +package worker + +import ( + "StuAcaWorksAI/dao" + "StuAcaWorksAI/proto" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "github.com/gorilla/websocket" + "github.com/ollama/ollama/api" + "io" + "log" + "net/http" + "net/url" + "strings" + "time" +) + +/***************** Spark请求结构 ***********************/ +type SparkRequestHeader struct { + AppID string `json:"app_id"` + UID string `json:"uid"` +} + +type ChatParameter struct { + Domain string `json:"domain"` + Temperature float64 `json:"temperature"` + MaxTokens int `json:"max_tokens"` +} + +type Parameter struct { + Chat ChatParameter `json:"chat"` +} + +type MessageContent struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type SparkMessage struct { + Text []MessageContent `json:"text"` +} + +type SparkRequestPayload struct { + Message SparkMessage `json:"message"` +} + +type SparkRequest struct { + Header SparkRequestHeader `json:"header"` + Parameter Parameter `json:"parameter"` + Payload SparkRequestPayload `json:"payload"` +} + +/***************** Spark请求结构 ***********************/ + +/***************** Spark响应结构 ***********************/ + +type SparkResponseHeader struct { + Code int `json:"code"` + Message string `json:"message"` + SID string `json:"sid"` + Status int `json:"status"` //会话状态,取值为[0,1,2];0代表首次结果;1代表中间结果;2代表最后一个结果 +} + +// ChoiceText 结构体对应 JSON 中 "choices.text[].text" 部分 +type ChoiceText struct { + Content string `json:"content"` + Role string `json:"role"` + Index int `json:"index"` +} + +// Choices 结构体对应 JSON 中 "choices" 部分 +type Choices struct { + Status int `json:"status"` + Seq int `json:"seq"` + Text []ChoiceText `json:"text"` +} + +// UsageText 结构体对应 JSON 中 "usage.text" 部分 +type UsageText struct { + QuestionTokens int `json:"question_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// Usage 结构体对应 JSON 中 "usage" 部分 +type Usage struct { + Text UsageText `json:"text"` +} + +// Payload 结构体对应 JSON 中 "payload" 部分 +type SparkResponsePayload struct { + Choices Choices `json:"choices"` + Usage Usage `json:"usage"` +} + +// Request 结构体对应整个 JSON 数据 +type SparkResponse struct { + Header SparkResponseHeader `json:"header"` + Payload SparkResponsePayload `json:"payload"` +} + +/***************** Spark响应结构 ***********************/ + +func Spark(hostUrl string, apiKey, apiSecret, appid string, question, channel string, SessionID, UserID int) { + //建立websocket连接 + dialer := websocket.Dialer{ + HandshakeTimeout: 5 * time.Second, + } + //建立连接 + conn, resp, err := dialer.Dial(assembleAuthUrl(hostUrl, apiKey, apiSecret), nil) + if err != nil { + panic(readResp(resp) + err.Error()) + return + } else if resp.StatusCode != 101 { + panic(readResp(resp) + err.Error()) + } + + //发送消息 + go func() { + data := genSparkParams(appid, question) + //将数据转换为json + //dataByte, err := json.Marshal(data) + if err != nil { + fmt.Println("Error parsing JSON:", err) + return + } + err2 := conn.WriteJSON(data) + if err != nil { + fmt.Println("write message error:", err2) + return + } + }() + + answer := "" + + //接收消息 + for { + _, msg, err2 := conn.ReadMessage() + if err2 != nil { + log.Println("read message error:", err2) + break + } + var data SparkResponse + err = json.Unmarshal(msg, &data) + if err != nil { + log.Println("Error parsing JSON:", err) + return + } + + if data.Header.Code != 0 { + log.Println("Error response:", data.Header.Message) + break + } + if data.Header.Status == 2 { + //最后一个结果 + answer += data.Payload.Choices.Text[0].Content + } else { + answer += data.Payload.Choices.Text[0].Content + } + SparkToGeneralMassageAndSendMsgQueue(&data, channel, SessionID, UserID) + } + + conn.Close() + //将消息存入数据库 + err, _ = dao.CreateMessage(SessionID, 3, UserID, 2, 1, answer) + if err != nil { + log.Println("spark create message error:", err) + } +} + +func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponsem, channel string, SessionID int, UserID int) { + var aiMsg proto.AIQueueMessage + var wsMsg proto.WSMessage + var apiMsg api.GenerateResponse + apiMsg.Model = "spark" + apiMsg.Response = data.Payload.Choices.Text[0].Content + if data.Header.Status == 2 { + apiMsg.Done = true + } else { + apiMsg.Done = false + } + //消息队列部分 + aiMsg.Type = "spark" + aiMsg.Msg = apiMsg + //ws发送消息部分 + wsMsg.Msg = aiMsg + wsMsg.SessionID = SessionID + wsMsg.ToID = UserID + wsMsg.Type = "spark" + + //发送消息 + wsMsgStr, _ := json.Marshal(wsMsg) + Publish(channel, string(wsMsgStr), time.Second*60) +} + +// 生成参数 +func genSparkParams(question, appid string) SparkRequest { + messages := []MessageContent{ + {Role: "user", Content: question}, + } + + data := SparkRequest{ + Header: SparkRequestHeader{ + AppID: appid, + }, + Parameter: Parameter{ + Chat: ChatParameter{ + Domain: "4.0Ultra", + Temperature: 0.8, + MaxTokens: 2048, + }, + }, + Payload: SparkRequestPayload{ + Message: SparkMessage{ + Text: messages, + }, + }, + } + return data +} + +// 创建鉴权url apikey 即 hmac username +func assembleAuthUrl(hostUrl string, apiKey, apiSecret string) string { + ul, err := url.Parse(hostUrl) + if err != nil { + fmt.Println(err) + } + //签名时间 + date := time.Now().UTC().Format(time.RFC1123) + //date = "Tue, 28 May 2019 09:10:42 MST" + //参与签名的字段 host ,date, request-line + signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} + //拼接签名字符串 + sgin := strings.Join(signString, "\n") + // fmt.Println(sgin) + //签名结果 + sha := HmacWithShaTobase64("hmac-sha256", sgin, apiSecret) + // fmt.Println(sha) + //构建请求参数 此时不需要urlencoding + authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, + "hmac-sha256", "host date request-line", sha) + //将请求参数使用base64编码 + authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) + + v := url.Values{} + v.Add("host", ul.Host) + v.Add("date", date) + v.Add("authorization", authorization) + //将编码后的字符串url encode后添加到url后面 + callurl := hostUrl + "?" + v.Encode() + return callurl +} + +func HmacWithShaTobase64(algorithm, data, key string) string { + mac := hmac.New(sha256.New, []byte(key)) + mac.Write([]byte(data)) + encodeData := mac.Sum(nil) + return base64.StdEncoding.EncodeToString(encodeData) +} + +func readResp(resp *http.Response) string { + if resp == nil { + return "" + } + b, err := io.ReadAll(resp.Body) + if err != nil { + panic(err) + } + return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b)) +} From e2f1b39803835da221f33a1ee2a76c1945a7c1c3 Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Mon, 24 Mar 2025 13:36:42 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E6=B7=BB=E5=8A=A0spark=E9=83=A8=E5=88=86?= =?UTF-8?q?=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- proto/im.go | 1 + service/imService.go | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/proto/im.go b/proto/im.go index 0524738..a21c301 100644 --- a/proto/im.go +++ b/proto/im.go @@ -21,6 +21,7 @@ type ModelParam struct { type WSMessageReq struct { Type string `json:"type"` //接收及发送消息类型 + Function string `json:"function"` //功能名称 Msg string `json:"msg"` //消息内容,只进行转发,不做处理 SessionID int `json:"session_id"` //应用层会话id ToID int `json:"to_id"` //接收者id diff --git a/service/imService.go b/service/imService.go index 2f01766..c1fcee2 100644 --- a/service/imService.go +++ b/service/imService.go @@ -86,6 +86,7 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS if msg.Type == "ollama" { resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg) } else if msg.Type == "spark" { + resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg) } return resErr, resID @@ -150,6 +151,9 @@ func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMes } func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) { + var resErr error + var resID uint + //查找模型 - return nil, 0 + return resErr, resID }