videoplayer/service/dnsService.go

787 lines
19 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"errors"
"fmt"
"net"
"sync"
"videoplayer/dao"
"videoplayer/proto"
"github.com/miekg/dns"
)
// DNSServiceInstance 运行中的DNS服务实例
type DNSServiceInstance struct {
ServerID uint `json:"server_id"`
Config *proto.DNSServer `json:"config"`
Server *dns.Server `json:"-"`
Running bool `json:"running"`
mutex sync.Mutex `json:"-"`
}
// DNSServiceManager DNS服务管理器
type DNSServiceManager struct {
instances map[uint]*DNSServiceInstance
rwMutex sync.RWMutex
}
var dnsServiceManager *DNSServiceManager
func init() {
dnsServiceManager = &DNSServiceManager{
instances: make(map[uint]*DNSServiceInstance),
}
}
// ==================== DNSServer 相关操作 ====================
func CreateDNSServer(req *proto.CreateDNSServerReq, userID uint) (proto.DNSServer, error) {
server := proto.DNSServer{
UserID: userID,
Name: req.Name,
Port: req.Port,
ListenIP: req.ListenIP,
UpstreamDNS: req.UpstreamDNS,
EnableRecursion: req.EnableRecursion,
Status: proto.DNS_STATUS_STOPPED,
Description: req.Description,
}
if server.Port == 0 {
server.Port = 53
}
if server.ListenIP == "" {
server.ListenIP = "0.0.0.0"
}
id, err := dao.CreateDNSServer(server)
if err != nil {
return proto.DNSServer{}, err
}
server.ID = id
return server, nil
}
func GetDNSServerList(req *proto.GetDNSServerReq, userID uint) ([]proto.DNSServer, error) {
var servers []proto.DNSServer
var err error
if req.ServerID > 0 {
server, err := dao.FindDNSServerByID(req.ServerID)
if err != nil {
return nil, err
}
if server.UserID != userID && GetUserByIDFromUserCenter(int(userID)).Role != "admin" {
return nil, errors.New("未授权访问DNS服务器")
}
servers = append(servers, server)
} else {
if req.GetType == 0 {
servers, err = dao.FindDNSServerByUserID(userID)
} else if req.GetType == 1 {
user := GetUserByIDFromUserCenter(int(userID))
if user.Role != "admin" {
return nil, errors.New("未授权访问仅管理员可获取全部DNS服务器")
}
servers, err = dao.FindAllDNSServer()
} else {
return nil, errors.New("无效的获取类型")
}
}
if err != nil {
return nil, err
}
return servers, nil
}
func UpdateDNSServer(req *proto.UpdateDNSServerReq, userID int) (proto.DNSServer, error) {
server, err := dao.FindDNSServerByID(req.ServerID)
if err != nil {
return proto.DNSServer{}, err
}
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
return proto.DNSServer{}, errors.New("未授权访问DNS服务器")
}
server.Name = req.Name
server.Port = req.Port
server.ListenIP = req.ListenIP
server.UpstreamDNS = req.UpstreamDNS
server.EnableRecursion = req.EnableRecursion
server.Status = req.Status
server.Description = req.Description
err = dao.UpdateDNSServer(server.ID, &server)
if err != nil {
return proto.DNSServer{}, err
}
// 如果服务正在运行,自动重启以应用新配置
go RestartDNSServerIfRunning(req.ServerID, userID)
return server, nil
}
func DeleteDNSServer(req *proto.DeleteDNSServerReq, userID int) error {
user := GetUserByIDFromUserCenter(userID)
if req.DelType == 0 && req.ServerID > 0 {
server, err := dao.FindDNSServerByID(req.ServerID)
if err != nil {
return err
}
if server.UserID != uint(userID) && user.Role != "admin" {
return errors.New("未授权访问DNS服务器")
}
err = dao.DeleteDNSZoneByServerID(req.ServerID)
if err != nil {
return err
}
err = dao.DeleteDNSServerByID(req.ServerID)
if err != nil {
return err
}
} else if req.DelType == 1 {
if user.Role != "admin" {
return errors.New("未授权访问仅管理员可删除所有DNS服务器")
}
servers, err := dao.FindAllDNSServer()
if err != nil {
return err
}
for _, server := range servers {
err = dao.DeleteDNSZoneByServerID(server.ID)
if err != nil {
return err
}
err = dao.DeleteDNSServerByID(server.ID)
if err != nil {
return err
}
}
} else {
return errors.New("无效的删除类型或参数")
}
return nil
}
// ==================== DNSZone 相关操作 ====================
func CreateDNSZone(req *proto.CreateDNSZoneReq, userID uint) (proto.DNSZone, error) {
server, err := dao.FindDNSServerByID(req.ServerID)
if err != nil {
return proto.DNSZone{}, err
}
if server.UserID != userID && GetUserByIDFromUserCenter(int(userID)).Role != "admin" {
return proto.DNSZone{}, errors.New("未授权访问DNS服务器")
}
zone := proto.DNSZone{
ServerID: req.ServerID,
Domain: req.Domain,
SOA_MName: req.SOA_MName,
SOA_RName: req.SOA_RName,
SOA_Serial: req.SOA_Serial,
SOA_Refresh: req.SOA_Refresh,
SOA_Retry: req.SOA_Retry,
SOA_Expire: req.SOA_Expire,
SOA_Minimum: req.SOA_Minimum,
TTL: req.TTL,
Description: req.Description,
}
if zone.SOA_Serial == 0 {
zone.SOA_Serial = 1
}
if zone.SOA_Refresh == 0 {
zone.SOA_Refresh = 86400
}
if zone.SOA_Retry == 0 {
zone.SOA_Retry = 7200
}
if zone.SOA_Expire == 0 {
zone.SOA_Expire = 3600000
}
if zone.SOA_Minimum == 0 {
zone.SOA_Minimum = 3600
}
if zone.TTL == 0 {
zone.TTL = 3600
}
id, err := dao.CreateDNSZone(zone)
if err != nil {
return proto.DNSZone{}, err
}
zone.ID = id
return zone, nil
}
func GetDNSZoneList(req *proto.GetDNSZoneReq, userID uint) ([]proto.DNSZone, error) {
var zones []proto.DNSZone
var err error
if req.ZoneID > 0 {
zone, err := dao.FindDNSZoneByID(req.ZoneID)
if err != nil {
return nil, err
}
server, err := dao.FindDNSServerByID(zone.ServerID)
if err != nil {
return nil, err
}
if server.UserID != userID && GetUserByIDFromUserCenter(int(userID)).Role != "admin" {
return nil, errors.New("未授权访问DNS区域")
}
zones = append(zones, zone)
} else {
if req.GetType == 0 {
if req.ServerID > 0 {
zones, err = dao.FindDNSZoneByServerID(req.ServerID)
} else {
servers, err := dao.FindDNSServerByUserID(userID)
if err != nil {
return nil, err
}
for _, server := range servers {
serverZones, err := dao.FindDNSZoneByServerID(server.ID)
if err != nil {
return nil, err
}
zones = append(zones, serverZones...)
}
}
} else if req.GetType == 1 {
user := GetUserByIDFromUserCenter(int(userID))
if user.Role != "admin" {
return nil, errors.New("未授权访问仅管理员可获取全部DNS区域")
}
if req.ServerID > 0 {
zones, err = dao.FindDNSZoneByServerID(req.ServerID)
} else {
zones, err = dao.FindAllDNSZone()
}
} else {
return nil, errors.New("无效的获取类型")
}
}
if err != nil {
return nil, err
}
return zones, nil
}
func UpdateDNSZone(req *proto.UpdateDNSZoneReq, userID int) (proto.DNSZone, error) {
zone, err := dao.FindDNSZoneByID(req.ZoneID)
if err != nil {
return proto.DNSZone{}, err
}
server, err := dao.FindDNSServerByID(zone.ServerID)
if err != nil {
return proto.DNSZone{}, err
}
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
return proto.DNSZone{}, errors.New("未授权访问DNS区域")
}
zone.Domain = req.Domain
zone.SOA_MName = req.SOA_MName
zone.SOA_RName = req.SOA_RName
zone.SOA_Serial = req.SOA_Serial
zone.SOA_Refresh = req.SOA_Refresh
zone.SOA_Retry = req.SOA_Retry
zone.SOA_Expire = req.SOA_Expire
zone.SOA_Minimum = req.SOA_Minimum
zone.TTL = req.TTL
zone.Description = req.Description
err = dao.UpdateDNSZone(zone.ID, &zone)
if err != nil {
return proto.DNSZone{}, err
}
return zone, nil
}
func DeleteDNSZone(req *proto.DeleteDNSZoneReq, userID int) error {
user := GetUserByIDFromUserCenter(userID)
if req.DelType == 0 && req.ZoneID > 0 {
zone, err := dao.FindDNSZoneByID(req.ZoneID)
if err != nil {
return err
}
server, err := dao.FindDNSServerByID(zone.ServerID)
if err != nil {
return err
}
if server.UserID != uint(userID) && user.Role != "admin" {
return errors.New("未授权访问DNS区域")
}
err = dao.DeleteDNSRecordByZoneID(req.ZoneID)
if err != nil {
return err
}
err = dao.DeleteDNSZoneByID(req.ZoneID)
if err != nil {
return err
}
} else if req.DelType == 1 {
if user.Role != "admin" {
return errors.New("未授权访问仅管理员可删除所有DNS区域")
}
zones, err := dao.FindAllDNSZone()
if err != nil {
return err
}
for _, zone := range zones {
err = dao.DeleteDNSRecordByZoneID(zone.ID)
if err != nil {
return err
}
err = dao.DeleteDNSZoneByID(zone.ID)
if err != nil {
return err
}
}
} else {
return errors.New("无效的删除类型或参数")
}
return nil
}
// ==================== DNSRecord 相关操作 ====================
func CreateDNSRecord(req *proto.CreateDNSRecordReq, userID uint) (proto.DNSRecord, error) {
zone, err := dao.FindDNSZoneByID(req.ZoneID)
if err != nil {
return proto.DNSRecord{}, err
}
server, err := dao.FindDNSServerByID(zone.ServerID)
if err != nil {
return proto.DNSRecord{}, err
}
if server.UserID != userID && GetUserByIDFromUserCenter(int(userID)).Role != "admin" {
return proto.DNSRecord{}, errors.New("未授权访问DNS区域")
}
record := proto.DNSRecord{
ZoneID: req.ZoneID,
Name: req.Name,
Type: req.Type,
Value: req.Value,
TTL: req.TTL,
Priority: req.Priority,
Weight: req.Weight,
Port: req.Port,
Target: req.Target,
}
if record.TTL == 0 {
record.TTL = zone.TTL
}
id, err := dao.CreateDNSRecord(record)
if err != nil {
return proto.DNSRecord{}, err
}
record.ID = id
return record, nil
}
func GetDNSRecordList(req *proto.GetDNSRecordReq, userID uint) ([]proto.DNSRecord, error) {
var records []proto.DNSRecord
var err error
if req.RecordID > 0 {
record, err := dao.FindDNSRecordByID(req.RecordID)
if err != nil {
return nil, err
}
zone, err := dao.FindDNSZoneByID(record.ZoneID)
if err != nil {
return nil, err
}
server, err := dao.FindDNSServerByID(zone.ServerID)
if err != nil {
return nil, err
}
if server.UserID != userID && GetUserByIDFromUserCenter(int(userID)).Role != "admin" {
return nil, errors.New("未授权访问DNS记录")
}
records = append(records, record)
} else {
if req.GetType == 0 {
if req.ZoneID > 0 {
records, err = dao.FindDNSRecordByZoneID(req.ZoneID)
} else {
servers, err := dao.FindDNSServerByUserID(userID)
if err != nil {
return nil, err
}
for _, server := range servers {
zones, err := dao.FindDNSZoneByServerID(server.ID)
if err != nil {
return nil, err
}
for _, zone := range zones {
zoneRecords, err := dao.FindDNSRecordByZoneID(zone.ID)
if err != nil {
return nil, err
}
records = append(records, zoneRecords...)
}
}
}
} else if req.GetType == 1 {
user := GetUserByIDFromUserCenter(int(userID))
if user.Role != "admin" {
return nil, errors.New("未授权访问仅管理员可获取全部DNS记录")
}
if req.ZoneID > 0 {
records, err = dao.FindDNSRecordByZoneID(req.ZoneID)
} else {
records, err = dao.FindAllDNSRecord()
}
} else {
return nil, errors.New("无效的获取类型")
}
}
if err != nil {
return nil, err
}
return records, nil
}
func UpdateDNSRecord(req *proto.UpdateDNSRecordReq, userID int) (proto.DNSRecord, error) {
record, err := dao.FindDNSRecordByID(req.RecordID)
if err != nil {
return proto.DNSRecord{}, err
}
zone, err := dao.FindDNSZoneByID(record.ZoneID)
if err != nil {
return proto.DNSRecord{}, err
}
server, err := dao.FindDNSServerByID(zone.ServerID)
if err != nil {
return proto.DNSRecord{}, err
}
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
return proto.DNSRecord{}, errors.New("未授权访问DNS记录")
}
record.Name = req.Name
record.Type = req.Type
record.Value = req.Value
record.TTL = req.TTL
record.Priority = req.Priority
record.Weight = req.Weight
record.Port = req.Port
record.Target = req.Target
err = dao.UpdateDNSRecord(record.ID, &record)
if err != nil {
return proto.DNSRecord{}, err
}
return record, nil
}
func DeleteDNSRecord(req *proto.DeleteDNSRecordReq, userID int) error {
user := GetUserByIDFromUserCenter(userID)
if req.DelType == 0 && req.RecordID > 0 {
record, err := dao.FindDNSRecordByID(req.RecordID)
if err != nil {
return err
}
zone, err := dao.FindDNSZoneByID(record.ZoneID)
if err != nil {
return err
}
server, err := dao.FindDNSServerByID(zone.ServerID)
if err != nil {
return err
}
if server.UserID != uint(userID) && user.Role != "admin" {
return errors.New("未授权访问DNS记录")
}
err = dao.DeleteDNSRecordByID(req.RecordID)
if err != nil {
return err
}
} else if req.DelType == 1 {
if user.Role != "admin" {
return errors.New("未授权访问仅管理员可删除所有DNS记录")
}
records, err := dao.FindAllDNSRecord()
if err != nil {
return err
}
for _, record := range records {
err = dao.DeleteDNSRecordByID(record.ID)
if err != nil {
return err
}
}
} else {
return errors.New("无效的删除类型或参数")
}
return nil
}
// ==================== DNS 服务运行管理 ====================
// handleDNSRequest 处理DNS查询请求
func (instance *DNSServiceInstance) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
for _, q := range r.Question {
// 查找对应的Zone
zones, err := dao.FindDNSZoneByServerID(instance.ServerID)
if err != nil {
continue
}
var matchedZone *proto.DNSZone
var qName string = dns.Fqdn(q.Name)
for _, zone := range zones {
zoneFqdn := dns.Fqdn(zone.Domain)
if dns.IsSubDomain(zoneFqdn, qName) {
matchedZone = &zone
break
}
}
if matchedZone == nil {
// 递归查询
if instance.Config.EnableRecursion {
c := new(dns.Client)
upstreams := instance.Config.UpstreamDNS
if upstreams == "" {
upstreams = "8.8.8.8:53"
}
resp, _, err := c.Exchange(r, upstreams)
if err == nil {
m.Answer = append(m.Answer, resp.Answer...)
}
}
continue
}
// 查找匹配的记录
records, err := dao.FindDNSRecordByZoneID(matchedZone.ID)
if err != nil {
continue
}
for _, record := range records {
recordName := dns.Fqdn(record.Name + "." + matchedZone.Domain)
if record.Name == "@" {
recordName = dns.Fqdn(matchedZone.Domain)
}
if recordName == qName && record.Type == uint(q.Qtype) {
var rr dns.RR
var ttl uint32
if record.TTL > 0 {
ttl = uint32(record.TTL)
} else {
ttl = uint32(matchedZone.TTL)
}
switch q.Qtype {
case dns.TypeA:
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN A %s", q.Name, ttl, record.Value))
case dns.TypeAAAA:
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN AAAA %s", q.Name, ttl, record.Value))
case dns.TypeCNAME:
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", q.Name, ttl, record.Value))
case dns.TypeMX:
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN MX %d %s", q.Name, ttl, record.Priority, record.Value))
case dns.TypeNS:
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN NS %s", q.Name, ttl, record.Value))
case dns.TypeTXT:
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN TXT \"%s\"", q.Name, ttl, record.Value))
case dns.TypeSRV:
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN SRV %d %d %d %s", q.Name, ttl, record.Priority, record.Weight, record.Port, record.Target))
}
if rr != nil {
m.Answer = append(m.Answer, rr)
}
}
}
}
w.WriteMsg(m)
}
// StartDNSServer 启动DNS服务
func StartDNSServer(serverID uint, userID int) (*DNSServiceInstance, error) {
// 检查权限
server, err := dao.FindDNSServerByID(serverID)
if err != nil {
return nil, err
}
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
return nil, errors.New("未授权访问DNS服务器")
}
dnsServiceManager.rwMutex.Lock()
defer dnsServiceManager.rwMutex.Unlock()
// 检查是否已经在运行
if instance, exists := dnsServiceManager.instances[serverID]; exists && instance.Running {
return instance, nil
}
// 创建服务实例
instance := &DNSServiceInstance{
ServerID: serverID,
Config: &server,
Running: false,
}
// 创建DNS服务器
dns.HandleFunc(".", instance.handleDNSRequest)
addr := fmt.Sprintf("%s:%d", server.ListenIP, server.Port)
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
udpConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return nil, err
}
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
udpConn.Close()
return nil, err
}
tcpListener, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
udpConn.Close()
return nil, err
}
instance.Server = &dns.Server{
Listener: tcpListener,
PacketConn: udpConn,
}
// 启动服务
go func() {
instance.mutex.Lock()
instance.Running = true
instance.mutex.Unlock()
err := instance.Server.ActivateAndServe()
if err != nil {
instance.mutex.Lock()
instance.Running = false
instance.mutex.Unlock()
}
}()
// 更新数据库状态
server.Status = proto.DNS_STATUS_RUNNING
dao.UpdateDNSServer(serverID, &server)
dnsServiceManager.instances[serverID] = instance
return instance, nil
}
// StopDNSServer 停止DNS服务
func StopDNSServer(serverID uint, userID int) error {
server, err := dao.FindDNSServerByID(serverID)
if err != nil {
return err
}
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
return errors.New("未授权访问DNS服务器")
}
dnsServiceManager.rwMutex.Lock()
defer dnsServiceManager.rwMutex.Unlock()
instance, exists := dnsServiceManager.instances[serverID]
if !exists || !instance.Running {
return nil
}
// 停止服务
err = instance.Server.Shutdown()
if err != nil {
return err
}
instance.mutex.Lock()
instance.Running = false
instance.mutex.Unlock()
// 更新数据库状态
server.Status = proto.DNS_STATUS_STOPPED
dao.UpdateDNSServer(serverID, &server)
delete(dnsServiceManager.instances, serverID)
return nil
}
// RestartDNSServer 重启DNS服务
func RestartDNSServer(serverID uint, userID int) (*DNSServiceInstance, error) {
err := StopDNSServer(serverID, userID)
if err != nil {
return nil, err
}
return StartDNSServer(serverID, userID)
}
// GetDNSServerStatus 获取DNS服务运行状态
func GetDNSServerStatus(serverID uint, userID int) (*DNSServiceInstance, error) {
server, err := dao.FindDNSServerByID(serverID)
if err != nil {
return nil, err
}
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
return nil, errors.New("未授权访问DNS服务器")
}
dnsServiceManager.rwMutex.RLock()
defer dnsServiceManager.rwMutex.RUnlock()
instance, exists := dnsServiceManager.instances[serverID]
if exists {
return instance, nil
}
// 不在运行中,返回配置中的状态
return &DNSServiceInstance{
ServerID: serverID,
Config: &server,
Running: server.Status == proto.DNS_STATUS_RUNNING,
}, nil
}
// RestartDNSServerIfRunning 如果服务正在运行则重启(配置更新时调用)
func RestartDNSServerIfRunning(serverID uint, userID int) error {
dnsServiceManager.rwMutex.RLock()
instance, exists := dnsServiceManager.instances[serverID]
dnsServiceManager.rwMutex.RUnlock()
if exists && instance.Running {
// 获取最新配置
server, err := dao.FindDNSServerByID(serverID)
if err != nil {
return err
}
// 重启服务(管理员权限)
_, err = RestartDNSServer(server.ID, userID) // 假设管理员ID为1
return err
}
return nil
}