添加模型及temperature参数及附属参数设置
This commit is contained in:
parent
ed030e3acd
commit
d3d3e6f594
|
|
@ -256,6 +256,8 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
|
||||||
}
|
}
|
||||||
modelParam.Url = model.Url
|
modelParam.Url = model.Url
|
||||||
modelParam.System = funcs[0].Info //系统功能
|
modelParam.System = funcs[0].Info //系统功能
|
||||||
|
//设置模型temperature及其他参数
|
||||||
|
checkAndSetModelTemperatureParam(&modelParam, data, &model)
|
||||||
imContext := proto.IMParamContext{UserID: userId, SessionID: int(*sessionID), FunctionID: int(funcs[0].ID), ModelID: int(model.ID), Question: data.Msg, Channel: chanel, ModelType: model.Type}
|
imContext := proto.IMParamContext{UserID: userId, SessionID: int(*sessionID), FunctionID: int(funcs[0].ID), ModelID: int(model.ID), Question: data.Msg, Channel: chanel, ModelType: model.Type}
|
||||||
|
|
||||||
var userMsgID uint
|
var userMsgID uint
|
||||||
|
|
@ -381,3 +383,31 @@ func subscribeAndHandleKBaseTextMessages(ws *websocket.Conn, userId int, kBaseSe
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func checkAndSetModelTemperatureParam(modelParam *proto.ModelParam, data *proto.WSMessageReq, model *dao.Model) {
|
||||||
|
isErrParam := false
|
||||||
|
if data.Temperature < 0 || data.Temperature > 2 {
|
||||||
|
isErrParam = true
|
||||||
|
} else {
|
||||||
|
modelParam.Temperature = data.Temperature
|
||||||
|
}
|
||||||
|
if model.Type == proto.ModelTypeDouBao {
|
||||||
|
if data.TopP < 0 || data.TopP > 1 {
|
||||||
|
isErrParam = true
|
||||||
|
} else {
|
||||||
|
modelParam.TopP = data.TopP
|
||||||
|
}
|
||||||
|
} else if model.Type == proto.ModelTypeSpark {
|
||||||
|
if data.TopK < 0 || data.TopK > 1 {
|
||||||
|
isErrParam = true
|
||||||
|
} else {
|
||||||
|
modelParam.TopK = data.TopK
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
isErrParam = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if isErrParam {
|
||||||
|
modelParam.SetDefaultParams()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
11
proto/im.go
11
proto/im.go
|
|
@ -31,6 +31,9 @@ type WSMessageReq struct {
|
||||||
IsKBase bool `json:"is_kbase"` //是否为知识库消息
|
IsKBase bool `json:"is_kbase"` //是否为知识库消息
|
||||||
KBaseID uint `json:"kbase_id"` //知识库id,为知识库消息时必传
|
KBaseID uint `json:"kbase_id"` //知识库id,为知识库消息时必传
|
||||||
KBaseType string `json:"kbase_type"` //知识库模式,1为query 2为chat
|
KBaseType string `json:"kbase_type"` //知识库模式,1为query 2为chat
|
||||||
|
Temperature float32 `json:"temperature"` //温度,默认0.5,核采样阈值。取值越高随机性越强,即相同的问题得到的不同答案的可能性越大
|
||||||
|
TopP float32 `json:"top_p"` //doubao 模型的top_p参数
|
||||||
|
TopK float32 `json:"top_k"` //spark 模型的top_k参数
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelParam struct {
|
type ModelParam struct {
|
||||||
|
|
@ -40,12 +43,12 @@ type ModelParam struct {
|
||||||
APPID string `json:"appid"` //应用id
|
APPID string `json:"appid"` //应用id
|
||||||
APISecret string `json:"apiSecret"` //应用密钥
|
APISecret string `json:"apiSecret"` //应用密钥
|
||||||
APIKey string `json:"apiKey"` //应用key
|
APIKey string `json:"apiKey"` //应用key
|
||||||
System string `json:"system"` //
|
System string `json:"system"` //系统参数
|
||||||
//其他参数
|
//其他参数
|
||||||
Temperature float64 `json:"temperature"` //温度,默认0.5,核采样阈值。取值越高随机性越强,即相同的问题得到的不同答案的可能性越大
|
Temperature float32 `json:"temperature"` //温度,默认0.5,核采样阈值。取值越高随机性越强,即相同的问题得到的不同答案的可能性越大
|
||||||
MaxTokens int `json:"maxTokens"` //最大生成长度,默认4096
|
MaxTokens int `json:"maxTokens"` //最大生成长度,默认4096
|
||||||
TopK float64 `json:"topK"` //取概率最大的前k个词
|
TopK float32 `json:"topK"` //取概率最大的前k个词
|
||||||
TopP float64 `json:"topP"` //取
|
TopP float32 `json:"topP"` //取
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ModelParam) SetDefaultParams() {
|
func (m *ModelParam) SetDefaultParams() {
|
||||||
|
|
|
||||||
|
|
@ -90,9 +90,13 @@ func DouBaoV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("get doubao session history message error:", err)
|
log.Println("get doubao session history message error:", err)
|
||||||
}
|
}
|
||||||
|
maxTokens := 10240
|
||||||
|
|
||||||
req := model.CreateChatCompletionRequest{
|
req := model.CreateChatCompletionRequest{
|
||||||
Model: modelParam.Model,
|
Model: modelParam.Model,
|
||||||
|
Temperature: &modelParam.Temperature,
|
||||||
|
TopP: &modelParam.TopP,
|
||||||
|
MaxTokens: &maxTokens, //设置最大token数设置为10k
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
StreamOptions: &model.StreamOptions{
|
StreamOptions: &model.StreamOptions{
|
||||||
IncludeUsage: true,
|
IncludeUsage: true,
|
||||||
|
|
@ -165,8 +169,13 @@ func doubaoToGeneralMassageAndSendMsgQueue(data *model.ChatCompletionStreamRespo
|
||||||
|
|
||||||
apiMsg.Model = proto.ModelTypeDouBao
|
apiMsg.Model = proto.ModelTypeDouBao
|
||||||
choices := data.Choices[0]
|
choices := data.Choices[0]
|
||||||
|
|
||||||
|
if choices.Delta.Content != "" {
|
||||||
apiMsg.Response = choices.Delta.Content
|
apiMsg.Response = choices.Delta.Content
|
||||||
if choices.FinishReason == proto.FinishReasonStop {
|
} else if choices.Delta.ReasoningContent != nil {
|
||||||
|
apiMsg.Response = *choices.Delta.ReasoningContent //深度思考部分
|
||||||
|
}
|
||||||
|
if choices.FinishReason == proto.FinishReasonStop || choices.FinishReason == proto.FinishReasonLength { //停止原因是stop或长度超限
|
||||||
apiMsg.Done = true
|
apiMsg.Done = true
|
||||||
} else {
|
} else {
|
||||||
apiMsg.Done = false
|
apiMsg.Done = false
|
||||||
|
|
|
||||||
|
|
@ -199,7 +199,7 @@ func SparkV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
|
||||||
|
|
||||||
//发送消息
|
//发送消息
|
||||||
go func() {
|
go func() {
|
||||||
data := genSparkParams(imCtx.UserID, modelParam.APPID, modelParam.Domain, imCtx.SessionID, modelParam.System)
|
data := genSparkParamsV2(imCtx.UserID, modelParam.APPID, modelParam.Domain, imCtx.SessionID, modelParam.System, float64(modelParam.Temperature))
|
||||||
//将数据转换为json
|
//将数据转换为json
|
||||||
//dataByte, err3 := json.Marshal(data)
|
//dataByte, err3 := json.Marshal(data)
|
||||||
//if err3 != nil {
|
//if err3 != nil {
|
||||||
|
|
@ -323,6 +323,40 @@ func genSparkParams(userID int, appid string, domain string, sessionID int, syst
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 生成参数
|
||||||
|
func genSparkParamsV2(userID int, appid string, domain string, sessionID int, systemPrompt string, temperature float64) SparkRequest {
|
||||||
|
//messages := []MessageContent{
|
||||||
|
// {Role: "user", Content: question},
|
||||||
|
//}
|
||||||
|
messages2 := make([]MessageContent, 0)
|
||||||
|
//err := GetSparkSessionHistoryMsg(sessionID, systemPrompt, &messages2, question)
|
||||||
|
err := GetSparkSessionHistoryMsgV2(sessionID, systemPrompt, &messages2, userID)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("get spark session history message error:", err)
|
||||||
|
}
|
||||||
|
//messages2Str, _ := json.Marshal(messages2)
|
||||||
|
//log.Println("spark messages2:", string(messages2Str))
|
||||||
|
|
||||||
|
data := SparkRequest{
|
||||||
|
Header: SparkRequestHeader{
|
||||||
|
AppID: appid,
|
||||||
|
},
|
||||||
|
Parameter: Parameter{
|
||||||
|
Chat: ChatParameter{
|
||||||
|
Domain: domain,
|
||||||
|
Temperature: temperature,
|
||||||
|
MaxTokens: 10240, //最大10k
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Payload: SparkRequestPayload{
|
||||||
|
Message: SparkMessage{
|
||||||
|
Text: messages2,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
// 创建鉴权url apikey 即 hmac username
|
// 创建鉴权url apikey 即 hmac username
|
||||||
func assembleAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
func assembleAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
||||||
ul, err := url.Parse(hostUrl)
|
ul, err := url.Parse(hostUrl)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue