diff options
author | 2022-08-31 12:06:14 +0200 | |
---|---|---|
committer | 2022-08-31 12:06:14 +0200 | |
commit | bee8458a2d12bdd42079fcb2c4ca88ebeafe305b (patch) | |
tree | c114acf28a460c1ebaa85965c89f2e7fb540eecc /vendor/github.com/ulule/limiter/v3/drivers | |
parent | [feature] Sort follow requests, followers, and following by updated_at (#774) (diff) | |
download | gotosocial-bee8458a2d12bdd42079fcb2c4ca88ebeafe305b.tar.xz |
[feature] add rate limit middleware (#741)
* feat: add rate limit middleware
* chore: update vendor dir
* chore: update readme with new dependency
* chore: add rate limit infos to swagger.md file
* refactor: add ipv6 mask limiter option
Add IPv6 CIDR /64 mask
* refactor: increase rate limit to 1000
Address https://github.com/superseriousbusiness/gotosocial/pull/741#discussion_r945584800
Co-authored-by: tobi <31960611+tsmethurst@users.noreply.github.com>
Diffstat (limited to 'vendor/github.com/ulule/limiter/v3/drivers')
5 files changed, 486 insertions, 0 deletions
diff --git a/vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/middleware.go b/vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/middleware.go new file mode 100644 index 000000000..23bad417a --- /dev/null +++ b/vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/middleware.go @@ -0,0 +1,65 @@ +package gin + +import ( + "strconv" + + "github.com/gin-gonic/gin" + + "github.com/ulule/limiter/v3" +) + +// Middleware is the middleware for gin. +type Middleware struct { + Limiter *limiter.Limiter + OnError ErrorHandler + OnLimitReached LimitReachedHandler + KeyGetter KeyGetter + ExcludedKey func(string) bool +} + +// NewMiddleware return a new instance of a gin middleware. +func NewMiddleware(limiter *limiter.Limiter, options ...Option) gin.HandlerFunc { + middleware := &Middleware{ + Limiter: limiter, + OnError: DefaultErrorHandler, + OnLimitReached: DefaultLimitReachedHandler, + KeyGetter: DefaultKeyGetter, + ExcludedKey: nil, + } + + for _, option := range options { + option.apply(middleware) + } + + return func(ctx *gin.Context) { + middleware.Handle(ctx) + } +} + +// Handle gin request. +func (middleware *Middleware) Handle(c *gin.Context) { + key := middleware.KeyGetter(c) + if middleware.ExcludedKey != nil && middleware.ExcludedKey(key) { + c.Next() + return + } + + context, err := middleware.Limiter.Get(c, key) + if err != nil { + middleware.OnError(c, err) + c.Abort() + return + } + + c.Header("X-RateLimit-Limit", strconv.FormatInt(context.Limit, 10)) + c.Header("X-RateLimit-Remaining", strconv.FormatInt(context.Remaining, 10)) + c.Header("X-RateLimit-Reset", strconv.FormatInt(context.Reset, 10)) + + if context.Reached { + middleware.OnLimitReached(c) + c.Abort() + return + } + + c.Next() +} diff --git a/vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/options.go b/vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/options.go new file mode 100644 index 000000000..604c6bc68 --- /dev/null +++ b/vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/options.go @@ -0,0 +1,71 @@ +package gin + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +// Option is used to define Middleware configuration. +type Option interface { + apply(*Middleware) +} + +type option func(*Middleware) + +func (o option) apply(middleware *Middleware) { + o(middleware) +} + +// ErrorHandler is an handler used to inform when an error has occurred. +type ErrorHandler func(c *gin.Context, err error) + +// WithErrorHandler will configure the Middleware to use the given ErrorHandler. +func WithErrorHandler(handler ErrorHandler) Option { + return option(func(middleware *Middleware) { + middleware.OnError = handler + }) +} + +// DefaultErrorHandler is the default ErrorHandler used by a new Middleware. +func DefaultErrorHandler(c *gin.Context, err error) { + panic(err) +} + +// LimitReachedHandler is an handler used to inform when the limit has exceeded. +type LimitReachedHandler func(c *gin.Context) + +// WithLimitReachedHandler will configure the Middleware to use the given LimitReachedHandler. +func WithLimitReachedHandler(handler LimitReachedHandler) Option { + return option(func(middleware *Middleware) { + middleware.OnLimitReached = handler + }) +} + +// DefaultLimitReachedHandler is the default LimitReachedHandler used by a new Middleware. +func DefaultLimitReachedHandler(c *gin.Context) { + c.String(http.StatusTooManyRequests, "Limit exceeded") +} + +// KeyGetter will define the rate limiter key given the gin Context. +type KeyGetter func(c *gin.Context) string + +// WithKeyGetter will configure the Middleware to use the given KeyGetter. +func WithKeyGetter(handler KeyGetter) Option { + return option(func(middleware *Middleware) { + middleware.KeyGetter = handler + }) +} + +// DefaultKeyGetter is the default KeyGetter used by a new Middleware. +// It returns the Client IP address. +func DefaultKeyGetter(c *gin.Context) string { + return c.ClientIP() +} + +// WithExcludedKey will configure the Middleware to ignore key(s) using the given function. +func WithExcludedKey(handler func(string) bool) Option { + return option(func(middleware *Middleware) { + middleware.ExcludedKey = handler + }) +} diff --git a/vendor/github.com/ulule/limiter/v3/drivers/store/common/context.go b/vendor/github.com/ulule/limiter/v3/drivers/store/common/context.go new file mode 100644 index 000000000..d181a460b --- /dev/null +++ b/vendor/github.com/ulule/limiter/v3/drivers/store/common/context.go @@ -0,0 +1,28 @@ +package common + +import ( + "time" + + "github.com/ulule/limiter/v3" +) + +// GetContextFromState generate a new limiter.Context from given state. +func GetContextFromState(now time.Time, rate limiter.Rate, expiration time.Time, count int64) limiter.Context { + limit := rate.Limit + remaining := int64(0) + reached := true + + if count <= limit { + remaining = limit - count + reached = false + } + + reset := expiration.Unix() + + return limiter.Context{ + Limit: limit, + Remaining: remaining, + Reset: reset, + Reached: reached, + } +} diff --git a/vendor/github.com/ulule/limiter/v3/drivers/store/memory/cache.go b/vendor/github.com/ulule/limiter/v3/drivers/store/memory/cache.go new file mode 100644 index 000000000..ce9accd9c --- /dev/null +++ b/vendor/github.com/ulule/limiter/v3/drivers/store/memory/cache.go @@ -0,0 +1,240 @@ +package memory + +import ( + "runtime" + "sync" + "time" +) + +// Forked from https://github.com/patrickmn/go-cache + +// CacheWrapper is used to ensure that the underlying cleaner goroutine used to clean expired keys will not prevent +// Cache from being garbage collected. +type CacheWrapper struct { + *Cache +} + +// A cleaner will periodically delete expired keys from cache. +type cleaner struct { + interval time.Duration + stop chan bool +} + +// Run will periodically delete expired keys from given cache until GC notify that it should stop. +func (cleaner *cleaner) Run(cache *Cache) { + ticker := time.NewTicker(cleaner.interval) + for { + select { + case <-ticker.C: + cache.Clean() + case <-cleaner.stop: + ticker.Stop() + return + } + } +} + +// stopCleaner is a callback from GC used to stop cleaner goroutine. +func stopCleaner(wrapper *CacheWrapper) { + wrapper.cleaner.stop <- true + wrapper.cleaner = nil +} + +// startCleaner will start a cleaner goroutine for given cache. +func startCleaner(cache *Cache, interval time.Duration) { + cleaner := &cleaner{ + interval: interval, + stop: make(chan bool), + } + + cache.cleaner = cleaner + go cleaner.Run(cache) +} + +// Counter is a simple counter with an expiration. +type Counter struct { + mutex sync.RWMutex + value int64 + expiration int64 +} + +// Value returns the counter current value. +func (counter *Counter) Value() int64 { + counter.mutex.RLock() + defer counter.mutex.RUnlock() + return counter.value +} + +// Expiration returns the counter expiration. +func (counter *Counter) Expiration() int64 { + counter.mutex.RLock() + defer counter.mutex.RUnlock() + return counter.expiration +} + +// Expired returns true if the counter has expired. +func (counter *Counter) Expired() bool { + counter.mutex.RLock() + defer counter.mutex.RUnlock() + + return counter.expiration == 0 || time.Now().UnixNano() > counter.expiration +} + +// Load returns the value and the expiration of this counter. +// If the counter is expired, it will use the given expiration. +func (counter *Counter) Load(expiration int64) (int64, int64) { + counter.mutex.RLock() + defer counter.mutex.RUnlock() + + if counter.expiration == 0 || time.Now().UnixNano() > counter.expiration { + return 0, expiration + } + + return counter.value, counter.expiration +} + +// Increment increments given value on this counter. +// If the counter is expired, it will use the given expiration. +// It returns its current value and expiration. +func (counter *Counter) Increment(value int64, expiration int64) (int64, int64) { + counter.mutex.Lock() + defer counter.mutex.Unlock() + + if counter.expiration == 0 || time.Now().UnixNano() > counter.expiration { + counter.value = value + counter.expiration = expiration + return counter.value, counter.expiration + } + + counter.value += value + return counter.value, counter.expiration +} + +// Cache contains a collection of counters. +type Cache struct { + counters sync.Map + cleaner *cleaner +} + +// NewCache returns a new cache. +func NewCache(cleanInterval time.Duration) *CacheWrapper { + + cache := &Cache{} + wrapper := &CacheWrapper{Cache: cache} + + if cleanInterval > 0 { + startCleaner(cache, cleanInterval) + runtime.SetFinalizer(wrapper, stopCleaner) + } + + return wrapper +} + +// LoadOrStore returns the existing counter for the key if present. +// Otherwise, it stores and returns the given counter. +// The loaded result is true if the counter was loaded, false if stored. +func (cache *Cache) LoadOrStore(key string, counter *Counter) (*Counter, bool) { + val, loaded := cache.counters.LoadOrStore(key, counter) + if val == nil { + return counter, false + } + + actual := val.(*Counter) + return actual, loaded +} + +// Load returns the counter stored in the map for a key, or nil if no counter is present. +// The ok result indicates whether counter was found in the map. +func (cache *Cache) Load(key string) (*Counter, bool) { + val, ok := cache.counters.Load(key) + if val == nil || !ok { + return nil, false + } + actual := val.(*Counter) + return actual, true +} + +// Store sets the counter for a key. +func (cache *Cache) Store(key string, counter *Counter) { + cache.counters.Store(key, counter) +} + +// Delete deletes the value for a key. +func (cache *Cache) Delete(key string) { + cache.counters.Delete(key) +} + +// Range calls handler sequentially for each key and value present in the cache. +// If handler returns false, range stops the iteration. +func (cache *Cache) Range(handler func(key string, counter *Counter)) { + cache.counters.Range(func(k interface{}, v interface{}) bool { + if v == nil { + return true + } + + key := k.(string) + counter := v.(*Counter) + + handler(key, counter) + + return true + }) +} + +// Increment increments given value on key. +// If key is undefined or expired, it will create it. +func (cache *Cache) Increment(key string, value int64, duration time.Duration) (int64, time.Time) { + expiration := time.Now().Add(duration).UnixNano() + + // If counter is in cache, try to load it first. + counter, loaded := cache.Load(key) + if loaded { + value, expiration = counter.Increment(value, expiration) + return value, time.Unix(0, expiration) + } + + // If it's not in cache, try to atomically create it. + // We do that in two step to reduce memory allocation. + counter, loaded = cache.LoadOrStore(key, &Counter{ + mutex: sync.RWMutex{}, + value: value, + expiration: expiration, + }) + if loaded { + value, expiration = counter.Increment(value, expiration) + return value, time.Unix(0, expiration) + } + + // Otherwise, it has been created, return given value. + return value, time.Unix(0, expiration) +} + +// Get returns key's value and expiration. +func (cache *Cache) Get(key string, duration time.Duration) (int64, time.Time) { + expiration := time.Now().Add(duration).UnixNano() + + counter, ok := cache.Load(key) + if !ok { + return 0, time.Unix(0, expiration) + } + + value, expiration := counter.Load(expiration) + return value, time.Unix(0, expiration) +} + +// Clean will deleted any expired keys. +func (cache *Cache) Clean() { + cache.Range(func(key string, counter *Counter) { + if counter.Expired() { + cache.Delete(key) + } + }) +} + +// Reset changes the key's value and resets the expiration. +func (cache *Cache) Reset(key string, duration time.Duration) (int64, time.Time) { + cache.Delete(key) + + expiration := time.Now().Add(duration).UnixNano() + return 0, time.Unix(0, expiration) +} diff --git a/vendor/github.com/ulule/limiter/v3/drivers/store/memory/store.go b/vendor/github.com/ulule/limiter/v3/drivers/store/memory/store.go new file mode 100644 index 000000000..21b12f6bc --- /dev/null +++ b/vendor/github.com/ulule/limiter/v3/drivers/store/memory/store.go @@ -0,0 +1,82 @@ +package memory + +import ( + "context" + "time" + + "github.com/ulule/limiter/v3" + "github.com/ulule/limiter/v3/drivers/store/common" + "github.com/ulule/limiter/v3/internal/bytebuffer" +) + +// Store is the in-memory store. +type Store struct { + // Prefix used for the key. + Prefix string + // cache used to store values in-memory. + cache *CacheWrapper +} + +// NewStore creates a new instance of memory store with defaults. +func NewStore() limiter.Store { + return NewStoreWithOptions(limiter.StoreOptions{ + Prefix: limiter.DefaultPrefix, + CleanUpInterval: limiter.DefaultCleanUpInterval, + }) +} + +// NewStoreWithOptions creates a new instance of memory store with options. +func NewStoreWithOptions(options limiter.StoreOptions) limiter.Store { + return &Store{ + Prefix: options.Prefix, + cache: NewCache(options.CleanUpInterval), + } +} + +// Get returns the limit for given identifier. +func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { + buffer := bytebuffer.New() + defer buffer.Close() + buffer.Concat(store.Prefix, ":", key) + + count, expiration := store.cache.Increment(buffer.String(), 1, rate.Period) + + lctx := common.GetContextFromState(time.Now(), rate, expiration, count) + return lctx, nil +} + +// Increment increments the limit by given count & returns the new limit value for given identifier. +func (store *Store) Increment(ctx context.Context, key string, count int64, rate limiter.Rate) (limiter.Context, error) { + buffer := bytebuffer.New() + defer buffer.Close() + buffer.Concat(store.Prefix, ":", key) + + newCount, expiration := store.cache.Increment(buffer.String(), count, rate.Period) + + lctx := common.GetContextFromState(time.Now(), rate, expiration, newCount) + return lctx, nil +} + +// Peek returns the limit for given identifier, without modification on current values. +func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { + buffer := bytebuffer.New() + defer buffer.Close() + buffer.Concat(store.Prefix, ":", key) + + count, expiration := store.cache.Get(buffer.String(), rate.Period) + + lctx := common.GetContextFromState(time.Now(), rate, expiration, count) + return lctx, nil +} + +// Reset returns the limit for given identifier. +func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { + buffer := bytebuffer.New() + defer buffer.Close() + buffer.Concat(store.Prefix, ":", key) + + count, expiration := store.cache.Reset(buffer.String(), rate.Period) + + lctx := common.GetContextFromState(time.Now(), rate, expiration, count) + return lctx, nil +} |