diff --git a/proto/im.go b/proto/im.go index 8dc2bf8..cd18f51 100644 --- a/proto/im.go +++ b/proto/im.go @@ -88,3 +88,20 @@ type UserImageMsg struct { ImageContent []ImageMsgContent `json:"image_content"` //图片内容 Text string `json:"text"` //问题文本 } + +type UserFileResponse struct { + ID int `json:"ID"` //文件ID + FileID int `json:"FileID"` //文件ID + UserFileName string `json:"UserFileName"` //用户文件名称 + FileStoreName string `json:"file_store_name"` //文件存储名称 +} + +type FileMessageContent struct { + FileContent UserFileResponse `json:"file_content"` //文件内容 + FileType string `json:"file_type"` //文件类型 +} + +type UserFileMessage struct { + FileContent []FileMessageContent `json:"file_content"` //文件内容 + Text string `json:"text"` //问题文本 +} diff --git a/proto/status.go b/proto/status.go index efe69c0..0a843f1 100644 --- a/proto/status.go +++ b/proto/status.go @@ -104,6 +104,7 @@ const ( //用户发送图片对话 UserToModelImageMsgType = 3 + UserToModelFileMsgType = 3 //模型发送文件对话 ModelToUserFileMsgType = 4 //用户与模型制作ppt的会话 @@ -163,3 +164,10 @@ const ( SessionTypeUserPrompt = 2 // 用户与模型提示词 SessionTypeUserCreatePPT = 3 // 用户与模型制作PPT ) + +// 文件类型(文本、图片) +const ( + DownloadFileBaseURL = "https://pm.ljsea.top/tool/" + FileTypeText = "text_file" + FileTypeImage = "image_file" +) diff --git a/service/doubao.go b/service/doubao.go index d14871e..a27e731 100644 --- a/service/doubao.go +++ b/service/doubao.go @@ -254,3 +254,96 @@ func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[] //添加本次请求消息(本次消息已在上面添加) return err } + +// 当前消息包含在messages中 +func GetDouBaoSessionHistoryMsgV2(sessionID int, systemPrompt string, messages *[]*model.ChatCompletionMessage, userID int) error { + var err error + //添加系统 + if systemPrompt != "" { + var message model.ChatCompletionMessage + message.Role = model.ChatMessageRoleSystem + message.Content = &model.ChatCompletionMessageContent{ + StringValue: volcengine.String(systemPrompt), + } + *messages = append(*messages, &message) + } + //获取该会话的历史消息 + msgs := dao.FindMessageBySessionIDV2(sessionID) + if len(msgs) > proto.DouBaoContextLength { + //只取最后的 SparkContextLength 条消息 + msgs = msgs[len(msgs)-proto.DouBaoContextLength-1:] + } + for _, v := range msgs { + var message model.ChatCompletionMessage + if v.Type == proto.UserToModelMsgType { + message.Role = model.ChatMessageRoleUser + //用户消息为有图片类型 + if v.Status == proto.UserToModelFileMsgType { + listValue, err2 := decodeFileMessageToDouBaoMessage(v.Msg, userID) + if err2 != nil { + log.Println("unmarshal user image message error:", err2) + return err2 + } + message.Content = &model.ChatCompletionMessageContent{ + ListValue: listValue, + } + } else { + //用户消息为文本类型 + message.Content = &model.ChatCompletionMessageContent{ + StringValue: volcengine.String(v.Msg), + } + } + } else if v.Type == proto.ModelToUserMsgType { + message.Role = model.ChatMessageRoleAssistant + message.Content = &model.ChatCompletionMessageContent{ + StringValue: volcengine.String(v.Msg), + } + } else { + continue + } + + *messages = append(*messages, &message) + } + //添加本次请求消息(本次消息已在上面添加) + return err +} + +func decodeFileMessageToDouBaoMessage(msg string, userID int) ([]*model.ChatCompletionMessageContentPart, error) { + var fileMsgContent proto.UserFileMessage + err := json.Unmarshal([]byte(msg), &fileMsgContent) + if err != nil { + log.Println("DouBao unmarshal file message error:", err) + return nil, err + } + var listValue []*model.ChatCompletionMessageContentPart + for _, file := range fileMsgContent.FileContent { + //文件消息为文本类型,则获取文本内容,若为图片则添加图片链接 + if file.FileType == proto.FileTypeText { + fileContents, err2 := FindFileContentV2(file.FileContent.FileID, userID) + if err2 != nil { + log.Println("doubao find file content error:", err2) + return nil, err2 + } + fileContent := fileContents[0].FileContent + //文件文本内容 + listValue = append(listValue, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeText, + Text: "文件" + file.FileContent.UserFileName + ",内容:" + fileContent, + }) + } else if file.FileType == proto.FileTypeImage { + listValue = append(listValue, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeImageURL, + ImageURL: &model.ChatMessageImageURL{ + URL: proto.DownloadFileBaseURL + file.FileContent.FileStoreName, + }, + }) + } else { + log.Println("file type is not support:", file.FileType, "file name:", file.FileContent.UserFileName, "file id:", file.FileContent.FileID) + } + } + listValue = append(listValue, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeText, + Text: fileMsgContent.Text, + }) + return listValue, nil +} diff --git a/service/fileService.go b/service/fileService.go index feb9b34..6bf28d6 100644 --- a/service/fileService.go +++ b/service/fileService.go @@ -296,7 +296,7 @@ func FindFileContent(userID int, userReq *proto.FileContentReq) ([]dao.FileConte return nil, errors.New("file type is image") } //获取文件内容 - fileContents := dao.FindFileContentByFileID(userReq.UserFileID) + fileContents := dao.FindFileContentByFileID(fileAuth.FileID) if len(fileContents) == 0 { //直接读取文件内容 filePath := file.FilePath + "/" + file.FileStoreName @@ -326,3 +326,29 @@ func readFileContent(filePath string) (string, error) { return string(content), nil } + +func FindFileContentV2(fileID, userID int) ([]dao.FileContent, error) { + //获取文件基础信息 + file := dao.FindFileByID(fileID, userID) + //判断文件类型\ + fileType := strings.Split(file.FileStoreName, ".")[1] + //如果文件类型是图片 + if fileType == "jpg" || fileType == "png" || fileType == "jpeg" || fileType == "gif" { + return nil, errors.New("file type is image") + } + //获取文件内容 + fileContents := dao.FindFileContentByFileID(fileID) + if len(fileContents) == 0 { + //直接读取文件内容 + filePath := file.FilePath + "/" + file.FileStoreName + fileContentStr, err := readFileContent(filePath) + if err != nil { + return nil, err + } + var fileContent dao.FileContent + fileContent.FileID = -1 + fileContent.FileContent = fileContentStr + fileContents = append(fileContents, fileContent) + } + return fileContents, nil +}