summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/api/client/streaming/stream.go1
-rw-r--r--internal/cache/timeline/status.go76
-rw-r--r--internal/cache/timeline/status_test.go83
-rw-r--r--internal/middleware/logger.go69
-rw-r--r--internal/paging/page.go8
5 files changed, 172 insertions, 65 deletions
diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go
index 7bb65f7a1..5d7f17f94 100644
--- a/internal/api/client/streaming/stream.go
+++ b/internal/api/client/streaming/stream.go
@@ -170,7 +170,6 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
// Prefer query token else use header token.
token := cmp.Or(queryToken, headerToken)
-
if token != "" {
// Token was provided, use it to authorize stream.
diff --git a/internal/cache/timeline/status.go b/internal/cache/timeline/status.go
index 56d90e422..c0c394042 100644
--- a/internal/cache/timeline/status.go
+++ b/internal/cache/timeline/status.go
@@ -336,6 +336,14 @@ func (t *StatusTimeline) Load(
limit := page.Limit
order := page.Order()
dir := toDirection(order)
+ if limit <= 0 {
+
+ // a page limit MUST be set!
+ // this shouldn't be possible
+ // but we check anyway to stop
+ // chance of limitless db calls!
+ panic("invalid page limit")
+ }
// Use a copy of current page so
// we can repeatedly update it.
@@ -344,11 +352,11 @@ func (t *StatusTimeline) Load(
nextPg.Min.Value = lo
nextPg.Max.Value = hi
- // Interstitial meta objects.
- var metas []*StatusMeta
+ // Preallocate slice of interstitial models.
+ metas := make([]*StatusMeta, 0, limit)
- // Returned frontend API statuses.
- var apiStatuses []*apimodel.Status
+ // Preallocate slice of required status API models.
+ apiStatuses := make([]*apimodel.Status, 0, limit)
// TODO: we can remove this nil
// check when we've updated all
@@ -362,13 +370,17 @@ func (t *StatusTimeline) Load(
return nil, "", "", err
}
+ // Load a little more than limit to
+ // reduce chance of db calls below.
+ limitPtr := util.Ptr(limit + 10)
+
// First we attempt to load status
// metadata entries from the timeline
// cache, up to given limit.
metas = t.cache.Select(
util.PtrIf(lo),
util.PtrIf(hi),
- util.PtrIf(limit),
+ limitPtr,
dir,
)
@@ -384,9 +396,6 @@ func (t *StatusTimeline) Load(
lo = metas[len(metas)-1].ID
hi = metas[0].ID
- // Allocate slice of expected required API models.
- apiStatuses = make([]*apimodel.Status, 0, len(metas))
-
// Prepare frontend API models for
// the cached statuses. For now this
// also does its own extra filtering.
@@ -399,10 +408,10 @@ func (t *StatusTimeline) Load(
}
}
- // If no cached timeline statuses
- // were found for page, we need to
- // call through to the database.
- if len(apiStatuses) == 0 {
+ // If not enough cached timeline
+ // statuses were found for page,
+ // we need to call to database.
+ if len(apiStatuses) < limit {
// Pass through to main timeline db load function.
apiStatuses, lo, hi, err = loadStatusTimeline(ctx,
@@ -460,25 +469,31 @@ func loadStatusTimeline(
// vals of loaded statuses.
var lo, hi string
- // Extract paging params.
+ // Extract paging params, in particular
+ // limit is used separate to nextPg to
+ // determine the *expected* return limit,
+ // not just what we use in db queries.
+ returnLimit := nextPg.Limit
order := nextPg.Order()
- limit := nextPg.Limit
-
- // Load a little more than
- // limit to reduce db calls.
- nextPg.Limit += 10
-
- // Ensure we have a slice of meta objects to
- // use in later preparation of the API models.
- metas = xslices.GrowJust(metas[:0], nextPg.Limit)
-
- // Ensure we have a slice of required frontend API models.
- apiStatuses = xslices.GrowJust(apiStatuses[:0], nextPg.Limit)
// Perform maximum of 5 load
// attempts fetching statuses.
for i := 0; i < 5; i++ {
+ // Update page limit to the *remaining*
+ // limit of total we're expected to return.
+ nextPg.Limit = returnLimit - len(apiStatuses)
+ if nextPg.Limit <= 0 {
+
+ // We reached the end! Set lo paging value.
+ lo = apiStatuses[len(apiStatuses)-1].ID
+ break
+ }
+
+ // But load a bit more than
+ // limit to reduce db calls.
+ nextPg.Limit += 10
+
// Load next timeline statuses.
statuses, err := loadPage(nextPg)
if err != nil {
@@ -519,17 +534,8 @@ func loadStatusTimeline(
metas,
prepareAPI,
apiStatuses,
- limit,
+ returnLimit,
)
-
- // If we have anything, return
- // here. Even if below limit.
- if len(apiStatuses) > 0 {
-
- // Set returned lo status paging value.
- lo = apiStatuses[len(apiStatuses)-1].ID
- break
- }
}
return apiStatuses, lo, hi, nil
diff --git a/internal/cache/timeline/status_test.go b/internal/cache/timeline/status_test.go
index 6a288d2ea..fc7e43da8 100644
--- a/internal/cache/timeline/status_test.go
+++ b/internal/cache/timeline/status_test.go
@@ -18,11 +18,16 @@
package timeline
import (
+ "context"
+ "fmt"
"slices"
"testing"
apimodel "code.superseriousbusiness.org/gotosocial/internal/api/model"
"code.superseriousbusiness.org/gotosocial/internal/gtsmodel"
+ "code.superseriousbusiness.org/gotosocial/internal/id"
+ "code.superseriousbusiness.org/gotosocial/internal/log"
+ "code.superseriousbusiness.org/gotosocial/internal/paging"
"codeberg.org/gruf/go-structr"
"github.com/stretchr/testify/assert"
)
@@ -60,6 +65,46 @@ var testStatusMeta = []*StatusMeta{
},
}
+func TestStatusTimelineLoadLimit(t *testing.T) {
+ var tt StatusTimeline
+ tt.Init(1000)
+
+ // Prepare new context for the duration of this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ // Clone the input test status data.
+ data := slices.Clone(testStatusMeta)
+
+ // Insert test data into timeline.
+ _ = tt.cache.Insert(data...)
+
+ // Manually mark timeline as 'preloaded'.
+ tt.preloader.CheckPreload(tt.preloader.Done)
+
+ // Craft a new page for selection,
+ // setting placeholder min / max values
+ // but in particular setting a limit
+ // HIGHER than currently cached values.
+ page := new(paging.Page)
+ page.Min = paging.MinID(id.Lowest)
+ page.Max = paging.MaxID(id.Highest)
+ page.Limit = len(data) + 10
+
+ // Load crafted page from the cache. This
+ // SHOULD load all cached entries, then
+ // generate an extra 10 statuses up to limit.
+ apiStatuses, _, _, err := tt.Load(ctx,
+ page,
+ loadGeneratedStatusPage,
+ loadStatusIDsFrom(data),
+ nil, // no filtering
+ func(status *gtsmodel.Status) (*apimodel.Status, error) { return new(apimodel.Status), nil },
+ )
+ assert.NoError(t, err)
+ assert.Len(t, apiStatuses, page.Limit)
+}
+
func TestStatusTimelineUnprepare(t *testing.T) {
var tt StatusTimeline
tt.Init(1000)
@@ -301,6 +346,44 @@ func TestStatusTimelineTrim(t *testing.T) {
assert.Equal(t, before, tt.cache.Len())
}
+// loadStatusIDsFrom imitates loading of statuses of given IDs from the database, instead selecting
+// statuses with appropriate IDs from the given slice of status meta, converting them to statuses.
+func loadStatusIDsFrom(data []*StatusMeta) func(ids []string) ([]*gtsmodel.Status, error) {
+ return func(ids []string) ([]*gtsmodel.Status, error) {
+ var statuses []*gtsmodel.Status
+ for _, id := range ids {
+ i := slices.IndexFunc(data, func(s *StatusMeta) bool {
+ return s.ID == id
+ })
+ if i < 0 || i >= len(data) {
+ panic(fmt.Sprintf("could not find %s in %v", id, log.VarDump(data)))
+ }
+ statuses = append(statuses, &gtsmodel.Status{
+ ID: data[i].ID,
+ AccountID: data[i].AccountID,
+ BoostOfID: data[i].BoostOfID,
+ BoostOfAccountID: data[i].BoostOfAccountID,
+ })
+ }
+ return statuses, nil
+ }
+}
+
+// loadGeneratedStatusPage imitates loading of a given page of statuses,
+// simply generating new statuses until the given page's limit is reached.
+func loadGeneratedStatusPage(page *paging.Page) ([]*gtsmodel.Status, error) {
+ var statuses []*gtsmodel.Status
+ for range page.Limit {
+ statuses = append(statuses, &gtsmodel.Status{
+ ID: id.NewULID(),
+ AccountID: id.NewULID(),
+ BoostOfID: id.NewULID(),
+ BoostOfAccountID: id.NewULID(),
+ })
+ }
+ return statuses, nil
+}
+
// containsStatusID returns whether timeline contains a status with ID.
func containsStatusID(t *StatusTimeline, id string) bool {
return getStatusByID(t, id) != nil
diff --git a/internal/middleware/logger.go b/internal/middleware/logger.go
index 9fa245666..00e940992 100644
--- a/internal/middleware/logger.go
+++ b/internal/middleware/logger.go
@@ -21,6 +21,7 @@ import (
"fmt"
"net/http"
"runtime"
+ "strings"
"time"
"code.superseriousbusiness.org/gotosocial/internal/gtscontext"
@@ -35,19 +36,21 @@ import (
// Logger returns a gin middleware which provides request logging and panic recovery.
func Logger(logClientIP bool) gin.HandlerFunc {
return func(c *gin.Context) {
- // Initialize the logging fields
- fields := make(kv.Fields, 5, 7)
-
// Determine pre-handler time
before := time.Now()
- // defer so that we log *after the request has completed*
+ // defer so that we log *after
+ // the request has completed*
defer func() {
+
+ // Get response status code.
code := c.Writer.Status()
- path := c.Request.URL.Path
+
+ // Get request context.
+ ctx := c.Request.Context()
if r := recover(); r != nil {
- if c.Writer.Status() == 0 {
+ if code == 0 {
// No response was written, send a generic Internal Error
c.Writer.WriteHeader(http.StatusInternalServerError)
}
@@ -65,37 +68,51 @@ func Logger(logClientIP bool) gin.HandlerFunc {
WithField("stacktrace", callers).Error(err)
}
- // NOTE:
- // It is very important here that we are ONLY logging
- // the request path, and none of the query parameters.
- // Query parameters can contain sensitive information
- // and could lead to storing plaintext API keys in logs
+ // Initialize the logging fields
+ fields := make(kv.Fields, 5, 8)
// Set request logging fields
fields[0] = kv.Field{"latency", time.Since(before)}
fields[1] = kv.Field{"userAgent", c.Request.UserAgent()}
fields[2] = kv.Field{"method", c.Request.Method}
fields[3] = kv.Field{"statusCode", code}
- fields[4] = kv.Field{"path", path}
- // Set optional request logging fields.
+ // If the request contains sensitive query
+ // data only log path, else log entire URI.
+ if sensitiveQuery(c.Request.URL.RawQuery) {
+ path := c.Request.URL.Path
+ fields[4] = kv.Field{"uri", path}
+ } else {
+ uri := c.Request.RequestURI
+ fields[4] = kv.Field{"uri", uri}
+ }
+
if logClientIP {
+ // Append IP only if configured to.
fields = append(fields, kv.Field{
"clientIP", c.ClientIP(),
})
}
- ctx := c.Request.Context()
if pubKeyID := gtscontext.HTTPSignaturePubKeyID(ctx); pubKeyID != nil {
+ // Append public key ID if attached.
fields = append(fields, kv.Field{
"pubKeyID", pubKeyID.String(),
})
}
- // Create log entry with fields
- l := log.New()
- l = l.WithContext(ctx)
- l = l.WithFields(fields...)
+ if len(c.Errors) > 0 {
+ // Always attach any found errors.
+ fields = append(fields, kv.Field{
+ "errors", c.Errors,
+ })
+ }
+
+ // Create entry
+ // with fields.
+ l := log.New().
+ WithContext(ctx).
+ WithFields(fields...)
// Default is info
lvl := log.INFO
@@ -105,11 +122,6 @@ func Logger(logClientIP bool) gin.HandlerFunc {
lvl = log.ERROR
}
- if len(c.Errors) > 0 {
- // Always attach any found errors.
- l = l.WithField("errors", c.Errors)
- }
-
// Get appropriate text for this code.
statusText := http.StatusText(code)
if statusText == "" {
@@ -125,15 +137,22 @@ func Logger(logClientIP bool) gin.HandlerFunc {
// Generate a nicer looking bytecount
size := bytesize.Size(c.Writer.Size()) // #nosec G115 -- Just logging
- // Finally, write log entry with status text + body size.
+ // Write log entry with status text + body size.
l.Logf(lvl, "%s: wrote %s", statusText, size)
}()
- // Process request
+ // Process
+ // request.
c.Next()
}
}
+// sensitiveQuery checks whether given query string
+// contains sensitive data that shouldn't be logged.
+func sensitiveQuery(query string) bool {
+ return strings.Contains(query, "token")
+}
+
// gatherFrames gathers runtime frames from a frame iterator.
func gatherFrames(iter *runtime.Frames, n int) []runtime.Frame {
if iter == nil {
diff --git a/internal/paging/page.go b/internal/paging/page.go
index 6c91da6b2..708ab1bd7 100644
--- a/internal/paging/page.go
+++ b/internal/paging/page.go
@@ -278,10 +278,10 @@ func (p *Page) ToLinkURL(proto, host, path string, queryParams url.Values) *url.
if queryParams == nil {
// Allocate new query parameters.
- queryParams = make(url.Values)
+ queryParams = make(url.Values, 2)
} else {
// Before edit clone existing params.
- queryParams = cloneQuery(queryParams)
+ queryParams = cloneQuery(queryParams, 2)
}
if p.Min.Value != "" {
@@ -309,8 +309,8 @@ func (p *Page) ToLinkURL(proto, host, path string, queryParams url.Values) *url.
}
// cloneQuery clones input map of url values.
-func cloneQuery(src url.Values) url.Values {
- dst := make(url.Values, len(src))
+func cloneQuery(src url.Values, extra int) url.Values {
+ dst := make(url.Values, len(src)+extra)
for k, vs := range src {
dst[k] = slices.Clone(vs)
}