添加ai回复流存储
This commit is contained in:
parent
4063c72166
commit
e8a268ca22
|
|
@ -77,6 +77,12 @@ func Init() error {
|
||||||
fmt.Println("message table:", err)
|
fmt.Println("message table:", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = db.AutoMigrate(&AIStreamMsg{})
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("ai stream message table:", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
DB = db
|
DB = db
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,39 @@
|
||||||
|
package dao
|
||||||
|
|
||||||
|
import "gorm.io/gorm"
|
||||||
|
|
||||||
|
type AIStreamMsg struct {
|
||||||
|
gorm.Model
|
||||||
|
Type string `gorm:"column:type"`
|
||||||
|
ModelID uint `gorm:"column:model_id"`
|
||||||
|
UserID int `gorm:"column:user_id"` //用户id
|
||||||
|
SessionID int `gorm:"column:session_id"` //会话id
|
||||||
|
MSG string `gorm:"column:msg"` //消息内容
|
||||||
|
}
|
||||||
|
|
||||||
|
func InsertAIStreamMsgToDB(userID, sessionID int, modelID uint, msg, tp string) uint {
|
||||||
|
aiStreamMsg := AIStreamMsg{UserID: userID, ModelID: modelID, MSG: msg, Type: tp, SessionID: sessionID}
|
||||||
|
DB.Create(&aiStreamMsg)
|
||||||
|
if aiStreamMsg.ID == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return aiStreamMsg.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindAIStreamMsgByUserID(userID int) []AIStreamMsg {
|
||||||
|
var aiStreamMsgs []AIStreamMsg
|
||||||
|
DB.Where("user_id = ?", userID).Find(&aiStreamMsgs)
|
||||||
|
return aiStreamMsgs
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindAIStreamMsgByModelID(modelID uint) []AIStreamMsg {
|
||||||
|
var aiStreamMsgs []AIStreamMsg
|
||||||
|
DB.Where("model_id = ?", modelID).Find(&aiStreamMsgs)
|
||||||
|
return aiStreamMsgs
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindAIStreamMsgByID(id uint) (error, AIStreamMsg) {
|
||||||
|
var aiStreamMsg AIStreamMsg
|
||||||
|
err := DB.Where("id = ?", id).First(&aiStreamMsg).Error
|
||||||
|
return err, aiStreamMsg
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"StuAcaWorksAI/dao"
|
||||||
|
"encoding/json"
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func CreateAIStreamMsg(userID, modelID, sessionID int, msg interface{}, t string) {
|
||||||
|
msgStr, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("ai stream msg Error parsing JSON:", err)
|
||||||
|
}
|
||||||
|
id := dao.InsertAIStreamMsgToDB(userID, sessionID, uint(modelID), string(msgStr), t)
|
||||||
|
if id == 0 {
|
||||||
|
log.Println("ai stream msg insert error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -100,7 +100,7 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS
|
||||||
var resID uint
|
var resID uint
|
||||||
//处理消息
|
//处理消息
|
||||||
if model.Type == "ollama" {
|
if model.Type == "ollama" {
|
||||||
resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam)
|
resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam, &model)
|
||||||
} else if model.Type == "spark" {
|
} else if model.Type == "spark" {
|
||||||
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam)
|
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam)
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -109,7 +109,7 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS
|
||||||
return resErr, resID
|
return resErr, resID
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam) (error, uint) {
|
func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, m *dao.Model) (error, uint) {
|
||||||
var resErr error
|
var resErr error
|
||||||
var resID uint
|
var resID uint
|
||||||
var err error
|
var err error
|
||||||
|
|
@ -158,6 +158,7 @@ func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMes
|
||||||
data.ToID = userID
|
data.ToID = userID
|
||||||
aiMsgStr, _ := json.Marshal(data)
|
aiMsgStr, _ := json.Marshal(data)
|
||||||
worker.Publish(channel, string(aiMsgStr), time.Second*60)
|
worker.Publish(channel, string(aiMsgStr), time.Second*60)
|
||||||
|
go CreateAIStreamMsg(userID, int(m.ID), sessionID, resp, modelParam.Model)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
err = client.Generate(ctx, req, respFunc)
|
err = client.Generate(ctx, req, respFunc)
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package worker
|
||||||
import (
|
import (
|
||||||
"StuAcaWorksAI/dao"
|
"StuAcaWorksAI/dao"
|
||||||
"StuAcaWorksAI/proto"
|
"StuAcaWorksAI/proto"
|
||||||
|
"StuAcaWorksAI/service"
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
|
@ -163,6 +164,7 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use
|
||||||
} else {
|
} else {
|
||||||
answer += data.Payload.Choices.Text[0].Content
|
answer += data.Payload.Choices.Text[0].Content
|
||||||
}
|
}
|
||||||
|
go service.CreateAIStreamMsg(UserID, 4, SessionID, data, modelParam.Model)
|
||||||
SparkToGeneralMassageAndSendMsgQueue(&data, channel, SessionID, UserID)
|
SparkToGeneralMassageAndSendMsgQueue(&data, channel, SessionID, UserID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue