feat(proxy): honour TLS config and pool keepalive

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-04 11:16:29 +00:00
parent 3376cea600
commit 465ea38308
4 changed files with 194 additions and 6 deletions

View file

@ -15,6 +15,7 @@ import (
"strings"
"sync"
"sync/atomic"
"time"
"dappco.re/go/core/proxy"
)
@ -82,15 +83,28 @@ func NewStratumClient(cfg proxy.PoolConfig, listener StratumListener) *StratumCl
func (c *StratumClient) Connect() error {
var connection net.Conn
var errorValue error
dialer := net.Dialer{}
if c.cfg.Keepalive {
dialer.KeepAlive = 30 * time.Second
}
if c.cfg.TLS {
tlsConfig := &tls.Config{MinVersion: tls.VersionTLS12}
connection, errorValue = tls.Dial("tcp", c.cfg.URL, tlsConfig)
connection, errorValue = dialer.Dial("tcp", c.cfg.URL)
if errorValue != nil {
return errorValue
}
tlsConnection := connection.(*tls.Conn)
serverName := c.cfg.URL
if host, _, splitError := net.SplitHostPort(c.cfg.URL); splitError == nil && host != "" {
serverName = host
}
tlsConnection := tls.Client(connection, &tls.Config{MinVersion: tls.VersionTLS12, ServerName: serverName})
errorValue = tlsConnection.Handshake()
if errorValue != nil {
_ = connection.Close()
return errorValue
}
if c.cfg.TLSFingerprint != "" {
state := tlsConnection.ConnectionState()
if len(state.PeerCertificates) == 0 {
@ -104,9 +118,10 @@ func (c *StratumClient) Connect() error {
return errors.New("pool fingerprint mismatch")
}
}
connection = tlsConnection
c.tlsConn = tlsConnection
} else {
connection, errorValue = net.Dial("tcp", c.cfg.URL)
connection, errorValue = dialer.Dial("tcp", c.cfg.URL)
if errorValue != nil {
return errorValue
}

View file

@ -96,12 +96,17 @@ func (p *Proxy) Start() {
if bind.TLS && p.config.TLS.Enabled {
certificate, errorValue := tls.LoadX509KeyPair(p.config.TLS.CertFile, p.config.TLS.KeyFile)
if errorValue == nil {
tlsConfig = &tls.Config{Certificates: []tls.Certificate{certificate}}
tlsConfig = buildTLSConfig(p.config.TLS)
tlsConfig.Certificates = []tls.Certificate{certificate}
} else {
p.Stop()
return
}
}
server, errorValue := NewServer(bind, tlsConfig, p.rateLimiter, p.acceptConn)
if errorValue != nil {
continue
p.Stop()
return
}
p.servers = append(p.servers, server)
server.Start()

120
tls_runtime.go Normal file
View file

@ -0,0 +1,120 @@
package proxy
import (
"crypto/tls"
"strconv"
"strings"
)
func buildTLSConfig(config TLSConfig) *tls.Config {
tlsConfig := &tls.Config{}
if versions := parseTLSVersions(config.Protocols); versions != nil {
tlsConfig.MinVersion = versions.min
tlsConfig.MaxVersion = versions.max
}
if suites := parseCipherSuites(config.Ciphers); len(suites) > 0 {
tlsConfig.CipherSuites = suites
}
return tlsConfig
}
type tlsVersionBounds struct {
min uint16
max uint16
}
func parseTLSVersions(value string) *tlsVersionBounds {
if strings.TrimSpace(value) == "" {
return nil
}
bounds := tlsVersionBounds{}
for _, token := range splitTLSList(value) {
version, ok := parseTLSVersionToken(token)
if !ok {
continue
}
if bounds.min == 0 || version < bounds.min {
bounds.min = version
}
if version > bounds.max {
bounds.max = version
}
}
if bounds.min == 0 || bounds.max == 0 {
return nil
}
return &bounds
}
func parseTLSVersionToken(token string) (uint16, bool) {
switch strings.ToLower(strings.TrimSpace(token)) {
case "tls1.0", "tlsv1.0", "tls1", "tlsv1", "1.0", "tls10":
return tls.VersionTLS10, true
case "tls1.1", "tlsv1.1", "1.1", "tls11":
return tls.VersionTLS11, true
case "tls1.2", "tlsv1.2", "1.2", "tls12":
return tls.VersionTLS12, true
case "tls1.3", "tlsv1.3", "1.3", "tls13":
return tls.VersionTLS13, true
}
if raw, errorValue := strconv.ParseUint(strings.TrimSpace(token), 10, 16); errorValue == nil {
switch uint16(raw) {
case tls.VersionTLS10, tls.VersionTLS11, tls.VersionTLS12, tls.VersionTLS13:
return uint16(raw), true
}
}
return 0, false
}
func parseCipherSuites(value string) []uint16 {
if strings.TrimSpace(value) == "" {
return nil
}
var suites []uint16
for _, token := range splitTLSList(value) {
if suite, ok := tlsCipherSuiteNames[strings.ToUpper(strings.TrimSpace(token))]; ok {
suites = append(suites, suite)
}
}
return suites
}
func splitTLSList(value string) []string {
return strings.FieldsFunc(value, func(r rune) bool {
switch r {
case ':', ',', ' ', ';':
return true
default:
return false
}
})
}
var tlsCipherSuiteNames = map[string]uint16{
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
"TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256,
"TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384,
"TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256,
"ECDHE-RSA-AES128-GCM-SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
"ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
"ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
"ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
"AES128-GCM-SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
"AES256-GCM-SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
"ECDHE-RSA-CHACHA20-POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
"ECDHE-ECDSA-CHACHA20-POLY1305": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
"CHACHA20-POLY1305": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
}

48
tls_runtime_test.go Normal file
View file

@ -0,0 +1,48 @@
package proxy
import (
"crypto/tls"
"testing"
)
func TestTLSRuntime_buildTLSConfig_Good(t *testing.T) {
config := buildTLSConfig(TLSConfig{
Ciphers: "ECDHE-RSA-AES128-GCM-SHA256:TLS_AES_128_GCM_SHA256",
Protocols: "TLSv1.2,TLSv1.3",
})
if config.MinVersion != tls.VersionTLS12 {
t.Fatalf("expected min version TLS1.2, got %d", config.MinVersion)
}
if config.MaxVersion != tls.VersionTLS13 {
t.Fatalf("expected max version TLS1.3, got %d", config.MaxVersion)
}
if len(config.CipherSuites) != 2 || config.CipherSuites[0] != tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 || config.CipherSuites[1] != tls.TLS_AES_128_GCM_SHA256 {
t.Fatalf("unexpected cipher suites: %#v", config.CipherSuites)
}
}
func TestTLSRuntime_buildTLSConfig_Bad(t *testing.T) {
config := buildTLSConfig(TLSConfig{Protocols: "bogus", Ciphers: "bogus"})
if config.MinVersion != 0 || config.MaxVersion != 0 {
t.Fatalf("expected default versions for invalid input, got min=%d max=%d", config.MinVersion, config.MaxVersion)
}
if len(config.CipherSuites) != 0 {
t.Fatalf("expected no cipher suites for invalid input, got %#v", config.CipherSuites)
}
}
func TestTLSRuntime_buildTLSConfig_Ugly(t *testing.T) {
config := buildTLSConfig(TLSConfig{Protocols: "1.1:1.2:1.3", Ciphers: "AES128-GCM-SHA256,unknown"})
if config.MinVersion != tls.VersionTLS11 {
t.Fatalf("expected min version TLS1.1, got %d", config.MinVersion)
}
if config.MaxVersion != tls.VersionTLS13 {
t.Fatalf("expected max version TLS1.3, got %d", config.MaxVersion)
}
if len(config.CipherSuites) != 1 || config.CipherSuites[0] != tls.TLS_RSA_WITH_AES_128_GCM_SHA256 {
t.Fatalf("unexpected cipher suites: %#v", config.CipherSuites)
}
}