Clean code dependencies on github.com/miekg/dns (#2099)

This commit is contained in:
yuhan6665 2023-05-20 23:40:56 -04:00 committed by GitHub
parent 51b2922427
commit c80646a045
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3,8 +3,8 @@ package dns
import ( import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"github.com/miekg/dns"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/dice"
) )
@ -36,7 +36,7 @@ func NewDNS(ctx context.Context, config interface{}) (interface{}, error) {
buf := make([]byte, 0x100) buf := make([]byte, 0x100)
off1, err := dns.PackDomainName(dns.Fqdn(config.(*Config).Domain), buf, 0, nil, false) off1, err := packDomainName(config.(*Config).Domain + ".", buf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -51,6 +51,73 @@ func NewDNS(ctx context.Context, config interface{}) (interface{}, error) {
}, nil }, nil
} }
// copied from github.com/miekg/dns
func packDomainName(s string, msg []byte) (off1 int, err error) {
off := 0
ls := len(s)
// Each dot ends a segment of the name.
// We trade each dot byte for a length byte.
// Except for escaped dots (\.), which are normal dots.
// There is also a trailing zero.
// Emit sequence of counted strings, chopping at dots.
var (
begin int
bs []byte
)
for i := 0; i < ls; i++ {
var c byte
if bs == nil {
c = s[i]
} else {
c = bs[i]
}
switch c {
case '\\':
if off+1 > len(msg) {
return len(msg), errors.New("buffer size too small")
}
if bs == nil {
bs = []byte(s)
}
copy(bs[i:ls-1], bs[i+1:])
ls--
case '.':
labelLen := i - begin
if labelLen >= 1<<6 { // top two bits of length must be clear
return len(msg), errors.New("bad rdata")
}
// off can already (we're in a loop) be bigger than len(msg)
// this happens when a name isn't fully qualified
if off+1+labelLen > len(msg) {
return len(msg), errors.New("buffer size too small")
}
// The following is covered by the length check above.
msg[off] = byte(labelLen)
if bs == nil {
copy(msg[off+1:], s[begin:i])
} else {
copy(msg[off+1:], bs[begin:i])
}
off += 1 + labelLen
begin = i + 1
default:
}
}
if off < len(msg) {
msg[off] = 0
}
return off + 1, nil
}
func init() { func init() {
common.Must(common.RegisterConfig((*Config)(nil), NewDNS)) common.Must(common.RegisterConfig((*Config)(nil), NewDNS))
} }