diff --git a/dao/model.go b/dao/model.go index 56bda81..8568f58 100644 --- a/dao/model.go +++ b/dao/model.go @@ -23,6 +23,14 @@ type FunctionModel struct { //System string `gorm:"column:system"` //系统功能,对应模型系统参数 } +type ModelToken struct { + gorm.Model + ModelID uint `gorm:"column:model_id"` //模型id + UserID uint `gorm:"column:user_id"` //用户id + SessionID uint `gorm:"column:session_id"` //会话id + Token uint `gorm:"column:token"` //模型使用token +} + // 创建模型 func CreateModel(userID uint, modelType, url, parameter, description string) (error, uint) { model := Model{UserID: userID, Type: modelType, Url: url, Parameter: parameter, Description: description} @@ -157,3 +165,34 @@ func FindFunctionModelByFunction(function string) []FunctionModel { } return models } + +func CreateModelToken(userID uint, modelID uint, sessionID, token uint) (error, uint) { + modelToken := ModelToken{UserID: userID, ModelID: modelID, Token: token, SessionID: sessionID} + var res *gorm.DB + if proto.Config.SERVER_SQL_LOG { + res = DB.Debug().Create(&modelToken) + } else { + res = DB.Create(&modelToken) + } + return res.Error, modelToken.ID +} + +func FindModelTokenByModelID(modelID uint) []ModelToken { + var modelToken []ModelToken + if proto.Config.SERVER_SQL_LOG { + DB.Debug().Where("model_id = ?", modelID).Find(&modelToken) + } else { + DB.Where("model_id = ?", modelID).Find(&modelToken) + } + return modelToken +} + +func FindModelTokenByUserID(userID uint) []ModelToken { + var modelToken []ModelToken + if proto.Config.SERVER_SQL_LOG { + DB.Debug().Where("user_id = ?", userID).Find(&modelToken) + } else { + DB.Where("user_id = ?", userID).Find(&modelToken) + } + return modelToken +} diff --git a/service/doubao.go b/service/doubao.go index 4249e14..17679e1 100644 --- a/service/doubao.go +++ b/service/doubao.go @@ -108,6 +108,7 @@ func DouBaoV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) { }(stream) answer := "" + tokens := 0 for { recv, err3 := stream.Recv() if err3 == io.EOF { @@ -123,6 +124,7 @@ func DouBaoV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) { choices := recv.Choices[0] if choices.FinishReason == proto.FinishReasonStop { answer += choices.Delta.Content + tokens += recv.Usage.TotalTokens break } else { answer += choices.Delta.Content @@ -133,6 +135,12 @@ func DouBaoV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) { if err != nil { log.Println("spark create message error:", err) } + //将本次请求的tokens存入数据库 + err2, _ := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens)) + if err2 != nil { + log.Println("create model token error:", err2) + } + } func doubaoToGeneralMassageAndSendMsgQueue(data *model.ChatCompletionStreamResponse, channel string, SessionID, UserID int) { diff --git a/service/spark.go b/service/spark.go index b634574..fa5fd5c 100644 --- a/service/spark.go +++ b/service/spark.go @@ -211,6 +211,7 @@ func SparkV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) { }() answer := "" + tokens := 0 //接收消息 for { @@ -233,6 +234,7 @@ func SparkV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) { if data.Header.Status == 2 { //最后一个结果 answer += data.Payload.Choices.Text[0].Content + tokens += data.Payload.Usage.Text.TotalTokens } else { answer += data.Payload.Choices.Text[0].Content } @@ -248,6 +250,14 @@ func SparkV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) { } else { log.Println("spark create message success,message id:", msgID) } + //存入tokens消耗信息 + err3, modelTokenID := dao.CreateModelToken(uint(imCtx.UserID), uint(imCtx.ModelID), uint(imCtx.SessionID), uint(tokens)) + if err3 != nil { + log.Println("spark create message error:", err3) + } else { + log.Println("spark create model tokens id:", modelTokenID) + } + } func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, SessionID int, UserID int) {