go-proxy/runtime_support.go
Virgil 0ab02e9e4b docs(ax): clarify public api examples
Co-Authored-By: Virgil <virgil@lethean.io>
2026-04-04 13:07:29 +00:00

178 lines
4 KiB
Go

package proxy
import (
"strconv"
"strings"
"time"
)
// NewRateLimiter creates a per-IP limiter, for example:
//
// rl := proxy.NewRateLimiter(cfg.RateLimit)
func NewRateLimiter(config RateLimit) *RateLimiter {
return &RateLimiter{
cfg: config,
buckets: make(map[string]*tokenBucket),
banned: make(map[string]time.Time),
}
}
// SetConfig swaps in a live reload value such as `proxy.RateLimit{MaxConnectionsPerMinute: 30}`.
//
// rl.SetConfig(proxy.RateLimit{MaxConnectionsPerMinute: 30, BanDurationSeconds: 300})
func (rateLimiter *RateLimiter) SetConfig(config RateLimit) {
if rateLimiter == nil {
return
}
rateLimiter.mu.Lock()
rateLimiter.cfg = config
rateLimiter.mu.Unlock()
}
// Allow returns true if the IP address is permitted to open a new connection. Thread-safe.
//
// if rl.Allow(conn.RemoteAddr().String()) { proceed() }
func (rateLimiter *RateLimiter) Allow(ip string) bool {
if rateLimiter == nil {
return true
}
host := remoteHost(ip)
now := time.Now().UTC()
rateLimiter.mu.Lock()
defer rateLimiter.mu.Unlock()
if rateLimiter.cfg.MaxConnectionsPerMinute <= 0 {
return true
}
if bannedUntil, exists := rateLimiter.banned[host]; exists {
if bannedUntil.After(now) {
return false
}
delete(rateLimiter.banned, host)
}
bucket, exists := rateLimiter.buckets[host]
if !exists {
bucket = &tokenBucket{
tokens: rateLimiter.cfg.MaxConnectionsPerMinute,
lastRefill: now,
}
rateLimiter.buckets[host] = bucket
}
rateLimiter.refillBucket(bucket, now)
if bucket.tokens <= 0 {
if rateLimiter.cfg.BanDurationSeconds > 0 {
rateLimiter.banned[host] = now.Add(time.Duration(rateLimiter.cfg.BanDurationSeconds) * time.Second)
}
return false
}
bucket.tokens--
return true
}
// Tick removes expired ban entries and refills all token buckets. Called every second.
//
// rl.Tick()
func (rateLimiter *RateLimiter) Tick() {
if rateLimiter == nil {
return
}
now := time.Now().UTC()
rateLimiter.mu.Lock()
defer rateLimiter.mu.Unlock()
if rateLimiter.cfg.MaxConnectionsPerMinute <= 0 {
return
}
for host, bannedUntil := range rateLimiter.banned {
if !bannedUntil.After(now) {
delete(rateLimiter.banned, host)
}
}
for _, bucket := range rateLimiter.buckets {
rateLimiter.refillBucket(bucket, now)
}
}
func (rateLimiter *RateLimiter) refillBucket(bucket *tokenBucket, now time.Time) {
if bucket == nil || rateLimiter.cfg.MaxConnectionsPerMinute <= 0 {
return
}
refillEvery := time.Minute / time.Duration(rateLimiter.cfg.MaxConnectionsPerMinute)
if refillEvery <= 0 {
refillEvery = time.Second
}
elapsed := now.Sub(bucket.lastRefill)
if elapsed < refillEvery {
return
}
tokensToAdd := int(elapsed / refillEvery)
bucket.tokens += tokensToAdd
if bucket.tokens > rateLimiter.cfg.MaxConnectionsPerMinute {
bucket.tokens = rateLimiter.cfg.MaxConnectionsPerMinute
}
bucket.lastRefill = bucket.lastRefill.Add(time.Duration(tokensToAdd) * refillEvery)
}
// NewCustomDiff stores the default custom difficulty override.
//
// cd := proxy.NewCustomDiff(50000)
func NewCustomDiff(globalDiff uint64) *CustomDiff {
return &CustomDiff{globalDiff: globalDiff}
}
// SetGlobalDiff updates the default custom difficulty override.
//
// cd.SetGlobalDiff(100000)
func (customDiff *CustomDiff) SetGlobalDiff(globalDiff uint64) {
if customDiff == nil {
return
}
customDiff.mu.Lock()
customDiff.globalDiff = globalDiff
customDiff.mu.Unlock()
}
// OnLogin parses `WALLET+50000` into `WALLET` and `50000`.
//
// cd.OnLogin(proxy.Event{Miner: miner})
func (customDiff *CustomDiff) OnLogin(event Event) {
if event.Miner == nil {
return
}
user := event.Miner.User()
index := strings.LastIndex(user, "+")
if index > 0 && index < len(user)-1 {
if value, errorValue := strconv.ParseUint(user[index+1:], 10, 64); errorValue == nil {
event.Miner.SetUser(user[:index])
event.Miner.SetCustomDiff(value)
return
}
return
}
if customDiff == nil {
return
}
customDiff.mu.RLock()
globalDiff := customDiff.globalDiff
customDiff.mu.RUnlock()
if globalDiff > 0 {
event.Miner.SetCustomDiff(globalDiff)
}
}