go-proxy/runtime_support.go

184 lines
4.2 KiB
Go
Raw Normal View History

package proxy
import (
"strconv"
"strings"
"time"
)
// NewRateLimiter creates a per-IP limiter, for example:
//
// rl := proxy.NewRateLimiter(cfg.RateLimit)
func NewRateLimiter(config RateLimit) *RateLimiter {
return &RateLimiter{
config: config,
buckets: make(map[string]*tokenBucket),
banned: make(map[string]time.Time),
}
}
// SetConfig swaps in a live reload value such as `proxy.RateLimit{MaxConnectionsPerMinute: 30}`.
//
// rl.SetConfig(proxy.RateLimit{MaxConnectionsPerMinute: 30, BanDurationSeconds: 300})
func (rateLimiter *RateLimiter) SetConfig(config RateLimit) {
if rateLimiter == nil {
return
}
rateLimiter.mu.Lock()
rateLimiter.config = config
rateLimiter.mu.Unlock()
}
// Allow returns true if the IP address is permitted to open a new connection. Thread-safe.
//
// if rl.Allow(conn.RemoteAddr().String()) { proceed() }
func (rateLimiter *RateLimiter) Allow(ip string) bool {
if rateLimiter == nil {
return true
}
host := remoteHost(ip)
now := time.Now().UTC()
rateLimiter.mu.Lock()
defer rateLimiter.mu.Unlock()
if rateLimiter.config.MaxConnectionsPerMinute <= 0 {
return true
}
if bannedUntil, exists := rateLimiter.banned[host]; exists {
if bannedUntil.After(now) {
return false
}
delete(rateLimiter.banned, host)
}
bucket, exists := rateLimiter.buckets[host]
if !exists {
bucket = &tokenBucket{
tokens: rateLimiter.config.MaxConnectionsPerMinute,
lastRefill: now,
}
rateLimiter.buckets[host] = bucket
}
rateLimiter.refillBucket(bucket, now)
if bucket.tokens <= 0 {
if rateLimiter.config.BanDurationSeconds > 0 {
rateLimiter.banned[host] = now.Add(time.Duration(rateLimiter.config.BanDurationSeconds) * time.Second)
}
return false
}
bucket.tokens--
return true
}
// Tick removes expired ban entries and refills all token buckets. Called every second.
//
// rl.Tick()
func (rateLimiter *RateLimiter) Tick() {
if rateLimiter == nil {
return
}
now := time.Now().UTC()
rateLimiter.mu.Lock()
defer rateLimiter.mu.Unlock()
if rateLimiter.config.MaxConnectionsPerMinute <= 0 {
return
}
for host, bannedUntil := range rateLimiter.banned {
if !bannedUntil.After(now) {
delete(rateLimiter.banned, host)
}
}
for _, bucket := range rateLimiter.buckets {
rateLimiter.refillBucket(bucket, now)
}
}
func (rateLimiter *RateLimiter) refillBucket(bucket *tokenBucket, now time.Time) {
if bucket == nil || rateLimiter.config.MaxConnectionsPerMinute <= 0 {
return
}
refillEvery := time.Minute / time.Duration(rateLimiter.config.MaxConnectionsPerMinute)
if refillEvery <= 0 {
refillEvery = time.Second
}
elapsed := now.Sub(bucket.lastRefill)
if elapsed < refillEvery {
return
}
tokensToAdd := int(elapsed / refillEvery)
bucket.tokens += tokensToAdd
if bucket.tokens > rateLimiter.config.MaxConnectionsPerMinute {
bucket.tokens = rateLimiter.config.MaxConnectionsPerMinute
}
bucket.lastRefill = bucket.lastRefill.Add(time.Duration(tokensToAdd) * refillEvery)
}
// NewCustomDiff stores the default custom difficulty override.
//
// cd := proxy.NewCustomDiff(50000)
func NewCustomDiff(globalDiff uint64) *CustomDiff {
return &CustomDiff{defaultDifficulty: globalDiff}
}
// SetGlobalDiff updates the default custom difficulty override.
//
// cd.SetGlobalDiff(100000)
func (customDiff *CustomDiff) SetGlobalDiff(globalDiff uint64) {
if customDiff == nil {
return
}
customDiff.mu.Lock()
customDiff.defaultDifficulty = globalDiff
customDiff.mu.Unlock()
}
// OnLogin parses `WALLET+50000` into `WALLET` and `50000`.
//
// cd.OnLogin(proxy.Event{Miner: miner})
func (customDiff *CustomDiff) OnLogin(event Event) {
if event.Miner == nil {
return
}
user := event.Miner.User()
index := strings.LastIndex(user, "+")
if index > 0 && index < len(user)-1 {
if value, errorValue := strconv.ParseUint(user[index+1:], 10, 64); errorValue == nil {
event.Miner.SetUser(user[:index])
event.Miner.SetCustomDiff(value)
return
}
event.Miner.SetCustomDiff(0)
return
}
if customDiff == nil {
event.Miner.SetCustomDiff(0)
return
}
customDiff.mu.RLock()
globalDiff := customDiff.defaultDifficulty
customDiff.mu.RUnlock()
if globalDiff > 0 {
event.Miner.SetCustomDiff(globalDiff)
return
}
event.Miner.SetCustomDiff(0)
}