saw-go/dao/im.go

229 lines
7.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package dao
import (
"StuAcaWorksAI/proto"
"encoding/json"
"gorm.io/gorm"
"log"
"time"
)
type Session struct {
gorm.Model
UserID int `gorm:"column:user_id"` //只能由用户创建
Name string `gorm:"column:name"`
Type int `gorm:"column:type"` //会话类型:1为通用聊天会话2为提示词会话3为制作ppt会话
Context json.RawMessage `gorm:"column:context;type:json"` //会话上下文
}
type Message struct {
gorm.Model
Type int `gorm:"column:type"` //1用户之间的消息2为与模型消息
SessionID int `gorm:"column:session_id;index"` //添加索引
FromID int `gorm:"column:from_id"` //发送者,可以为用户或者模型
ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型则为模型id根据type判断
Msg string `gorm:"column:msg"` //消息内容
FunctionID int `gorm:"column:function_id"` //功能id
Status int `gorm:"column:status"` //0为未读1为已读3为图片消息4为文件消息,5为制作ppt消息
}
func CreateSession(userID, Type int, name string) (error, uint) {
session := Session{UserID: userID, Name: name, Type: Type}
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Create(&session)
} else {
res = DB.Create(&session)
}
return res.Error, session.ID
}
func FindSessionByID(id int) Session {
var session Session
if proto.Config.SERVER_SQL_LOG {
DB.Debug().Where("id = ?", id).First(&session)
} else {
DB.Where("id = ?", id).First(&session)
}
return session
}
// 获取用户最新的创建时间会话列表 50个
func FindSessionByUserID(userID, sessionType int) []Session {
var sessions []Session
if proto.Config.SERVER_SQL_LOG {
DB.Debug().Where("user_id = ? and type = ?", userID, sessionType).Order("created_at DESC").Limit(50).Find(&sessions)
} else {
DB.Where("user_id = ? and type = ?", userID, sessionType).Order("created_at DESC").Limit(50).Find(&sessions)
}
return sessions
}
// 获取用户最新的创建时间会话列表 100个
func FindSessionByUserIDV2(userID int) []Session {
var sessions []Session
if proto.Config.SERVER_SQL_LOG {
DB.Debug().Where("user_id = ?", userID).Order("created_at DESC").Limit(100).Find(&sessions)
} else {
DB.Where("user_id = ?", userID).Order("created_at DESC").Limit(100).Find(&sessions)
}
return sessions
}
// 更新会话的名字
func UpdateSessionByID(id int, userId int, name string, context []int) error {
var res *gorm.DB
contextJson, err := json.Marshal(context)
if err != nil {
log.Println("update session json marshal error:", err)
}
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name}).Update("context", contextJson)
} else {
res = DB.Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name, Context: nil}).Update("context", contextJson)
}
return res.Error
}
// 删除会话
func DeleteSessionByID(id int) error {
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Delete(&Session{}, id)
} else {
res = DB.Delete(&Session{}, id)
}
return res.Error
}
// 创建消息
func CreateMessage(sessionID, fromID, toID, msgType, status int, msg string, function_id int) (error, uint) {
message := Message{SessionID: sessionID, FromID: fromID, ToID: toID, Type: msgType, Status: status, Msg: msg, FunctionID: function_id}
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Create(&message)
} else {
res = DB.Create(&message)
}
return res.Error, message.ID
}
// 根据会话id获取消息
func FindMessageBySessionID(sessionID int) []Message {
var messages []Message
if proto.Config.SERVER_SQL_LOG {
DB.Debug().Where("session_id = ?", sessionID).Order("created_at DESC").Find(&messages)
} else {
DB.Where("session_id = ?", sessionID).Order("created_at DESC").Find(&messages)
}
return messages
}
func FindMessageBySessionIDV2(sessionID int) []Message {
var messages []Message
if proto.Config.SERVER_SQL_LOG {
DB.Debug().Where("session_id = ?", sessionID).Find(&messages)
} else {
DB.Where("session_id = ?", sessionID).Find(&messages)
}
return messages
}
type SessionMessageCount struct {
SessionID int `gorm:"column:session_id"`
Count int `gorm:"column:count"`
}
// 获取属于用户的会话消息数量
func FindSessionMessageCountByUserID(userID int) []SessionMessageCount {
var sessionMessageCounts []SessionMessageCount
if proto.Config.SERVER_SQL_LOG {
//DB.Debug().Exec("Select session_id, count(*) as count from messages where to_id = ? or from_id=? group by session_id", userID,userID).Scan(&sessionMessageCounts)
DB.Debug().Table("messages").Select("session_id, count(*) as count").Where("to_id = ? or from_id= ?", userID, userID).Group("session_id").Scan(&sessionMessageCounts)
} else {
DB.Table("messages").Select("session_id, count(*) as count").Where("to_id = ? or from_id= ?", userID, userID).Group("session_id").Scan(&sessionMessageCounts)
}
return sessionMessageCounts
}
// 获取会话数,问答数,今日问答数
func FindBaseSessionMessageStatisticsInfo() (int64, int64, int64, error) {
//所有会话数
var sessionCount int64
var messageCount int64
var todayMessageCount int64
today := time.Now().Truncate(24 * time.Hour)
var db2 *gorm.DB
if proto.Config.SERVER_SQL_LOG {
db2 = DB.Debug()
} else {
db2 = DB
}
// 查询会话数量
if err := db2.Model(&Session{}).Count(&sessionCount).Error; err != nil {
return 0, 0, 0, err
}
// 查询消息数量
if err := db2.Model(&Message{}).Count(&messageCount).Error; err != nil {
return 0, 0, 0, err
}
// 查询今天的消息数量
if err := db2.Model(&Message{}).Where("created_at > ?", today).Count(&todayMessageCount).Error; err != nil {
return 0, 0, 0, err
}
return sessionCount, messageCount, todayMessageCount, nil
}
func FindUserSessionCount(userID, sessionType int) int64 {
var sessionCount int64
var db2 *gorm.DB
if proto.Config.SERVER_SQL_LOG {
db2 = DB.Debug()
} else {
db2 = DB
}
if err := db2.Model(&Session{}).Where("user_id = ? and type = ?", userID, sessionType).Count(&sessionCount).Error; err != nil {
log.Println("find user session count error:", err)
}
log.Println("user session count:", sessionCount)
return sessionCount
}
type SessionOrMessageCount struct {
Date string `gorm:"column:date"`
Count int64 `gorm:"column:count"`
}
// 获取最近一周每天的会话数
func FindSessionCountByDate() ([]SessionOrMessageCount, error) {
var sessionCounts []SessionOrMessageCount
var db2 *gorm.DB
if proto.Config.SERVER_SQL_LOG {
db2 = DB.Debug()
} else {
db2 = DB
}
if err := db2.Model(&Session{}).Select("DATE(created_at) as date, COUNT(*) as count").Where("created_at >= DATE_SUB(CURDATE(), INTERVAL 7 DAY)").Group("date").Order("date DESC").Scan(&sessionCounts).Error; err != nil {
return nil, err
}
return sessionCounts, nil
}
// 获取最近一周每天的消息数
func FindMessageCountByDate() ([]SessionOrMessageCount, error) {
var messageCounts []SessionOrMessageCount
var db2 *gorm.DB
if proto.Config.SERVER_SQL_LOG {
db2 = DB.Debug()
} else {
db2 = DB
}
if err := db2.Model(&Message{}).Select("DATE(created_at) as date, COUNT(*) as count").Where("created_at >= DATE_SUB(CURDATE(), INTERVAL 7 DAY)").Group("date").Order("date DESC").Scan(&messageCounts).Error; err != nil {
return nil, err
}
return messageCounts, nil
}