添加模型及temperature参数及附属参数设置

This commit is contained in:
junleea 2025-04-22 16:46:19 +08:00
parent ed030e3acd
commit d3d3e6f594
4 changed files with 96 additions and 20 deletions

View File

@ -256,6 +256,8 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
} }
modelParam.Url = model.Url modelParam.Url = model.Url
modelParam.System = funcs[0].Info //系统功能 modelParam.System = funcs[0].Info //系统功能
//设置模型temperature及其他参数
checkAndSetModelTemperatureParam(&modelParam, data, &model)
imContext := proto.IMParamContext{UserID: userId, SessionID: int(*sessionID), FunctionID: int(funcs[0].ID), ModelID: int(model.ID), Question: data.Msg, Channel: chanel, ModelType: model.Type} imContext := proto.IMParamContext{UserID: userId, SessionID: int(*sessionID), FunctionID: int(funcs[0].ID), ModelID: int(model.ID), Question: data.Msg, Channel: chanel, ModelType: model.Type}
var userMsgID uint var userMsgID uint
@ -381,3 +383,31 @@ func subscribeAndHandleKBaseTextMessages(ws *websocket.Conn, userId int, kBaseSe
} }
} }
} }
func checkAndSetModelTemperatureParam(modelParam *proto.ModelParam, data *proto.WSMessageReq, model *dao.Model) {
isErrParam := false
if data.Temperature < 0 || data.Temperature > 2 {
isErrParam = true
} else {
modelParam.Temperature = data.Temperature
}
if model.Type == proto.ModelTypeDouBao {
if data.TopP < 0 || data.TopP > 1 {
isErrParam = true
} else {
modelParam.TopP = data.TopP
}
} else if model.Type == proto.ModelTypeSpark {
if data.TopK < 0 || data.TopK > 1 {
isErrParam = true
} else {
modelParam.TopK = data.TopK
}
} else {
isErrParam = true
}
if isErrParam {
modelParam.SetDefaultParams()
}
}

View File

@ -20,17 +20,20 @@ type WSMessage struct {
} }
type WSMessageReq struct { type WSMessageReq struct {
Type string `json:"type"` //接收及发送消息类型 Type string `json:"type"` //接收及发送消息类型
Function string `json:"function"` //功能名称 Function string `json:"function"` //功能名称
Msg string `json:"msg"` //消息内容,只进行转发,不做处理 Msg string `json:"msg"` //消息内容,只进行转发,不做处理
SessionID int `json:"session_id"` //应用层会话id SessionID int `json:"session_id"` //应用层会话id
ToID int `json:"to_id"` //接收者id ToID int `json:"to_id"` //接收者id
ModelID uint `json:"model_id"` //模型id ModelID uint `json:"model_id"` //模型id
IsImage bool `json:"is_image"` //是否为图片(图片消息)如果是图片消息则msg为图片消息结构为{"img_url":"","text":""} IsImage bool `json:"is_image"` //是否为图片(图片消息)如果是图片消息则msg为图片消息结构为{"img_url":"","text":""}
IsFile bool `json:"is_file"` //是否为文件(文件消息)如果是文件消息则msg为文件消息结构为{"file_url":"","text":""} IsFile bool `json:"is_file"` //是否为文件(文件消息)如果是文件消息则msg为文件消息结构为{"file_url":"","text":""}
IsKBase bool `json:"is_kbase"` //是否为知识库消息 IsKBase bool `json:"is_kbase"` //是否为知识库消息
KBaseID uint `json:"kbase_id"` //知识库id,为知识库消息时必传 KBaseID uint `json:"kbase_id"` //知识库id,为知识库消息时必传
KBaseType string `json:"kbase_type"` //知识库模式1为query 2为chat KBaseType string `json:"kbase_type"` //知识库模式1为query 2为chat
Temperature float32 `json:"temperature"` //温度,默认0.5,核采样阈值。取值越高随机性越强,即相同的问题得到的不同答案的可能性越大
TopP float32 `json:"top_p"` //doubao 模型的top_p参数
TopK float32 `json:"top_k"` //spark 模型的top_k参数
} }
type ModelParam struct { type ModelParam struct {
@ -40,12 +43,12 @@ type ModelParam struct {
APPID string `json:"appid"` //应用id APPID string `json:"appid"` //应用id
APISecret string `json:"apiSecret"` //应用密钥 APISecret string `json:"apiSecret"` //应用密钥
APIKey string `json:"apiKey"` //应用key APIKey string `json:"apiKey"` //应用key
System string `json:"system"` // System string `json:"system"` //系统参数
//其他参数 //其他参数
Temperature float64 `json:"temperature"` //温度,默认0.5,核采样阈值。取值越高随机性越强,即相同的问题得到的不同答案的可能性越大 Temperature float32 `json:"temperature"` //温度,默认0.5,核采样阈值。取值越高随机性越强,即相同的问题得到的不同答案的可能性越大
MaxTokens int `json:"maxTokens"` //最大生成长度,默认4096 MaxTokens int `json:"maxTokens"` //最大生成长度,默认4096
TopK float64 `json:"topK"` //取概率最大的前k个词 TopK float32 `json:"topK"` //取概率最大的前k个词
TopP float64 `json:"topP"` //取 TopP float32 `json:"topP"` //取
} }
func (m *ModelParam) SetDefaultParams() { func (m *ModelParam) SetDefaultParams() {

View File

@ -90,10 +90,14 @@ func DouBaoV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
if err != nil { if err != nil {
log.Println("get doubao session history message error:", err) log.Println("get doubao session history message error:", err)
} }
maxTokens := 10240
req := model.CreateChatCompletionRequest{ req := model.CreateChatCompletionRequest{
Model: modelParam.Model, Model: modelParam.Model,
Messages: messages, Temperature: &modelParam.Temperature,
TopP: &modelParam.TopP,
MaxTokens: &maxTokens, //设置最大token数设置为10k
Messages: messages,
StreamOptions: &model.StreamOptions{ StreamOptions: &model.StreamOptions{
IncludeUsage: true, IncludeUsage: true,
}, },
@ -165,8 +169,13 @@ func doubaoToGeneralMassageAndSendMsgQueue(data *model.ChatCompletionStreamRespo
apiMsg.Model = proto.ModelTypeDouBao apiMsg.Model = proto.ModelTypeDouBao
choices := data.Choices[0] choices := data.Choices[0]
apiMsg.Response = choices.Delta.Content
if choices.FinishReason == proto.FinishReasonStop { if choices.Delta.Content != "" {
apiMsg.Response = choices.Delta.Content
} else if choices.Delta.ReasoningContent != nil {
apiMsg.Response = *choices.Delta.ReasoningContent //深度思考部分
}
if choices.FinishReason == proto.FinishReasonStop || choices.FinishReason == proto.FinishReasonLength { //停止原因是stop或长度超限
apiMsg.Done = true apiMsg.Done = true
} else { } else {
apiMsg.Done = false apiMsg.Done = false

View File

@ -199,7 +199,7 @@ func SparkV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
//发送消息 //发送消息
go func() { go func() {
data := genSparkParams(imCtx.UserID, modelParam.APPID, modelParam.Domain, imCtx.SessionID, modelParam.System) data := genSparkParamsV2(imCtx.UserID, modelParam.APPID, modelParam.Domain, imCtx.SessionID, modelParam.System, float64(modelParam.Temperature))
//将数据转换为json //将数据转换为json
//dataByte, err3 := json.Marshal(data) //dataByte, err3 := json.Marshal(data)
//if err3 != nil { //if err3 != nil {
@ -323,6 +323,40 @@ func genSparkParams(userID int, appid string, domain string, sessionID int, syst
return data return data
} }
// 生成参数
func genSparkParamsV2(userID int, appid string, domain string, sessionID int, systemPrompt string, temperature float64) SparkRequest {
//messages := []MessageContent{
// {Role: "user", Content: question},
//}
messages2 := make([]MessageContent, 0)
//err := GetSparkSessionHistoryMsg(sessionID, systemPrompt, &messages2, question)
err := GetSparkSessionHistoryMsgV2(sessionID, systemPrompt, &messages2, userID)
if err != nil {
log.Println("get spark session history message error:", err)
}
//messages2Str, _ := json.Marshal(messages2)
//log.Println("spark messages2:", string(messages2Str))
data := SparkRequest{
Header: SparkRequestHeader{
AppID: appid,
},
Parameter: Parameter{
Chat: ChatParameter{
Domain: domain,
Temperature: temperature,
MaxTokens: 10240, //最大10k
},
},
Payload: SparkRequestPayload{
Message: SparkMessage{
Text: messages2,
},
},
}
return data
}
// 创建鉴权url apikey 即 hmac username // 创建鉴权url apikey 即 hmac username
func assembleAuthUrl(hostUrl string, apiKey, apiSecret string) string { func assembleAuthUrl(hostUrl string, apiKey, apiSecret string) string {
ul, err := url.Parse(hostUrl) ul, err := url.Parse(hostUrl)