diff options
author | 2024-02-20 18:07:49 +0000 | |
---|---|---|
committer | 2024-02-20 18:07:49 +0000 | |
commit | 291e18099050ff9e19b8ee25c2ffad68d9baafef (patch) | |
tree | 0ad1be36b4c958830d1371f3b9a32f017c5dcff0 /internal/processing | |
parent | [feature] Add `requested_by` to relationship model (#2672) (diff) | |
download | gotosocial-291e18099050ff9e19b8ee25c2ffad68d9baafef.tar.xz |
[bugfix] fix possible mutex lockup during streaming code (#2633)
* rewrite Stream{} to use much less mutex locking, update related code
* use new context for the stream context
* ensure stream gets closed on return of writeTo / readFrom WSConn()
* ensure stream write timeout gets cancelled
* remove embedded context type from Stream{}, reformat log messages for consistency
* use c.Request.Context() for context passed into Stream().Open()
* only return 1 boolean, fix tests to expect multiple stream types in messages
* changes to ping logic
* further improved ping logic
* don't export unused function types, update message sending to only include relevant stream type
* ensure stream gets closed :facepalm:
* update to error log on failed json marshal (instead of panic)
* inverse websocket read error checking to _ignore_ expected close errors
Diffstat (limited to 'internal/processing')
-rw-r--r-- | internal/processing/stream/delete.go | 34 | ||||
-rw-r--r-- | internal/processing/stream/notification.go | 21 | ||||
-rw-r--r-- | internal/processing/stream/notification_test.go | 7 | ||||
-rw-r--r-- | internal/processing/stream/open.go | 97 | ||||
-rw-r--r-- | internal/processing/stream/statusupdate.go | 21 | ||||
-rw-r--r-- | internal/processing/stream/statusupdate_test.go | 7 | ||||
-rw-r--r-- | internal/processing/stream/stream.go | 46 | ||||
-rw-r--r-- | internal/processing/stream/update.go | 18 | ||||
-rw-r--r-- | internal/processing/workers/fromclientapi_test.go | 25 | ||||
-rw-r--r-- | internal/processing/workers/fromfediapi_test.go | 53 | ||||
-rw-r--r-- | internal/processing/workers/surfacenotify.go | 5 | ||||
-rw-r--r-- | internal/processing/workers/surfacetimeline.go | 18 |
12 files changed, 92 insertions, 260 deletions
diff --git a/internal/processing/stream/delete.go b/internal/processing/stream/delete.go index d7745eef8..1c61b98d3 100644 --- a/internal/processing/stream/delete.go +++ b/internal/processing/stream/delete.go @@ -18,38 +18,16 @@ package stream import ( - "fmt" - "strings" + "context" "github.com/superseriousbusiness/gotosocial/internal/stream" ) // Delete streams the delete of the given statusID to *ALL* open streams. -func (p *Processor) Delete(statusID string) error { - errs := []string{} - - // get all account IDs with open streams - accountIDs := []string{} - p.streamMap.Range(func(k interface{}, _ interface{}) bool { - key, ok := k.(string) - if !ok { - panic("streamMap key was not a string (account id)") - } - - accountIDs = append(accountIDs, key) - return true +func (p *Processor) Delete(ctx context.Context, statusID string) { + p.streams.PostAll(ctx, stream.Message{ + Payload: statusID, + Event: stream.EventTypeDelete, + Stream: stream.AllStatusTimelines, }) - - // stream the delete to every account - for _, accountID := range accountIDs { - if err := p.toAccount(statusID, stream.EventTypeDelete, stream.AllStatusTimelines, accountID); err != nil { - errs = append(errs, err.Error()) - } - } - - if len(errs) != 0 { - return fmt.Errorf("one or more errors streaming status delete: %s", strings.Join(errs, ";")) - } - - return nil } diff --git a/internal/processing/stream/notification.go b/internal/processing/stream/notification.go index 63d7c5d11..a16da11e6 100644 --- a/internal/processing/stream/notification.go +++ b/internal/processing/stream/notification.go @@ -18,20 +18,29 @@ package stream import ( + "context" "encoding/json" - "fmt" + "codeberg.org/gruf/go-byteutil" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/stream" ) // Notify streams the given notification to any open, appropriate streams belonging to the given account. -func (p *Processor) Notify(n *apimodel.Notification, account *gtsmodel.Account) error { - bytes, err := json.Marshal(n) +func (p *Processor) Notify(ctx context.Context, account *gtsmodel.Account, notif *apimodel.Notification) { + b, err := json.Marshal(notif) if err != nil { - return fmt.Errorf("error marshalling notification to json: %s", err) + log.Errorf(ctx, "error marshaling json: %v", err) + return } - - return p.toAccount(string(bytes), stream.EventTypeNotification, []string{stream.TimelineNotifications, stream.TimelineHome}, account.ID) + p.streams.Post(ctx, account.ID, stream.Message{ + Payload: byteutil.B2S(b), + Event: stream.EventTypeNotification, + Stream: []string{ + stream.TimelineNotifications, + stream.TimelineHome, + }, + }) } diff --git a/internal/processing/stream/notification_test.go b/internal/processing/stream/notification_test.go index 2138f0025..e12f23abe 100644 --- a/internal/processing/stream/notification_test.go +++ b/internal/processing/stream/notification_test.go @@ -49,10 +49,11 @@ func (suite *NotificationTestSuite) TestStreamNotification() { Account: followAccountAPIModel, } - err = suite.streamProcessor.Notify(notification, account) - suite.NoError(err) + suite.streamProcessor.Notify(context.Background(), account, notification) + + msg, ok := openStream.Recv(context.Background()) + suite.True(ok) - msg := <-openStream.Messages dst := new(bytes.Buffer) err = json.Indent(dst, []byte(msg.Payload), "", " ") suite.NoError(err) diff --git a/internal/processing/stream/open.go b/internal/processing/stream/open.go index 1c041309f..2f2bbd4a3 100644 --- a/internal/processing/stream/open.go +++ b/internal/processing/stream/open.go @@ -19,13 +19,10 @@ package stream import ( "context" - "errors" - "fmt" "codeberg.org/gruf/go-kv" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/stream" ) @@ -37,97 +34,5 @@ func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamT {"streamType", streamType}, }...) l.Debug("received open stream request") - - var ( - streamID string - err error - ) - - // Each stream needs a unique ID so we know to close it. - streamID, err = id.NewRandomULID() - if err != nil { - return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %w", err)) - } - - // Each stream can be subscibed to multiple types. - // Record them in a set, and include the initial one - // if it was given to us. - streamTypes := map[string]any{} - if streamType != "" { - streamTypes[streamType] = true - } - - newStream := &stream.Stream{ - ID: streamID, - StreamTypes: streamTypes, - Messages: make(chan *stream.Message, 100), - Hangup: make(chan interface{}, 1), - Connected: true, - } - go p.waitToCloseStream(account, newStream) - - v, ok := p.streamMap.Load(account.ID) - if ok { - // There is an entry in the streamMap - // for this account. Parse it out. - streamsForAccount, ok := v.(*stream.StreamsForAccount) - if !ok { - return nil, gtserror.NewErrorInternalError(errors.New("stream map error")) - } - - // Append new stream to existing entry. - streamsForAccount.Lock() - streamsForAccount.Streams = append(streamsForAccount.Streams, newStream) - streamsForAccount.Unlock() - } else { - // There is no entry in the streamMap for - // this account yet. Create one and store it. - p.streamMap.Store(account.ID, &stream.StreamsForAccount{ - Streams: []*stream.Stream{ - newStream, - }, - }) - } - - return newStream, nil -} - -// waitToCloseStream waits until the hangup channel is closed for the given stream. -// It then iterates through the map of streams stored by the processor, removes the stream from it, -// and then closes the messages channel of the stream to indicate that the channel should no longer be read from. -func (p *Processor) waitToCloseStream(account *gtsmodel.Account, thisStream *stream.Stream) { - <-thisStream.Hangup // wait for a hangup message - - // lock the stream to prevent more messages being put in it while we work - thisStream.Lock() - defer thisStream.Unlock() - - // indicate the stream is no longer connected - thisStream.Connected = false - - // load and parse the entry for this account from the stream map - v, ok := p.streamMap.Load(account.ID) - if !ok || v == nil { - return - } - streamsForAccount, ok := v.(*stream.StreamsForAccount) - if !ok { - return - } - - // lock the streams for account while we remove this stream from its slice - streamsForAccount.Lock() - defer streamsForAccount.Unlock() - - // put everything into modified streams *except* the stream we're removing - modifiedStreams := []*stream.Stream{} - for _, s := range streamsForAccount.Streams { - if s.ID != thisStream.ID { - modifiedStreams = append(modifiedStreams, s) - } - } - streamsForAccount.Streams = modifiedStreams - - // finally close the messages channel so no more messages can be read from it - close(thisStream.Messages) + return p.streams.Open(account.ID, streamType), nil } diff --git a/internal/processing/stream/statusupdate.go b/internal/processing/stream/statusupdate.go index fd8e388ce..bd4658873 100644 --- a/internal/processing/stream/statusupdate.go +++ b/internal/processing/stream/statusupdate.go @@ -18,21 +18,26 @@ package stream import ( + "context" "encoding/json" - "fmt" + "codeberg.org/gruf/go-byteutil" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/stream" ) -// StatusUpdate streams the given edited status to any open, appropriate -// streams belonging to the given account. -func (p *Processor) StatusUpdate(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error { - bytes, err := json.Marshal(s) +// StatusUpdate streams the given edited status to any open, appropriate streams belonging to the given account. +func (p *Processor) StatusUpdate(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) { + b, err := json.Marshal(status) if err != nil { - return fmt.Errorf("error marshalling status to json: %s", err) + log.Errorf(ctx, "error marshaling json: %v", err) + return } - - return p.toAccount(string(bytes), stream.EventTypeStatusUpdate, streamTypes, account.ID) + p.streams.Post(ctx, account.ID, stream.Message{ + Payload: byteutil.B2S(b), + Event: stream.EventTypeStatusUpdate, + Stream: []string{streamType}, + }) } diff --git a/internal/processing/stream/statusupdate_test.go b/internal/processing/stream/statusupdate_test.go index 7b987b412..8814c966f 100644 --- a/internal/processing/stream/statusupdate_test.go +++ b/internal/processing/stream/statusupdate_test.go @@ -42,10 +42,11 @@ func (suite *StatusUpdateTestSuite) TestStreamNotification() { apiStatus, err := typeutils.NewConverter(&suite.state).StatusToAPIStatus(context.Background(), editedStatus, account) suite.NoError(err) - err = suite.streamProcessor.StatusUpdate(apiStatus, account, []string{stream.TimelineHome}) - suite.NoError(err) + suite.streamProcessor.StatusUpdate(context.Background(), account, apiStatus, stream.TimelineHome) + + msg, ok := openStream.Recv(context.Background()) + suite.True(ok) - msg := <-openStream.Messages dst := new(bytes.Buffer) err = json.Indent(dst, []byte(msg.Payload), "", " ") suite.NoError(err) diff --git a/internal/processing/stream/stream.go b/internal/processing/stream/stream.go index a5b3b9386..0b7285b58 100644 --- a/internal/processing/stream/stream.go +++ b/internal/processing/stream/stream.go @@ -18,8 +18,6 @@ package stream import ( - "sync" - "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/stream" @@ -28,53 +26,13 @@ import ( type Processor struct { state *state.State oauthServer oauth.Server - streamMap *sync.Map + streams stream.Streams } func New(state *state.State, oauthServer oauth.Server) Processor { return Processor{ state: state, oauthServer: oauthServer, - streamMap: &sync.Map{}, - } -} - -// toAccount streams the given payload with the given event type to any streams currently open for the given account ID. -func (p *Processor) toAccount(payload string, event string, streamTypes []string, accountID string) error { - // Load all streams open for this account. - v, ok := p.streamMap.Load(accountID) - if !ok { - return nil // No entry = nothing to stream. + streams: stream.Streams{}, } - streamsForAccount := v.(*stream.StreamsForAccount) - - streamsForAccount.Lock() - defer streamsForAccount.Unlock() - - for _, s := range streamsForAccount.Streams { - s.Lock() - defer s.Unlock() - - if !s.Connected { - continue - } - - typeLoop: - for _, streamType := range streamTypes { - if _, found := s.StreamTypes[streamType]; found { - s.Messages <- &stream.Message{ - Stream: []string{streamType}, - Event: string(event), - Payload: payload, - } - - // Break out to the outer loop, - // to avoid sending duplicates of - // the same event to the same stream. - break typeLoop - } - } - } - - return nil } diff --git a/internal/processing/stream/update.go b/internal/processing/stream/update.go index ee70bda11..a84763d51 100644 --- a/internal/processing/stream/update.go +++ b/internal/processing/stream/update.go @@ -18,20 +18,26 @@ package stream import ( + "context" "encoding/json" - "fmt" + "codeberg.org/gruf/go-byteutil" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/stream" ) // Update streams the given update to any open, appropriate streams belonging to the given account. -func (p *Processor) Update(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error { - bytes, err := json.Marshal(s) +func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) { + b, err := json.Marshal(status) if err != nil { - return fmt.Errorf("error marshalling status to json: %s", err) + log.Errorf(ctx, "error marshaling json: %v", err) + return } - - return p.toAccount(string(bytes), stream.EventTypeUpdate, streamTypes, account.ID) + p.streams.Post(ctx, account.ID, stream.Message{ + Payload: byteutil.B2S(b), + Event: stream.EventTypeUpdate, + Stream: []string{streamType}, + }) } diff --git a/internal/processing/workers/fromclientapi_test.go b/internal/processing/workers/fromclientapi_test.go index 05526f437..3d3630b11 100644 --- a/internal/processing/workers/fromclientapi_test.go +++ b/internal/processing/workers/fromclientapi_test.go @@ -116,23 +116,20 @@ func (suite *FromClientAPITestSuite) checkStreamed( expectPayload string, expectEventType string, ) { - var msg *stream.Message -streamLoop: - for { - select { - case msg = <-str.Messages: - break streamLoop // Got it. - case <-time.After(5 * time.Second): - break streamLoop // Didn't get it. - } - } - if expectMessage && msg == nil { - suite.FailNow("expected a message but message was nil") + // Set a 5s timeout on context. + ctx := context.Background() + ctx, cncl := context.WithTimeout(ctx, time.Second*5) + defer cncl() + + msg, ok := str.Recv(ctx) + + if expectMessage && !ok { + suite.FailNow("expected a message but message was not received") } - if !expectMessage && msg != nil { - suite.FailNow("expected no message but message was not nil") + if !expectMessage && ok { + suite.FailNow("expected no message but message was received") } if expectPayload != "" && msg.Payload != expectPayload { diff --git a/internal/processing/workers/fromfediapi_test.go b/internal/processing/workers/fromfediapi_test.go index 799eaf2dc..446355628 100644 --- a/internal/processing/workers/fromfediapi_test.go +++ b/internal/processing/workers/fromfediapi_test.go @@ -130,14 +130,9 @@ func (suite *FromFediAPITestSuite) TestProcessReplyMention() { suite.Equal(replyingStatus.ID, notif.StatusID) suite.False(*notif.Read) - // the notification should be streamed - var msg *stream.Message - select { - case msg = <-wssStream.Messages: - // fine - case <-time.After(5 * time.Second): - suite.FailNow("no message from wssStream") - } + ctx, _ := context.WithTimeout(context.Background(), time.Second*5) + msg, ok := wssStream.Recv(ctx) + suite.True(ok) suite.Equal(stream.EventTypeNotification, msg.Event) suite.NotEmpty(msg.Payload) @@ -203,14 +198,10 @@ func (suite *FromFediAPITestSuite) TestProcessFave() { suite.Equal(fave.StatusID, notif.StatusID) suite.False(*notif.Read) - // 2. a notification should be streamed - var msg *stream.Message - select { - case msg = <-wssStream.Messages: - // fine - case <-time.After(5 * time.Second): - suite.FailNow("no message from wssStream") - } + ctx, _ := context.WithTimeout(context.Background(), time.Second*5) + msg, ok := wssStream.Recv(ctx) + suite.True(ok) + suite.Equal(stream.EventTypeNotification, msg.Event) suite.NotEmpty(msg.Payload) suite.EqualValues([]string{stream.TimelineNotifications}, msg.Stream) @@ -277,7 +268,9 @@ func (suite *FromFediAPITestSuite) TestProcessFaveWithDifferentReceivingAccount( suite.False(*notif.Read) // 2. no notification should be streamed to the account that received the fave message, because they weren't the target - suite.Empty(wssStream.Messages) + ctx, _ := context.WithTimeout(context.Background(), time.Second*5) + _, ok := wssStream.Recv(ctx) + suite.False(ok) } func (suite *FromFediAPITestSuite) TestProcessAccountDelete() { @@ -405,14 +398,10 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() { }) suite.NoError(err) - // a notification should be streamed - var msg *stream.Message - select { - case msg = <-wssStream.Messages: - // fine - case <-time.After(5 * time.Second): - suite.FailNow("no message from wssStream") - } + ctx, _ = context.WithTimeout(ctx, time.Second*5) + msg, ok := wssStream.Recv(context.Background()) + suite.True(ok) + suite.Equal(stream.EventTypeNotification, msg.Event) suite.NotEmpty(msg.Payload) suite.EqualValues([]string{stream.TimelineHome}, msg.Stream) @@ -423,7 +412,7 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() { suite.Equal(originAccount.ID, notif.Account.ID) // no messages should have been sent out, since we didn't need to federate an accept - suite.Empty(suite.httpClient.SentMessages) + suite.Empty(&suite.httpClient.SentMessages) } func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() { @@ -503,14 +492,10 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() { suite.Equal(originAccount.URI, accept.To) suite.Equal("Accept", accept.Type) - // a notification should be streamed - var msg *stream.Message - select { - case msg = <-wssStream.Messages: - // fine - case <-time.After(5 * time.Second): - suite.FailNow("no message from wssStream") - } + ctx, _ = context.WithTimeout(ctx, time.Second*5) + msg, ok := wssStream.Recv(context.Background()) + suite.True(ok) + suite.Equal(stream.EventTypeNotification, msg.Event) suite.NotEmpty(msg.Payload) suite.EqualValues([]string{stream.TimelineHome}, msg.Stream) diff --git a/internal/processing/workers/surfacenotify.go b/internal/processing/workers/surfacenotify.go index 39798f45e..a8c36248c 100644 --- a/internal/processing/workers/surfacenotify.go +++ b/internal/processing/workers/surfacenotify.go @@ -394,10 +394,7 @@ func (s *surface) notify( if err != nil { return gtserror.Newf("error converting notification to api representation: %w", err) } - - if err := s.stream.Notify(apiNotif, targetAccount); err != nil { - return gtserror.Newf("error streaming notification to account: %w", err) - } + s.stream.Notify(ctx, targetAccount, apiNotif) return nil } diff --git a/internal/processing/workers/surfacetimeline.go b/internal/processing/workers/surfacetimeline.go index e63b8a7c0..14634f846 100644 --- a/internal/processing/workers/surfacetimeline.go +++ b/internal/processing/workers/surfacetimeline.go @@ -348,11 +348,7 @@ func (s *surface) timelineStatus( err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err) return true, err } - - if err := s.stream.Update(apiStatus, account, []string{streamType}); err != nil { - err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err) - return true, err - } + s.stream.Update(ctx, account, apiStatus, streamType) return true, nil } @@ -363,12 +359,11 @@ func (s *surface) deleteStatusFromTimelines(ctx context.Context, statusID string if err := s.state.Timelines.Home.WipeItemFromAllTimelines(ctx, statusID); err != nil { return err } - if err := s.state.Timelines.List.WipeItemFromAllTimelines(ctx, statusID); err != nil { return err } - - return s.stream.Delete(statusID) + s.stream.Delete(ctx, statusID) + return nil } // invalidateStatusFromTimelines does cache invalidation on the given status by @@ -555,11 +550,6 @@ func (s *surface) timelineStreamStatusUpdate( err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err) return err } - - if err := s.stream.StatusUpdate(apiStatus, account, []string{streamType}); err != nil { - err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err) - return err - } - + s.stream.StatusUpdate(ctx, account, apiStatus, streamType) return nil } |