添加ai回复流存储
This commit is contained in:
parent
4063c72166
commit
e8a268ca22
|
|
@ -77,6 +77,12 @@ func Init() error {
|
|||
fmt.Println("message table:", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = db.AutoMigrate(&AIStreamMsg{})
|
||||
if err != nil {
|
||||
fmt.Println("ai stream message table:", err)
|
||||
return err
|
||||
}
|
||||
DB = db
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
package dao
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
type AIStreamMsg struct {
|
||||
gorm.Model
|
||||
Type string `gorm:"column:type"`
|
||||
ModelID uint `gorm:"column:model_id"`
|
||||
UserID int `gorm:"column:user_id"` //用户id
|
||||
SessionID int `gorm:"column:session_id"` //会话id
|
||||
MSG string `gorm:"column:msg"` //消息内容
|
||||
}
|
||||
|
||||
func InsertAIStreamMsgToDB(userID, sessionID int, modelID uint, msg, tp string) uint {
|
||||
aiStreamMsg := AIStreamMsg{UserID: userID, ModelID: modelID, MSG: msg, Type: tp, SessionID: sessionID}
|
||||
DB.Create(&aiStreamMsg)
|
||||
if aiStreamMsg.ID == 0 {
|
||||
return 0
|
||||
}
|
||||
return aiStreamMsg.ID
|
||||
}
|
||||
|
||||
func FindAIStreamMsgByUserID(userID int) []AIStreamMsg {
|
||||
var aiStreamMsgs []AIStreamMsg
|
||||
DB.Where("user_id = ?", userID).Find(&aiStreamMsgs)
|
||||
return aiStreamMsgs
|
||||
}
|
||||
|
||||
func FindAIStreamMsgByModelID(modelID uint) []AIStreamMsg {
|
||||
var aiStreamMsgs []AIStreamMsg
|
||||
DB.Where("model_id = ?", modelID).Find(&aiStreamMsgs)
|
||||
return aiStreamMsgs
|
||||
}
|
||||
|
||||
func FindAIStreamMsgByID(id uint) (error, AIStreamMsg) {
|
||||
var aiStreamMsg AIStreamMsg
|
||||
err := DB.Where("id = ?", id).First(&aiStreamMsg).Error
|
||||
return err, aiStreamMsg
|
||||
}
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"StuAcaWorksAI/dao"
|
||||
"encoding/json"
|
||||
"log"
|
||||
)
|
||||
|
||||
func CreateAIStreamMsg(userID, modelID, sessionID int, msg interface{}, t string) {
|
||||
msgStr, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
log.Println("ai stream msg Error parsing JSON:", err)
|
||||
}
|
||||
id := dao.InsertAIStreamMsgToDB(userID, sessionID, uint(modelID), string(msgStr), t)
|
||||
if id == 0 {
|
||||
log.Println("ai stream msg insert error")
|
||||
}
|
||||
}
|
||||
|
|
@ -100,7 +100,7 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS
|
|||
var resID uint
|
||||
//处理消息
|
||||
if model.Type == "ollama" {
|
||||
resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam)
|
||||
resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam, &model)
|
||||
} else if model.Type == "spark" {
|
||||
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam)
|
||||
} else {
|
||||
|
|
@ -109,7 +109,7 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS
|
|||
return resErr, resID
|
||||
}
|
||||
|
||||
func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam) (error, uint) {
|
||||
func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, m *dao.Model) (error, uint) {
|
||||
var resErr error
|
||||
var resID uint
|
||||
var err error
|
||||
|
|
@ -158,6 +158,7 @@ func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMes
|
|||
data.ToID = userID
|
||||
aiMsgStr, _ := json.Marshal(data)
|
||||
worker.Publish(channel, string(aiMsgStr), time.Second*60)
|
||||
go CreateAIStreamMsg(userID, int(m.ID), sessionID, resp, modelParam.Model)
|
||||
return nil
|
||||
}
|
||||
err = client.Generate(ctx, req, respFunc)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package worker
|
|||
import (
|
||||
"StuAcaWorksAI/dao"
|
||||
"StuAcaWorksAI/proto"
|
||||
"StuAcaWorksAI/service"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
|
|
@ -163,6 +164,7 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use
|
|||
} else {
|
||||
answer += data.Payload.Choices.Text[0].Content
|
||||
}
|
||||
go service.CreateAIStreamMsg(UserID, 4, SessionID, data, modelParam.Model)
|
||||
SparkToGeneralMassageAndSendMsgQueue(&data, channel, SessionID, UserID)
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue