From bc6c996b2f07042edb6e1f08709631d980ec7356 Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Thu, 3 Apr 2025 16:13:22 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=88=9B=E5=BB=BAppt?= =?UTF-8?q?=E6=95=B4=E4=B8=AA=E8=BF=87=E7=A8=8B=EF=BC=8C=E4=BF=9D=E5=AD=98?= =?UTF-8?q?=E5=88=9B=E5=BB=BA=E5=A4=A7=E7=BA=B2=E3=80=81=E5=88=9B=E5=BB=BA?= =?UTF-8?q?ppt=E5=8F=8A=E4=BF=9D=E5=AD=98=EF=BC=8C=E5=8F=8A=E6=95=B4?= =?UTF-8?q?=E4=B8=AA=E4=BC=9A=E8=AF=9D=E8=BF=87=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dao/im.go | 6 +- handler/im.go | 6 +- handler/session.go | 2 +- handler/tool.go | 16 +++- proto/spark.go | 25 +++++ service/fileService.go | 4 +- service/imService.go | 4 +- service/spark/spark-ppt.go | 183 +++++++++++++++++++++++++++++++++++-- 8 files changed, 226 insertions(+), 20 deletions(-) diff --git a/dao/im.go b/dao/im.go index 06806d7..bec4973 100644 --- a/dao/im.go +++ b/dao/im.go @@ -24,11 +24,11 @@ type Message struct { ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型,则为模型id,根据type判断 Msg string `gorm:"column:msg"` //消息内容 FunctionID int `gorm:"column:function_id"` //功能id - Status int `gorm:"column:status"` //0为未读,1为已读,3为图片消息,4为文件消息 + Status int `gorm:"column:status"` //0为未读,1为已读,3为图片消息,4为文件消息,5为制作ppt消息 } -func CreateSession(userID int, name string) (error, uint) { - session := Session{UserID: userID, Name: name} +func CreateSession(userID, Type int, name string) (error, uint) { + session := Session{UserID: userID, Name: name, Type: Type} var res *gorm.DB if proto.Config.SERVER_SQL_LOG { res = DB.Debug().Create(&session) diff --git a/handler/im.go b/handler/im.go index 3aec346..2a3d08c 100644 --- a/handler/im.go +++ b/handler/im.go @@ -224,7 +224,11 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR if sessionName == "" { sessionName = "新会话" } - err, *sessionID = service.CreateSession(userId, sessionName) + sessionType := proto.SessionTypeUserWithModelGeneration //默认通用聊天 + if data.Function == "gen-prompt" { + sessionType = proto.SessionTypeUserPrompt //提示词会话 + } + err, *sessionID = service.CreateSession(userId, sessionType, sessionName) log.Println("create session id:", *sessionID) if err != nil { log.Println("Create session error:", err) diff --git a/handler/session.go b/handler/session.go index f070a92..a83ab59 100644 --- a/handler/session.go +++ b/handler/session.go @@ -28,7 +28,7 @@ func CreateSession(c *gin.Context) { var req Session if err := c.ShouldBind(&req); err == nil { // 创建会话 - err2, mid := service.CreateSession(int(uint(userID)), req.Name) + err2, mid := service.CreateSession(int(uint(userID)), proto.SessionTypeUserWithModelGeneration, req.Name) if err2 == nil { c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": mid}) } else { diff --git a/handler/tool.go b/handler/tool.go index 78b65ae..a83552e 100644 --- a/handler/tool.go +++ b/handler/tool.go @@ -474,12 +474,14 @@ func CreateSparkPPTSOutline(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"error": "file name is empty", "code": proto.ParameterError, "message": "failed"}) return } - outlineResp, err2 := spark.SparkDoCreateOutline(&req, userID) + var base proto.SparkCreatePPTBaseInfo + base.UserID = userID + outlineResp, err2 := spark.SparkDoCreateOutline(&req, &base) if err2 != nil { c.JSON(http.StatusOK, gin.H{"error": "create outline error", "code": proto.ParameterError, "message": "failed"}) return } - c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": outlineResp}) + c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": outlineResp, "base": base}) } else { c.JSON(http.StatusOK, gin.H{"error": "parameter error", "code": proto.ParameterError, "message": "failed"}) return @@ -495,7 +497,15 @@ func CreateSparkPPT(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"error": "function or query is empty", "code": proto.ParameterError, "message": "failed"}) return } - spark.SparkDoCreatePPTByOutline(&req, userID) + base := proto.SparkCreatePPTBaseInfo{UserID: userID} + pptresp, err2 := spark.SparkDoCreatePPTByOutline(&req, &base) + if err2 != nil { + c.JSON(http.StatusOK, gin.H{"error": "create ppt error", "code": proto.ParameterError, "message": "failed"}) + return + } else { + c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": pptresp}) + return + } } else { c.JSON(http.StatusOK, gin.H{"error": "parameter error", "code": proto.ParameterError, "message": "failed"}) return diff --git a/proto/spark.go b/proto/spark.go index 40d76cf..3def9bb 100644 --- a/proto/spark.go +++ b/proto/spark.go @@ -155,3 +155,28 @@ type SparkCreatePPTByOutlineData struct { TotalPages int `json:"totalPages"` DonePages int `json:"donePages"` } + +type SparkCreatePPTBaseInfo struct { + UserID int `json:"user_id"` //用户id + SessionID uint `json:"session_id"` //会话id + FunctionID int `json:"function_id"` //功能id + ModelID int `json:"model_id"` //模型id +} + +// SparkCreatePPTResponse 定义对应 JSON 数据的外层结构体 +type SparkCreatePPTResponse struct { + Flag bool `json:"flag"` + Code int `json:"code"` + Desc string `json:"desc"` + Count interface{} `json:"count"` + Data SparkCreatePPTData `json:"data"` +} + +// SparkCreatePPTData 定义 JSON 中 data 字段对应的结构体 +type SparkCreatePPTData struct { + SID string `json:"sid"` + CoverImgSrc string `json:"coverImgSrc"` + Title string `json:"title"` + SubTitle string `json:"subTitle"` + Outline interface{} `json:"outline"` +} diff --git a/service/fileService.go b/service/fileService.go index ca2c7a1..000a85e 100644 --- a/service/fileService.go +++ b/service/fileService.go @@ -19,7 +19,7 @@ import ( ) // 检查path是否存在当前日期文件夹如(2024-08-09),不存在则path下当前日期文件夹创建,存在则返回 -func getFilePath(path string) string { +func GetFilePath(path string) string { //当前日期,格式为2024-08-09 date := time.Now().Format("2006-01-02") //拼接文件路径 @@ -39,7 +39,7 @@ func SaveFile(c *gin.Context, file *multipart.FileHeader, uploadType string) (st //生成文件名 fileStoreName := uuid.NewString() + fileSuffix //生成文件路径 - path_ := getFilePath(proto.FILE_BASE_DIR) + path_ := GetFilePath(proto.FILE_BASE_DIR) filePath := path_ + "/" + fileStoreName //保存文件 if err := c.SaveUploadedFile(file, filePath); err != nil { diff --git a/service/imService.go b/service/imService.go index 1ecb889..a22a1f3 100644 --- a/service/imService.go +++ b/service/imService.go @@ -13,8 +13,8 @@ import ( ) // 创建会话服务 -func CreateSession(userID int, name string) (error, uint) { - err, id := dao.CreateSession(userID, name) +func CreateSession(userID, Type int, name string) (error, uint) { + err, id := dao.CreateSession(userID, Type, name) return err, id } diff --git a/service/spark/spark-ppt.go b/service/spark/spark-ppt.go index 2d11d37..0f1677e 100644 --- a/service/spark/spark-ppt.go +++ b/service/spark/spark-ppt.go @@ -1,12 +1,19 @@ package spark import ( + "StuAcaWorksAI/dao" "StuAcaWorksAI/proto" "StuAcaWorksAI/service" "StuAcaWorksAI/worker" "encoding/json" + "errors" + "fmt" + "io" "log" + "net/http" + "os" "strconv" + "strings" "time" ) @@ -132,12 +139,12 @@ func SparkCreateOutline(model *proto.ModelParam, userReq *proto.SparkCreateOutli return response, err } -func SparkCreatePPTByOutline(model *proto.ModelParam, userReq *proto.SparkCreatePPTByOutlineUserRequest) (proto.SparkCreatePPTByOutlineResponse, error) { +func SparkCreatePPTByOutline(model *proto.ModelParam, userReq *proto.SparkCreatePPTByOutlineUserRequest) (proto.SparkCreatePPTResponse, error) { url := "https://zwapi.xfyun.cn/api/ppt/v2/createPptByOutline" headers, err := getSparkRequestHeaders(model.APPID, model.APISecret) if err != nil { log.Println("Spark create ppt by outline Error encoding request:", err) - return proto.SparkCreatePPTByOutlineResponse{}, err + return proto.SparkCreatePPTResponse{}, err } log.Println("Spark create ppt by outline headers:", headers) @@ -156,7 +163,7 @@ func SparkCreatePPTByOutline(model *proto.ModelParam, userReq *proto.SparkCreate err, resp := worker.DoPostRequestJSON(url, reqStr, headers) log.Println("Spark create ppt by outline Response:", string(resp)) - var response proto.SparkCreatePPTByOutlineResponse + var response proto.SparkCreatePPTResponse if err != nil { log.Println("Spark create ppt by outline Error:", err) } else { @@ -194,35 +201,195 @@ func SparkGetPPTInfoBySID(model *proto.ModelParam, userReq *proto.SparkCreatePPT } // 处理制作大纲请求 -func SparkDoCreateOutline(userReq *proto.SparkCreateOutlineRequest, userID int) (proto.SparkCreateOutlineResponse, error) { +func SparkDoCreateOutline(userReq *proto.SparkCreateOutlineRequest, base *proto.SparkCreatePPTBaseInfo) (proto.SparkCreateOutlineResponse, error) { var modelParam proto.ModelParam var err error + //需要保存会话 + sessionName := userReq.Query[:10] + err2, sessionID := service.CreateSession(base.UserID, proto.SessionTypeUserCreatePPT, sessionName) + if err2 != nil { + log.Println("Create session error:", err2) + return proto.SparkCreateOutlineResponse{}, err2 + } + base.SessionID = sessionID //获取模型 models, funcs, mferr := service.FindFuncModelListByFunctionV2(userReq.Function) if mferr != nil { return proto.SparkCreateOutlineResponse{}, mferr } model := models[0] + base.ModelID = int(model.ID) err = json.Unmarshal([]byte(model.Parameter), &modelParam) if err != nil { return proto.SparkCreateOutlineResponse{}, err } modelParam.Url = model.Url modelParam.System = funcs[0].Info //系统功能 - var outline proto.SparkCreateOutlineResponse + //保存消息 + userReqStr, err := json.Marshal(userReq) + if err != nil { + log.Println("Error parsing JSON:", err) + return proto.SparkCreateOutlineResponse{}, err + } + err3, userMsgID := service.CreateMessage(proto.UserToModelMsgType, int(sessionID), base.UserID, int(model.ID), string(userReqStr), proto.UserToModelPPTMsgType, int(funcs[0].ID)) + if err3 != nil { + log.Println("Create message error:", err3, userMsgID) + return proto.SparkCreateOutlineResponse{}, err3 + } - //获取主题列表 + var outline proto.SparkCreateOutlineResponse + //获取大纲 if userReq.FileUrl == "" { outline, err = SparkCreateOutline(&modelParam, userReq) //根据query生成大纲 } else { outline, err = SparkCreateOutlineByDoc(&modelParam, userReq) //根据文档生成大纲 } //进行保存 - + if err != nil { + log.Println("Spark create outline Error:", err) + return outline, err + } + outlineStr, _ := json.Marshal(userReq) + //返回的数据保存 + service.CreateAIStreamMsg(base.UserID, base.ModelID, int(base.SessionID), outline, modelParam.Model) + //保存大纲 + err3, userMsgID = service.CreateMessage(proto.ModelToUserMsgType, int(sessionID), base.UserID, int(model.ID), string(outlineStr), proto.UserToModelPPTMsgType, int(funcs[0].ID)) + if err3 != nil { + log.Println("Create message error:", err3, userMsgID) + return proto.SparkCreateOutlineResponse{}, err3 + } return outline, err } // 处理制作PPT请求 -func SparkDoCreatePPTByOutline(userReq *proto.SparkCreatePPTByOutlineUserRequest, userID int) { +func SparkDoCreatePPTByOutline(userReq *proto.SparkCreatePPTByOutlineUserRequest, base *proto.SparkCreatePPTBaseInfo) (proto.SparkCreatePPTResponse, error) { + var modelParam proto.ModelParam + var err error + //需要保存会话 + if base.SessionID == 0 { + log.Println("session id is 0, create ppt") + return proto.SparkCreatePPTResponse{}, errors.New("create ppt,session id is 0, please create session first") + } + //获取模型 + models, funcs, mferr := service.FindFuncModelListByFunctionV2(userReq.Function) + if mferr != nil { + return proto.SparkCreatePPTResponse{}, mferr + } + model := models[0] + base.ModelID = int(model.ID) + err = json.Unmarshal([]byte(model.Parameter), &modelParam) + if err != nil { + return proto.SparkCreatePPTResponse{}, err + } + modelParam.Url = model.Url + modelParam.System = funcs[0].Info //系统功能 + //保存消息 + userReqStr, err := json.Marshal(userReq) + if err != nil { + log.Println("Error parsing JSON:", err) + return proto.SparkCreatePPTResponse{}, err + } + err3, userMsgID := service.CreateMessage(proto.UserToModelMsgType, int(base.SessionID), base.UserID, int(model.ID), string(userReqStr), proto.UserToModelPPTMsgType, int(funcs[0].ID)) + if err3 != nil { + log.Println("Create message error:", err3, userMsgID) + return proto.SparkCreatePPTResponse{}, err3 + } + resp, err4 := SparkCreatePPTByOutline(&modelParam, userReq) //根据大纲生成ppt + if err4 != nil { + log.Println("Spark create ppt by outline Error:", err4) + return proto.SparkCreatePPTResponse{}, err4 + } + //进行保存 + service.CreateAIStreamMsg(base.UserID, base.ModelID, int(base.SessionID), resp, modelParam.Model) //保存消息 + //新协程获取ppt制作进度 + userReq.Sid = resp.Data.SID + go SparkDoGetPPTInfoBySID(&modelParam, userReq, base) + return resp, err4 +} + +// 每3秒查询一次ppt制作情况 +func SparkDoGetPPTInfoBySID(modelParam *proto.ModelParam, userReq *proto.SparkCreatePPTByOutlineUserRequest, base *proto.SparkCreatePPTBaseInfo) { + //定时 + ticker := time.NewTicker(3 * time.Second) + defer ticker.Stop() + i := 0 + for { + select { + case <-ticker.C: + i += 1 + if i > 20 { + log.Println("SparkDoGetPPTInfoBySID timeout more 1 min, stop ticker") + return + } + log.Println("SparkDoGetPPTInfoBySID ticker") + resp, err := SparkGetPPTInfoBySID(modelParam, userReq) //根据大纲生成ppt + if err != nil { + log.Println("Spark create ppt by outline Error:", err) + return + } + service.CreateAIStreamMsg(base.UserID, base.ModelID, int(base.SessionID), resp, modelParam.Model) //保存消息 + if resp.Data.PptStatus == "done" { + log.Println("ppt done") + //下载文件,保存到指定用户的文件 + fileName, filePath, err2 := SparkSavePPTFileToUser(resp.Data.PptUrl) + if err2 != nil { + log.Println("Spark save ppt file to user Error:", err2) + return + } + //保存文件路径到数据库 + fileRecord := dao.CreateFile(fileName, "", "", filePath, "", 0, base.UserID, false) + fileAuth := service.CreateUserFile(base.UserID, fileRecord.FileName, int(fileRecord.ID), proto.UserFileTypeFile) + if fileRecord.ID == 0 || fileAuth.ID == 0 { + log.Println("save file info failed", fileRecord.ID, fileAuth.ID) + } + return + } + } + } +} + +// 将下载的文件保存到对应指定用户的文件,并对该会话添加ppt完成消息 +func SparkSavePPTFileToUser(fileUrl string) (string, string, error) { + log.Println("Spark save ppt file to user") + //https://bjcdn.openstorage.cn/xinghuo-privatedata/zhiwen/2025-04-02/6ebfc0b2-824a-4acc-a030-01acce938dab/3209de1bca814ac48884e015277360fd.pptx + fileName := fileUrl[strings.LastIndex(fileUrl, "/")+1:] + //文件保存路径 + path := service.GetFilePath(proto.FILE_BASE_DIR) + saveFile := path + "/" + fileName + //下载文件 + err := DownloadFile(fileUrl, saveFile) + if err != nil { + log.Println("Download file error:", err) + return "", "", err + } + log.Println("Download file success:", fileName) + return fileName, path, nil +} + +func DownloadFile(url, filePath string) error { + // 创建 HTTP GET 请求 + resp, err := http.Get(url) + if err != nil { + return err + } + // 确保在函数结束时关闭响应体 + defer resp.Body.Close() + + // 检查 HTTP 响应状态码 + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + // 创建本地文件 + out, err := os.Create(filePath) + if err != nil { + return err + } + // 确保在函数结束时关闭文件 + defer out.Close() + + // 将响应体内容复制到本地文件 + _, err = io.Copy(out, resp.Body) + return err }