重构ws消息获取,添加根据功能查找对应模型
This commit is contained in:
parent
e3a14eec54
commit
269bdea1c7
10
dao/model.go
10
dao/model.go
|
|
@ -135,3 +135,13 @@ func DeleteFunctionModelByID(id int) error {
|
|||
}
|
||||
return res.Error
|
||||
}
|
||||
|
||||
func FindFunctionModelByFunction(function string) []FunctionModel {
|
||||
var models []FunctionModel
|
||||
if proto.Config.SERVER_SQL_LOG {
|
||||
DB.Debug().Where("function = ?", function).Find(&models)
|
||||
} else {
|
||||
DB.Where("function = ?", function).Find(&models)
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"StuAcaWorksAI/dao"
|
||||
"StuAcaWorksAI/proto"
|
||||
"StuAcaWorksAI/service"
|
||||
"StuAcaWorksAI/worker"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"log"
|
||||
|
|
@ -105,7 +107,6 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
|
|||
//defer ticker.Stop()
|
||||
|
||||
var sessionID uint
|
||||
var userMsgID uint
|
||||
|
||||
//接收客户端消息,发送消息
|
||||
done := make(chan struct{})
|
||||
|
|
@ -123,31 +124,9 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
|
|||
if err1 != nil {
|
||||
log.Println("Error parsing JSON:", err1)
|
||||
}
|
||||
if sessionID == 0 {
|
||||
//创建会话
|
||||
//会话名字为前10个字符,如果
|
||||
//var sessionName string
|
||||
//if len(data.Msg) > 10 {
|
||||
// sessionName = data.Msg[:10]
|
||||
//} else {
|
||||
// sessionName = data.Msg
|
||||
//}
|
||||
err, sessionID = service.CreateSession(userId, data.Msg)
|
||||
log.Println("create session id:", sessionID)
|
||||
if err != nil {
|
||||
log.Println("Create session error:", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
//将消息存入数据库
|
||||
err, userMsgID = service.CreateMessage(1, int(sessionID), userId, 2, data.Msg, 1)
|
||||
log.Println("create user message id:", userMsgID)
|
||||
if err != nil {
|
||||
return
|
||||
} //机器人id为2,消息类型为1,状态为1
|
||||
//需要后续处理
|
||||
err, _ = service.WSReceiveMessageService(userId, int(sessionID), chanel, data)
|
||||
err = doReceiveGenChatMessage(userId, sessionID, &data)
|
||||
if err != nil {
|
||||
log.Println("do receive gen chat message error:", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
@ -202,3 +181,41 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
func doReceiveGenChatMessage(userId int, sessionID uint, data *proto.WSMessageReq) error {
|
||||
var err error
|
||||
chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg"
|
||||
if sessionID == 0 {
|
||||
//创建会话
|
||||
err, sessionID = service.CreateSession(userId, data.Msg)
|
||||
log.Println("create session id:", sessionID)
|
||||
if err != nil {
|
||||
log.Println("Create session error:", err)
|
||||
}
|
||||
}
|
||||
var model dao.Model
|
||||
//查看请求功能类型
|
||||
if data.Function == "gen-ai-chat" {
|
||||
model = service.FindFuncModelByFunction(data.Function, userId)[0]
|
||||
} else {
|
||||
return errors.New("function not exist")
|
||||
}
|
||||
//解析模型参数
|
||||
var modelParam proto.ModelParam
|
||||
err = json.Unmarshal([]byte(model.Parameter), &modelParam)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
modelParam.Url = model.Url
|
||||
|
||||
var userMsgID uint
|
||||
//将消息存入数据库
|
||||
err, userMsgID = service.CreateMessage(proto.UserAndModelMsgType, int(sessionID), userId, int(model.ID), data.Msg, proto.MsgHasRead)
|
||||
log.Println("create user message id:", userMsgID)
|
||||
if err != nil {
|
||||
return err
|
||||
} //机器人id为2,消息类型为1,状态为1
|
||||
//需要后续处理
|
||||
err, _ = service.WSReceiveMessageService(userId, int(sessionID), chanel, *data, model, modelParam)
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
12
proto/im.go
12
proto/im.go
|
|
@ -15,10 +15,6 @@ type WSMessage struct {
|
|||
ToID int `json:"to_id"` //接收者id
|
||||
}
|
||||
|
||||
type ModelParam struct {
|
||||
Model string `json:"model"` //模型名称
|
||||
}
|
||||
|
||||
type WSMessageReq struct {
|
||||
Type string `json:"type"` //接收及发送消息类型
|
||||
Function string `json:"function"` //功能名称
|
||||
|
|
@ -26,3 +22,11 @@ type WSMessageReq struct {
|
|||
SessionID int `json:"session_id"` //应用层会话id
|
||||
ToID int `json:"to_id"` //接收者id
|
||||
}
|
||||
|
||||
type ModelParam struct {
|
||||
Model string `json:"model"` //模型名称
|
||||
Url string `json:"url"` //模型地址
|
||||
APPID string `json:"appid"` //应用id
|
||||
APISecret string `json:"apiSecret"` //应用密钥
|
||||
APIKey string `json:"apiKey"` //应用key
|
||||
}
|
||||
|
|
|
|||
|
|
@ -89,3 +89,9 @@ const (
|
|||
FuncModelDeleteFailed = 132 // 删除功能模型失败
|
||||
FuncModelSearchFailed = 133 // 获取功能模型失败
|
||||
)
|
||||
|
||||
const (
|
||||
// 代码中使用常量定义
|
||||
UserAndModelMsgType = 2 // 用户与模型消息类型
|
||||
MsgHasRead = 1 // 消息已读
|
||||
)
|
||||
|
|
|
|||
|
|
@ -79,15 +79,16 @@ func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg strin
|
|||
var client *api.Client
|
||||
var SessionContextMap map[int][]int
|
||||
|
||||
func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) {
|
||||
func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq, model dao.Model, modelParam proto.ModelParam) (error, uint) {
|
||||
var resErr error
|
||||
var resID uint
|
||||
//处理消息
|
||||
if msg.Type == "ollama" {
|
||||
resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg)
|
||||
} else if msg.Type == "spark" {
|
||||
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg)
|
||||
|
||||
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam)
|
||||
} else {
|
||||
return errors.New("model not exist"), 0
|
||||
}
|
||||
return resErr, resID
|
||||
}
|
||||
|
|
@ -150,10 +151,9 @@ func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMes
|
|||
return resErr, resID
|
||||
}
|
||||
|
||||
func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) {
|
||||
func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam) (error, uint) {
|
||||
var resErr error
|
||||
var resID uint
|
||||
//查找模型
|
||||
|
||||
worker.Spark(modelParam, msg.Msg, channel, sessionID, userID)
|
||||
return resErr, resID
|
||||
}
|
||||
|
|
|
|||
|
|
@ -106,3 +106,15 @@ func UpdateFuncModelByID(id int, userID, modelID uint, name, info, function stri
|
|||
}
|
||||
return dao.UpdateFunctionModelByID(id, userID, modelID, name, info, function)
|
||||
}
|
||||
|
||||
// 根据功能查找对应功能
|
||||
func FindFuncModelByFunction(function string, userID int) []dao.Model {
|
||||
//先查找对应功能
|
||||
funcModels := dao.FindFunctionModelByFunction(function)
|
||||
if len(funcModels) == 0 {
|
||||
return nil
|
||||
}
|
||||
modelID := funcModels[0].ModelID
|
||||
//再查找对应模型
|
||||
return dao.FindModelByID(int(modelID), userID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -105,13 +105,13 @@ type SparkResponse struct {
|
|||
|
||||
/***************** Spark响应结构 ***********************/
|
||||
|
||||
func Spark(hostUrl string, apiKey, apiSecret, appid string, question, channel string, SessionID, UserID int) {
|
||||
func Spark(modelParam proto.ModelParam, question, channel string, SessionID, UserID int) {
|
||||
//建立websocket连接
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
}
|
||||
//建立连接
|
||||
conn, resp, err := dialer.Dial(assembleAuthUrl(hostUrl, apiKey, apiSecret), nil)
|
||||
conn, resp, err := dialer.Dial(assembleAuthUrl(modelParam.Url, modelParam.APIKey, modelParam.APISecret), nil)
|
||||
if err != nil {
|
||||
panic(readResp(resp) + err.Error())
|
||||
return
|
||||
|
|
@ -121,7 +121,7 @@ func Spark(hostUrl string, apiKey, apiSecret, appid string, question, channel st
|
|||
|
||||
//发送消息
|
||||
go func() {
|
||||
data := genSparkParams(appid, question)
|
||||
data := genSparkParams(modelParam.APPID, question)
|
||||
//将数据转换为json
|
||||
//dataByte, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
|
|
@ -172,7 +172,7 @@ func Spark(hostUrl string, apiKey, apiSecret, appid string, question, channel st
|
|||
}
|
||||
}
|
||||
|
||||
func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponsem, channel string, SessionID int, UserID int) {
|
||||
func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponse, channel string, SessionID int, UserID int) {
|
||||
var aiMsg proto.AIQueueMessage
|
||||
var wsMsg proto.WSMessage
|
||||
var apiMsg api.GenerateResponse
|
||||
|
|
|
|||
Loading…
Reference in New Issue