feat(proxy): honour TLS config and pool keepalive
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
3376cea600
commit
465ea38308
4 changed files with 194 additions and 6 deletions
|
|
@ -15,6 +15,7 @@ import (
|
|||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"dappco.re/go/core/proxy"
|
||||
)
|
||||
|
|
@ -82,15 +83,28 @@ func NewStratumClient(cfg proxy.PoolConfig, listener StratumListener) *StratumCl
|
|||
func (c *StratumClient) Connect() error {
|
||||
var connection net.Conn
|
||||
var errorValue error
|
||||
dialer := net.Dialer{}
|
||||
if c.cfg.Keepalive {
|
||||
dialer.KeepAlive = 30 * time.Second
|
||||
}
|
||||
|
||||
if c.cfg.TLS {
|
||||
tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
connection, errorValue = tls.Dial("tcp", c.cfg.URL, tlsConfig)
|
||||
connection, errorValue = dialer.Dial("tcp", c.cfg.URL)
|
||||
if errorValue != nil {
|
||||
return errorValue
|
||||
}
|
||||
|
||||
tlsConnection := connection.(*tls.Conn)
|
||||
serverName := c.cfg.URL
|
||||
if host, _, splitError := net.SplitHostPort(c.cfg.URL); splitError == nil && host != "" {
|
||||
serverName = host
|
||||
}
|
||||
|
||||
tlsConnection := tls.Client(connection, &tls.Config{MinVersion: tls.VersionTLS12, ServerName: serverName})
|
||||
errorValue = tlsConnection.Handshake()
|
||||
if errorValue != nil {
|
||||
_ = connection.Close()
|
||||
return errorValue
|
||||
}
|
||||
if c.cfg.TLSFingerprint != "" {
|
||||
state := tlsConnection.ConnectionState()
|
||||
if len(state.PeerCertificates) == 0 {
|
||||
|
|
@ -104,9 +118,10 @@ func (c *StratumClient) Connect() error {
|
|||
return errors.New("pool fingerprint mismatch")
|
||||
}
|
||||
}
|
||||
connection = tlsConnection
|
||||
c.tlsConn = tlsConnection
|
||||
} else {
|
||||
connection, errorValue = net.Dial("tcp", c.cfg.URL)
|
||||
connection, errorValue = dialer.Dial("tcp", c.cfg.URL)
|
||||
if errorValue != nil {
|
||||
return errorValue
|
||||
}
|
||||
|
|
|
|||
|
|
@ -96,12 +96,17 @@ func (p *Proxy) Start() {
|
|||
if bind.TLS && p.config.TLS.Enabled {
|
||||
certificate, errorValue := tls.LoadX509KeyPair(p.config.TLS.CertFile, p.config.TLS.KeyFile)
|
||||
if errorValue == nil {
|
||||
tlsConfig = &tls.Config{Certificates: []tls.Certificate{certificate}}
|
||||
tlsConfig = buildTLSConfig(p.config.TLS)
|
||||
tlsConfig.Certificates = []tls.Certificate{certificate}
|
||||
} else {
|
||||
p.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
server, errorValue := NewServer(bind, tlsConfig, p.rateLimiter, p.acceptConn)
|
||||
if errorValue != nil {
|
||||
continue
|
||||
p.Stop()
|
||||
return
|
||||
}
|
||||
p.servers = append(p.servers, server)
|
||||
server.Start()
|
||||
|
|
|
|||
120
tls_runtime.go
Normal file
120
tls_runtime.go
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func buildTLSConfig(config TLSConfig) *tls.Config {
|
||||
tlsConfig := &tls.Config{}
|
||||
if versions := parseTLSVersions(config.Protocols); versions != nil {
|
||||
tlsConfig.MinVersion = versions.min
|
||||
tlsConfig.MaxVersion = versions.max
|
||||
}
|
||||
if suites := parseCipherSuites(config.Ciphers); len(suites) > 0 {
|
||||
tlsConfig.CipherSuites = suites
|
||||
}
|
||||
return tlsConfig
|
||||
}
|
||||
|
||||
type tlsVersionBounds struct {
|
||||
min uint16
|
||||
max uint16
|
||||
}
|
||||
|
||||
func parseTLSVersions(value string) *tlsVersionBounds {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
bounds := tlsVersionBounds{}
|
||||
for _, token := range splitTLSList(value) {
|
||||
version, ok := parseTLSVersionToken(token)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if bounds.min == 0 || version < bounds.min {
|
||||
bounds.min = version
|
||||
}
|
||||
if version > bounds.max {
|
||||
bounds.max = version
|
||||
}
|
||||
}
|
||||
|
||||
if bounds.min == 0 || bounds.max == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &bounds
|
||||
}
|
||||
|
||||
func parseTLSVersionToken(token string) (uint16, bool) {
|
||||
switch strings.ToLower(strings.TrimSpace(token)) {
|
||||
case "tls1.0", "tlsv1.0", "tls1", "tlsv1", "1.0", "tls10":
|
||||
return tls.VersionTLS10, true
|
||||
case "tls1.1", "tlsv1.1", "1.1", "tls11":
|
||||
return tls.VersionTLS11, true
|
||||
case "tls1.2", "tlsv1.2", "1.2", "tls12":
|
||||
return tls.VersionTLS12, true
|
||||
case "tls1.3", "tlsv1.3", "1.3", "tls13":
|
||||
return tls.VersionTLS13, true
|
||||
}
|
||||
|
||||
if raw, errorValue := strconv.ParseUint(strings.TrimSpace(token), 10, 16); errorValue == nil {
|
||||
switch uint16(raw) {
|
||||
case tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13:
|
||||
return uint16(raw), true
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func parseCipherSuites(value string) []uint16 {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var suites []uint16
|
||||
for _, token := range splitTLSList(value) {
|
||||
if suite, ok := tlsCipherSuiteNames[strings.ToUpper(strings.TrimSpace(token))]; ok {
|
||||
suites = append(suites, suite)
|
||||
}
|
||||
}
|
||||
return suites
|
||||
}
|
||||
|
||||
func splitTLSList(value string) []string {
|
||||
return strings.FieldsFunc(value, func(r rune) bool {
|
||||
switch r {
|
||||
case ':', ',', ' ', ';':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
var tlsCipherSuiteNames = map[string]uint16{
|
||||
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
"TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256,
|
||||
"TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384,
|
||||
"TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256,
|
||||
"ECDHE-RSA-AES128-GCM-SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
"ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
"ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
"AES128-GCM-SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
|
||||
"AES256-GCM-SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"ECDHE-RSA-CHACHA20-POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
"ECDHE-ECDSA-CHACHA20-POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
"CHACHA20-POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
}
|
||||
48
tls_runtime_test.go
Normal file
48
tls_runtime_test.go
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTLSRuntime_buildTLSConfig_Good(t *testing.T) {
|
||||
config := buildTLSConfig(TLSConfig{
|
||||
Ciphers: "ECDHE-RSA-AES128-GCM-SHA256:TLS_AES_128_GCM_SHA256",
|
||||
Protocols: "TLSv1.2,TLSv1.3",
|
||||
})
|
||||
|
||||
if config.MinVersion != tls.VersionTLS12 {
|
||||
t.Fatalf("expected min version TLS1.2, got %d", config.MinVersion)
|
||||
}
|
||||
if config.MaxVersion != tls.VersionTLS13 {
|
||||
t.Fatalf("expected max version TLS1.3, got %d", config.MaxVersion)
|
||||
}
|
||||
if len(config.CipherSuites) != 2 || config.CipherSuites[0] != tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 || config.CipherSuites[1] != tls.TLS_AES_128_GCM_SHA256 {
|
||||
t.Fatalf("unexpected cipher suites: %#v", config.CipherSuites)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSRuntime_buildTLSConfig_Bad(t *testing.T) {
|
||||
config := buildTLSConfig(TLSConfig{Protocols: "bogus", Ciphers: "bogus"})
|
||||
|
||||
if config.MinVersion != 0 || config.MaxVersion != 0 {
|
||||
t.Fatalf("expected default versions for invalid input, got min=%d max=%d", config.MinVersion, config.MaxVersion)
|
||||
}
|
||||
if len(config.CipherSuites) != 0 {
|
||||
t.Fatalf("expected no cipher suites for invalid input, got %#v", config.CipherSuites)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSRuntime_buildTLSConfig_Ugly(t *testing.T) {
|
||||
config := buildTLSConfig(TLSConfig{Protocols: "1.1:1.2:1.3", Ciphers: "AES128-GCM-SHA256,unknown"})
|
||||
|
||||
if config.MinVersion != tls.VersionTLS11 {
|
||||
t.Fatalf("expected min version TLS1.1, got %d", config.MinVersion)
|
||||
}
|
||||
if config.MaxVersion != tls.VersionTLS13 {
|
||||
t.Fatalf("expected max version TLS1.3, got %d", config.MaxVersion)
|
||||
}
|
||||
if len(config.CipherSuites) != 1 || config.CipherSuites[0] != tls.TLS_RSA_WITH_AES_128_GCM_SHA256 {
|
||||
t.Fatalf("unexpected cipher suites: %#v", config.CipherSuites)
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue