会话名称过长,使用ai进行总结更新
This commit is contained in:
parent
ebe9e31de9
commit
1ccffac34b
|
|
@ -231,10 +231,22 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
|
|||
sessionType = proto.SessionTypeUserPrompt //提示词会话
|
||||
}
|
||||
err, *sessionID = service.CreateSession(userId, sessionType, sessionName)
|
||||
log.Println("create session id:", *sessionID)
|
||||
if err != nil {
|
||||
log.Println("Create session error:", err)
|
||||
}
|
||||
log.Println("create session id:", *sessionID)
|
||||
//会话名称过长,使用ai进行总结
|
||||
if len(sessionName) > 30 {
|
||||
//通过ai总结会话名称
|
||||
go func() {
|
||||
sessionErr := service.UpdateSessionNameByAIAndID(int(*sessionID), userId, sessionName)
|
||||
if sessionErr != nil {
|
||||
log.Println("update ai session name error:", sessionErr)
|
||||
return
|
||||
} //更新会话名称
|
||||
}()
|
||||
}
|
||||
|
||||
}
|
||||
//获取模型
|
||||
for _, v := range models {
|
||||
|
|
|
|||
|
|
@ -159,3 +159,30 @@ func OpenAIToGeneralMassageAndSendMsgQueue(data *openai.ChatCompletionChunk, cha
|
|||
wsMsgStr, _ := json.Marshal(wsMsg)
|
||||
worker.Publish(channel, string(wsMsgStr), time.Second*60)
|
||||
}
|
||||
|
||||
// openai 单轮对话,文本
|
||||
func SingleTurnConversationText(prompt string, param proto.ModelParam) (string, error) {
|
||||
var answer string
|
||||
var err error
|
||||
apiKey := param.APIKey
|
||||
if apiKey == "" {
|
||||
log.Printf("openai apiKey is empty\n")
|
||||
return "", nil
|
||||
}
|
||||
openaiClient := openai.NewClient(
|
||||
option.WithAPIKey(apiKey),
|
||||
option.WithBaseURL(param.Url),
|
||||
)
|
||||
chatCompletion, err := openaiClient.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{
|
||||
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||
openai.UserMessage(prompt),
|
||||
},
|
||||
Model: param.Model,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("openai single turn conversation error:%v\n", err)
|
||||
return "", err
|
||||
}
|
||||
answer = chatCompletion.Choices[0].Message.Content
|
||||
return answer, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,6 +20,38 @@ func CreateSession(userID, Type int, name string) (error, uint) {
|
|||
return err, id
|
||||
}
|
||||
|
||||
// 将用户的提问总计为会话名称并更新会话
|
||||
func UpdateSessionNameByAIAndID(id, userID int, originSessionName string) error {
|
||||
var err error
|
||||
//获取ai总结的会话名称
|
||||
prompt := "请根据下面的对话内容,提取为20字以内的主题作为会话名称:\n" + originSessionName
|
||||
//ai总结模型id
|
||||
modelID := 24 //固定为混元lite模型,免费
|
||||
models := FindModelByID(modelID, 1)
|
||||
if models == nil {
|
||||
log.Println("UpdateSessionNameByAIAndID model not exist")
|
||||
return errors.New("session name ai model not exist")
|
||||
}
|
||||
model := models[0]
|
||||
var modelParam proto.ModelParam
|
||||
err = json.Unmarshal([]byte(model.Parameter), &modelParam)
|
||||
if err != nil {
|
||||
log.Println("UpdateSessionNameByAIAndID model param error:", err)
|
||||
return errors.New("session name ai model param error")
|
||||
}
|
||||
newSessionName, err2 := generalModel.SingleTurnConversationText(prompt, modelParam)
|
||||
if err2 != nil {
|
||||
return err
|
||||
}
|
||||
//更新会话名称
|
||||
err = UpdateSessionByID(id, userID, newSessionName, nil)
|
||||
if err != nil {
|
||||
log.Println("UpdateSessionNameByAIAndAI update session name error:", err, "session id:", id, "user id:", userID, "name:", newSessionName)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 查找用户的会话列表
|
||||
func FindSessionByUserID(userID, sessionType int) []proto.SessionResponse {
|
||||
var sessions []dao.Session
|
||||
|
|
|
|||
Loading…
Reference in New Issue