diff --git a/dao/im.go b/dao/im.go index 29e02d3..f8c5660 100644 --- a/dao/im.go +++ b/dao/im.go @@ -24,9 +24,18 @@ type Message struct { ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型,则为模型id,根据type判断 Msg string `gorm:"column:msg"` //消息内容 FunctionID int `gorm:"column:function_id"` //功能id + Token int `gorm:"column:token"` //消息消耗的token Status int `gorm:"column:status"` //0为未读,1为已读,3为图片消息,4为文件消息,5为制作ppt消息 } +type MessageToken struct { + gorm.Model + SessionID int `gorm:"column:session_id"` //会话id + UserID int `gorm:"column:user_id"` //用户id + MessageID int `gorm:"column:message_id"` //消息id + Token int `gorm:"column:token"` //消息消耗token +} + func CreateSession(userID, Type int, name string) (error, uint) { session := Session{UserID: userID, Name: name, Type: Type} var res *gorm.DB @@ -109,6 +118,15 @@ func CreateMessage(sessionID, fromID, toID, msgType, status int, msg string, fun } return res.Error, message.ID } +func CreateMessageV2(message Message) (error, uint) { + var res *gorm.DB + if proto.Config.SERVER_SQL_LOG { + res = DB.Debug().Create(&message) + } else { + res = DB.Create(&message) + } + return res.Error, message.ID +} // 根据会话id获取消息 func FindMessageBySessionID(sessionID int) []Message { @@ -226,3 +244,62 @@ func FindMessageCountByDate() ([]SessionOrMessageCount, error) { } return messageCounts, nil } + +func CreateMessageToken(sessionID, userID, messageID, token int) (error, uint) { + messageToken := MessageToken{SessionID: sessionID, UserID: userID, MessageID: messageID, Token: token} + var res *gorm.DB + if proto.Config.SERVER_SQL_LOG { + res = DB.Debug().Create(&messageToken) + } else { + res = DB.Create(&messageToken) + } + return res.Error, messageToken.ID +} +func FindMessageTokenBySessionID(sessionID int) []MessageToken { + var messageTokens []MessageToken + if proto.Config.SERVER_SQL_LOG { + DB.Debug().Where("session_id = ?", sessionID).Find(&messageTokens) + } else { + DB.Where("session_id = ?", sessionID).Find(&messageTokens) + } + return messageTokens +} + +func FindMessageTokenByUserID(userID int) []MessageToken { + var messageTokens []MessageToken + if proto.Config.SERVER_SQL_LOG { + DB.Debug().Where("user_id = ?", userID).Find(&messageTokens) + } else { + DB.Where("user_id = ?", userID).Find(&messageTokens) + } + return messageTokens +} +func FindMessageTokenByMessageID(messageID int) []MessageToken { + var messageTokens []MessageToken + if proto.Config.SERVER_SQL_LOG { + DB.Debug().Where("message_id = ?", messageID).Find(&messageTokens) + } else { + DB.Where("message_id = ?", messageID).Find(&messageTokens) + } + return messageTokens +} + +func UpdateMessageTokenByID(messageID int, token int) error { + var res *gorm.DB + if proto.Config.SERVER_SQL_LOG { + res = DB.Debug().Model(&MessageToken{}).Where("id = ?", messageID).Update("token", token) + } else { + res = DB.Model(&MessageToken{}).Where("id = ?", messageID).Update("token", token) + } + return res.Error +} + +func DeleteMessageTokenByID(messageID int) error { + var res *gorm.DB + if proto.Config.SERVER_SQL_LOG { + res = DB.Debug().Delete(&MessageToken{}, messageID) + } else { + res = DB.Delete(&MessageToken{}, messageID) + } + return res.Error +} diff --git a/go.mod b/go.mod index b6bff46..8db5d04 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/ollama/ollama v0.6.2 + github.com/openai/openai-go v0.1.0-beta.10 github.com/robfig/cron/v3 v3.0.1 github.com/volcengine/volcengine-go-sdk v1.0.187 gorm.io/driver/mysql v1.5.7 @@ -62,6 +63,10 @@ require ( github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/viper v1.18.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/tidwall/gjson v1.14.4 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect github.com/volcengine/volc-sdk-golang v1.0.23 // indirect diff --git a/go.sum b/go.sum index 176fbc6..8806908 100644 --- a/go.sum +++ b/go.sum @@ -127,6 +127,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs= +github.com/openai/openai-go v0.1.0-beta.10 h1:CknhGXe8aXQMRuqg255PFnWzgRY9nEryMxoNIBBM9tU= +github.com/openai/openai-go v0.1.0-beta.10/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -169,6 +171,16 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= diff --git a/handler/im.go b/handler/im.go index c51fc1f..98090b2 100644 --- a/handler/im.go +++ b/handler/im.go @@ -271,6 +271,7 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR if err != nil { return err } //机器人id为2,消息类型为1,状态为1 + imContext.QuestionMessageID = int(userMsgID) //若为知识库消息,需要先根据用户消息获取知识库查询内容让ai进行最终结果生成 if data.IsKBase { diff --git a/proto/im.go b/proto/im.go index cc90144..affc8b5 100644 --- a/proto/im.go +++ b/proto/im.go @@ -78,14 +78,15 @@ type SessionResponse struct { } type IMParamContext struct { - UserID int `json:"user_id"` //用户id - SessionID int `json:"session_id"` //会话id - FunctionID int `json:"function_id"` //功能id - ModelID int `json:"model_id"` //模型id - ModelType string `json:"model_type"` //模型类型 - Question string `json:"question"` //问题 - Channel string `json:"channel"` //消息队列 - IsHasImage bool `json:"is_has_image"` //是否有图片 + UserID int `json:"user_id"` //用户id + SessionID int `json:"session_id"` //会话id + FunctionID int `json:"function_id"` //功能id + ModelID int `json:"model_id"` //模型id + ModelType string `json:"model_type"` //模型类型 + QuestionMessageID int `json:"question_message_id"` //问题消息id,用于后面的消息token消耗保存 + Question string `json:"question"` //问题 + Channel string `json:"channel"` //消息队列 + IsHasImage bool `json:"is_has_image"` //是否有图片 } type ImageMsgContent struct { diff --git a/proto/status.go b/proto/status.go index d88c4b0..8f93711 100644 --- a/proto/status.go +++ b/proto/status.go @@ -145,6 +145,7 @@ const ( ModelTypeDouBao = "doubao" ModelTypeOllama = "ollama" ModelTypeQianfan = "qianfan" + ModelTypeTongyi = "tongyi" KnowledgeBaseServerResponseType = "kbase_query_resp" ) diff --git a/service/imService.go b/service/imService.go index 3980548..297d73f 100644 --- a/service/imService.go +++ b/service/imService.go @@ -136,6 +136,8 @@ func WSReceiveMessageServiceV2(modelParam proto.ModelParam, ctx *proto.IMParamCo DouBaoV2(modelParam, ctx) case proto.ModelTypeQianfan: Qianfan(modelParam, ctx) + case proto.ModelTypeTongyi: + Tongyi(modelParam, ctx) default: resErr, resID = errors.New("model type not exist:"+ctx.ModelType), 0 } diff --git a/service/tongyi.go b/service/tongyi.go new file mode 100644 index 0000000..53bbba5 --- /dev/null +++ b/service/tongyi.go @@ -0,0 +1,150 @@ +package service + +import ( + "StuAcaWorksAI/dao" + "StuAcaWorksAI/proto" + "StuAcaWorksAI/worker" + "context" + "encoding/json" + "github.com/ollama/ollama/api" + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + "log" + "time" +) + +// 通义模型 +func Tongyi(modelParam proto.ModelParam, imCtx *proto.IMParamContext) { + tongyiApiKey := modelParam.APIKey + tongyiClient := openai.NewClient( + option.WithAPIKey(tongyiApiKey), + option.WithBaseURL(modelParam.Url), + ) + var messages []openai.ChatCompletionMessageParamUnion + err := GetTongyiHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID) + if err != nil { + println("GetTongyiHistoryMessages error:", err) + return + } + if proto.Config.LOG_OUTPUT { + messagesBytes, _ := json.Marshal(messages) + println("tongyi messages:", string(messagesBytes)) + } + + stream := tongyiClient.Chat.Completions.NewStreaming(context.TODO(), openai.ChatCompletionNewParams{ + Messages: messages, + Model: modelParam.Model, + }) + acc := openai.ChatCompletionAccumulator{} + answer := "" + var tokens int64 + var PromptTokens int64 //用户问题token + + for stream.Next() { + chunk := stream.Current() + go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, chunk, modelParam.Model) + acc.AddChunk(chunk) + if _, ok := acc.JustFinishedToolCall(); ok { + chunk.Created = -1 // 这里由于没有返回结束标志,所以用该字段来判断 + } + if len(chunk.Choices) > 0 { + println(chunk.Choices[0].Delta.Content) + answer += chunk.Choices[0].Delta.Content + tokens = chunk.Usage.TotalTokens + PromptTokens = chunk.Usage.PromptTokens + } + TongyiToGeneralMassageAndSendMsgQueue(&chunk, imCtx.Channel, imCtx.SessionID, imCtx.UserID) + } + if stream.Err() != nil { + log.Println("tongyi stream err:", stream.Err()) + } + //将消息存入数据库 + DOSaveMessageAndToken(imCtx, answer, "tongyi", tokens, PromptTokens) +} + +func DOSaveMessageAndToken(imCtx *proto.IMParamContext, answer, platform string, tokens, PromptTokens int64) { + msg := dao.Message{SessionID: imCtx.SessionID, FromID: imCtx.UserID, ToID: imCtx.ModelID, Msg: answer, Type: proto.ModelToUserMsgType, FunctionID: imCtx.FunctionID, Token: int(tokens), Status: proto.UserToModelMsgType} + err2, msgID := dao.CreateMessageV2(msg) + if err2 != nil { + log.Printf("%s create message error:%v\n", platform, err2) + } else { + log.Printf("%s create message success,message id:%d\n", platform, msgID) + } + //保存用户问题消耗token + err, messageTokenID := dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(PromptTokens)) + if err != nil { + log.Printf("%s user prompt create message token error:%v\n", platform, err) + } else { + log.Printf("%s user prompt create message token id:%d\n", platform, messageTokenID) + } + + //保存返回的消息token消耗 + err, messageTokenID = dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(tokens)) + if err != nil { + log.Printf("%s create message token error:%v\n", platform, err) + } else { + log.Printf("%s create message token id:%d\n", platform, messageTokenID) + } + //存入tokens消耗信息 + err3, modelTokenID := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens)) + if err3 != nil { + log.Printf("%s create model tokens error:%v\n", platform, err3) + } else { + log.Printf("%s create model tokens success,message id:%d\n", platform, modelTokenID) + } +} + +func GetTongyiHistoryMessages(sessionID int, systemPrompt string, messages *[]openai.ChatCompletionMessageParamUnion, userID int) error { + //获取会话的历史消息 + msgs := dao.FindMessageBySessionIDV2(sessionID) + if len(msgs) > proto.SparkContextLength { + //只取最后的 SparkContextLength 条消息 + msgs = msgs[len(msgs)-proto.SparkContextLength-1:] + } + //添加系统消息 + if systemPrompt != "" { + *messages = append(*messages, openai.SystemMessage(systemPrompt)) + } + for _, msg := range msgs { + if msg.Type == proto.UserToModelMsgType { + if msg.Status == proto.UserToModelFileMsgType { + //err2 := decodeFileMessageToSparkMessage(msg.Msg, userID, messages) + //if err2 != nil { + // log.Println("spark decode file message error:", err2) + //} + } else { + *messages = append(*messages, openai.UserMessage(msg.Msg)) + } + } else if msg.Type == proto.ModelToUserMsgType { + *messages = append(*messages, openai.AssistantMessage(msg.Msg)) + } else { + continue + } + } + return nil +} + +func TongyiToGeneralMassageAndSendMsgQueue(data *openai.ChatCompletionChunk, channel string, SessionID int, UserID int) { + var aiMsg proto.AIQueueMessage + var wsMsg proto.WSMessage + var apiMsg api.GenerateResponse + apiMsg.Model = proto.ModelTypeSpark + apiMsg.Response = data.Choices[0].Delta.Content + if data.Created < 0 { //这里由于没有返回结束标志,所以用该字段来判断 + apiMsg.Done = true + } else { + apiMsg.Done = false + } + //消息队列部分 + aiMsg.Type = proto.ModelTypeSpark + aiMsg.Msg = apiMsg + //ws发送消息部分 + wsMsg.Msg = aiMsg + wsMsg.SessionID = SessionID + wsMsg.ToID = UserID + wsMsg.Type = proto.ModelTypeSpark + + //发送消息 + wsMsgStr, _ := json.Marshal(wsMsg) + worker.Publish(channel, string(wsMsgStr), time.Second*60) +}