mirror of
https://github.com/XTLS/Xray-core.git
synced 2024-11-29 03:51:28 +00:00
Fix OCSP Stapling (#172)
Co-authored-by: RPRX <63339210+rprx@users.noreply.github.com>
This commit is contained in:
parent
4cd343f2d5
commit
c13b8ec9bb
|
@ -42,8 +42,8 @@ func (c *Config) loadSelfCertPool() (*x509.CertPool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildCertificates builds a list of TLS certificates from proto definition.
|
// BuildCertificates builds a list of TLS certificates from proto definition.
|
||||||
func (c *Config) BuildCertificates() []tls.Certificate {
|
func (c *Config) BuildCertificates() []*tls.Certificate {
|
||||||
certs := make([]tls.Certificate, 0, len(c.Certificate))
|
certs := make([]*tls.Certificate, 0, len(c.Certificate))
|
||||||
for _, entry := range c.Certificate {
|
for _, entry := range c.Certificate {
|
||||||
if entry.Usage != Certificate_ENCIPHERMENT {
|
if entry.Usage != Certificate_ENCIPHERMENT {
|
||||||
continue
|
continue
|
||||||
|
@ -53,7 +53,12 @@ func (c *Config) BuildCertificates() []tls.Certificate {
|
||||||
newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog()
|
newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
certs = append(certs, keyPair)
|
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
newError("ignoring invalid certificate").Base(err).AtWarning().WriteToLog()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
certs = append(certs, &keyPair)
|
||||||
if entry.OcspStapling != 0 {
|
if entry.OcspStapling != 0 {
|
||||||
go func(cert *tls.Certificate) {
|
go func(cert *tls.Certificate) {
|
||||||
t := time.NewTicker(time.Duration(entry.OcspStapling) * time.Second)
|
t := time.NewTicker(time.Duration(entry.OcspStapling) * time.Second)
|
||||||
|
@ -65,7 +70,7 @@ func (c *Config) BuildCertificates() []tls.Certificate {
|
||||||
}
|
}
|
||||||
<-t.C
|
<-t.C
|
||||||
}
|
}
|
||||||
}(&certs[len(certs)-1])
|
}(certs[len(certs)-1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return certs
|
return certs
|
||||||
|
@ -169,6 +174,33 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getNewGetCertficateFunc(certs []*tls.Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
if len(certs) == 0 {
|
||||||
|
return nil, newError("empty certs")
|
||||||
|
}
|
||||||
|
sni := strings.ToLower(hello.ServerName)
|
||||||
|
if len(certs) == 1 || sni == "" {
|
||||||
|
return certs[0], nil
|
||||||
|
}
|
||||||
|
gsni := "*"
|
||||||
|
if index := strings.IndexByte(sni, '.'); index != -1 {
|
||||||
|
gsni += sni[index:]
|
||||||
|
}
|
||||||
|
for _, keyPair := range certs {
|
||||||
|
if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni {
|
||||||
|
return keyPair, nil
|
||||||
|
}
|
||||||
|
for _, name := range keyPair.Leaf.DNSNames {
|
||||||
|
if name == sni || name == gsni {
|
||||||
|
return keyPair, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return certs[0], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Config) IsExperiment8357() bool {
|
func (c *Config) IsExperiment8357() bool {
|
||||||
return strings.HasPrefix(c.ServerName, exp8357)
|
return strings.HasPrefix(c.ServerName, exp8357)
|
||||||
}
|
}
|
||||||
|
@ -210,12 +242,11 @@ func (c *Config) GetTLSConfig(opts ...Option) *tls.Config {
|
||||||
opt(config)
|
opt(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
config.Certificates = c.BuildCertificates()
|
|
||||||
config.BuildNameToCertificate()
|
|
||||||
|
|
||||||
caCerts := c.getCustomCA()
|
caCerts := c.getCustomCA()
|
||||||
if len(caCerts) > 0 {
|
if len(caCerts) > 0 {
|
||||||
config.GetCertificate = getGetCertificateFunc(config, caCerts)
|
config.GetCertificate = getGetCertificateFunc(config, caCerts)
|
||||||
|
} else {
|
||||||
|
config.GetCertificate = getNewGetCertficateFunc(c.BuildCertificates())
|
||||||
}
|
}
|
||||||
|
|
||||||
if sn := c.parseServerName(); len(sn) > 0 {
|
if sn := c.parseServerName(); len(sn) > 0 {
|
||||||
|
|
|
@ -41,8 +41,8 @@ func (c *Config) loadSelfCertPool() (*x509.CertPool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildCertificates builds a list of TLS certificates from proto definition.
|
// BuildCertificates builds a list of TLS certificates from proto definition.
|
||||||
func (c *Config) BuildCertificates() []xtls.Certificate {
|
func (c *Config) BuildCertificates() []*xtls.Certificate {
|
||||||
certs := make([]xtls.Certificate, 0, len(c.Certificate))
|
certs := make([]*xtls.Certificate, 0, len(c.Certificate))
|
||||||
for _, entry := range c.Certificate {
|
for _, entry := range c.Certificate {
|
||||||
if entry.Usage != Certificate_ENCIPHERMENT {
|
if entry.Usage != Certificate_ENCIPHERMENT {
|
||||||
continue
|
continue
|
||||||
|
@ -52,7 +52,12 @@ func (c *Config) BuildCertificates() []xtls.Certificate {
|
||||||
newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog()
|
newError("ignoring invalid X509 key pair").Base(err).AtWarning().WriteToLog()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
certs = append(certs, keyPair)
|
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
newError("ignoring invalid certificate").Base(err).AtWarning().WriteToLog()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
certs = append(certs, &keyPair)
|
||||||
if entry.OcspStapling != 0 {
|
if entry.OcspStapling != 0 {
|
||||||
go func(cert *xtls.Certificate) {
|
go func(cert *xtls.Certificate) {
|
||||||
t := time.NewTicker(time.Duration(entry.OcspStapling) * time.Second)
|
t := time.NewTicker(time.Duration(entry.OcspStapling) * time.Second)
|
||||||
|
@ -64,7 +69,7 @@ func (c *Config) BuildCertificates() []xtls.Certificate {
|
||||||
}
|
}
|
||||||
<-t.C
|
<-t.C
|
||||||
}
|
}
|
||||||
}(&certs[len(certs)-1])
|
}(certs[len(certs)-1])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return certs
|
return certs
|
||||||
|
@ -168,6 +173,33 @@ func getGetCertificateFunc(c *xtls.Config, ca []*Certificate) func(hello *xtls.C
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getNewGetCertficateFunc(certs []*xtls.Certificate) func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) {
|
||||||
|
return func(hello *xtls.ClientHelloInfo) (*xtls.Certificate, error) {
|
||||||
|
if len(certs) == 0 {
|
||||||
|
return nil, newError("empty certs")
|
||||||
|
}
|
||||||
|
sni := strings.ToLower(hello.ServerName)
|
||||||
|
if len(certs) == 1 || sni == "" {
|
||||||
|
return certs[0], nil
|
||||||
|
}
|
||||||
|
gsni := "*"
|
||||||
|
if index := strings.IndexByte(sni, '.'); index != -1 {
|
||||||
|
gsni += sni[index:]
|
||||||
|
}
|
||||||
|
for _, keyPair := range certs {
|
||||||
|
if keyPair.Leaf.Subject.CommonName == sni || keyPair.Leaf.Subject.CommonName == gsni {
|
||||||
|
return keyPair, nil
|
||||||
|
}
|
||||||
|
for _, name := range keyPair.Leaf.DNSNames {
|
||||||
|
if name == sni || name == gsni {
|
||||||
|
return keyPair, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return certs[0], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Config) parseServerName() string {
|
func (c *Config) parseServerName() string {
|
||||||
return c.ServerName
|
return c.ServerName
|
||||||
}
|
}
|
||||||
|
@ -201,12 +233,11 @@ func (c *Config) GetXTLSConfig(opts ...Option) *xtls.Config {
|
||||||
opt(config)
|
opt(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
config.Certificates = c.BuildCertificates()
|
|
||||||
config.BuildNameToCertificate()
|
|
||||||
|
|
||||||
caCerts := c.getCustomCA()
|
caCerts := c.getCustomCA()
|
||||||
if len(caCerts) > 0 {
|
if len(caCerts) > 0 {
|
||||||
config.GetCertificate = getGetCertificateFunc(config, caCerts)
|
config.GetCertificate = getGetCertificateFunc(config, caCerts)
|
||||||
|
} else {
|
||||||
|
config.GetCertificate = getNewGetCertficateFunc(c.BuildCertificates())
|
||||||
}
|
}
|
||||||
|
|
||||||
if sn := c.parseServerName(); len(sn) > 0 {
|
if sn := c.parseServerName(); len(sn) > 0 {
|
||||||
|
|
Loading…
Reference in a new issue