会话名称过长,使用ai进行总结更新
This commit is contained in:
parent
ebe9e31de9
commit
1ccffac34b
|
|
@ -231,10 +231,22 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
|
||||||
sessionType = proto.SessionTypeUserPrompt //提示词会话
|
sessionType = proto.SessionTypeUserPrompt //提示词会话
|
||||||
}
|
}
|
||||||
err, *sessionID = service.CreateSession(userId, sessionType, sessionName)
|
err, *sessionID = service.CreateSession(userId, sessionType, sessionName)
|
||||||
log.Println("create session id:", *sessionID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Create session error:", err)
|
log.Println("Create session error:", err)
|
||||||
}
|
}
|
||||||
|
log.Println("create session id:", *sessionID)
|
||||||
|
//会话名称过长,使用ai进行总结
|
||||||
|
if len(sessionName) > 30 {
|
||||||
|
//通过ai总结会话名称
|
||||||
|
go func() {
|
||||||
|
sessionErr := service.UpdateSessionNameByAIAndID(int(*sessionID), userId, sessionName)
|
||||||
|
if sessionErr != nil {
|
||||||
|
log.Println("update ai session name error:", sessionErr)
|
||||||
|
return
|
||||||
|
} //更新会话名称
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
//获取模型
|
//获取模型
|
||||||
for _, v := range models {
|
for _, v := range models {
|
||||||
|
|
|
||||||
|
|
@ -159,3 +159,30 @@ func OpenAIToGeneralMassageAndSendMsgQueue(data *openai.ChatCompletionChunk, cha
|
||||||
wsMsgStr, _ := json.Marshal(wsMsg)
|
wsMsgStr, _ := json.Marshal(wsMsg)
|
||||||
worker.Publish(channel, string(wsMsgStr), time.Second*60)
|
worker.Publish(channel, string(wsMsgStr), time.Second*60)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// openai 单轮对话,文本
|
||||||
|
func SingleTurnConversationText(prompt string, param proto.ModelParam) (string, error) {
|
||||||
|
var answer string
|
||||||
|
var err error
|
||||||
|
apiKey := param.APIKey
|
||||||
|
if apiKey == "" {
|
||||||
|
log.Printf("openai apiKey is empty\n")
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
openaiClient := openai.NewClient(
|
||||||
|
option.WithAPIKey(apiKey),
|
||||||
|
option.WithBaseURL(param.Url),
|
||||||
|
)
|
||||||
|
chatCompletion, err := openaiClient.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{
|
||||||
|
Messages: []openai.ChatCompletionMessageParamUnion{
|
||||||
|
openai.UserMessage(prompt),
|
||||||
|
},
|
||||||
|
Model: param.Model,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("openai single turn conversation error:%v\n", err)
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
answer = chatCompletion.Choices[0].Message.Content
|
||||||
|
return answer, nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,38 @@ func CreateSession(userID, Type int, name string) (error, uint) {
|
||||||
return err, id
|
return err, id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 将用户的提问总计为会话名称并更新会话
|
||||||
|
func UpdateSessionNameByAIAndID(id, userID int, originSessionName string) error {
|
||||||
|
var err error
|
||||||
|
//获取ai总结的会话名称
|
||||||
|
prompt := "请根据下面的对话内容,提取为20字以内的主题作为会话名称:\n" + originSessionName
|
||||||
|
//ai总结模型id
|
||||||
|
modelID := 24 //固定为混元lite模型,免费
|
||||||
|
models := FindModelByID(modelID, 1)
|
||||||
|
if models == nil {
|
||||||
|
log.Println("UpdateSessionNameByAIAndID model not exist")
|
||||||
|
return errors.New("session name ai model not exist")
|
||||||
|
}
|
||||||
|
model := models[0]
|
||||||
|
var modelParam proto.ModelParam
|
||||||
|
err = json.Unmarshal([]byte(model.Parameter), &modelParam)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("UpdateSessionNameByAIAndID model param error:", err)
|
||||||
|
return errors.New("session name ai model param error")
|
||||||
|
}
|
||||||
|
newSessionName, err2 := generalModel.SingleTurnConversationText(prompt, modelParam)
|
||||||
|
if err2 != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
//更新会话名称
|
||||||
|
err = UpdateSessionByID(id, userID, newSessionName, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("UpdateSessionNameByAIAndAI update session name error:", err, "session id:", id, "user id:", userID, "name:", newSessionName)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// 查找用户的会话列表
|
// 查找用户的会话列表
|
||||||
func FindSessionByUserID(userID, sessionType int) []proto.SessionResponse {
|
func FindSessionByUserID(userID, sessionType int) []proto.SessionResponse {
|
||||||
var sessions []dao.Session
|
var sessions []dao.Session
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue