添加通义模型接入,添加消息的token保存,修改保存模型返回消息及消息的token

This commit is contained in:
junleea 2025-05-11 20:31:39 +08:00
parent 0c48fc83a9
commit bdf374ba91
8 changed files with 257 additions and 8 deletions

View File

@ -24,9 +24,18 @@ type Message struct {
ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型则为模型id根据type判断
Msg string `gorm:"column:msg"` //消息内容
FunctionID int `gorm:"column:function_id"` //功能id
Token int `gorm:"column:token"` //消息消耗的token
Status int `gorm:"column:status"` //0为未读1为已读3为图片消息4为文件消息,5为制作ppt消息
}
type MessageToken struct {
gorm.Model
SessionID int `gorm:"column:session_id"` //会话id
UserID int `gorm:"column:user_id"` //用户id
MessageID int `gorm:"column:message_id"` //消息id
Token int `gorm:"column:token"` //消息消耗token
}
func CreateSession(userID, Type int, name string) (error, uint) {
session := Session{UserID: userID, Name: name, Type: Type}
var res *gorm.DB
@ -109,6 +118,15 @@ func CreateMessage(sessionID, fromID, toID, msgType, status int, msg string, fun
}
return res.Error, message.ID
}
func CreateMessageV2(message Message) (error, uint) {
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Create(&message)
} else {
res = DB.Create(&message)
}
return res.Error, message.ID
}
// 根据会话id获取消息
func FindMessageBySessionID(sessionID int) []Message {
@ -226,3 +244,62 @@ func FindMessageCountByDate() ([]SessionOrMessageCount, error) {
}
return messageCounts, nil
}
func CreateMessageToken(sessionID, userID, messageID, token int) (error, uint) {
messageToken := MessageToken{SessionID: sessionID, UserID: userID, MessageID: messageID, Token: token}
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Create(&messageToken)
} else {
res = DB.Create(&messageToken)
}
return res.Error, messageToken.ID
}
func FindMessageTokenBySessionID(sessionID int) []MessageToken {
var messageTokens []MessageToken
if proto.Config.SERVER_SQL_LOG {
DB.Debug().Where("session_id = ?", sessionID).Find(&messageTokens)
} else {
DB.Where("session_id = ?", sessionID).Find(&messageTokens)
}
return messageTokens
}
func FindMessageTokenByUserID(userID int) []MessageToken {
var messageTokens []MessageToken
if proto.Config.SERVER_SQL_LOG {
DB.Debug().Where("user_id = ?", userID).Find(&messageTokens)
} else {
DB.Where("user_id = ?", userID).Find(&messageTokens)
}
return messageTokens
}
func FindMessageTokenByMessageID(messageID int) []MessageToken {
var messageTokens []MessageToken
if proto.Config.SERVER_SQL_LOG {
DB.Debug().Where("message_id = ?", messageID).Find(&messageTokens)
} else {
DB.Where("message_id = ?", messageID).Find(&messageTokens)
}
return messageTokens
}
func UpdateMessageTokenByID(messageID int, token int) error {
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Model(&MessageToken{}).Where("id = ?", messageID).Update("token", token)
} else {
res = DB.Model(&MessageToken{}).Where("id = ?", messageID).Update("token", token)
}
return res.Error
}
func DeleteMessageTokenByID(messageID int) error {
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Delete(&MessageToken{}, messageID)
} else {
res = DB.Delete(&MessageToken{}, messageID)
}
return res.Error
}

5
go.mod
View File

@ -12,6 +12,7 @@ require (
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/ollama/ollama v0.6.2
github.com/openai/openai-go v0.1.0-beta.10
github.com/robfig/cron/v3 v3.0.1
github.com/volcengine/volcengine-go-sdk v1.0.187
gorm.io/driver/mysql v1.5.7
@ -62,6 +63,10 @@ require (
github.com/spf13/pflag v1.0.5 // indirect
github.com/spf13/viper v1.18.2 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tidwall/gjson v1.14.4 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/volcengine/volc-sdk-golang v1.0.23 // indirect

12
go.sum
View File

@ -127,6 +127,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
github.com/openai/openai-go v0.1.0-beta.10 h1:CknhGXe8aXQMRuqg255PFnWzgRY9nEryMxoNIBBM9tU=
github.com/openai/openai-go v0.1.0-beta.10/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@ -169,6 +171,16 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=

View File

@ -271,6 +271,7 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
if err != nil {
return err
} //机器人id为2,消息类型为1,状态为1
imContext.QuestionMessageID = int(userMsgID)
//若为知识库消息需要先根据用户消息获取知识库查询内容让ai进行最终结果生成
if data.IsKBase {

View File

@ -78,14 +78,15 @@ type SessionResponse struct {
}
type IMParamContext struct {
UserID int `json:"user_id"` //用户id
SessionID int `json:"session_id"` //会话id
FunctionID int `json:"function_id"` //功能id
ModelID int `json:"model_id"` //模型id
ModelType string `json:"model_type"` //模型类型
Question string `json:"question"` //问题
Channel string `json:"channel"` //消息队列
IsHasImage bool `json:"is_has_image"` //是否有图片
UserID int `json:"user_id"` //用户id
SessionID int `json:"session_id"` //会话id
FunctionID int `json:"function_id"` //功能id
ModelID int `json:"model_id"` //模型id
ModelType string `json:"model_type"` //模型类型
QuestionMessageID int `json:"question_message_id"` //问题消息id,用于后面的消息token消耗保存
Question string `json:"question"` //问题
Channel string `json:"channel"` //消息队列
IsHasImage bool `json:"is_has_image"` //是否有图片
}
type ImageMsgContent struct {

View File

@ -145,6 +145,7 @@ const (
ModelTypeDouBao = "doubao"
ModelTypeOllama = "ollama"
ModelTypeQianfan = "qianfan"
ModelTypeTongyi = "tongyi"
KnowledgeBaseServerResponseType = "kbase_query_resp"
)

View File

@ -136,6 +136,8 @@ func WSReceiveMessageServiceV2(modelParam proto.ModelParam, ctx *proto.IMParamCo
DouBaoV2(modelParam, ctx)
case proto.ModelTypeQianfan:
Qianfan(modelParam, ctx)
case proto.ModelTypeTongyi:
Tongyi(modelParam, ctx)
default:
resErr, resID = errors.New("model type not exist:"+ctx.ModelType), 0
}

150
service/tongyi.go Normal file
View File

@ -0,0 +1,150 @@
package service
import (
"StuAcaWorksAI/dao"
"StuAcaWorksAI/proto"
"StuAcaWorksAI/worker"
"context"
"encoding/json"
"github.com/ollama/ollama/api"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"
"log"
"time"
)
// 通义模型
func Tongyi(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
tongyiApiKey := modelParam.APIKey
tongyiClient := openai.NewClient(
option.WithAPIKey(tongyiApiKey),
option.WithBaseURL(modelParam.Url),
)
var messages []openai.ChatCompletionMessageParamUnion
err := GetTongyiHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID)
if err != nil {
println("GetTongyiHistoryMessages error:", err)
return
}
if proto.Config.LOG_OUTPUT {
messagesBytes, _ := json.Marshal(messages)
println("tongyi messages:", string(messagesBytes))
}
stream := tongyiClient.Chat.Completions.NewStreaming(context.TODO(), openai.ChatCompletionNewParams{
Messages: messages,
Model: modelParam.Model,
})
acc := openai.ChatCompletionAccumulator{}
answer := ""
var tokens int64
var PromptTokens int64 //用户问题token
for stream.Next() {
chunk := stream.Current()
go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, chunk, modelParam.Model)
acc.AddChunk(chunk)
if _, ok := acc.JustFinishedToolCall(); ok {
chunk.Created = -1 // 这里由于没有返回结束标志,所以用该字段来判断
}
if len(chunk.Choices) > 0 {
println(chunk.Choices[0].Delta.Content)
answer += chunk.Choices[0].Delta.Content
tokens = chunk.Usage.TotalTokens
PromptTokens = chunk.Usage.PromptTokens
}
TongyiToGeneralMassageAndSendMsgQueue(&chunk, imCtx.Channel, imCtx.SessionID, imCtx.UserID)
}
if stream.Err() != nil {
log.Println("tongyi stream err:", stream.Err())
}
//将消息存入数据库
DOSaveMessageAndToken(imCtx, answer, "tongyi", tokens, PromptTokens)
}
func DOSaveMessageAndToken(imCtx *proto.IMParamContext, answer, platform string, tokens, PromptTokens int64) {
msg := dao.Message{SessionID: imCtx.SessionID, FromID: imCtx.UserID, ToID: imCtx.ModelID, Msg: answer, Type: proto.ModelToUserMsgType, FunctionID: imCtx.FunctionID, Token: int(tokens), Status: proto.UserToModelMsgType}
err2, msgID := dao.CreateMessageV2(msg)
if err2 != nil {
log.Printf("%s create message error:%v\n", platform, err2)
} else {
log.Printf("%s create message success,message id:%d\n", platform, msgID)
}
//保存用户问题消耗token
err, messageTokenID := dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(PromptTokens))
if err != nil {
log.Printf("%s user prompt create message token error:%v\n", platform, err)
} else {
log.Printf("%s user prompt create message token id:%d\n", platform, messageTokenID)
}
//保存返回的消息token消耗
err, messageTokenID = dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(tokens))
if err != nil {
log.Printf("%s create message token error:%v\n", platform, err)
} else {
log.Printf("%s create message token id:%d\n", platform, messageTokenID)
}
//存入tokens消耗信息
err3, modelTokenID := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens))
if err3 != nil {
log.Printf("%s create model tokens error:%v\n", platform, err3)
} else {
log.Printf("%s create model tokens success,message id:%d\n", platform, modelTokenID)
}
}
func GetTongyiHistoryMessages(sessionID int, systemPrompt string, messages *[]openai.ChatCompletionMessageParamUnion, userID int) error {
//获取会话的历史消息
msgs := dao.FindMessageBySessionIDV2(sessionID)
if len(msgs) > proto.SparkContextLength {
//只取最后的 SparkContextLength 条消息
msgs = msgs[len(msgs)-proto.SparkContextLength-1:]
}
//添加系统消息
if systemPrompt != "" {
*messages = append(*messages, openai.SystemMessage(systemPrompt))
}
for _, msg := range msgs {
if msg.Type == proto.UserToModelMsgType {
if msg.Status == proto.UserToModelFileMsgType {
//err2 := decodeFileMessageToSparkMessage(msg.Msg, userID, messages)
//if err2 != nil {
// log.Println("spark decode file message error:", err2)
//}
} else {
*messages = append(*messages, openai.UserMessage(msg.Msg))
}
} else if msg.Type == proto.ModelToUserMsgType {
*messages = append(*messages, openai.AssistantMessage(msg.Msg))
} else {
continue
}
}
return nil
}
func TongyiToGeneralMassageAndSendMsgQueue(data *openai.ChatCompletionChunk, channel string, SessionID int, UserID int) {
var aiMsg proto.AIQueueMessage
var wsMsg proto.WSMessage
var apiMsg api.GenerateResponse
apiMsg.Model = proto.ModelTypeSpark
apiMsg.Response = data.Choices[0].Delta.Content
if data.Created < 0 { //这里由于没有返回结束标志,所以用该字段来判断
apiMsg.Done = true
} else {
apiMsg.Done = false
}
//消息队列部分
aiMsg.Type = proto.ModelTypeSpark
aiMsg.Msg = apiMsg
//ws发送消息部分
wsMsg.Msg = aiMsg
wsMsg.SessionID = SessionID
wsMsg.ToID = UserID
wsMsg.Type = proto.ModelTypeSpark
//发送消息
wsMsgStr, _ := json.Marshal(wsMsg)
worker.Publish(channel, string(wsMsgStr), time.Second*60)
}