summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
authorLibravatar kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com>2023-12-16 11:53:42 +0000
committerLibravatar GitHub <noreply@github.com>2023-12-16 12:53:42 +0100
commitd56a8d095e8fe84422ef4098d1e1a25198da17a1 (patch)
treecb34cb50335098492c863e4630dfd0b8da10d6c5 /internal
parent[docs]: Update FAQ and ROADMAP (#2458) (diff)
downloadgotosocial-d56a8d095e8fe84422ef4098d1e1a25198da17a1.tar.xz
[performance] simpler throttling logic (#2407)
* reduce complexity of throttling logic to use 1 queue and an atomic int * use atomic add instead of CAS, add throttling test
Diffstat (limited to 'internal')
-rw-r--r--internal/api/util/response.go6
-rw-r--r--internal/middleware/ratelimit.go8
-rw-r--r--internal/middleware/throttling.go81
-rw-r--r--internal/middleware/throttling_test.go149
4 files changed, 206 insertions, 38 deletions
diff --git a/internal/api/util/response.go b/internal/api/util/response.go
index e22bac545..150d2ac2e 100644
--- a/internal/api/util/response.go
+++ b/internal/api/util/response.go
@@ -42,6 +42,12 @@ var (
StatusInternalServerErrorJSON = mustJSON(map[string]string{
"status": http.StatusText(http.StatusInternalServerError),
})
+ ErrorCapacityExceeded = mustJSON(map[string]string{
+ "error": "server capacity exceeded!",
+ })
+ ErrorRateLimitReached = mustJSON(map[string]string{
+ "error": "rate limit reached!",
+ })
EmptyJSONObject = mustJSON("{}")
EmptyJSONArray = mustJSON("[]")
diff --git a/internal/middleware/ratelimit.go b/internal/middleware/ratelimit.go
index a59a3e608..57055fe70 100644
--- a/internal/middleware/ratelimit.go
+++ b/internal/middleware/ratelimit.go
@@ -29,6 +29,8 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/ulule/limiter/v3"
"github.com/ulule/limiter/v3/drivers/store/memory"
+
+ apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
)
const rateLimitPeriod = 5 * time.Minute
@@ -141,10 +143,12 @@ func RateLimit(limit int, exceptions []string) gin.HandlerFunc {
if context.Reached {
// Return JSON error message for
// consistency with other endpoints.
- c.AbortWithStatusJSON(
+ apiutil.Data(c,
http.StatusTooManyRequests,
- gin.H{"error": "rate limit reached"},
+ apiutil.AppJSON,
+ apiutil.ErrorRateLimitReached,
)
+ c.Abort()
return
}
diff --git a/internal/middleware/throttling.go b/internal/middleware/throttling.go
index 589671547..33f46f175 100644
--- a/internal/middleware/throttling.go
+++ b/internal/middleware/throttling.go
@@ -29,9 +29,12 @@ import (
"net/http"
"runtime"
"strconv"
+ "sync/atomic"
"time"
"github.com/gin-gonic/gin"
+
+ apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
)
// token represents a request that is being processed.
@@ -80,55 +83,61 @@ func Throttle(cpuMultiplier int, retryAfter time.Duration) gin.HandlerFunc {
}
var (
- limit = runtime.GOMAXPROCS(0) * cpuMultiplier
- backlogLimit = limit * cpuMultiplier
- backlogChannelSize = limit + backlogLimit
- tokens = make(chan token, limit)
- backlogTokens = make(chan token, backlogChannelSize)
- retryAfterStr = strconv.FormatUint(uint64(retryAfter/time.Second), 10)
+ limit = runtime.GOMAXPROCS(0) * cpuMultiplier
+ queueLimit = limit * cpuMultiplier
+ tokens = make(chan token, limit)
+ requestCount = atomic.Int64{}
+ retryAfterStr = strconv.FormatUint(uint64(retryAfter/time.Second), 10)
)
- // prefill token channels
+ // prefill token channel
for i := 0; i < limit; i++ {
tokens <- token{}
}
- for i := 0; i < backlogChannelSize; i++ {
- backlogTokens <- token{}
- }
return func(c *gin.Context) {
- // inside this select, the caller tries to get a backlog token
+ // Always decrement request counter.
+ defer func() { requestCount.Add(-1) }()
+
+ // Increment request count.
+ n := requestCount.Add(1)
+
+ // Check whether the request
+ // count is over queue limit.
+ if n > int64(queueLimit) {
+ c.Header("Retry-After", retryAfterStr)
+ apiutil.Data(c,
+ http.StatusTooManyRequests,
+ apiutil.AppJSON,
+ apiutil.ErrorCapacityExceeded,
+ )
+ c.Abort()
+ return
+ }
+
+ // Sit and wait in the
+ // queue for free token.
select {
+
case <-c.Request.Context().Done():
- // request context has been canceled already
+ // request context has
+ // been canceled already.
return
- case btok := <-backlogTokens:
+
+ case tok := <-tokens:
+ // caller has successfully
+ // received a token, allowing
+ // request to be processed.
+
defer func() {
- // when we're finished, return the backlog token to the bucket
- backlogTokens <- btok
+ // when we're finished, return
+ // this token to the bucket.
+ tokens <- tok
}()
- // inside *this* select, the caller has a backlog token,
- // and they're waiting for their turn to be processed
- select {
- case <-c.Request.Context().Done():
- // the request context has been canceled already
- return
- case tok := <-tokens:
- // the caller gets a token, so their request can now be processed
- defer func() {
- // whatever happens to the request, put the
- // token back in the bucket when we're finished
- tokens <- tok
- }()
- c.Next() // <- finally process the caller's request
- }
-
- default:
- // we don't have space in the backlog queue
- c.Header("Retry-After", retryAfterStr)
- c.JSON(http.StatusTooManyRequests, gin.H{"error": "server capacity exceeded"})
- c.Abort()
+ // Process
+ // request!
+ c.Next()
}
}
}
diff --git a/internal/middleware/throttling_test.go b/internal/middleware/throttling_test.go
new file mode 100644
index 000000000..2a716ec53
--- /dev/null
+++ b/internal/middleware/throttling_test.go
@@ -0,0 +1,149 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see <http://www.gnu.org/licenses/>.
+
+/*
+ The code in this file is adapted from MIT-licensed code in github.com/go-chi/chi. Thanks chi (thi)!
+
+ See: https://github.com/go-chi/chi/blob/e6baba61759b26ddf7b14d1e02d1da81a4d76c08/middleware/throttle.go
+
+ And: https://github.com/sponsors/pkieltyka
+*/
+
+package middleware_test
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "runtime"
+ "strconv"
+ "testing"
+ "time"
+
+ "github.com/gin-gonic/gin"
+ "github.com/superseriousbusiness/gotosocial/internal/middleware"
+)
+
+func TestThrottlingMiddleware(t *testing.T) {
+ testThrottlingMiddleware(t, 2, time.Second*10)
+ testThrottlingMiddleware(t, 4, time.Second*15)
+ testThrottlingMiddleware(t, 8, time.Second*30)
+}
+
+func testThrottlingMiddleware(t *testing.T, cpuMulti int, retryAfter time.Duration) {
+ // Calculate expected request limit + queue.
+ limit := runtime.GOMAXPROCS(0) * cpuMulti
+ queueLimit := limit * cpuMulti
+
+ // Calculate expected retry-after header string.
+ retryAfterStr := strconv.FormatUint(uint64(retryAfter/time.Second), 10)
+
+ // Gin test http engine
+ // (used for ctx init).
+ e := gin.New()
+
+ // Add middleware to the gin engine handler stack.
+ middleware := middleware.Throttle(cpuMulti, retryAfter)
+ e.Use(middleware)
+
+ // Set the blocking gin handler.
+ handler := blockingHandler()
+ e.Handle("GET", "/", handler)
+
+ var cncls []func()
+
+ for i := 0; i < queueLimit+limit; i++ {
+ // Prepare a gin test context.
+ r := httptest.NewRequest("GET", "/", nil)
+ rw := httptest.NewRecorder()
+
+ // Wrap request with new cancel context.
+ ctx, cncl := context.WithCancel(r.Context())
+ r = r.WithContext(ctx)
+
+ // Pass req through
+ // engine handler.
+ go e.ServeHTTP(rw, r)
+ time.Sleep(time.Millisecond)
+
+ // Get http result.
+ res := rw.Result()
+
+ if i < queueLimit {
+
+ // Check status == 200 (default, i.e not set).
+ if res.StatusCode != http.StatusOK {
+ t.Fatalf("status code was set (%d) with queueLimit=%d and request=%d", res.StatusCode, queueLimit, i)
+ }
+
+ // Add cancel to func slice.
+ cncls = append(cncls, cncl)
+
+ } else {
+
+ // Check the returned status code is expected.
+ if res.StatusCode != http.StatusTooManyRequests {
+ t.Fatalf("did not return status 429 (%d) with queueLimit=%d and request=%d", res.StatusCode, queueLimit, i)
+ }
+
+ // Check the returned retry-after header is set.
+ if res.Header.Get("Retry-After") != retryAfterStr {
+ t.Fatalf("did not return retry-after %s with queueLimit=%d and request=%d", retryAfterStr, queueLimit, i)
+ }
+
+ // Cancel on return.
+ defer cncl()
+
+ }
+ }
+
+ // Cancel all blocked reqs.
+ for _, cncl := range cncls {
+ cncl()
+ }
+ time.Sleep(time.Second)
+
+ // Check a bunchh more requests
+ // can now make it through after
+ // previous requests were released!
+ for i := 0; i < limit; i++ {
+
+ // Prepare a gin test context.
+ r := httptest.NewRequest("GET", "/", nil)
+ rw := httptest.NewRecorder()
+
+ // Pass req through
+ // engine handler.
+ go e.ServeHTTP(rw, r)
+ time.Sleep(time.Millisecond)
+
+ // Get http result.
+ res := rw.Result()
+
+ // Check status == 200 (default, i.e not set).
+ if res.StatusCode != http.StatusOK {
+ t.Fatalf("status code was set (%d) with queueLimit=%d and request=%d", res.StatusCode, queueLimit, i)
+ }
+ }
+}
+
+func blockingHandler() gin.HandlerFunc {
+ return func(ctx *gin.Context) {
+ <-ctx.Done()
+ ctx.Status(201) // specifically not 200
+ }
+}