package dao import ( "StuAcaWorksAI/proto" "encoding/json" "gorm.io/gorm" "log" "time" ) type Session struct { gorm.Model UserID int `gorm:"column:user_id"` //只能由用户创建 Name string `gorm:"column:name"` Type int `gorm:"column:type"` //会话类型:1为通用聊天会话,2为提示词会话,3为制作ppt会话 Context json.RawMessage `gorm:"column:context;type:json"` //会话上下文 } type Message struct { gorm.Model Type int `gorm:"column:type"` //1用户之间的消息,2为与模型消息 SessionID int `gorm:"column:session_id;index"` //添加索引 FromID int `gorm:"column:from_id"` //发送者,可以为用户或者模型 ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型,则为模型id,根据type判断 Msg string `gorm:"column:msg"` //消息内容 FunctionID int `gorm:"column:function_id"` //功能id Status int `gorm:"column:status"` //0为未读,1为已读,3为图片消息,4为文件消息,5为制作ppt消息 } func CreateSession(userID, Type int, name string) (error, uint) { session := Session{UserID: userID, Name: name, Type: Type} var res *gorm.DB if proto.Config.SERVER_SQL_LOG { res = DB.Debug().Create(&session) } else { res = DB.Create(&session) } return res.Error, session.ID } func FindSessionByID(id int) Session { var session Session if proto.Config.SERVER_SQL_LOG { DB.Debug().Where("id = ?", id).First(&session) } else { DB.Where("id = ?", id).First(&session) } return session } // 获取用户最新的创建时间会话列表 50个 func FindSessionByUserID(userID, sessionType int) []Session { var sessions []Session if proto.Config.SERVER_SQL_LOG { DB.Debug().Where("user_id = ? and type = ?", userID, sessionType).Order("created_at DESC").Limit(50).Find(&sessions) } else { DB.Where("user_id = ? and type = ?", userID, sessionType).Order("created_at DESC").Limit(50).Find(&sessions) } return sessions } // 更新会话的名字 func UpdateSessionByID(id int, userId int, name string, context []int) error { var res *gorm.DB contextJson, err := json.Marshal(context) if err != nil { log.Println("update session json marshal error:", err) } if proto.Config.SERVER_SQL_LOG { res = DB.Debug().Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name}).Update("context", contextJson) } else { res = DB.Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name, Context: nil}).Update("context", contextJson) } return res.Error } // 删除会话 func DeleteSessionByID(id int) error { var res *gorm.DB if proto.Config.SERVER_SQL_LOG { res = DB.Debug().Delete(&Session{}, id) } else { res = DB.Delete(&Session{}, id) } return res.Error } // 创建消息 func CreateMessage(sessionID, fromID, toID, msgType, status int, msg string, function_id int) (error, uint) { message := Message{SessionID: sessionID, FromID: fromID, ToID: toID, Type: msgType, Status: status, Msg: msg, FunctionID: function_id} var res *gorm.DB if proto.Config.SERVER_SQL_LOG { res = DB.Debug().Create(&message) } else { res = DB.Create(&message) } return res.Error, message.ID } // 根据会话id获取消息 func FindMessageBySessionID(sessionID int) []Message { var messages []Message if proto.Config.SERVER_SQL_LOG { DB.Debug().Where("session_id = ?", sessionID).Order("created_at DESC").Find(&messages) } else { DB.Where("session_id = ?", sessionID).Order("created_at DESC").Find(&messages) } return messages } func FindMessageBySessionIDV2(sessionID int) []Message { var messages []Message if proto.Config.SERVER_SQL_LOG { DB.Debug().Where("session_id = ?", sessionID).Find(&messages) } else { DB.Where("session_id = ?", sessionID).Find(&messages) } return messages } type SessionMessageCount struct { SessionID int `gorm:"column:session_id"` Count int `gorm:"column:count"` } // 获取属于用户的会话消息数量 func FindSessionMessageCountByUserID(userID int) []SessionMessageCount { var sessionMessageCounts []SessionMessageCount if proto.Config.SERVER_SQL_LOG { //DB.Debug().Exec("Select session_id, count(*) as count from messages where to_id = ? or from_id=? group by session_id", userID,userID).Scan(&sessionMessageCounts) DB.Debug().Table("messages").Select("session_id, count(*) as count").Where("to_id = ? or from_id= ?", userID, userID).Group("session_id").Scan(&sessionMessageCounts) } else { DB.Table("messages").Select("session_id, count(*) as count").Where("to_id = ? or from_id= ?", userID, userID).Group("session_id").Scan(&sessionMessageCounts) } return sessionMessageCounts } // 获取会话数,问答数,今日问答数 func FindBaseSessionMessageStatisticsInfo() (int64, int64, int64, error) { //所有会话数 var sessionCount int64 var messageCount int64 var todayMessageCount int64 today := time.Now().Truncate(24 * time.Hour) var db2 *gorm.DB if proto.Config.SERVER_SQL_LOG { db2 = DB.Debug() } else { db2 = DB } // 查询会话数量 if err := db2.Model(&Session{}).Count(&sessionCount).Error; err != nil { return 0, 0, 0, err } // 查询消息数量 if err := db2.Model(&Message{}).Count(&messageCount).Error; err != nil { return 0, 0, 0, err } // 查询今天的消息数量 if err := db2.Model(&Message{}).Where("created_at > ?", today).Count(&todayMessageCount).Error; err != nil { return 0, 0, 0, err } return sessionCount, messageCount, todayMessageCount, nil }