diff --git a/core_impl.go b/core_impl.go index 6943672..6fa288c 100644 --- a/core_impl.go +++ b/core_impl.go @@ -332,9 +332,9 @@ func refillBucket(bucket *tokenBucket, limit int, now time.Time) { } return } - interval := time.Duration(60/limit) * time.Second + interval := time.Duration(time.Minute) / time.Duration(limit) if interval <= 0 { - interval = time.Second + interval = time.Nanosecond } elapsed := now.Sub(bucket.lastRefill) if elapsed < interval { diff --git a/proxy.go b/proxy.go index 4e2ccfa..35d994a 100644 --- a/proxy.go +++ b/proxy.go @@ -13,6 +13,7 @@ package proxy import ( "net/http" "sync" + "sync/atomic" "time" ) @@ -22,22 +23,23 @@ import ( // p, result := proxy.New(cfg) // if result.OK { p.Start() } type Proxy struct { - config *Config - splitter Splitter - stats *Stats - workers *Workers - events *EventBus - servers []*Server - ticker *time.Ticker - watcher *ConfigWatcher - done chan struct{} - stopOnce sync.Once - minersMu sync.RWMutex - miners map[int64]*Miner - customDiff *CustomDiff - rateLimit *RateLimiter - httpServer *http.Server - accessLog *accessLogSink + config *Config + splitter Splitter + stats *Stats + workers *Workers + events *EventBus + servers []*Server + ticker *time.Ticker + watcher *ConfigWatcher + done chan struct{} + stopOnce sync.Once + minersMu sync.RWMutex + miners map[int64]*Miner + customDiff *CustomDiff + rateLimit *RateLimiter + httpServer *http.Server + accessLog *accessLogSink + submitCount atomic.Int64 } // Splitter is the interface both NonceSplitter and SimpleSplitter satisfy. diff --git a/ratelimit_test.go b/ratelimit_test.go index b310dd7..6f10f00 100644 --- a/ratelimit_test.go +++ b/ratelimit_test.go @@ -1,6 +1,9 @@ package proxy -import "testing" +import ( + "testing" + "time" +) func TestRateLimiter_Allow(t *testing.T) { rl := NewRateLimiter(RateLimit{MaxConnectionsPerMinute: 1, BanDurationSeconds: 1}) @@ -11,3 +14,17 @@ func TestRateLimiter_Allow(t *testing.T) { t.Fatalf("expected second call to fail") } } + +func TestRateLimiter_Allow_ReplenishesHighLimits(t *testing.T) { + rl := NewRateLimiter(RateLimit{MaxConnectionsPerMinute: 120, BanDurationSeconds: 1}) + rl.mu.Lock() + rl.buckets["1.2.3.4"] = &tokenBucket{ + tokens: 0, + lastRefill: time.Now().Add(-30 * time.Second), + } + rl.mu.Unlock() + + if !rl.Allow("1.2.3.4:1234") { + t.Fatalf("expected bucket to replenish at 120/min") + } +} diff --git a/reload_test.go b/reload_test.go index 4a8cf43..07f7322 100644 --- a/reload_test.go +++ b/reload_test.go @@ -56,3 +56,30 @@ func TestProxy_Reload_Good(t *testing.T) { t.Fatalf("expected rate limiter to be replaced with active configuration") } } + +func TestProxy_Reload_UpdatesServers(t *testing.T) { + originalLimiter := NewRateLimiter(RateLimit{MaxConnectionsPerMinute: 1}) + p := &Proxy{ + config: &Config{Mode: "nicehash", Workers: WorkersByRigID}, + rateLimit: originalLimiter, + servers: []*Server{ + {limiter: originalLimiter}, + }, + } + + p.Reload(&Config{ + Mode: "nicehash", + Workers: WorkersByRigID, + Bind: []BindAddr{{Host: "127.0.0.1", Port: 3333}}, + Pools: []PoolConfig{{URL: "pool.example:3333", Enabled: true}}, + RateLimit: RateLimit{MaxConnectionsPerMinute: 10}, + AccessLogFile: "", + }) + + if got := p.servers[0].limiter; got != p.rateLimit { + t.Fatalf("expected server limiter to be updated") + } + if p.rateLimit == originalLimiter { + t.Fatalf("expected rate limiter instance to be replaced") + } +} diff --git a/state_impl.go b/state_impl.go index 39d8768..0d03c08 100644 --- a/state_impl.go +++ b/state_impl.go @@ -166,10 +166,23 @@ func (p *Proxy) Start() { } for _, bind := range p.config.Bind { var tlsCfg *tls.Config - if bind.TLS && p.config.TLS.Enabled { - tlsCfg = buildTLSConfig(p.config.TLS) + if bind.TLS { + if !p.config.TLS.Enabled { + p.Stop() + return + } + var result Result + tlsCfg, result = buildTLSConfig(p.config.TLS) + if !result.OK { + p.Stop() + return + } + } + server, result := NewServer(bind, tlsCfg, p.rateLimit, p.acceptMiner) + if !result.OK { + p.Stop() + return } - server, _ := NewServer(bind, tlsCfg, p.rateLimit, p.acceptMiner) p.servers = append(p.servers, server) server.Start() } @@ -234,6 +247,10 @@ func (p *Proxy) Stop() { defer cancel() _ = p.httpServer.Shutdown(ctx) } + deadline := time.Now().Add(5 * time.Second) + for p.submitCount.Load() > 0 && time.Now().Before(deadline) { + time.Sleep(10 * time.Millisecond) + } if p.accessLog != nil { p.accessLog.Close() } @@ -262,6 +279,11 @@ func (p *Proxy) Reload(cfg *Config) { p.customDiff.globalDiff = cfg.CustomDiff } p.rateLimit = NewRateLimiter(cfg.RateLimit) + for _, server := range p.servers { + if server != nil { + server.limiter = p.rateLimit + } + } if p.accessLog != nil { p.accessLog.SetPath(cfg.AccessLogFile) } @@ -289,6 +311,8 @@ func (p *Proxy) acceptMiner(conn net.Conn, localPort uint16) { } } miner.onSubmit = func(m *Miner, event *SubmitEvent) { + p.submitCount.Add(1) + defer p.submitCount.Add(-1) if p.splitter != nil { p.splitter.OnSubmit(event) } @@ -310,18 +334,21 @@ func (p *Proxy) acceptMiner(conn net.Conn, localPort uint16) { miner.Start() } -func buildTLSConfig(cfg TLSConfig) *tls.Config { - if !cfg.Enabled || cfg.CertFile == "" || cfg.KeyFile == "" { - return nil +func buildTLSConfig(cfg TLSConfig) (*tls.Config, Result) { + if !cfg.Enabled { + return nil, successResult() + } + if cfg.CertFile == "" || cfg.KeyFile == "" { + return nil, errorResult(errors.New("tls certificate or key path is empty")) } cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile) if err != nil { - return nil + return nil, errorResult(err) } tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}} applyTLSProtocols(tlsConfig, cfg.Protocols) applyTLSCiphers(tlsConfig, cfg.Ciphers) - return tlsConfig + return tlsConfig, successResult() } func applyTLSProtocols(tlsConfig *tls.Config, protocols string) { @@ -1344,13 +1371,17 @@ func NewServer(bind BindAddr, tlsCfg *tls.Config, limiter *RateLimiter, onAccept if onAccept == nil { onAccept = func(net.Conn, uint16) {} } - return &Server{ + server := &Server{ addr: bind, tlsCfg: tlsCfg, limiter: limiter, onAccept: onAccept, done: make(chan struct{}), - }, successResult() + } + if result := server.listen(); !result.OK { + return nil, result + } + return server, successResult() } // Start begins accepting connections in a goroutine. @@ -1358,19 +1389,12 @@ func (s *Server) Start() { if s == nil { return } + if result := s.listen(); !result.OK { + return + } go func() { - ln, err := net.Listen("tcp", net.JoinHostPort(s.addr.Host, strconv.Itoa(int(s.addr.Port)))) - if err != nil { - return - } - if s.tlsCfg != nil || s.addr.TLS { - if s.tlsCfg != nil { - ln = tls.NewListener(ln, s.tlsCfg) - } - } - s.listener = ln for { - conn, err := ln.Accept() + conn, err := s.listener.Accept() if err != nil { select { case <-s.done: @@ -1405,6 +1429,27 @@ func (s *Server) Stop() { } } +func (s *Server) listen() Result { + if s == nil { + return errorResult(errors.New("server is nil")) + } + if s.listener != nil { + return successResult() + } + if s.addr.TLS && s.tlsCfg == nil { + return errorResult(errors.New("tls listener requires a tls config")) + } + ln, err := net.Listen("tcp", net.JoinHostPort(s.addr.Host, strconv.Itoa(int(s.addr.Port)))) + if err != nil { + return errorResult(err) + } + if s.tlsCfg != nil { + ln = tls.NewListener(ln, s.tlsCfg) + } + s.listener = ln + return successResult() +} + // NewConfig returns a minimal config? not used. // NewRateLimiter, Allow, Tick are defined in core_impl.go.