package service import ( "StuAcaWorksAI/dao" "StuAcaWorksAI/proto" "StuAcaWorksAI/service/generalModel" "StuAcaWorksAI/worker" "context" "encoding/json" "errors" "fmt" "github.com/ollama/ollama/api" "log" "time" ) // 创建会话服务 func CreateSession(userID, Type int, name string) (error, uint) { err, id := dao.CreateSession(userID, Type, name) return err, id } // 查找用户的会话列表 func FindSessionByUserID(userID, sessionType int) []proto.SessionResponse { var sessions []dao.Session if sessionType == -1 { sessions = dao.FindSessionByUserIDV2(userID) } else { sessions = dao.FindSessionByUserID(userID, sessionType) } sessionsMsgCounts := dao.FindSessionMessageCountByUserID(userID) sessionsTokens := dao.FindModelTotalTokensBySessionID(userID) smap := make(map[int]int) st := make(map[int]uint) for _, v := range sessionsMsgCounts { smap[v.SessionID] = v.Count } for _, v := range sessionsTokens { st[int(v.SessionID)] = v.Token } var res []proto.SessionResponse for _, v := range sessions { var session proto.SessionResponse session.ID = v.ID session.CreatedAt = v.CreatedAt session.UpdatedAt = v.UpdatedAt session.Name = v.Name session.UserID = v.UserID session.MsgCount = smap[int(v.ID)] session.TokenUsage = st[int(v.ID)] res = append(res, session) } return res } func FindSessionByID(id, userID, sessionType int) (error, []dao.Session) { session := dao.FindSessionByID(id) var res []dao.Session if session.UserID != userID { return errors.New("session not belong to user"), res } res = append(res, session) return nil, res } func DeleteSessionByID(id, userID int) error { session := dao.FindSessionByID(id) if session.UserID != userID { return errors.New("session not belong to user") } err := dao.DeleteSessionByID(id) return err } // 更新会话名字 func UpdateSessionByID(id int, userId int, name string, context []int) error { err := dao.UpdateSessionByID(id, userId, name, context) return err } // 创建消息服务 func CreateMessage(msgType int, sessionID int, fromID int, toID int, msg string, status int, functionId int) (error, uint) { //sessionID, fromID, toID, msgType, status int, msg string err, id := dao.CreateMessage(sessionID, fromID, toID, msgType, status, msg, functionId) return err, id } // 通过会话id查询消息 func FindMessageBySessionID(sessionID, userID int) (error, []dao.Message) { //查看会话是否属于该用户 session := dao.FindSessionByID(sessionID) if session.UserID != userID { return errors.New("session not belong to user"), nil } msgs := dao.FindMessageBySessionIDV2(sessionID) return nil, msgs } func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg string) (error, uint) { //查看会话是否属于该用户 session := dao.FindSessionByID(sessionID) if session.UserID != fromID { return errors.New("session not belong to user"), 0 } err, mid := CreateMessage(msgType, sessionID, fromID, toID, msg, 0, 0) return err, mid } var client *api.Client var SessionContextMap map[int][]int func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq, model dao.Model, modelParam proto.ModelParam) (error, uint) { var resErr error var resID uint //处理消息 if model.Type == proto.ModelTypeOllama { resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam, &model) } else if model.Type == proto.ModelTypeSpark { resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam, &model) } else if model.Type == proto.ModelTypeDouBao { doubao(modelParam, msg.Msg, channel, sessionID, userID, int(model.ID)) } else { return errors.New("model not exist"), 0 } return resErr, resID } func WSReceiveMessageServiceV2(modelParam proto.ModelParam, ctx *proto.IMParamContext) (error, uint) { var resErr error var resID uint //处理消息 switch ctx.ModelType { case proto.ModelTypeOllama: resErr, resID = ReceiveOllamaSessionV2(modelParam, ctx) case proto.ModelTypeSpark: SparkV2(modelParam, ctx) case proto.ModelTypeDouBao: DouBaoV2(modelParam, ctx) case proto.ModelTypeQianfan: Qianfan(modelParam, ctx) case proto.ModelTypeTongyi: Tongyi(modelParam, ctx) case proto.ModelTypeHunyuan: generalModel.OpenAI(modelParam, ctx, ctx.ModelType) default: resErr, resID = errors.New("model type not exist:"+ctx.ModelType), 0 } return resErr, resID } func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, m *dao.Model) (error, uint) { var resErr error var resID uint var err error if client == nil { client, err = api.ClientFromEnvironment() } if err != nil { log.Println("get ollama client error:", err) } prompt := msg.Msg actx, ok := SessionContextMap[sessionID] if ok { log.Println("actx has get will use:", len(actx)) } req := &api.GenerateRequest{ Model: modelParam.Model, Context: actx, Prompt: prompt, } ctx := context.Background() robotMsg := "" var aiMsg proto.AIQueueMessage var data proto.WSMessage respFunc := func(resp api.GenerateResponse) error { aiMsg.Type = proto.ModelTypeOllama aiMsg.Msg = resp robotMsg += resp.Response if resp.Done { //该消息完成 actx = resp.Context if SessionContextMap == nil { SessionContextMap = make(map[int][]int) } SessionContextMap[sessionID] = actx log.Println("actx has save:", len(SessionContextMap[sessionID])) //创建消息 err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1, 0) //机器人id为2,消息类型为2,状态为1 if err2 != nil { log.Println("create robot message error:", err2) } aiMsg.ID = msgID resID = msgID } //发送消息 data.Msg = aiMsg data.SessionID = sessionID data.ToID = userID aiMsgStr, _ := json.Marshal(data) worker.Publish(channel, string(aiMsgStr), time.Second*60) go CreateAIStreamMsg(userID, int(m.ID), sessionID, resp, modelParam.Model) return nil } err = client.Generate(ctx, req, respFunc) if err != nil { log.Fatal(err) } return resErr, resID } func ReceiveOllamaSessionV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) (error, uint) { var resErr error var resID uint var err error if client == nil { client, err = api.ClientFromEnvironment() } if err != nil { log.Println("get ollama client error:", err) } prompt := imCtx.Question actx, ok := SessionContextMap[int(imCtx.SessionID)] if ok { log.Println("actx has get will use:", len(actx)) } req := &api.GenerateRequest{ Model: modelParam.Model, Context: actx, Prompt: prompt, } ctx := context.Background() robotMsg := "" var aiMsg proto.AIQueueMessage var data proto.WSMessage respFunc := func(resp api.GenerateResponse) error { aiMsg.Type = proto.ModelTypeOllama aiMsg.Msg = resp robotMsg += resp.Response if resp.Done { //该消息完成 actx = resp.Context if SessionContextMap == nil { SessionContextMap = make(map[int][]int) } SessionContextMap[imCtx.SessionID] = actx log.Println("actx has save:", len(SessionContextMap[imCtx.SessionID])) //创建消息 err2, msgID := CreateMessage(proto.ModelToUserMsgType, imCtx.SessionID, imCtx.ModelID, imCtx.UserID, robotMsg, 1, imCtx.FunctionID) //机器人id为2,消息类型为2,状态为1 if err2 != nil { log.Println("create robot message error:", err2) } aiMsg.ID = msgID resID = msgID } //发送消息 data.Msg = aiMsg data.SessionID = int(imCtx.SessionID) data.ToID = imCtx.UserID aiMsgStr, _ := json.Marshal(data) worker.Publish(imCtx.Channel, string(aiMsgStr), time.Second*60) go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, resp, modelParam.Model) return nil } err = client.Generate(ctx, req, respFunc) if err != nil { log.Fatal(err) } return resErr, resID } func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, model *dao.Model) (error, uint) { var resErr error var resID uint Spark(modelParam, msg.Msg, channel, sessionID, userID, int(model.ID)) return resErr, resID } func CheckUserCreatePPTSessionPermission(userID int) error { sessionCount := dao.FindUserSessionCount(userID, proto.SessionTypeUserCreatePPT) var err error user := GetUserByIDWithCache(userID) if user.Role == "admin" { return nil } if sessionCount > 3 { err = errors.New("create ppt session count exceed limit") } return err } // 知识库消息,从原始请求数据,上下文数据 func SelectKBaseContentByQuery(imCtx *proto.IMParamContext, data *proto.WSMessageReq) error { var req proto.KnowledgeBaseServerRequest req.IMContext = *imCtx kbase, err := dao.GetKnowledgeBaseByID(data.KBaseID) if err != nil { return err } req.KnowledgeBase = proto.KBase{ID: kbase.ID, UUID: kbase.UUID, Name: kbase.Name, Description: kbase.Description, FileIDs: kbase.FileIDs, SessionID: kbase.SessionID} //选取合适的服务器进行查询 //先查看改知识库是否已经创建 key := fmt.Sprintf("kbase_%s", kbase.UUID) KBaseServerID := worker.GetRedis(key) //若未分配,从可用的知识库服务器中进行分配 setKey := "kbase_server_id_set_is_available" //可用状态集合 if KBaseServerID == "" { //KBaseServerSet := "kbase_server_id_set" //进行分配 KBaseServerID = worker.GetRedisSetRandomMember(setKey) if KBaseServerID == "" { return errors.New("no kbase server id") } //记录分配的知识库服务器 res := worker.SetRedisWithExpire(key, KBaseServerID, time.Hour*24*7) log.Println("分配的知识库服务器ID:", KBaseServerID, "\t key:", key, "\tset:", res) } //查看KBaseServerID是否可用 isAvailable := worker.SetRedisSetIsMember(setKey, KBaseServerID) if isAvailable == false { log.Println("当前知识库服务器不可用,重新分配") return errors.New("kbase server id is not available") } log.Println("KBaseServerID:", KBaseServerID, " KBaseUUID:", kbase.UUID, "\tName:", kbase.Name) reqChanel := fmt.Sprintf("kbase_server_id_%s", KBaseServerID) // 频道名称 reqStr, _ := json.Marshal(req) //将消息转发到对应频道 worker.Publish(reqChanel, string(reqStr), time.Minute*3) chanel := fmt.Sprintf("kbase_%s", kbase.UUID) //获取返回消息频道 pubsub := worker.RedisClient.Subscribe(context.Background(), chanel) defer pubsub.Close() ch := pubsub.Channel() //1min定时器,超时则返回错误 timer := time.NewTimer(time.Minute * 3) var resp proto.KnowledgeBaseServerResponse select { case m := <-ch: msg := m.Payload // 获取消息,消息格式为json //log.Println("resp:", msg) err2 := json.Unmarshal([]byte(msg), &resp) if err2 != nil { log.Println("Error parsing JSON:", err2) } break case <-timer.C: log.Println("kbase query timeout") return errors.New("kbase query timeout") } //解析返回的查询消息 text := "根据下面文档片段:\n" querySelect := resp.QuerySelect qText := getSelectedDocument(querySelect) log.Println("kbase type:", data.KBaseType) if data.KBaseType == "query" { //则将query为模型返回结果 //将消息存入数据库 err, _ = dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, qText, imCtx.FunctionID) if err != nil { log.Println("kbase query message save error:", err) } kbaseQueryToGeneralMassageAndSendMsgQueue(qText, imCtx.Channel, imCtx.SessionID, imCtx.UserID) //将消息发送到消息队列 } else if data.KBaseType == "chat" { //用户问题 text += qText text += "回答用户问题:" + data.Msg data.Msg = text log.Println("text:", text) } else { return errors.New("unknow kbase type") } return nil } func getSelectedDocument(querySelect proto.KnowledgeBaseServerResponseSelect) string { text := "" for i := 0; i < len(querySelect.Documents); i++ { if len(querySelect.Documents[i]) == 0 { break } text += "文件:" for j := 0; j < len(querySelect.Metadatas); j++ { text += querySelect.Metadatas[i][j].Source + "," } text += ":\n内容:" for j := 0; j < len(querySelect.Documents[i]); j++ { text += querySelect.Documents[i][j] + "," } text += "\n" } return text } func WSReturnErrorMessage(channel string, msg string, code int) { var data proto.WSMessage data.Code = code data.Type = msg //错误时type为错误信息 data.SessionID = 0 data.ToID = 0 errStr, _ := json.Marshal(data) worker.Publish(channel, string(errStr), time.Second*60) } func kbaseQueryToGeneralMassageAndSendMsgQueue(answer, channel string, SessionID, UserID int) { var aiMsg proto.AIQueueMessage var wsMsg proto.WSMessage var apiMsg api.GenerateResponse apiMsg.Done = true apiMsg.Model = proto.KnowledgeBaseServerResponseType apiMsg.Response = answer //消息队列部分 aiMsg.Type = proto.KnowledgeBaseServerResponseType aiMsg.Msg = apiMsg //ws发送消息部分 wsMsg.Msg = aiMsg wsMsg.SessionID = SessionID wsMsg.ToID = UserID wsMsg.Type = proto.KnowledgeBaseServerResponseType //发送消息 wsMsgStr, _ := json.Marshal(wsMsg) worker.Publish(channel, string(wsMsgStr), time.Second*60) }