From bee8458a2d12bdd42079fcb2c4ca88ebeafe305b Mon Sep 17 00:00:00 2001 From: nya1 Date: Wed, 31 Aug 2022 12:06:14 +0200 Subject: [feature] add rate limit middleware (#741) * feat: add rate limit middleware * chore: update vendor dir * chore: update readme with new dependency * chore: add rate limit infos to swagger.md file * refactor: add ipv6 mask limiter option Add IPv6 CIDR /64 mask * refactor: increase rate limit to 1000 Address https://github.com/superseriousbusiness/gotosocial/pull/741#discussion_r945584800 Co-authored-by: tobi <31960611+tsmethurst@users.noreply.github.com> --- .../limiter/v3/drivers/store/common/context.go | 28 +++ .../ulule/limiter/v3/drivers/store/memory/cache.go | 240 +++++++++++++++++++++ .../ulule/limiter/v3/drivers/store/memory/store.go | 82 +++++++ 3 files changed, 350 insertions(+) create mode 100644 vendor/github.com/ulule/limiter/v3/drivers/store/common/context.go create mode 100644 vendor/github.com/ulule/limiter/v3/drivers/store/memory/cache.go create mode 100644 vendor/github.com/ulule/limiter/v3/drivers/store/memory/store.go (limited to 'vendor/github.com/ulule/limiter/v3/drivers/store') diff --git a/vendor/github.com/ulule/limiter/v3/drivers/store/common/context.go b/vendor/github.com/ulule/limiter/v3/drivers/store/common/context.go new file mode 100644 index 000000000..d181a460b --- /dev/null +++ b/vendor/github.com/ulule/limiter/v3/drivers/store/common/context.go @@ -0,0 +1,28 @@ +package common + +import ( + "time" + + "github.com/ulule/limiter/v3" +) + +// GetContextFromState generate a new limiter.Context from given state. +func GetContextFromState(now time.Time, rate limiter.Rate, expiration time.Time, count int64) limiter.Context { + limit := rate.Limit + remaining := int64(0) + reached := true + + if count <= limit { + remaining = limit - count + reached = false + } + + reset := expiration.Unix() + + return limiter.Context{ + Limit: limit, + Remaining: remaining, + Reset: reset, + Reached: reached, + } +} diff --git a/vendor/github.com/ulule/limiter/v3/drivers/store/memory/cache.go b/vendor/github.com/ulule/limiter/v3/drivers/store/memory/cache.go new file mode 100644 index 000000000..ce9accd9c --- /dev/null +++ b/vendor/github.com/ulule/limiter/v3/drivers/store/memory/cache.go @@ -0,0 +1,240 @@ +package memory + +import ( + "runtime" + "sync" + "time" +) + +// Forked from https://github.com/patrickmn/go-cache + +// CacheWrapper is used to ensure that the underlying cleaner goroutine used to clean expired keys will not prevent +// Cache from being garbage collected. +type CacheWrapper struct { + *Cache +} + +// A cleaner will periodically delete expired keys from cache. +type cleaner struct { + interval time.Duration + stop chan bool +} + +// Run will periodically delete expired keys from given cache until GC notify that it should stop. +func (cleaner *cleaner) Run(cache *Cache) { + ticker := time.NewTicker(cleaner.interval) + for { + select { + case <-ticker.C: + cache.Clean() + case <-cleaner.stop: + ticker.Stop() + return + } + } +} + +// stopCleaner is a callback from GC used to stop cleaner goroutine. +func stopCleaner(wrapper *CacheWrapper) { + wrapper.cleaner.stop <- true + wrapper.cleaner = nil +} + +// startCleaner will start a cleaner goroutine for given cache. +func startCleaner(cache *Cache, interval time.Duration) { + cleaner := &cleaner{ + interval: interval, + stop: make(chan bool), + } + + cache.cleaner = cleaner + go cleaner.Run(cache) +} + +// Counter is a simple counter with an expiration. +type Counter struct { + mutex sync.RWMutex + value int64 + expiration int64 +} + +// Value returns the counter current value. +func (counter *Counter) Value() int64 { + counter.mutex.RLock() + defer counter.mutex.RUnlock() + return counter.value +} + +// Expiration returns the counter expiration. +func (counter *Counter) Expiration() int64 { + counter.mutex.RLock() + defer counter.mutex.RUnlock() + return counter.expiration +} + +// Expired returns true if the counter has expired. +func (counter *Counter) Expired() bool { + counter.mutex.RLock() + defer counter.mutex.RUnlock() + + return counter.expiration == 0 || time.Now().UnixNano() > counter.expiration +} + +// Load returns the value and the expiration of this counter. +// If the counter is expired, it will use the given expiration. +func (counter *Counter) Load(expiration int64) (int64, int64) { + counter.mutex.RLock() + defer counter.mutex.RUnlock() + + if counter.expiration == 0 || time.Now().UnixNano() > counter.expiration { + return 0, expiration + } + + return counter.value, counter.expiration +} + +// Increment increments given value on this counter. +// If the counter is expired, it will use the given expiration. +// It returns its current value and expiration. +func (counter *Counter) Increment(value int64, expiration int64) (int64, int64) { + counter.mutex.Lock() + defer counter.mutex.Unlock() + + if counter.expiration == 0 || time.Now().UnixNano() > counter.expiration { + counter.value = value + counter.expiration = expiration + return counter.value, counter.expiration + } + + counter.value += value + return counter.value, counter.expiration +} + +// Cache contains a collection of counters. +type Cache struct { + counters sync.Map + cleaner *cleaner +} + +// NewCache returns a new cache. +func NewCache(cleanInterval time.Duration) *CacheWrapper { + + cache := &Cache{} + wrapper := &CacheWrapper{Cache: cache} + + if cleanInterval > 0 { + startCleaner(cache, cleanInterval) + runtime.SetFinalizer(wrapper, stopCleaner) + } + + return wrapper +} + +// LoadOrStore returns the existing counter for the key if present. +// Otherwise, it stores and returns the given counter. +// The loaded result is true if the counter was loaded, false if stored. +func (cache *Cache) LoadOrStore(key string, counter *Counter) (*Counter, bool) { + val, loaded := cache.counters.LoadOrStore(key, counter) + if val == nil { + return counter, false + } + + actual := val.(*Counter) + return actual, loaded +} + +// Load returns the counter stored in the map for a key, or nil if no counter is present. +// The ok result indicates whether counter was found in the map. +func (cache *Cache) Load(key string) (*Counter, bool) { + val, ok := cache.counters.Load(key) + if val == nil || !ok { + return nil, false + } + actual := val.(*Counter) + return actual, true +} + +// Store sets the counter for a key. +func (cache *Cache) Store(key string, counter *Counter) { + cache.counters.Store(key, counter) +} + +// Delete deletes the value for a key. +func (cache *Cache) Delete(key string) { + cache.counters.Delete(key) +} + +// Range calls handler sequentially for each key and value present in the cache. +// If handler returns false, range stops the iteration. +func (cache *Cache) Range(handler func(key string, counter *Counter)) { + cache.counters.Range(func(k interface{}, v interface{}) bool { + if v == nil { + return true + } + + key := k.(string) + counter := v.(*Counter) + + handler(key, counter) + + return true + }) +} + +// Increment increments given value on key. +// If key is undefined or expired, it will create it. +func (cache *Cache) Increment(key string, value int64, duration time.Duration) (int64, time.Time) { + expiration := time.Now().Add(duration).UnixNano() + + // If counter is in cache, try to load it first. + counter, loaded := cache.Load(key) + if loaded { + value, expiration = counter.Increment(value, expiration) + return value, time.Unix(0, expiration) + } + + // If it's not in cache, try to atomically create it. + // We do that in two step to reduce memory allocation. + counter, loaded = cache.LoadOrStore(key, &Counter{ + mutex: sync.RWMutex{}, + value: value, + expiration: expiration, + }) + if loaded { + value, expiration = counter.Increment(value, expiration) + return value, time.Unix(0, expiration) + } + + // Otherwise, it has been created, return given value. + return value, time.Unix(0, expiration) +} + +// Get returns key's value and expiration. +func (cache *Cache) Get(key string, duration time.Duration) (int64, time.Time) { + expiration := time.Now().Add(duration).UnixNano() + + counter, ok := cache.Load(key) + if !ok { + return 0, time.Unix(0, expiration) + } + + value, expiration := counter.Load(expiration) + return value, time.Unix(0, expiration) +} + +// Clean will deleted any expired keys. +func (cache *Cache) Clean() { + cache.Range(func(key string, counter *Counter) { + if counter.Expired() { + cache.Delete(key) + } + }) +} + +// Reset changes the key's value and resets the expiration. +func (cache *Cache) Reset(key string, duration time.Duration) (int64, time.Time) { + cache.Delete(key) + + expiration := time.Now().Add(duration).UnixNano() + return 0, time.Unix(0, expiration) +} diff --git a/vendor/github.com/ulule/limiter/v3/drivers/store/memory/store.go b/vendor/github.com/ulule/limiter/v3/drivers/store/memory/store.go new file mode 100644 index 000000000..21b12f6bc --- /dev/null +++ b/vendor/github.com/ulule/limiter/v3/drivers/store/memory/store.go @@ -0,0 +1,82 @@ +package memory + +import ( + "context" + "time" + + "github.com/ulule/limiter/v3" + "github.com/ulule/limiter/v3/drivers/store/common" + "github.com/ulule/limiter/v3/internal/bytebuffer" +) + +// Store is the in-memory store. +type Store struct { + // Prefix used for the key. + Prefix string + // cache used to store values in-memory. + cache *CacheWrapper +} + +// NewStore creates a new instance of memory store with defaults. +func NewStore() limiter.Store { + return NewStoreWithOptions(limiter.StoreOptions{ + Prefix: limiter.DefaultPrefix, + CleanUpInterval: limiter.DefaultCleanUpInterval, + }) +} + +// NewStoreWithOptions creates a new instance of memory store with options. +func NewStoreWithOptions(options limiter.StoreOptions) limiter.Store { + return &Store{ + Prefix: options.Prefix, + cache: NewCache(options.CleanUpInterval), + } +} + +// Get returns the limit for given identifier. +func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { + buffer := bytebuffer.New() + defer buffer.Close() + buffer.Concat(store.Prefix, ":", key) + + count, expiration := store.cache.Increment(buffer.String(), 1, rate.Period) + + lctx := common.GetContextFromState(time.Now(), rate, expiration, count) + return lctx, nil +} + +// Increment increments the limit by given count & returns the new limit value for given identifier. +func (store *Store) Increment(ctx context.Context, key string, count int64, rate limiter.Rate) (limiter.Context, error) { + buffer := bytebuffer.New() + defer buffer.Close() + buffer.Concat(store.Prefix, ":", key) + + newCount, expiration := store.cache.Increment(buffer.String(), count, rate.Period) + + lctx := common.GetContextFromState(time.Now(), rate, expiration, newCount) + return lctx, nil +} + +// Peek returns the limit for given identifier, without modification on current values. +func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { + buffer := bytebuffer.New() + defer buffer.Close() + buffer.Concat(store.Prefix, ":", key) + + count, expiration := store.cache.Get(buffer.String(), rate.Period) + + lctx := common.GetContextFromState(time.Now(), rate, expiration, count) + return lctx, nil +} + +// Reset returns the limit for given identifier. +func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { + buffer := bytebuffer.New() + defer buffer.Close() + buffer.Concat(store.Prefix, ":", key) + + count, expiration := store.cache.Reset(buffer.String(), rate.Period) + + lctx := common.GetContextFromState(time.Now(), rate, expiration, count) + return lctx, nil +} -- cgit v1.3