diff --git a/dao/file.go b/dao/file.go index 1df5005..203f1aa 100644 --- a/dao/file.go +++ b/dao/file.go @@ -29,6 +29,13 @@ type FileAuth struct { ShareCode string `gorm:"column:share_code"` // 分享码,用于分享时的验证,构建分享链接 } +// 存储文本文件的内容,主要是pdf、docx等文件的内容,其它文本文件也可以存储 +type FileContent struct { + gorm.Model + FileID int `gorm:"column:file_id"` // 文件id + FileContent string `gorm:"column:file_content"` // 文件内容最大长度为16777215,16M +} + type ConfigFile struct { gorm.Model AuthID int `gorm:"column:auth_id"` @@ -294,3 +301,47 @@ func FindFileAuthByName(fileName string, userID int) FileAuth { db2.Where("user_file_name = ? and auth_id = ?", fileName, userID).First(&file) return file } + +func FindFileContentByFileID(id int) []FileContent { + var fileContent []FileContent + var db2 *gorm.DB + if proto.Config.SERVER_SQL_LOG { + db2 = DB.Debug() + } else { + db2 = DB + } + db2.Where("file_id = ?", id).Find(&fileContent) + return fileContent +} + +func CreateFileContent(fileID int, fileContent string) (uint, error) { + content := FileContent{FileID: fileID, FileContent: fileContent} + var result *gorm.DB + if proto.Config.SERVER_SQL_LOG { + result = DB.Debug().Create(&content) + } else { + result = DB.Create(&content) + } + return content.ID, result.Error +} + +func DeleteFileContentByID(id int) error { + var res *gorm.DB + if proto.Config.SERVER_SQL_LOG { + res = DB.Debug().Delete(&FileContent{}, id) + } else { + res = DB.Delete(&FileContent{}, id) + } + return res.Error +} + +func UpdateFileContentByID(id int, fileContent string) error { + var db2 *gorm.DB + if proto.Config.SERVER_SQL_LOG { + db2 = DB.Debug() + } else { + db2 = DB + } + res := db2.Model(&FileContent{}).Where("id = ?", id).Updates(FileContent{FileContent: fileContent}) + return res.Error +} diff --git a/handler/file.go b/handler/file.go index bd232d5..f243b43 100644 --- a/handler/file.go +++ b/handler/file.go @@ -20,6 +20,7 @@ func SetUpFileGroup(router *gin.Engine) { fileGroup.POST("/file_list", GetUserFileList) fileGroup.POST("/file_delete", DeleteUserFile) fileGroup.POST("/file_update", UpdateUserFile) + fileGroup.POST("/find_file_content", FindFileContent) } @@ -278,3 +279,25 @@ func UploadFileV2(c *gin.Context) { file_record.FilePath = "" c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": file_record, "file_auth": fileAuth}) } + +func FindFileContent(c *gin.Context) { + id, _ := c.Get("id") + userId := int(id.(float64)) + var req proto.FileContentReq + var resp proto.FileContentResp + if err := c.ShouldBind(&req); err == nil { + fileContent, err2 := service.FindFileContent(userId, &req) + if err2 != nil { + resp.Code = proto.ParameterError + resp.Message = "find file content failed:" + err2.Error() + } else { + resp.Code = proto.SuccessCode + resp.Message = "success" + resp.Data = fileContent + } + } else { + resp.Code = proto.ParameterError + resp.Message = "upload form parameter decode error:" + err.Error() + } + c.JSON(http.StatusOK, resp) +} diff --git a/main.go b/main.go index 0f3ec43..3612b2d 100644 --- a/main.go +++ b/main.go @@ -58,9 +58,16 @@ func init() { os.MkdirAll(proto.CID_BASE_DIR, os.ModePerm) os.MkdirAll(proto.CID_BASE_DIR+"script", os.ModePerm) os.MkdirAll(proto.CID_BASE_DIR+"workspace", os.ModePerm) - //读取配置文件 - //文件地址/home/saw-ai/saw-ai.conf - configPath := "/home/saw/saw-ai-go/saw-ai.conf" + //系统是linux、macos还是windows + var configPath string + if os.Getenv("OS") == "Windows_NT" { + configPath = "E:/Code/saw-ai/saw-ai.conf" + } else if os.Getenv("OS") == "linux" { + //文件地址/home/saw-ai/saw-ai.conf + configPath = "/home/saw/saw-ai-go/saw-ai.conf" + } else { + configPath = "/home/saw/saw-ai-go/saw-ai.conf" + } //读取配置文件 err := proto.ReadConfig(configPath) if err != nil { diff --git a/proto/tool.go b/proto/tool.go index cab0411..581f430 100644 --- a/proto/tool.go +++ b/proto/tool.go @@ -30,3 +30,13 @@ type MessageModelIDCountSt struct { ModelName string `json:"name"` Count int `json:"value"` } + +type FileContentReq struct { + UserFileID int `json:"user_file_id" form:"user_file_id"` // 用户文件ID +} + +type FileContentResp struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data"` +} diff --git a/service/fileService.go b/service/fileService.go index c8bc25e..feb9b34 100644 --- a/service/fileService.go +++ b/service/fileService.go @@ -278,3 +278,51 @@ func FindUserFileList(userID int, uploadType string) []dao.UserFileListResp { func CreateUserFileAfterUnique() { } + +func FindFileContent(userID int, userReq *proto.FileContentReq) ([]dao.FileContent, error) { + fileAuth := dao.FindFileAuthByID(userReq.UserFileID) + if fileAuth.ID == 0 { + return nil, errors.New("file auth not found") + } + if fileAuth.AuthID != userID { + return nil, errors.New("no permission") + } + //获取文件基础信息 + file := dao.FindFileByID(fileAuth.FileID, userID) + //判断文件类型\ + fileType := strings.Split(file.FileStoreName, ".")[1] + //如果文件类型是图片 + if fileType == "jpg" || fileType == "png" || fileType == "jpeg" || fileType == "gif" { + return nil, errors.New("file type is image") + } + //获取文件内容 + fileContents := dao.FindFileContentByFileID(userReq.UserFileID) + if len(fileContents) == 0 { + //直接读取文件内容 + filePath := file.FilePath + "/" + file.FileStoreName + fileContentStr, err := readFileContent(filePath) + if err != nil { + return nil, err + } + var fileContent dao.FileContent + fileContent.FileID = -1 + fileContent.FileContent = fileContentStr + fileContents = append(fileContents, fileContent) + } + return fileContents, nil +} + +func readFileContent(filePath string) (string, error) { + file, err := os.Open(filePath) + if err != nil { + return "", err + } + defer file.Close() + + content, err := io.ReadAll(file) + if err != nil { + return "", err + } + + return string(content), nil +} diff --git a/service/spark.go b/service/spark.go index fa5fd5c..c4e3f4f 100644 --- a/service/spark.go +++ b/service/spark.go @@ -306,7 +306,7 @@ func genSparkParams(question, appid string, domain string, sessionID int, system Chat: ChatParameter{ Domain: domain, Temperature: 0.8, - MaxTokens: 2048, + MaxTokens: 4096, }, }, Payload: SparkRequestPayload{ diff --git a/service/spark/spark-image.go b/service/spark/spark-image.go index 0883be2..16aa1a6 100644 --- a/service/spark/spark-image.go +++ b/service/spark/spark-image.go @@ -26,9 +26,8 @@ var ( func sparkImage() { d := websocket.Dialer{ - HandshakeTimeout: 5 * time.Second, + HandshakeTimeout: 20 * time.Second, } - //鎻℃墜骞跺缓绔媤ebsocket 杩炴帴 conn, resp, err := d.Dial(assembleAuthUrl(hostUrl, apiKey, apiSecret), nil) if err != nil { panic(readResp(resp) + err.Error())