diff --git a/dao/im.go b/dao/im.go index 84d9b7e..3a0c00d 100644 --- a/dao/im.go +++ b/dao/im.go @@ -16,12 +16,13 @@ type Session struct { type Message struct { gorm.Model - Type int `gorm:"column:type"` //1用户之间的消息,2为与模型消息 - SessionID int `gorm:"column:session_id;index"` //添加索引 - FromID int `gorm:"column:from_id"` //发送者,可以为用户或者模型 - ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型,则为模型id,根据type判断 - Msg string `gorm:"column:msg"` //消息内容 - Status int `gorm:"column:status"` //0为未读,1为已读 + Type int `gorm:"column:type"` //1用户之间的消息,2为与模型消息 + SessionID int `gorm:"column:session_id;index"` //添加索引 + FromID int `gorm:"column:from_id"` //发送者,可以为用户或者模型 + ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型,则为模型id,根据type判断 + Msg string `gorm:"column:msg"` //消息内容 + FunctionID int `gorm:"column:function_id"` //功能id + Status int `gorm:"column:status"` //0为未读,1为已读 } func CreateSession(userID int, name string) (error, uint) { @@ -84,8 +85,8 @@ func DeleteSessionByID(id int) error { } // 创建消息 -func CreateMessage(sessionID, fromID, toID, msgType, status int, msg string) (error, uint) { - message := Message{SessionID: sessionID, FromID: fromID, ToID: toID, Type: msgType, Status: status, Msg: msg} +func CreateMessage(sessionID, fromID, toID, msgType, status int, msg string, function_id int) (error, uint) { + message := Message{SessionID: sessionID, FromID: fromID, ToID: toID, Type: msgType, Status: status, Msg: msg, FunctionID: function_id} var res *gorm.DB if proto.Config.SERVER_SQL_LOG { res = DB.Debug().Create(&message) diff --git a/handler/im.go b/handler/im.go index eff1040..daf7413 100644 --- a/handler/im.go +++ b/handler/im.go @@ -9,6 +9,7 @@ import ( "encoding/json" "errors" "github.com/gin-gonic/gin" + "github.com/google/uuid" "github.com/gorilla/websocket" "log" "net/http" @@ -97,7 +98,9 @@ func SRMessage(c *gin.Context) { func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) { ctx := context.Background() - chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg" + //随机字符串 + strRandom := uuid.NewString() + chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg_" + strRandom pubsub := worker.RedisClient.Subscribe(ctx, chanel) defer pubsub.Close() defer ws.Close() @@ -125,7 +128,7 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) { if err1 != nil { log.Println("Error parsing JSON:", err1) } - err = doReceiveGenChatMessage(userId, &sessionID, &data) + err = doReceiveGenChatMessage(userId, &sessionID, &data, chanel) if err != nil { log.Println("do receive gen chat message error:", err) return @@ -183,9 +186,8 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) { } -func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageReq) error { +func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageReq, chanel string) error { var err error - chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg" if *sessionID == 0 { //创建会话 err, *sessionID = service.CreateSession(userId, data.Msg) @@ -212,15 +214,17 @@ func doReceiveGenChatMessage(userId int, sessionID *uint, data *proto.WSMessageR } modelParam.Url = model.Url modelParam.System = funcs[0].Info //系统功能 + imContext := proto.IMParamContext{UserID: userId, SessionID: int(*sessionID), FunctionID: int(funcs[0].ID), ModelID: int(model.ID), Question: data.Msg, Channel: chanel} var userMsgID uint //将消息存入数据库 - err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, proto.MsgHasRead) + err, userMsgID = service.CreateMessage(proto.UserToModelMsgType, int(*sessionID), userId, int(model.ID), data.Msg, proto.MsgHasRead, int(funcs[0].ID)) log.Println("create user message id:", userMsgID) if err != nil { return err } //机器人id为2,消息类型为1,状态为1 //需要后续处理 - err, _ = service.WSReceiveMessageService(userId, int(*sessionID), chanel, *data, model, modelParam) + //err, _ = service.WSReceiveMessageService(userId, int(*sessionID), chanel, *data, model, modelParam) + err, _ = service.WSReceiveMessageServiceV2(modelParam, &imContext) return err } diff --git a/proto/im.go b/proto/im.go index 34247c2..201bb3f 100644 --- a/proto/im.go +++ b/proto/im.go @@ -48,3 +48,13 @@ type SessionResponse struct { Context []int `gorm:"column:context;type:json"` //会话上下文 MsgCount int `gorm:"column:msg_count"` //消息数量 } + +type IMParamContext struct { + UserID int `json:"user_id"` //用户id + SessionID int `json:"session_id"` //会话id + FunctionID int `json:"function_id"` //功能id + ModelID int `json:"model_id"` //模型id + ModelType string `json:"model_type"` //模型类型 + Question string `json:"question"` //问题 + Channel string `json:"channel"` //消息队列 +} diff --git a/service/doubao.go b/service/doubao.go index 9c19cba..4249e14 100644 --- a/service/doubao.go +++ b/service/doubao.go @@ -70,7 +70,66 @@ func doubao(modelParam proto.ModelParam, question, channel string, SessionID, Us } } //将消息存入数据库 - err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer) + err, _ = dao.CreateMessage(SessionID, modelID, UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, 2) + if err != nil { + log.Println("spark create message error:", err) + } +} + +func DouBaoV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) { + doubaoClient := arkruntime.NewClientWithApiKey( + modelParam.APIKey, + arkruntime.WithBaseUrl(modelParam.Url), + ) + ctx := context.Background() + messages := make([]*model.ChatCompletionMessage, 0) + err := GetDouBaoSessionHistoryMsg(imCtx.SessionID, modelParam.System, &messages, imCtx.Question) + if err != nil { + log.Println("get doubao session history message error:", err) + } + //messagesStr, _ := json.Marshal(messages) + //log.Println("doubao messages:", string(messagesStr)) + + req := model.CreateChatCompletionRequest{ + Model: modelParam.Model, + Messages: messages, + } + + stream, err := doubaoClient.CreateChatCompletionStream(ctx, req) + if err != nil { + fmt.Printf("stream chat error: %v\n", err) + return + } + defer func(stream *utils.ChatCompletionStreamReader) { + err2 := stream.Close() + if err2 != nil { + log.Println("close stream error:", err2) + } + }(stream) + + answer := "" + for { + recv, err3 := stream.Recv() + if err3 == io.EOF { + log.Println("doubao stream end:", err3) + return + } + if err3 != nil { + fmt.Printf("doubao Stream chat error: %v\n", err) + return + } + go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, recv, modelParam.Model) + doubaoToGeneralMassageAndSendMsgQueue(&recv, imCtx.Channel, imCtx.SessionID, imCtx.UserID) + choices := recv.Choices[0] + if choices.FinishReason == proto.FinishReasonStop { + answer += choices.Delta.Content + break + } else { + answer += choices.Delta.Content + } + } + //将消息存入数据库 + err, _ = dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, imCtx.FunctionID) if err != nil { log.Println("spark create message error:", err) } diff --git a/service/imService.go b/service/imService.go index 051dc38..29b9f6b 100644 --- a/service/imService.go +++ b/service/imService.go @@ -65,9 +65,9 @@ func UpdateSessionByID(id int, userId int, name string, context []int) error { } // 创建消息服务 -func CreateMessage(msgType int, sessionID int, fromID int, toID int, msg string, status int) (error, uint) { +func CreateMessage(msgType int, sessionID int, fromID int, toID int, msg string, status int, functionId int) (error, uint) { //sessionID, fromID, toID, msgType, status int, msg string - err, id := dao.CreateMessage(sessionID, fromID, toID, msgType, status, msg) + err, id := dao.CreateMessage(sessionID, fromID, toID, msgType, status, msg, functionId) return err, id } @@ -88,7 +88,7 @@ func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg strin if session.UserID != fromID { return errors.New("session not belong to user"), 0 } - err, mid := CreateMessage(msgType, sessionID, fromID, toID, msg, 0) + err, mid := CreateMessage(msgType, sessionID, fromID, toID, msg, 0, 0) return err, mid } @@ -111,6 +111,22 @@ func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WS return resErr, resID } +func WSReceiveMessageServiceV2(modelParam proto.ModelParam, ctx *proto.IMParamContext) (error, uint) { + var resErr error + var resID uint + //处理消息 + if ctx.ModelType == proto.ModelTypeOllama { + resErr, resID = ReceiveOllamaSessionV2(modelParam, ctx) + } else if ctx.ModelType == proto.ModelTypeSpark { + SparkV2(modelParam, ctx) + } else if ctx.ModelType == proto.ModelTypeDouBao { + DouBaoV2(modelParam, ctx) + } else { + return errors.New("model not exist"), 0 + } + return resErr, resID +} + func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, m *dao.Model) (error, uint) { var resErr error var resID uint @@ -147,7 +163,7 @@ func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMes SessionContextMap[sessionID] = actx log.Println("actx has save:", len(SessionContextMap[sessionID])) //创建消息 - err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1 + err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1, 0) //机器人id为2,消息类型为2,状态为1 if err2 != nil { log.Println("create robot message error:", err2) } @@ -170,6 +186,65 @@ func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMes return resErr, resID } +func ReceiveOllamaSessionV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) (error, uint) { + var resErr error + var resID uint + var err error + if client == nil { + client, err = api.ClientFromEnvironment() + } + if err != nil { + log.Println("get ollama client error:", err) + } + prompt := imCtx.Question + actx, ok := SessionContextMap[int(imCtx.SessionID)] + if ok { + log.Println("actx has get will use:", len(actx)) + } + req := &api.GenerateRequest{ + Model: modelParam.Model, + Context: actx, + Prompt: prompt, + } + ctx := context.Background() + robotMsg := "" + var aiMsg proto.AIQueueMessage + var data proto.WSMessage + respFunc := func(resp api.GenerateResponse) error { + aiMsg.Type = proto.ModelTypeOllama + aiMsg.Msg = resp + robotMsg += resp.Response + if resp.Done { //该消息完成 + actx = resp.Context + if SessionContextMap == nil { + SessionContextMap = make(map[int][]int) + } + SessionContextMap[imCtx.SessionID] = actx + log.Println("actx has save:", len(SessionContextMap[imCtx.SessionID])) + //创建消息 + err2, msgID := CreateMessage(proto.UserToModelMsgType, imCtx.SessionID, imCtx.ModelID, imCtx.UserID, robotMsg, 1, imCtx.FunctionID) //机器人id为2,消息类型为2,状态为1 + if err2 != nil { + log.Println("create robot message error:", err2) + } + aiMsg.ID = msgID + resID = msgID + } + //发送消息 + data.Msg = aiMsg + data.SessionID = int(imCtx.SessionID) + data.ToID = imCtx.UserID + aiMsgStr, _ := json.Marshal(data) + worker.Publish(imCtx.Channel, string(aiMsgStr), time.Second*60) + go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, resp, modelParam.Model) + return nil + } + err = client.Generate(ctx, req, respFunc) + if err != nil { + log.Fatal(err) + } + return resErr, resID +} + func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, model *dao.Model) (error, uint) { var resErr error var resID uint diff --git a/service/spark.go b/service/spark.go index 65aabb2..b634574 100644 --- a/service/spark.go +++ b/service/spark.go @@ -170,7 +170,79 @@ func Spark(modelParam proto.ModelParam, question, channel string, SessionID, Use conn.Close() //将消息存入数据库 - err2, msgID := dao.CreateMessage(SessionID, modelID, UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer) + err2, msgID := dao.CreateMessage(SessionID, modelID, UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, 2) + if err2 != nil { + log.Println("spark create message error:", err2) + } else { + log.Println("spark create message success,message id:", msgID) + } +} + +func SparkV2(modelParam proto.ModelParam, imCtx *proto.IMParamContext) { + log.Println("spark model param:", modelParam) + log.Println("spark question:", imCtx.Question) + //建立websocket连接 + dialer := websocket.Dialer{ + HandshakeTimeout: 5 * time.Second, + } + //建立连接 + conn, resp, err := dialer.Dial(assembleAuthUrl(modelParam.Url, modelParam.APIKey, modelParam.APISecret), nil) + if err != nil { + panic(readResp(resp) + err.Error()) + return + } else if resp.StatusCode != 101 { + panic(readResp(resp) + err.Error()) + } + + //发送消息 + go func() { + data := genSparkParams(imCtx.Question, modelParam.APPID, modelParam.Domain, imCtx.SessionID, modelParam.System) + //将数据转换为json + //dataByte, err := json.Marshal(data) + if err != nil { + fmt.Println("Error parsing JSON:", err) + return + } + err2 := conn.WriteJSON(data) + if err != nil { + fmt.Println("write message error:", err2) + return + } + }() + + answer := "" + + //接收消息 + for { + _, msg, err2 := conn.ReadMessage() + if err2 != nil { + log.Println("read message error:", err2) + break + } + var data SparkResponse + err = json.Unmarshal(msg, &data) + if err != nil { + log.Println("Error parsing JSON:", err) + return + } + + if data.Header.Code != 0 { + log.Println("Error response:", data.Header.Message, string(msg)) + break + } + if data.Header.Status == 2 { + //最后一个结果 + answer += data.Payload.Choices.Text[0].Content + } else { + answer += data.Payload.Choices.Text[0].Content + } + go CreateAIStreamMsg(imCtx.UserID, imCtx.ModelID, imCtx.SessionID, data, modelParam.Model) + SparkToGeneralMassageAndSendMsgQueue(&data, imCtx.Channel, imCtx.SessionID, imCtx.UserID) + } + + conn.Close() + //将消息存入数据库 + err2, msgID := dao.CreateMessage(imCtx.SessionID, imCtx.ModelID, imCtx.UserID, proto.ModelToUserMsgType, proto.MsgHasRead, answer, imCtx.FunctionID) if err2 != nil { log.Println("spark create message error:", err2) } else {