From 7de1a26c8e9bd13cf63af0fa344fcdef3bb808de Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Tue, 25 Mar 2025 15:51:05 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84spark=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- service/imService.go | 6 +++--- service/spark.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/service/imService.go b/service/imService.go index 594dab4..fb5514d 100644 --- a/service/imService.go +++ b/service/imService.go @@ -102,7 +102,7 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS if model.Type == "ollama" { resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam, &model) } else if model.Type == "spark" { - resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam) + resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam, &model) } else { return errors.New("model not exist"), 0 } @@ -168,9 +168,9 @@ func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMes return resErr, resID } -func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam) (error, uint) { +func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, model *dao.Model) (error, uint) { var resErr error var resID uint - Spark(modelParam, msg.Msg, channel, sessionID, userID) + Spark(modelParam, msg.Msg, channel, sessionID, userID, int(model.ID)) return resErr, resID } diff --git a/service/spark.go b/service/spark.go index c4343f4..6b835da 100644 --- a/service/spark.go +++ b/service/spark.go @@ -106,7 +106,7 @@ type SparkResponse struct { /***************** Spark响应结构 ***********************/ -func Spark(modelParam proto.ModelParam, question, channel string, SessionID, UserID int) { +func Spark(modelParam proto.ModelParam, question, channel string, SessionID, UserID, modelID int) { log.Println("spark model param:", modelParam) log.Println("spark question:", question) //建立websocket连接 @@ -170,7 +170,7 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use conn.Close() //将消息存入数据库 - err, _ = dao.CreateMessage(SessionID, 3, UserID, 2, 1, answer) + err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.UserAndModelMsgType, proto.MsgHasRead, answer) if err != nil { log.Println("spark create message error:", err) }