fix(proxy): tighten listener and limiter lifecycle
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
31a8ba558f
commit
bc67e73ca0
5 changed files with 131 additions and 40 deletions
|
|
@ -332,9 +332,9 @@ func refillBucket(bucket *tokenBucket, limit int, now time.Time) {
|
|||
}
|
||||
return
|
||||
}
|
||||
interval := time.Duration(60/limit) * time.Second
|
||||
interval := time.Duration(time.Minute) / time.Duration(limit)
|
||||
if interval <= 0 {
|
||||
interval = time.Second
|
||||
interval = time.Nanosecond
|
||||
}
|
||||
elapsed := now.Sub(bucket.lastRefill)
|
||||
if elapsed < interval {
|
||||
|
|
|
|||
34
proxy.go
34
proxy.go
|
|
@ -13,6 +13,7 @@ package proxy
|
|||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
|
@ -22,22 +23,23 @@ import (
|
|||
// p, result := proxy.New(cfg)
|
||||
// if result.OK { p.Start() }
|
||||
type Proxy struct {
|
||||
config *Config
|
||||
splitter Splitter
|
||||
stats *Stats
|
||||
workers *Workers
|
||||
events *EventBus
|
||||
servers []*Server
|
||||
ticker *time.Ticker
|
||||
watcher *ConfigWatcher
|
||||
done chan struct{}
|
||||
stopOnce sync.Once
|
||||
minersMu sync.RWMutex
|
||||
miners map[int64]*Miner
|
||||
customDiff *CustomDiff
|
||||
rateLimit *RateLimiter
|
||||
httpServer *http.Server
|
||||
accessLog *accessLogSink
|
||||
config *Config
|
||||
splitter Splitter
|
||||
stats *Stats
|
||||
workers *Workers
|
||||
events *EventBus
|
||||
servers []*Server
|
||||
ticker *time.Ticker
|
||||
watcher *ConfigWatcher
|
||||
done chan struct{}
|
||||
stopOnce sync.Once
|
||||
minersMu sync.RWMutex
|
||||
miners map[int64]*Miner
|
||||
customDiff *CustomDiff
|
||||
rateLimit *RateLimiter
|
||||
httpServer *http.Server
|
||||
accessLog *accessLogSink
|
||||
submitCount atomic.Int64
|
||||
}
|
||||
|
||||
// Splitter is the interface both NonceSplitter and SimpleSplitter satisfy.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
package proxy
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRateLimiter_Allow(t *testing.T) {
|
||||
rl := NewRateLimiter(RateLimit{MaxConnectionsPerMinute: 1, BanDurationSeconds: 1})
|
||||
|
|
@ -11,3 +14,17 @@ func TestRateLimiter_Allow(t *testing.T) {
|
|||
t.Fatalf("expected second call to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiter_Allow_ReplenishesHighLimits(t *testing.T) {
|
||||
rl := NewRateLimiter(RateLimit{MaxConnectionsPerMinute: 120, BanDurationSeconds: 1})
|
||||
rl.mu.Lock()
|
||||
rl.buckets["1.2.3.4"] = &tokenBucket{
|
||||
tokens: 0,
|
||||
lastRefill: time.Now().Add(-30 * time.Second),
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
|
||||
if !rl.Allow("1.2.3.4:1234") {
|
||||
t.Fatalf("expected bucket to replenish at 120/min")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -56,3 +56,30 @@ func TestProxy_Reload_Good(t *testing.T) {
|
|||
t.Fatalf("expected rate limiter to be replaced with active configuration")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxy_Reload_UpdatesServers(t *testing.T) {
|
||||
originalLimiter := NewRateLimiter(RateLimit{MaxConnectionsPerMinute: 1})
|
||||
p := &Proxy{
|
||||
config: &Config{Mode: "nicehash", Workers: WorkersByRigID},
|
||||
rateLimit: originalLimiter,
|
||||
servers: []*Server{
|
||||
{limiter: originalLimiter},
|
||||
},
|
||||
}
|
||||
|
||||
p.Reload(&Config{
|
||||
Mode: "nicehash",
|
||||
Workers: WorkersByRigID,
|
||||
Bind: []BindAddr{{Host: "127.0.0.1", Port: 3333}},
|
||||
Pools: []PoolConfig{{URL: "pool.example:3333", Enabled: true}},
|
||||
RateLimit: RateLimit{MaxConnectionsPerMinute: 10},
|
||||
AccessLogFile: "",
|
||||
})
|
||||
|
||||
if got := p.servers[0].limiter; got != p.rateLimit {
|
||||
t.Fatalf("expected server limiter to be updated")
|
||||
}
|
||||
if p.rateLimit == originalLimiter {
|
||||
t.Fatalf("expected rate limiter instance to be replaced")
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -166,10 +166,23 @@ func (p *Proxy) Start() {
|
|||
}
|
||||
for _, bind := range p.config.Bind {
|
||||
var tlsCfg *tls.Config
|
||||
if bind.TLS && p.config.TLS.Enabled {
|
||||
tlsCfg = buildTLSConfig(p.config.TLS)
|
||||
if bind.TLS {
|
||||
if !p.config.TLS.Enabled {
|
||||
p.Stop()
|
||||
return
|
||||
}
|
||||
var result Result
|
||||
tlsCfg, result = buildTLSConfig(p.config.TLS)
|
||||
if !result.OK {
|
||||
p.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
server, result := NewServer(bind, tlsCfg, p.rateLimit, p.acceptMiner)
|
||||
if !result.OK {
|
||||
p.Stop()
|
||||
return
|
||||
}
|
||||
server, _ := NewServer(bind, tlsCfg, p.rateLimit, p.acceptMiner)
|
||||
p.servers = append(p.servers, server)
|
||||
server.Start()
|
||||
}
|
||||
|
|
@ -234,6 +247,10 @@ func (p *Proxy) Stop() {
|
|||
defer cancel()
|
||||
_ = p.httpServer.Shutdown(ctx)
|
||||
}
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
for p.submitCount.Load() > 0 && time.Now().Before(deadline) {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
if p.accessLog != nil {
|
||||
p.accessLog.Close()
|
||||
}
|
||||
|
|
@ -262,6 +279,11 @@ func (p *Proxy) Reload(cfg *Config) {
|
|||
p.customDiff.globalDiff = cfg.CustomDiff
|
||||
}
|
||||
p.rateLimit = NewRateLimiter(cfg.RateLimit)
|
||||
for _, server := range p.servers {
|
||||
if server != nil {
|
||||
server.limiter = p.rateLimit
|
||||
}
|
||||
}
|
||||
if p.accessLog != nil {
|
||||
p.accessLog.SetPath(cfg.AccessLogFile)
|
||||
}
|
||||
|
|
@ -289,6 +311,8 @@ func (p *Proxy) acceptMiner(conn net.Conn, localPort uint16) {
|
|||
}
|
||||
}
|
||||
miner.onSubmit = func(m *Miner, event *SubmitEvent) {
|
||||
p.submitCount.Add(1)
|
||||
defer p.submitCount.Add(-1)
|
||||
if p.splitter != nil {
|
||||
p.splitter.OnSubmit(event)
|
||||
}
|
||||
|
|
@ -310,18 +334,21 @@ func (p *Proxy) acceptMiner(conn net.Conn, localPort uint16) {
|
|||
miner.Start()
|
||||
}
|
||||
|
||||
func buildTLSConfig(cfg TLSConfig) *tls.Config {
|
||||
if !cfg.Enabled || cfg.CertFile == "" || cfg.KeyFile == "" {
|
||||
return nil
|
||||
func buildTLSConfig(cfg TLSConfig) (*tls.Config, Result) {
|
||||
if !cfg.Enabled {
|
||||
return nil, successResult()
|
||||
}
|
||||
if cfg.CertFile == "" || cfg.KeyFile == "" {
|
||||
return nil, errorResult(errors.New("tls certificate or key path is empty"))
|
||||
}
|
||||
cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.KeyFile)
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, errorResult(err)
|
||||
}
|
||||
tlsConfig := &tls.Config{Certificates: []tls.Certificate{cert}}
|
||||
applyTLSProtocols(tlsConfig, cfg.Protocols)
|
||||
applyTLSCiphers(tlsConfig, cfg.Ciphers)
|
||||
return tlsConfig
|
||||
return tlsConfig, successResult()
|
||||
}
|
||||
|
||||
func applyTLSProtocols(tlsConfig *tls.Config, protocols string) {
|
||||
|
|
@ -1344,13 +1371,17 @@ func NewServer(bind BindAddr, tlsCfg *tls.Config, limiter *RateLimiter, onAccept
|
|||
if onAccept == nil {
|
||||
onAccept = func(net.Conn, uint16) {}
|
||||
}
|
||||
return &Server{
|
||||
server := &Server{
|
||||
addr: bind,
|
||||
tlsCfg: tlsCfg,
|
||||
limiter: limiter,
|
||||
onAccept: onAccept,
|
||||
done: make(chan struct{}),
|
||||
}, successResult()
|
||||
}
|
||||
if result := server.listen(); !result.OK {
|
||||
return nil, result
|
||||
}
|
||||
return server, successResult()
|
||||
}
|
||||
|
||||
// Start begins accepting connections in a goroutine.
|
||||
|
|
@ -1358,19 +1389,12 @@ func (s *Server) Start() {
|
|||
if s == nil {
|
||||
return
|
||||
}
|
||||
if result := s.listen(); !result.OK {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", net.JoinHostPort(s.addr.Host, strconv.Itoa(int(s.addr.Port))))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.tlsCfg != nil || s.addr.TLS {
|
||||
if s.tlsCfg != nil {
|
||||
ln = tls.NewListener(ln, s.tlsCfg)
|
||||
}
|
||||
}
|
||||
s.listener = ln
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.done:
|
||||
|
|
@ -1405,6 +1429,27 @@ func (s *Server) Stop() {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) listen() Result {
|
||||
if s == nil {
|
||||
return errorResult(errors.New("server is nil"))
|
||||
}
|
||||
if s.listener != nil {
|
||||
return successResult()
|
||||
}
|
||||
if s.addr.TLS && s.tlsCfg == nil {
|
||||
return errorResult(errors.New("tls listener requires a tls config"))
|
||||
}
|
||||
ln, err := net.Listen("tcp", net.JoinHostPort(s.addr.Host, strconv.Itoa(int(s.addr.Port))))
|
||||
if err != nil {
|
||||
return errorResult(err)
|
||||
}
|
||||
if s.tlsCfg != nil {
|
||||
ln = tls.NewListener(ln, s.tlsCfg)
|
||||
}
|
||||
s.listener = ln
|
||||
return successResult()
|
||||
}
|
||||
|
||||
// NewConfig returns a minimal config? not used.
|
||||
|
||||
// NewRateLimiter, Allow, Tick are defined in core_impl.go.
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue