saw-go/dao/im.go

125 lines
4.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package dao
import (
"StuAcaWorksAI/proto"
"encoding/json"
"gorm.io/gorm"
"log"
)
type Session struct {
gorm.Model
UserID int `gorm:"column:user_id"` //只能由用户创建
Name string `gorm:"column:name"`
Context []int `gorm:"column:context;type:json"` //会话上下文
}
type Message struct {
gorm.Model
Type int `gorm:"column:type"` //1用户之间的消息2为与模型消息
SessionID int `gorm:"column:session_id"`
FromID int `gorm:"column:from_id"` //发送者,可以为用户或者模型
ToID int `gorm:"column:to_id"` //接收者,可以为用户或者模型,如果为模型则为模型id根据type判断
Msg string `gorm:"column:msg"` //消息内容
Status int `gorm:"column:status"` //0为未读1为已读
}
func CreateSession(userID int, name string) (error, uint) {
session := Session{UserID: userID, Name: name}
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Create(&session)
} else {
res = DB.Create(&session)
}
return res.Error, session.ID
}
func FindSessionByID(id int) Session {
var session Session
if proto.Config.SERVER_SQL_LOG {
DB.Debug().Where("id = ?", id).First(&session)
} else {
DB.Where("id = ?", id).First(&session)
}
return session
}
// 获取用户最新的创建时间会话列表 50个
func FindSessionByUserID(userID int) []Session {
var sessions []Session
if proto.Config.SERVER_SQL_LOG {
DB.Debug().Where("user_id = ?", userID).Order("created_at DESC").Limit(50).Find(&sessions)
} else {
DB.Where("user_id = ?", userID).Order("created_at DESC").Limit(50).Find(&sessions)
}
return sessions
}
// 更新会话的名字
func UpdateSessionByID(id int, userId int, name string, context []int) error {
var res *gorm.DB
contextJson, err := json.Marshal(context)
if err != nil {
log.Println("update session json marshal error:", err)
}
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name}).Update("context", contextJson)
} else {
res = DB.Model(&Session{}).Where("id = ? AND user_id = ?", id, userId).Updates(Session{Name: name, Context: nil}).Update("context", contextJson)
}
return res.Error
}
// 删除会话
func DeleteSessionByID(id int) error {
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Delete(&Session{}, id)
} else {
res = DB.Delete(&Session{}, id)
}
return res.Error
}
// 创建消息
func CreateMessage(sessionID, fromID, toID, msgType, status int, msg string) (error, uint) {
message := Message{SessionID: sessionID, FromID: fromID, ToID: toID, Type: msgType, Status: status, Msg: msg}
var res *gorm.DB
if proto.Config.SERVER_SQL_LOG {
res = DB.Debug().Create(&message)
} else {
res = DB.Create(&message)
}
return res.Error, message.ID
}
// 根据会话id获取消息
func FindMessageBySessionID(sessionID int) []Message {
var messages []Message
if proto.Config.SERVER_SQL_LOG {
DB.Debug().Where("session_id = ?", sessionID).Order("created_at DESC").Find(&messages)
} else {
DB.Where("session_id = ?", sessionID).Order("created_at DESC").Find(&messages)
}
return messages
}
type SessionMessageCount struct {
SessionID int `gorm:"column:session_id"`
Count int `gorm:"column:count"`
}
// 获取属于用户的会话消息数量
func FindSessionMessageCountByUserID(userID int) []SessionMessageCount {
var sessionMessageCounts []SessionMessageCount
if proto.Config.SERVER_SQL_LOG {
//DB.Debug().Exec("Select session_id, count(*) as count from messages where to_id = ? or from_id=? group by session_id", userID,userID).Scan(&sessionMessageCounts)
DB.Debug().Table("messages").Select("session_id, count(*) as count").Where("to_id = ? or from_id= ?", userID, userID).Group("session_id").Scan(&sessionMessageCounts)
} else {
DB.Table("messages").Select("session_id, count(*) as count").Where("to_id = ? or from_id= ?", userID, userID).Group("session_id").Scan(&sessionMessageCounts)
}
return sessionMessageCounts
}