diff --git a/transport/internet/tls/config.go b/transport/internet/tls/config.go index a3bab60a..9df57b84 100644 --- a/transport/internet/tls/config.go +++ b/transport/internet/tls/config.go @@ -42,8 +42,8 @@ func (c *Config) loadSelfCertPool() (*x509.CertPool, error) { } // BuildCertificates builds a list of TLS certificates from proto definition. -func (c *Config) BuildCertificates() []tls.Certificate { - certs := make([]tls.Certificate, 0, len(c.Certificate)) +func (c *Config) BuildCertificates() []*tls.Certificate { + certs := make([]*tls.Certificate, 0, len(c.Certificate)) for _, entry := range c.Certificate { if entry.Usage != Certificate_ENCIPHERMENT { continue @@ -53,7 +53,12 @@ func (c *Config) BuildCertificates() []tls.Certificate { newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog() continue } - certs = append(certs, keyPair) + keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0]) + if err != nil { + newError("ignoring invalid certificate").Base(err).AtWarning().WriteToLog() + continue + } + certs = append(certs, &keyPair) if entry.OcspStapling != 0 { go func(cert *tls.Certificate) { t := time.NewTicker(time.Duration(entry.OcspStapling) * time.Second) @@ -65,7 +70,7 @@ func (c *Config) BuildCertificates() []tls.Certificate { } <-t.C } - }(&certs[len(certs)-1]) + }(certs[len(certs)-1]) } } return certs @@ -169,6 +174,33 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli } } +func getNewGetCertficateFunc(certs []*tls.Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + if len(certs) == 0 { + return nil, newError("empty certs") + } + sni := strings.ToLower(hello.ServerName) + if len(certs) == 1 || sni == "" { + return certs[0], nil + } + gsni := "*" + if index := strings.IndexByte(sni, '.'); index != -1 { + gsni += sni[index:] + } + for _, keyPair := range certs { + if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni { + return keyPair, nil + } + for _, name := range keyPair.Leaf.DNSNames { + if name == sni || name == gsni { + return keyPair, nil + } + } + } + return certs[0], nil + } +} + func (c *Config) IsExperiment8357() bool { return strings.HasPrefix(c.ServerName, exp8357) } @@ -210,12 +242,11 @@ func (c *Config) GetTLSConfig(opts ...Option) *tls.Config { opt(config) } - config.Certificates = c.BuildCertificates() - config.BuildNameToCertificate() - caCerts := c.getCustomCA() if len(caCerts) > 0 { config.GetCertificate = getGetCertificateFunc(config, caCerts) + } else { + config.GetCertificate = getNewGetCertficateFunc(c.BuildCertificates()) } if sn := c.parseServerName(); len(sn) > 0 { diff --git a/transport/internet/xtls/config.go b/transport/internet/xtls/config.go index df82bcec..3330e3ac 100644 --- a/transport/internet/xtls/config.go +++ b/transport/internet/xtls/config.go @@ -41,8 +41,8 @@ func (c *Config) loadSelfCertPool() (*x509.CertPool, error) { } // BuildCertificates builds a list of TLS certificates from proto definition. -func (c *Config) BuildCertificates() []xtls.Certificate { - certs := make([]xtls.Certificate, 0, len(c.Certificate)) +func (c *Config) BuildCertificates() []*xtls.Certificate { + certs := make([]*xtls.Certificate, 0, len(c.Certificate)) for _, entry := range c.Certificate { if entry.Usage != Certificate_ENCIPHERMENT { continue @@ -52,7 +52,12 @@ func (c *Config) BuildCertificates() []xtls.Certificate { newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog() continue } - certs = append(certs, keyPair) + keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0]) + if err != nil { + newError("ignoring invalid certificate").Base(err).AtWarning().WriteToLog() + continue + } + certs = append(certs, &keyPair) if entry.OcspStapling != 0 { go func(cert *xtls.Certificate) { t := time.NewTicker(time.Duration(entry.OcspStapling) * time.Second) @@ -64,7 +69,7 @@ func (c *Config) BuildCertificates() []xtls.Certificate { } <-t.C } - }(&certs[len(certs)-1]) + }(certs[len(certs)-1]) } } return certs @@ -168,6 +173,33 @@ func getGetCertificateFunc(c *xtls.Config, ca []*Certificate) func(hello *xtls.C } } +func getNewGetCertficateFunc(certs []*xtls.Certificate) func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) { + return func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) { + if len(certs) == 0 { + return nil, newError("empty certs") + } + sni := strings.ToLower(hello.ServerName) + if len(certs) == 1 || sni == "" { + return certs[0], nil + } + gsni := "*" + if index := strings.IndexByte(sni, '.'); index != -1 { + gsni += sni[index:] + } + for _, keyPair := range certs { + if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni { + return keyPair, nil + } + for _, name := range keyPair.Leaf.DNSNames { + if name == sni || name == gsni { + return keyPair, nil + } + } + } + return certs[0], nil + } +} + func (c *Config) parseServerName() string { return c.ServerName } @@ -201,12 +233,11 @@ func (c *Config) GetXTLSConfig(opts ...Option) *xtls.Config { opt(config) } - config.Certificates = c.BuildCertificates() - config.BuildNameToCertificate() - caCerts := c.getCustomCA() if len(caCerts) > 0 { config.GetCertificate = getGetCertificateFunc(config, caCerts) + } else { + config.GetCertificate = getNewGetCertficateFunc(c.BuildCertificates()) } if sn := c.parseServerName(); len(sn) > 0 {