saw-go/service/imService.go

457 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"StuAcaWorksAI/dao"
"StuAcaWorksAI/proto"
"StuAcaWorksAI/service/generalModel"
"StuAcaWorksAI/worker"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/ollama/ollama/api"
"log"
"time"
)
// 创建会话服务
func CreateSession(userID, Type int, name string) (error, uint) {
err, id := dao.CreateSession(userID, Type, name)
return err, id
}
// 将用户的提问总计为会话名称并更新会话
func UpdateSessionNameByAIAndID(id, userID int, originSessionName string) error {
var err error
//获取ai总结的会话名称
prompt := "请根据下面的对话内容提取为20字以内的主题作为会话名称\n" + originSessionName
//ai总结模型id
modelID := proto.Config.AISessionNameModelID //固定为混元lite模型免费
models := FindModelByID(modelID, 1)
if models == nil {
log.Println("UpdateSessionNameByAIAndID model not exist")
return errors.New("session name ai model not exist")
}
model := models[0]
var modelParam proto.ModelParam
err = json.Unmarshal([]byte(model.Parameter), &modelParam)
if err != nil {
log.Println("UpdateSessionNameByAIAndID model param error:", err)
return errors.New("session name ai model param error")
}
modelParam.Url = model.Url
newSessionName, err2 := generalModel.SingleTurnConversationText(prompt, modelParam)
if err2 != nil {
return err
}
//更新会话名称
err = UpdateSessionByID(id, userID, newSessionName, nil)
if err != nil {
log.Println("UpdateSessionNameByAIAndAI update session name error:", err, "session id:", id, "user id:", userID, "name:", newSessionName)
return err
}
return nil
}
// 查找用户的会话列表
func FindSessionByUserID(userID, sessionType int) []proto.SessionResponse {
var sessions []dao.Session
if sessionType == -1 {
sessions = dao.FindSessionByUserIDV2(userID)
} else {
sessions = dao.FindSessionByUserID(userID, sessionType)
}
sessionsMsgCounts := dao.FindSessionMessageCountByUserID(userID)
sessionsTokens := dao.FindModelTotalTokensBySessionID(userID)
smap := make(map[int]int)
st := make(map[int]uint)
for _, v := range sessionsMsgCounts {
smap[v.SessionID] = v.Count
}
for _, v := range sessionsTokens {
st[int(v.SessionID)] = v.Token
}
var res []proto.SessionResponse
for _, v := range sessions {
var session proto.SessionResponse
session.ID = v.ID
session.CreatedAt = v.CreatedAt
session.UpdatedAt = v.UpdatedAt
session.Name = v.Name
session.UserID = v.UserID
session.MsgCount = smap[int(v.ID)]
session.TokenUsage = st[int(v.ID)]
res = append(res, session)
}
return res
}
func FindSessionByID(id, userID, sessionType int) (error, []dao.Session) {
session := dao.FindSessionByID(id)
var res []dao.Session
if session.UserID != userID {
return errors.New("session not belong to user"), res
}
res = append(res, session)
return nil, res
}
func DeleteSessionByID(id, userID int) error {
session := dao.FindSessionByID(id)
if session.UserID != userID {
return errors.New("session not belong to user")
}
err := dao.DeleteSessionByID(id)
return err
}
// 更新会话名字
func UpdateSessionByID(id int, userId int, name string, context []int) error {
err := dao.UpdateSessionByID(id, userId, name, context)
return err
}
// 创建消息服务
func CreateMessage(msgType int, sessionID int, fromID int, toID int, msg string, status int, functionId int) (error, uint) {
//sessionID, fromID, toID, msgType, status int, msg string
err, id := dao.CreateMessage(sessionID, fromID, toID, msgType, status, msg, functionId)
return err, id
}
// 通过会话id查询消息
func FindMessageBySessionID(sessionID, userID int) (error, []dao.Message) {
//查看会话是否属于该用户
session := dao.FindSessionByID(sessionID)
if session.UserID != userID {
return errors.New("session not belong to user"), nil
}
msgs := dao.FindMessageBySessionIDV2(sessionID)
return nil, msgs
}
func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg string) (error, uint) {
//查看会话是否属于该用户
session := dao.FindSessionByID(sessionID)
if session.UserID != fromID {
return errors.New("session not belong to user"), 0
}
err, mid := CreateMessage(msgType, sessionID, fromID, toID, msg, 0, 0)
return err, mid
}
var client *api.Client
var SessionContextMap map[int][]int
func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq, model dao.Model, modelParam proto.ModelParam) (error, uint) {
var resErr error
var resID uint
//处理消息
if model.Type == proto.ModelTypeOllama {
resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam, &model)
} else if model.Type == proto.ModelTypeSpark {
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam, &model)
} else if model.Type == proto.ModelTypeDouBao {
doubao(modelParam, msg.Msg, channel, sessionID, userID, int(model.ID))
} else {
return errors.New("model not exist"), 0
}
return resErr, resID
}
func WSReceiveMessageServiceV2(modelParam proto.ModelParam, ctx *proto.IMParamContext) (error, uint) {
var resErr error
var resID uint
//处理消息
switch ctx.ModelType {
case proto.ModelTypeOllama:
resErr, resID = ReceiveOllamaSessionV2(modelParam, ctx)
case proto.ModelTypeSpark:
SparkV2(modelParam, ctx)
case proto.ModelTypeDouBao:
DouBaoV2(modelParam, ctx)
case proto.ModelTypeQianfan:
Qianfan(modelParam, ctx)
case proto.ModelTypeTongyi:
Tongyi(modelParam, ctx)
case proto.ModelTypeHunyuan:
generalModel.OpenAI(modelParam, ctx, ctx.ModelType)
case proto.ModelTypeGemini:
generalModel.SendOpenAIForeignAI(modelParam, ctx, ctx.ModelType)
default:
resErr, resID = errors.New("model type not exist:"+ctx.ModelType), 0
}
return resErr, resID
}
func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, m *dao.Model) (error, uint) {
var resErr error
var resID uint
var err error
if client == nil {
client, err = api.ClientFromEnvironment()
}
if err != nil {
log.Println("get ollama client error:", err)
}
prompt := msg.Msg
actx, ok := SessionContextMap[sessionID]
if ok {
log.Println("actx has get will use:", len(actx))
}
req := &api.GenerateRequest{
Model: modelParam.Model,
Context: actx,
Prompt: prompt,
}
ctx := context.Background()
robotMsg := ""
var aiMsg proto.AIQueueMessage
var data proto.WSMessage
respFunc := func(resp api.GenerateResponse) error {
aiMsg.Type = proto.ModelTypeOllama
aiMsg.Msg = resp
robotMsg += resp.Response
if resp.Done { //该消息完成
actx = resp.Context
if SessionContextMap == nil {
SessionContextMap = make(map[int][]int)
}
SessionContextMap[sessionID] = actx
log.Println("actx has save:", len(SessionContextMap[sessionID]))
//创建消息
err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1, 0) //机器人id为2,消息类型为2,状态为1
if err2 != nil {
log.Println("create robot message error:", err2)
}
aiMsg.ID = msgID
resID = msgID
}
//发送消息
data.Msg = aiMsg
data.SessionID = sessionID
data.ToID = userID
aiMsgStr, _ := json.Marshal(data)
worker.Publish(channel, string(aiMsgStr), time.Second*60)
go CreateAIStreamMsg(userID, int(m.ID), sessionID, resp, modelParam.Model)
return nil
}
err = client.Generate(ctx, req, respFunc)
if err != nil {
log.Fatal(err)
}
return resErr, resID
}
func ReceiveOllamaSessionV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) (error, uint) {
var resErr error
var resID uint
var err error
if client == nil {
client, err = api.ClientFromEnvironment()
}
if err != nil {
log.Println("get ollama client error:", err)
}
prompt := imCtx.Question
actx, ok := SessionContextMap[int(imCtx.SessionID)]
if ok {
log.Println("actx has get will use:", len(actx))
}
req := &api.GenerateRequest{
Model: modelParam.Model,
Context: actx,
Prompt: prompt,
}
ctx := context.Background()
robotMsg := ""
var aiMsg proto.AIQueueMessage
var data proto.WSMessage
respFunc := func(resp api.GenerateResponse) error {
aiMsg.Type = proto.ModelTypeOllama
aiMsg.Msg = resp
robotMsg += resp.Response
if resp.Done { //该消息完成
actx = resp.Context
if SessionContextMap == nil {
SessionContextMap = make(map[int][]int)
}
SessionContextMap[imCtx.SessionID] = actx
log.Println("actx has save:", len(SessionContextMap[imCtx.SessionID]))
//创建消息
err2, msgID := CreateMessage(proto.ModelToUserMsgType, imCtx.SessionID, imCtx.ModelID, imCtx.UserID, robotMsg, 1, imCtx.FunctionID) //机器人id为2,消息类型为2,状态为1
if err2 != nil {
log.Println("create robot message error:", err2)
}
aiMsg.ID = msgID
resID = msgID
}
//发送消息
data.Msg = aiMsg
data.SessionID = int(imCtx.SessionID)
data.ToID = imCtx.UserID
aiMsgStr, _ := json.Marshal(data)
worker.Publish(imCtx.Channel, string(aiMsgStr), time.Second*60)
go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, resp, modelParam.Model)
return nil
}
err = client.Generate(ctx, req, respFunc)
if err != nil {
log.Fatal(err)
}
return resErr, resID
}
func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, model *dao.Model) (error, uint) {
var resErr error
var resID uint
Spark(modelParam, msg.Msg, channel, sessionID, userID, int(model.ID))
return resErr, resID
}
func CheckUserCreatePPTSessionPermission(userID int) error {
sessionCount := dao.FindUserSessionCount(userID, proto.SessionTypeUserCreatePPT)
var err error
user := GetUserByIDFromUserCenter(userID)
if user.Role == "admin" {
return nil
}
if sessionCount > 3 {
err = errors.New("create ppt session count exceed limit")
}
return err
}
// 知识库消息,从原始请求数据,上下文数据
func SelectKBaseContentByQuery(imCtx *proto.IMParamContext, data *proto.WSMessageReq) error {
var req proto.KnowledgeBaseServerRequest
req.IMContext = *imCtx
kbase, err := dao.GetKnowledgeBaseByID(data.KBaseID)
if err != nil {
return err
}
req.KnowledgeBase = proto.KBase{ID: kbase.ID, UUID: kbase.UUID, Name: kbase.Name, Description: kbase.Description, FileIDs: kbase.FileIDs, SessionID: kbase.SessionID}
//选取合适的服务器进行查询
//先查看改知识库是否已经创建
key := fmt.Sprintf("kbase_%s", kbase.UUID)
KBaseServerID := worker.GetRedis(key)
//若未分配,从可用的知识库服务器中进行分配
setKey := "kbase_server_id_set_is_available" //可用状态集合
if KBaseServerID == "" {
//KBaseServerSet := "kbase_server_id_set"
//进行分配
KBaseServerID = worker.GetRedisSetRandomMember(setKey)
if KBaseServerID == "" {
return errors.New("no kbase server id")
}
//记录分配的知识库服务器
res := worker.SetRedisWithExpire(key, KBaseServerID, time.Hour*24*7)
log.Println("分配的知识库服务器ID:", KBaseServerID, "\t key:", key, "\tset:", res)
}
//查看KBaseServerID是否可用
isAvailable := worker.SetRedisSetIsMember(setKey, KBaseServerID)
if isAvailable == false {
log.Println("当前知识库服务器不可用,重新分配")
return errors.New("kbase server id is not available")
}
log.Println("KBaseServerID:", KBaseServerID, " KBaseUUID:", kbase.UUID, "\tName:", kbase.Name)
reqChanel := fmt.Sprintf("kbase_server_id_%s", KBaseServerID) // 频道名称
reqStr, _ := json.Marshal(req)
//将消息转发到对应频道
worker.Publish(reqChanel, string(reqStr), time.Minute*3)
chanel := fmt.Sprintf("kbase_%s", kbase.UUID) //获取返回消息频道
pubsub := worker.RedisClient.Subscribe(context.Background(), chanel)
defer pubsub.Close()
ch := pubsub.Channel()
//1min定时器超时则返回错误
timer := time.NewTimer(time.Minute * 3)
var resp proto.KnowledgeBaseServerResponse
select {
case m := <-ch:
msg := m.Payload // 获取消息,消息格式为json
//log.Println("resp:", msg)
err2 := json.Unmarshal([]byte(msg), &resp)
if err2 != nil {
log.Println("Error parsing JSON:", err2)
}
break
case <-timer.C:
log.Println("kbase query timeout")
return errors.New("kbase query timeout")
}
//解析返回的查询消息
text := "根据下面文档片段:\n"
querySelect := resp.QuerySelect
qText := getSelectedDocument(querySelect)
log.Println("kbase type:", data.KBaseType)
if data.KBaseType == "query" {
//则将query为模型返回结果
//将消息存入数据库
err, _ = dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, qText, imCtx.FunctionID)
if err != nil {
log.Println("kbase query message save error:", err)
}
kbaseQueryToGeneralMassageAndSendMsgQueue(qText, imCtx.Channel, imCtx.SessionID, imCtx.UserID) //将消息发送到消息队列
} else if data.KBaseType == "chat" {
//用户问题
text += qText
text += "回答用户问题:" + data.Msg
data.Msg = text
log.Println("text:", text)
} else {
return errors.New("unknow kbase type")
}
return nil
}
func getSelectedDocument(querySelect proto.KnowledgeBaseServerResponseSelect) string {
text := ""
for i := 0; i < len(querySelect.Documents); i++ {
if len(querySelect.Documents[i]) == 0 {
break
}
text += "文件:"
for j := 0; j < len(querySelect.Metadatas); j++ {
text += querySelect.Metadatas[i][j].Source + ""
}
text += ":\n内容"
for j := 0; j < len(querySelect.Documents[i]); j++ {
text += querySelect.Documents[i][j] + ""
}
text += "\n"
}
return text
}
func WSReturnErrorMessage(channel string, msg string, code int) {
var data proto.WSMessage
data.Code = code
data.Type = msg //错误时type为错误信息
data.SessionID = 0
data.ToID = 0
errStr, _ := json.Marshal(data)
worker.Publish(channel, string(errStr), time.Second*60)
}
func kbaseQueryToGeneralMassageAndSendMsgQueue(answer, channel string, SessionID, UserID int) {
var aiMsg proto.AIQueueMessage
var wsMsg proto.WSMessage
var apiMsg api.GenerateResponse
apiMsg.Done = true
apiMsg.Model = proto.KnowledgeBaseServerResponseType
apiMsg.Response = answer
//消息队列部分
aiMsg.Type = proto.KnowledgeBaseServerResponseType
aiMsg.Msg = apiMsg
//ws发送消息部分
wsMsg.Msg = aiMsg
wsMsg.SessionID = SessionID
wsMsg.ToID = UserID
wsMsg.Type = proto.KnowledgeBaseServerResponseType
//发送消息
wsMsgStr, _ := json.Marshal(wsMsg)
worker.Publish(channel, string(wsMsgStr), time.Second*60)
}