Merge branch 'refs/heads/feat-im-ws'
This commit is contained in:
commit
e3a14eec54
11
dao/im.go
11
dao/im.go
|
|
@ -7,8 +7,9 @@ import (
|
|||
|
||||
type Session struct {
|
||||
gorm.Model
|
||||
UserID int `gorm:"column:user_id"` //只能由用户创建
|
||||
Name string `gorm:"column:name"`
|
||||
UserID int `gorm:"column:user_id"` //只能由用户创建
|
||||
Name string `gorm:"column:name"`
|
||||
Context []int `gorm:"column:context"` //会话上下文
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
|
|
@ -54,12 +55,12 @@ func FindSessionByUserID(userID int) []Session {
|
|||
}
|
||||
|
||||
// 更新会话的名字
|
||||
func UpdateSessionByID(id int, userId int, name string) error {
|
||||
func UpdateSessionByID(id int, userId int, name string, context []int) error {
|
||||
var res *gorm.DB
|
||||
if proto.Config.SERVER_SQL_LOG {
|
||||
res = DB.Debug().Model(&Session{}).Where("id = ? and user_id = ?", id, userId).Update("name", name)
|
||||
res = DB.Debug().Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name, Context: context})
|
||||
} else {
|
||||
res = DB.Model(&Session{}).Where("id = ? and user_id = ?", id, userId).Update("name", name)
|
||||
res = DB.Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name, Context: context})
|
||||
}
|
||||
return res.Error
|
||||
}
|
||||
|
|
|
|||
|
|
@ -176,6 +176,18 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
|
|||
for {
|
||||
select {
|
||||
case <-done:
|
||||
//结束会话将context保存
|
||||
ctx2, ok := service.SessionContextMap[int(sessionID)]
|
||||
if !ok {
|
||||
log.Println("session context not found")
|
||||
return
|
||||
}
|
||||
err := service.UpdateSessionByID(int(sessionID), userId, "", ctx2)
|
||||
if err != nil {
|
||||
log.Println("update session context error:", err)
|
||||
}
|
||||
//保存的会话
|
||||
delete(service.SessionContextMap, int(sessionID))
|
||||
//客户端断开连接
|
||||
return
|
||||
//case t := <-ticker.C:
|
||||
|
|
@ -188,4 +200,5 @@ func subscribeAndHandleIMMessages(ws *websocket.Conn, userId int) {
|
|||
//}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ func UpdateSession(c *gin.Context) {
|
|||
var req Session
|
||||
if err := c.ShouldBind(&req); err == nil {
|
||||
// 修改会话
|
||||
err2 := service.UpdateSessionByID(req.ID, userID, req.Name)
|
||||
err2 := service.UpdateSessionByID(req.ID, userID, req.Name, nil)
|
||||
if err2 == nil {
|
||||
c.JSON(http.StatusOK, gin.H{"code": proto.SuccessCode, "message": "success"})
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ type ModelParam struct {
|
|||
|
||||
type WSMessageReq struct {
|
||||
Type string `json:"type"` //接收及发送消息类型
|
||||
Function string `json:"function"` //功能名称
|
||||
Msg string `json:"msg"` //消息内容,只进行转发,不做处理
|
||||
SessionID int `json:"session_id"` //应用层会话id
|
||||
ToID int `json:"to_id"` //接收者id
|
||||
|
|
|
|||
|
|
@ -43,8 +43,8 @@ func DeleteSessionByID(id, userID int) error {
|
|||
}
|
||||
|
||||
// 更新会话名字
|
||||
func UpdateSessionByID(id int, userId int, name string) error {
|
||||
err := dao.UpdateSessionByID(id, userId, name)
|
||||
func UpdateSessionByID(id int, userId int, name string, context []int) error {
|
||||
err := dao.UpdateSessionByID(id, userId, name, context)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -77,65 +77,83 @@ func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg strin
|
|||
}
|
||||
|
||||
var client *api.Client
|
||||
var actxMap map[int][]int
|
||||
var SessionContextMap map[int][]int
|
||||
|
||||
func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) {
|
||||
var resErr error
|
||||
var resID uint
|
||||
//处理消息
|
||||
if msg.Type == "ollama" {
|
||||
var err error
|
||||
if client == nil {
|
||||
client, err = api.ClientFromEnvironment()
|
||||
}
|
||||
if err != nil {
|
||||
log.Println("get ollama client error:", err)
|
||||
}
|
||||
prompt := msg.Msg
|
||||
actx, ok := actxMap[sessionID]
|
||||
if ok {
|
||||
log.Println("actx has get will use:", len(actx))
|
||||
}
|
||||
req := &api.GenerateRequest{
|
||||
Model: "qwen2.5:0.5b",
|
||||
Context: actx,
|
||||
Prompt: prompt,
|
||||
}
|
||||
ctx := context.Background()
|
||||
robotMsg := ""
|
||||
var aiMsg proto.AIQueueMessage
|
||||
var data proto.WSMessage
|
||||
respFunc := func(resp api.GenerateResponse) error {
|
||||
aiMsg.Type = "ollama"
|
||||
aiMsg.Msg = resp
|
||||
robotMsg += resp.Response
|
||||
if resp.Done { //该消息完成
|
||||
actx = resp.Context
|
||||
if actxMap == nil {
|
||||
actxMap = make(map[int][]int)
|
||||
}
|
||||
actxMap[sessionID] = actx
|
||||
log.Println("actx has save:", len(actxMap[sessionID]))
|
||||
//创建消息
|
||||
err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1
|
||||
if err2 != nil {
|
||||
log.Println("create robot message error:", err2)
|
||||
}
|
||||
aiMsg.ID = msgID
|
||||
resID = msgID
|
||||
}
|
||||
//发送消息
|
||||
data.Msg = aiMsg
|
||||
data.SessionID = sessionID
|
||||
data.ToID = userID
|
||||
aiMsgStr, _ := json.Marshal(data)
|
||||
worker.Publish(channel, string(aiMsgStr), time.Second*60)
|
||||
return nil
|
||||
}
|
||||
err = client.Generate(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg)
|
||||
} else if msg.Type == "spark" {
|
||||
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg)
|
||||
|
||||
}
|
||||
return resErr, resID
|
||||
}
|
||||
|
||||
func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) {
|
||||
var resErr error
|
||||
var resID uint
|
||||
var err error
|
||||
if client == nil {
|
||||
client, err = api.ClientFromEnvironment()
|
||||
}
|
||||
if err != nil {
|
||||
log.Println("get ollama client error:", err)
|
||||
}
|
||||
prompt := msg.Msg
|
||||
actx, ok := SessionContextMap[sessionID]
|
||||
if ok {
|
||||
log.Println("actx has get will use:", len(actx))
|
||||
}
|
||||
req := &api.GenerateRequest{
|
||||
Model: "qwen2.5:0.5b",
|
||||
Context: actx,
|
||||
Prompt: prompt,
|
||||
}
|
||||
ctx := context.Background()
|
||||
robotMsg := ""
|
||||
var aiMsg proto.AIQueueMessage
|
||||
var data proto.WSMessage
|
||||
respFunc := func(resp api.GenerateResponse) error {
|
||||
aiMsg.Type = "ollama"
|
||||
aiMsg.Msg = resp
|
||||
robotMsg += resp.Response
|
||||
if resp.Done { //该消息完成
|
||||
actx = resp.Context
|
||||
if SessionContextMap == nil {
|
||||
SessionContextMap = make(map[int][]int)
|
||||
}
|
||||
SessionContextMap[sessionID] = actx
|
||||
log.Println("actx has save:", len(SessionContextMap[sessionID]))
|
||||
//创建消息
|
||||
err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1
|
||||
if err2 != nil {
|
||||
log.Println("create robot message error:", err2)
|
||||
}
|
||||
aiMsg.ID = msgID
|
||||
resID = msgID
|
||||
}
|
||||
//发送消息
|
||||
data.Msg = aiMsg
|
||||
data.SessionID = sessionID
|
||||
data.ToID = userID
|
||||
aiMsgStr, _ := json.Marshal(data)
|
||||
worker.Publish(channel, string(aiMsgStr), time.Second*60)
|
||||
return nil
|
||||
}
|
||||
err = client.Generate(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return resErr, resID
|
||||
}
|
||||
|
||||
func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) {
|
||||
var resErr error
|
||||
var resID uint
|
||||
//查找模型
|
||||
|
||||
return resErr, resID
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,274 @@
|
|||
package worker
|
||||
|
||||
import (
|
||||
"StuAcaWorksAI/dao"
|
||||
"StuAcaWorksAI/proto"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/ollama/ollama/api"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
/***************** Spark请求结构 ***********************/
|
||||
type SparkRequestHeader struct {
|
||||
AppID string `json:"app_id"`
|
||||
UID string `json:"uid"`
|
||||
}
|
||||
|
||||
type ChatParameter struct {
|
||||
Domain string `json:"domain"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
}
|
||||
|
||||
type Parameter struct {
|
||||
Chat ChatParameter `json:"chat"`
|
||||
}
|
||||
|
||||
type MessageContent struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type SparkMessage struct {
|
||||
Text []MessageContent `json:"text"`
|
||||
}
|
||||
|
||||
type SparkRequestPayload struct {
|
||||
Message SparkMessage `json:"message"`
|
||||
}
|
||||
|
||||
type SparkRequest struct {
|
||||
Header SparkRequestHeader `json:"header"`
|
||||
Parameter Parameter `json:"parameter"`
|
||||
Payload SparkRequestPayload `json:"payload"`
|
||||
}
|
||||
|
||||
/***************** Spark请求结构 ***********************/
|
||||
|
||||
/***************** Spark响应结构 ***********************/
|
||||
|
||||
type SparkResponseHeader struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
SID string `json:"sid"`
|
||||
Status int `json:"status"` //会话状态,取值为[0,1,2];0代表首次结果;1代表中间结果;2代表最后一个结果
|
||||
}
|
||||
|
||||
// ChoiceText 结构体对应 JSON 中 "choices.text[].text" 部分
|
||||
type ChoiceText struct {
|
||||
Content string `json:"content"`
|
||||
Role string `json:"role"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
// Choices 结构体对应 JSON 中 "choices" 部分
|
||||
type Choices struct {
|
||||
Status int `json:"status"`
|
||||
Seq int `json:"seq"`
|
||||
Text []ChoiceText `json:"text"`
|
||||
}
|
||||
|
||||
// UsageText 结构体对应 JSON 中 "usage.text" 部分
|
||||
type UsageText struct {
|
||||
QuestionTokens int `json:"question_tokens"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// Usage 结构体对应 JSON 中 "usage" 部分
|
||||
type Usage struct {
|
||||
Text UsageText `json:"text"`
|
||||
}
|
||||
|
||||
// Payload 结构体对应 JSON 中 "payload" 部分
|
||||
type SparkResponsePayload struct {
|
||||
Choices Choices `json:"choices"`
|
||||
Usage Usage `json:"usage"`
|
||||
}
|
||||
|
||||
// Request 结构体对应整个 JSON 数据
|
||||
type SparkResponse struct {
|
||||
Header SparkResponseHeader `json:"header"`
|
||||
Payload SparkResponsePayload `json:"payload"`
|
||||
}
|
||||
|
||||
/***************** Spark响应结构 ***********************/
|
||||
|
||||
func Spark(hostUrl string, apiKey, apiSecret, appid string, question, channel string, SessionID, UserID int) {
|
||||
//建立websocket连接
|
||||
dialer := websocket.Dialer{
|
||||
HandshakeTimeout: 5 * time.Second,
|
||||
}
|
||||
//建立连接
|
||||
conn, resp, err := dialer.Dial(assembleAuthUrl(hostUrl, apiKey, apiSecret), nil)
|
||||
if err != nil {
|
||||
panic(readResp(resp) + err.Error())
|
||||
return
|
||||
} else if resp.StatusCode != 101 {
|
||||
panic(readResp(resp) + err.Error())
|
||||
}
|
||||
|
||||
//发送消息
|
||||
go func() {
|
||||
data := genSparkParams(appid, question)
|
||||
//将数据转换为json
|
||||
//dataByte, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
fmt.Println("Error parsing JSON:", err)
|
||||
return
|
||||
}
|
||||
err2 := conn.WriteJSON(data)
|
||||
if err != nil {
|
||||
fmt.Println("write message error:", err2)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
answer := ""
|
||||
|
||||
//接收消息
|
||||
for {
|
||||
_, msg, err2 := conn.ReadMessage()
|
||||
if err2 != nil {
|
||||
log.Println("read message error:", err2)
|
||||
break
|
||||
}
|
||||
var data SparkResponse
|
||||
err = json.Unmarshal(msg, &data)
|
||||
if err != nil {
|
||||
log.Println("Error parsing JSON:", err)
|
||||
return
|
||||
}
|
||||
|
||||
if data.Header.Code != 0 {
|
||||
log.Println("Error response:", data.Header.Message)
|
||||
break
|
||||
}
|
||||
if data.Header.Status == 2 {
|
||||
//最后一个结果
|
||||
answer += data.Payload.Choices.Text[0].Content
|
||||
} else {
|
||||
answer += data.Payload.Choices.Text[0].Content
|
||||
}
|
||||
SparkToGeneralMassageAndSendMsgQueue(&data, channel, SessionID, UserID)
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
//将消息存入数据库
|
||||
err, _ = dao.CreateMessage(SessionID, 3, UserID, 2, 1, answer)
|
||||
if err != nil {
|
||||
log.Println("spark create message error:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func SparkToGeneralMassageAndSendMsgQueue(data *SparkResponsem, channel string, SessionID int, UserID int) {
|
||||
var aiMsg proto.AIQueueMessage
|
||||
var wsMsg proto.WSMessage
|
||||
var apiMsg api.GenerateResponse
|
||||
apiMsg.Model = "spark"
|
||||
apiMsg.Response = data.Payload.Choices.Text[0].Content
|
||||
if data.Header.Status == 2 {
|
||||
apiMsg.Done = true
|
||||
} else {
|
||||
apiMsg.Done = false
|
||||
}
|
||||
//消息队列部分
|
||||
aiMsg.Type = "spark"
|
||||
aiMsg.Msg = apiMsg
|
||||
//ws发送消息部分
|
||||
wsMsg.Msg = aiMsg
|
||||
wsMsg.SessionID = SessionID
|
||||
wsMsg.ToID = UserID
|
||||
wsMsg.Type = "spark"
|
||||
|
||||
//发送消息
|
||||
wsMsgStr, _ := json.Marshal(wsMsg)
|
||||
Publish(channel, string(wsMsgStr), time.Second*60)
|
||||
}
|
||||
|
||||
// 生成参数
|
||||
func genSparkParams(question, appid string) SparkRequest {
|
||||
messages := []MessageContent{
|
||||
{Role: "user", Content: question},
|
||||
}
|
||||
|
||||
data := SparkRequest{
|
||||
Header: SparkRequestHeader{
|
||||
AppID: appid,
|
||||
},
|
||||
Parameter: Parameter{
|
||||
Chat: ChatParameter{
|
||||
Domain: "4.0Ultra",
|
||||
Temperature: 0.8,
|
||||
MaxTokens: 2048,
|
||||
},
|
||||
},
|
||||
Payload: SparkRequestPayload{
|
||||
Message: SparkMessage{
|
||||
Text: messages,
|
||||
},
|
||||
},
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// 创建鉴权url apikey 即 hmac username
|
||||
func assembleAuthUrl(hostUrl string, apiKey, apiSecret string) string {
|
||||
ul, err := url.Parse(hostUrl)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
//签名时间
|
||||
date := time.Now().UTC().Format(time.RFC1123)
|
||||
//date = "Tue, 28 May 2019 09:10:42 MST"
|
||||
//参与签名的字段 host ,date, request-line
|
||||
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
|
||||
//拼接签名字符串
|
||||
sgin := strings.Join(signString, "\n")
|
||||
// fmt.Println(sgin)
|
||||
//签名结果
|
||||
sha := HmacWithShaTobase64("hmac-sha256", sgin, apiSecret)
|
||||
// fmt.Println(sha)
|
||||
//构建请求参数 此时不需要urlencoding
|
||||
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
|
||||
"hmac-sha256", "host date request-line", sha)
|
||||
//将请求参数使用base64编码
|
||||
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
|
||||
|
||||
v := url.Values{}
|
||||
v.Add("host", ul.Host)
|
||||
v.Add("date", date)
|
||||
v.Add("authorization", authorization)
|
||||
//将编码后的字符串url encode后添加到url后面
|
||||
callurl := hostUrl + "?" + v.Encode()
|
||||
return callurl
|
||||
}
|
||||
|
||||
func HmacWithShaTobase64(algorithm, data, key string) string {
|
||||
mac := hmac.New(sha256.New, []byte(key))
|
||||
mac.Write([]byte(data))
|
||||
encodeData := mac.Sum(nil)
|
||||
return base64.StdEncoding.EncodeToString(encodeData)
|
||||
}
|
||||
|
||||
func readResp(resp *http.Response) string {
|
||||
if resp == nil {
|
||||
return ""
|
||||
}
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b))
|
||||
}
|
||||
Loading…
Reference in New Issue