websocket接口与服务端进行连接
This commit is contained in:
parent
8a6949c5fd
commit
611e027262
16
go.mod
16
go.mod
|
|
@ -1,6 +1,8 @@
|
|||
module StuAcaWorksAI
|
||||
|
||||
go 1.23
|
||||
go 1.24.0
|
||||
|
||||
toolchain go1.24.1
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
|
|
@ -8,6 +10,7 @@ require (
|
|||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/ollama/ollama v0.6.2
|
||||
github.com/robfig/cron/v3 v3.0.1
|
||||
gorm.io/driver/mysql v1.5.7
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
|
|
@ -36,7 +39,6 @@ require (
|
|||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
|
|
@ -46,11 +48,11 @@ require (
|
|||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.23.0 // indirect
|
||||
golang.org/x/net v0.25.0 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sys v0.26.0 // indirect
|
||||
golang.org/x/text v0.15.0 // indirect
|
||||
golang.org/x/crypto v0.33.0 // indirect
|
||||
golang.org/x/net v0.35.0 // indirect
|
||||
golang.org/x/sync v0.11.0 // indirect
|
||||
golang.org/x/sys v0.30.0 // indirect
|
||||
golang.org/x/text v0.22.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
|
|
|||
29
go.sum
29
go.sum
|
|
@ -8,7 +8,6 @@ github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/
|
|||
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
|
||||
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
|
||||
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
|
|
@ -38,8 +37,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
|||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
|
|
@ -78,6 +77,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
|
|||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
|
||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||
github.com/ollama/ollama v0.6.2 h1:IMUxPByUqXY4fvt/5Rsm6zuffN1X+7jEWIjkqo4arK4=
|
||||
github.com/ollama/ollama v0.6.2/go.mod h1:pGgtoNyc9DdM6oZI6yMfI6jTk2Eh4c36c2GpfQCH7PY=
|
||||
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
|
||||
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
|
||||
github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE=
|
||||
|
|
@ -109,20 +110,18 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
|
|||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
||||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
|
|
|||
122
handler/im.go
122
handler/im.go
|
|
@ -3,8 +3,16 @@ package handler
|
|||
import (
|
||||
"StuAcaWorksAI/proto"
|
||||
"StuAcaWorksAI/service"
|
||||
"StuAcaWorksAI/worker"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SMessage struct {
|
||||
|
|
@ -14,9 +22,27 @@ type SMessage struct {
|
|||
SessionID int `json:"session_id" form:"session_id"`
|
||||
}
|
||||
|
||||
var (
|
||||
upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
// 允许所有来源的连接
|
||||
return true
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// 创建一个用于存储WebSocket连接的map和互斥锁
|
||||
var (
|
||||
clients = make(map[*websocket.Conn]bool)
|
||||
clientsMux sync.Mutex
|
||||
)
|
||||
|
||||
func SetUpIMGroup(router *gin.Engine) {
|
||||
imGroup := router.Group("/im")
|
||||
imGroup.POST("/send_message", SendMessage)
|
||||
imGroup.GET("/ai_chat_ws", SRMessage)
|
||||
}
|
||||
|
||||
func SendMessage(c *gin.Context) {
|
||||
|
|
@ -38,3 +64,99 @@ func SendMessage(c *gin.Context) {
|
|||
}
|
||||
|
||||
// ws长连接
|
||||
func SRMessage(c *gin.Context) {
|
||||
id, _ := c.Get("id")
|
||||
userID := int(id.(float64))
|
||||
// 升级HTTP连接为WebSocket连接
|
||||
ws, err1 := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
clients[ws] = true
|
||||
if err1 != nil {
|
||||
log.Println("升级为WebSocket时发生错误:", err1)
|
||||
return
|
||||
}
|
||||
subscribeAndHandleIMMessages(ws, userID)
|
||||
}
|
||||
|
||||
func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
|
||||
ctx := context.Background()
|
||||
chanel := "user_" + strconv.Itoa(userId) + "_ai_chat_msg"
|
||||
pubsub := worker.RedisClient.Subscribe(ctx, chanel)
|
||||
defer pubsub.Close()
|
||||
defer ws.Close()
|
||||
ch := pubsub.Channel()
|
||||
//定时器,每3秒发送一次心跳
|
||||
ticker := time.NewTicker(3 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
var sessionID uint
|
||||
|
||||
//接收客户端消息,发送消息
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
for {
|
||||
_, message, err := ws.ReadMessage()
|
||||
if err != nil {
|
||||
log.Println("Read error:", err)
|
||||
}
|
||||
var data proto.WSMessage
|
||||
err1 := json.Unmarshal(message, &data)
|
||||
if err1 != nil {
|
||||
log.Println("Error parsing JSON:", err1)
|
||||
}
|
||||
if sessionID == 0 {
|
||||
//创建会话
|
||||
sessionName := data.Msg[:10]
|
||||
err, sessionID = service.CreateSession(userId, sessionName)
|
||||
if err != nil {
|
||||
log.Println("Create session error:", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
//需要后续处理
|
||||
err, _ = service.WSReceiveMessageService(userId, int(sessionID), chanel, data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
var data proto.WSMessage
|
||||
//从redis订阅消息
|
||||
for m := range ch {
|
||||
msg := m.Payload // 获取消息,消息格式为json
|
||||
if msg != "" {
|
||||
data.Msg = msg
|
||||
data.Type = "msg"
|
||||
data.SessionID = int(sessionID)
|
||||
data.ToID = 0
|
||||
msg_, _ := json.Marshal(data)
|
||||
err2 := ws.WriteMessage(websocket.TextMessage, msg_)
|
||||
if err2 != nil {
|
||||
// 发生错误,删除连接
|
||||
clientsMux.Lock()
|
||||
delete(clients, ws)
|
||||
clientsMux.Unlock()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
//发送心跳包
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
//客户端断开连接
|
||||
return
|
||||
case t := <-ticker.C:
|
||||
//发送心跳
|
||||
err := ws.WriteMessage(websocket.PingMessage, nil)
|
||||
if err != nil {
|
||||
log.Println("Ping error:", err, t)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
1
main.go
1
main.go
|
|
@ -37,6 +37,7 @@ func main() {
|
|||
handler.SetUpFileGroup(r) // File
|
||||
handler.SetUpSessionGroup(r) // Session
|
||||
handler.SetUpModelGroup(r) // Model
|
||||
handler.SetUpIMGroup(r) // IM
|
||||
defer dao.Close()
|
||||
defer worker.CloseRedis()
|
||||
//定时任务
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
package proto
|
||||
|
||||
type AIQueueMessage struct {
|
||||
Type string `json:"type"` //声明不同消息类型
|
||||
ID uint `json:"id"` //消息id
|
||||
Msg string `json:"msg"` //原始json消息
|
||||
}
|
||||
|
||||
type WSMessage struct {
|
||||
Type string `json:"type"` //接收及发送消息类型
|
||||
Msg string `json:"msg"` //消息内容,只进行转发,不做处理
|
||||
SessionID int `json:"session_id"` //应用层会话id
|
||||
ToID int `json:"to_id"` //接收者id
|
||||
}
|
||||
|
|
@ -2,7 +2,14 @@ package service
|
|||
|
||||
import (
|
||||
"StuAcaWorksAI/dao"
|
||||
"StuAcaWorksAI/proto"
|
||||
"StuAcaWorksAI/worker"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/ollama/ollama/api"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 创建会话服务
|
||||
|
|
@ -68,3 +75,55 @@ func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg strin
|
|||
err, mid := CreateMessage(msgType, sessionID, fromID, toID, msg, 0)
|
||||
return err, mid
|
||||
}
|
||||
|
||||
var client *api.Client
|
||||
|
||||
func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessage) (error, uint) {
|
||||
var resErr error
|
||||
var resID uint
|
||||
//处理消息
|
||||
if msg.Type == "ollama" {
|
||||
var err error
|
||||
if client == nil {
|
||||
client, err = api.ClientFromEnvironment()
|
||||
}
|
||||
if err != nil {
|
||||
log.Println("get ollama client error:", err)
|
||||
}
|
||||
prompt := msg.Msg
|
||||
var actx []int
|
||||
req := &api.GenerateRequest{
|
||||
Model: "qwen2.5-coder:latest",
|
||||
Context: actx,
|
||||
Prompt: prompt,
|
||||
}
|
||||
ctx := context.Background()
|
||||
robotMsg := ""
|
||||
var aiMsg proto.AIQueueMessage
|
||||
respFunc := func(resp api.GenerateResponse) error {
|
||||
aiMsg.Type = "ollama"
|
||||
respStr, _ := json.Marshal(resp)
|
||||
aiMsg.Msg = string(respStr)
|
||||
robotMsg += resp.Response
|
||||
if resp.Done { //该消息完成
|
||||
actx = resp.Context
|
||||
//创建消息
|
||||
err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1
|
||||
if err2 != nil {
|
||||
log.Println("create robot message error:", err2)
|
||||
}
|
||||
aiMsg.ID = msgID
|
||||
resID = msgID
|
||||
}
|
||||
//发送消息
|
||||
aiMsgStr, _ := json.Marshal(aiMsg)
|
||||
worker.Publish(channel, string(aiMsgStr), time.Second*60)
|
||||
return nil
|
||||
}
|
||||
err = client.Generate(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
return resErr, resID
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue