From e8a268ca229a472d1be9b623e9e52fa11bca16e8 Mon Sep 17 00:00:00 2001 From: junleea <354425203@qq.com> Date: Tue, 25 Mar 2025 14:34:50 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0ai=E5=9B=9E=E5=A4=8D=E6=B5=81?= =?UTF-8?q?=E5=AD=98=E5=82=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dao/db.go | 6 ++++++ dao/msg.go | 39 ++++++++++++++++++++++++++++++++++++++ service/AIStreamService.go | 18 ++++++++++++++++++ service/imService.go | 5 +++-- worker/spark.go | 2 ++ 5 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 dao/msg.go create mode 100644 service/AIStreamService.go diff --git a/dao/db.go b/dao/db.go index 742dfa2..9ec40ea 100644 --- a/dao/db.go +++ b/dao/db.go @@ -77,6 +77,12 @@ func Init() error { fmt.Println("message table:", err) return err } + + err = db.AutoMigrate(&AIStreamMsg{}) + if err != nil { + fmt.Println("ai stream message table:", err) + return err + } DB = db return err } diff --git a/dao/msg.go b/dao/msg.go new file mode 100644 index 0000000..c2c0eee --- /dev/null +++ b/dao/msg.go @@ -0,0 +1,39 @@ +package dao + +import "gorm.io/gorm" + +type AIStreamMsg struct { + gorm.Model + Type string `gorm:"column:type"` + ModelID uint `gorm:"column:model_id"` + UserID int `gorm:"column:user_id"` //用户id + SessionID int `gorm:"column:session_id"` //会话id + MSG string `gorm:"column:msg"` //消息内容 +} + +func InsertAIStreamMsgToDB(userID, sessionID int, modelID uint, msg, tp string) uint { + aiStreamMsg := AIStreamMsg{UserID: userID, ModelID: modelID, MSG: msg, Type: tp, SessionID: sessionID} + DB.Create(&aiStreamMsg) + if aiStreamMsg.ID == 0 { + return 0 + } + return aiStreamMsg.ID +} + +func FindAIStreamMsgByUserID(userID int) []AIStreamMsg { + var aiStreamMsgs []AIStreamMsg + DB.Where("user_id = ?", userID).Find(&aiStreamMsgs) + return aiStreamMsgs +} + +func FindAIStreamMsgByModelID(modelID uint) []AIStreamMsg { + var aiStreamMsgs []AIStreamMsg + DB.Where("model_id = ?", modelID).Find(&aiStreamMsgs) + return aiStreamMsgs +} + +func FindAIStreamMsgByID(id uint) (error, AIStreamMsg) { + var aiStreamMsg AIStreamMsg + err := DB.Where("id = ?", id).First(&aiStreamMsg).Error + return err, aiStreamMsg +} diff --git a/service/AIStreamService.go b/service/AIStreamService.go new file mode 100644 index 0000000..58d2755 --- /dev/null +++ b/service/AIStreamService.go @@ -0,0 +1,18 @@ +package service + +import ( + "StuAcaWorksAI/dao" + "encoding/json" + "log" +) + +func CreateAIStreamMsg(userID, modelID, sessionID int, msg interface{}, t string) { + msgStr, err := json.Marshal(msg) + if err != nil { + log.Println("ai stream msg Error parsing JSON:", err) + } + id := dao.InsertAIStreamMsgToDB(userID, sessionID, uint(modelID), string(msgStr), t) + if id == 0 { + log.Println("ai stream msg insert error") + } +} diff --git a/service/imService.go b/service/imService.go index f3324df..4f1b6d2 100644 --- a/service/imService.go +++ b/service/imService.go @@ -100,7 +100,7 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS var resID uint //处理消息 if model.Type == "ollama" { - resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam) + resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam, &model) } else if model.Type == "spark" { resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam) } else { @@ -109,7 +109,7 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS return resErr, resID } -func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam) (error, uint) { +func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, m *dao.Model) (error, uint) { var resErr error var resID uint var err error @@ -158,6 +158,7 @@ func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMes data.ToID = userID aiMsgStr, _ := json.Marshal(data) worker.Publish(channel, string(aiMsgStr), time.Second*60) + go CreateAIStreamMsg(userID, int(m.ID), sessionID, resp, modelParam.Model) return nil } err = client.Generate(ctx, req, respFunc) diff --git a/worker/spark.go b/worker/spark.go index 594532c..4784a9c 100644 --- a/worker/spark.go +++ b/worker/spark.go @@ -3,6 +3,7 @@ package worker import ( "StuAcaWorksAI/dao" "StuAcaWorksAI/proto" + "StuAcaWorksAI/service" "crypto/hmac" "crypto/sha256" "encoding/base64" @@ -163,6 +164,7 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use } else { answer += data.Payload.Choices.Text[0].Content } + go service.CreateAIStreamMsg(UserID, 4, SessionID, data, modelParam.Model) SparkToGeneralMassageAndSendMsgQueue(&data, channel, SessionID, UserID) }