diff --git a/go.mod b/go.mod index f9c02ad..0addf2b 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module StuAcaWorksAI -go 1.23 +go 1.24.0 + +toolchain go1.24.1 require ( github.com/gin-gonic/gin v1.10.0 @@ -8,6 +10,7 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 + github.com/ollama/ollama v0.6.2 github.com/robfig/cron/v3 v3.0.1 gorm.io/driver/mysql v1.5.7 gorm.io/driver/postgres v1.5.11 @@ -36,7 +39,6 @@ require ( github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect - github.com/kr/text v0.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect @@ -46,11 +48,11 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect golang.org/x/arch v0.8.0 // indirect - golang.org/x/crypto v0.23.0 // indirect - golang.org/x/net v0.25.0 // indirect - golang.org/x/sync v0.1.0 // indirect - golang.org/x/sys v0.26.0 // indirect - golang.org/x/text v0.15.0 // indirect + golang.org/x/crypto v0.33.0 // indirect + golang.org/x/net v0.35.0 // indirect + golang.org/x/sync v0.11.0 // indirect + golang.org/x/sys v0.30.0 // indirect + golang.org/x/text v0.22.0 // indirect google.golang.org/protobuf v1.34.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 7fc7e68..0086496 100644 --- a/go.sum +++ b/go.sum @@ -8,7 +8,6 @@ github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/ github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -38,8 +37,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -78,6 +77,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/ollama/ollama v0.6.2 h1:IMUxPByUqXY4fvt/5Rsm6zuffN1X+7jEWIjkqo4arK4= +github.com/ollama/ollama v0.6.2/go.mod h1:pGgtoNyc9DdM6oZI6yMfI6jTk2Eh4c36c2GpfQCH7PY= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= @@ -109,20 +110,18 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= +golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= +golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= +golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= +golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= +golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/handler/im.go b/handler/im.go index 5b7b17f..9fd6da0 100644 --- a/handler/im.go +++ b/handler/im.go @@ -3,8 +3,16 @@ package handler import ( "StuAcaWorksAI/proto" "StuAcaWorksAI/service" + "StuAcaWorksAI/worker" + "context" + "encoding/json" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "log" "net/http" + "strconv" + "sync" + "time" ) type SMessage struct { @@ -14,9 +22,27 @@ type SMessage struct { SessionID int `json:"session_id" form:"session_id"` } +var ( + upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + // 允许所有来源的连接 + return true + }, + } +) + +// 创建一个用于存储WebSocket连接的map和互斥锁 +var ( + clients = make(map[*websocket.Conn]bool) + clientsMux sync.Mutex +) + func SetUpIMGroup(router *gin.Engine) { imGroup := router.Group("/im") imGroup.POST("/send_message", SendMessage) + imGroup.GET("/ai_chat_ws", SRMessage) } func SendMessage(c *gin.Context) { @@ -37,4 +63,100 @@ func SendMessage(c *gin.Context) { } } -//ws长连接 +// ws长连接 +func SRMessage(c *gin.Context) { + id, _ := c.Get("id") + userID := int(id.(float64)) + // 升级HTTP连接为WebSocket连接 + ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil) + clients[ws] = true + if err1 != nil { + log.Println("升级为WebSocket时发生错误:", err1) + return + } + subscribeAndHandleIMMessages(ws, userID) +} + +func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) { + ctx := context.Background() + chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg" + pubsub := worker.RedisClient.Subscribe(ctx, chanel) + defer pubsub.Close() + defer ws.Close() + ch := pubsub.Channel() + //定时器,每3秒发送一次心跳 + ticker := time.NewTicker(3 * time.Second) + defer ticker.Stop() + + var sessionID uint + + //接收客户端消息,发送消息 + done := make(chan struct{}) + go func() { + defer close(done) + for { + _, message, err := ws.ReadMessage() + if err != nil { + log.Println("Read error:", err) + } + var data proto.WSMessage + err1 := json.Unmarshal(message, &data) + if err1 != nil { + log.Println("Error parsing JSON:", err1) + } + if sessionID == 0 { + //创建会话 + sessionName := data.Msg[:10] + err, sessionID = service.CreateSession(userId, sessionName) + if err != nil { + log.Println("Create session error:", err) + return + } + } + //需要后续处理 + err, _ = service.WSReceiveMessageService(userId, int(sessionID), chanel, data) + if err != nil { + return + } + } + }() + + go func() { + var data proto.WSMessage + //从redis订阅消息 + for m := range ch { + msg := m.Payload // 获取消息,消息格式为json + if msg != "" { + data.Msg = msg + data.Type = "msg" + data.SessionID = int(sessionID) + data.ToID = 0 + msg_, _ := json.Marshal(data) + err2 := ws.WriteMessage(websocket.TextMessage, msg_) + if err2 != nil { + // 发生错误,删除连接 + clientsMux.Lock() + delete(clients, ws) + clientsMux.Unlock() + break + } + } + } + }() + + //发送心跳包 + for { + select { + case <-done: + //客户端断开连接 + return + case t := <-ticker.C: + //发送心跳 + err := ws.WriteMessage(websocket.PingMessage, nil) + if err != nil { + log.Println("Ping error:", err, t) + return + } + } + } +} diff --git a/main.go b/main.go index 4a4dd33..3263608 100644 --- a/main.go +++ b/main.go @@ -37,6 +37,7 @@ func main() { handler.SetUpFileGroup(r) // File handler.SetUpSessionGroup(r) // Session handler.SetUpModelGroup(r) // Model + handler.SetUpIMGroup(r) // IM defer dao.Close() defer worker.CloseRedis() //定时任务 diff --git a/proto/im.go b/proto/im.go new file mode 100644 index 0000000..7542a9e --- /dev/null +++ b/proto/im.go @@ -0,0 +1,14 @@ +package proto + +type AIQueueMessage struct { + Type string `json:"type"` //声明不同消息类型 + ID uint `json:"id"` //消息id + Msg string `json:"msg"` //原始json消息 +} + +type WSMessage struct { + Type string `json:"type"` //接收及发送消息类型 + Msg string `json:"msg"` //消息内容,只进行转发,不做处理 + SessionID int `json:"session_id"` //应用层会话id + ToID int `json:"to_id"` //接收者id +} diff --git a/service/imService.go b/service/imService.go index 1015ada..4036353 100644 --- a/service/imService.go +++ b/service/imService.go @@ -2,7 +2,14 @@ package service import ( "StuAcaWorksAI/dao" + "StuAcaWorksAI/proto" + "StuAcaWorksAI/worker" + "context" + "encoding/json" "errors" + "github.com/ollama/ollama/api" + "log" + "time" ) // 创建会话服务 @@ -68,3 +75,55 @@ func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg strin err, mid := CreateMessage(msgType, sessionID, fromID, toID, msg, 0) return err, mid } + +var client *api.Client + +func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessage) (error, uint) { + var resErr error + var resID uint + //处理消息 + if msg.Type == "ollama" { + var err error + if client == nil { + client, err = api.ClientFromEnvironment() + } + if err != nil { + log.Println("get ollama client error:", err) + } + prompt := msg.Msg + var actx []int + req := &api.GenerateRequest{ + Model: "qwen2.5-coder:latest", + Context: actx, + Prompt: prompt, + } + ctx := context.Background() + robotMsg := "" + var aiMsg proto.AIQueueMessage + respFunc := func(resp api.GenerateResponse) error { + aiMsg.Type = "ollama" + respStr, _ := json.Marshal(resp) + aiMsg.Msg = string(respStr) + robotMsg += resp.Response + if resp.Done { //该消息完成 + actx = resp.Context + //创建消息 + err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1 + if err2 != nil { + log.Println("create robot message error:", err2) + } + aiMsg.ID = msgID + resID = msgID + } + //发送消息 + aiMsgStr, _ := json.Marshal(aiMsg) + worker.Publish(channel, string(aiMsgStr), time.Second*60) + return nil + } + err = client.Generate(ctx, req, respFunc) + if err != nil { + log.Fatal(err) + } + } + return resErr, resID +}