完善spark支持
This commit is contained in:
parent
0b508ccfcc
commit
7de1a26c8e
|
|
@ -102,7 +102,7 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS
|
||||||
if model.Type == "ollama" {
|
if model.Type == "ollama" {
|
||||||
resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam, &model)
|
resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam, &model)
|
||||||
} else if model.Type == "spark" {
|
} else if model.Type == "spark" {
|
||||||
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam)
|
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam, &model)
|
||||||
} else {
|
} else {
|
||||||
return errors.New("model not exist"), 0
|
return errors.New("model not exist"), 0
|
||||||
}
|
}
|
||||||
|
|
@ -168,9 +168,9 @@ func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMes
|
||||||
return resErr, resID
|
return resErr, resID
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam) (error, uint) {
|
func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, model *dao.Model) (error, uint) {
|
||||||
var resErr error
|
var resErr error
|
||||||
var resID uint
|
var resID uint
|
||||||
Spark(modelParam, msg.Msg, channel, sessionID, userID)
|
Spark(modelParam, msg.Msg, channel, sessionID, userID, int(model.ID))
|
||||||
return resErr, resID
|
return resErr, resID
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,7 @@ type SparkResponse struct {
|
||||||
|
|
||||||
/***************** Spark响应结构 ***********************/
|
/***************** Spark响应结构 ***********************/
|
||||||
|
|
||||||
func Spark(modelParam proto.ModelParam, question, channel string, SessionID, UserID int) {
|
func Spark(modelParam proto.ModelParam, question, channel string, SessionID, UserID, modelID int) {
|
||||||
log.Println("spark model param:", modelParam)
|
log.Println("spark model param:", modelParam)
|
||||||
log.Println("spark question:", question)
|
log.Println("spark question:", question)
|
||||||
//建立websocket连接
|
//建立websocket连接
|
||||||
|
|
@ -170,7 +170,7 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use
|
||||||
|
|
||||||
conn.Close()
|
conn.Close()
|
||||||
//将消息存入数据库
|
//将消息存入数据库
|
||||||
err, _ = dao.CreateMessage(SessionID, 3, UserID, 2, 1, answer)
|
err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.UserAndModelMsgType, proto.MsgHasRead, answer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("spark create message error:", err)
|
log.Println("spark create message error:", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue