saw-go/handler/im.go

500 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package handler
import (
"StuAcaWorksAI/dao"
"StuAcaWorksAI/proto"
"StuAcaWorksAI/service"
"StuAcaWorksAI/service/generalModel"
"StuAcaWorksAI/worker"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"log"
"net/http"
"strconv"
"sync"
"time"
)
type SMessage struct {
Type int `json:"type" form:"type"`
Msg string `json:"msg" form:"msg"`
ToID int `json:"to_id" form:"to_id"`
SessionID int `json:"session_id" form:"session_id"`
}
var (
upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// 允许所有来源的连接
return true
},
}
)
// 创建一个用于存储WebSocket连接的map和互斥锁
var (
clients = make(map[*websocket.Conn]bool)
clientsMux sync.Mutex
)
func SetUpIMGroup(router *gin.Engine) {
imGroup := router.Group("/im")
imGroup.POST("/send_message", SendMessage)
imGroup.GET("/ai_chat_ws", SRMessage)
imGroup.POST("/get_message", GetMessage)
imGroup.GET("/kbase_text_ws", KBaseTextWS)
//添加国外ai处理接口服务之间调用
imGroup.POST("/send_message_foreign_ai", SendMessageForeignAI)
imGroup.POST("/receive_message_foreign_ai", ReceiveMessageForeignAI)
}
func GetMessage(c *gin.Context) {
var req SMessage
id, _ := c.Get("id")
userID := int(id.(float64))
if err := c.ShouldBind(&req); err == nil {
err2, messages := service.FindMessageBySessionID(req.SessionID, userID)
if err2 == nil {
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": messages})
} else {
c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.MsgSendFailed, "message": "failed"})
}
} else {
c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"})
}
}
func SendMessage(c *gin.Context) {
var req SMessage
id, _ := c.Get("id")
userID := int(id.(float64))
if err := c.ShouldBind(&req); err == nil {
var err2 error
var mid uint
err2, mid = service.CreateGeneralMessageService(userID, req.ToID, req.Type, req.SessionID, req.Msg)
if err2 == nil {
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success", "data": mid})
} else {
c.JSON(http.StatusOK, gin.H{"error": err2.Error(), "code": proto.MsgSendFailed, "message": "failed"})
}
} else {
c.JSON(http.StatusOK, gin.H{"error": err.Error(), "code": proto.ParameterError, "message": "failed"})
}
}
// ws长连接
func SRMessage(c *gin.Context) {
id, _ := c.Get("id")
userID := int(id.(float64))
// 升级HTTP连接为WebSocket连接
ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil)
clients[ws] = true
if err1 != nil {
log.Println("升级为WebSocket时发生错误:", err1)
return
}
subscribeAndHandleIMMessages(ws, userID)
}
func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
ctx := context.Background()
//随机字符串
strRandom := uuid.NewString()
chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg_" + strRandom
pubsub := worker.RedisClient.Subscribe(ctx, chanel)
defer pubsub.Close()
defer ws.Close()
ch := pubsub.Channel()
//定时器每3秒发送一次心跳
//ticker := time.NewTicker(3 * time.Second)
//defer ticker.Stop()
var sessionID uint
sessionID = 0
//接收客户端消息,发送消息
done := make(chan struct{})
go func() {
defer close(done)
for {
_, message, err := ws.ReadMessage()
log.Println("rec message:", string(message))
if err != nil {
log.Println("Read error:", err)
break
}
var data proto.WSMessageReq
err1 := json.Unmarshal(message, &data)
if err1 != nil {
log.Println("Error parsing JSON:", err1)
}
err = doReceiveGenChatMessage(userId, &sessionID, &data, chanel)
if err != nil {
log.Println("do receive gen chat message error:", err)
//返回错误消息
errMsg := proto.WSMessage{Code: proto.OperationFailed, ErrorMsg: "服务器处理错误", SessionID: int(sessionID)}
errMsgStr, _ := json.Marshal(errMsg)
worker.Publish(chanel, string(errMsgStr), time.Second*60)
return
}
}
}()
go func() {
//defer close(done)
//从redis订阅消息
for m := range ch {
msg := m.Payload // 获取消息,消息格式为json
if msg != "" {
err2 := ws.WriteMessage(websocket.TextMessage, []byte(msg))
if err2 != nil {
// 发生错误,删除连接
clientsMux.Lock()
delete(clients, ws)
clientsMux.Unlock()
break
}
}
}
//查看done是否关闭
}()
//发送心跳包
for {
select {
case <-done:
//结束会话将context保存
ctx2, ok := service.SessionContextMap[int(sessionID)]
if !ok {
log.Println("session context not found")
return
}
err := service.UpdateSessionByID(int(sessionID), userId, "", ctx2)
if err != nil {
log.Println("update session context error:", err)
}
//保存的会话
delete(service.SessionContextMap, int(sessionID))
//客户端断开连接
return
//case t := <-ticker.C:
// //发送心跳
// err := ws.WriteMessage(websocket.PingMessage, nil)
// if err != nil {
// log.Println("Ping error:", err, t)
// return
// }
//}
}
}
}
func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageReq, chanel string) error {
var model dao.Model
//查看请求功能类型
//models, funcs := service.FindFuncModelByFunction(data.Function, userId)
models, funcs, mferr := service.FindFuncModelListByFunctionV2(data.Function)
if mferr != nil {
return mferr
}
//log.Println("find function model by function:", models, funcs)
//log.Println("funcs:", len(funcs), "\tmodels:", len(models))
if len(funcs) == 0 {
return errors.New("function not exist")
}
if len(models) == 0 {
//log.Println("doReceiveGenChatMessage models:", models)
return errors.New("model not exist")
}
var err error
sessionName := ""
if data.IsImage || data.IsFile {
var fileMsg proto.UserFileMessage
err = json.Unmarshal([]byte(data.Msg), &fileMsg)
if err != nil {
log.Println("Error parsing JSON:", err)
}
sessionName = fileMsg.Text
} else {
sessionName = data.Msg
}
//前端传入的sessionID为0则新会话
*sessionID = uint(data.SessionID)
if *sessionID == 0 {
//创建会话
if sessionName == "" {
sessionName = "新会话"
}
sessionType := proto.SessionTypeUserWithModelGeneration //默认通用聊天
if data.Function == "gen-prompt" {
sessionType = proto.SessionTypeUserPrompt //提示词会话
}
err, *sessionID = service.CreateSession(userId, sessionType, sessionName)
if err != nil {
log.Println("Create session error:", err)
}
log.Println("create session id:", *sessionID)
//会话名称过长使用ai进行总结
if len(sessionName) > 30 && data.Function == "gen-chat" { //只有在通用聊天时才进行ai总结
//通过ai总结会话名称
go func() {
sessionErr := service.UpdateSessionNameByAIAndID(int(*sessionID), userId, sessionName)
if sessionErr != nil {
log.Println("update ai session name error:", sessionErr)
return
} //更新会话名称
}()
}
}
//获取模型
for _, v := range models {
if v.ID == data.ModelID {
model = v
break
}
}
if model.ID == 0 {
//说明前端传入的模型id不正确取第一个模型
model = models[0]
}
//解析模型参数
var modelParam proto.ModelParam
err = json.Unmarshal([]byte(model.Parameter), &modelParam)
if err != nil {
return err
}
modelParam.Url = model.Url
modelParam.System = funcs[0].Info //系统功能
//设置模型temperature及其他参数
checkAndSetModelTemperatureParam(&modelParam, data, &model)
imContext := proto.IMParamContext{UserID: userId, SessionID: int(*sessionID), FunctionID: int(funcs[0].ID), ModelID: int(model.ID), Question: data.Msg, Channel: chanel, ModelType: model.Type}
var userMsgID uint
status := proto.MsgHasRead
if data.IsImage || data.IsFile {
status = proto.UserToModelImageMsgType //图片类型
}
//将消息存入数据库
err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, status, int(funcs[0].ID))
log.Println("create user message id:", userMsgID)
if err != nil {
return err
} //机器人id为2,消息类型为1,状态为1
imContext.QuestionMessageID = int(userMsgID)
//若为知识库消息需要先根据用户消息获取知识库查询内容让ai进行最终结果生成
if data.IsKBase {
//根据会话ID获取对应知识库消息
kbaseMsg, err2 := service.FindKnowledgeBaseBySessionID(*sessionID, uint(userId))
if err2 != nil {
log.Println("find knowledge base by session id error:", err2)
service.WSReturnErrorMessage(imContext.Channel, "find knowledge base error:"+err2.Error(), proto.WSKBaseSessionError)
return err2
}
data.KBaseID = kbaseMsg.ID
err2 = service.SelectKBaseContentByQuery(&imContext, data)
if err2 != nil {
log.Println("select kbase content error:", err2)
service.WSReturnErrorMessage(imContext.Channel, "select kbase error:"+err2.Error(), proto.WSKBaseServerError)
return err2
}
if data.KBaseType == "query" {
//若为查询直接返回
return nil
}
}
//需要后续处理
//err, _ = service.WSReceiveMessageService(userId, int(*sessionID), chanel, *data, model, modelParam)
err, _ = service.WSReceiveMessageServiceV2(modelParam, &imContext)
return err
}
func KBaseTextWS(c *gin.Context) {
id, _ := c.Get("id")
userID := int(id.(float64))
//query参数
kBaseSercerID := c.Query("kbase_server_id")
if kBaseSercerID == "" {
c.JSON(http.StatusOK, gin.H{"code": proto.ParameterError, "message": "kbase_server_id不能为空"})
return
}
// 升级HTTP连接为WebSocket连接
ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil)
if err1 != nil {
log.Println("升级为WebSocket时发生错误:", err1)
return
}
subscribeAndHandleKBaseTextMessages(ws, userID, kBaseSercerID)
}
func subscribeAndHandleKBaseTextMessages(ws *websocket.Conn, userId int, kBaseServerID string) {
ctx := context.Background()
chanel := fmt.Sprintf("kbase_server_id_%s", kBaseServerID) // 频道名称
//设置KBaseServerID可用状态加入可用集合
setKey := "kbase_server_id_set_is_available"
worker.SetRedisSetAdd(setKey, kBaseServerID)
pubsub := worker.RedisClient.Subscribe(ctx, chanel)
defer func() {
//设置KBaseServerID不可用状态删除可用集合
worker.SetRedisSetRemove(setKey, kBaseServerID)
pubsub.Close()
ws.Close()
}()
ch := pubsub.Channel()
//接收客户端消息,发送消息
done := make(chan struct{})
go func() {
defer close(done)
for {
_, message, err := ws.ReadMessage()
var data proto.KnowledgeBaseServerResponse
//log.Println("rec message:", string(message))
if err != nil {
log.Println("Read error:", err)
break
}
err1 := json.Unmarshal(message, &data)
if err1 != nil {
log.Println("Error parsing JSON:", err1)
continue
}
service.CreateAIStreamMsg(userId, data.IMContext.ModelID, data.IMContext.SessionID, data, proto.KnowledgeBaseServerResponseType) //将消息存入数据库
//这里主要将接收到的消息转到对应协程
//目标chanel
targetChanel := fmt.Sprintf("kbase_%s", data.KnowledgeBase.UUID)
//将消息转入
err = worker.RedisClient.Publish(ctx, targetChanel, message).Err()
if err != nil {
log.Println("KBase Text WS Publish error:", err)
}
}
}()
go func() {
//defer close(done)
//从redis订阅消息
for m := range ch {
msg := m.Payload // 获取消息,消息格式为json
if msg != "" {
err2 := ws.WriteMessage(websocket.TextMessage, []byte(msg))
if err2 != nil {
break
}
}
}
}()
for {
select {
case <-done:
return
}
}
}
func checkAndSetModelTemperatureParam(modelParam *proto.ModelParam, data *proto.WSMessageReq, model *dao.Model) {
isErrParam := false
if data.Temperature < 0 || data.Temperature > 2 {
isErrParam = true
} else {
modelParam.Temperature = data.Temperature
}
if model.Type == proto.ModelTypeDouBao {
if data.TopP < 0 || data.TopP > 1 {
isErrParam = true
} else {
modelParam.TopP = data.TopP
}
} else if model.Type == proto.ModelTypeSpark {
if data.TopK < 0 || data.TopK > 1 {
isErrParam = true
} else {
modelParam.TopK = data.TopK
}
} else {
isErrParam = true
}
if isErrParam {
modelParam.SetDefaultParams()
}
}
func SendMessageForeignAI(c *gin.Context) {
id, _ := c.Get("user_id")
userID := id.(int)
user := service.GetUserByIDWithCache(userID)
var resp proto.GenerateResp
var req proto.SendMessageForeignAIRRequest
if user.Role == "admin" {
//解析请求参数
if err := c.ShouldBind(&req); err == nil {
reqBytes, _ := json.Marshal(req)
log.Println("SendMessageForeignAI req:", string(reqBytes))
//校验参数
if req.ModelParam.Model == "" || req.ModelParam.Url == "" || req.ModelParam.APIKey == "" {
resp.Code = proto.ParameterError
resp.Message = "模型必要参数不能为空"
} else {
//处理
generalModel.OpenAIForeignAI(req.ModelParam, &req.IMParamContext, req.PlatForm, req.ChatCompletionMessageParamUnion)
//返回结果
resp.Code = proto.SuccessCode
resp.Message = "success"
}
} else {
resp.Code = proto.ParameterError
resp.Message = "参数错误"
}
} else {
resp.Code = proto.PermissionDenied
resp.Message = "没有权限"
}
c.JSON(http.StatusOK, resp)
}
func ReceiveMessageForeignAI(c *gin.Context) {
id, _ := c.Get("user_id")
userID := id.(int)
user := service.GetUserByIDWithCache(userID)
var resp proto.GenerateResp
var req proto.ReceiveMessageForeignAIRRequest
if user.Role == "admin" {
//解析请求参数
if err := c.ShouldBind(&req); err == nil {
//校验参数
if req.IMParamContext.SessionID == 0 || req.ChatCompletionChunk.Choices == nil {
resp.Code = proto.ParameterError
resp.Message = "模型必要参数不能为空"
} else {
generalModel.OpenAIForeignAIResponseToGeneralMassageAndSendMsgQueue(&req)
resp.Code = proto.SuccessCode
resp.Message = "success"
}
} else {
resp.Code = proto.ParameterError
resp.Message = "参数错误"
}
} else {
resp.Code = proto.PermissionDenied
resp.Message = "没有权限"
}
c.JSON(http.StatusOK, resp)
}