添加消息功能列,传参使用结构体来避免增加参数时修改过多

This commit is contained in:
junleea 2025-03-28 11:23:45 +08:00
parent 5c57689700
commit cd00f65ebe
6 changed files with 241 additions and 20 deletions

View File

@ -16,12 +16,13 @@ type Session struct {
type Message struct {
gorm.Model
Type int `gorm:"column:type"` //1用户之间的消息2为与模型消息
SessionID int `gorm:"column:session_id;index"` //添加索引
FromID int `gorm:"column:from_id"` //发送者,可以为用户或者模型
ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型则为模型id根据type判断
Msg string `gorm:"column:msg"` //消息内容
Status int `gorm:"column:status"` //0为未读1为已读
Type int `gorm:"column:type"` //1用户之间的消息2为与模型消息
SessionID int `gorm:"column:session_id;index"` //添加索引
FromID int `gorm:"column:from_id"` //发送者,可以为用户或者模型
ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型则为模型id根据type判断
Msg string `gorm:"column:msg"` //消息内容
FunctionID int `gorm:"column:function_id"` //功能id
Status int `gorm:"column:status"` //0为未读1为已读
}
func CreateSession(userID int, name string) (error, uint) {
@ -84,8 +85,8 @@ func DeleteSessionByID(id int) error {
}
// 创建消息
func CreateMessage(sessionID, fromID, toID, msgType, status int, msg string) (error, uint) {
message := Message{SessionID: sessionID, FromID: fromID, ToID: toID, Type: msgType, Status: status, Msg: msg}
func CreateMessage(sessionID, fromID, toID, msgType, status int, msg string, function_id int) (error, uint) {
message := Message{SessionID: sessionID, FromID: fromID, ToID: toID, Type: msgType, Status: status, Msg: msg, FunctionID: function_id}
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Create(&message)

View File

@ -9,6 +9,7 @@ import (
"encoding/json"
"errors"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"log"
"net/http"
@ -97,7 +98,9 @@ func SRMessage(c *gin.Context) {
func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
ctx := context.Background()
chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg"
//随机字符串
strRandom := uuid.NewString()
chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg_" + strRandom
pubsub := worker.RedisClient.Subscribe(ctx, chanel)
defer pubsub.Close()
defer ws.Close()
@ -125,7 +128,7 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
if err1 != nil {
log.Println("Error parsing JSON:", err1)
}
err = doReceiveGenChatMessage(userId, &sessionID, &data)
err = doReceiveGenChatMessage(userId, &sessionID, &data, chanel)
if err != nil {
log.Println("do receive gen chat message error:", err)
return
@ -183,9 +186,8 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
}
func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageReq) error {
func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageReq, chanel string) error {
var err error
chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg"
if *sessionID == 0 {
//创建会话
err, *sessionID = service.CreateSession(userId, data.Msg)
@ -212,15 +214,17 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR
}
modelParam.Url = model.Url
modelParam.System = funcs[0].Info //系统功能
imContext := proto.IMParamContext{UserID: userId, SessionID: int(*sessionID), FunctionID: int(funcs[0].ID), ModelID: int(model.ID), Question: data.Msg, Channel: chanel}
var userMsgID uint
//将消息存入数据库
err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, proto.MsgHasRead)
err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, proto.MsgHasRead, int(funcs[0].ID))
log.Println("create user message id:", userMsgID)
if err != nil {
return err
} //机器人id为2,消息类型为1,状态为1
//需要后续处理
err, _ = service.WSReceiveMessageService(userId, int(*sessionID), chanel, *data, model, modelParam)
//err, _ = service.WSReceiveMessageService(userId, int(*sessionID), chanel, *data, model, modelParam)
err, _ = service.WSReceiveMessageServiceV2(modelParam, &imContext)
return err
}

View File

@ -48,3 +48,13 @@ type SessionResponse struct {
Context []int `gorm:"column:context;type:json"` //会话上下文
MsgCount int `gorm:"column:msg_count"` //消息数量
}
type IMParamContext struct {
UserID int `json:"user_id"` //用户id
SessionID int `json:"session_id"` //会话id
FunctionID int `json:"function_id"` //功能id
ModelID int `json:"model_id"` //模型id
ModelType string `json:"model_type"` //模型类型
Question string `json:"question"` //问题
Channel string `json:"channel"` //消息队列
}

View File

@ -70,7 +70,66 @@ func doubao(modelParam proto.ModelParam, question, channel string, SessionID, Us
}
}
//将消息存入数据库
err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer)
err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, 2)
if err != nil {
log.Println("spark create message error:", err)
}
}
func DouBaoV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
doubaoClient := arkruntime.NewClientWithApiKey(
modelParam.APIKey,
arkruntime.WithBaseUrl(modelParam.Url),
)
ctx := context.Background()
messages := make([]*model.ChatCompletionMessage, 0)
err := GetDouBaoSessionHistoryMsg(imCtx.SessionID, modelParam.System, &messages, imCtx.Question)
if err != nil {
log.Println("get doubao session history message error:", err)
}
//messagesStr, _ := json.Marshal(messages)
//log.Println("doubao messages:", string(messagesStr))
req := model.CreateChatCompletionRequest{
Model: modelParam.Model,
Messages: messages,
}
stream, err := doubaoClient.CreateChatCompletionStream(ctx, req)
if err != nil {
fmt.Printf("stream chat error: %v\n", err)
return
}
defer func(stream *utils.ChatCompletionStreamReader) {
err2 := stream.Close()
if err2 != nil {
log.Println("close stream error:", err2)
}
}(stream)
answer := ""
for {
recv, err3 := stream.Recv()
if err3 == io.EOF {
log.Println("doubao stream end:", err3)
return
}
if err3 != nil {
fmt.Printf("doubao Stream chat error: %v\n", err)
return
}
go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, recv, modelParam.Model)
doubaoToGeneralMassageAndSendMsgQueue(&recv, imCtx.Channel, imCtx.SessionID, imCtx.UserID)
choices := recv.Choices[0]
if choices.FinishReason == proto.FinishReasonStop {
answer += choices.Delta.Content
break
} else {
answer += choices.Delta.Content
}
}
//将消息存入数据库
err, _ = dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, imCtx.FunctionID)
if err != nil {
log.Println("spark create message error:", err)
}

View File

@ -65,9 +65,9 @@ func UpdateSessionByID(id int, userId int, name string, context []int) error {
}
// 创建消息服务
func CreateMessage(msgType int, sessionID int, fromID int, toID int, msg string, status int) (error, uint) {
func CreateMessage(msgType int, sessionID int, fromID int, toID int, msg string, status int, functionId int) (error, uint) {
//sessionID, fromID, toID, msgType, status int, msg string
err, id := dao.CreateMessage(sessionID, fromID, toID, msgType, status, msg)
err, id := dao.CreateMessage(sessionID, fromID, toID, msgType, status, msg, functionId)
return err, id
}
@ -88,7 +88,7 @@ func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg strin
if session.UserID != fromID {
return errors.New("session not belong to user"), 0
}
err, mid := CreateMessage(msgType, sessionID, fromID, toID, msg, 0)
err, mid := CreateMessage(msgType, sessionID, fromID, toID, msg, 0, 0)
return err, mid
}
@ -111,6 +111,22 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS
return resErr, resID
}
func WSReceiveMessageServiceV2(modelParam proto.ModelParam, ctx *proto.IMParamContext) (error, uint) {
var resErr error
var resID uint
//处理消息
if ctx.ModelType == proto.ModelTypeOllama {
resErr, resID = ReceiveOllamaSessionV2(modelParam, ctx)
} else if ctx.ModelType == proto.ModelTypeSpark {
SparkV2(modelParam, ctx)
} else if ctx.ModelType == proto.ModelTypeDouBao {
DouBaoV2(modelParam, ctx)
} else {
return errors.New("model not exist"), 0
}
return resErr, resID
}
func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, m *dao.Model) (error, uint) {
var resErr error
var resID uint
@ -147,7 +163,7 @@ func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMes
SessionContextMap[sessionID] = actx
log.Println("actx has save:", len(SessionContextMap[sessionID]))
//创建消息
err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1
err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1, 0) //机器人id为2,消息类型为2,状态为1
if err2 != nil {
log.Println("create robot message error:", err2)
}
@ -170,6 +186,65 @@ func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMes
return resErr, resID
}
func ReceiveOllamaSessionV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) (error, uint) {
var resErr error
var resID uint
var err error
if client == nil {
client, err = api.ClientFromEnvironment()
}
if err != nil {
log.Println("get ollama client error:", err)
}
prompt := imCtx.Question
actx, ok := SessionContextMap[int(imCtx.SessionID)]
if ok {
log.Println("actx has get will use:", len(actx))
}
req := &api.GenerateRequest{
Model: modelParam.Model,
Context: actx,
Prompt: prompt,
}
ctx := context.Background()
robotMsg := ""
var aiMsg proto.AIQueueMessage
var data proto.WSMessage
respFunc := func(resp api.GenerateResponse) error {
aiMsg.Type = proto.ModelTypeOllama
aiMsg.Msg = resp
robotMsg += resp.Response
if resp.Done { //该消息完成
actx = resp.Context
if SessionContextMap == nil {
SessionContextMap = make(map[int][]int)
}
SessionContextMap[imCtx.SessionID] = actx
log.Println("actx has save:", len(SessionContextMap[imCtx.SessionID]))
//创建消息
err2, msgID := CreateMessage(proto.UserToModelMsgType, imCtx.SessionID, imCtx.ModelID, imCtx.UserID, robotMsg, 1, imCtx.FunctionID) //机器人id为2,消息类型为2,状态为1
if err2 != nil {
log.Println("create robot message error:", err2)
}
aiMsg.ID = msgID
resID = msgID
}
//发送消息
data.Msg = aiMsg
data.SessionID = int(imCtx.SessionID)
data.ToID = imCtx.UserID
aiMsgStr, _ := json.Marshal(data)
worker.Publish(imCtx.Channel, string(aiMsgStr), time.Second*60)
go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, resp, modelParam.Model)
return nil
}
err = client.Generate(ctx, req, respFunc)
if err != nil {
log.Fatal(err)
}
return resErr, resID
}
func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, model *dao.Model) (error, uint) {
var resErr error
var resID uint

View File

@ -170,7 +170,79 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use
conn.Close()
//将消息存入数据库
err2, msgID := dao.CreateMessage(SessionID, modelID, UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer)
err2, msgID := dao.CreateMessage(SessionID, modelID, UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, 2)
if err2 != nil {
log.Println("spark create message error:", err2)
} else {
log.Println("spark create message success,message id:", msgID)
}
}
func SparkV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) {
log.Println("spark model param:", modelParam)
log.Println("spark question:", imCtx.Question)
//建立websocket连接
dialer := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
}
//建立连接
conn, resp, err := dialer.Dial(assembleAuthUrl(modelParam.Url, modelParam.APIKey, modelParam.APISecret), nil)
if err != nil {
panic(readResp(resp) + err.Error())
return
} else if resp.StatusCode != 101 {
panic(readResp(resp) + err.Error())
}
//发送消息
go func() {
data := genSparkParams(imCtx.Question, modelParam.APPID, modelParam.Domain, imCtx.SessionID, modelParam.System)
//将数据转换为json
//dataByte, err := json.Marshal(data)
if err != nil {
fmt.Println("Error parsing JSON:", err)
return
}
err2 := conn.WriteJSON(data)
if err != nil {
fmt.Println("write message error:", err2)
return
}
}()
answer := ""
//接收消息
for {
_, msg, err2 := conn.ReadMessage()
if err2 != nil {
log.Println("read message error:", err2)
break
}
var data SparkResponse
err = json.Unmarshal(msg, &data)
if err != nil {
log.Println("Error parsing JSON:", err)
return
}
if data.Header.Code != 0 {
log.Println("Error response:", data.Header.Message, string(msg))
break
}
if data.Header.Status == 2 {
//最后一个结果
answer += data.Payload.Choices.Text[0].Content
} else {
answer += data.Payload.Choices.Text[0].Content
}
go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, data, modelParam.Model)
SparkToGeneralMassageAndSendMsgQueue(&data, imCtx.Channel, imCtx.SessionID, imCtx.UserID)
}
conn.Close()
//将消息存入数据库
err2, msgID := dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, imCtx.FunctionID)
if err2 != nil {
log.Println("spark create message error:", err2)
} else {