diff --git a/pool/client.go b/pool/client.go index 0c74e6b..df46ffa 100644 --- a/pool/client.go +++ b/pool/client.go @@ -15,6 +15,7 @@ import ( "strings" "sync" "sync/atomic" + "time" "dappco.re/go/core/proxy" ) @@ -82,15 +83,28 @@ func NewStratumClient(cfg proxy.PoolConfig, listener StratumListener) *StratumCl func (c *StratumClient) Connect() error { var connection net.Conn var errorValue error + dialer := net.Dialer{} + if c.cfg.Keepalive { + dialer.KeepAlive = 30 * time.Second + } if c.cfg.TLS { - tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12} - connection, errorValue = tls.Dial("tcp", c.cfg.URL, tlsConfig) + connection, errorValue = dialer.Dial("tcp", c.cfg.URL) if errorValue != nil { return errorValue } - tlsConnection := connection.(*tls.Conn) + serverName := c.cfg.URL + if host, _, splitError := net.SplitHostPort(c.cfg.URL); splitError == nil && host != "" { + serverName = host + } + + tlsConnection := tls.Client(connection, &tls.Config{MinVersion: tls.VersionTLS12, ServerName: serverName}) + errorValue = tlsConnection.Handshake() + if errorValue != nil { + _ = connection.Close() + return errorValue + } if c.cfg.TLSFingerprint != "" { state := tlsConnection.ConnectionState() if len(state.PeerCertificates) == 0 { @@ -104,9 +118,10 @@ func (c *StratumClient) Connect() error { return errors.New("pool fingerprint mismatch") } } + connection = tlsConnection c.tlsConn = tlsConnection } else { - connection, errorValue = net.Dial("tcp", c.cfg.URL) + connection, errorValue = dialer.Dial("tcp", c.cfg.URL) if errorValue != nil { return errorValue } diff --git a/proxy_runtime.go b/proxy_runtime.go index 87a8329..f940b17 100644 --- a/proxy_runtime.go +++ b/proxy_runtime.go @@ -96,12 +96,17 @@ func (p *Proxy) Start() { if bind.TLS && p.config.TLS.Enabled { certificate, errorValue := tls.LoadX509KeyPair(p.config.TLS.CertFile, p.config.TLS.KeyFile) if errorValue == nil { - tlsConfig = &tls.Config{Certificates: []tls.Certificate{certificate}} + tlsConfig = buildTLSConfig(p.config.TLS) + tlsConfig.Certificates = []tls.Certificate{certificate} + } else { + p.Stop() + return } } server, errorValue := NewServer(bind, tlsConfig, p.rateLimiter, p.acceptConn) if errorValue != nil { - continue + p.Stop() + return } p.servers = append(p.servers, server) server.Start() diff --git a/tls_runtime.go b/tls_runtime.go new file mode 100644 index 0000000..c19bdc0 --- /dev/null +++ b/tls_runtime.go @@ -0,0 +1,120 @@ +package proxy + +import ( + "crypto/tls" + "strconv" + "strings" +) + +func buildTLSConfig(config TLSConfig) *tls.Config { + tlsConfig := &tls.Config{} + if versions := parseTLSVersions(config.Protocols); versions != nil { + tlsConfig.MinVersion = versions.min + tlsConfig.MaxVersion = versions.max + } + if suites := parseCipherSuites(config.Ciphers); len(suites) > 0 { + tlsConfig.CipherSuites = suites + } + return tlsConfig +} + +type tlsVersionBounds struct { + min uint16 + max uint16 +} + +func parseTLSVersions(value string) *tlsVersionBounds { + if strings.TrimSpace(value) == "" { + return nil + } + + bounds := tlsVersionBounds{} + for _, token := range splitTLSList(value) { + version, ok := parseTLSVersionToken(token) + if !ok { + continue + } + if bounds.min == 0 || version < bounds.min { + bounds.min = version + } + if version > bounds.max { + bounds.max = version + } + } + + if bounds.min == 0 || bounds.max == 0 { + return nil + } + + return &bounds +} + +func parseTLSVersionToken(token string) (uint16, bool) { + switch strings.ToLower(strings.TrimSpace(token)) { + case "tls1.0", "tlsv1.0", "tls1", "tlsv1", "1.0", "tls10": + return tls.VersionTLS10, true + case "tls1.1", "tlsv1.1", "1.1", "tls11": + return tls.VersionTLS11, true + case "tls1.2", "tlsv1.2", "1.2", "tls12": + return tls.VersionTLS12, true + case "tls1.3", "tlsv1.3", "1.3", "tls13": + return tls.VersionTLS13, true + } + + if raw, errorValue := strconv.ParseUint(strings.TrimSpace(token), 10, 16); errorValue == nil { + switch uint16(raw) { + case tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13: + return uint16(raw), true + } + } + + return 0, false +} + +func parseCipherSuites(value string) []uint16 { + if strings.TrimSpace(value) == "" { + return nil + } + + var suites []uint16 + for _, token := range splitTLSList(value) { + if suite, ok := tlsCipherSuiteNames[strings.ToUpper(strings.TrimSpace(token))]; ok { + suites = append(suites, suite) + } + } + return suites +} + +func splitTLSList(value string) []string { + return strings.FieldsFunc(value, func(r rune) bool { + switch r { + case ':', ',', ' ', ';': + return true + default: + return false + } + }) +} + +var tlsCipherSuiteNames = map[string]uint16{ + "TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + "TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + "TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256, + "TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384, + "TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256, + "ECDHE-RSA-AES128-GCM-SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + "ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + "ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + "ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + "AES128-GCM-SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + "AES256-GCM-SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + "ECDHE-RSA-CHACHA20-POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + "ECDHE-ECDSA-CHACHA20-POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + "CHACHA20-POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, +} diff --git a/tls_runtime_test.go b/tls_runtime_test.go new file mode 100644 index 0000000..3327181 --- /dev/null +++ b/tls_runtime_test.go @@ -0,0 +1,48 @@ +package proxy + +import ( + "crypto/tls" + "testing" +) + +func TestTLSRuntime_buildTLSConfig_Good(t *testing.T) { + config := buildTLSConfig(TLSConfig{ + Ciphers: "ECDHE-RSA-AES128-GCM-SHA256:TLS_AES_128_GCM_SHA256", + Protocols: "TLSv1.2,TLSv1.3", + }) + + if config.MinVersion != tls.VersionTLS12 { + t.Fatalf("expected min version TLS1.2, got %d", config.MinVersion) + } + if config.MaxVersion != tls.VersionTLS13 { + t.Fatalf("expected max version TLS1.3, got %d", config.MaxVersion) + } + if len(config.CipherSuites) != 2 || config.CipherSuites[0] != tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 || config.CipherSuites[1] != tls.TLS_AES_128_GCM_SHA256 { + t.Fatalf("unexpected cipher suites: %#v", config.CipherSuites) + } +} + +func TestTLSRuntime_buildTLSConfig_Bad(t *testing.T) { + config := buildTLSConfig(TLSConfig{Protocols: "bogus", Ciphers: "bogus"}) + + if config.MinVersion != 0 || config.MaxVersion != 0 { + t.Fatalf("expected default versions for invalid input, got min=%d max=%d", config.MinVersion, config.MaxVersion) + } + if len(config.CipherSuites) != 0 { + t.Fatalf("expected no cipher suites for invalid input, got %#v", config.CipherSuites) + } +} + +func TestTLSRuntime_buildTLSConfig_Ugly(t *testing.T) { + config := buildTLSConfig(TLSConfig{Protocols: "1.1:1.2:1.3", Ciphers: "AES128-GCM-SHA256,unknown"}) + + if config.MinVersion != tls.VersionTLS11 { + t.Fatalf("expected min version TLS1.1, got %d", config.MinVersion) + } + if config.MaxVersion != tls.VersionTLS13 { + t.Fatalf("expected max version TLS1.3, got %d", config.MaxVersion) + } + if len(config.CipherSuites) != 1 || config.CipherSuites[0] != tls.TLS_RSA_WITH_AES_128_GCM_SHA256 { + t.Fatalf("unexpected cipher suites: %#v", config.CipherSuites) + } +}