package service import ( "StuAcaWorksAI/dao" "StuAcaWorksAI/proto" "StuAcaWorksAI/worker" "context" "encoding/json" "errors" "github.com/ollama/ollama/api" "log" "time" ) // 创建会话服务 func CreateSession(userID, Type int, name string) (error, uint) { err, id := dao.CreateSession(userID, Type, name) return err, id } // 查找用户的会话列表 func FindSessionByUserID(userID, sessionType int) []proto.SessionResponse { sessions := dao.FindSessionByUserID(userID, sessionType) sessionsMsgCounts := dao.FindSessionMessageCountByUserID(userID) sessionsTokens := dao.FindModelTotalTokensBySessionID(userID) smap := make(map[int]int) st := make(map[int]uint) for _, v := range sessionsMsgCounts { smap[v.SessionID] = v.Count } for _, v := range sessionsTokens { st[int(v.SessionID)] = v.Token } var res []proto.SessionResponse for _, v := range sessions { var session proto.SessionResponse session.ID = v.ID session.CreatedAt = v.CreatedAt session.UpdatedAt = v.UpdatedAt session.Name = v.Name session.UserID = v.UserID session.MsgCount = smap[int(v.ID)] session.TokenUsage = st[int(v.ID)] res = append(res, session) } return res } func FindSessionByID(id, userID, sessionType int) (error, []dao.Session) { session := dao.FindSessionByID(id) var res []dao.Session if session.UserID != userID { return errors.New("session not belong to user"), res } res = append(res, session) return nil, res } func DeleteSessionByID(id, userID int) error { session := dao.FindSessionByID(id) if session.UserID != userID { return errors.New("session not belong to user") } err := dao.DeleteSessionByID(id) return err } // 更新会话名字 func UpdateSessionByID(id int, userId int, name string, context []int) error { err := dao.UpdateSessionByID(id, userId, name, context) return err } // 创建消息服务 func CreateMessage(msgType int, sessionID int, fromID int, toID int, msg string, status int, functionId int) (error, uint) { //sessionID, fromID, toID, msgType, status int, msg string err, id := dao.CreateMessage(sessionID, fromID, toID, msgType, status, msg, functionId) return err, id } // 通过会话id查询消息 func FindMessageBySessionID(sessionID, userID int) (error, []dao.Message) { //查看会话是否属于该用户 session := dao.FindSessionByID(sessionID) if session.UserID != userID { return errors.New("session not belong to user"), nil } msgs := dao.FindMessageBySessionIDV2(sessionID) return nil, msgs } func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg string) (error, uint) { //查看会话是否属于该用户 session := dao.FindSessionByID(sessionID) if session.UserID != fromID { return errors.New("session not belong to user"), 0 } err, mid := CreateMessage(msgType, sessionID, fromID, toID, msg, 0, 0) return err, mid } var client *api.Client var SessionContextMap map[int][]int func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq, model dao.Model, modelParam proto.ModelParam) (error, uint) { var resErr error var resID uint //处理消息 if model.Type == proto.ModelTypeOllama { resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam, &model) } else if model.Type == proto.ModelTypeSpark { resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam, &model) } else if model.Type == proto.ModelTypeDouBao { doubao(modelParam, msg.Msg, channel, sessionID, userID, int(model.ID)) } else { return errors.New("model not exist"), 0 } return resErr, resID } func WSReceiveMessageServiceV2(modelParam proto.ModelParam, ctx *proto.IMParamContext) (error, uint) { var resErr error var resID uint //处理消息 if ctx.ModelType == proto.ModelTypeOllama { resErr, resID = ReceiveOllamaSessionV2(modelParam, ctx) } else if ctx.ModelType == proto.ModelTypeSpark { SparkV2(modelParam, ctx) } else if ctx.ModelType == proto.ModelTypeDouBao { DouBaoV2(modelParam, ctx) } else { return errors.New("model not exist"), 0 } return resErr, resID } func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, m *dao.Model) (error, uint) { var resErr error var resID uint var err error if client == nil { client, err = api.ClientFromEnvironment() } if err != nil { log.Println("get ollama client error:", err) } prompt := msg.Msg actx, ok := SessionContextMap[sessionID] if ok { log.Println("actx has get will use:", len(actx)) } req := &api.GenerateRequest{ Model: modelParam.Model, Context: actx, Prompt: prompt, } ctx := context.Background() robotMsg := "" var aiMsg proto.AIQueueMessage var data proto.WSMessage respFunc := func(resp api.GenerateResponse) error { aiMsg.Type = proto.ModelTypeOllama aiMsg.Msg = resp robotMsg += resp.Response if resp.Done { //该消息完成 actx = resp.Context if SessionContextMap == nil { SessionContextMap = make(map[int][]int) } SessionContextMap[sessionID] = actx log.Println("actx has save:", len(SessionContextMap[sessionID])) //创建消息 err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1, 0) //机器人id为2,消息类型为2,状态为1 if err2 != nil { log.Println("create robot message error:", err2) } aiMsg.ID = msgID resID = msgID } //发送消息 data.Msg = aiMsg data.SessionID = sessionID data.ToID = userID aiMsgStr, _ := json.Marshal(data) worker.Publish(channel, string(aiMsgStr), time.Second*60) go CreateAIStreamMsg(userID, int(m.ID), sessionID, resp, modelParam.Model) return nil } err = client.Generate(ctx, req, respFunc) if err != nil { log.Fatal(err) } return resErr, resID } func ReceiveOllamaSessionV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) (error, uint) { var resErr error var resID uint var err error if client == nil { client, err = api.ClientFromEnvironment() } if err != nil { log.Println("get ollama client error:", err) } prompt := imCtx.Question actx, ok := SessionContextMap[int(imCtx.SessionID)] if ok { log.Println("actx has get will use:", len(actx)) } req := &api.GenerateRequest{ Model: modelParam.Model, Context: actx, Prompt: prompt, } ctx := context.Background() robotMsg := "" var aiMsg proto.AIQueueMessage var data proto.WSMessage respFunc := func(resp api.GenerateResponse) error { aiMsg.Type = proto.ModelTypeOllama aiMsg.Msg = resp robotMsg += resp.Response if resp.Done { //该消息完成 actx = resp.Context if SessionContextMap == nil { SessionContextMap = make(map[int][]int) } SessionContextMap[imCtx.SessionID] = actx log.Println("actx has save:", len(SessionContextMap[imCtx.SessionID])) //创建消息 err2, msgID := CreateMessage(proto.UserToModelMsgType, imCtx.SessionID, imCtx.ModelID, imCtx.UserID, robotMsg, 1, imCtx.FunctionID) //机器人id为2,消息类型为2,状态为1 if err2 != nil { log.Println("create robot message error:", err2) } aiMsg.ID = msgID resID = msgID } //发送消息 data.Msg = aiMsg data.SessionID = int(imCtx.SessionID) data.ToID = imCtx.UserID aiMsgStr, _ := json.Marshal(data) worker.Publish(imCtx.Channel, string(aiMsgStr), time.Second*60) go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, resp, modelParam.Model) return nil } err = client.Generate(ctx, req, respFunc) if err != nil { log.Fatal(err) } return resErr, resID } func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, model *dao.Model) (error, uint) { var resErr error var resID uint Spark(modelParam, msg.Msg, channel, sessionID, userID, int(model.ID)) return resErr, resID } func CheckUserCreatePPTSessionPermission(userID int) error { sessionCount := dao.FindUserSessionCount(userID, proto.SessionTypeUserCreatePPT) var err error user := GetUserByIDWithCache(userID) if user.Role == "admin" { return nil } if sessionCount > 3 { err = errors.New("create ppt session count exceed limit") } return err }