添加通义模型接入,添加消息的token保存,修改保存模型返回消息及消息的token
This commit is contained in:
parent
0c48fc83a9
commit
bdf374ba91
77
dao/im.go
77
dao/im.go
|
|
@ -24,9 +24,18 @@ type Message struct {
|
|||
ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型,则为模型id,根据type判断
|
||||
Msg string `gorm:"column:msg"` //消息内容
|
||||
FunctionID int `gorm:"column:function_id"` //功能id
|
||||
Token int `gorm:"column:token"` //消息消耗的token
|
||||
Status int `gorm:"column:status"` //0为未读,1为已读,3为图片消息,4为文件消息,5为制作ppt消息
|
||||
}
|
||||
|
||||
type MessageToken struct {
|
||||
gorm.Model
|
||||
SessionID int `gorm:"column:session_id"` //会话id
|
||||
UserID int `gorm:"column:user_id"` //用户id
|
||||
MessageID int `gorm:"column:message_id"` //消息id
|
||||
Token int `gorm:"column:token"` //消息消耗token
|
||||
}
|
||||
|
||||
func CreateSession(userID, Type int, name string) (error, uint) {
|
||||
session := Session{UserID: userID, Name: name, Type: Type}
|
||||
var res *gorm.DB
|
||||
|
|
@ -109,6 +118,15 @@ func CreateMessage(sessionID, fromID, toID, msgType, status int, msg string, fun
|
|||
}
|
||||
return res.Error, message.ID
|
||||
}
|
||||
func CreateMessageV2(message Message) (error, uint) {
|
||||
var res *gorm.DB
|
||||
if proto.Config.SERVER_SQL_LOG {
|
||||
res = DB.Debug().Create(&message)
|
||||
} else {
|
||||
res = DB.Create(&message)
|
||||
}
|
||||
return res.Error, message.ID
|
||||
}
|
||||
|
||||
// 根据会话id获取消息
|
||||
func FindMessageBySessionID(sessionID int) []Message {
|
||||
|
|
@ -226,3 +244,62 @@ func FindMessageCountByDate() ([]SessionOrMessageCount, error) {
|
|||
}
|
||||
return messageCounts, nil
|
||||
}
|
||||
|
||||
func CreateMessageToken(sessionID, userID, messageID, token int) (error, uint) {
|
||||
messageToken := MessageToken{SessionID: sessionID, UserID: userID, MessageID: messageID, Token: token}
|
||||
var res *gorm.DB
|
||||
if proto.Config.SERVER_SQL_LOG {
|
||||
res = DB.Debug().Create(&messageToken)
|
||||
} else {
|
||||
res = DB.Create(&messageToken)
|
||||
}
|
||||
return res.Error, messageToken.ID
|
||||
}
|
||||
func FindMessageTokenBySessionID(sessionID int) []MessageToken {
|
||||
var messageTokens []MessageToken
|
||||
if proto.Config.SERVER_SQL_LOG {
|
||||
DB.Debug().Where("session_id = ?", sessionID).Find(&messageTokens)
|
||||
} else {
|
||||
DB.Where("session_id = ?", sessionID).Find(&messageTokens)
|
||||
}
|
||||
return messageTokens
|
||||
}
|
||||
|
||||
func FindMessageTokenByUserID(userID int) []MessageToken {
|
||||
var messageTokens []MessageToken
|
||||
if proto.Config.SERVER_SQL_LOG {
|
||||
DB.Debug().Where("user_id = ?", userID).Find(&messageTokens)
|
||||
} else {
|
||||
DB.Where("user_id = ?", userID).Find(&messageTokens)
|
||||
}
|
||||
return messageTokens
|
||||
}
|
||||
func FindMessageTokenByMessageID(messageID int) []MessageToken {
|
||||
var messageTokens []MessageToken
|
||||
if proto.Config.SERVER_SQL_LOG {
|
||||
DB.Debug().Where("message_id = ?", messageID).Find(&messageTokens)
|
||||
} else {
|
||||
DB.Where("message_id = ?", messageID).Find(&messageTokens)
|
||||
}
|
||||
return messageTokens
|
||||
}
|
||||
|
||||
func UpdateMessageTokenByID(messageID int, token int) error {
|
||||
var res *gorm.DB
|
||||
if proto.Config.SERVER_SQL_LOG {
|
||||
res = DB.Debug().Model(&MessageToken{}).Where("id = ?", messageID).Update("token", token)
|
||||
} else {
|
||||
res = DB.Model(&MessageToken{}).Where("id = ?", messageID).Update("token", token)
|
||||
}
|
||||
return res.Error
|
||||
}
|
||||
|
||||
func DeleteMessageTokenByID(messageID int) error {
|
||||
var res *gorm.DB
|
||||
if proto.Config.SERVER_SQL_LOG {
|
||||
res = DB.Debug().Delete(&MessageToken{}, messageID)
|
||||
} else {
|
||||
res = DB.Delete(&MessageToken{}, messageID)
|
||||
}
|
||||
return res.Error
|
||||
}
|
||||
|
|
|
|||
5
go.mod
5
go.mod
|
|
@ -12,6 +12,7 @@ require (
|
|||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/ollama/ollama v0.6.2
|
||||
github.com/openai/openai-go v0.1.0-beta.10
|
||||
github.com/robfig/cron/v3 v3.0.1
|
||||
github.com/volcengine/volcengine-go-sdk v1.0.187
|
||||
gorm.io/driver/mysql v1.5.7
|
||||
|
|
@ -62,6 +63,10 @@ require (
|
|||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
github.com/spf13/viper v1.18.2 // indirect
|
||||
github.com/subosito/gotenv v1.6.0 // indirect
|
||||
github.com/tidwall/gjson v1.14.4 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
github.com/volcengine/volc-sdk-golang v1.0.23 // indirect
|
||||
|
|
|
|||
12
go.sum
12
go.sum
|
|
@ -127,6 +127,8 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
|||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
|
||||
github.com/onsi/gomega v1.18.1/go.mod h1:0q+aL8jAiMXy9hbwj2mr5GziHiwhAIQpFmmtT5hitRs=
|
||||
github.com/openai/openai-go v0.1.0-beta.10 h1:CknhGXe8aXQMRuqg255PFnWzgRY9nEryMxoNIBBM9tU=
|
||||
github.com/openai/openai-go v0.1.0-beta.10/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
|
|
@ -169,6 +171,16 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
|
|||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
|
||||
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
|
||||
github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||
|
|
|
|||
|
|
@ -271,6 +271,7 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
|
|||
if err != nil {
|
||||
return err
|
||||
} //机器人id为2,消息类型为1,状态为1
|
||||
imContext.QuestionMessageID = int(userMsgID)
|
||||
|
||||
//若为知识库消息,需要先根据用户消息获取知识库查询内容让ai进行最终结果生成
|
||||
if data.IsKBase {
|
||||
|
|
|
|||
17
proto/im.go
17
proto/im.go
|
|
@ -78,14 +78,15 @@ type SessionResponse struct {
|
|||
}
|
||||
|
||||
type IMParamContext struct {
|
||||
UserID int `json:"user_id"` //用户id
|
||||
SessionID int `json:"session_id"` //会话id
|
||||
FunctionID int `json:"function_id"` //功能id
|
||||
ModelID int `json:"model_id"` //模型id
|
||||
ModelType string `json:"model_type"` //模型类型
|
||||
Question string `json:"question"` //问题
|
||||
Channel string `json:"channel"` //消息队列
|
||||
IsHasImage bool `json:"is_has_image"` //是否有图片
|
||||
UserID int `json:"user_id"` //用户id
|
||||
SessionID int `json:"session_id"` //会话id
|
||||
FunctionID int `json:"function_id"` //功能id
|
||||
ModelID int `json:"model_id"` //模型id
|
||||
ModelType string `json:"model_type"` //模型类型
|
||||
QuestionMessageID int `json:"question_message_id"` //问题消息id,用于后面的消息token消耗保存
|
||||
Question string `json:"question"` //问题
|
||||
Channel string `json:"channel"` //消息队列
|
||||
IsHasImage bool `json:"is_has_image"` //是否有图片
|
||||
}
|
||||
|
||||
type ImageMsgContent struct {
|
||||
|
|
|
|||
|
|
@ -145,6 +145,7 @@ const (
|
|||
ModelTypeDouBao = "doubao"
|
||||
ModelTypeOllama = "ollama"
|
||||
ModelTypeQianfan = "qianfan"
|
||||
ModelTypeTongyi = "tongyi"
|
||||
KnowledgeBaseServerResponseType = "kbase_query_resp"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -136,6 +136,8 @@ func WSReceiveMessageServiceV2(modelParam proto.ModelParam, ctx *proto.IMParamCo
|
|||
DouBaoV2(modelParam, ctx)
|
||||
case proto.ModelTypeQianfan:
|
||||
Qianfan(modelParam, ctx)
|
||||
case proto.ModelTypeTongyi:
|
||||
Tongyi(modelParam, ctx)
|
||||
default:
|
||||
resErr, resID = errors.New("model type not exist:"+ctx.ModelType), 0
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,150 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"StuAcaWorksAI/dao"
|
||||
"StuAcaWorksAI/proto"
|
||||
"StuAcaWorksAI/worker"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/openai/openai-go"
|
||||
"github.com/openai/openai-go/option"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 通义模型
|
||||
func Tongyi(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
|
||||
tongyiApiKey := modelParam.APIKey
|
||||
tongyiClient := openai.NewClient(
|
||||
option.WithAPIKey(tongyiApiKey),
|
||||
option.WithBaseURL(modelParam.Url),
|
||||
)
|
||||
var messages []openai.ChatCompletionMessageParamUnion
|
||||
err := GetTongyiHistoryMessages(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID)
|
||||
if err != nil {
|
||||
println("GetTongyiHistoryMessages error:", err)
|
||||
return
|
||||
}
|
||||
if proto.Config.LOG_OUTPUT {
|
||||
messagesBytes, _ := json.Marshal(messages)
|
||||
println("tongyi messages:", string(messagesBytes))
|
||||
}
|
||||
|
||||
stream := tongyiClient.Chat.Completions.NewStreaming(context.TODO(), openai.ChatCompletionNewParams{
|
||||
Messages: messages,
|
||||
Model: modelParam.Model,
|
||||
})
|
||||
acc := openai.ChatCompletionAccumulator{}
|
||||
answer := ""
|
||||
var tokens int64
|
||||
var PromptTokens int64 //用户问题token
|
||||
|
||||
for stream.Next() {
|
||||
chunk := stream.Current()
|
||||
go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, chunk, modelParam.Model)
|
||||
acc.AddChunk(chunk)
|
||||
if _, ok := acc.JustFinishedToolCall(); ok {
|
||||
chunk.Created = -1 // 这里由于没有返回结束标志,所以用该字段来判断
|
||||
}
|
||||
if len(chunk.Choices) > 0 {
|
||||
println(chunk.Choices[0].Delta.Content)
|
||||
answer += chunk.Choices[0].Delta.Content
|
||||
tokens = chunk.Usage.TotalTokens
|
||||
PromptTokens = chunk.Usage.PromptTokens
|
||||
}
|
||||
TongyiToGeneralMassageAndSendMsgQueue(&chunk, imCtx.Channel, imCtx.SessionID, imCtx.UserID)
|
||||
}
|
||||
if stream.Err() != nil {
|
||||
log.Println("tongyi stream err:", stream.Err())
|
||||
}
|
||||
//将消息存入数据库
|
||||
DOSaveMessageAndToken(imCtx, answer, "tongyi", tokens, PromptTokens)
|
||||
}
|
||||
|
||||
func DOSaveMessageAndToken(imCtx *proto.IMParamContext, answer, platform string, tokens, PromptTokens int64) {
|
||||
msg := dao.Message{SessionID: imCtx.SessionID, FromID: imCtx.UserID, ToID: imCtx.ModelID, Msg: answer, Type: proto.ModelToUserMsgType, FunctionID: imCtx.FunctionID, Token: int(tokens), Status: proto.UserToModelMsgType}
|
||||
err2, msgID := dao.CreateMessageV2(msg)
|
||||
if err2 != nil {
|
||||
log.Printf("%s create message error:%v\n", platform, err2)
|
||||
} else {
|
||||
log.Printf("%s create message success,message id:%d\n", platform, msgID)
|
||||
}
|
||||
//保存用户问题消耗token
|
||||
err, messageTokenID := dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(PromptTokens))
|
||||
if err != nil {
|
||||
log.Printf("%s user prompt create message token error:%v\n", platform, err)
|
||||
} else {
|
||||
log.Printf("%s user prompt create message token id:%d\n", platform, messageTokenID)
|
||||
}
|
||||
|
||||
//保存返回的消息token消耗
|
||||
err, messageTokenID = dao.CreateMessageToken(imCtx.SessionID, imCtx.UserID, int(msgID), int(tokens))
|
||||
if err != nil {
|
||||
log.Printf("%s create message token error:%v\n", platform, err)
|
||||
} else {
|
||||
log.Printf("%s create message token id:%d\n", platform, messageTokenID)
|
||||
}
|
||||
//存入tokens消耗信息
|
||||
err3, modelTokenID := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens))
|
||||
if err3 != nil {
|
||||
log.Printf("%s create model tokens error:%v\n", platform, err3)
|
||||
} else {
|
||||
log.Printf("%s create model tokens success,message id:%d\n", platform, modelTokenID)
|
||||
}
|
||||
}
|
||||
|
||||
func GetTongyiHistoryMessages(sessionID int, systemPrompt string, messages *[]openai.ChatCompletionMessageParamUnion, userID int) error {
|
||||
//获取会话的历史消息
|
||||
msgs := dao.FindMessageBySessionIDV2(sessionID)
|
||||
if len(msgs) > proto.SparkContextLength {
|
||||
//只取最后的 SparkContextLength 条消息
|
||||
msgs = msgs[len(msgs)-proto.SparkContextLength-1:]
|
||||
}
|
||||
//添加系统消息
|
||||
if systemPrompt != "" {
|
||||
*messages = append(*messages, openai.SystemMessage(systemPrompt))
|
||||
}
|
||||
for _, msg := range msgs {
|
||||
if msg.Type == proto.UserToModelMsgType {
|
||||
if msg.Status == proto.UserToModelFileMsgType {
|
||||
//err2 := decodeFileMessageToSparkMessage(msg.Msg, userID, messages)
|
||||
//if err2 != nil {
|
||||
// log.Println("spark decode file message error:", err2)
|
||||
//}
|
||||
} else {
|
||||
*messages = append(*messages, openai.UserMessage(msg.Msg))
|
||||
}
|
||||
} else if msg.Type == proto.ModelToUserMsgType {
|
||||
*messages = append(*messages, openai.AssistantMessage(msg.Msg))
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TongyiToGeneralMassageAndSendMsgQueue(data *openai.ChatCompletionChunk, channel string, SessionID int, UserID int) {
|
||||
var aiMsg proto.AIQueueMessage
|
||||
var wsMsg proto.WSMessage
|
||||
var apiMsg api.GenerateResponse
|
||||
apiMsg.Model = proto.ModelTypeSpark
|
||||
apiMsg.Response = data.Choices[0].Delta.Content
|
||||
if data.Created < 0 { //这里由于没有返回结束标志,所以用该字段来判断
|
||||
apiMsg.Done = true
|
||||
} else {
|
||||
apiMsg.Done = false
|
||||
}
|
||||
//消息队列部分
|
||||
aiMsg.Type = proto.ModelTypeSpark
|
||||
aiMsg.Msg = apiMsg
|
||||
//ws发送消息部分
|
||||
wsMsg.Msg = aiMsg
|
||||
wsMsg.SessionID = SessionID
|
||||
wsMsg.ToID = UserID
|
||||
wsMsg.Type = proto.ModelTypeSpark
|
||||
|
||||
//发送消息
|
||||
wsMsgStr, _ := json.Marshal(wsMsg)
|
||||
worker.Publish(channel, string(wsMsgStr), time.Second*60)
|
||||
}
|
||||
Loading…
Reference in New Issue