summaryrefslogtreecommitdiff
path: root/vendor/github.com/ulule/limiter/v3/drivers
diff options
context:
space:
mode:
authorLibravatar nya1 <nya1git@imap.cc>2022-08-31 12:06:14 +0200
committerLibravatar GitHub <noreply@github.com>2022-08-31 12:06:14 +0200
commitbee8458a2d12bdd42079fcb2c4ca88ebeafe305b (patch)
treec114acf28a460c1ebaa85965c89f2e7fb540eecc /vendor/github.com/ulule/limiter/v3/drivers
parent[feature] Sort follow requests, followers, and following by updated_at (#774) (diff)
downloadgotosocial-bee8458a2d12bdd42079fcb2c4ca88ebeafe305b.tar.xz
[feature] add rate limit middleware (#741)
* feat: add rate limit middleware * chore: update vendor dir * chore: update readme with new dependency * chore: add rate limit infos to swagger.md file * refactor: add ipv6 mask limiter option Add IPv6 CIDR /64 mask * refactor: increase rate limit to 1000 Address https://github.com/superseriousbusiness/gotosocial/pull/741#discussion_r945584800 Co-authored-by: tobi <31960611+tsmethurst@users.noreply.github.com>
Diffstat (limited to 'vendor/github.com/ulule/limiter/v3/drivers')
-rw-r--r--vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/middleware.go65
-rw-r--r--vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/options.go71
-rw-r--r--vendor/github.com/ulule/limiter/v3/drivers/store/common/context.go28
-rw-r--r--vendor/github.com/ulule/limiter/v3/drivers/store/memory/cache.go240
-rw-r--r--vendor/github.com/ulule/limiter/v3/drivers/store/memory/store.go82
5 files changed, 486 insertions, 0 deletions
diff --git a/vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/middleware.go b/vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/middleware.go
new file mode 100644
index 000000000..23bad417a
--- /dev/null
+++ b/vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/middleware.go
@@ -0,0 +1,65 @@
+package gin
+
+import (
+ "strconv"
+
+ "github.com/gin-gonic/gin"
+
+ "github.com/ulule/limiter/v3"
+)
+
+// Middleware is the middleware for gin.
+type Middleware struct {
+ Limiter *limiter.Limiter
+ OnError ErrorHandler
+ OnLimitReached LimitReachedHandler
+ KeyGetter KeyGetter
+ ExcludedKey func(string) bool
+}
+
+// NewMiddleware return a new instance of a gin middleware.
+func NewMiddleware(limiter *limiter.Limiter, options ...Option) gin.HandlerFunc {
+ middleware := &Middleware{
+ Limiter: limiter,
+ OnError: DefaultErrorHandler,
+ OnLimitReached: DefaultLimitReachedHandler,
+ KeyGetter: DefaultKeyGetter,
+ ExcludedKey: nil,
+ }
+
+ for _, option := range options {
+ option.apply(middleware)
+ }
+
+ return func(ctx *gin.Context) {
+ middleware.Handle(ctx)
+ }
+}
+
+// Handle gin request.
+func (middleware *Middleware) Handle(c *gin.Context) {
+ key := middleware.KeyGetter(c)
+ if middleware.ExcludedKey != nil && middleware.ExcludedKey(key) {
+ c.Next()
+ return
+ }
+
+ context, err := middleware.Limiter.Get(c, key)
+ if err != nil {
+ middleware.OnError(c, err)
+ c.Abort()
+ return
+ }
+
+ c.Header("X-RateLimit-Limit", strconv.FormatInt(context.Limit, 10))
+ c.Header("X-RateLimit-Remaining", strconv.FormatInt(context.Remaining, 10))
+ c.Header("X-RateLimit-Reset", strconv.FormatInt(context.Reset, 10))
+
+ if context.Reached {
+ middleware.OnLimitReached(c)
+ c.Abort()
+ return
+ }
+
+ c.Next()
+}
diff --git a/vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/options.go b/vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/options.go
new file mode 100644
index 000000000..604c6bc68
--- /dev/null
+++ b/vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/options.go
@@ -0,0 +1,71 @@
+package gin
+
+import (
+ "net/http"
+
+ "github.com/gin-gonic/gin"
+)
+
+// Option is used to define Middleware configuration.
+type Option interface {
+ apply(*Middleware)
+}
+
+type option func(*Middleware)
+
+func (o option) apply(middleware *Middleware) {
+ o(middleware)
+}
+
+// ErrorHandler is an handler used to inform when an error has occurred.
+type ErrorHandler func(c *gin.Context, err error)
+
+// WithErrorHandler will configure the Middleware to use the given ErrorHandler.
+func WithErrorHandler(handler ErrorHandler) Option {
+ return option(func(middleware *Middleware) {
+ middleware.OnError = handler
+ })
+}
+
+// DefaultErrorHandler is the default ErrorHandler used by a new Middleware.
+func DefaultErrorHandler(c *gin.Context, err error) {
+ panic(err)
+}
+
+// LimitReachedHandler is an handler used to inform when the limit has exceeded.
+type LimitReachedHandler func(c *gin.Context)
+
+// WithLimitReachedHandler will configure the Middleware to use the given LimitReachedHandler.
+func WithLimitReachedHandler(handler LimitReachedHandler) Option {
+ return option(func(middleware *Middleware) {
+ middleware.OnLimitReached = handler
+ })
+}
+
+// DefaultLimitReachedHandler is the default LimitReachedHandler used by a new Middleware.
+func DefaultLimitReachedHandler(c *gin.Context) {
+ c.String(http.StatusTooManyRequests, "Limit exceeded")
+}
+
+// KeyGetter will define the rate limiter key given the gin Context.
+type KeyGetter func(c *gin.Context) string
+
+// WithKeyGetter will configure the Middleware to use the given KeyGetter.
+func WithKeyGetter(handler KeyGetter) Option {
+ return option(func(middleware *Middleware) {
+ middleware.KeyGetter = handler
+ })
+}
+
+// DefaultKeyGetter is the default KeyGetter used by a new Middleware.
+// It returns the Client IP address.
+func DefaultKeyGetter(c *gin.Context) string {
+ return c.ClientIP()
+}
+
+// WithExcludedKey will configure the Middleware to ignore key(s) using the given function.
+func WithExcludedKey(handler func(string) bool) Option {
+ return option(func(middleware *Middleware) {
+ middleware.ExcludedKey = handler
+ })
+}
diff --git a/vendor/github.com/ulule/limiter/v3/drivers/store/common/context.go b/vendor/github.com/ulule/limiter/v3/drivers/store/common/context.go
new file mode 100644
index 000000000..d181a460b
--- /dev/null
+++ b/vendor/github.com/ulule/limiter/v3/drivers/store/common/context.go
@@ -0,0 +1,28 @@
+package common
+
+import (
+ "time"
+
+ "github.com/ulule/limiter/v3"
+)
+
+// GetContextFromState generate a new limiter.Context from given state.
+func GetContextFromState(now time.Time, rate limiter.Rate, expiration time.Time, count int64) limiter.Context {
+ limit := rate.Limit
+ remaining := int64(0)
+ reached := true
+
+ if count <= limit {
+ remaining = limit - count
+ reached = false
+ }
+
+ reset := expiration.Unix()
+
+ return limiter.Context{
+ Limit: limit,
+ Remaining: remaining,
+ Reset: reset,
+ Reached: reached,
+ }
+}
diff --git a/vendor/github.com/ulule/limiter/v3/drivers/store/memory/cache.go b/vendor/github.com/ulule/limiter/v3/drivers/store/memory/cache.go
new file mode 100644
index 000000000..ce9accd9c
--- /dev/null
+++ b/vendor/github.com/ulule/limiter/v3/drivers/store/memory/cache.go
@@ -0,0 +1,240 @@
+package memory
+
+import (
+ "runtime"
+ "sync"
+ "time"
+)
+
+// Forked from https://github.com/patrickmn/go-cache
+
+// CacheWrapper is used to ensure that the underlying cleaner goroutine used to clean expired keys will not prevent
+// Cache from being garbage collected.
+type CacheWrapper struct {
+ *Cache
+}
+
+// A cleaner will periodically delete expired keys from cache.
+type cleaner struct {
+ interval time.Duration
+ stop chan bool
+}
+
+// Run will periodically delete expired keys from given cache until GC notify that it should stop.
+func (cleaner *cleaner) Run(cache *Cache) {
+ ticker := time.NewTicker(cleaner.interval)
+ for {
+ select {
+ case <-ticker.C:
+ cache.Clean()
+ case <-cleaner.stop:
+ ticker.Stop()
+ return
+ }
+ }
+}
+
+// stopCleaner is a callback from GC used to stop cleaner goroutine.
+func stopCleaner(wrapper *CacheWrapper) {
+ wrapper.cleaner.stop <- true
+ wrapper.cleaner = nil
+}
+
+// startCleaner will start a cleaner goroutine for given cache.
+func startCleaner(cache *Cache, interval time.Duration) {
+ cleaner := &cleaner{
+ interval: interval,
+ stop: make(chan bool),
+ }
+
+ cache.cleaner = cleaner
+ go cleaner.Run(cache)
+}
+
+// Counter is a simple counter with an expiration.
+type Counter struct {
+ mutex sync.RWMutex
+ value int64
+ expiration int64
+}
+
+// Value returns the counter current value.
+func (counter *Counter) Value() int64 {
+ counter.mutex.RLock()
+ defer counter.mutex.RUnlock()
+ return counter.value
+}
+
+// Expiration returns the counter expiration.
+func (counter *Counter) Expiration() int64 {
+ counter.mutex.RLock()
+ defer counter.mutex.RUnlock()
+ return counter.expiration
+}
+
+// Expired returns true if the counter has expired.
+func (counter *Counter) Expired() bool {
+ counter.mutex.RLock()
+ defer counter.mutex.RUnlock()
+
+ return counter.expiration == 0 || time.Now().UnixNano() > counter.expiration
+}
+
+// Load returns the value and the expiration of this counter.
+// If the counter is expired, it will use the given expiration.
+func (counter *Counter) Load(expiration int64) (int64, int64) {
+ counter.mutex.RLock()
+ defer counter.mutex.RUnlock()
+
+ if counter.expiration == 0 || time.Now().UnixNano() > counter.expiration {
+ return 0, expiration
+ }
+
+ return counter.value, counter.expiration
+}
+
+// Increment increments given value on this counter.
+// If the counter is expired, it will use the given expiration.
+// It returns its current value and expiration.
+func (counter *Counter) Increment(value int64, expiration int64) (int64, int64) {
+ counter.mutex.Lock()
+ defer counter.mutex.Unlock()
+
+ if counter.expiration == 0 || time.Now().UnixNano() > counter.expiration {
+ counter.value = value
+ counter.expiration = expiration
+ return counter.value, counter.expiration
+ }
+
+ counter.value += value
+ return counter.value, counter.expiration
+}
+
+// Cache contains a collection of counters.
+type Cache struct {
+ counters sync.Map
+ cleaner *cleaner
+}
+
+// NewCache returns a new cache.
+func NewCache(cleanInterval time.Duration) *CacheWrapper {
+
+ cache := &Cache{}
+ wrapper := &CacheWrapper{Cache: cache}
+
+ if cleanInterval > 0 {
+ startCleaner(cache, cleanInterval)
+ runtime.SetFinalizer(wrapper, stopCleaner)
+ }
+
+ return wrapper
+}
+
+// LoadOrStore returns the existing counter for the key if present.
+// Otherwise, it stores and returns the given counter.
+// The loaded result is true if the counter was loaded, false if stored.
+func (cache *Cache) LoadOrStore(key string, counter *Counter) (*Counter, bool) {
+ val, loaded := cache.counters.LoadOrStore(key, counter)
+ if val == nil {
+ return counter, false
+ }
+
+ actual := val.(*Counter)
+ return actual, loaded
+}
+
+// Load returns the counter stored in the map for a key, or nil if no counter is present.
+// The ok result indicates whether counter was found in the map.
+func (cache *Cache) Load(key string) (*Counter, bool) {
+ val, ok := cache.counters.Load(key)
+ if val == nil || !ok {
+ return nil, false
+ }
+ actual := val.(*Counter)
+ return actual, true
+}
+
+// Store sets the counter for a key.
+func (cache *Cache) Store(key string, counter *Counter) {
+ cache.counters.Store(key, counter)
+}
+
+// Delete deletes the value for a key.
+func (cache *Cache) Delete(key string) {
+ cache.counters.Delete(key)
+}
+
+// Range calls handler sequentially for each key and value present in the cache.
+// If handler returns false, range stops the iteration.
+func (cache *Cache) Range(handler func(key string, counter *Counter)) {
+ cache.counters.Range(func(k interface{}, v interface{}) bool {
+ if v == nil {
+ return true
+ }
+
+ key := k.(string)
+ counter := v.(*Counter)
+
+ handler(key, counter)
+
+ return true
+ })
+}
+
+// Increment increments given value on key.
+// If key is undefined or expired, it will create it.
+func (cache *Cache) Increment(key string, value int64, duration time.Duration) (int64, time.Time) {
+ expiration := time.Now().Add(duration).UnixNano()
+
+ // If counter is in cache, try to load it first.
+ counter, loaded := cache.Load(key)
+ if loaded {
+ value, expiration = counter.Increment(value, expiration)
+ return value, time.Unix(0, expiration)
+ }
+
+ // If it's not in cache, try to atomically create it.
+ // We do that in two step to reduce memory allocation.
+ counter, loaded = cache.LoadOrStore(key, &Counter{
+ mutex: sync.RWMutex{},
+ value: value,
+ expiration: expiration,
+ })
+ if loaded {
+ value, expiration = counter.Increment(value, expiration)
+ return value, time.Unix(0, expiration)
+ }
+
+ // Otherwise, it has been created, return given value.
+ return value, time.Unix(0, expiration)
+}
+
+// Get returns key's value and expiration.
+func (cache *Cache) Get(key string, duration time.Duration) (int64, time.Time) {
+ expiration := time.Now().Add(duration).UnixNano()
+
+ counter, ok := cache.Load(key)
+ if !ok {
+ return 0, time.Unix(0, expiration)
+ }
+
+ value, expiration := counter.Load(expiration)
+ return value, time.Unix(0, expiration)
+}
+
+// Clean will deleted any expired keys.
+func (cache *Cache) Clean() {
+ cache.Range(func(key string, counter *Counter) {
+ if counter.Expired() {
+ cache.Delete(key)
+ }
+ })
+}
+
+// Reset changes the key's value and resets the expiration.
+func (cache *Cache) Reset(key string, duration time.Duration) (int64, time.Time) {
+ cache.Delete(key)
+
+ expiration := time.Now().Add(duration).UnixNano()
+ return 0, time.Unix(0, expiration)
+}
diff --git a/vendor/github.com/ulule/limiter/v3/drivers/store/memory/store.go b/vendor/github.com/ulule/limiter/v3/drivers/store/memory/store.go
new file mode 100644
index 000000000..21b12f6bc
--- /dev/null
+++ b/vendor/github.com/ulule/limiter/v3/drivers/store/memory/store.go
@@ -0,0 +1,82 @@
+package memory
+
+import (
+ "context"
+ "time"
+
+ "github.com/ulule/limiter/v3"
+ "github.com/ulule/limiter/v3/drivers/store/common"
+ "github.com/ulule/limiter/v3/internal/bytebuffer"
+)
+
+// Store is the in-memory store.
+type Store struct {
+ // Prefix used for the key.
+ Prefix string
+ // cache used to store values in-memory.
+ cache *CacheWrapper
+}
+
+// NewStore creates a new instance of memory store with defaults.
+func NewStore() limiter.Store {
+ return NewStoreWithOptions(limiter.StoreOptions{
+ Prefix: limiter.DefaultPrefix,
+ CleanUpInterval: limiter.DefaultCleanUpInterval,
+ })
+}
+
+// NewStoreWithOptions creates a new instance of memory store with options.
+func NewStoreWithOptions(options limiter.StoreOptions) limiter.Store {
+ return &Store{
+ Prefix: options.Prefix,
+ cache: NewCache(options.CleanUpInterval),
+ }
+}
+
+// Get returns the limit for given identifier.
+func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
+ buffer := bytebuffer.New()
+ defer buffer.Close()
+ buffer.Concat(store.Prefix, ":", key)
+
+ count, expiration := store.cache.Increment(buffer.String(), 1, rate.Period)
+
+ lctx := common.GetContextFromState(time.Now(), rate, expiration, count)
+ return lctx, nil
+}
+
+// Increment increments the limit by given count & returns the new limit value for given identifier.
+func (store *Store) Increment(ctx context.Context, key string, count int64, rate limiter.Rate) (limiter.Context, error) {
+ buffer := bytebuffer.New()
+ defer buffer.Close()
+ buffer.Concat(store.Prefix, ":", key)
+
+ newCount, expiration := store.cache.Increment(buffer.String(), count, rate.Period)
+
+ lctx := common.GetContextFromState(time.Now(), rate, expiration, newCount)
+ return lctx, nil
+}
+
+// Peek returns the limit for given identifier, without modification on current values.
+func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
+ buffer := bytebuffer.New()
+ defer buffer.Close()
+ buffer.Concat(store.Prefix, ":", key)
+
+ count, expiration := store.cache.Get(buffer.String(), rate.Period)
+
+ lctx := common.GetContextFromState(time.Now(), rate, expiration, count)
+ return lctx, nil
+}
+
+// Reset returns the limit for given identifier.
+func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
+ buffer := bytebuffer.New()
+ defer buffer.Close()
+ buffer.Concat(store.Prefix, ":", key)
+
+ count, expiration := store.cache.Reset(buffer.String(), rate.Period)
+
+ lctx := common.GetContextFromState(time.Now(), rate, expiration, count)
+ return lctx, nil
+}