mirror of
https://github.com/XTLS/Xray-core.git
synced 2024-12-22 19:33:32 +02:00
Fix OCSP Stapling (#172)
Co-authored-by: RPRX <63339210+rprx@users.noreply.github.com>
This commit is contained in:
parent
4cd343f2d5
commit
c13b8ec9bb
2 changed files with 76 additions and 14 deletions
|
@ -42,8 +42,8 @@ func (c *Config) loadSelfCertPool() (*x509.CertPool, error) {
|
|||
}
|
||||
|
||||
// BuildCertificates builds a list of TLS certificates from proto definition.
|
||||
func (c *Config) BuildCertificates() []tls.Certificate {
|
||||
certs := make([]tls.Certificate, 0, len(c.Certificate))
|
||||
func (c *Config) BuildCertificates() []*tls.Certificate {
|
||||
certs := make([]*tls.Certificate, 0, len(c.Certificate))
|
||||
for _, entry := range c.Certificate {
|
||||
if entry.Usage != Certificate_ENCIPHERMENT {
|
||||
continue
|
||||
|
@ -53,7 +53,12 @@ func (c *Config) BuildCertificates() []tls.Certificate {
|
|||
newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog()
|
||||
continue
|
||||
}
|
||||
certs = append(certs, keyPair)
|
||||
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
|
||||
if err != nil {
|
||||
newError("ignoring invalid certificate").Base(err).AtWarning().WriteToLog()
|
||||
continue
|
||||
}
|
||||
certs = append(certs, &keyPair)
|
||||
if entry.OcspStapling != 0 {
|
||||
go func(cert *tls.Certificate) {
|
||||
t := time.NewTicker(time.Duration(entry.OcspStapling) * time.Second)
|
||||
|
@ -65,7 +70,7 @@ func (c *Config) BuildCertificates() []tls.Certificate {
|
|||
}
|
||||
<-t.C
|
||||
}
|
||||
}(&certs[len(certs)-1])
|
||||
}(certs[len(certs)-1])
|
||||
}
|
||||
}
|
||||
return certs
|
||||
|
@ -169,6 +174,33 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
|
|||
}
|
||||
}
|
||||
|
||||
func getNewGetCertficateFunc(certs []*tls.Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if len(certs) == 0 {
|
||||
return nil, newError("empty certs")
|
||||
}
|
||||
sni := strings.ToLower(hello.ServerName)
|
||||
if len(certs) == 1 || sni == "" {
|
||||
return certs[0], nil
|
||||
}
|
||||
gsni := "*"
|
||||
if index := strings.IndexByte(sni, '.'); index != -1 {
|
||||
gsni += sni[index:]
|
||||
}
|
||||
for _, keyPair := range certs {
|
||||
if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni {
|
||||
return keyPair, nil
|
||||
}
|
||||
for _, name := range keyPair.Leaf.DNSNames {
|
||||
if name == sni || name == gsni {
|
||||
return keyPair, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return certs[0], nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) IsExperiment8357() bool {
|
||||
return strings.HasPrefix(c.ServerName, exp8357)
|
||||
}
|
||||
|
@ -210,12 +242,11 @@ func (c *Config) GetTLSConfig(opts ...Option) *tls.Config {
|
|||
opt(config)
|
||||
}
|
||||
|
||||
config.Certificates = c.BuildCertificates()
|
||||
config.BuildNameToCertificate()
|
||||
|
||||
caCerts := c.getCustomCA()
|
||||
if len(caCerts) > 0 {
|
||||
config.GetCertificate = getGetCertificateFunc(config, caCerts)
|
||||
} else {
|
||||
config.GetCertificate = getNewGetCertficateFunc(c.BuildCertificates())
|
||||
}
|
||||
|
||||
if sn := c.parseServerName(); len(sn) > 0 {
|
||||
|
|
|
@ -41,8 +41,8 @@ func (c *Config) loadSelfCertPool() (*x509.CertPool, error) {
|
|||
}
|
||||
|
||||
// BuildCertificates builds a list of TLS certificates from proto definition.
|
||||
func (c *Config) BuildCertificates() []xtls.Certificate {
|
||||
certs := make([]xtls.Certificate, 0, len(c.Certificate))
|
||||
func (c *Config) BuildCertificates() []*xtls.Certificate {
|
||||
certs := make([]*xtls.Certificate, 0, len(c.Certificate))
|
||||
for _, entry := range c.Certificate {
|
||||
if entry.Usage != Certificate_ENCIPHERMENT {
|
||||
continue
|
||||
|
@ -52,7 +52,12 @@ func (c *Config) BuildCertificates() []xtls.Certificate {
|
|||
newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog()
|
||||
continue
|
||||
}
|
||||
certs = append(certs, keyPair)
|
||||
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
|
||||
if err != nil {
|
||||
newError("ignoring invalid certificate").Base(err).AtWarning().WriteToLog()
|
||||
continue
|
||||
}
|
||||
certs = append(certs, &keyPair)
|
||||
if entry.OcspStapling != 0 {
|
||||
go func(cert *xtls.Certificate) {
|
||||
t := time.NewTicker(time.Duration(entry.OcspStapling) * time.Second)
|
||||
|
@ -64,7 +69,7 @@ func (c *Config) BuildCertificates() []xtls.Certificate {
|
|||
}
|
||||
<-t.C
|
||||
}
|
||||
}(&certs[len(certs)-1])
|
||||
}(certs[len(certs)-1])
|
||||
}
|
||||
}
|
||||
return certs
|
||||
|
@ -168,6 +173,33 @@ func getGetCertificateFunc(c *xtls.Config, ca []*Certificate) func(hello *xtls.C
|
|||
}
|
||||
}
|
||||
|
||||
func getNewGetCertficateFunc(certs []*xtls.Certificate) func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) {
|
||||
return func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) {
|
||||
if len(certs) == 0 {
|
||||
return nil, newError("empty certs")
|
||||
}
|
||||
sni := strings.ToLower(hello.ServerName)
|
||||
if len(certs) == 1 || sni == "" {
|
||||
return certs[0], nil
|
||||
}
|
||||
gsni := "*"
|
||||
if index := strings.IndexByte(sni, '.'); index != -1 {
|
||||
gsni += sni[index:]
|
||||
}
|
||||
for _, keyPair := range certs {
|
||||
if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni {
|
||||
return keyPair, nil
|
||||
}
|
||||
for _, name := range keyPair.Leaf.DNSNames {
|
||||
if name == sni || name == gsni {
|
||||
return keyPair, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return certs[0], nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) parseServerName() string {
|
||||
return c.ServerName
|
||||
}
|
||||
|
@ -201,12 +233,11 @@ func (c *Config) GetXTLSConfig(opts ...Option) *xtls.Config {
|
|||
opt(config)
|
||||
}
|
||||
|
||||
config.Certificates = c.BuildCertificates()
|
||||
config.BuildNameToCertificate()
|
||||
|
||||
caCerts := c.getCustomCA()
|
||||
if len(caCerts) > 0 {
|
||||
config.GetCertificate = getGetCertificateFunc(config, caCerts)
|
||||
} else {
|
||||
config.GetCertificate = getNewGetCertficateFunc(c.BuildCertificates())
|
||||
}
|
||||
|
||||
if sn := c.parseServerName(); len(sn) > 0 {
|
||||
|
|
Loading…
Reference in a new issue