会话名称过长,使用ai进行总结更新

This commit is contained in:
junleea 2025-05-12 12:57:40 +08:00
parent ebe9e31de9
commit 1ccffac34b
3 changed files with 72 additions and 1 deletions

View File

@ -231,10 +231,22 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
sessionType = proto.SessionTypeUserPrompt //提示词会话 sessionType = proto.SessionTypeUserPrompt //提示词会话
} }
err, *sessionID = service.CreateSession(userId, sessionType, sessionName) err, *sessionID = service.CreateSession(userId, sessionType, sessionName)
log.Println("create session id:", *sessionID)
if err != nil { if err != nil {
log.Println("Create session error:", err) log.Println("Create session error:", err)
} }
log.Println("create session id:", *sessionID)
//会话名称过长使用ai进行总结
if len(sessionName) > 30 {
//通过ai总结会话名称
go func() {
sessionErr := service.UpdateSessionNameByAIAndID(int(*sessionID), userId, sessionName)
if sessionErr != nil {
log.Println("update ai session name error:", sessionErr)
return
} //更新会话名称
}()
}
} }
//获取模型 //获取模型
for _, v := range models { for _, v := range models {

View File

@ -159,3 +159,30 @@ func OpenAIToGeneralMassageAndSendMsgQueue(data *openai.ChatCompletionChunk, cha
wsMsgStr, _ := json.Marshal(wsMsg) wsMsgStr, _ := json.Marshal(wsMsg)
worker.Publish(channel, string(wsMsgStr), time.Second*60) worker.Publish(channel, string(wsMsgStr), time.Second*60)
} }
// openai 单轮对话,文本
func SingleTurnConversationText(prompt string, param proto.ModelParam) (string, error) {
var answer string
var err error
apiKey := param.APIKey
if apiKey == "" {
log.Printf("openai apiKey is empty\n")
return "", nil
}
openaiClient := openai.NewClient(
option.WithAPIKey(apiKey),
option.WithBaseURL(param.Url),
)
chatCompletion, err := openaiClient.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage(prompt),
},
Model: param.Model,
})
if err != nil {
log.Printf("openai single turn conversation error:%v\n", err)
return "", err
}
answer = chatCompletion.Choices[0].Message.Content
return answer, nil
}

View File

@ -20,6 +20,38 @@ func CreateSession(userID, Type int, name string) (error, uint) {
return err, id return err, id
} }
// 将用户的提问总计为会话名称并更新会话
func UpdateSessionNameByAIAndID(id, userID int, originSessionName string) error {
var err error
//获取ai总结的会话名称
prompt := "请根据下面的对话内容提取为20字以内的主题作为会话名称\n" + originSessionName
//ai总结模型id
modelID := 24 //固定为混元lite模型免费
models := FindModelByID(modelID, 1)
if models == nil {
log.Println("UpdateSessionNameByAIAndID model not exist")
return errors.New("session name ai model not exist")
}
model := models[0]
var modelParam proto.ModelParam
err = json.Unmarshal([]byte(model.Parameter), &modelParam)
if err != nil {
log.Println("UpdateSessionNameByAIAndID model param error:", err)
return errors.New("session name ai model param error")
}
newSessionName, err2 := generalModel.SingleTurnConversationText(prompt, modelParam)
if err2 != nil {
return err
}
//更新会话名称
err = UpdateSessionByID(id, userID, newSessionName, nil)
if err != nil {
log.Println("UpdateSessionNameByAIAndAI update session name error:", err, "session id:", id, "user id:", userID, "name:", newSessionName)
return err
}
return nil
}
// 查找用户的会话列表 // 查找用户的会话列表
func FindSessionByUserID(userID, sessionType int) []proto.SessionResponse { func FindSessionByUserID(userID, sessionType int) []proto.SessionResponse {
var sessions []dao.Session var sessions []dao.Session