229 lines
7.4 KiB
Go
229 lines
7.4 KiB
Go
package dao
|
||
|
||
import (
|
||
"StuAcaWorksAI/proto"
|
||
"encoding/json"
|
||
"gorm.io/gorm"
|
||
"log"
|
||
"time"
|
||
)
|
||
|
||
type Session struct {
|
||
gorm.Model
|
||
UserID int `gorm:"column:user_id"` //只能由用户创建
|
||
Name string `gorm:"column:name"`
|
||
Type int `gorm:"column:type"` //会话类型:1为通用聊天会话,2为提示词会话,3为制作ppt会话
|
||
Context json.RawMessage `gorm:"column:context;type:json"` //会话上下文
|
||
}
|
||
|
||
type Message struct {
|
||
gorm.Model
|
||
Type int `gorm:"column:type"` //1用户之间的消息,2为与模型消息
|
||
SessionID int `gorm:"column:session_id;index"` //添加索引
|
||
FromID int `gorm:"column:from_id"` //发送者,可以为用户或者模型
|
||
ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型,则为模型id,根据type判断
|
||
Msg string `gorm:"column:msg"` //消息内容
|
||
FunctionID int `gorm:"column:function_id"` //功能id
|
||
Status int `gorm:"column:status"` //0为未读,1为已读,3为图片消息,4为文件消息,5为制作ppt消息
|
||
}
|
||
|
||
func CreateSession(userID, Type int, name string) (error, uint) {
|
||
session := Session{UserID: userID, Name: name, Type: Type}
|
||
var res *gorm.DB
|
||
if proto.Config.SERVER_SQL_LOG {
|
||
res = DB.Debug().Create(&session)
|
||
} else {
|
||
res = DB.Create(&session)
|
||
}
|
||
return res.Error, session.ID
|
||
}
|
||
|
||
func FindSessionByID(id int) Session {
|
||
var session Session
|
||
if proto.Config.SERVER_SQL_LOG {
|
||
DB.Debug().Where("id = ?", id).First(&session)
|
||
} else {
|
||
DB.Where("id = ?", id).First(&session)
|
||
}
|
||
return session
|
||
}
|
||
|
||
// 获取用户最新的创建时间会话列表 50个
|
||
func FindSessionByUserID(userID, sessionType int) []Session {
|
||
var sessions []Session
|
||
if proto.Config.SERVER_SQL_LOG {
|
||
DB.Debug().Where("user_id = ? and type = ?", userID, sessionType).Order("created_at DESC").Limit(50).Find(&sessions)
|
||
} else {
|
||
DB.Where("user_id = ? and type = ?", userID, sessionType).Order("created_at DESC").Limit(50).Find(&sessions)
|
||
}
|
||
return sessions
|
||
}
|
||
|
||
// 获取用户最新的创建时间会话列表 100个
|
||
func FindSessionByUserIDV2(userID int) []Session {
|
||
var sessions []Session
|
||
if proto.Config.SERVER_SQL_LOG {
|
||
DB.Debug().Where("user_id = ?", userID).Order("created_at DESC").Limit(100).Find(&sessions)
|
||
} else {
|
||
DB.Where("user_id = ?", userID).Order("created_at DESC").Limit(100).Find(&sessions)
|
||
}
|
||
return sessions
|
||
|
||
}
|
||
|
||
// 更新会话的名字
|
||
func UpdateSessionByID(id int, userId int, name string, context []int) error {
|
||
var res *gorm.DB
|
||
contextJson, err := json.Marshal(context)
|
||
if err != nil {
|
||
log.Println("update session json marshal error:", err)
|
||
}
|
||
|
||
if proto.Config.SERVER_SQL_LOG {
|
||
res = DB.Debug().Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name}).Update("context", contextJson)
|
||
} else {
|
||
res = DB.Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name, Context: nil}).Update("context", contextJson)
|
||
}
|
||
return res.Error
|
||
}
|
||
|
||
// 删除会话
|
||
func DeleteSessionByID(id int) error {
|
||
var res *gorm.DB
|
||
if proto.Config.SERVER_SQL_LOG {
|
||
res = DB.Debug().Delete(&Session{}, id)
|
||
} else {
|
||
res = DB.Delete(&Session{}, id)
|
||
}
|
||
return res.Error
|
||
}
|
||
|
||
// 创建消息
|
||
func CreateMessage(sessionID, fromID, toID, msgType, status int, msg string, function_id int) (error, uint) {
|
||
message := Message{SessionID: sessionID, FromID: fromID, ToID: toID, Type: msgType, Status: status, Msg: msg, FunctionID: function_id}
|
||
var res *gorm.DB
|
||
if proto.Config.SERVER_SQL_LOG {
|
||
res = DB.Debug().Create(&message)
|
||
} else {
|
||
res = DB.Create(&message)
|
||
}
|
||
return res.Error, message.ID
|
||
}
|
||
|
||
// 根据会话id获取消息
|
||
func FindMessageBySessionID(sessionID int) []Message {
|
||
var messages []Message
|
||
if proto.Config.SERVER_SQL_LOG {
|
||
DB.Debug().Where("session_id = ?", sessionID).Order("created_at DESC").Find(&messages)
|
||
} else {
|
||
DB.Where("session_id = ?", sessionID).Order("created_at DESC").Find(&messages)
|
||
}
|
||
return messages
|
||
}
|
||
|
||
func FindMessageBySessionIDV2(sessionID int) []Message {
|
||
var messages []Message
|
||
if proto.Config.SERVER_SQL_LOG {
|
||
DB.Debug().Where("session_id = ?", sessionID).Find(&messages)
|
||
} else {
|
||
DB.Where("session_id = ?", sessionID).Find(&messages)
|
||
}
|
||
return messages
|
||
}
|
||
|
||
type SessionMessageCount struct {
|
||
SessionID int `gorm:"column:session_id"`
|
||
Count int `gorm:"column:count"`
|
||
}
|
||
|
||
// 获取属于用户的会话消息数量
|
||
func FindSessionMessageCountByUserID(userID int) []SessionMessageCount {
|
||
var sessionMessageCounts []SessionMessageCount
|
||
if proto.Config.SERVER_SQL_LOG {
|
||
//DB.Debug().Exec("Select session_id, count(*) as count from messages where to_id = ? or from_id=? group by session_id", userID,userID).Scan(&sessionMessageCounts)
|
||
DB.Debug().Table("messages").Select("session_id, count(*) as count").Where("to_id = ? or from_id= ?", userID, userID).Group("session_id").Scan(&sessionMessageCounts)
|
||
} else {
|
||
DB.Table("messages").Select("session_id, count(*) as count").Where("to_id = ? or from_id= ?", userID, userID).Group("session_id").Scan(&sessionMessageCounts)
|
||
}
|
||
return sessionMessageCounts
|
||
}
|
||
|
||
// 获取会话数,问答数,今日问答数
|
||
func FindBaseSessionMessageStatisticsInfo() (int64, int64, int64, error) {
|
||
//所有会话数
|
||
var sessionCount int64
|
||
var messageCount int64
|
||
var todayMessageCount int64
|
||
today := time.Now().Truncate(24 * time.Hour)
|
||
var db2 *gorm.DB
|
||
if proto.Config.SERVER_SQL_LOG {
|
||
db2 = DB.Debug()
|
||
} else {
|
||
db2 = DB
|
||
}
|
||
|
||
// 查询会话数量
|
||
if err := db2.Model(&Session{}).Count(&sessionCount).Error; err != nil {
|
||
return 0, 0, 0, err
|
||
}
|
||
// 查询消息数量
|
||
if err := db2.Model(&Message{}).Count(&messageCount).Error; err != nil {
|
||
return 0, 0, 0, err
|
||
}
|
||
// 查询今天的消息数量
|
||
if err := db2.Model(&Message{}).Where("created_at > ?", today).Count(&todayMessageCount).Error; err != nil {
|
||
return 0, 0, 0, err
|
||
}
|
||
return sessionCount, messageCount, todayMessageCount, nil
|
||
}
|
||
|
||
func FindUserSessionCount(userID, sessionType int) int64 {
|
||
var sessionCount int64
|
||
var db2 *gorm.DB
|
||
if proto.Config.SERVER_SQL_LOG {
|
||
db2 = DB.Debug()
|
||
} else {
|
||
db2 = DB
|
||
}
|
||
if err := db2.Model(&Session{}).Where("user_id = ? and type = ?", userID, sessionType).Count(&sessionCount).Error; err != nil {
|
||
log.Println("find user session count error:", err)
|
||
}
|
||
log.Println("user session count:", sessionCount)
|
||
return sessionCount
|
||
}
|
||
|
||
type SessionOrMessageCount struct {
|
||
Date string `gorm:"column:date"`
|
||
Count int64 `gorm:"column:count"`
|
||
}
|
||
|
||
// 获取最近一周每天的会话数
|
||
func FindSessionCountByDate() ([]SessionOrMessageCount, error) {
|
||
var sessionCounts []SessionOrMessageCount
|
||
var db2 *gorm.DB
|
||
if proto.Config.SERVER_SQL_LOG {
|
||
db2 = DB.Debug()
|
||
} else {
|
||
db2 = DB
|
||
}
|
||
if err := db2.Model(&Session{}).Select("DATE(created_at) as date, COUNT(*) as count").Where("created_at >= DATE_SUB(CURDATE(), INTERVAL 7 DAY)").Group("date").Order("date DESC").Scan(&sessionCounts).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return sessionCounts, nil
|
||
}
|
||
|
||
// 获取最近一周每天的消息数
|
||
func FindMessageCountByDate() ([]SessionOrMessageCount, error) {
|
||
var messageCounts []SessionOrMessageCount
|
||
var db2 *gorm.DB
|
||
if proto.Config.SERVER_SQL_LOG {
|
||
db2 = DB.Debug()
|
||
} else {
|
||
db2 = DB
|
||
}
|
||
if err := db2.Model(&Message{}).Select("DATE(created_at) as date, COUNT(*) as count").Where("created_at >= DATE_SUB(CURDATE(), INTERVAL 7 DAY)").Group("date").Order("date DESC").Scan(&messageCounts).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return messageCounts, nil
|
||
}
|