diff --git a/handler/im.go b/handler/im.go index 98090b2..6efefba 100644 --- a/handler/im.go +++ b/handler/im.go @@ -231,10 +231,22 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR sessionType = proto.SessionTypeUserPrompt //提示词会话 } err, *sessionID = service.CreateSession(userId, sessionType, sessionName) - log.Println("create session id:", *sessionID) if err != nil { log.Println("Create session error:", err) } + log.Println("create session id:", *sessionID) + //会话名称过长,使用ai进行总结 + if len(sessionName) > 30 { + //通过ai总结会话名称 + go func() { + sessionErr := service.UpdateSessionNameByAIAndID(int(*sessionID), userId, sessionName) + if sessionErr != nil { + log.Println("update ai session name error:", sessionErr) + return + } //更新会话名称 + }() + } + } //获取模型 for _, v := range models { diff --git a/service/generalModel/openai.go b/service/generalModel/openai.go index b8dc5d1..2e89040 100644 --- a/service/generalModel/openai.go +++ b/service/generalModel/openai.go @@ -159,3 +159,30 @@ func OpenAIToGeneralMassageAndSendMsgQueue(data *openai.ChatCompletionChunk, cha wsMsgStr, _ := json.Marshal(wsMsg) worker.Publish(channel, string(wsMsgStr), time.Second*60) } + +// openai 单轮对话,文本 +func SingleTurnConversationText(prompt string, param proto.ModelParam) (string, error) { + var answer string + var err error + apiKey := param.APIKey + if apiKey == "" { + log.Printf("openai apiKey is empty\n") + return "", nil + } + openaiClient := openai.NewClient( + option.WithAPIKey(apiKey), + option.WithBaseURL(param.Url), + ) + chatCompletion, err := openaiClient.Chat.Completions.New(context.TODO(), openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(prompt), + }, + Model: param.Model, + }) + if err != nil { + log.Printf("openai single turn conversation error:%v\n", err) + return "", err + } + answer = chatCompletion.Choices[0].Message.Content + return answer, nil +} diff --git a/service/imService.go b/service/imService.go index 86054c1..6b9bddb 100644 --- a/service/imService.go +++ b/service/imService.go @@ -20,6 +20,38 @@ func CreateSession(userID, Type int, name string) (error, uint) { return err, id } +// 将用户的提问总计为会话名称并更新会话 +func UpdateSessionNameByAIAndID(id, userID int, originSessionName string) error { + var err error + //获取ai总结的会话名称 + prompt := "请根据下面的对话内容,提取为20字以内的主题作为会话名称:\n" + originSessionName + //ai总结模型id + modelID := 24 //固定为混元lite模型,免费 + models := FindModelByID(modelID, 1) + if models == nil { + log.Println("UpdateSessionNameByAIAndID model not exist") + return errors.New("session name ai model not exist") + } + model := models[0] + var modelParam proto.ModelParam + err = json.Unmarshal([]byte(model.Parameter), &modelParam) + if err != nil { + log.Println("UpdateSessionNameByAIAndID model param error:", err) + return errors.New("session name ai model param error") + } + newSessionName, err2 := generalModel.SingleTurnConversationText(prompt, modelParam) + if err2 != nil { + return err + } + //更新会话名称 + err = UpdateSessionByID(id, userID, newSessionName, nil) + if err != nil { + log.Println("UpdateSessionNameByAIAndAI update session name error:", err, "session id:", id, "user id:", userID, "name:", newSessionName) + return err + } + return nil +} + // 查找用户的会话列表 func FindSessionByUserID(userID, sessionType int) []proto.SessionResponse { var sessions []dao.Session