完善spark支持

This commit is contained in:
junleea 2025-03-25 15:51:05 +08:00
parent 0b508ccfcc
commit 7de1a26c8e
2 changed files with 5 additions and 5 deletions

View File

@ -102,7 +102,7 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS
if model.Type == "ollama" {
resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam, &model)
} else if model.Type == "spark" {
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam)
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam, &model)
} else {
return errors.New("model not exist"), 0
}
@ -168,9 +168,9 @@ func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMes
return resErr, resID
}
func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam) (error, uint) {
func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, model *dao.Model) (error, uint) {
var resErr error
var resID uint
Spark(modelParam, msg.Msg, channel, sessionID, userID)
Spark(modelParam, msg.Msg, channel, sessionID, userID, int(model.ID))
return resErr, resID
}

View File

@ -106,7 +106,7 @@ type SparkResponse struct {
/***************** Spark响应结构 ***********************/
func Spark(modelParam proto.ModelParam, question, channel string, SessionID, UserID int) {
func Spark(modelParam proto.ModelParam, question, channel string, SessionID, UserID, modelID int) {
log.Println("spark model param:", modelParam)
log.Println("spark question:", question)
//建立websocket连接
@ -170,7 +170,7 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use
conn.Close()
//将消息存入数据库
err, _ = dao.CreateMessage(SessionID, 3, UserID, 2, 1, answer)
err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.UserAndModelMsgType, proto.MsgHasRead, answer)
if err != nil {
log.Println("spark create message error:", err)
}