重构ws消息获取,添加根据功能查找对应模型
This commit is contained in:
parent
e3a14eec54
commit
269bdea1c7
10
dao/model.go
10
dao/model.go
|
|
@ -135,3 +135,13 @@ func DeleteFunctionModelByID(id int) error {
|
||||||
}
|
}
|
||||||
return res.Error
|
return res.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func FindFunctionModelByFunction(function string) []FunctionModel {
|
||||||
|
var models []FunctionModel
|
||||||
|
if proto.Config.SERVER_SQL_LOG {
|
||||||
|
DB.Debug().Where("function = ?", function).Find(&models)
|
||||||
|
} else {
|
||||||
|
DB.Where("function = ?", function).Find(&models)
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
package handler
|
package handler
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"StuAcaWorksAI/dao"
|
||||||
"StuAcaWorksAI/proto"
|
"StuAcaWorksAI/proto"
|
||||||
"StuAcaWorksAI/service"
|
"StuAcaWorksAI/service"
|
||||||
"StuAcaWorksAI/worker"
|
"StuAcaWorksAI/worker"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"log"
|
"log"
|
||||||
|
|
@ -105,7 +107,6 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
|
||||||
//defer ticker.Stop()
|
//defer ticker.Stop()
|
||||||
|
|
||||||
var sessionID uint
|
var sessionID uint
|
||||||
var userMsgID uint
|
|
||||||
|
|
||||||
//接收客户端消息,发送消息
|
//接收客户端消息,发送消息
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
|
@ -123,31 +124,9 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
|
||||||
if err1 != nil {
|
if err1 != nil {
|
||||||
log.Println("Error parsing JSON:", err1)
|
log.Println("Error parsing JSON:", err1)
|
||||||
}
|
}
|
||||||
if sessionID == 0 {
|
err = doReceiveGenChatMessage(userId, sessionID, &data)
|
||||||
//创建会话
|
|
||||||
//会话名字为前10个字符,如果
|
|
||||||
//var sessionName string
|
|
||||||
//if len(data.Msg) > 10 {
|
|
||||||
// sessionName = data.Msg[:10]
|
|
||||||
//} else {
|
|
||||||
// sessionName = data.Msg
|
|
||||||
//}
|
|
||||||
err, sessionID = service.CreateSession(userId, data.Msg)
|
|
||||||
log.Println("create session id:", sessionID)
|
|
||||||
if err != nil {
|
|
||||||
log.Println("Create session error:", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
//将消息存入数据库
|
|
||||||
err, userMsgID = service.CreateMessage(1, int(sessionID), userId, 2, data.Msg, 1)
|
|
||||||
log.Println("create user message id:", userMsgID)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
} //机器人id为2,消息类型为1,状态为1
|
|
||||||
//需要后续处理
|
|
||||||
err, _ = service.WSReceiveMessageService(userId, int(sessionID), chanel, data)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Println("do receive gen chat message error:", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -202,3 +181,41 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func doReceiveGenChatMessage(userId int, sessionID uint, data *proto.WSMessageReq) error {
|
||||||
|
var err error
|
||||||
|
chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg"
|
||||||
|
if sessionID == 0 {
|
||||||
|
//创建会话
|
||||||
|
err, sessionID = service.CreateSession(userId, data.Msg)
|
||||||
|
log.Println("create session id:", sessionID)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Create session error:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var model dao.Model
|
||||||
|
//查看请求功能类型
|
||||||
|
if data.Function == "gen-ai-chat" {
|
||||||
|
model = service.FindFuncModelByFunction(data.Function, userId)[0]
|
||||||
|
} else {
|
||||||
|
return errors.New("function not exist")
|
||||||
|
}
|
||||||
|
//解析模型参数
|
||||||
|
var modelParam proto.ModelParam
|
||||||
|
err = json.Unmarshal([]byte(model.Parameter), &modelParam)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
modelParam.Url = model.Url
|
||||||
|
|
||||||
|
var userMsgID uint
|
||||||
|
//将消息存入数据库
|
||||||
|
err, userMsgID = service.CreateMessage(proto.UserAndModelMsgType, int(sessionID), userId, int(model.ID), data.Msg, proto.MsgHasRead)
|
||||||
|
log.Println("create user message id:", userMsgID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} //机器人id为2,消息类型为1,状态为1
|
||||||
|
//需要后续处理
|
||||||
|
err, _ = service.WSReceiveMessageService(userId, int(sessionID), chanel, *data, model, modelParam)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
|
||||||
12
proto/im.go
12
proto/im.go
|
|
@ -15,10 +15,6 @@ type WSMessage struct {
|
||||||
ToID int `json:"to_id"` //接收者id
|
ToID int `json:"to_id"` //接收者id
|
||||||
}
|
}
|
||||||
|
|
||||||
type ModelParam struct {
|
|
||||||
Model string `json:"model"` //模型名称
|
|
||||||
}
|
|
||||||
|
|
||||||
type WSMessageReq struct {
|
type WSMessageReq struct {
|
||||||
Type string `json:"type"` //接收及发送消息类型
|
Type string `json:"type"` //接收及发送消息类型
|
||||||
Function string `json:"function"` //功能名称
|
Function string `json:"function"` //功能名称
|
||||||
|
|
@ -26,3 +22,11 @@ type WSMessageReq struct {
|
||||||
SessionID int `json:"session_id"` //应用层会话id
|
SessionID int `json:"session_id"` //应用层会话id
|
||||||
ToID int `json:"to_id"` //接收者id
|
ToID int `json:"to_id"` //接收者id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ModelParam struct {
|
||||||
|
Model string `json:"model"` //模型名称
|
||||||
|
Url string `json:"url"` //模型地址
|
||||||
|
APPID string `json:"appid"` //应用id
|
||||||
|
APISecret string `json:"apiSecret"` //应用密钥
|
||||||
|
APIKey string `json:"apiKey"` //应用key
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -89,3 +89,9 @@ const (
|
||||||
FuncModelDeleteFailed = 132 // 删除功能模型失败
|
FuncModelDeleteFailed = 132 // 删除功能模型失败
|
||||||
FuncModelSearchFailed = 133 // 获取功能模型失败
|
FuncModelSearchFailed = 133 // 获取功能模型失败
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// 代码中使用常量定义
|
||||||
|
UserAndModelMsgType = 2 // 用户与模型消息类型
|
||||||
|
MsgHasRead = 1 // 消息已读
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -79,15 +79,16 @@ func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg strin
|
||||||
var client *api.Client
|
var client *api.Client
|
||||||
var SessionContextMap map[int][]int
|
var SessionContextMap map[int][]int
|
||||||
|
|
||||||
func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) {
|
func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq, model dao.Model, modelParam proto.ModelParam) (error, uint) {
|
||||||
var resErr error
|
var resErr error
|
||||||
var resID uint
|
var resID uint
|
||||||
//处理消息
|
//处理消息
|
||||||
if msg.Type == "ollama" {
|
if msg.Type == "ollama" {
|
||||||
resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg)
|
resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg)
|
||||||
} else if msg.Type == "spark" {
|
} else if msg.Type == "spark" {
|
||||||
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg)
|
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam)
|
||||||
|
} else {
|
||||||
|
return errors.New("model not exist"), 0
|
||||||
}
|
}
|
||||||
return resErr, resID
|
return resErr, resID
|
||||||
}
|
}
|
||||||
|
|
@ -150,10 +151,9 @@ func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMes
|
||||||
return resErr, resID
|
return resErr, resID
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) {
|
func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam) (error, uint) {
|
||||||
var resErr error
|
var resErr error
|
||||||
var resID uint
|
var resID uint
|
||||||
//查找模型
|
worker.Spark(modelParam, msg.Msg, channel, sessionID, userID)
|
||||||
|
|
||||||
return resErr, resID
|
return resErr, resID
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -106,3 +106,15 @@ func UpdateFuncModelByID(id int, userID, modelID uint, name, info, function stri
|
||||||
}
|
}
|
||||||
return dao.UpdateFunctionModelByID(id, userID, modelID, name, info, function)
|
return dao.UpdateFunctionModelByID(id, userID, modelID, name, info, function)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 根据功能查找对应功能
|
||||||
|
func FindFuncModelByFunction(function string, userID int) []dao.Model {
|
||||||
|
//先查找对应功能
|
||||||
|
funcModels := dao.FindFunctionModelByFunction(function)
|
||||||
|
if len(funcModels) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
modelID := funcModels[0].ModelID
|
||||||
|
//再查找对应模型
|
||||||
|
return dao.FindModelByID(int(modelID), userID)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -105,13 +105,13 @@ type SparkResponse struct {
|
||||||
|
|
||||||
/***************** Spark响应结构 ***********************/
|
/***************** Spark响应结构 ***********************/
|
||||||
|
|
||||||
func Spark(hostUrl string, apiKey, apiSecret, appid string, question, channel string, SessionID, UserID int) {
|
func Spark(modelParam proto.ModelParam, question, channel string, SessionID, UserID int) {
|
||||||
//建立websocket连接
|
//建立websocket连接
|
||||||
dialer := websocket.Dialer{
|
dialer := websocket.Dialer{
|
||||||
HandshakeTimeout: 5 * time.Second,
|
HandshakeTimeout: 5 * time.Second,
|
||||||
}
|
}
|
||||||
//建立连接
|
//建立连接
|
||||||
conn, resp, err := dialer.Dial(assembleAuthUrl(hostUrl, apiKey, apiSecret), nil)
|
conn, resp, err := dialer.Dial(assembleAuthUrl(modelParam.Url, modelParam.APIKey, modelParam.APISecret), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(readResp(resp) + err.Error())
|
panic(readResp(resp) + err.Error())
|
||||||
return
|
return
|
||||||
|
|
@ -121,7 +121,7 @@ func Spark(hostUrl string, apiKey, apiSecret, appid string, question, channel st
|
||||||
|
|
||||||
//发送消息
|
//发送消息
|
||||||
go func() {
|
go func() {
|
||||||
data := genSparkParams(appid, question)
|
data := genSparkParams(modelParam.APPID, question)
|
||||||
//将数据转换为json
|
//将数据转换为json
|
||||||
//dataByte, err := json.Marshal(data)
|
//dataByte, err := json.Marshal(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -172,7 +172,7 @@ func Spark(hostUrl string, apiKey, apiSecret, appid string, question, channel st
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponsem, channel string, SessionID int, UserID int) {
|
func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, SessionID int, UserID int) {
|
||||||
var aiMsg proto.AIQueueMessage
|
var aiMsg proto.AIQueueMessage
|
||||||
var wsMsg proto.WSMessage
|
var wsMsg proto.WSMessage
|
||||||
var apiMsg api.GenerateResponse
|
var apiMsg api.GenerateResponse
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue