summaryrefslogtreecommitdiff
path: root/internal/processing
diff options
context:
space:
mode:
authorLibravatar kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com>2024-02-20 18:07:49 +0000
committerLibravatar GitHub <noreply@github.com>2024-02-20 18:07:49 +0000
commit291e18099050ff9e19b8ee25c2ffad68d9baafef (patch)
tree0ad1be36b4c958830d1371f3b9a32f017c5dcff0 /internal/processing
parent[feature] Add `requested_by` to relationship model (#2672) (diff)
downloadgotosocial-291e18099050ff9e19b8ee25c2ffad68d9baafef.tar.xz
[bugfix] fix possible mutex lockup during streaming code (#2633)
* rewrite Stream{} to use much less mutex locking, update related code * use new context for the stream context * ensure stream gets closed on return of writeTo / readFrom WSConn() * ensure stream write timeout gets cancelled * remove embedded context type from Stream{}, reformat log messages for consistency * use c.Request.Context() for context passed into Stream().Open() * only return 1 boolean, fix tests to expect multiple stream types in messages * changes to ping logic * further improved ping logic * don't export unused function types, update message sending to only include relevant stream type * ensure stream gets closed :facepalm: * update to error log on failed json marshal (instead of panic) * inverse websocket read error checking to _ignore_ expected close errors
Diffstat (limited to 'internal/processing')
-rw-r--r--internal/processing/stream/delete.go34
-rw-r--r--internal/processing/stream/notification.go21
-rw-r--r--internal/processing/stream/notification_test.go7
-rw-r--r--internal/processing/stream/open.go97
-rw-r--r--internal/processing/stream/statusupdate.go21
-rw-r--r--internal/processing/stream/statusupdate_test.go7
-rw-r--r--internal/processing/stream/stream.go46
-rw-r--r--internal/processing/stream/update.go18
-rw-r--r--internal/processing/workers/fromclientapi_test.go25
-rw-r--r--internal/processing/workers/fromfediapi_test.go53
-rw-r--r--internal/processing/workers/surfacenotify.go5
-rw-r--r--internal/processing/workers/surfacetimeline.go18
12 files changed, 92 insertions, 260 deletions
diff --git a/internal/processing/stream/delete.go b/internal/processing/stream/delete.go
index d7745eef8..1c61b98d3 100644
--- a/internal/processing/stream/delete.go
+++ b/internal/processing/stream/delete.go
@@ -18,38 +18,16 @@
package stream
import (
- "fmt"
- "strings"
+ "context"
"github.com/superseriousbusiness/gotosocial/internal/stream"
)
// Delete streams the delete of the given statusID to *ALL* open streams.
-func (p *Processor) Delete(statusID string) error {
- errs := []string{}
-
- // get all account IDs with open streams
- accountIDs := []string{}
- p.streamMap.Range(func(k interface{}, _ interface{}) bool {
- key, ok := k.(string)
- if !ok {
- panic("streamMap key was not a string (account id)")
- }
-
- accountIDs = append(accountIDs, key)
- return true
+func (p *Processor) Delete(ctx context.Context, statusID string) {
+ p.streams.PostAll(ctx, stream.Message{
+ Payload: statusID,
+ Event: stream.EventTypeDelete,
+ Stream: stream.AllStatusTimelines,
})
-
- // stream the delete to every account
- for _, accountID := range accountIDs {
- if err := p.toAccount(statusID, stream.EventTypeDelete, stream.AllStatusTimelines, accountID); err != nil {
- errs = append(errs, err.Error())
- }
- }
-
- if len(errs) != 0 {
- return fmt.Errorf("one or more errors streaming status delete: %s", strings.Join(errs, ";"))
- }
-
- return nil
}
diff --git a/internal/processing/stream/notification.go b/internal/processing/stream/notification.go
index 63d7c5d11..a16da11e6 100644
--- a/internal/processing/stream/notification.go
+++ b/internal/processing/stream/notification.go
@@ -18,20 +18,29 @@
package stream
import (
+ "context"
"encoding/json"
- "fmt"
+ "codeberg.org/gruf/go-byteutil"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/stream"
)
// Notify streams the given notification to any open, appropriate streams belonging to the given account.
-func (p *Processor) Notify(n *apimodel.Notification, account *gtsmodel.Account) error {
- bytes, err := json.Marshal(n)
+func (p *Processor) Notify(ctx context.Context, account *gtsmodel.Account, notif *apimodel.Notification) {
+ b, err := json.Marshal(notif)
if err != nil {
- return fmt.Errorf("error marshalling notification to json: %s", err)
+ log.Errorf(ctx, "error marshaling json: %v", err)
+ return
}
-
- return p.toAccount(string(bytes), stream.EventTypeNotification, []string{stream.TimelineNotifications, stream.TimelineHome}, account.ID)
+ p.streams.Post(ctx, account.ID, stream.Message{
+ Payload: byteutil.B2S(b),
+ Event: stream.EventTypeNotification,
+ Stream: []string{
+ stream.TimelineNotifications,
+ stream.TimelineHome,
+ },
+ })
}
diff --git a/internal/processing/stream/notification_test.go b/internal/processing/stream/notification_test.go
index 2138f0025..e12f23abe 100644
--- a/internal/processing/stream/notification_test.go
+++ b/internal/processing/stream/notification_test.go
@@ -49,10 +49,11 @@ func (suite *NotificationTestSuite) TestStreamNotification() {
Account: followAccountAPIModel,
}
- err = suite.streamProcessor.Notify(notification, account)
- suite.NoError(err)
+ suite.streamProcessor.Notify(context.Background(), account, notification)
+
+ msg, ok := openStream.Recv(context.Background())
+ suite.True(ok)
- msg := <-openStream.Messages
dst := new(bytes.Buffer)
err = json.Indent(dst, []byte(msg.Payload), "", " ")
suite.NoError(err)
diff --git a/internal/processing/stream/open.go b/internal/processing/stream/open.go
index 1c041309f..2f2bbd4a3 100644
--- a/internal/processing/stream/open.go
+++ b/internal/processing/stream/open.go
@@ -19,13 +19,10 @@ package stream
import (
"context"
- "errors"
- "fmt"
"codeberg.org/gruf/go-kv"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/stream"
)
@@ -37,97 +34,5 @@ func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamT
{"streamType", streamType},
}...)
l.Debug("received open stream request")
-
- var (
- streamID string
- err error
- )
-
- // Each stream needs a unique ID so we know to close it.
- streamID, err = id.NewRandomULID()
- if err != nil {
- return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %w", err))
- }
-
- // Each stream can be subscibed to multiple types.
- // Record them in a set, and include the initial one
- // if it was given to us.
- streamTypes := map[string]any{}
- if streamType != "" {
- streamTypes[streamType] = true
- }
-
- newStream := &stream.Stream{
- ID: streamID,
- StreamTypes: streamTypes,
- Messages: make(chan *stream.Message, 100),
- Hangup: make(chan interface{}, 1),
- Connected: true,
- }
- go p.waitToCloseStream(account, newStream)
-
- v, ok := p.streamMap.Load(account.ID)
- if ok {
- // There is an entry in the streamMap
- // for this account. Parse it out.
- streamsForAccount, ok := v.(*stream.StreamsForAccount)
- if !ok {
- return nil, gtserror.NewErrorInternalError(errors.New("stream map error"))
- }
-
- // Append new stream to existing entry.
- streamsForAccount.Lock()
- streamsForAccount.Streams = append(streamsForAccount.Streams, newStream)
- streamsForAccount.Unlock()
- } else {
- // There is no entry in the streamMap for
- // this account yet. Create one and store it.
- p.streamMap.Store(account.ID, &stream.StreamsForAccount{
- Streams: []*stream.Stream{
- newStream,
- },
- })
- }
-
- return newStream, nil
-}
-
-// waitToCloseStream waits until the hangup channel is closed for the given stream.
-// It then iterates through the map of streams stored by the processor, removes the stream from it,
-// and then closes the messages channel of the stream to indicate that the channel should no longer be read from.
-func (p *Processor) waitToCloseStream(account *gtsmodel.Account, thisStream *stream.Stream) {
- <-thisStream.Hangup // wait for a hangup message
-
- // lock the stream to prevent more messages being put in it while we work
- thisStream.Lock()
- defer thisStream.Unlock()
-
- // indicate the stream is no longer connected
- thisStream.Connected = false
-
- // load and parse the entry for this account from the stream map
- v, ok := p.streamMap.Load(account.ID)
- if !ok || v == nil {
- return
- }
- streamsForAccount, ok := v.(*stream.StreamsForAccount)
- if !ok {
- return
- }
-
- // lock the streams for account while we remove this stream from its slice
- streamsForAccount.Lock()
- defer streamsForAccount.Unlock()
-
- // put everything into modified streams *except* the stream we're removing
- modifiedStreams := []*stream.Stream{}
- for _, s := range streamsForAccount.Streams {
- if s.ID != thisStream.ID {
- modifiedStreams = append(modifiedStreams, s)
- }
- }
- streamsForAccount.Streams = modifiedStreams
-
- // finally close the messages channel so no more messages can be read from it
- close(thisStream.Messages)
+ return p.streams.Open(account.ID, streamType), nil
}
diff --git a/internal/processing/stream/statusupdate.go b/internal/processing/stream/statusupdate.go
index fd8e388ce..bd4658873 100644
--- a/internal/processing/stream/statusupdate.go
+++ b/internal/processing/stream/statusupdate.go
@@ -18,21 +18,26 @@
package stream
import (
+ "context"
"encoding/json"
- "fmt"
+ "codeberg.org/gruf/go-byteutil"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/stream"
)
-// StatusUpdate streams the given edited status to any open, appropriate
-// streams belonging to the given account.
-func (p *Processor) StatusUpdate(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error {
- bytes, err := json.Marshal(s)
+// StatusUpdate streams the given edited status to any open, appropriate streams belonging to the given account.
+func (p *Processor) StatusUpdate(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) {
+ b, err := json.Marshal(status)
if err != nil {
- return fmt.Errorf("error marshalling status to json: %s", err)
+ log.Errorf(ctx, "error marshaling json: %v", err)
+ return
}
-
- return p.toAccount(string(bytes), stream.EventTypeStatusUpdate, streamTypes, account.ID)
+ p.streams.Post(ctx, account.ID, stream.Message{
+ Payload: byteutil.B2S(b),
+ Event: stream.EventTypeStatusUpdate,
+ Stream: []string{streamType},
+ })
}
diff --git a/internal/processing/stream/statusupdate_test.go b/internal/processing/stream/statusupdate_test.go
index 7b987b412..8814c966f 100644
--- a/internal/processing/stream/statusupdate_test.go
+++ b/internal/processing/stream/statusupdate_test.go
@@ -42,10 +42,11 @@ func (suite *StatusUpdateTestSuite) TestStreamNotification() {
apiStatus, err := typeutils.NewConverter(&suite.state).StatusToAPIStatus(context.Background(), editedStatus, account)
suite.NoError(err)
- err = suite.streamProcessor.StatusUpdate(apiStatus, account, []string{stream.TimelineHome})
- suite.NoError(err)
+ suite.streamProcessor.StatusUpdate(context.Background(), account, apiStatus, stream.TimelineHome)
+
+ msg, ok := openStream.Recv(context.Background())
+ suite.True(ok)
- msg := <-openStream.Messages
dst := new(bytes.Buffer)
err = json.Indent(dst, []byte(msg.Payload), "", " ")
suite.NoError(err)
diff --git a/internal/processing/stream/stream.go b/internal/processing/stream/stream.go
index a5b3b9386..0b7285b58 100644
--- a/internal/processing/stream/stream.go
+++ b/internal/processing/stream/stream.go
@@ -18,8 +18,6 @@
package stream
import (
- "sync"
-
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/stream"
@@ -28,53 +26,13 @@ import (
type Processor struct {
state *state.State
oauthServer oauth.Server
- streamMap *sync.Map
+ streams stream.Streams
}
func New(state *state.State, oauthServer oauth.Server) Processor {
return Processor{
state: state,
oauthServer: oauthServer,
- streamMap: &sync.Map{},
- }
-}
-
-// toAccount streams the given payload with the given event type to any streams currently open for the given account ID.
-func (p *Processor) toAccount(payload string, event string, streamTypes []string, accountID string) error {
- // Load all streams open for this account.
- v, ok := p.streamMap.Load(accountID)
- if !ok {
- return nil // No entry = nothing to stream.
+ streams: stream.Streams{},
}
- streamsForAccount := v.(*stream.StreamsForAccount)
-
- streamsForAccount.Lock()
- defer streamsForAccount.Unlock()
-
- for _, s := range streamsForAccount.Streams {
- s.Lock()
- defer s.Unlock()
-
- if !s.Connected {
- continue
- }
-
- typeLoop:
- for _, streamType := range streamTypes {
- if _, found := s.StreamTypes[streamType]; found {
- s.Messages <- &stream.Message{
- Stream: []string{streamType},
- Event: string(event),
- Payload: payload,
- }
-
- // Break out to the outer loop,
- // to avoid sending duplicates of
- // the same event to the same stream.
- break typeLoop
- }
- }
- }
-
- return nil
}
diff --git a/internal/processing/stream/update.go b/internal/processing/stream/update.go
index ee70bda11..a84763d51 100644
--- a/internal/processing/stream/update.go
+++ b/internal/processing/stream/update.go
@@ -18,20 +18,26 @@
package stream
import (
+ "context"
"encoding/json"
- "fmt"
+ "codeberg.org/gruf/go-byteutil"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/stream"
)
// Update streams the given update to any open, appropriate streams belonging to the given account.
-func (p *Processor) Update(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error {
- bytes, err := json.Marshal(s)
+func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) {
+ b, err := json.Marshal(status)
if err != nil {
- return fmt.Errorf("error marshalling status to json: %s", err)
+ log.Errorf(ctx, "error marshaling json: %v", err)
+ return
}
-
- return p.toAccount(string(bytes), stream.EventTypeUpdate, streamTypes, account.ID)
+ p.streams.Post(ctx, account.ID, stream.Message{
+ Payload: byteutil.B2S(b),
+ Event: stream.EventTypeUpdate,
+ Stream: []string{streamType},
+ })
}
diff --git a/internal/processing/workers/fromclientapi_test.go b/internal/processing/workers/fromclientapi_test.go
index 05526f437..3d3630b11 100644
--- a/internal/processing/workers/fromclientapi_test.go
+++ b/internal/processing/workers/fromclientapi_test.go
@@ -116,23 +116,20 @@ func (suite *FromClientAPITestSuite) checkStreamed(
expectPayload string,
expectEventType string,
) {
- var msg *stream.Message
-streamLoop:
- for {
- select {
- case msg = <-str.Messages:
- break streamLoop // Got it.
- case <-time.After(5 * time.Second):
- break streamLoop // Didn't get it.
- }
- }
- if expectMessage && msg == nil {
- suite.FailNow("expected a message but message was nil")
+ // Set a 5s timeout on context.
+ ctx := context.Background()
+ ctx, cncl := context.WithTimeout(ctx, time.Second*5)
+ defer cncl()
+
+ msg, ok := str.Recv(ctx)
+
+ if expectMessage && !ok {
+ suite.FailNow("expected a message but message was not received")
}
- if !expectMessage && msg != nil {
- suite.FailNow("expected no message but message was not nil")
+ if !expectMessage && ok {
+ suite.FailNow("expected no message but message was received")
}
if expectPayload != "" && msg.Payload != expectPayload {
diff --git a/internal/processing/workers/fromfediapi_test.go b/internal/processing/workers/fromfediapi_test.go
index 799eaf2dc..446355628 100644
--- a/internal/processing/workers/fromfediapi_test.go
+++ b/internal/processing/workers/fromfediapi_test.go
@@ -130,14 +130,9 @@ func (suite *FromFediAPITestSuite) TestProcessReplyMention() {
suite.Equal(replyingStatus.ID, notif.StatusID)
suite.False(*notif.Read)
- // the notification should be streamed
- var msg *stream.Message
- select {
- case msg = <-wssStream.Messages:
- // fine
- case <-time.After(5 * time.Second):
- suite.FailNow("no message from wssStream")
- }
+ ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
+ msg, ok := wssStream.Recv(ctx)
+ suite.True(ok)
suite.Equal(stream.EventTypeNotification, msg.Event)
suite.NotEmpty(msg.Payload)
@@ -203,14 +198,10 @@ func (suite *FromFediAPITestSuite) TestProcessFave() {
suite.Equal(fave.StatusID, notif.StatusID)
suite.False(*notif.Read)
- // 2. a notification should be streamed
- var msg *stream.Message
- select {
- case msg = <-wssStream.Messages:
- // fine
- case <-time.After(5 * time.Second):
- suite.FailNow("no message from wssStream")
- }
+ ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
+ msg, ok := wssStream.Recv(ctx)
+ suite.True(ok)
+
suite.Equal(stream.EventTypeNotification, msg.Event)
suite.NotEmpty(msg.Payload)
suite.EqualValues([]string{stream.TimelineNotifications}, msg.Stream)
@@ -277,7 +268,9 @@ func (suite *FromFediAPITestSuite) TestProcessFaveWithDifferentReceivingAccount(
suite.False(*notif.Read)
// 2. no notification should be streamed to the account that received the fave message, because they weren't the target
- suite.Empty(wssStream.Messages)
+ ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
+ _, ok := wssStream.Recv(ctx)
+ suite.False(ok)
}
func (suite *FromFediAPITestSuite) TestProcessAccountDelete() {
@@ -405,14 +398,10 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() {
})
suite.NoError(err)
- // a notification should be streamed
- var msg *stream.Message
- select {
- case msg = <-wssStream.Messages:
- // fine
- case <-time.After(5 * time.Second):
- suite.FailNow("no message from wssStream")
- }
+ ctx, _ = context.WithTimeout(ctx, time.Second*5)
+ msg, ok := wssStream.Recv(context.Background())
+ suite.True(ok)
+
suite.Equal(stream.EventTypeNotification, msg.Event)
suite.NotEmpty(msg.Payload)
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)
@@ -423,7 +412,7 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() {
suite.Equal(originAccount.ID, notif.Account.ID)
// no messages should have been sent out, since we didn't need to federate an accept
- suite.Empty(suite.httpClient.SentMessages)
+ suite.Empty(&suite.httpClient.SentMessages)
}
func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() {
@@ -503,14 +492,10 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() {
suite.Equal(originAccount.URI, accept.To)
suite.Equal("Accept", accept.Type)
- // a notification should be streamed
- var msg *stream.Message
- select {
- case msg = <-wssStream.Messages:
- // fine
- case <-time.After(5 * time.Second):
- suite.FailNow("no message from wssStream")
- }
+ ctx, _ = context.WithTimeout(ctx, time.Second*5)
+ msg, ok := wssStream.Recv(context.Background())
+ suite.True(ok)
+
suite.Equal(stream.EventTypeNotification, msg.Event)
suite.NotEmpty(msg.Payload)
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)
diff --git a/internal/processing/workers/surfacenotify.go b/internal/processing/workers/surfacenotify.go
index 39798f45e..a8c36248c 100644
--- a/internal/processing/workers/surfacenotify.go
+++ b/internal/processing/workers/surfacenotify.go
@@ -394,10 +394,7 @@ func (s *surface) notify(
if err != nil {
return gtserror.Newf("error converting notification to api representation: %w", err)
}
-
- if err := s.stream.Notify(apiNotif, targetAccount); err != nil {
- return gtserror.Newf("error streaming notification to account: %w", err)
- }
+ s.stream.Notify(ctx, targetAccount, apiNotif)
return nil
}
diff --git a/internal/processing/workers/surfacetimeline.go b/internal/processing/workers/surfacetimeline.go
index e63b8a7c0..14634f846 100644
--- a/internal/processing/workers/surfacetimeline.go
+++ b/internal/processing/workers/surfacetimeline.go
@@ -348,11 +348,7 @@ func (s *surface) timelineStatus(
err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err)
return true, err
}
-
- if err := s.stream.Update(apiStatus, account, []string{streamType}); err != nil {
- err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err)
- return true, err
- }
+ s.stream.Update(ctx, account, apiStatus, streamType)
return true, nil
}
@@ -363,12 +359,11 @@ func (s *surface) deleteStatusFromTimelines(ctx context.Context, statusID string
if err := s.state.Timelines.Home.WipeItemFromAllTimelines(ctx, statusID); err != nil {
return err
}
-
if err := s.state.Timelines.List.WipeItemFromAllTimelines(ctx, statusID); err != nil {
return err
}
-
- return s.stream.Delete(statusID)
+ s.stream.Delete(ctx, statusID)
+ return nil
}
// invalidateStatusFromTimelines does cache invalidation on the given status by
@@ -555,11 +550,6 @@ func (s *surface) timelineStreamStatusUpdate(
err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err)
return err
}
-
- if err := s.stream.StatusUpdate(apiStatus, account, []string{streamType}); err != nil {
- err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err)
- return err
- }
-
+ s.stream.StatusUpdate(ctx, account, apiStatus, streamType)
return nil
}