完善spark支持

This commit is contained in:
junleea 2025-03-25 15:51:05 +08:00
parent 0b508ccfcc
commit 7de1a26c8e
2 changed files with 5 additions and 5 deletions

View File

@ -102,7 +102,7 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS
if model.Type == "ollama" { if model.Type == "ollama" {
resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam, &model) resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam, &model)
} else if model.Type == "spark" { } else if model.Type == "spark" {
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam) resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam, &model)
} else { } else {
return errors.New("model not exist"), 0 return errors.New("model not exist"), 0
} }
@ -168,9 +168,9 @@ func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMes
return resErr, resID return resErr, resID
} }
func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam) (error, uint) { func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, model *dao.Model) (error, uint) {
var resErr error var resErr error
var resID uint var resID uint
Spark(modelParam, msg.Msg, channel, sessionID, userID) Spark(modelParam, msg.Msg, channel, sessionID, userID, int(model.ID))
return resErr, resID return resErr, resID
} }

View File

@ -106,7 +106,7 @@ type SparkResponse struct {
/***************** Spark响应结构 ***********************/ /***************** Spark响应结构 ***********************/
func Spark(modelParam proto.ModelParam, question, channel string, SessionID, UserID int) { func Spark(modelParam proto.ModelParam, question, channel string, SessionID, UserID, modelID int) {
log.Println("spark model param:", modelParam) log.Println("spark model param:", modelParam)
log.Println("spark question:", question) log.Println("spark question:", question)
//建立websocket连接 //建立websocket连接
@ -170,7 +170,7 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use
conn.Close() conn.Close()
//将消息存入数据库 //将消息存入数据库
err, _ = dao.CreateMessage(SessionID, 3, UserID, 2, 1, answer) err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.UserAndModelMsgType, proto.MsgHasRead, answer)
if err != nil { if err != nil {
log.Println("spark create message error:", err) log.Println("spark create message error:", err)
} }