160 lines
4.2 KiB
Go
160 lines
4.2 KiB
Go
package service
|
|
|
|
import (
|
|
"StuAcaWorksAI/dao"
|
|
"StuAcaWorksAI/proto"
|
|
"StuAcaWorksAI/worker"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"github.com/ollama/ollama/api"
|
|
"log"
|
|
"time"
|
|
)
|
|
|
|
// 创建会话服务
|
|
func CreateSession(userID int, name string) (error, uint) {
|
|
err, id := dao.CreateSession(userID, name)
|
|
return err, id
|
|
}
|
|
|
|
// 查找用户的会话列表
|
|
func FindSessionByUserID(userID int) []dao.Session {
|
|
sessions := dao.FindSessionByUserID(userID)
|
|
return sessions
|
|
}
|
|
func FindSessionByID(id, userID int) (error, []dao.Session) {
|
|
session := dao.FindSessionByID(id)
|
|
var res []dao.Session
|
|
if session.UserID != userID {
|
|
return errors.New("session not belong to user"), res
|
|
}
|
|
res = append(res, session)
|
|
return nil, res
|
|
}
|
|
|
|
func DeleteSessionByID(id, userID int) error {
|
|
session := dao.FindSessionByID(id)
|
|
if session.UserID != userID {
|
|
return errors.New("session not belong to user")
|
|
}
|
|
err := dao.DeleteSessionByID(id)
|
|
return err
|
|
}
|
|
|
|
// 更新会话名字
|
|
func UpdateSessionByID(id int, userId int, name string, context []int) error {
|
|
err := dao.UpdateSessionByID(id, userId, name, context)
|
|
return err
|
|
}
|
|
|
|
// 创建消息服务
|
|
func CreateMessage(msgType int, sessionID int, fromID int, toID int, msg string, status int) (error, uint) {
|
|
//sessionID, fromID, toID, msgType, status int, msg string
|
|
err, id := dao.CreateMessage(sessionID, fromID, toID, msgType, status, msg)
|
|
return err, id
|
|
}
|
|
|
|
// 通过会话id查询消息
|
|
func FindMessageBySessionID(sessionID, userID int) (error, []dao.Message) {
|
|
//查看会话是否属于该用户
|
|
session := dao.FindSessionByID(sessionID)
|
|
if session.UserID != userID {
|
|
return errors.New("session not belong to user"), nil
|
|
}
|
|
msgs := dao.FindMessageBySessionID(sessionID)
|
|
return nil, msgs
|
|
}
|
|
|
|
func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg string) (error, uint) {
|
|
//查看会话是否属于该用户
|
|
session := dao.FindSessionByID(sessionID)
|
|
if session.UserID != fromID {
|
|
return errors.New("session not belong to user"), 0
|
|
}
|
|
err, mid := CreateMessage(msgType, sessionID, fromID, toID, msg, 0)
|
|
return err, mid
|
|
}
|
|
|
|
var client *api.Client
|
|
var SessionContextMap map[int][]int
|
|
|
|
func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) {
|
|
var resErr error
|
|
var resID uint
|
|
//处理消息
|
|
if msg.Type == "ollama" {
|
|
resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg)
|
|
} else if msg.Type == "spark" {
|
|
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg)
|
|
|
|
}
|
|
return resErr, resID
|
|
}
|
|
|
|
func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) {
|
|
var resErr error
|
|
var resID uint
|
|
var err error
|
|
if client == nil {
|
|
client, err = api.ClientFromEnvironment()
|
|
}
|
|
if err != nil {
|
|
log.Println("get ollama client error:", err)
|
|
}
|
|
prompt := msg.Msg
|
|
actx, ok := SessionContextMap[sessionID]
|
|
if ok {
|
|
log.Println("actx has get will use:", len(actx))
|
|
}
|
|
req := &api.GenerateRequest{
|
|
Model: "qwen2.5:0.5b",
|
|
Context: actx,
|
|
Prompt: prompt,
|
|
}
|
|
ctx := context.Background()
|
|
robotMsg := ""
|
|
var aiMsg proto.AIQueueMessage
|
|
var data proto.WSMessage
|
|
respFunc := func(resp api.GenerateResponse) error {
|
|
aiMsg.Type = "ollama"
|
|
aiMsg.Msg = resp
|
|
robotMsg += resp.Response
|
|
if resp.Done { //该消息完成
|
|
actx = resp.Context
|
|
if SessionContextMap == nil {
|
|
SessionContextMap = make(map[int][]int)
|
|
}
|
|
SessionContextMap[sessionID] = actx
|
|
log.Println("actx has save:", len(SessionContextMap[sessionID]))
|
|
//创建消息
|
|
err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1
|
|
if err2 != nil {
|
|
log.Println("create robot message error:", err2)
|
|
}
|
|
aiMsg.ID = msgID
|
|
resID = msgID
|
|
}
|
|
//发送消息
|
|
data.Msg = aiMsg
|
|
data.SessionID = sessionID
|
|
data.ToID = userID
|
|
aiMsgStr, _ := json.Marshal(data)
|
|
worker.Publish(channel, string(aiMsgStr), time.Second*60)
|
|
return nil
|
|
}
|
|
err = client.Generate(ctx, req, respFunc)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
return resErr, resID
|
|
}
|
|
|
|
func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq) (error, uint) {
|
|
var resErr error
|
|
var resID uint
|
|
//查找模型
|
|
|
|
return resErr, resID
|
|
}
|