saw-go/service/imService.go

179 lines
5.1 KiB
Go
Raw Normal View History

package service
import (
"StuAcaWorksAI/dao"
"StuAcaWorksAI/proto"
"StuAcaWorksAI/worker"
"context"
"encoding/json"
"errors"
"github.com/ollama/ollama/api"
"log"
"time"
)
// 创建会话服务
func CreateSession(userID int, name string) (error, uint) {
err, id := dao.CreateSession(userID, name)
return err, id
}
// 查找用户的会话列表
func FindSessionByUserID(userID int) []proto.SessionResponse {
sessions := dao.FindSessionByUserID(userID)
sessionsMsgCounts := dao.FindSessionMessageCountByUserID(userID)
smap := make(map[int]int)
for _, v := range sessionsMsgCounts {
smap[v.SessionID] = v.Count
}
var res []proto.SessionResponse
for _, v := range sessions {
var session proto.SessionResponse
session.ID = v.ID
session.CreatedAt = v.CreatedAt
session.UpdatedAt = v.UpdatedAt
session.Name = v.Name
session.UserID = v.UserID
session.MsgCount = smap[int(v.ID)]
res = append(res, session)
}
return res
}
2025-03-21 16:20:24 +08:00
func FindSessionByID(id, userID int) (error, []dao.Session) {
session := dao.FindSessionByID(id)
var res []dao.Session
if session.UserID != userID {
return errors.New("session not belong to user"), res
}
res = append(res, session)
return nil, res
}
func DeleteSessionByID(id, userID int) error {
session := dao.FindSessionByID(id)
if session.UserID != userID {
return errors.New("session not belong to user")
}
err := dao.DeleteSessionByID(id)
return err
}
// 更新会话名字
2025-03-23 13:17:03 +08:00
func UpdateSessionByID(id int, userId int, name string, context []int) error {
err := dao.UpdateSessionByID(id, userId, name, context)
return err
}
// 创建消息服务
func CreateMessage(msgType int, sessionID int, fromID int, toID int, msg string, status int) (error, uint) {
//sessionID, fromID, toID, msgType, status int, msg string
err, id := dao.CreateMessage(sessionID, fromID, toID, msgType, status, msg)
return err, id
}
// 通过会话id查询消息
func FindMessageBySessionID(sessionID, userID int) (error, []dao.Message) {
//查看会话是否属于该用户
session := dao.FindSessionByID(sessionID)
if session.UserID != userID {
return errors.New("session not belong to user"), nil
}
msgs := dao.FindMessageBySessionID(sessionID)
return nil, msgs
}
func CreateGeneralMessageService(fromID, toID, msgType, sessionID int, msg string) (error, uint) {
//查看会话是否属于该用户
session := dao.FindSessionByID(sessionID)
if session.UserID != fromID {
return errors.New("session not belong to user"), 0
}
err, mid := CreateMessage(msgType, sessionID, fromID, toID, msg, 0)
return err, mid
}
var client *api.Client
2025-03-23 13:17:03 +08:00
var SessionContextMap map[int][]int
func WSReceiveMessageService(userID, sessionID int, channel string, msg proto.WSMessageReq, model dao.Model, modelParam proto.ModelParam) (error, uint) {
var resErr error
var resID uint
//处理消息
2025-03-24 16:31:39 +08:00
if model.Type == "ollama" {
2025-03-25 14:34:50 +08:00
resErr, resID = ReceiveOllamaSession(userID, sessionID, channel, msg, modelParam, &model)
2025-03-24 16:31:39 +08:00
} else if model.Type == "spark" {
2025-03-25 15:51:05 +08:00
resErr, resID = ReceiveSparkSession(userID, sessionID, channel, msg, modelParam, &model)
2025-03-25 15:48:56 +08:00
} else if model.Type == "doubao" {
doubao(modelParam, msg.Msg, channel, sessionID, userID, int(model.ID))
} else {
return errors.New("model not exist"), 0
2025-03-23 16:03:17 +08:00
}
return resErr, resID
}
2025-03-25 14:34:50 +08:00
func ReceiveOllamaSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, m *dao.Model) (error, uint) {
2025-03-23 16:03:17 +08:00
var resErr error
var resID uint
var err error
if client == nil {
client, err = api.ClientFromEnvironment()
}
if err != nil {
log.Println("get ollama client error:", err)
}
prompt := msg.Msg
actx, ok := SessionContextMap[sessionID]
if ok {
log.Println("actx has get will use:", len(actx))
}
req := &api.GenerateRequest{
Model: modelParam.Model,
2025-03-23 16:03:17 +08:00
Context: actx,
Prompt: prompt,
}
ctx := context.Background()
robotMsg := ""
var aiMsg proto.AIQueueMessage
var data proto.WSMessage
respFunc := func(resp api.GenerateResponse) error {
aiMsg.Type = "ollama"
aiMsg.Msg = resp
robotMsg += resp.Response
if resp.Done { //该消息完成
actx = resp.Context
if SessionContextMap == nil {
SessionContextMap = make(map[int][]int)
}
2025-03-23 16:03:17 +08:00
SessionContextMap[sessionID] = actx
log.Println("actx has save:", len(SessionContextMap[sessionID]))
//创建消息
err2, msgID := CreateMessage(2, sessionID, 2, userID, robotMsg, 1) //机器人id为2,消息类型为2,状态为1
if err2 != nil {
log.Println("create robot message error:", err2)
}
aiMsg.ID = msgID
resID = msgID
}
2025-03-23 16:03:17 +08:00
//发送消息
data.Msg = aiMsg
data.SessionID = sessionID
data.ToID = userID
aiMsgStr, _ := json.Marshal(data)
worker.Publish(channel, string(aiMsgStr), time.Second*60)
2025-03-25 14:34:50 +08:00
go CreateAIStreamMsg(userID, int(m.ID), sessionID, resp, modelParam.Model)
2025-03-23 16:03:17 +08:00
return nil
}
err = client.Generate(ctx, req, respFunc)
if err != nil {
log.Fatal(err)
}
return resErr, resID
}
2025-03-23 16:03:17 +08:00
2025-03-25 15:51:05 +08:00
func ReceiveSparkSession(userID, sessionID int, channel string, msg proto.WSMessageReq, modelParam proto.ModelParam, model *dao.Model) (error, uint) {
2025-03-24 13:36:42 +08:00
var resErr error
var resID uint
2025-03-25 15:51:05 +08:00
Spark(modelParam, msg.Msg, channel, sessionID, userID, int(model.ID))
2025-03-24 13:36:42 +08:00
return resErr, resID
2025-03-23 16:03:17 +08:00
}