From d3d3e6f594e6e1b4deda69b4bc0fbe95589730a3 Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Tue, 22 Apr 2025 16:46:19 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=A8=A1=E5=9E=8B=E5=8F=8Ate?= =?UTF-8?q?mperature=E5=8F=82=E6=95=B0=E5=8F=8A=E9=99=84=E5=B1=9E=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- handler/im.go | 30 ++++++++++++++++++++++++++++++ proto/im.go | 33 ++++++++++++++++++--------------- service/doubao.go | 17 +++++++++++++---- service/spark.go | 36 +++++++++++++++++++++++++++++++++++- 4 files changed, 96 insertions(+), 20 deletions(-) diff --git a/handler/im.go b/handler/im.go index c103739..c51fc1f 100644 --- a/handler/im.go +++ b/handler/im.go @@ -256,6 +256,8 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR } modelParam.Url = model.Url modelParam.System = funcs[0].Info //系统功能 + //设置模型temperature及其他参数 + checkAndSetModelTemperatureParam(&modelParam, data, &model) imContext := proto.IMParamContext{UserID: userId, SessionID: int(*sessionID), FunctionID: int(funcs[0].ID), ModelID: int(model.ID), Question: data.Msg, Channel: chanel, ModelType: model.Type} var userMsgID uint @@ -381,3 +383,31 @@ func subscribeAndHandleKBaseTextMessages(ws *websocket.Conn, userId int, kBaseSe } } } + +func checkAndSetModelTemperatureParam(modelParam *proto.ModelParam, data *proto.WSMessageReq, model *dao.Model) { + isErrParam := false + if data.Temperature < 0 || data.Temperature > 2 { + isErrParam = true + } else { + modelParam.Temperature = data.Temperature + } + if model.Type == proto.ModelTypeDouBao { + if data.TopP < 0 || data.TopP > 1 { + isErrParam = true + } else { + modelParam.TopP = data.TopP + } + } else if model.Type == proto.ModelTypeSpark { + if data.TopK < 0 || data.TopK > 1 { + isErrParam = true + } else { + modelParam.TopK = data.TopK + } + } else { + isErrParam = true + } + + if isErrParam { + modelParam.SetDefaultParams() + } +} diff --git a/proto/im.go b/proto/im.go index 5762408..bf77294 100644 --- a/proto/im.go +++ b/proto/im.go @@ -20,17 +20,20 @@ type WSMessage struct { } type WSMessageReq struct { - Type string `json:"type"` //接收及发送消息类型 - Function string `json:"function"` //功能名称 - Msg string `json:"msg"` //消息内容,只进行转发,不做处理 - SessionID int `json:"session_id"` //应用层会话id - ToID int `json:"to_id"` //接收者id - ModelID uint `json:"model_id"` //模型id - IsImage bool `json:"is_image"` //是否为图片(图片消息),如果是图片消息,则msg为图片消息结构为{"img_url":"","text":""} - IsFile bool `json:"is_file"` //是否为文件(文件消息),如果是文件消息,则msg为文件消息结构为{"file_url":"","text":""} - IsKBase bool `json:"is_kbase"` //是否为知识库消息 - KBaseID uint `json:"kbase_id"` //知识库id,为知识库消息时必传 - KBaseType string `json:"kbase_type"` //知识库模式,1为query 2为chat + Type string `json:"type"` //接收及发送消息类型 + Function string `json:"function"` //功能名称 + Msg string `json:"msg"` //消息内容,只进行转发,不做处理 + SessionID int `json:"session_id"` //应用层会话id + ToID int `json:"to_id"` //接收者id + ModelID uint `json:"model_id"` //模型id + IsImage bool `json:"is_image"` //是否为图片(图片消息),如果是图片消息,则msg为图片消息结构为{"img_url":"","text":""} + IsFile bool `json:"is_file"` //是否为文件(文件消息),如果是文件消息,则msg为文件消息结构为{"file_url":"","text":""} + IsKBase bool `json:"is_kbase"` //是否为知识库消息 + KBaseID uint `json:"kbase_id"` //知识库id,为知识库消息时必传 + KBaseType string `json:"kbase_type"` //知识库模式,1为query 2为chat + Temperature float32 `json:"temperature"` //温度,默认0.5,核采样阈值。取值越高随机性越强,即相同的问题得到的不同答案的可能性越大 + TopP float32 `json:"top_p"` //doubao 模型的top_p参数 + TopK float32 `json:"top_k"` //spark 模型的top_k参数 } type ModelParam struct { @@ -40,12 +43,12 @@ type ModelParam struct { APPID string `json:"appid"` //应用id APISecret string `json:"apiSecret"` //应用密钥 APIKey string `json:"apiKey"` //应用key - System string `json:"system"` // + System string `json:"system"` //系统参数 //其他参数 - Temperature float64 `json:"temperature"` //温度,默认0.5,核采样阈值。取值越高随机性越强,即相同的问题得到的不同答案的可能性越大 + Temperature float32 `json:"temperature"` //温度,默认0.5,核采样阈值。取值越高随机性越强,即相同的问题得到的不同答案的可能性越大 MaxTokens int `json:"maxTokens"` //最大生成长度,默认4096 - TopK float64 `json:"topK"` //取概率最大的前k个词 - TopP float64 `json:"topP"` //取 + TopK float32 `json:"topK"` //取概率最大的前k个词 + TopP float32 `json:"topP"` //取 } func (m *ModelParam) SetDefaultParams() { diff --git a/service/doubao.go b/service/doubao.go index 7a68dff..4dec8ac 100644 --- a/service/doubao.go +++ b/service/doubao.go @@ -90,10 +90,14 @@ func DouBaoV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) { if err != nil { log.Println("get doubao session history message error:", err) } + maxTokens := 10240 req := model.CreateChatCompletionRequest{ - Model: modelParam.Model, - Messages: messages, + Model: modelParam.Model, + Temperature: &modelParam.Temperature, + TopP: &modelParam.TopP, + MaxTokens: &maxTokens, //设置最大token数设置为10k + Messages: messages, StreamOptions: &model.StreamOptions{ IncludeUsage: true, }, @@ -165,8 +169,13 @@ func doubaoToGeneralMassageAndSendMsgQueue(data *model.ChatCompletionStreamRespo apiMsg.Model = proto.ModelTypeDouBao choices := data.Choices[0] - apiMsg.Response = choices.Delta.Content - if choices.FinishReason == proto.FinishReasonStop { + + if choices.Delta.Content != "" { + apiMsg.Response = choices.Delta.Content + } else if choices.Delta.ReasoningContent != nil { + apiMsg.Response = *choices.Delta.ReasoningContent //深度思考部分 + } + if choices.FinishReason == proto.FinishReasonStop || choices.FinishReason == proto.FinishReasonLength { //停止原因是stop或长度超限 apiMsg.Done = true } else { apiMsg.Done = false diff --git a/service/spark.go b/service/spark.go index 17762ce..226bf41 100644 --- a/service/spark.go +++ b/service/spark.go @@ -199,7 +199,7 @@ func SparkV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) { //发送消息 go func() { - data := genSparkParams(imCtx.UserID, modelParam.APPID, modelParam.Domain, imCtx.SessionID, modelParam.System) + data := genSparkParamsV2(imCtx.UserID, modelParam.APPID, modelParam.Domain, imCtx.SessionID, modelParam.System, float64(modelParam.Temperature)) //将数据转换为json //dataByte, err3 := json.Marshal(data) //if err3 != nil { @@ -323,6 +323,40 @@ func genSparkParams(userID int, appid string, domain string, sessionID int, syst return data } +// 生成参数 +func genSparkParamsV2(userID int, appid string, domain string, sessionID int, systemPrompt string, temperature float64) SparkRequest { + //messages := []MessageContent{ + // {Role: "user", Content: question}, + //} + messages2 := make([]MessageContent, 0) + //err := GetSparkSessionHistoryMsg(sessionID, systemPrompt, &messages2, question) + err := GetSparkSessionHistoryMsgV2(sessionID, systemPrompt, &messages2, userID) + if err != nil { + log.Println("get spark session history message error:", err) + } + //messages2Str, _ := json.Marshal(messages2) + //log.Println("spark messages2:", string(messages2Str)) + + data := SparkRequest{ + Header: SparkRequestHeader{ + AppID: appid, + }, + Parameter: Parameter{ + Chat: ChatParameter{ + Domain: domain, + Temperature: temperature, + MaxTokens: 10240, //最大10k + }, + }, + Payload: SparkRequestPayload{ + Message: SparkMessage{ + Text: messages2, + }, + }, + } + return data +} + // 创建鉴权url apikey 即 hmac username func assembleAuthUrl(hostUrl string, apiKey, apiSecret string) string { ul, err := url.Parse(hostUrl)