787 lines
19 KiB
Go
787 lines
19 KiB
Go
package service
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"net"
|
||
"sync"
|
||
"videoplayer/dao"
|
||
"videoplayer/proto"
|
||
|
||
"github.com/miekg/dns"
|
||
)
|
||
|
||
// DNSServiceInstance 运行中的DNS服务实例
|
||
type DNSServiceInstance struct {
|
||
ServerID uint `json:"server_id"`
|
||
Config *proto.DNSServer `json:"config"`
|
||
Server *dns.Server `json:"-"`
|
||
Running bool `json:"running"`
|
||
mutex sync.Mutex `json:"-"`
|
||
}
|
||
|
||
// DNSServiceManager DNS服务管理器
|
||
type DNSServiceManager struct {
|
||
instances map[uint]*DNSServiceInstance
|
||
rwMutex sync.RWMutex
|
||
}
|
||
|
||
var dnsServiceManager *DNSServiceManager
|
||
|
||
func init() {
|
||
dnsServiceManager = &DNSServiceManager{
|
||
instances: make(map[uint]*DNSServiceInstance),
|
||
}
|
||
}
|
||
|
||
// ==================== DNSServer 相关操作 ====================
|
||
|
||
func CreateDNSServer(req *proto.CreateDNSServerReq, userID uint) (proto.DNSServer, error) {
|
||
server := proto.DNSServer{
|
||
UserID: userID,
|
||
Name: req.Name,
|
||
Port: req.Port,
|
||
ListenIP: req.ListenIP,
|
||
UpstreamDNS: req.UpstreamDNS,
|
||
EnableRecursion: req.EnableRecursion,
|
||
Status: proto.DNS_STATUS_STOPPED,
|
||
Description: req.Description,
|
||
}
|
||
|
||
if server.Port == 0 {
|
||
server.Port = 53
|
||
}
|
||
if server.ListenIP == "" {
|
||
server.ListenIP = "0.0.0.0"
|
||
}
|
||
|
||
id, err := dao.CreateDNSServer(server)
|
||
if err != nil {
|
||
return proto.DNSServer{}, err
|
||
}
|
||
server.ID = id
|
||
return server, nil
|
||
}
|
||
|
||
func GetDNSServerList(req *proto.GetDNSServerReq, userID uint) ([]proto.DNSServer, error) {
|
||
var servers []proto.DNSServer
|
||
var err error
|
||
|
||
if req.ServerID > 0 {
|
||
server, err := dao.FindDNSServerByID(req.ServerID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if server.UserID != userID && GetUserByIDFromUserCenter(int(userID)).Role != "admin" {
|
||
return nil, errors.New("未授权访问DNS服务器")
|
||
}
|
||
servers = append(servers, server)
|
||
} else {
|
||
if req.GetType == 0 {
|
||
servers, err = dao.FindDNSServerByUserID(userID)
|
||
} else if req.GetType == 1 {
|
||
user := GetUserByIDFromUserCenter(int(userID))
|
||
if user.Role != "admin" {
|
||
return nil, errors.New("未授权访问,仅管理员可获取全部DNS服务器")
|
||
}
|
||
servers, err = dao.FindAllDNSServer()
|
||
} else {
|
||
return nil, errors.New("无效的获取类型")
|
||
}
|
||
}
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return servers, nil
|
||
}
|
||
|
||
func UpdateDNSServer(req *proto.UpdateDNSServerReq, userID int) (proto.DNSServer, error) {
|
||
server, err := dao.FindDNSServerByID(req.ServerID)
|
||
if err != nil {
|
||
return proto.DNSServer{}, err
|
||
}
|
||
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
|
||
return proto.DNSServer{}, errors.New("未授权访问DNS服务器")
|
||
}
|
||
|
||
server.Name = req.Name
|
||
server.Port = req.Port
|
||
server.ListenIP = req.ListenIP
|
||
server.UpstreamDNS = req.UpstreamDNS
|
||
server.EnableRecursion = req.EnableRecursion
|
||
server.Status = req.Status
|
||
server.Description = req.Description
|
||
|
||
err = dao.UpdateDNSServer(server.ID, &server)
|
||
if err != nil {
|
||
return proto.DNSServer{}, err
|
||
}
|
||
|
||
// 如果服务正在运行,自动重启以应用新配置
|
||
go RestartDNSServerIfRunning(req.ServerID, userID)
|
||
|
||
return server, nil
|
||
}
|
||
|
||
func DeleteDNSServer(req *proto.DeleteDNSServerReq, userID int) error {
|
||
user := GetUserByIDFromUserCenter(userID)
|
||
if req.DelType == 0 && req.ServerID > 0 {
|
||
server, err := dao.FindDNSServerByID(req.ServerID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if server.UserID != uint(userID) && user.Role != "admin" {
|
||
return errors.New("未授权访问DNS服务器")
|
||
}
|
||
err = dao.DeleteDNSZoneByServerID(req.ServerID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
err = dao.DeleteDNSServerByID(req.ServerID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
} else if req.DelType == 1 {
|
||
if user.Role != "admin" {
|
||
return errors.New("未授权访问,仅管理员可删除所有DNS服务器")
|
||
}
|
||
servers, err := dao.FindAllDNSServer()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
for _, server := range servers {
|
||
err = dao.DeleteDNSZoneByServerID(server.ID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
err = dao.DeleteDNSServerByID(server.ID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
} else {
|
||
return errors.New("无效的删除类型或参数")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ==================== DNSZone 相关操作 ====================
|
||
|
||
func CreateDNSZone(req *proto.CreateDNSZoneReq, userID uint) (proto.DNSZone, error) {
|
||
server, err := dao.FindDNSServerByID(req.ServerID)
|
||
if err != nil {
|
||
return proto.DNSZone{}, err
|
||
}
|
||
if server.UserID != userID && GetUserByIDFromUserCenter(int(userID)).Role != "admin" {
|
||
return proto.DNSZone{}, errors.New("未授权访问DNS服务器")
|
||
}
|
||
|
||
zone := proto.DNSZone{
|
||
ServerID: req.ServerID,
|
||
Domain: req.Domain,
|
||
SOA_MName: req.SOA_MName,
|
||
SOA_RName: req.SOA_RName,
|
||
SOA_Serial: req.SOA_Serial,
|
||
SOA_Refresh: req.SOA_Refresh,
|
||
SOA_Retry: req.SOA_Retry,
|
||
SOA_Expire: req.SOA_Expire,
|
||
SOA_Minimum: req.SOA_Minimum,
|
||
TTL: req.TTL,
|
||
Description: req.Description,
|
||
}
|
||
|
||
if zone.SOA_Serial == 0 {
|
||
zone.SOA_Serial = 1
|
||
}
|
||
if zone.SOA_Refresh == 0 {
|
||
zone.SOA_Refresh = 86400
|
||
}
|
||
if zone.SOA_Retry == 0 {
|
||
zone.SOA_Retry = 7200
|
||
}
|
||
if zone.SOA_Expire == 0 {
|
||
zone.SOA_Expire = 3600000
|
||
}
|
||
if zone.SOA_Minimum == 0 {
|
||
zone.SOA_Minimum = 3600
|
||
}
|
||
if zone.TTL == 0 {
|
||
zone.TTL = 3600
|
||
}
|
||
|
||
id, err := dao.CreateDNSZone(zone)
|
||
if err != nil {
|
||
return proto.DNSZone{}, err
|
||
}
|
||
zone.ID = id
|
||
return zone, nil
|
||
}
|
||
|
||
func GetDNSZoneList(req *proto.GetDNSZoneReq, userID uint) ([]proto.DNSZone, error) {
|
||
var zones []proto.DNSZone
|
||
var err error
|
||
|
||
if req.ZoneID > 0 {
|
||
zone, err := dao.FindDNSZoneByID(req.ZoneID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
server, err := dao.FindDNSServerByID(zone.ServerID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if server.UserID != userID && GetUserByIDFromUserCenter(int(userID)).Role != "admin" {
|
||
return nil, errors.New("未授权访问DNS区域")
|
||
}
|
||
zones = append(zones, zone)
|
||
} else {
|
||
if req.GetType == 0 {
|
||
if req.ServerID > 0 {
|
||
zones, err = dao.FindDNSZoneByServerID(req.ServerID)
|
||
} else {
|
||
servers, err := dao.FindDNSServerByUserID(userID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, server := range servers {
|
||
serverZones, err := dao.FindDNSZoneByServerID(server.ID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
zones = append(zones, serverZones...)
|
||
}
|
||
}
|
||
} else if req.GetType == 1 {
|
||
user := GetUserByIDFromUserCenter(int(userID))
|
||
if user.Role != "admin" {
|
||
return nil, errors.New("未授权访问,仅管理员可获取全部DNS区域")
|
||
}
|
||
if req.ServerID > 0 {
|
||
zones, err = dao.FindDNSZoneByServerID(req.ServerID)
|
||
} else {
|
||
zones, err = dao.FindAllDNSZone()
|
||
}
|
||
} else {
|
||
return nil, errors.New("无效的获取类型")
|
||
}
|
||
}
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return zones, nil
|
||
}
|
||
|
||
func UpdateDNSZone(req *proto.UpdateDNSZoneReq, userID int) (proto.DNSZone, error) {
|
||
zone, err := dao.FindDNSZoneByID(req.ZoneID)
|
||
if err != nil {
|
||
return proto.DNSZone{}, err
|
||
}
|
||
server, err := dao.FindDNSServerByID(zone.ServerID)
|
||
if err != nil {
|
||
return proto.DNSZone{}, err
|
||
}
|
||
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
|
||
return proto.DNSZone{}, errors.New("未授权访问DNS区域")
|
||
}
|
||
|
||
zone.Domain = req.Domain
|
||
zone.SOA_MName = req.SOA_MName
|
||
zone.SOA_RName = req.SOA_RName
|
||
zone.SOA_Serial = req.SOA_Serial
|
||
zone.SOA_Refresh = req.SOA_Refresh
|
||
zone.SOA_Retry = req.SOA_Retry
|
||
zone.SOA_Expire = req.SOA_Expire
|
||
zone.SOA_Minimum = req.SOA_Minimum
|
||
zone.TTL = req.TTL
|
||
zone.Description = req.Description
|
||
|
||
err = dao.UpdateDNSZone(zone.ID, &zone)
|
||
if err != nil {
|
||
return proto.DNSZone{}, err
|
||
}
|
||
return zone, nil
|
||
}
|
||
|
||
func DeleteDNSZone(req *proto.DeleteDNSZoneReq, userID int) error {
|
||
user := GetUserByIDFromUserCenter(userID)
|
||
if req.DelType == 0 && req.ZoneID > 0 {
|
||
zone, err := dao.FindDNSZoneByID(req.ZoneID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
server, err := dao.FindDNSServerByID(zone.ServerID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if server.UserID != uint(userID) && user.Role != "admin" {
|
||
return errors.New("未授权访问DNS区域")
|
||
}
|
||
err = dao.DeleteDNSRecordByZoneID(req.ZoneID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
err = dao.DeleteDNSZoneByID(req.ZoneID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
} else if req.DelType == 1 {
|
||
if user.Role != "admin" {
|
||
return errors.New("未授权访问,仅管理员可删除所有DNS区域")
|
||
}
|
||
zones, err := dao.FindAllDNSZone()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
for _, zone := range zones {
|
||
err = dao.DeleteDNSRecordByZoneID(zone.ID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
err = dao.DeleteDNSZoneByID(zone.ID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
} else {
|
||
return errors.New("无效的删除类型或参数")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ==================== DNSRecord 相关操作 ====================
|
||
|
||
func CreateDNSRecord(req *proto.CreateDNSRecordReq, userID uint) (proto.DNSRecord, error) {
|
||
zone, err := dao.FindDNSZoneByID(req.ZoneID)
|
||
if err != nil {
|
||
return proto.DNSRecord{}, err
|
||
}
|
||
server, err := dao.FindDNSServerByID(zone.ServerID)
|
||
if err != nil {
|
||
return proto.DNSRecord{}, err
|
||
}
|
||
if server.UserID != userID && GetUserByIDFromUserCenter(int(userID)).Role != "admin" {
|
||
return proto.DNSRecord{}, errors.New("未授权访问DNS区域")
|
||
}
|
||
|
||
record := proto.DNSRecord{
|
||
ZoneID: req.ZoneID,
|
||
Name: req.Name,
|
||
Type: req.Type,
|
||
Value: req.Value,
|
||
TTL: req.TTL,
|
||
Priority: req.Priority,
|
||
Weight: req.Weight,
|
||
Port: req.Port,
|
||
Target: req.Target,
|
||
}
|
||
|
||
if record.TTL == 0 {
|
||
record.TTL = zone.TTL
|
||
}
|
||
|
||
id, err := dao.CreateDNSRecord(record)
|
||
if err != nil {
|
||
return proto.DNSRecord{}, err
|
||
}
|
||
record.ID = id
|
||
return record, nil
|
||
}
|
||
|
||
func GetDNSRecordList(req *proto.GetDNSRecordReq, userID uint) ([]proto.DNSRecord, error) {
|
||
var records []proto.DNSRecord
|
||
var err error
|
||
|
||
if req.RecordID > 0 {
|
||
record, err := dao.FindDNSRecordByID(req.RecordID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
zone, err := dao.FindDNSZoneByID(record.ZoneID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
server, err := dao.FindDNSServerByID(zone.ServerID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if server.UserID != userID && GetUserByIDFromUserCenter(int(userID)).Role != "admin" {
|
||
return nil, errors.New("未授权访问DNS记录")
|
||
}
|
||
records = append(records, record)
|
||
} else {
|
||
if req.GetType == 0 {
|
||
if req.ZoneID > 0 {
|
||
records, err = dao.FindDNSRecordByZoneID(req.ZoneID)
|
||
} else {
|
||
servers, err := dao.FindDNSServerByUserID(userID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, server := range servers {
|
||
zones, err := dao.FindDNSZoneByServerID(server.ID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
for _, zone := range zones {
|
||
zoneRecords, err := dao.FindDNSRecordByZoneID(zone.ID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
records = append(records, zoneRecords...)
|
||
}
|
||
}
|
||
}
|
||
} else if req.GetType == 1 {
|
||
user := GetUserByIDFromUserCenter(int(userID))
|
||
if user.Role != "admin" {
|
||
return nil, errors.New("未授权访问,仅管理员可获取全部DNS记录")
|
||
}
|
||
if req.ZoneID > 0 {
|
||
records, err = dao.FindDNSRecordByZoneID(req.ZoneID)
|
||
} else {
|
||
records, err = dao.FindAllDNSRecord()
|
||
}
|
||
} else {
|
||
return nil, errors.New("无效的获取类型")
|
||
}
|
||
}
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return records, nil
|
||
}
|
||
|
||
func UpdateDNSRecord(req *proto.UpdateDNSRecordReq, userID int) (proto.DNSRecord, error) {
|
||
record, err := dao.FindDNSRecordByID(req.RecordID)
|
||
if err != nil {
|
||
return proto.DNSRecord{}, err
|
||
}
|
||
zone, err := dao.FindDNSZoneByID(record.ZoneID)
|
||
if err != nil {
|
||
return proto.DNSRecord{}, err
|
||
}
|
||
server, err := dao.FindDNSServerByID(zone.ServerID)
|
||
if err != nil {
|
||
return proto.DNSRecord{}, err
|
||
}
|
||
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
|
||
return proto.DNSRecord{}, errors.New("未授权访问DNS记录")
|
||
}
|
||
|
||
record.Name = req.Name
|
||
record.Type = req.Type
|
||
record.Value = req.Value
|
||
record.TTL = req.TTL
|
||
record.Priority = req.Priority
|
||
record.Weight = req.Weight
|
||
record.Port = req.Port
|
||
record.Target = req.Target
|
||
|
||
err = dao.UpdateDNSRecord(record.ID, &record)
|
||
if err != nil {
|
||
return proto.DNSRecord{}, err
|
||
}
|
||
return record, nil
|
||
}
|
||
|
||
func DeleteDNSRecord(req *proto.DeleteDNSRecordReq, userID int) error {
|
||
user := GetUserByIDFromUserCenter(userID)
|
||
if req.DelType == 0 && req.RecordID > 0 {
|
||
record, err := dao.FindDNSRecordByID(req.RecordID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
zone, err := dao.FindDNSZoneByID(record.ZoneID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
server, err := dao.FindDNSServerByID(zone.ServerID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if server.UserID != uint(userID) && user.Role != "admin" {
|
||
return errors.New("未授权访问DNS记录")
|
||
}
|
||
err = dao.DeleteDNSRecordByID(req.RecordID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
} else if req.DelType == 1 {
|
||
if user.Role != "admin" {
|
||
return errors.New("未授权访问,仅管理员可删除所有DNS记录")
|
||
}
|
||
records, err := dao.FindAllDNSRecord()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
for _, record := range records {
|
||
err = dao.DeleteDNSRecordByID(record.ID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
} else {
|
||
return errors.New("无效的删除类型或参数")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ==================== DNS 服务运行管理 ====================
|
||
|
||
// handleDNSRequest 处理DNS查询请求
|
||
func (instance *DNSServiceInstance) handleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||
m := new(dns.Msg)
|
||
m.SetReply(r)
|
||
m.Authoritative = true
|
||
|
||
for _, q := range r.Question {
|
||
// 查找对应的Zone
|
||
zones, err := dao.FindDNSZoneByServerID(instance.ServerID)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
|
||
var matchedZone *proto.DNSZone
|
||
var qName string = dns.Fqdn(q.Name)
|
||
|
||
for _, zone := range zones {
|
||
zoneFqdn := dns.Fqdn(zone.Domain)
|
||
if dns.IsSubDomain(zoneFqdn, qName) {
|
||
matchedZone = &zone
|
||
break
|
||
}
|
||
}
|
||
|
||
if matchedZone == nil {
|
||
// 递归查询
|
||
if instance.Config.EnableRecursion {
|
||
c := new(dns.Client)
|
||
upstreams := instance.Config.UpstreamDNS
|
||
if upstreams == "" {
|
||
upstreams = "8.8.8.8:53"
|
||
}
|
||
resp, _, err := c.Exchange(r, upstreams)
|
||
if err == nil {
|
||
m.Answer = append(m.Answer, resp.Answer...)
|
||
}
|
||
}
|
||
continue
|
||
}
|
||
|
||
// 查找匹配的记录
|
||
records, err := dao.FindDNSRecordByZoneID(matchedZone.ID)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
|
||
for _, record := range records {
|
||
recordName := dns.Fqdn(record.Name + "." + matchedZone.Domain)
|
||
if record.Name == "@" {
|
||
recordName = dns.Fqdn(matchedZone.Domain)
|
||
}
|
||
|
||
if recordName == qName && record.Type == uint(q.Qtype) {
|
||
var rr dns.RR
|
||
var ttl uint32
|
||
if record.TTL > 0 {
|
||
ttl = uint32(record.TTL)
|
||
} else {
|
||
ttl = uint32(matchedZone.TTL)
|
||
}
|
||
|
||
switch q.Qtype {
|
||
case dns.TypeA:
|
||
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN A %s", q.Name, ttl, record.Value))
|
||
case dns.TypeAAAA:
|
||
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN AAAA %s", q.Name, ttl, record.Value))
|
||
case dns.TypeCNAME:
|
||
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN CNAME %s", q.Name, ttl, record.Value))
|
||
case dns.TypeMX:
|
||
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN MX %d %s", q.Name, ttl, record.Priority, record.Value))
|
||
case dns.TypeNS:
|
||
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN NS %s", q.Name, ttl, record.Value))
|
||
case dns.TypeTXT:
|
||
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN TXT \"%s\"", q.Name, ttl, record.Value))
|
||
case dns.TypeSRV:
|
||
rr, _ = dns.NewRR(fmt.Sprintf("%s %d IN SRV %d %d %d %s", q.Name, ttl, record.Priority, record.Weight, record.Port, record.Target))
|
||
}
|
||
|
||
if rr != nil {
|
||
m.Answer = append(m.Answer, rr)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
w.WriteMsg(m)
|
||
}
|
||
|
||
// StartDNSServer 启动DNS服务
|
||
func StartDNSServer(serverID uint, userID int) (*DNSServiceInstance, error) {
|
||
// 检查权限
|
||
server, err := dao.FindDNSServerByID(serverID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
|
||
return nil, errors.New("未授权访问DNS服务器")
|
||
}
|
||
|
||
dnsServiceManager.rwMutex.Lock()
|
||
defer dnsServiceManager.rwMutex.Unlock()
|
||
|
||
// 检查是否已经在运行
|
||
if instance, exists := dnsServiceManager.instances[serverID]; exists && instance.Running {
|
||
return instance, nil
|
||
}
|
||
|
||
// 创建服务实例
|
||
instance := &DNSServiceInstance{
|
||
ServerID: serverID,
|
||
Config: &server,
|
||
Running: false,
|
||
}
|
||
|
||
// 创建DNS服务器
|
||
dns.HandleFunc(".", instance.handleDNSRequest)
|
||
|
||
addr := fmt.Sprintf("%s:%d", server.ListenIP, server.Port)
|
||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
udpConn, err := net.ListenUDP("udp", udpAddr)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
|
||
if err != nil {
|
||
udpConn.Close()
|
||
return nil, err
|
||
}
|
||
|
||
tcpListener, err := net.ListenTCP("tcp", tcpAddr)
|
||
if err != nil {
|
||
udpConn.Close()
|
||
return nil, err
|
||
}
|
||
|
||
instance.Server = &dns.Server{
|
||
Listener: tcpListener,
|
||
PacketConn: udpConn,
|
||
}
|
||
|
||
// 启动服务
|
||
go func() {
|
||
instance.mutex.Lock()
|
||
instance.Running = true
|
||
instance.mutex.Unlock()
|
||
|
||
err := instance.Server.ActivateAndServe()
|
||
if err != nil {
|
||
instance.mutex.Lock()
|
||
instance.Running = false
|
||
instance.mutex.Unlock()
|
||
}
|
||
}()
|
||
|
||
// 更新数据库状态
|
||
server.Status = proto.DNS_STATUS_RUNNING
|
||
dao.UpdateDNSServer(serverID, &server)
|
||
|
||
dnsServiceManager.instances[serverID] = instance
|
||
return instance, nil
|
||
}
|
||
|
||
// StopDNSServer 停止DNS服务
|
||
func StopDNSServer(serverID uint, userID int) error {
|
||
server, err := dao.FindDNSServerByID(serverID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
|
||
return errors.New("未授权访问DNS服务器")
|
||
}
|
||
|
||
dnsServiceManager.rwMutex.Lock()
|
||
defer dnsServiceManager.rwMutex.Unlock()
|
||
|
||
instance, exists := dnsServiceManager.instances[serverID]
|
||
if !exists || !instance.Running {
|
||
return nil
|
||
}
|
||
|
||
// 停止服务
|
||
err = instance.Server.Shutdown()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
instance.mutex.Lock()
|
||
instance.Running = false
|
||
instance.mutex.Unlock()
|
||
|
||
// 更新数据库状态
|
||
server.Status = proto.DNS_STATUS_STOPPED
|
||
dao.UpdateDNSServer(serverID, &server)
|
||
|
||
delete(dnsServiceManager.instances, serverID)
|
||
return nil
|
||
}
|
||
|
||
// RestartDNSServer 重启DNS服务
|
||
func RestartDNSServer(serverID uint, userID int) (*DNSServiceInstance, error) {
|
||
err := StopDNSServer(serverID, userID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return StartDNSServer(serverID, userID)
|
||
}
|
||
|
||
// GetDNSServerStatus 获取DNS服务运行状态
|
||
func GetDNSServerStatus(serverID uint, userID int) (*DNSServiceInstance, error) {
|
||
server, err := dao.FindDNSServerByID(serverID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if server.UserID != uint(userID) && GetUserByIDFromUserCenter(userID).Role != "admin" {
|
||
return nil, errors.New("未授权访问DNS服务器")
|
||
}
|
||
|
||
dnsServiceManager.rwMutex.RLock()
|
||
defer dnsServiceManager.rwMutex.RUnlock()
|
||
|
||
instance, exists := dnsServiceManager.instances[serverID]
|
||
if exists {
|
||
return instance, nil
|
||
}
|
||
|
||
// 不在运行中,返回配置中的状态
|
||
return &DNSServiceInstance{
|
||
ServerID: serverID,
|
||
Config: &server,
|
||
Running: server.Status == proto.DNS_STATUS_RUNNING,
|
||
}, nil
|
||
}
|
||
|
||
// RestartDNSServerIfRunning 如果服务正在运行则重启(配置更新时调用)
|
||
func RestartDNSServerIfRunning(serverID uint, userID int) error {
|
||
dnsServiceManager.rwMutex.RLock()
|
||
instance, exists := dnsServiceManager.instances[serverID]
|
||
dnsServiceManager.rwMutex.RUnlock()
|
||
|
||
if exists && instance.Running {
|
||
// 获取最新配置
|
||
server, err := dao.FindDNSServerByID(serverID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
// 重启服务(管理员权限)
|
||
_, err = RestartDNSServer(server.ID, userID) // 假设管理员ID为1
|
||
return err
|
||
}
|
||
return nil
|
||
}
|