diff --git a/proto/im.go b/proto/im.go index 0cabe7b..34247c2 100644 --- a/proto/im.go +++ b/proto/im.go @@ -33,7 +33,12 @@ type ModelParam struct { APPID string `json:"appid"` //应用id APISecret string `json:"apiSecret"` //应用密钥 APIKey string `json:"apiKey"` //应用key - System string `json:"system"` //系统参数 + System string `json:"system"` // + //其他参数 + Temperature float64 `json:"temperature"` //温度,默认0.5,核采样阈值。取值越高随机性越强,即相同的问题得到的不同答案的可能性越大 + MaxTokens int `json:"maxTokens"` //最大生成长度,默认4096 + TopK float64 `json:"topK"` //取概率最大的前k个词 + TopP float64 `json:"topP"` //取 } type SessionResponse struct { diff --git a/proto/status.go b/proto/status.go index 8f7a148..4327d56 100644 --- a/proto/status.go +++ b/proto/status.go @@ -124,3 +124,9 @@ const ( ModelTypeDouBao = "doubao" ModelTypeOllama = "ollama" ) + +// 其它 +const ( + SparkContextLength = 6 + DouBaoContextLength = 6 +) diff --git a/service/doubao.go b/service/doubao.go index 9a40975..9c19cba 100644 --- a/service/doubao.go +++ b/service/doubao.go @@ -116,6 +116,10 @@ func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[] } //获取该会话的历史消息 msgs := dao.FindMessageBySessionIDV2(sessionID) + if len(msgs) > proto.DouBaoContextLength { + //只取最后的 SparkContextLength 条消息 + msgs = msgs[len(msgs)-proto.DouBaoContextLength-1:] + } for _, v := range msgs { var message model.ChatCompletionMessage if v.Type == proto.UserToModelMsgType { diff --git a/service/spark.go b/service/spark.go index ef43442..65aabb2 100644 --- a/service/spark.go +++ b/service/spark.go @@ -290,6 +290,10 @@ func readResp(resp *http.Response) string { func GetSparkSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]MessageContent, prompt string) error { //获取会话的历史消息 msgs := dao.FindMessageBySessionIDV2(sessionID) + if len(msgs) > proto.SparkContextLength { + //只取最后的 SparkContextLength 条消息 + msgs = msgs[len(msgs)-proto.SparkContextLength-1:] + } //添加系统消息 if systemPrompt != "" { var message MessageContent