2025-03-21 14:44:40 +08:00
package dao
import (
"StuAcaWorksAI/proto"
2025-03-25 15:12:24 +08:00
"encoding/json"
2025-03-21 14:44:40 +08:00
"gorm.io/gorm"
2025-03-25 15:12:24 +08:00
"log"
2025-03-21 14:44:40 +08:00
)
type Session struct {
gorm . Model
2025-03-25 15:19:15 +08:00
UserID int ` gorm:"column:user_id" ` //只能由用户创建
Name string ` gorm:"column:name" `
Context json . RawMessage ` gorm:"column:context;type:json" ` //会话上下文
2025-03-21 14:44:40 +08:00
}
type Message struct {
gorm . Model
2025-03-26 13:51:38 +08:00
Type int ` gorm:"column:type" ` //1用户之间的消息, 2为与模型消息
SessionID int ` gorm:"column:session_id;index" ` //添加索引
FromID int ` gorm:"column:from_id" ` //发送者,可以为用户或者模型
ToID int ` gorm:"column:to_id" ` //接收者,可以为用户或者模型,如果为模型, 则为模型id, 根据type判断
Msg string ` gorm:"column:msg" ` //消息内容
Status int ` gorm:"column:status" ` //0为未读, 1为已读
2025-03-21 14:44:40 +08:00
}
func CreateSession ( userID int , name string ) ( error , uint ) {
session := Session { UserID : userID , Name : name }
var res * gorm . DB
if proto . Config . SERVER_SQL_LOG {
res = DB . Debug ( ) . Create ( & session )
} else {
res = DB . Create ( & session )
}
return res . Error , session . ID
}
func FindSessionByID ( id int ) Session {
var session Session
if proto . Config . SERVER_SQL_LOG {
DB . Debug ( ) . Where ( "id = ?" , id ) . First ( & session )
} else {
DB . Where ( "id = ?" , id ) . First ( & session )
}
return session
}
// 获取用户最新的创建时间会话列表 50个
func FindSessionByUserID ( userID int ) [ ] Session {
var sessions [ ] Session
if proto . Config . SERVER_SQL_LOG {
DB . Debug ( ) . Where ( "user_id = ?" , userID ) . Order ( "created_at DESC" ) . Limit ( 50 ) . Find ( & sessions )
} else {
DB . Where ( "user_id = ?" , userID ) . Order ( "created_at DESC" ) . Limit ( 50 ) . Find ( & sessions )
}
return sessions
}
// 更新会话的名字
2025-03-23 13:17:03 +08:00
func UpdateSessionByID ( id int , userId int , name string , context [ ] int ) error {
2025-03-21 14:44:40 +08:00
var res * gorm . DB
2025-03-25 15:12:24 +08:00
contextJson , err := json . Marshal ( context )
if err != nil {
log . Println ( "update session json marshal error:" , err )
}
2025-03-21 14:44:40 +08:00
if proto . Config . SERVER_SQL_LOG {
2025-03-25 15:12:24 +08:00
res = DB . Debug ( ) . Model ( & Session { } ) . Where ( "id = ? AND user_id = ?" , id , userId ) . Updates ( Session { Name : name } ) . Update ( "context" , contextJson )
2025-03-21 14:44:40 +08:00
} else {
2025-03-25 15:12:24 +08:00
res = DB . Model ( & Session { } ) . Where ( "id = ? AND user_id = ?" , id , userId ) . Updates ( Session { Name : name , Context : nil } ) . Update ( "context" , contextJson )
2025-03-21 14:44:40 +08:00
}
return res . Error
}
// 删除会话
2025-03-21 16:20:24 +08:00
func DeleteSessionByID ( id int ) error {
2025-03-21 14:44:40 +08:00
var res * gorm . DB
if proto . Config . SERVER_SQL_LOG {
res = DB . Debug ( ) . Delete ( & Session { } , id )
} else {
res = DB . Delete ( & Session { } , id )
}
return res . Error
}
// 创建消息
func CreateMessage ( sessionID , fromID , toID , msgType , status int , msg string ) ( error , uint ) {
message := Message { SessionID : sessionID , FromID : fromID , ToID : toID , Type : msgType , Status : status , Msg : msg }
var res * gorm . DB
if proto . Config . SERVER_SQL_LOG {
res = DB . Debug ( ) . Create ( & message )
} else {
res = DB . Create ( & message )
}
return res . Error , message . ID
}
// 根据会话id获取消息
func FindMessageBySessionID ( sessionID int ) [ ] Message {
var messages [ ] Message
if proto . Config . SERVER_SQL_LOG {
DB . Debug ( ) . Where ( "session_id = ?" , sessionID ) . Order ( "created_at DESC" ) . Find ( & messages )
} else {
DB . Where ( "session_id = ?" , sessionID ) . Order ( "created_at DESC" ) . Find ( & messages )
}
return messages
}
2025-03-25 13:26:51 +08:00
2025-03-26 13:51:38 +08:00
func FindMessageBySessionIDV2 ( sessionID int ) [ ] Message {
var messages [ ] Message
if proto . Config . SERVER_SQL_LOG {
DB . Debug ( ) . Where ( "session_id = ?" , sessionID ) . Find ( & messages )
} else {
DB . Where ( "session_id = ?" , sessionID ) . Find ( & messages )
}
return messages
}
2025-03-25 13:26:51 +08:00
type SessionMessageCount struct {
SessionID int ` gorm:"column:session_id" `
Count int ` gorm:"column:count" `
}
// 获取属于用户的会话消息数量
func FindSessionMessageCountByUserID ( userID int ) [ ] SessionMessageCount {
var sessionMessageCounts [ ] SessionMessageCount
if proto . Config . SERVER_SQL_LOG {
//DB.Debug().Exec("Select session_id, count(*) as count from messages where to_id = ? or from_id=? group by session_id", userID,userID).Scan(&sessionMessageCounts)
DB . Debug ( ) . Table ( "messages" ) . Select ( "session_id, count(*) as count" ) . Where ( "to_id = ? or from_id= ?" , userID , userID ) . Group ( "session_id" ) . Scan ( & sessionMessageCounts )
} else {
DB . Table ( "messages" ) . Select ( "session_id, count(*) as count" ) . Where ( "to_id = ? or from_id= ?" , userID , userID ) . Group ( "session_id" ) . Scan ( & sessionMessageCounts )
}
return sessionMessageCounts
}