添加豆包会话文件内容添加
This commit is contained in:
parent
44d7ce7f1b
commit
7850126c50
17
proto/im.go
17
proto/im.go
|
|
@ -88,3 +88,20 @@ type UserImageMsg struct {
|
|||
ImageContent []ImageMsgContent `json:"image_content"` //图片内容
|
||||
Text string `json:"text"` //问题文本
|
||||
}
|
||||
|
||||
type UserFileResponse struct {
|
||||
ID int `json:"ID"` //文件ID
|
||||
FileID int `json:"FileID"` //文件ID
|
||||
UserFileName string `json:"UserFileName"` //用户文件名称
|
||||
FileStoreName string `json:"file_store_name"` //文件存储名称
|
||||
}
|
||||
|
||||
type FileMessageContent struct {
|
||||
FileContent UserFileResponse `json:"file_content"` //文件内容
|
||||
FileType string `json:"file_type"` //文件类型
|
||||
}
|
||||
|
||||
type UserFileMessage struct {
|
||||
FileContent []FileMessageContent `json:"file_content"` //文件内容
|
||||
Text string `json:"text"` //问题文本
|
||||
}
|
||||
|
|
|
|||
|
|
@ -104,6 +104,7 @@ const (
|
|||
|
||||
//用户发送图片对话
|
||||
UserToModelImageMsgType = 3
|
||||
UserToModelFileMsgType = 3
|
||||
//模型发送文件对话
|
||||
ModelToUserFileMsgType = 4
|
||||
//用户与模型制作ppt的会话
|
||||
|
|
@ -163,3 +164,10 @@ const (
|
|||
SessionTypeUserPrompt = 2 // 用户与模型提示词
|
||||
SessionTypeUserCreatePPT = 3 // 用户与模型制作PPT
|
||||
)
|
||||
|
||||
// 文件类型(文本、图片)
|
||||
const (
|
||||
DownloadFileBaseURL = "https://pm.ljsea.top/tool/"
|
||||
FileTypeText = "text_file"
|
||||
FileTypeImage = "image_file"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -254,3 +254,96 @@ func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]
|
|||
//添加本次请求消息(本次消息已在上面添加)
|
||||
return err
|
||||
}
|
||||
|
||||
// 当前消息包含在messages中
|
||||
func GetDouBaoSessionHistoryMsgV2(sessionID int, systemPrompt string, messages *[]*model.ChatCompletionMessage, userID int) error {
|
||||
var err error
|
||||
//添加系统
|
||||
if systemPrompt != "" {
|
||||
var message model.ChatCompletionMessage
|
||||
message.Role = model.ChatMessageRoleSystem
|
||||
message.Content = &model.ChatCompletionMessageContent{
|
||||
StringValue: volcengine.String(systemPrompt),
|
||||
}
|
||||
*messages = append(*messages, &message)
|
||||
}
|
||||
//获取该会话的历史消息
|
||||
msgs := dao.FindMessageBySessionIDV2(sessionID)
|
||||
if len(msgs) > proto.DouBaoContextLength {
|
||||
//只取最后的 SparkContextLength 条消息
|
||||
msgs = msgs[len(msgs)-proto.DouBaoContextLength-1:]
|
||||
}
|
||||
for _, v := range msgs {
|
||||
var message model.ChatCompletionMessage
|
||||
if v.Type == proto.UserToModelMsgType {
|
||||
message.Role = model.ChatMessageRoleUser
|
||||
//用户消息为有图片类型
|
||||
if v.Status == proto.UserToModelFileMsgType {
|
||||
listValue, err2 := decodeFileMessageToDouBaoMessage(v.Msg, userID)
|
||||
if err2 != nil {
|
||||
log.Println("unmarshal user image message error:", err2)
|
||||
return err2
|
||||
}
|
||||
message.Content = &model.ChatCompletionMessageContent{
|
||||
ListValue: listValue,
|
||||
}
|
||||
} else {
|
||||
//用户消息为文本类型
|
||||
message.Content = &model.ChatCompletionMessageContent{
|
||||
StringValue: volcengine.String(v.Msg),
|
||||
}
|
||||
}
|
||||
} else if v.Type == proto.ModelToUserMsgType {
|
||||
message.Role = model.ChatMessageRoleAssistant
|
||||
message.Content = &model.ChatCompletionMessageContent{
|
||||
StringValue: volcengine.String(v.Msg),
|
||||
}
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
|
||||
*messages = append(*messages, &message)
|
||||
}
|
||||
//添加本次请求消息(本次消息已在上面添加)
|
||||
return err
|
||||
}
|
||||
|
||||
func decodeFileMessageToDouBaoMessage(msg string, userID int) ([]*model.ChatCompletionMessageContentPart, error) {
|
||||
var fileMsgContent proto.UserFileMessage
|
||||
err := json.Unmarshal([]byte(msg), &fileMsgContent)
|
||||
if err != nil {
|
||||
log.Println("DouBao unmarshal file message error:", err)
|
||||
return nil, err
|
||||
}
|
||||
var listValue []*model.ChatCompletionMessageContentPart
|
||||
for _, file := range fileMsgContent.FileContent {
|
||||
//文件消息为文本类型,则获取文本内容,若为图片则添加图片链接
|
||||
if file.FileType == proto.FileTypeText {
|
||||
fileContents, err2 := FindFileContentV2(file.FileContent.FileID, userID)
|
||||
if err2 != nil {
|
||||
log.Println("doubao find file content error:", err2)
|
||||
return nil, err2
|
||||
}
|
||||
fileContent := fileContents[0].FileContent
|
||||
//文件文本内容
|
||||
listValue = append(listValue, &model.ChatCompletionMessageContentPart{
|
||||
Type: model.ChatCompletionMessageContentPartTypeText,
|
||||
Text: "文件" + file.FileContent.UserFileName + ",内容:" + fileContent,
|
||||
})
|
||||
} else if file.FileType == proto.FileTypeImage {
|
||||
listValue = append(listValue, &model.ChatCompletionMessageContentPart{
|
||||
Type: model.ChatCompletionMessageContentPartTypeImageURL,
|
||||
ImageURL: &model.ChatMessageImageURL{
|
||||
URL: proto.DownloadFileBaseURL + file.FileContent.FileStoreName,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
log.Println("file type is not support:", file.FileType, "file name:", file.FileContent.UserFileName, "file id:", file.FileContent.FileID)
|
||||
}
|
||||
}
|
||||
listValue = append(listValue, &model.ChatCompletionMessageContentPart{
|
||||
Type: model.ChatCompletionMessageContentPartTypeText,
|
||||
Text: fileMsgContent.Text,
|
||||
})
|
||||
return listValue, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -296,7 +296,7 @@ func FindFileContent(userID int, userReq *proto.FileContentReq) ([]dao.FileConte
|
|||
return nil, errors.New("file type is image")
|
||||
}
|
||||
//获取文件内容
|
||||
fileContents := dao.FindFileContentByFileID(userReq.UserFileID)
|
||||
fileContents := dao.FindFileContentByFileID(fileAuth.FileID)
|
||||
if len(fileContents) == 0 {
|
||||
//直接读取文件内容
|
||||
filePath := file.FilePath + "/" + file.FileStoreName
|
||||
|
|
@ -326,3 +326,29 @@ func readFileContent(filePath string) (string, error) {
|
|||
|
||||
return string(content), nil
|
||||
}
|
||||
|
||||
func FindFileContentV2(fileID, userID int) ([]dao.FileContent, error) {
|
||||
//获取文件基础信息
|
||||
file := dao.FindFileByID(fileID, userID)
|
||||
//判断文件类型\
|
||||
fileType := strings.Split(file.FileStoreName, ".")[1]
|
||||
//如果文件类型是图片
|
||||
if fileType == "jpg" || fileType == "png" || fileType == "jpeg" || fileType == "gif" {
|
||||
return nil, errors.New("file type is image")
|
||||
}
|
||||
//获取文件内容
|
||||
fileContents := dao.FindFileContentByFileID(fileID)
|
||||
if len(fileContents) == 0 {
|
||||
//直接读取文件内容
|
||||
filePath := file.FilePath + "/" + file.FileStoreName
|
||||
fileContentStr, err := readFileContent(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var fileContent dao.FileContent
|
||||
fileContent.FileID = -1
|
||||
fileContent.FileContent = fileContentStr
|
||||
fileContents = append(fileContents, fileContent)
|
||||
}
|
||||
return fileContents, nil
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue