添加模型及temperature参数及附属参数设置

This commit is contained in:
junleea 2025-04-22 16:46:19 +08:00
parent ed030e3acd
commit d3d3e6f594
4 changed files with 96 additions and 20 deletions

View File

@ -256,6 +256,8 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
} }
modelParam.Url = model.Url modelParam.Url = model.Url
modelParam.System = funcs[0].Info //系统功能 modelParam.System = funcs[0].Info //系统功能
//设置模型temperature及其他参数
checkAndSetModelTemperatureParam(&modelParam, data, &model)
imContext := proto.IMParamContext{UserID: userId, SessionID: int(*sessionID), FunctionID: int(funcs[0].ID), ModelID: int(model.ID), Question: data.Msg, Channel: chanel, ModelType: model.Type} imContext := proto.IMParamContext{UserID: userId, SessionID: int(*sessionID), FunctionID: int(funcs[0].ID), ModelID: int(model.ID), Question: data.Msg, Channel: chanel, ModelType: model.Type}
var userMsgID uint var userMsgID uint
@ -381,3 +383,31 @@ func subscribeAndHandleKBaseTextMessages(ws *websocket.Conn, userId int, kBaseSe
} }
} }
} }
func checkAndSetModelTemperatureParam(modelParam *proto.ModelParam, data *proto.WSMessageReq, model *dao.Model) {
isErrParam := false
if data.Temperature < 0 || data.Temperature > 2 {
isErrParam = true
} else {
modelParam.Temperature = data.Temperature
}
if model.Type == proto.ModelTypeDouBao {
if data.TopP < 0 || data.TopP > 1 {
isErrParam = true
} else {
modelParam.TopP = data.TopP
}
} else if model.Type == proto.ModelTypeSpark {
if data.TopK < 0 || data.TopK > 1 {
isErrParam = true
} else {
modelParam.TopK = data.TopK
}
} else {
isErrParam = true
}
if isErrParam {
modelParam.SetDefaultParams()
}
}

View File

@ -31,6 +31,9 @@ type WSMessageReq struct {
IsKBase bool `json:"is_kbase"` //是否为知识库消息 IsKBase bool `json:"is_kbase"` //是否为知识库消息
KBaseID uint `json:"kbase_id"` //知识库id,为知识库消息时必传 KBaseID uint `json:"kbase_id"` //知识库id,为知识库消息时必传
KBaseType string `json:"kbase_type"` //知识库模式1为query 2为chat KBaseType string `json:"kbase_type"` //知识库模式1为query 2为chat
Temperature float32 `json:"temperature"` //温度,默认0.5,核采样阈值。取值越高随机性越强,即相同的问题得到的不同答案的可能性越大
TopP float32 `json:"top_p"` //doubao 模型的top_p参数
TopK float32 `json:"top_k"` //spark 模型的top_k参数
} }
type ModelParam struct { type ModelParam struct {
@ -40,12 +43,12 @@ type ModelParam struct {
APPID string `json:"appid"` //应用id APPID string `json:"appid"` //应用id
APISecret string `json:"apiSecret"` //应用密钥 APISecret string `json:"apiSecret"` //应用密钥
APIKey string `json:"apiKey"` //应用key APIKey string `json:"apiKey"` //应用key
System string `json:"system"` // System string `json:"system"` //系统参数
//其他参数 //其他参数
Temperature float64 `json:"temperature"` //温度,默认0.5,核采样阈值。取值越高随机性越强,即相同的问题得到的不同答案的可能性越大 Temperature float32 `json:"temperature"` //温度,默认0.5,核采样阈值。取值越高随机性越强,即相同的问题得到的不同答案的可能性越大
MaxTokens int `json:"maxTokens"` //最大生成长度,默认4096 MaxTokens int `json:"maxTokens"` //最大生成长度,默认4096
TopK float64 `json:"topK"` //取概率最大的前k个词 TopK float32 `json:"topK"` //取概率最大的前k个词
TopP float64 `json:"topP"` //取 TopP float32 `json:"topP"` //取
} }
func (m *ModelParam) SetDefaultParams() { func (m *ModelParam) SetDefaultParams() {

View File

@ -90,9 +90,13 @@ func DouBaoV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
if err != nil { if err != nil {
log.Println("get doubao session history message error:", err) log.Println("get doubao session history message error:", err)
} }
maxTokens := 10240
req := model.CreateChatCompletionRequest{ req := model.CreateChatCompletionRequest{
Model: modelParam.Model, Model: modelParam.Model,
Temperature: &modelParam.Temperature,
TopP: &modelParam.TopP,
MaxTokens: &maxTokens, //设置最大token数设置为10k
Messages: messages, Messages: messages,
StreamOptions: &model.StreamOptions{ StreamOptions: &model.StreamOptions{
IncludeUsage: true, IncludeUsage: true,
@ -165,8 +169,13 @@ func doubaoToGeneralMassageAndSendMsgQueue(data *model.ChatCompletionStreamRespo
apiMsg.Model = proto.ModelTypeDouBao apiMsg.Model = proto.ModelTypeDouBao
choices := data.Choices[0] choices := data.Choices[0]
if choices.Delta.Content != "" {
apiMsg.Response = choices.Delta.Content apiMsg.Response = choices.Delta.Content
if choices.FinishReason == proto.FinishReasonStop { } else if choices.Delta.ReasoningContent != nil {
apiMsg.Response = *choices.Delta.ReasoningContent //深度思考部分
}
if choices.FinishReason == proto.FinishReasonStop || choices.FinishReason == proto.FinishReasonLength { //停止原因是stop或长度超限
apiMsg.Done = true apiMsg.Done = true
} else { } else {
apiMsg.Done = false apiMsg.Done = false

View File

@ -199,7 +199,7 @@ func SparkV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
//发送消息 //发送消息
go func() { go func() {
data := genSparkParams(imCtx.UserID, modelParam.APPID, modelParam.Domain, imCtx.SessionID, modelParam.System) data := genSparkParamsV2(imCtx.UserID, modelParam.APPID, modelParam.Domain, imCtx.SessionID, modelParam.System, float64(modelParam.Temperature))
//将数据转换为json //将数据转换为json
//dataByte, err3 := json.Marshal(data) //dataByte, err3 := json.Marshal(data)
//if err3 != nil { //if err3 != nil {
@ -323,6 +323,40 @@ func genSparkParams(userID int, appid string, domain string, sessionID int, syst
return data return data
} }
// 生成参数
func genSparkParamsV2(userID int, appid string, domain string, sessionID int, systemPrompt string, temperature float64) SparkRequest {
//messages := []MessageContent{
// {Role: "user", Content: question},
//}
messages2 := make([]MessageContent, 0)
//err := GetSparkSessionHistoryMsg(sessionID, systemPrompt, &messages2, question)
err := GetSparkSessionHistoryMsgV2(sessionID, systemPrompt, &messages2, userID)
if err != nil {
log.Println("get spark session history message error:", err)
}
//messages2Str, _ := json.Marshal(messages2)
//log.Println("spark messages2:", string(messages2Str))
data := SparkRequest{
Header: SparkRequestHeader{
AppID: appid,
},
Parameter: Parameter{
Chat: ChatParameter{
Domain: domain,
Temperature: temperature,
MaxTokens: 10240, //最大10k
},
},
Payload: SparkRequestPayload{
Message: SparkMessage{
Text: messages2,
},
},
}
return data
}
// 创建鉴权url apikey 即 hmac username // 创建鉴权url apikey 即 hmac username
func assembleAuthUrl(hostUrl string, apiKey, apiSecret string) string { func assembleAuthUrl(hostUrl string, apiKey, apiSecret string) string {
ul, err := url.Parse(hostUrl) ul, err := url.Parse(hostUrl)