package service import ( "StuAcaWorksAI/dao" "StuAcaWorksAI/proto" "StuAcaWorksAI/worker" "context" "encoding/json" "errors" "github.com/ollama/ollama/api" "log" "time" ) // 创建会话服务 func CreateSession(userID int, name string) (error, uint) { err, id := dao.CreateSession(userID, name) return err, id } // 查找用户的会话列表 func FindSessionByUserID(userID int) []dao.Session { sessions := dao.FindSessionByUserID(userID) return sessions } func FindSessionByID(id, userID int) (error, []dao.Session) { session := dao.FindSessionByID(id) var res []dao.Session if session.UserID != userID { return errors.New("session not belong to user"), res } res = append(res, session) return nil, res } func DeleteSessionByID(id, userID int) error { session := dao.FindSessionByID(id) if session.UserID != userID { return errors.New("session not belong to user") } err := dao.DeleteSessionByID(id) return err } // 更新会话名字 func UpdateSessionByID(id int, userId int, name string) error { err := dao.UpdateSessionByID(id, userId, name) return err } // 创建消息服务 func CreateMessage(msgType int, sessionID int, fromID int, toID int, msg string, status int) (error, uint) { //sessionID, fromID, toID, msgType, status int, msg string err, id := dao.CreateMessage(sessionID, fromID, toID, msgType, status, msg) return err, id } // 通过会话id查询消息 func FindMessageBySessionID(sessionID, userID int) (error, []dao.Message) { //查看会话是否属于该用户 session := dao.FindSessionByID(sessionID) if session.UserID != userID { return errors.New("session not belong to user"), nil } msgs := dao.FindMessageBySessionID(sessionID) return nil, msgs } func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg string) (error, uint) { //查看会话是否属于该用户 session := dao.FindSessionByID(sessionID) if session.UserID != fromID { return errors.New("session not belong to user"), 0 } err, mid := CreateMessage(msgType, sessionID, fromID, toID, msg, 0) return err, mid } var client *api.Client var actxMap map[int][]int func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) { var resErr error var resID uint //处理消息 if msg.Type == "ollama" { var err error if client == nil { client, err = api.ClientFromEnvironment() } if err != nil { log.Println("get ollama client error:", err) } prompt := msg.Msg actx, ok := actxMap[sessionID] if ok { log.Println("actx has get will use:", len(actx)) } req := &api.GenerateRequest{ Model: "qwen2.5:0.5b", Context: actx, Prompt: prompt, } ctx := context.Background() robotMsg := "" var aiMsg proto.AIQueueMessage var data proto.WSMessage respFunc := func(resp api.GenerateResponse) error { aiMsg.Type = "ollama" aiMsg.Msg = resp robotMsg += resp.Response if resp.Done { //该消息完成 actx = resp.Context if actxMap == nil { actxMap = make(map[int][]int) } actxMap[sessionID] = actx log.Println("actx has save:", len(actxMap[sessionID])) //创建消息 err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1 if err2 != nil { log.Println("create robot message error:", err2) } aiMsg.ID = msgID resID = msgID } //发送消息 data.Msg = aiMsg data.SessionID = sessionID data.ToID = userID aiMsgStr, _ := json.Marshal(data) worker.Publish(channel, string(aiMsgStr), time.Second*60) return nil } err = client.Generate(ctx, req, respFunc) if err != nil { log.Fatal(err) } } return resErr, resID }