diff --git a/proto/im.go b/proto/im.go index 7542a9e..6751861 100644 --- a/proto/im.go +++ b/proto/im.go @@ -12,3 +12,7 @@ type WSMessage struct { SessionID int `json:"session_id"` //应用层会话id ToID int `json:"to_id"` //接收者id } + +type ModelParam struct { + Model string `json:"model"` //模型名称 +} diff --git a/service/imService.go b/service/imService.go index 4036353..55a8dd4 100644 --- a/service/imService.go +++ b/service/imService.go @@ -77,6 +77,7 @@ func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg strin } var client *api.Client +var actxMap map[int][]int func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessage) (error, uint) { var resErr error @@ -91,9 +92,9 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS log.Println("get ollama client error:", err) } prompt := msg.Msg - var actx []int + actx, _ := actxMap[sessionID] req := &api.GenerateRequest{ - Model: "qwen2.5-coder:latest", + Model: "qwen2.5:0.5b", Context: actx, Prompt: prompt, } @@ -107,6 +108,7 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS robotMsg += resp.Response if resp.Done { //该消息完成 actx = resp.Context + actxMap[sessionID] = actx //创建消息 err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1 if err2 != nil {