重构ws消息获取,添加根据功能查找对应模型

This commit is contained in:
junleea 2025-03-24 14:47:53 +08:00
parent e3a14eec54
commit 269bdea1c7
7 changed files with 88 additions and 39 deletions

View File

@ -135,3 +135,13 @@ func DeleteFunctionModelByID(id int) error {
} }
return res.Error return res.Error
} }
func FindFunctionModelByFunction(function string) []FunctionModel {
var models []FunctionModel
if proto.Config.SERVER_SQL_LOG {
DB.Debug().Where("function = ?", function).Find(&models)
} else {
DB.Where("function = ?", function).Find(&models)
}
return models
}

View File

@ -1,11 +1,13 @@
package handler package handler
import ( import (
"StuAcaWorksAI/dao"
"StuAcaWorksAI/proto" "StuAcaWorksAI/proto"
"StuAcaWorksAI/service" "StuAcaWorksAI/service"
"StuAcaWorksAI/worker" "StuAcaWorksAI/worker"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"log" "log"
@ -105,7 +107,6 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
//defer ticker.Stop() //defer ticker.Stop()
var sessionID uint var sessionID uint
var userMsgID uint
//接收客户端消息,发送消息 //接收客户端消息,发送消息
done := make(chan struct{}) done := make(chan struct{})
@ -123,31 +124,9 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
if err1 != nil { if err1 != nil {
log.Println("Error parsing JSON:", err1) log.Println("Error parsing JSON:", err1)
} }
if sessionID == 0 { err = doReceiveGenChatMessage(userId, sessionID, &data)
//创建会话
//会话名字为前10个字符如果
//var sessionName string
//if len(data.Msg) > 10 {
// sessionName = data.Msg[:10]
//} else {
// sessionName = data.Msg
//}
err, sessionID = service.CreateSession(userId, data.Msg)
log.Println("create session id:", sessionID)
if err != nil {
log.Println("Create session error:", err)
return
}
}
//将消息存入数据库
err, userMsgID = service.CreateMessage(1, int(sessionID), userId, 2, data.Msg, 1)
log.Println("create user message id:", userMsgID)
if err != nil {
return
} //机器人id为2,消息类型为1,状态为1
//需要后续处理
err, _ = service.WSReceiveMessageService(userId, int(sessionID), chanel, data)
if err != nil { if err != nil {
log.Println("do receive gen chat message error:", err)
return return
} }
} }
@ -202,3 +181,41 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
} }
} }
func doReceiveGenChatMessage(userId int, sessionID uint, data *proto.WSMessageReq) error {
var err error
chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg"
if sessionID == 0 {
//创建会话
err, sessionID = service.CreateSession(userId, data.Msg)
log.Println("create session id:", sessionID)
if err != nil {
log.Println("Create session error:", err)
}
}
var model dao.Model
//查看请求功能类型
if data.Function == "gen-ai-chat" {
model = service.FindFuncModelByFunction(data.Function, userId)[0]
} else {
return errors.New("function not exist")
}
//解析模型参数
var modelParam proto.ModelParam
err = json.Unmarshal([]byte(model.Parameter), &modelParam)
if err != nil {
return err
}
modelParam.Url = model.Url
var userMsgID uint
//将消息存入数据库
err, userMsgID = service.CreateMessage(proto.UserAndModelMsgType, int(sessionID), userId, int(model.ID), data.Msg, proto.MsgHasRead)
log.Println("create user message id:", userMsgID)
if err != nil {
return err
} //机器人id为2,消息类型为1,状态为1
//需要后续处理
err, _ = service.WSReceiveMessageService(userId, int(sessionID), chanel, *data, model, modelParam)
return err
}

View File

@ -15,10 +15,6 @@ type WSMessage struct {
ToID int `json:"to_id"` //接收者id ToID int `json:"to_id"` //接收者id
} }
type ModelParam struct {
Model string `json:"model"` //模型名称
}
type WSMessageReq struct { type WSMessageReq struct {
Type string `json:"type"` //接收及发送消息类型 Type string `json:"type"` //接收及发送消息类型
Function string `json:"function"` //功能名称 Function string `json:"function"` //功能名称
@ -26,3 +22,11 @@ type WSMessageReq struct {
SessionID int `json:"session_id"` //应用层会话id SessionID int `json:"session_id"` //应用层会话id
ToID int `json:"to_id"` //接收者id ToID int `json:"to_id"` //接收者id
} }
type ModelParam struct {
Model string `json:"model"` //模型名称
Url string `json:"url"` //模型地址
APPID string `json:"appid"` //应用id
APISecret string `json:"apiSecret"` //应用密钥
APIKey string `json:"apiKey"` //应用key
}

View File

@ -89,3 +89,9 @@ const (
FuncModelDeleteFailed = 132 // 删除功能模型失败 FuncModelDeleteFailed = 132 // 删除功能模型失败
FuncModelSearchFailed = 133 // 获取功能模型失败 FuncModelSearchFailed = 133 // 获取功能模型失败
) )
const (
// 代码中使用常量定义
UserAndModelMsgType = 2 // 用户与模型消息类型
MsgHasRead = 1 // 消息已读
)

View File

@ -79,15 +79,16 @@ func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg strin
var client *api.Client var client *api.Client
var SessionContextMap map[int][]int var SessionContextMap map[int][]int
func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) { func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq, model dao.Model, modelParam proto.ModelParam) (error, uint) {
var resErr error var resErr error
var resID uint var resID uint
//处理消息 //处理消息
if msg.Type == "ollama" { if msg.Type == "ollama" {
resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg) resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg)
} else if msg.Type == "spark" { } else if msg.Type == "spark" {
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg) resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam)
} else {
return errors.New("model not exist"), 0
} }
return resErr, resID return resErr, resID
} }
@ -150,10 +151,9 @@ func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMes
return resErr, resID return resErr, resID
} }
func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) { func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam) (error, uint) {
var resErr error var resErr error
var resID uint var resID uint
//查找模型 worker.Spark(modelParam, msg.Msg, channel, sessionID, userID)
return resErr, resID return resErr, resID
} }

View File

@ -106,3 +106,15 @@ func UpdateFuncModelByID(id int, userID, modelID uint, name, info, function stri
} }
return dao.UpdateFunctionModelByID(id, userID, modelID, name, info, function) return dao.UpdateFunctionModelByID(id, userID, modelID, name, info, function)
} }
// 根据功能查找对应功能
func FindFuncModelByFunction(function string, userID int) []dao.Model {
//先查找对应功能
funcModels := dao.FindFunctionModelByFunction(function)
if len(funcModels) == 0 {
return nil
}
modelID := funcModels[0].ModelID
//再查找对应模型
return dao.FindModelByID(int(modelID), userID)
}

View File

@ -105,13 +105,13 @@ type SparkResponse struct {
/***************** Spark响应结构 ***********************/ /***************** Spark响应结构 ***********************/
func Spark(hostUrl string, apiKey, apiSecret, appid string, question, channel string, SessionID, UserID int) { func Spark(modelParam proto.ModelParam, question, channel string, SessionID, UserID int) {
//建立websocket连接 //建立websocket连接
dialer := websocket.Dialer{ dialer := websocket.Dialer{
HandshakeTimeout: 5 * time.Second, HandshakeTimeout: 5 * time.Second,
} }
//建立连接 //建立连接
conn, resp, err := dialer.Dial(assembleAuthUrl(hostUrl, apiKey, apiSecret), nil) conn, resp, err := dialer.Dial(assembleAuthUrl(modelParam.Url, modelParam.APIKey, modelParam.APISecret), nil)
if err != nil { if err != nil {
panic(readResp(resp) + err.Error()) panic(readResp(resp) + err.Error())
return return
@ -121,7 +121,7 @@ func Spark(hostUrl string, apiKey, apiSecret, appid string, question, channel st
//发送消息 //发送消息
go func() { go func() {
data := genSparkParams(appid, question) data := genSparkParams(modelParam.APPID, question)
//将数据转换为json //将数据转换为json
//dataByte, err := json.Marshal(data) //dataByte, err := json.Marshal(data)
if err != nil { if err != nil {
@ -172,7 +172,7 @@ func Spark(hostUrl string, apiKey, apiSecret, appid string, question, channel st
} }
} }
func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponsem, channel string, SessionID int, UserID int) { func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, SessionID int, UserID int) {
var aiMsg proto.AIQueueMessage var aiMsg proto.AIQueueMessage
var wsMsg proto.WSMessage var wsMsg proto.WSMessage
var apiMsg api.GenerateResponse var apiMsg api.GenerateResponse