videoplayer/service/dnsService.go

787 lines
19 KiB
Go
Raw Normal View History

2026-05-02 14:45:51 +08:00
package service
import (
"errors"
"fmt"
"net"
"sync"
"videoplayer/dao"
"videoplayer/proto"
"github.com/miekg/dns"
)
// DNSServiceInstance 运行中的DNS服务实例
type DNSServiceInstance struct {
ServerID uint `json:"server_id"`
Config *proto.DNSServer `json:"config"`
Server *dns.Server `json:"-"`
Running bool `json:"running"`
mutex sync.Mutex `json:"-"`
}
// DNSServiceManager DNS服务管理器
type DNSServiceManager struct {
instances map[uint]*DNSServiceInstance
rwMutex sync.RWMutex
}
var dnsServiceManager *DNSServiceManager
func init() {
dnsServiceManager = &DNSServiceManager{
instances: make(map[uint]*DNSServiceInstance),
}
}
// ==================== DNSServer 相关操作 ====================
func CreateDNSServer(req *proto.CreateDNSServerReq, userID uint) (proto.DNSServer, error) {
server := proto.DNSServer{
UserID: userID,
Name: req.Name,
Port: req.Port,
ListenIP: req.ListenIP,
UpstreamDNS: req.UpstreamDNS,
EnableRecursion: req.EnableRecursion,
Status: proto.DNS_STATUS_STOPPED,
Description: req.Description,
}
if server.Port == 0 {
server.Port = 53
}
if server.ListenIP == "" {
server.ListenIP = "0.0.0.0"
}
id, err := dao.CreateDNSServer(server)
if err != nil {
return proto.DNSServer{}, err
}
server.ID = id
return server, nil
}
func GetDNSServerList(req *proto.GetDNSServerReq, userID uint) ([]proto.DNSServer, error) {
var servers []proto.DNSServer
var err error
if req.ServerID > 0 {
server, err := dao.FindDNSServerByID(req.ServerID)
if err != nil {
return nil, err
}
if server.UserID != userID && GetUserByIDFromUserCenter(int(userID)).Role != "admin" {
return nil, errors.New("未授权访问DNS服务器")
}
servers = append(servers, server)
} else {
if req.GetType == 0 {
servers, err = dao.FindDNSServerByUserID(userID)
} else if req.GetType == 1 {
user := GetUserByIDFromUserCenter(int(userID))
if user.Role != "admin" {
return nil, errors.New("未授权访问仅管理员可获取全部DNS服务器")
}
servers, err = dao.FindAllDNSServer()
} else {
return nil, errors.New("无效的获取类型")
}
}
if err != nil {
return nil, err
}
return servers, nil
}
func UpdateDNSServer(req *proto.UpdateDNSServerReq, userID int) (proto.DNSServer, error) {
server, err := dao.FindDNSServerByID(req.ServerID)
if err != nil {
return proto.DNSServer{}, err
}
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
return proto.DNSServer{}, errors.New("未授权访问DNS服务器")
}
server.Name = req.Name
server.Port = req.Port
server.ListenIP = req.ListenIP
server.UpstreamDNS = req.UpstreamDNS
server.EnableRecursion = req.EnableRecursion
server.Status = req.Status
server.Description = req.Description
err = dao.UpdateDNSServer(server.ID, &server)
if err != nil {
return proto.DNSServer{}, err
}
// 如果服务正在运行,自动重启以应用新配置
go RestartDNSServerIfRunning(req.ServerID, userID)
return server, nil
}
func DeleteDNSServer(req *proto.DeleteDNSServerReq, userID int) error {
user := GetUserByIDFromUserCenter(userID)
if req.DelType == 0 && req.ServerID > 0 {
server, err := dao.FindDNSServerByID(req.ServerID)
if err != nil {
return err
}
if server.UserID != uint(userID) && user.Role != "admin" {
return errors.New("未授权访问DNS服务器")
}
err = dao.DeleteDNSZoneByServerID(req.ServerID)
if err != nil {
return err
}
err = dao.DeleteDNSServerByID(req.ServerID)
if err != nil {
return err
}
} else if req.DelType == 1 {
if user.Role != "admin" {
return errors.New("未授权访问仅管理员可删除所有DNS服务器")
}
servers, err := dao.FindAllDNSServer()
if err != nil {
return err
}
for _, server := range servers {
err = dao.DeleteDNSZoneByServerID(server.ID)
if err != nil {
return err
}
err = dao.DeleteDNSServerByID(server.ID)
if err != nil {
return err
}
}
} else {
return errors.New("无效的删除类型或参数")
}
return nil
}
// ==================== DNSZone 相关操作 ====================
func CreateDNSZone(req *proto.CreateDNSZoneReq, userID uint) (proto.DNSZone, error) {
server, err := dao.FindDNSServerByID(req.ServerID)
if err != nil {
return proto.DNSZone{}, err
}
if server.UserID != userID && GetUserByIDFromUserCenter(int(userID)).Role != "admin" {
return proto.DNSZone{}, errors.New("未授权访问DNS服务器")
}
zone := proto.DNSZone{
ServerID: req.ServerID,
Domain: req.Domain,
SOA_MName: req.SOA_MName,
SOA_RName: req.SOA_RName,
SOA_Serial: req.SOA_Serial,
SOA_Refresh: req.SOA_Refresh,
SOA_Retry: req.SOA_Retry,
SOA_Expire: req.SOA_Expire,
SOA_Minimum: req.SOA_Minimum,
TTL: req.TTL,
Description: req.Description,
}
if zone.SOA_Serial == 0 {
zone.SOA_Serial = 1
}
if zone.SOA_Refresh == 0 {
zone.SOA_Refresh = 86400
}
if zone.SOA_Retry == 0 {
zone.SOA_Retry = 7200
}
if zone.SOA_Expire == 0 {
zone.SOA_Expire = 3600000
}
if zone.SOA_Minimum == 0 {
zone.SOA_Minimum = 3600
}
if zone.TTL == 0 {
zone.TTL = 3600
}
id, err := dao.CreateDNSZone(zone)
if err != nil {
return proto.DNSZone{}, err
}
zone.ID = id
return zone, nil
}
func GetDNSZoneList(req *proto.GetDNSZoneReq, userID uint) ([]proto.DNSZone, error) {
var zones []proto.DNSZone
var err error
if req.ZoneID > 0 {
zone, err := dao.FindDNSZoneByID(req.ZoneID)
if err != nil {
return nil, err
}
server, err := dao.FindDNSServerByID(zone.ServerID)
if err != nil {
return nil, err
}
if server.UserID != userID && GetUserByIDFromUserCenter(int(userID)).Role != "admin" {
return nil, errors.New("未授权访问DNS区域")
}
zones = append(zones, zone)
} else {
if req.GetType == 0 {
if req.ServerID > 0 {
zones, err = dao.FindDNSZoneByServerID(req.ServerID)
} else {
servers, err := dao.FindDNSServerByUserID(userID)
if err != nil {
return nil, err
}
for _, server := range servers {
serverZones, err := dao.FindDNSZoneByServerID(server.ID)
if err != nil {
return nil, err
}
zones = append(zones, serverZones...)
}
}
} else if req.GetType == 1 {
user := GetUserByIDFromUserCenter(int(userID))
if user.Role != "admin" {
return nil, errors.New("未授权访问仅管理员可获取全部DNS区域")
}
if req.ServerID > 0 {
zones, err = dao.FindDNSZoneByServerID(req.ServerID)
} else {
zones, err = dao.FindAllDNSZone()
}
} else {
return nil, errors.New("无效的获取类型")
}
}
if err != nil {
return nil, err
}
return zones, nil
}
func UpdateDNSZone(req *proto.UpdateDNSZoneReq, userID int) (proto.DNSZone, error) {
zone, err := dao.FindDNSZoneByID(req.ZoneID)
if err != nil {
return proto.DNSZone{}, err
}
server, err := dao.FindDNSServerByID(zone.ServerID)
if err != nil {
return proto.DNSZone{}, err
}
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
return proto.DNSZone{}, errors.New("未授权访问DNS区域")
}
zone.Domain = req.Domain
zone.SOA_MName = req.SOA_MName
zone.SOA_RName = req.SOA_RName
zone.SOA_Serial = req.SOA_Serial
zone.SOA_Refresh = req.SOA_Refresh
zone.SOA_Retry = req.SOA_Retry
zone.SOA_Expire = req.SOA_Expire
zone.SOA_Minimum = req.SOA_Minimum
zone.TTL = req.TTL
zone.Description = req.Description
err = dao.UpdateDNSZone(zone.ID, &zone)
if err != nil {
return proto.DNSZone{}, err
}
return zone, nil
}
func DeleteDNSZone(req *proto.DeleteDNSZoneReq, userID int) error {
user := GetUserByIDFromUserCenter(userID)
if req.DelType == 0 && req.ZoneID > 0 {
zone, err := dao.FindDNSZoneByID(req.ZoneID)
if err != nil {
return err
}
server, err := dao.FindDNSServerByID(zone.ServerID)
if err != nil {
return err
}
if server.UserID != uint(userID) && user.Role != "admin" {
return errors.New("未授权访问DNS区域")
}
err = dao.DeleteDNSRecordByZoneID(req.ZoneID)
if err != nil {
return err
}
err = dao.DeleteDNSZoneByID(req.ZoneID)
if err != nil {
return err
}
} else if req.DelType == 1 {
if user.Role != "admin" {
return errors.New("未授权访问仅管理员可删除所有DNS区域")
}
zones, err := dao.FindAllDNSZone()
if err != nil {
return err
}
for _, zone := range zones {
err = dao.DeleteDNSRecordByZoneID(zone.ID)
if err != nil {
return err
}
err = dao.DeleteDNSZoneByID(zone.ID)
if err != nil {
return err
}
}
} else {
return errors.New("无效的删除类型或参数")
}
return nil
}
// ==================== DNSRecord 相关操作 ====================
func CreateDNSRecord(req *proto.CreateDNSRecordReq, userID uint) (proto.DNSRecord, error) {
zone, err := dao.FindDNSZoneByID(req.ZoneID)
if err != nil {
return proto.DNSRecord{}, err
}
server, err := dao.FindDNSServerByID(zone.ServerID)
if err != nil {
return proto.DNSRecord{}, err
}
if server.UserID != userID && GetUserByIDFromUserCenter(int(userID)).Role != "admin" {
return proto.DNSRecord{}, errors.New("未授权访问DNS区域")
}
record := proto.DNSRecord{
ZoneID: req.ZoneID,
Name: req.Name,
Type: req.Type,
Value: req.Value,
TTL: req.TTL,
Priority: req.Priority,
Weight: req.Weight,
Port: req.Port,
Target: req.Target,
}
if record.TTL == 0 {
record.TTL = zone.TTL
}
id, err := dao.CreateDNSRecord(record)
if err != nil {
return proto.DNSRecord{}, err
}
record.ID = id
return record, nil
}
func GetDNSRecordList(req *proto.GetDNSRecordReq, userID uint) ([]proto.DNSRecord, error) {
var records []proto.DNSRecord
var err error
if req.RecordID > 0 {
record, err := dao.FindDNSRecordByID(req.RecordID)
if err != nil {
return nil, err
}
zone, err := dao.FindDNSZoneByID(record.ZoneID)
if err != nil {
return nil, err
}
server, err := dao.FindDNSServerByID(zone.ServerID)
if err != nil {
return nil, err
}
if server.UserID != userID && GetUserByIDFromUserCenter(int(userID)).Role != "admin" {
return nil, errors.New("未授权访问DNS记录")
}
records = append(records, record)
} else {
if req.GetType == 0 {
if req.ZoneID > 0 {
records, err = dao.FindDNSRecordByZoneID(req.ZoneID)
} else {
servers, err := dao.FindDNSServerByUserID(userID)
if err != nil {
return nil, err
}
for _, server := range servers {
zones, err := dao.FindDNSZoneByServerID(server.ID)
if err != nil {
return nil, err
}
for _, zone := range zones {
zoneRecords, err := dao.FindDNSRecordByZoneID(zone.ID)
if err != nil {
return nil, err
}
records = append(records, zoneRecords...)
}
}
}
} else if req.GetType == 1 {
user := GetUserByIDFromUserCenter(int(userID))
if user.Role != "admin" {
return nil, errors.New("未授权访问仅管理员可获取全部DNS记录")
}
if req.ZoneID > 0 {
records, err = dao.FindDNSRecordByZoneID(req.ZoneID)
} else {
records, err = dao.FindAllDNSRecord()
}
} else {
return nil, errors.New("无效的获取类型")
}
}
if err != nil {
return nil, err
}
return records, nil
}
func UpdateDNSRecord(req *proto.UpdateDNSRecordReq, userID int) (proto.DNSRecord, error) {
record, err := dao.FindDNSRecordByID(req.RecordID)
if err != nil {
return proto.DNSRecord{}, err
}
zone, err := dao.FindDNSZoneByID(record.ZoneID)
if err != nil {
return proto.DNSRecord{}, err
}
server, err := dao.FindDNSServerByID(zone.ServerID)
if err != nil {
return proto.DNSRecord{}, err
}
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
return proto.DNSRecord{}, errors.New("未授权访问DNS记录")
}
record.Name = req.Name
record.Type = req.Type
record.Value = req.Value
record.TTL = req.TTL
record.Priority = req.Priority
record.Weight = req.Weight
record.Port = req.Port
record.Target = req.Target
err = dao.UpdateDNSRecord(record.ID, &record)
if err != nil {
return proto.DNSRecord{}, err
}
return record, nil
}
func DeleteDNSRecord(req *proto.DeleteDNSRecordReq, userID int) error {
user := GetUserByIDFromUserCenter(userID)
if req.DelType == 0 && req.RecordID > 0 {
record, err := dao.FindDNSRecordByID(req.RecordID)
if err != nil {
return err
}
zone, err := dao.FindDNSZoneByID(record.ZoneID)
if err != nil {
return err
}
server, err := dao.FindDNSServerByID(zone.ServerID)
if err != nil {
return err
}
if server.UserID != uint(userID) && user.Role != "admin" {
return errors.New("未授权访问DNS记录")
}
err = dao.DeleteDNSRecordByID(req.RecordID)
if err != nil {
return err
}
} else if req.DelType == 1 {
if user.Role != "admin" {
return errors.New("未授权访问仅管理员可删除所有DNS记录")
}
records, err := dao.FindAllDNSRecord()
if err != nil {
return err
}
for _, record := range records {
err = dao.DeleteDNSRecordByID(record.ID)
if err != nil {
return err
}
}
} else {
return errors.New("无效的删除类型或参数")
}
return nil
}
// ==================== DNS 服务运行管理 ====================
// handleDNSRequest 处理DNS查询请求
func (instance *DNSServiceInstance) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
for _, q := range r.Question {
// 查找对应的Zone
zones, err := dao.FindDNSZoneByServerID(instance.ServerID)
if err != nil {
continue
}
var matchedZone *proto.DNSZone
var qName string = dns.Fqdn(q.Name)
for _, zone := range zones {
zoneFqdn := dns.Fqdn(zone.Domain)
if dns.IsSubDomain(zoneFqdn, qName) {
matchedZone = &zone
break
}
}
if matchedZone == nil {
// 递归查询
if instance.Config.EnableRecursion {
c := new(dns.Client)
upstreams := instance.Config.UpstreamDNS
if upstreams == "" {
upstreams = "8.8.8.8:53"
}
resp, _, err := c.Exchange(r, upstreams)
if err == nil {
m.Answer = append(m.Answer, resp.Answer...)
}
}
continue
}
// 查找匹配的记录
records, err := dao.FindDNSRecordByZoneID(matchedZone.ID)
if err != nil {
continue
}
for _, record := range records {
recordName := dns.Fqdn(record.Name + "." + matchedZone.Domain)
if record.Name == "@" {
recordName = dns.Fqdn(matchedZone.Domain)
}
if recordName == qName && record.Type == uint(q.Qtype) {
var rr dns.RR
var ttl uint32
if record.TTL > 0 {
ttl = uint32(record.TTL)
} else {
ttl = uint32(matchedZone.TTL)
}
switch q.Qtype {
case dns.TypeA:
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN A %s", q.Name, ttl, record.Value))
case dns.TypeAAAA:
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN AAAA %s", q.Name, ttl, record.Value))
case dns.TypeCNAME:
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", q.Name, ttl, record.Value))
case dns.TypeMX:
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN MX %d %s", q.Name, ttl, record.Priority, record.Value))
case dns.TypeNS:
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN NS %s", q.Name, ttl, record.Value))
case dns.TypeTXT:
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN TXT \"%s\"", q.Name, ttl, record.Value))
case dns.TypeSRV:
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN SRV %d %d %d %s", q.Name, ttl, record.Priority, record.Weight, record.Port, record.Target))
}
if rr != nil {
m.Answer = append(m.Answer, rr)
}
}
}
}
w.WriteMsg(m)
}
// StartDNSServer 启动DNS服务
func StartDNSServer(serverID uint, userID int) (*DNSServiceInstance, error) {
// 检查权限
server, err := dao.FindDNSServerByID(serverID)
if err != nil {
return nil, err
}
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
return nil, errors.New("未授权访问DNS服务器")
}
dnsServiceManager.rwMutex.Lock()
defer dnsServiceManager.rwMutex.Unlock()
// 检查是否已经在运行
if instance, exists := dnsServiceManager.instances[serverID]; exists && instance.Running {
return instance, nil
}
// 创建服务实例
instance := &DNSServiceInstance{
ServerID: serverID,
Config: &server,
Running: false,
}
// 创建DNS服务器
dns.HandleFunc(".", instance.handleDNSRequest)
addr := fmt.Sprintf("%s:%d", server.ListenIP, server.Port)
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return nil, err
}
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
udpConn.Close()
return nil, err
}
tcpListener, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
udpConn.Close()
return nil, err
}
instance.Server = &dns.Server{
Listener: tcpListener,
PacketConn: udpConn,
}
// 启动服务
go func() {
instance.mutex.Lock()
instance.Running = true
instance.mutex.Unlock()
err := instance.Server.ActivateAndServe()
if err != nil {
instance.mutex.Lock()
instance.Running = false
instance.mutex.Unlock()
}
}()
// 更新数据库状态
server.Status = proto.DNS_STATUS_RUNNING
dao.UpdateDNSServer(serverID, &server)
dnsServiceManager.instances[serverID] = instance
return instance, nil
}
// StopDNSServer 停止DNS服务
func StopDNSServer(serverID uint, userID int) error {
server, err := dao.FindDNSServerByID(serverID)
if err != nil {
return err
}
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
return errors.New("未授权访问DNS服务器")
}
dnsServiceManager.rwMutex.Lock()
defer dnsServiceManager.rwMutex.Unlock()
instance, exists := dnsServiceManager.instances[serverID]
if !exists || !instance.Running {
return nil
}
// 停止服务
err = instance.Server.Shutdown()
if err != nil {
return err
}
instance.mutex.Lock()
instance.Running = false
instance.mutex.Unlock()
// 更新数据库状态
server.Status = proto.DNS_STATUS_STOPPED
dao.UpdateDNSServer(serverID, &server)
delete(dnsServiceManager.instances, serverID)
return nil
}
// RestartDNSServer 重启DNS服务
func RestartDNSServer(serverID uint, userID int) (*DNSServiceInstance, error) {
err := StopDNSServer(serverID, userID)
if err != nil {
return nil, err
}
return StartDNSServer(serverID, userID)
}
// GetDNSServerStatus 获取DNS服务运行状态
func GetDNSServerStatus(serverID uint, userID int) (*DNSServiceInstance, error) {
server, err := dao.FindDNSServerByID(serverID)
if err != nil {
return nil, err
}
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
return nil, errors.New("未授权访问DNS服务器")
}
dnsServiceManager.rwMutex.RLock()
defer dnsServiceManager.rwMutex.RUnlock()
instance, exists := dnsServiceManager.instances[serverID]
if exists {
return instance, nil
}
// 不在运行中,返回配置中的状态
return &DNSServiceInstance{
ServerID: serverID,
Config: &server,
Running: server.Status == proto.DNS_STATUS_RUNNING,
}, nil
}
// RestartDNSServerIfRunning 如果服务正在运行则重启(配置更新时调用)
func RestartDNSServerIfRunning(serverID uint, userID int) error {
dnsServiceManager.rwMutex.RLock()
instance, exists := dnsServiceManager.instances[serverID]
dnsServiceManager.rwMutex.RUnlock()
if exists && instance.Running {
// 获取最新配置
server, err := dao.FindDNSServerByID(serverID)
if err != nil {
return err
}
// 重启服务(管理员权限)
_, err = RestartDNSServer(server.ID, userID) // 假设管理员ID为1
return err
}
return nil
}