package service import ( "StuAcaWorksAI/dao" "StuAcaWorksAI/proto" "StuAcaWorksAI/worker" "context" "encoding/json" "errors" "github.com/ollama/ollama/api" "log" "time" ) // 创建会话服务 func CreateSession(userID int, name string) (error, uint) { err, id := dao.CreateSession(userID, name) return err, id } // 查找用户的会话列表 func FindSessionByUserID(userID int) []proto.SessionResponse { sessions := dao.FindSessionByUserID(userID) sessionsMsgCounts := dao.FindSessionMessageCountByUserID(userID) smap := make(map[int]int) for _, v := range sessionsMsgCounts { smap[v.SessionID] = v.Count } var res []proto.SessionResponse for _, v := range sessions { var session proto.SessionResponse session.ID = v.ID session.CreatedAt = v.CreatedAt session.UpdatedAt = v.UpdatedAt session.Name = v.Name session.UserID = v.UserID session.MsgCount = smap[int(v.ID)] res = append(res, session) } return res } func FindSessionByID(id, userID int) (error, []dao.Session) { session := dao.FindSessionByID(id) var res []dao.Session if session.UserID != userID { return errors.New("session not belong to user"), res } res = append(res, session) return nil, res } func DeleteSessionByID(id, userID int) error { session := dao.FindSessionByID(id) if session.UserID != userID { return errors.New("session not belong to user") } err := dao.DeleteSessionByID(id) return err } // 更新会话名字 func UpdateSessionByID(id int, userId int, name string, context []int) error { err := dao.UpdateSessionByID(id, userId, name, context) return err } // 创建消息服务 func CreateMessage(msgType int, sessionID int, fromID int, toID int, msg string, status int) (error, uint) { //sessionID, fromID, toID, msgType, status int, msg string err, id := dao.CreateMessage(sessionID, fromID, toID, msgType, status, msg) return err, id } // 通过会话id查询消息 func FindMessageBySessionID(sessionID, userID int) (error, []dao.Message) { //查看会话是否属于该用户 session := dao.FindSessionByID(sessionID) if session.UserID != userID { return errors.New("session not belong to user"), nil } msgs := dao.FindMessageBySessionID(sessionID) return nil, msgs } func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg string) (error, uint) { //查看会话是否属于该用户 session := dao.FindSessionByID(sessionID) if session.UserID != fromID { return errors.New("session not belong to user"), 0 } err, mid := CreateMessage(msgType, sessionID, fromID, toID, msg, 0) return err, mid } var client *api.Client var SessionContextMap map[int][]int func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq, model dao.Model, modelParam proto.ModelParam) (error, uint) { var resErr error var resID uint //处理消息 if model.Type == "ollama" { resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam, &model) } else if model.Type == "spark" { resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam, &model) } else if model.Type == "doubao" { doubao(modelParam, msg.Msg, channel, sessionID, userID, int(model.ID)) } else { return errors.New("model not exist"), 0 } return resErr, resID } func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, m *dao.Model) (error, uint) { var resErr error var resID uint var err error if client == nil { client, err = api.ClientFromEnvironment() } if err != nil { log.Println("get ollama client error:", err) } prompt := msg.Msg actx, ok := SessionContextMap[sessionID] if ok { log.Println("actx has get will use:", len(actx)) } req := &api.GenerateRequest{ Model: modelParam.Model, Context: actx, Prompt: prompt, } ctx := context.Background() robotMsg := "" var aiMsg proto.AIQueueMessage var data proto.WSMessage respFunc := func(resp api.GenerateResponse) error { aiMsg.Type = "ollama" aiMsg.Msg = resp robotMsg += resp.Response if resp.Done { //该消息完成 actx = resp.Context if SessionContextMap == nil { SessionContextMap = make(map[int][]int) } SessionContextMap[sessionID] = actx log.Println("actx has save:", len(SessionContextMap[sessionID])) //创建消息 err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1 if err2 != nil { log.Println("create robot message error:", err2) } aiMsg.ID = msgID resID = msgID } //发送消息 data.Msg = aiMsg data.SessionID = sessionID data.ToID = userID aiMsgStr, _ := json.Marshal(data) worker.Publish(channel, string(aiMsgStr), time.Second*60) go CreateAIStreamMsg(userID, int(m.ID), sessionID, resp, modelParam.Model) return nil } err = client.Generate(ctx, req, respFunc) if err != nil { log.Fatal(err) } return resErr, resID } func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, model *dao.Model) (error, uint) { var resErr error var resID uint Spark(modelParam, msg.Msg, channel, sessionID, userID, int(model.ID)) return resErr, resID }