2025-03-21 14:44:40 +08:00
package dao
import (
"StuAcaWorksAI/proto"
2025-03-25 15:12:24 +08:00
"encoding/json"
2025-03-21 14:44:40 +08:00
"gorm.io/gorm"
2025-03-25 15:12:24 +08:00
"log"
2025-03-31 15:36:34 +08:00
"time"
2025-03-21 14:44:40 +08:00
)
type Session struct {
gorm . Model
2025-03-25 15:19:15 +08:00
UserID int ` gorm:"column:user_id" ` //只能由用户创建
Name string ` gorm:"column:name" `
2025-04-03 14:14:02 +08:00
Type int ` gorm:"column:type" ` //会话类型:1为通用聊天会话, 2为提示词会话, 3为制作ppt会话
2025-03-25 15:19:15 +08:00
Context json . RawMessage ` gorm:"column:context;type:json" ` //会话上下文
2025-03-21 14:44:40 +08:00
}
type Message struct {
gorm . Model
2025-03-28 11:23:45 +08:00
Type int ` gorm:"column:type" ` //1用户之间的消息, 2为与模型消息
SessionID int ` gorm:"column:session_id;index" ` //添加索引
FromID int ` gorm:"column:from_id" ` //发送者,可以为用户或者模型
ToID int ` gorm:"column:to_id" ` //接收者,可以为用户或者模型,如果为模型, 则为模型id, 根据type判断
Msg string ` gorm:"column:msg" ` //消息内容
FunctionID int ` gorm:"column:function_id" ` //功能id
2025-04-03 16:13:22 +08:00
Status int ` gorm:"column:status" ` //0为未读, 1为已读, 3为图片消息, 4为文件消息,5为制作ppt消息
2025-03-21 14:44:40 +08:00
}
2025-04-03 16:13:22 +08:00
func CreateSession ( userID , Type int , name string ) ( error , uint ) {
session := Session { UserID : userID , Name : name , Type : Type }
2025-03-21 14:44:40 +08:00
var res * gorm . DB
if proto . Config . SERVER_SQL_LOG {
res = DB . Debug ( ) . Create ( & session )
} else {
res = DB . Create ( & session )
}
return res . Error , session . ID
}
func FindSessionByID ( id int ) Session {
var session Session
if proto . Config . SERVER_SQL_LOG {
DB . Debug ( ) . Where ( "id = ?" , id ) . First ( & session )
} else {
DB . Where ( "id = ?" , id ) . First ( & session )
}
return session
}
// 获取用户最新的创建时间会话列表 50个
2025-04-03 19:16:14 +08:00
func FindSessionByUserID ( userID , sessionType int ) [ ] Session {
2025-03-21 14:44:40 +08:00
var sessions [ ] Session
if proto . Config . SERVER_SQL_LOG {
2025-04-03 19:16:14 +08:00
DB . Debug ( ) . Where ( "user_id = ? and type = ?" , userID , sessionType ) . Order ( "created_at DESC" ) . Limit ( 50 ) . Find ( & sessions )
2025-03-21 14:44:40 +08:00
} else {
2025-04-03 19:16:14 +08:00
DB . Where ( "user_id = ? and type = ?" , userID , sessionType ) . Order ( "created_at DESC" ) . Limit ( 50 ) . Find ( & sessions )
2025-03-21 14:44:40 +08:00
}
return sessions
}
// 更新会话的名字
2025-03-23 13:17:03 +08:00
func UpdateSessionByID ( id int , userId int , name string , context [ ] int ) error {
2025-03-21 14:44:40 +08:00
var res * gorm . DB
2025-03-25 15:12:24 +08:00
contextJson , err := json . Marshal ( context )
if err != nil {
log . Println ( "update session json marshal error:" , err )
}
2025-03-21 14:44:40 +08:00
if proto . Config . SERVER_SQL_LOG {
2025-03-25 15:12:24 +08:00
res = DB . Debug ( ) . Model ( & Session { } ) . Where ( "id = ? AND user_id = ?" , id , userId ) . Updates ( Session { Name : name } ) . Update ( "context" , contextJson )
2025-03-21 14:44:40 +08:00
} else {
2025-03-25 15:12:24 +08:00
res = DB . Model ( & Session { } ) . Where ( "id = ? AND user_id = ?" , id , userId ) . Updates ( Session { Name : name , Context : nil } ) . Update ( "context" , contextJson )
2025-03-21 14:44:40 +08:00
}
return res . Error
}
// 删除会话
2025-03-21 16:20:24 +08:00
func DeleteSessionByID ( id int ) error {
2025-03-21 14:44:40 +08:00
var res * gorm . DB
if proto . Config . SERVER_SQL_LOG {
res = DB . Debug ( ) . Delete ( & Session { } , id )
} else {
res = DB . Delete ( & Session { } , id )
}
return res . Error
}
// 创建消息
2025-03-28 11:23:45 +08:00
func CreateMessage ( sessionID , fromID , toID , msgType , status int , msg string , function_id int ) ( error , uint ) {
message := Message { SessionID : sessionID , FromID : fromID , ToID : toID , Type : msgType , Status : status , Msg : msg , FunctionID : function_id }
2025-03-21 14:44:40 +08:00
var res * gorm . DB
if proto . Config . SERVER_SQL_LOG {
res = DB . Debug ( ) . Create ( & message )
} else {
res = DB . Create ( & message )
}
return res . Error , message . ID
}
// 根据会话id获取消息
func FindMessageBySessionID ( sessionID int ) [ ] Message {
var messages [ ] Message
if proto . Config . SERVER_SQL_LOG {
DB . Debug ( ) . Where ( "session_id = ?" , sessionID ) . Order ( "created_at DESC" ) . Find ( & messages )
} else {
DB . Where ( "session_id = ?" , sessionID ) . Order ( "created_at DESC" ) . Find ( & messages )
}
return messages
}
2025-03-25 13:26:51 +08:00
2025-03-26 13:51:38 +08:00
func FindMessageBySessionIDV2 ( sessionID int ) [ ] Message {
var messages [ ] Message
if proto . Config . SERVER_SQL_LOG {
DB . Debug ( ) . Where ( "session_id = ?" , sessionID ) . Find ( & messages )
} else {
DB . Where ( "session_id = ?" , sessionID ) . Find ( & messages )
}
return messages
}
2025-03-25 13:26:51 +08:00
type SessionMessageCount struct {
SessionID int ` gorm:"column:session_id" `
Count int ` gorm:"column:count" `
}
// 获取属于用户的会话消息数量
func FindSessionMessageCountByUserID ( userID int ) [ ] SessionMessageCount {
var sessionMessageCounts [ ] SessionMessageCount
if proto . Config . SERVER_SQL_LOG {
//DB.Debug().Exec("Select session_id, count(*) as count from messages where to_id = ? or from_id=? group by session_id", userID,userID).Scan(&sessionMessageCounts)
DB . Debug ( ) . Table ( "messages" ) . Select ( "session_id, count(*) as count" ) . Where ( "to_id = ? or from_id= ?" , userID , userID ) . Group ( "session_id" ) . Scan ( & sessionMessageCounts )
} else {
DB . Table ( "messages" ) . Select ( "session_id, count(*) as count" ) . Where ( "to_id = ? or from_id= ?" , userID , userID ) . Group ( "session_id" ) . Scan ( & sessionMessageCounts )
}
return sessionMessageCounts
}
2025-03-31 15:36:34 +08:00
// 获取会话数,问答数,今日问答数
func FindBaseSessionMessageStatisticsInfo ( ) ( int64 , int64 , int64 , error ) {
//所有会话数
var sessionCount int64
var messageCount int64
var todayMessageCount int64
today := time . Now ( ) . Truncate ( 24 * time . Hour )
var db2 * gorm . DB
if proto . Config . SERVER_SQL_LOG {
db2 = DB . Debug ( )
} else {
db2 = DB
}
// 查询会话数量
if err := db2 . Model ( & Session { } ) . Count ( & sessionCount ) . Error ; err != nil {
return 0 , 0 , 0 , err
}
// 查询消息数量
if err := db2 . Model ( & Message { } ) . Count ( & messageCount ) . Error ; err != nil {
return 0 , 0 , 0 , err
}
// 查询今天的消息数量
if err := db2 . Model ( & Message { } ) . Where ( "created_at > ?" , today ) . Count ( & todayMessageCount ) . Error ; err != nil {
return 0 , 0 , 0 , err
}
return sessionCount , messageCount , todayMessageCount , nil
}
2025-04-05 19:57:50 +08:00
func FindUserSessionCount ( userID , sessionType int ) int64 {
var sessionCount int64
var db2 * gorm . DB
if proto . Config . SERVER_SQL_LOG {
db2 = DB . Debug ( )
} else {
db2 = DB
}
if err := db2 . Model ( & Session { } ) . Where ( "user_id = ? and type = ?" , userID , sessionType ) . Count ( & sessionCount ) . Error ; err != nil {
log . Println ( "find user session count error:" , err )
}
log . Println ( "user session count:" , sessionCount )
return sessionCount
}
2025-04-06 14:59:44 +08:00
2025-04-06 15:14:29 +08:00
type SessionOrMessageCount struct {
Date string ` gorm:"column:date" `
Count int64 ` gorm:"column:count" `
}
2025-04-06 14:59:44 +08:00
// 获取最近一周每天的会话数
2025-04-06 15:14:29 +08:00
func FindSessionCountByDate ( ) ( [ ] SessionOrMessageCount , error ) {
var sessionCounts [ ] SessionOrMessageCount
2025-04-06 14:59:44 +08:00
var db2 * gorm . DB
if proto . Config . SERVER_SQL_LOG {
db2 = DB . Debug ( )
} else {
db2 = DB
}
if err := db2 . Model ( & Session { } ) . Select ( "DATE(created_at) as date, COUNT(*) as count" ) . Where ( "created_at >= DATE_SUB(CURDATE(), INTERVAL 7 DAY)" ) . Group ( "date" ) . Order ( "date DESC" ) . Scan ( & sessionCounts ) . Error ; err != nil {
return nil , err
}
return sessionCounts , nil
}
// 获取最近一周每天的消息数
2025-04-06 15:14:29 +08:00
func FindMessageCountByDate ( ) ( [ ] SessionOrMessageCount , error ) {
var messageCounts [ ] SessionOrMessageCount
2025-04-06 14:59:44 +08:00
var db2 * gorm . DB
if proto . Config . SERVER_SQL_LOG {
db2 = DB . Debug ( )
} else {
db2 = DB
}
if err := db2 . Model ( & Message { } ) . Select ( "DATE(created_at) as date, COUNT(*) as count" ) . Where ( "created_at >= DATE_SUB(CURDATE(), INTERVAL 7 DAY)" ) . Group ( "date" ) . Order ( "date DESC" ) . Scan ( & messageCounts ) . Error ; err != nil {
return nil , err
}
return messageCounts , nil
}