fix(proxy): tighten listener and limiter lifecycle

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-04 18:49:03 +00:00
parent 31a8ba558f
commit bc67e73ca0
5 changed files with 131 additions and 40 deletions

View file

@ -332,9 +332,9 @@ func refillBucket(bucket *tokenBucket, limit int, now time.Time) {
}
return
}
interval := time.Duration(60/limit) * time.Second
interval := time.Duration(time.Minute) / time.Duration(limit)
if interval <= 0 {
interval = time.Second
interval = time.Nanosecond
}
elapsed := now.Sub(bucket.lastRefill)
if elapsed < interval {

View file

@ -13,6 +13,7 @@ package proxy
import (
"net/http"
"sync"
"sync/atomic"
"time"
)
@ -22,22 +23,23 @@ import (
// p, result := proxy.New(cfg)
// if result.OK { p.Start() }
type Proxy struct {
config *Config
splitter Splitter
stats *Stats
workers *Workers
events *EventBus
servers []*Server
ticker *time.Ticker
watcher *ConfigWatcher
done chan struct{}
stopOnce sync.Once
minersMu sync.RWMutex
miners map[int64]*Miner
customDiff *CustomDiff
rateLimit *RateLimiter
httpServer *http.Server
accessLog *accessLogSink
config *Config
splitter Splitter
stats *Stats
workers *Workers
events *EventBus
servers []*Server
ticker *time.Ticker
watcher *ConfigWatcher
done chan struct{}
stopOnce sync.Once
minersMu sync.RWMutex
miners map[int64]*Miner
customDiff *CustomDiff
rateLimit *RateLimiter
httpServer *http.Server
accessLog *accessLogSink
submitCount atomic.Int64
}
// Splitter is the interface both NonceSplitter and SimpleSplitter satisfy.

View file

@ -1,6 +1,9 @@
package proxy
import "testing"
import (
"testing"
"time"
)
func TestRateLimiter_Allow(t *testing.T) {
rl := NewRateLimiter(RateLimit{MaxConnectionsPerMinute: 1, BanDurationSeconds: 1})
@ -11,3 +14,17 @@ func TestRateLimiter_Allow(t *testing.T) {
t.Fatalf("expected second call to fail")
}
}
func TestRateLimiter_Allow_ReplenishesHighLimits(t *testing.T) {
rl := NewRateLimiter(RateLimit{MaxConnectionsPerMinute: 120, BanDurationSeconds: 1})
rl.mu.Lock()
rl.buckets["1.2.3.4"] = &tokenBucket{
tokens: 0,
lastRefill: time.Now().Add(-30 * time.Second),
}
rl.mu.Unlock()
if !rl.Allow("1.2.3.4:1234") {
t.Fatalf("expected bucket to replenish at 120/min")
}
}

View file

@ -56,3 +56,30 @@ func TestProxy_Reload_Good(t *testing.T) {
t.Fatalf("expected rate limiter to be replaced with active configuration")
}
}
func TestProxy_Reload_UpdatesServers(t *testing.T) {
originalLimiter := NewRateLimiter(RateLimit{MaxConnectionsPerMinute: 1})
p := &Proxy{
config: &Config{Mode: "nicehash", Workers: WorkersByRigID},
rateLimit: originalLimiter,
servers: []*Server{
{limiter: originalLimiter},
},
}
p.Reload(&Config{
Mode: "nicehash",
Workers: WorkersByRigID,
Bind: []BindAddr{{Host: "127.0.0.1", Port: 3333}},
Pools: []PoolConfig{{URL: "pool.example:3333", Enabled: true}},
RateLimit: RateLimit{MaxConnectionsPerMinute: 10},
AccessLogFile: "",
})
if got := p.servers[0].limiter; got != p.rateLimit {
t.Fatalf("expected server limiter to be updated")
}
if p.rateLimit == originalLimiter {
t.Fatalf("expected rate limiter instance to be replaced")
}
}

View file

@ -166,10 +166,23 @@ func (p *Proxy) Start() {
}
for _, bind := range p.config.Bind {
var tlsCfg *tls.Config
if bind.TLS && p.config.TLS.Enabled {
tlsCfg = buildTLSConfig(p.config.TLS)
if bind.TLS {
if !p.config.TLS.Enabled {
p.Stop()
return
}
var result Result
tlsCfg, result = buildTLSConfig(p.config.TLS)
if !result.OK {
p.Stop()
return
}
}
server, result := NewServer(bind, tlsCfg, p.rateLimit, p.acceptMiner)
if !result.OK {
p.Stop()
return
}
server, _ := NewServer(bind, tlsCfg, p.rateLimit, p.acceptMiner)
p.servers = append(p.servers, server)
server.Start()
}
@ -234,6 +247,10 @@ func (p *Proxy) Stop() {
defer cancel()
_ = p.httpServer.Shutdown(ctx)
}
deadline := time.Now().Add(5 * time.Second)
for p.submitCount.Load() > 0 && time.Now().Before(deadline) {
time.Sleep(10 * time.Millisecond)
}
if p.accessLog != nil {
p.accessLog.Close()
}
@ -262,6 +279,11 @@ func (p *Proxy) Reload(cfg *Config) {
p.customDiff.globalDiff = cfg.CustomDiff
}
p.rateLimit = NewRateLimiter(cfg.RateLimit)
for _, server := range p.servers {
if server != nil {
server.limiter = p.rateLimit
}
}
if p.accessLog != nil {
p.accessLog.SetPath(cfg.AccessLogFile)
}
@ -289,6 +311,8 @@ func (p *Proxy) acceptMiner(conn net.Conn, localPort uint16) {
}
}
miner.onSubmit = func(m *Miner, event *SubmitEvent) {
p.submitCount.Add(1)
defer p.submitCount.Add(-1)
if p.splitter != nil {
p.splitter.OnSubmit(event)
}
@ -310,18 +334,21 @@ func (p *Proxy) acceptMiner(conn net.Conn, localPort uint16) {
miner.Start()
}
func buildTLSConfig(cfg TLSConfig) *tls.Config {
if !cfg.Enabled || cfg.CertFile == "" || cfg.KeyFile == "" {
return nil
func buildTLSConfig(cfg TLSConfig) (*tls.Config, Result) {
if !cfg.Enabled {
return nil, successResult()
}
if cfg.CertFile == "" || cfg.KeyFile == "" {
return nil, errorResult(errors.New("tls certificate or key path is empty"))
}
cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
if err != nil {
return nil
return nil, errorResult(err)
}
tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
applyTLSProtocols(tlsConfig, cfg.Protocols)
applyTLSCiphers(tlsConfig, cfg.Ciphers)
return tlsConfig
return tlsConfig, successResult()
}
func applyTLSProtocols(tlsConfig *tls.Config, protocols string) {
@ -1344,13 +1371,17 @@ func NewServer(bind BindAddr, tlsCfg *tls.Config, limiter *RateLimiter, onAccept
if onAccept == nil {
onAccept = func(net.Conn, uint16) {}
}
return &Server{
server := &Server{
addr: bind,
tlsCfg: tlsCfg,
limiter: limiter,
onAccept: onAccept,
done: make(chan struct{}),
}, successResult()
}
if result := server.listen(); !result.OK {
return nil, result
}
return server, successResult()
}
// Start begins accepting connections in a goroutine.
@ -1358,19 +1389,12 @@ func (s *Server) Start() {
if s == nil {
return
}
if result := s.listen(); !result.OK {
return
}
go func() {
ln, err := net.Listen("tcp", net.JoinHostPort(s.addr.Host, strconv.Itoa(int(s.addr.Port))))
if err != nil {
return
}
if s.tlsCfg != nil || s.addr.TLS {
if s.tlsCfg != nil {
ln = tls.NewListener(ln, s.tlsCfg)
}
}
s.listener = ln
for {
conn, err := ln.Accept()
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.done:
@ -1405,6 +1429,27 @@ func (s *Server) Stop() {
}
}
func (s *Server) listen() Result {
if s == nil {
return errorResult(errors.New("server is nil"))
}
if s.listener != nil {
return successResult()
}
if s.addr.TLS && s.tlsCfg == nil {
return errorResult(errors.New("tls listener requires a tls config"))
}
ln, err := net.Listen("tcp", net.JoinHostPort(s.addr.Host, strconv.Itoa(int(s.addr.Port))))
if err != nil {
return errorResult(err)
}
if s.tlsCfg != nil {
ln = tls.NewListener(ln, s.tlsCfg)
}
s.listener = ln
return successResult()
}
// NewConfig returns a minimal config? not used.
// NewRateLimiter, Allow, Tick are defined in core_impl.go.