完成spark读取图片及文件内容

This commit is contained in:
junleea 2025-04-07 19:09:58 +08:00
parent 7064c156a8
commit f10384a4a5
2 changed files with 145 additions and 10 deletions

View File

@ -5,6 +5,7 @@ import (
"StuAcaWorksAI/proto"
"StuAcaWorksAI/worker"
"crypto/md5"
"encoding/base64"
"errors"
"fmt"
"github.com/gin-gonic/gin"
@ -352,3 +353,42 @@ func FindFileContentV2(fileID, userID int) ([]dao.FileContent, error) {
}
return fileContents, nil
}
func FindImageFileContent(fileID, userID int) (string, error) {
//获取文件基础信息
file := dao.FindFileByID(fileID, userID)
//判断文件类型\
fileType := strings.Split(file.FileStoreName, ".")[1]
//如果文件类型是图片
if fileType != "jpg" && fileType != "png" && fileType != "jpeg" && fileType != "gif" {
return "", errors.New("file type is not image")
}
//直接读取文件内容
filePath := file.FilePath + "/" + file.FileStoreName
base64Str, err := readImageFile(filePath)
if err != nil {
return "", err
}
return base64Str, nil
}
// readImageFile 函数用于读取指定路径的图片文件并将其内容转换为 Base64 编码字符串
func readImageFile(filePath string) (string, error) {
// 打开图片文件
file, err := os.Open(filePath)
if err != nil {
return "", fmt.Errorf("无法打开文件: %w", err)
}
// 确保文件在函数结束时关闭
defer file.Close()
// 读取图片文件内容
imageData, err := io.ReadAll(file)
if err != nil {
return "", fmt.Errorf("无法读取文件内容: %w", err)
}
// 将图片数据进行 Base64 编码
base64Encoded := base64.StdEncoding.EncodeToString(imageData)
return base64Encoded, nil
}

View File

@ -36,8 +36,9 @@ type Parameter struct {
}
type MessageContent struct {
Role string `json:"role"`
Content string `json:"content"`
Role string `json:"role"`
Content string `json:"content"`
ContentType string `json:"content_type,omitempty"`
}
type SparkMessage struct {
@ -124,13 +125,15 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use
//发送消息
go func() {
data := genSparkParams(question, modelParam.APPID, modelParam.Domain, SessionID, modelParam.System)
data := genSparkParams(UserID, modelParam.APPID, modelParam.Domain, SessionID, modelParam.System)
//将数据转换为json
//dataByte, err := json.Marshal(data)
if err != nil {
dataByte, err3 := json.Marshal(data)
if err3 != nil {
fmt.Println("Error parsing JSON:", err)
return
}
log.Println("send message:", string(dataByte))
return
err2 := conn.WriteJSON(data)
if err != nil {
fmt.Println("write message error:", err2)
@ -196,13 +199,14 @@ func SparkV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
//发送消息
go func() {
data := genSparkParams(imCtx.Question, modelParam.APPID, modelParam.Domain, imCtx.SessionID, modelParam.System)
data := genSparkParams(imCtx.UserID, modelParam.APPID, modelParam.Domain, imCtx.SessionID, modelParam.System)
//将数据转换为json
//dataByte, err := json.Marshal(data)
if err != nil {
dataByte, err3 := json.Marshal(data)
if err3 != nil {
fmt.Println("Error parsing JSON:", err)
return
}
log.Println("spark send message:", string(dataByte))
err2 := conn.WriteJSON(data)
if err != nil {
fmt.Println("write message error:", err2)
@ -286,12 +290,13 @@ func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, S
}
// 生成参数
func genSparkParams(question, appid string, domain string, sessionID int, systemPrompt string) SparkRequest {
func genSparkParams(userID int, appid string, domain string, sessionID int, systemPrompt string) SparkRequest {
//messages := []MessageContent{
// {Role: "user", Content: question},
//}
messages2 := make([]MessageContent, 0)
err := GetSparkSessionHistoryMsg(sessionID, systemPrompt, &messages2, question)
//err := GetSparkSessionHistoryMsg(sessionID, systemPrompt, &messages2, question)
err := GetSparkSessionHistoryMsgV2(sessionID, systemPrompt, &messages2, userID)
if err != nil {
log.Println("get spark session history message error:", err)
}
@ -402,3 +407,93 @@ func GetSparkSessionHistoryMsg(sessionID int, systemPrompt string, messages *[]M
//*messages = append(*messages, message)
return nil
}
// 获取spark模型会话的历史消息
func GetSparkSessionHistoryMsgV2(sessionID int, systemPrompt string, messages *[]MessageContent, userID int) error {
//获取会话的历史消息
msgs := dao.FindMessageBySessionIDV2(sessionID)
if len(msgs) > proto.SparkContextLength {
//只取最后的 SparkContextLength 条消息
msgs = msgs[len(msgs)-proto.SparkContextLength-1:]
}
//添加系统消息
if systemPrompt != "" {
var message MessageContent
message.Role = "system"
message.Content = systemPrompt
*messages = append(*messages, message)
}
for _, msg := range msgs {
var message MessageContent
if msg.Type == proto.UserToModelMsgType {
message.Role = proto.SparkRoleUser
if msg.Status == proto.UserToModelFileMsgType {
err2 := decodeFileMessageToSparkMessage(msg.Msg, userID, messages)
if err2 != nil {
log.Println("spark decode file message error:", err2)
}
} else {
message.Content = msg.Msg
*messages = append(*messages, message)
}
} else if msg.Type == proto.ModelToUserMsgType {
message.Role = proto.SparkRoleAssistant
message.Content = msg.Msg
*messages = append(*messages, message)
} else {
continue
}
}
//添加本次请求消息(上面获取到的已包含本次请求消息)
//var message MessageContent
//message.Role = "user"
//message.Content = prompt
//*messages = append(*messages, message)
return nil
}
func decodeFileMessageToSparkMessage(msg string, userID int, message *[]MessageContent) error {
var fileMsgContent proto.UserFileMessage
err := json.Unmarshal([]byte(msg), &fileMsgContent)
if err != nil {
log.Println("spark unmarshal file message error:", err)
return err
}
log.Println("sparkv2,decode file message:", fileMsgContent)
for _, file := range fileMsgContent.FileContent {
//文件消息为文本类型,则获取文本内容,若为图片则添加图片链接
if file.FileType == proto.FileTypeText {
fileContents, err2 := FindFileContentV2(file.FileContent.FileID, userID)
if err2 != nil {
log.Println("doubao find file content error:", err2, "\t,file id:", file.FileContent.FileID, "\t,file name:", file.FileContent.UserFileName)
continue
}
fileContent := fileContents[0].FileContent
//文件文本内容
var messageContent MessageContent
messageContent.Role = "user"
messageContent.Content = fileContent
messageContent.ContentType = "text"
*message = append(*message, messageContent)
} else if file.FileType == proto.FileTypeImage {
imgBase64, err2 := FindImageFileContent(file.FileContent.FileID, userID)
if err2 != nil {
continue
}
//图片文件内容
var messageContent MessageContent
messageContent.Role = "user"
messageContent.Content = imgBase64
messageContent.ContentType = "image"
*message = append(*message, messageContent)
} else {
log.Println("file type is not support:", file.FileType, "file name:", file.FileContent.UserFileName, "file id:", file.FileContent.FileID)
}
}
var messageContent MessageContent
messageContent.Role = "user"
messageContent.Content = fileMsgContent.Text
messageContent.ContentType = "text"
*message = append(*message, messageContent)
return nil
}