diff --git a/dao/db.go b/dao/db.go index 1eaafa5..7561ab5 100644 --- a/dao/db.go +++ b/dao/db.go @@ -55,6 +55,12 @@ func Init() error { return err } + err = db.AutoMigrate(&FileContent{}) + if err != nil { + fmt.Println("file content table:", err) + return err + } + err = db.AutoMigrate(&Model{}) if err != nil { fmt.Println("model table:", err) diff --git a/handler/im.go b/handler/im.go index 2a3d08c..4fccc5f 100644 --- a/handler/im.go +++ b/handler/im.go @@ -206,13 +206,13 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR } var err error sessionName := "" - if data.IsImage { - var imgMsg proto.UserImageMsg - err = json.Unmarshal([]byte(data.Msg), &imgMsg) + if data.IsImage || data.IsFile { + var fileMsg proto.UserFileMessage + err = json.Unmarshal([]byte(data.Msg), &fileMsg) if err != nil { log.Println("Error parsing JSON:", err) } - sessionName = imgMsg.Text + sessionName = fileMsg.Text } else { sessionName = data.Msg } @@ -258,7 +258,7 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR var userMsgID uint status := proto.MsgHasRead - if data.IsImage { + if data.IsImage || data.IsFile { status = proto.UserToModelImageMsgType //图片类型 } //将消息存入数据库 diff --git a/proto/im.go b/proto/im.go index 8dc2bf8..7df3b52 100644 --- a/proto/im.go +++ b/proto/im.go @@ -26,6 +26,7 @@ type WSMessageReq struct { ToID int `json:"to_id"` //接收者id ModelID uint `json:"model_id"` //模型id IsImage bool `json:"is_image"` //是否为图片(图片消息),如果是图片消息,则msg为图片消息结构为{"img_url":"","text":""} + IsFile bool `json:"is_file"` //是否为文件(文件消息),如果是文件消息,则msg为文件消息结构为{"file_url":"","text":""} } type ModelParam struct { @@ -88,3 +89,20 @@ type UserImageMsg struct { ImageContent []ImageMsgContent `json:"image_content"` //图片内容 Text string `json:"text"` //问题文本 } + +type UserFileResponse struct { + ID int `json:"ID"` //文件ID + FileID int `json:"FileID"` //文件ID + UserFileName string `json:"UserFileName"` //用户文件名称 + FileStoreName string `json:"file_store_name"` //文件存储名称 +} + +type FileMessageContent struct { + FileContent UserFileResponse `json:"file_content"` //文件内容 + FileType string `json:"file_type"` //文件类型 +} + +type UserFileMessage struct { + FileContent []FileMessageContent `json:"file_content"` //文件内容 + Text string `json:"text"` //问题文本 +} diff --git a/proto/status.go b/proto/status.go index efe69c0..9d9dfbd 100644 --- a/proto/status.go +++ b/proto/status.go @@ -104,6 +104,7 @@ const ( //用户发送图片对话 UserToModelImageMsgType = 3 + UserToModelFileMsgType = 3 //模型发送文件对话 ModelToUserFileMsgType = 4 //用户与模型制作ppt的会话 @@ -163,3 +164,10 @@ const ( SessionTypeUserPrompt = 2 // 用户与模型提示词 SessionTypeUserCreatePPT = 3 // 用户与模型制作PPT ) + +// 文件类型(文本、图片) +const ( + DownloadFileBaseURL = "https://pm.ljsea.top/tool/file/" + FileTypeText = "text_file" + FileTypeImage = "image_file" +) diff --git a/service/doubao.go b/service/doubao.go index d14871e..0109480 100644 --- a/service/doubao.go +++ b/service/doubao.go @@ -86,7 +86,7 @@ func DouBaoV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) { ) ctx := context.Background() messages := make([]*model.ChatCompletionMessage, 0) - err := GetDouBaoSessionHistoryMsg(imCtx.SessionID, modelParam.System, &messages, imCtx.Question) + err := GetDouBaoSessionHistoryMsgV2(imCtx.SessionID, modelParam.System, &messages, imCtx.UserID) if err != nil { log.Println("get doubao session history message error:", err) } @@ -254,3 +254,98 @@ func GetDouBaoSessionHistoryMsg(sessionID int, systemPrompt string, messages *[] //添加本次请求消息(本次消息已在上面添加) return err } + +// 当前消息包含在messages中 +func GetDouBaoSessionHistoryMsgV2(sessionID int, systemPrompt string, messages *[]*model.ChatCompletionMessage, userID int) error { + var err error + //添加系统 + if systemPrompt != "" { + var message model.ChatCompletionMessage + message.Role = model.ChatMessageRoleSystem + message.Content = &model.ChatCompletionMessageContent{ + StringValue: volcengine.String(systemPrompt), + } + *messages = append(*messages, &message) + } + //获取该会话的历史消息 + msgs := dao.FindMessageBySessionIDV2(sessionID) + if len(msgs) > proto.DouBaoContextLength { + //只取最后的 SparkContextLength 条消息 + msgs = msgs[len(msgs)-proto.DouBaoContextLength-1:] + } + //log.Println("get doubao session history message:", msgs) + for i, v := range msgs { + var message model.ChatCompletionMessage + if v.Type == proto.UserToModelMsgType { + message.Role = model.ChatMessageRoleUser + //用户消息为有图片类型且是最后一条消息 + if v.Status == proto.UserToModelFileMsgType && i == len(msgs)-1 { + listValue, err2 := decodeFileMessageToDouBaoMessage(v.Msg, userID) + if err2 != nil { + log.Println("unmarshal user image message error:", err2) + return err2 + } + message.Content = &model.ChatCompletionMessageContent{ + ListValue: listValue, + } + } else { + //用户消息为文本类型 + message.Content = &model.ChatCompletionMessageContent{ + StringValue: volcengine.String(v.Msg), + } + } + } else if v.Type == proto.ModelToUserMsgType { + message.Role = model.ChatMessageRoleAssistant + message.Content = &model.ChatCompletionMessageContent{ + StringValue: volcengine.String(v.Msg), + } + } else { + continue + } + + *messages = append(*messages, &message) + } + //添加本次请求消息(本次消息已在上面添加) + return err +} + +func decodeFileMessageToDouBaoMessage(msg string, userID int) ([]*model.ChatCompletionMessageContentPart, error) { + var fileMsgContent proto.UserFileMessage + err := json.Unmarshal([]byte(msg), &fileMsgContent) + if err != nil { + log.Println("DouBao unmarshal file message error:", err) + return nil, err + } + log.Println("doubaov2,decode file message:", fileMsgContent) + var listValue []*model.ChatCompletionMessageContentPart + for _, file := range fileMsgContent.FileContent { + //文件消息为文本类型,则获取文本内容,若为图片则添加图片链接 + if file.FileType == proto.FileTypeText { + fileContents, err2 := FindFileContentV2(file.FileContent.FileID, userID) + if err2 != nil { + log.Println("doubao find file content error:", err2, "\t,file id:", file.FileContent.FileID, "\t,file name:", file.FileContent.UserFileName) + continue + } + fileContent := fileContents[0].FileContent + //文件文本内容 + listValue = append(listValue, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeText, + Text: "文件" + file.FileContent.UserFileName + ",内容:" + fileContent, + }) + } else if file.FileType == proto.FileTypeImage { + listValue = append(listValue, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeImageURL, + ImageURL: &model.ChatMessageImageURL{ + URL: proto.DownloadFileBaseURL + file.FileContent.FileStoreName, + }, + }) + } else { + log.Println("file type is not support:", file.FileType, "file name:", file.FileContent.UserFileName, "file id:", file.FileContent.FileID) + } + } + listValue = append(listValue, &model.ChatCompletionMessageContentPart{ + Type: model.ChatCompletionMessageContentPartTypeText, + Text: fileMsgContent.Text, + }) + return listValue, nil +} diff --git a/service/fileService.go b/service/fileService.go index feb9b34..a52a5c0 100644 --- a/service/fileService.go +++ b/service/fileService.go @@ -5,6 +5,7 @@ import ( "StuAcaWorksAI/proto" "StuAcaWorksAI/worker" "crypto/md5" + "encoding/base64" "errors" "fmt" "github.com/gin-gonic/gin" @@ -296,7 +297,7 @@ func FindFileContent(userID int, userReq *proto.FileContentReq) ([]dao.FileConte return nil, errors.New("file type is image") } //获取文件内容 - fileContents := dao.FindFileContentByFileID(userReq.UserFileID) + fileContents := dao.FindFileContentByFileID(fileAuth.FileID) if len(fileContents) == 0 { //直接读取文件内容 filePath := file.FilePath + "/" + file.FileStoreName @@ -326,3 +327,68 @@ func readFileContent(filePath string) (string, error) { return string(content), nil } + +func FindFileContentV2(fileID, userID int) ([]dao.FileContent, error) { + //获取文件基础信息 + file := dao.FindFileByID(fileID, userID) + //判断文件类型\ + fileType := strings.Split(file.FileStoreName, ".")[1] + //如果文件类型是图片 + if fileType == "jpg" || fileType == "png" || fileType == "jpeg" || fileType == "gif" { + return nil, errors.New("file type is image") + } + //获取文件内容 + fileContents := dao.FindFileContentByFileID(fileID) + if len(fileContents) == 0 { + //直接读取文件内容 + filePath := file.FilePath + "/" + file.FileStoreName + fileContentStr, err := readFileContent(filePath) + if err != nil { + return nil, err + } + var fileContent dao.FileContent + fileContent.FileID = -1 + fileContent.FileContent = fileContentStr + fileContents = append(fileContents, fileContent) + } + return fileContents, nil +} + +func FindImageFileContent(fileID, userID int) (string, error) { + //获取文件基础信息 + file := dao.FindFileByID(fileID, userID) + //判断文件类型\ + fileType := strings.Split(file.FileStoreName, ".")[1] + //如果文件类型是图片 + if fileType != "jpg" && fileType != "png" && fileType != "jpeg" && fileType != "gif" { + return "", errors.New("file type is not image") + } + //直接读取文件内容 + filePath := file.FilePath + "/" + file.FileStoreName + base64Str, err := readImageFile(filePath) + if err != nil { + return "", err + } + return base64Str, nil +} + +// readImageFile 函数用于读取指定路径的图片文件并将其内容转换为 Base64 编码字符串 +func readImageFile(filePath string) (string, error) { + // 打开图片文件 + file, err := os.Open(filePath) + if err != nil { + return "", fmt.Errorf("无法打开文件: %w", err) + } + // 确保文件在函数结束时关闭 + defer file.Close() + + // 读取图片文件内容 + imageData, err := io.ReadAll(file) + if err != nil { + return "", fmt.Errorf("无法读取文件内容: %w", err) + } + + // 将图片数据进行 Base64 编码 + base64Encoded := base64.StdEncoding.EncodeToString(imageData) + return base64Encoded, nil +} diff --git a/service/spark.go b/service/spark.go index c4e3f4f..c4edffd 100644 --- a/service/spark.go +++ b/service/spark.go @@ -36,8 +36,9 @@ type Parameter struct { } type MessageContent struct { - Role string `json:"role"` - Content string `json:"content"` + Role string `json:"role"` + Content string `json:"content"` + ContentType string `json:"content_type,omitempty"` } type SparkMessage struct { @@ -124,13 +125,15 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use //发送消息 go func() { - data := genSparkParams(question, modelParam.APPID, modelParam.Domain, SessionID, modelParam.System) + data := genSparkParams(UserID, modelParam.APPID, modelParam.Domain, SessionID, modelParam.System) //将数据转换为json - //dataByte, err := json.Marshal(data) - if err != nil { + dataByte, err3 := json.Marshal(data) + if err3 != nil { fmt.Println("Error parsing JSON:", err) return } + log.Println("send message:", string(dataByte)) + return err2 := conn.WriteJSON(data) if err != nil { fmt.Println("write message error:", err2) @@ -196,13 +199,14 @@ func SparkV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) { //发送消息 go func() { - data := genSparkParams(imCtx.Question, modelParam.APPID, modelParam.Domain, imCtx.SessionID, modelParam.System) + data := genSparkParams(imCtx.UserID, modelParam.APPID, modelParam.Domain, imCtx.SessionID, modelParam.System) //将数据转换为json - //dataByte, err := json.Marshal(data) - if err != nil { + dataByte, err3 := json.Marshal(data) + if err3 != nil { fmt.Println("Error parsing JSON:", err) return } + log.Println("spark send message:", string(dataByte)) err2 := conn.WriteJSON(data) if err != nil { fmt.Println("write message error:", err2) @@ -286,12 +290,13 @@ func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, S } // 生成参数 -func genSparkParams(question, appid string, domain string, sessionID int, systemPrompt string) SparkRequest { +func genSparkParams(userID int, appid string, domain string, sessionID int, systemPrompt string) SparkRequest { //messages := []MessageContent{ // {Role: "user", Content: question}, //} messages2 := make([]MessageContent, 0) - err := GetSparkSessionHistoryMsg(sessionID, systemPrompt, &messages2, question) + //err := GetSparkSessionHistoryMsg(sessionID, systemPrompt, &messages2, question) + err := GetSparkSessionHistoryMsgV2(sessionID, systemPrompt, &messages2, userID) if err != nil { log.Println("get spark session history message error:", err) } @@ -402,3 +407,93 @@ func GetSparkSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]M //*messages = append(*messages, message) return nil } + +// 获取spark模型会话的历史消息 +func GetSparkSessionHistoryMsgV2(sessionID int, systemPrompt string, messages *[]MessageContent, userID int) error { + //获取会话的历史消息 + msgs := dao.FindMessageBySessionIDV2(sessionID) + if len(msgs) > proto.SparkContextLength { + //只取最后的 SparkContextLength 条消息 + msgs = msgs[len(msgs)-proto.SparkContextLength-1:] + } + //添加系统消息 + if systemPrompt != "" { + var message MessageContent + message.Role = "system" + message.Content = systemPrompt + *messages = append(*messages, message) + } + for _, msg := range msgs { + var message MessageContent + if msg.Type == proto.UserToModelMsgType { + message.Role = proto.SparkRoleUser + if msg.Status == proto.UserToModelFileMsgType { + err2 := decodeFileMessageToSparkMessage(msg.Msg, userID, messages) + if err2 != nil { + log.Println("spark decode file message error:", err2) + } + } else { + message.Content = msg.Msg + *messages = append(*messages, message) + } + } else if msg.Type == proto.ModelToUserMsgType { + message.Role = proto.SparkRoleAssistant + message.Content = msg.Msg + *messages = append(*messages, message) + } else { + continue + } + } + //添加本次请求消息(上面获取到的已包含本次请求消息) + //var message MessageContent + //message.Role = "user" + //message.Content = prompt + //*messages = append(*messages, message) + return nil +} + +func decodeFileMessageToSparkMessage(msg string, userID int, message *[]MessageContent) error { + var fileMsgContent proto.UserFileMessage + err := json.Unmarshal([]byte(msg), &fileMsgContent) + if err != nil { + log.Println("spark unmarshal file message error:", err) + return err + } + log.Println("sparkv2,decode file message:", fileMsgContent) + for _, file := range fileMsgContent.FileContent { + //文件消息为文本类型,则获取文本内容,若为图片则添加图片链接 + if file.FileType == proto.FileTypeText { + fileContents, err2 := FindFileContentV2(file.FileContent.FileID, userID) + if err2 != nil { + log.Println("doubao find file content error:", err2, "\t,file id:", file.FileContent.FileID, "\t,file name:", file.FileContent.UserFileName) + continue + } + fileContent := fileContents[0].FileContent + //文件文本内容 + var messageContent MessageContent + messageContent.Role = "user" + messageContent.Content = fileContent + messageContent.ContentType = "text" + *message = append(*message, messageContent) + } else if file.FileType == proto.FileTypeImage { + imgBase64, err2 := FindImageFileContent(file.FileContent.FileID, userID) + if err2 != nil { + continue + } + //图片文件内容 + var messageContent MessageContent + messageContent.Role = "user" + messageContent.Content = imgBase64 + messageContent.ContentType = "image" + *message = append(*message, messageContent) + } else { + log.Println("file type is not support:", file.FileType, "file name:", file.FileContent.UserFileName, "file id:", file.FileContent.FileID) + } + } + var messageContent MessageContent + messageContent.Role = "user" + messageContent.Content = fileMsgContent.Text + messageContent.ContentType = "text" + *message = append(*message, messageContent) + return nil +}