summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLibravatar kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com>2024-02-20 18:07:49 +0000
committerLibravatar GitHub <noreply@github.com>2024-02-20 18:07:49 +0000
commit291e18099050ff9e19b8ee25c2ffad68d9baafef (patch)
tree0ad1be36b4c958830d1371f3b9a32f017c5dcff0
parent[feature] Add `requested_by` to relationship model (#2672) (diff)
downloadgotosocial-291e18099050ff9e19b8ee25c2ffad68d9baafef.tar.xz
[bugfix] fix possible mutex lockup during streaming code (#2633)
* rewrite Stream{} to use much less mutex locking, update related code * use new context for the stream context * ensure stream gets closed on return of writeTo / readFrom WSConn() * ensure stream write timeout gets cancelled * remove embedded context type from Stream{}, reformat log messages for consistency * use c.Request.Context() for context passed into Stream().Open() * only return 1 boolean, fix tests to expect multiple stream types in messages * changes to ping logic * further improved ping logic * don't export unused function types, update message sending to only include relevant stream type * ensure stream gets closed :facepalm: * update to error log on failed json marshal (instead of panic) * inverse websocket read error checking to _ignore_ expected close errors
-rw-r--r--internal/api/client/streaming/stream.go235
-rw-r--r--internal/processing/stream/delete.go34
-rw-r--r--internal/processing/stream/notification.go21
-rw-r--r--internal/processing/stream/notification_test.go7
-rw-r--r--internal/processing/stream/open.go97
-rw-r--r--internal/processing/stream/statusupdate.go21
-rw-r--r--internal/processing/stream/statusupdate_test.go7
-rw-r--r--internal/processing/stream/stream.go46
-rw-r--r--internal/processing/stream/update.go18
-rw-r--r--internal/processing/workers/fromclientapi_test.go25
-rw-r--r--internal/processing/workers/fromfediapi_test.go53
-rw-r--r--internal/processing/workers/surfacenotify.go5
-rw-r--r--internal/processing/workers/surfacetimeline.go18
-rw-r--r--internal/stream/stream.go385
14 files changed, 528 insertions, 444 deletions
diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go
index 266b64976..8df4e9e76 100644
--- a/internal/api/client/streaming/stream.go
+++ b/internal/api/client/streaming/stream.go
@@ -22,10 +22,10 @@ import (
"slices"
"time"
- "codeberg.org/gruf/go-kv"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
streampkg "github.com/superseriousbusiness/gotosocial/internal/stream"
@@ -202,7 +202,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
// functions pass messages into a channel, which we can
// then read from and put into a websockets connection.
stream, errWithCode := m.processor.Stream().Open(
- c.Request.Context(),
+ c.Request.Context(), // this ctx is only used for logging
account,
streamType,
)
@@ -213,10 +213,8 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
l := log.
WithContext(c.Request.Context()).
- WithFields(kv.Fields{
- {"username", account.Username},
- {"streamID", stream.ID},
- }...)
+ WithField("streamID", id.NewULID()).
+ WithField("username", account.Username)
// Upgrade the incoming HTTP request. This hijacks the
// underlying connection and reuses it for the websocket
@@ -227,18 +225,16 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil)
if err != nil {
l.Errorf("error upgrading websocket connection: %v", err)
- close(stream.Hangup)
+ stream.Close()
return
}
- l.Info("opened websocket connection")
-
// We perform the main websocket rw loops in a separate
// goroutine in order to let the upgrade handler return.
// This prevents the upgrade handler from holding open any
// throttle / rate-limit request tokens which could become
// problematic on instances with multiple users.
- go m.handleWSConn(account.Username, wsConn, stream)
+ go m.handleWSConn(&l, wsConn, stream)
}
// handleWSConn handles a two-way websocket streaming connection.
@@ -246,48 +242,39 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
// into the connection. If any errors are encountered while reading
// or writing (including expected errors like clients leaving), the
// connection will be closed.
-func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *streampkg.Stream) {
- // Create new context for the lifetime of this connection.
- ctx, cancel := context.WithCancel(context.Background())
-
- l := log.
- WithContext(ctx).
- WithFields(kv.Fields{
- {"username", username},
- {"streamID", stream.ID},
- }...)
+func (m *Module) handleWSConn(l *log.Entry, wsConn *websocket.Conn, stream *streampkg.Stream) {
+ l.Info("opened websocket connection")
- // Create ticker to send keepalive pings
- pinger := time.NewTicker(m.dTicker)
+ // Create new async context with cancel.
+ ctx, cncl := context.WithCancel(context.Background())
- // Read messages coming from the Websocket client connection into the server.
go func() {
- defer cancel()
- m.readFromWSConn(ctx, username, wsConn, stream)
+ defer cncl()
+
+ // Read messages from websocket to server.
+ m.readFromWSConn(ctx, wsConn, stream, l)
}()
- // Write messages coming from the processor into the Websocket client connection.
go func() {
- defer cancel()
- m.writeToWSConn(ctx, username, wsConn, stream, pinger)
+ defer cncl()
+
+ // Write messages from processor in websocket conn.
+ m.writeToWSConn(ctx, wsConn, stream, m.dTicker, l)
}()
- // Wait for either the read or write functions to close, to indicate
- // that the client has left, or something else has gone wrong.
+ // Wait for ctx
+ // to be closed.
<-ctx.Done()
+ // Close stream
+ // straightaway.
+ stream.Close()
+
// Tidy up underlying websocket connection.
if err := wsConn.Close(); err != nil {
l.Errorf("error closing websocket connection: %v", err)
}
- // Close processor channel so the processor knows
- // not to send any more messages to this stream.
- close(stream.Hangup)
-
- // Stop ping ticker (tiny resource saving).
- pinger.Stop()
-
l.Info("closed websocket connection")
}
@@ -299,89 +286,64 @@ func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *s
// if the given context is canceled.
func (m *Module) readFromWSConn(
ctx context.Context,
- username string,
wsConn *websocket.Conn,
stream *streampkg.Stream,
+ l *log.Entry,
) {
- l := log.
- WithContext(ctx).
- WithFields(kv.Fields{
- {"username", username},
- {"streamID", stream.ID},
- }...)
-readLoop:
for {
- select {
- case <-ctx.Done():
- // Connection closed.
- break readLoop
+ var msg struct {
+ Type string `json:"type"`
+ Stream string `json:"stream"`
+ List string `json:"list,omitempty"`
+ }
- default:
- // Read JSON objects from the client and act on them.
- var msg map[string]string
- if err := wsConn.ReadJSON(&msg); err != nil {
- // Only log an error if something weird happened.
- // See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7
- if websocket.IsUnexpectedCloseError(err, []int{
- websocket.CloseNormalClosure,
- websocket.CloseGoingAway,
- websocket.CloseNoStatusReceived,
- }...) {
- l.Errorf("error reading from websocket: %v", err)
- }
-
- // The connection is gone; no
- // further streaming possible.
- break readLoop
+ // Read JSON objects from the client and act on them.
+ if err := wsConn.ReadJSON(&msg); err != nil {
+ // Only log an error if something weird happened.
+ // See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7
+ if !websocket.IsCloseError(err, []int{
+ websocket.CloseNormalClosure,
+ websocket.CloseGoingAway,
+ websocket.CloseNoStatusReceived,
+ }...) {
+ l.Errorf("error during websocket read: %v", err)
}
- // Messages *from* the WS connection are infrequent
- // and usually interesting, so log this at info.
- l.Infof("received message from websocket: %v", msg)
-
- // If the message contains 'stream' and 'type' fields, we can
- // update the set of timelines that are subscribed for events.
- updateType, ok := msg["type"]
- if !ok {
- l.Warn("'type' field not provided")
- continue
- }
+ // The connection is gone; no
+ // further streaming possible.
+ break
+ }
- updateStream, ok := msg["stream"]
- if !ok {
- l.Warn("'stream' field not provided")
- continue
- }
+ // Messages *from* the WS connection are infrequent
+ // and usually interesting, so log this at info.
+ l.Infof("received websocket message: %+v", msg)
- // Ignore if the updateStreamType is unknown (or missing),
- // so a bad client can't cause extra memory allocations
- if !slices.Contains(streampkg.AllStatusTimelines, updateStream) {
- l.Warnf("unknown 'stream' field: %v", msg)
- continue
- }
+ // Ignore if the updateStreamType is unknown (or missing),
+ // so a bad client can't cause extra memory allocations
+ if !slices.Contains(streampkg.AllStatusTimelines, msg.Stream) {
+ l.Warnf("unknown 'stream' field: %v", msg)
+ continue
+ }
- updateList, ok := msg["list"]
- if ok {
- updateStream += ":" + updateList
- }
+ if msg.List != "" {
+ // If a list is given, add this to
+ // the stream name as this is how we
+ // we track stream types internally.
+ msg.Stream += ":" + msg.List
+ }
- switch updateType {
- case "subscribe":
- stream.Lock()
- stream.StreamTypes[updateStream] = true
- stream.Unlock()
- case "unsubscribe":
- stream.Lock()
- delete(stream.StreamTypes, updateStream)
- stream.Unlock()
- default:
- l.Warnf("invalid 'type' field: %v", msg)
- }
+ switch msg.Type {
+ case "subscribe":
+ stream.Subscribe(msg.Stream)
+ case "unsubscribe":
+ stream.Unsubscribe(msg.Stream)
+ default:
+ l.Warnf("invalid 'type' field: %v", msg)
}
}
- l.Debug("finished reading from websocket connection")
+ l.Debug("finished websocket read")
}
// writeToWSConn receives messages coming from the processor via the
@@ -393,46 +355,47 @@ readLoop:
// if the given context is canceled.
func (m *Module) writeToWSConn(
ctx context.Context,
- username string,
wsConn *websocket.Conn,
stream *streampkg.Stream,
- pinger *time.Ticker,
+ ping time.Duration,
+ l *log.Entry,
) {
- l := log.
- WithContext(ctx).
- WithFields(kv.Fields{
- {"username", username},
- {"streamID", stream.ID},
- }...)
-
-writeLoop:
for {
- select {
- case <-ctx.Done():
- // Connection closed.
- break writeLoop
-
- case msg := <-stream.Messages:
- // Received a new message from the processor.
- l.Tracef("writing message to websocket: %+v", msg)
- if err := wsConn.WriteJSON(msg); err != nil {
- l.Debugf("error writing json to websocket: %v", err)
- break writeLoop
- }
+ // Wrap context with timeout to send a ping.
+ pingctx, cncl := context.WithTimeout(ctx, ping)
+
+ // Block on receipt of msg.
+ msg, ok := stream.Recv(pingctx)
- // Reset pinger on successful send, since
- // we know the connection is still there.
- pinger.Reset(m.dTicker)
+ // Check if cancel because ping.
+ pinged := (pingctx.Err() != nil)
+ cncl()
- case <-pinger.C:
- // Time to send a keep-alive "ping".
- l.Trace("writing ping control message to websocket")
+ switch {
+ case !ok && pinged:
+ // The ping context timed out!
+ l.Trace("writing websocket ping")
+
+ // Wrapped context time-out, send a keep-alive "ping".
if err := wsConn.WriteControl(websocket.PingMessage, nil, time.Time{}); err != nil {
- l.Debugf("error writing ping to websocket: %v", err)
- break writeLoop
+ l.Debugf("error writing websocket ping: %v", err)
+ break
}
+
+ case !ok:
+ // Stream was
+ // closed.
+ return
+ }
+
+ l.Trace("writing websocket message: %+v", msg)
+
+ // Received a new message from the processor.
+ if err := wsConn.WriteJSON(msg); err != nil {
+ l.Debugf("error writing websocket message: %v", err)
+ break
}
}
- l.Debug("finished writing to websocket connection")
+ l.Debug("finished websocket write")
}
diff --git a/internal/processing/stream/delete.go b/internal/processing/stream/delete.go
index d7745eef8..1c61b98d3 100644
--- a/internal/processing/stream/delete.go
+++ b/internal/processing/stream/delete.go
@@ -18,38 +18,16 @@
package stream
import (
- "fmt"
- "strings"
+ "context"
"github.com/superseriousbusiness/gotosocial/internal/stream"
)
// Delete streams the delete of the given statusID to *ALL* open streams.
-func (p *Processor) Delete(statusID string) error {
- errs := []string{}
-
- // get all account IDs with open streams
- accountIDs := []string{}
- p.streamMap.Range(func(k interface{}, _ interface{}) bool {
- key, ok := k.(string)
- if !ok {
- panic("streamMap key was not a string (account id)")
- }
-
- accountIDs = append(accountIDs, key)
- return true
+func (p *Processor) Delete(ctx context.Context, statusID string) {
+ p.streams.PostAll(ctx, stream.Message{
+ Payload: statusID,
+ Event: stream.EventTypeDelete,
+ Stream: stream.AllStatusTimelines,
})
-
- // stream the delete to every account
- for _, accountID := range accountIDs {
- if err := p.toAccount(statusID, stream.EventTypeDelete, stream.AllStatusTimelines, accountID); err != nil {
- errs = append(errs, err.Error())
- }
- }
-
- if len(errs) != 0 {
- return fmt.Errorf("one or more errors streaming status delete: %s", strings.Join(errs, ";"))
- }
-
- return nil
}
diff --git a/internal/processing/stream/notification.go b/internal/processing/stream/notification.go
index 63d7c5d11..a16da11e6 100644
--- a/internal/processing/stream/notification.go
+++ b/internal/processing/stream/notification.go
@@ -18,20 +18,29 @@
package stream
import (
+ "context"
"encoding/json"
- "fmt"
+ "codeberg.org/gruf/go-byteutil"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/stream"
)
// Notify streams the given notification to any open, appropriate streams belonging to the given account.
-func (p *Processor) Notify(n *apimodel.Notification, account *gtsmodel.Account) error {
- bytes, err := json.Marshal(n)
+func (p *Processor) Notify(ctx context.Context, account *gtsmodel.Account, notif *apimodel.Notification) {
+ b, err := json.Marshal(notif)
if err != nil {
- return fmt.Errorf("error marshalling notification to json: %s", err)
+ log.Errorf(ctx, "error marshaling json: %v", err)
+ return
}
-
- return p.toAccount(string(bytes), stream.EventTypeNotification, []string{stream.TimelineNotifications, stream.TimelineHome}, account.ID)
+ p.streams.Post(ctx, account.ID, stream.Message{
+ Payload: byteutil.B2S(b),
+ Event: stream.EventTypeNotification,
+ Stream: []string{
+ stream.TimelineNotifications,
+ stream.TimelineHome,
+ },
+ })
}
diff --git a/internal/processing/stream/notification_test.go b/internal/processing/stream/notification_test.go
index 2138f0025..e12f23abe 100644
--- a/internal/processing/stream/notification_test.go
+++ b/internal/processing/stream/notification_test.go
@@ -49,10 +49,11 @@ func (suite *NotificationTestSuite) TestStreamNotification() {
Account: followAccountAPIModel,
}
- err = suite.streamProcessor.Notify(notification, account)
- suite.NoError(err)
+ suite.streamProcessor.Notify(context.Background(), account, notification)
+
+ msg, ok := openStream.Recv(context.Background())
+ suite.True(ok)
- msg := <-openStream.Messages
dst := new(bytes.Buffer)
err = json.Indent(dst, []byte(msg.Payload), "", " ")
suite.NoError(err)
diff --git a/internal/processing/stream/open.go b/internal/processing/stream/open.go
index 1c041309f..2f2bbd4a3 100644
--- a/internal/processing/stream/open.go
+++ b/internal/processing/stream/open.go
@@ -19,13 +19,10 @@ package stream
import (
"context"
- "errors"
- "fmt"
"codeberg.org/gruf/go-kv"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/stream"
)
@@ -37,97 +34,5 @@ func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamT
{"streamType", streamType},
}...)
l.Debug("received open stream request")
-
- var (
- streamID string
- err error
- )
-
- // Each stream needs a unique ID so we know to close it.
- streamID, err = id.NewRandomULID()
- if err != nil {
- return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %w", err))
- }
-
- // Each stream can be subscibed to multiple types.
- // Record them in a set, and include the initial one
- // if it was given to us.
- streamTypes := map[string]any{}
- if streamType != "" {
- streamTypes[streamType] = true
- }
-
- newStream := &stream.Stream{
- ID: streamID,
- StreamTypes: streamTypes,
- Messages: make(chan *stream.Message, 100),
- Hangup: make(chan interface{}, 1),
- Connected: true,
- }
- go p.waitToCloseStream(account, newStream)
-
- v, ok := p.streamMap.Load(account.ID)
- if ok {
- // There is an entry in the streamMap
- // for this account. Parse it out.
- streamsForAccount, ok := v.(*stream.StreamsForAccount)
- if !ok {
- return nil, gtserror.NewErrorInternalError(errors.New("stream map error"))
- }
-
- // Append new stream to existing entry.
- streamsForAccount.Lock()
- streamsForAccount.Streams = append(streamsForAccount.Streams, newStream)
- streamsForAccount.Unlock()
- } else {
- // There is no entry in the streamMap for
- // this account yet. Create one and store it.
- p.streamMap.Store(account.ID, &stream.StreamsForAccount{
- Streams: []*stream.Stream{
- newStream,
- },
- })
- }
-
- return newStream, nil
-}
-
-// waitToCloseStream waits until the hangup channel is closed for the given stream.
-// It then iterates through the map of streams stored by the processor, removes the stream from it,
-// and then closes the messages channel of the stream to indicate that the channel should no longer be read from.
-func (p *Processor) waitToCloseStream(account *gtsmodel.Account, thisStream *stream.Stream) {
- <-thisStream.Hangup // wait for a hangup message
-
- // lock the stream to prevent more messages being put in it while we work
- thisStream.Lock()
- defer thisStream.Unlock()
-
- // indicate the stream is no longer connected
- thisStream.Connected = false
-
- // load and parse the entry for this account from the stream map
- v, ok := p.streamMap.Load(account.ID)
- if !ok || v == nil {
- return
- }
- streamsForAccount, ok := v.(*stream.StreamsForAccount)
- if !ok {
- return
- }
-
- // lock the streams for account while we remove this stream from its slice
- streamsForAccount.Lock()
- defer streamsForAccount.Unlock()
-
- // put everything into modified streams *except* the stream we're removing
- modifiedStreams := []*stream.Stream{}
- for _, s := range streamsForAccount.Streams {
- if s.ID != thisStream.ID {
- modifiedStreams = append(modifiedStreams, s)
- }
- }
- streamsForAccount.Streams = modifiedStreams
-
- // finally close the messages channel so no more messages can be read from it
- close(thisStream.Messages)
+ return p.streams.Open(account.ID, streamType), nil
}
diff --git a/internal/processing/stream/statusupdate.go b/internal/processing/stream/statusupdate.go
index fd8e388ce..bd4658873 100644
--- a/internal/processing/stream/statusupdate.go
+++ b/internal/processing/stream/statusupdate.go
@@ -18,21 +18,26 @@
package stream
import (
+ "context"
"encoding/json"
- "fmt"
+ "codeberg.org/gruf/go-byteutil"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/stream"
)
-// StatusUpdate streams the given edited status to any open, appropriate
-// streams belonging to the given account.
-func (p *Processor) StatusUpdate(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error {
- bytes, err := json.Marshal(s)
+// StatusUpdate streams the given edited status to any open, appropriate streams belonging to the given account.
+func (p *Processor) StatusUpdate(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) {
+ b, err := json.Marshal(status)
if err != nil {
- return fmt.Errorf("error marshalling status to json: %s", err)
+ log.Errorf(ctx, "error marshaling json: %v", err)
+ return
}
-
- return p.toAccount(string(bytes), stream.EventTypeStatusUpdate, streamTypes, account.ID)
+ p.streams.Post(ctx, account.ID, stream.Message{
+ Payload: byteutil.B2S(b),
+ Event: stream.EventTypeStatusUpdate,
+ Stream: []string{streamType},
+ })
}
diff --git a/internal/processing/stream/statusupdate_test.go b/internal/processing/stream/statusupdate_test.go
index 7b987b412..8814c966f 100644
--- a/internal/processing/stream/statusupdate_test.go
+++ b/internal/processing/stream/statusupdate_test.go
@@ -42,10 +42,11 @@ func (suite *StatusUpdateTestSuite) TestStreamNotification() {
apiStatus, err := typeutils.NewConverter(&suite.state).StatusToAPIStatus(context.Background(), editedStatus, account)
suite.NoError(err)
- err = suite.streamProcessor.StatusUpdate(apiStatus, account, []string{stream.TimelineHome})
- suite.NoError(err)
+ suite.streamProcessor.StatusUpdate(context.Background(), account, apiStatus, stream.TimelineHome)
+
+ msg, ok := openStream.Recv(context.Background())
+ suite.True(ok)
- msg := <-openStream.Messages
dst := new(bytes.Buffer)
err = json.Indent(dst, []byte(msg.Payload), "", " ")
suite.NoError(err)
diff --git a/internal/processing/stream/stream.go b/internal/processing/stream/stream.go
index a5b3b9386..0b7285b58 100644
--- a/internal/processing/stream/stream.go
+++ b/internal/processing/stream/stream.go
@@ -18,8 +18,6 @@
package stream
import (
- "sync"
-
"github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/stream"
@@ -28,53 +26,13 @@ import (
type Processor struct {
state *state.State
oauthServer oauth.Server
- streamMap *sync.Map
+ streams stream.Streams
}
func New(state *state.State, oauthServer oauth.Server) Processor {
return Processor{
state: state,
oauthServer: oauthServer,
- streamMap: &sync.Map{},
- }
-}
-
-// toAccount streams the given payload with the given event type to any streams currently open for the given account ID.
-func (p *Processor) toAccount(payload string, event string, streamTypes []string, accountID string) error {
- // Load all streams open for this account.
- v, ok := p.streamMap.Load(accountID)
- if !ok {
- return nil // No entry = nothing to stream.
+ streams: stream.Streams{},
}
- streamsForAccount := v.(*stream.StreamsForAccount)
-
- streamsForAccount.Lock()
- defer streamsForAccount.Unlock()
-
- for _, s := range streamsForAccount.Streams {
- s.Lock()
- defer s.Unlock()
-
- if !s.Connected {
- continue
- }
-
- typeLoop:
- for _, streamType := range streamTypes {
- if _, found := s.StreamTypes[streamType]; found {
- s.Messages <- &stream.Message{
- Stream: []string{streamType},
- Event: string(event),
- Payload: payload,
- }
-
- // Break out to the outer loop,
- // to avoid sending duplicates of
- // the same event to the same stream.
- break typeLoop
- }
- }
- }
-
- return nil
}
diff --git a/internal/processing/stream/update.go b/internal/processing/stream/update.go
index ee70bda11..a84763d51 100644
--- a/internal/processing/stream/update.go
+++ b/internal/processing/stream/update.go
@@ -18,20 +18,26 @@
package stream
import (
+ "context"
"encoding/json"
- "fmt"
+ "codeberg.org/gruf/go-byteutil"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/stream"
)
// Update streams the given update to any open, appropriate streams belonging to the given account.
-func (p *Processor) Update(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error {
- bytes, err := json.Marshal(s)
+func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) {
+ b, err := json.Marshal(status)
if err != nil {
- return fmt.Errorf("error marshalling status to json: %s", err)
+ log.Errorf(ctx, "error marshaling json: %v", err)
+ return
}
-
- return p.toAccount(string(bytes), stream.EventTypeUpdate, streamTypes, account.ID)
+ p.streams.Post(ctx, account.ID, stream.Message{
+ Payload: byteutil.B2S(b),
+ Event: stream.EventTypeUpdate,
+ Stream: []string{streamType},
+ })
}
diff --git a/internal/processing/workers/fromclientapi_test.go b/internal/processing/workers/fromclientapi_test.go
index 05526f437..3d3630b11 100644
--- a/internal/processing/workers/fromclientapi_test.go
+++ b/internal/processing/workers/fromclientapi_test.go
@@ -116,23 +116,20 @@ func (suite *FromClientAPITestSuite) checkStreamed(
expectPayload string,
expectEventType string,
) {
- var msg *stream.Message
-streamLoop:
- for {
- select {
- case msg = <-str.Messages:
- break streamLoop // Got it.
- case <-time.After(5 * time.Second):
- break streamLoop // Didn't get it.
- }
- }
- if expectMessage && msg == nil {
- suite.FailNow("expected a message but message was nil")
+ // Set a 5s timeout on context.
+ ctx := context.Background()
+ ctx, cncl := context.WithTimeout(ctx, time.Second*5)
+ defer cncl()
+
+ msg, ok := str.Recv(ctx)
+
+ if expectMessage && !ok {
+ suite.FailNow("expected a message but message was not received")
}
- if !expectMessage && msg != nil {
- suite.FailNow("expected no message but message was not nil")
+ if !expectMessage && ok {
+ suite.FailNow("expected no message but message was received")
}
if expectPayload != "" && msg.Payload != expectPayload {
diff --git a/internal/processing/workers/fromfediapi_test.go b/internal/processing/workers/fromfediapi_test.go
index 799eaf2dc..446355628 100644
--- a/internal/processing/workers/fromfediapi_test.go
+++ b/internal/processing/workers/fromfediapi_test.go
@@ -130,14 +130,9 @@ func (suite *FromFediAPITestSuite) TestProcessReplyMention() {
suite.Equal(replyingStatus.ID, notif.StatusID)
suite.False(*notif.Read)
- // the notification should be streamed
- var msg *stream.Message
- select {
- case msg = <-wssStream.Messages:
- // fine
- case <-time.After(5 * time.Second):
- suite.FailNow("no message from wssStream")
- }
+ ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
+ msg, ok := wssStream.Recv(ctx)
+ suite.True(ok)
suite.Equal(stream.EventTypeNotification, msg.Event)
suite.NotEmpty(msg.Payload)
@@ -203,14 +198,10 @@ func (suite *FromFediAPITestSuite) TestProcessFave() {
suite.Equal(fave.StatusID, notif.StatusID)
suite.False(*notif.Read)
- // 2. a notification should be streamed
- var msg *stream.Message
- select {
- case msg = <-wssStream.Messages:
- // fine
- case <-time.After(5 * time.Second):
- suite.FailNow("no message from wssStream")
- }
+ ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
+ msg, ok := wssStream.Recv(ctx)
+ suite.True(ok)
+
suite.Equal(stream.EventTypeNotification, msg.Event)
suite.NotEmpty(msg.Payload)
suite.EqualValues([]string{stream.TimelineNotifications}, msg.Stream)
@@ -277,7 +268,9 @@ func (suite *FromFediAPITestSuite) TestProcessFaveWithDifferentReceivingAccount(
suite.False(*notif.Read)
// 2. no notification should be streamed to the account that received the fave message, because they weren't the target
- suite.Empty(wssStream.Messages)
+ ctx, _ := context.WithTimeout(context.Background(), time.Second*5)
+ _, ok := wssStream.Recv(ctx)
+ suite.False(ok)
}
func (suite *FromFediAPITestSuite) TestProcessAccountDelete() {
@@ -405,14 +398,10 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() {
})
suite.NoError(err)
- // a notification should be streamed
- var msg *stream.Message
- select {
- case msg = <-wssStream.Messages:
- // fine
- case <-time.After(5 * time.Second):
- suite.FailNow("no message from wssStream")
- }
+ ctx, _ = context.WithTimeout(ctx, time.Second*5)
+ msg, ok := wssStream.Recv(context.Background())
+ suite.True(ok)
+
suite.Equal(stream.EventTypeNotification, msg.Event)
suite.NotEmpty(msg.Payload)
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)
@@ -423,7 +412,7 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() {
suite.Equal(originAccount.ID, notif.Account.ID)
// no messages should have been sent out, since we didn't need to federate an accept
- suite.Empty(suite.httpClient.SentMessages)
+ suite.Empty(&suite.httpClient.SentMessages)
}
func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() {
@@ -503,14 +492,10 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() {
suite.Equal(originAccount.URI, accept.To)
suite.Equal("Accept", accept.Type)
- // a notification should be streamed
- var msg *stream.Message
- select {
- case msg = <-wssStream.Messages:
- // fine
- case <-time.After(5 * time.Second):
- suite.FailNow("no message from wssStream")
- }
+ ctx, _ = context.WithTimeout(ctx, time.Second*5)
+ msg, ok := wssStream.Recv(context.Background())
+ suite.True(ok)
+
suite.Equal(stream.EventTypeNotification, msg.Event)
suite.NotEmpty(msg.Payload)
suite.EqualValues([]string{stream.TimelineHome}, msg.Stream)
diff --git a/internal/processing/workers/surfacenotify.go b/internal/processing/workers/surfacenotify.go
index 39798f45e..a8c36248c 100644
--- a/internal/processing/workers/surfacenotify.go
+++ b/internal/processing/workers/surfacenotify.go
@@ -394,10 +394,7 @@ func (s *surface) notify(
if err != nil {
return gtserror.Newf("error converting notification to api representation: %w", err)
}
-
- if err := s.stream.Notify(apiNotif, targetAccount); err != nil {
- return gtserror.Newf("error streaming notification to account: %w", err)
- }
+ s.stream.Notify(ctx, targetAccount, apiNotif)
return nil
}
diff --git a/internal/processing/workers/surfacetimeline.go b/internal/processing/workers/surfacetimeline.go
index e63b8a7c0..14634f846 100644
--- a/internal/processing/workers/surfacetimeline.go
+++ b/internal/processing/workers/surfacetimeline.go
@@ -348,11 +348,7 @@ func (s *surface) timelineStatus(
err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err)
return true, err
}
-
- if err := s.stream.Update(apiStatus, account, []string{streamType}); err != nil {
- err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err)
- return true, err
- }
+ s.stream.Update(ctx, account, apiStatus, streamType)
return true, nil
}
@@ -363,12 +359,11 @@ func (s *surface) deleteStatusFromTimelines(ctx context.Context, statusID string
if err := s.state.Timelines.Home.WipeItemFromAllTimelines(ctx, statusID); err != nil {
return err
}
-
if err := s.state.Timelines.List.WipeItemFromAllTimelines(ctx, statusID); err != nil {
return err
}
-
- return s.stream.Delete(statusID)
+ s.stream.Delete(ctx, statusID)
+ return nil
}
// invalidateStatusFromTimelines does cache invalidation on the given status by
@@ -555,11 +550,6 @@ func (s *surface) timelineStreamStatusUpdate(
err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err)
return err
}
-
- if err := s.stream.StatusUpdate(apiStatus, account, []string{streamType}); err != nil {
- err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err)
- return err
- }
-
+ s.stream.StatusUpdate(ctx, account, apiStatus, streamType)
return nil
}
diff --git a/internal/stream/stream.go b/internal/stream/stream.go
index da5647433..ec22464f5 100644
--- a/internal/stream/stream.go
+++ b/internal/stream/stream.go
@@ -17,36 +17,65 @@
package stream
-import "sync"
+import (
+ "context"
+ "maps"
+ "slices"
+ "sync"
+ "sync/atomic"
+)
const (
- // EventTypeNotification -- a user should be shown a notification
- EventTypeNotification string = "notification"
- // EventTypeUpdate -- a user should be shown an update in their timeline
- EventTypeUpdate string = "update"
- // EventTypeDelete -- something should be deleted from a user
- EventTypeDelete string = "delete"
- // EventTypeStatusUpdate -- something in the user's timeline has been edited
- // (yes this is a confusing name, blame Mastodon)
- EventTypeStatusUpdate string = "status.update"
+ // EventTypeNotification -- a user
+ // should be shown a notification.
+ EventTypeNotification = "notification"
+
+ // EventTypeUpdate -- a user should
+ // be shown an update in their timeline.
+ EventTypeUpdate = "update"
+
+ // EventTypeDelete -- something
+ // should be deleted from a user.
+ EventTypeDelete = "delete"
+
+ // EventTypeStatusUpdate -- something in the
+ // user's timeline has been edited (yes this
+ // is a confusing name, blame Mastodon ...).
+ EventTypeStatusUpdate = "status.update"
)
const (
- // TimelineLocal -- public statuses from the LOCAL timeline.
- TimelineLocal string = "public:local"
- // TimelinePublic -- public statuses, including federated ones.
- TimelinePublic string = "public"
- // TimelineHome -- statuses for a user's Home timeline.
- TimelineHome string = "user"
- // TimelineNotifications -- notification events.
- TimelineNotifications string = "user:notification"
- // TimelineDirect -- statuses sent to a user directly.
- TimelineDirect string = "direct"
- // TimelineList -- statuses for a user's list timeline.
- TimelineList string = "list"
+ // TimelineLocal:
+ // All public posts originating from this
+ // server. Analogous to the local timeline.
+ TimelineLocal = "public:local"
+
+ // TimelinePublic:
+ // All public posts known to the server.
+ // Analogous to the federated timeline.
+ TimelinePublic = "public"
+
+ // TimelineHome:
+ // Events related to the current user, such
+ // as home feed updates and notifications.
+ TimelineHome = "user"
+
+ // TimelineNotifications:
+ // Notifications for the current user.
+ TimelineNotifications = "user:notification"
+
+ // TimelineDirect:
+ // Updates to direct conversations.
+ TimelineDirect = "direct"
+
+ // TimelineList:
+ // Updates to a specific list.
+ TimelineList = "list"
)
-// AllStatusTimelines contains all Timelines that a status could conceivably be delivered to -- useful for doing deletes.
+// AllStatusTimelines contains all Timelines
+// that a status could conceivably be delivered
+// to, useful for sending out status deletes.
var AllStatusTimelines = []string{
TimelineLocal,
TimelinePublic,
@@ -55,38 +84,298 @@ var AllStatusTimelines = []string{
TimelineList,
}
-// StreamsForAccount is a wrapper for the multiple streams that one account can have running at the same time.
-// TODO: put a limit on this
-type StreamsForAccount struct {
- // The currently held streams for this account
- Streams []*Stream
- // Mutex to lock/unlock when modifying the slice of streams.
- sync.Mutex
+type Streams struct {
+ streams map[string][]*Stream
+ mutex sync.Mutex
+}
+
+// Open will open open a new Stream for given account ID and stream types, the given context will be passed to Stream.
+func (s *Streams) Open(accountID string, streamTypes ...string) *Stream {
+ if len(streamTypes) == 0 {
+ panic("no stream types given")
+ }
+
+ // Prep new Stream.
+ str := new(Stream)
+ str.done = make(chan struct{})
+ str.msgCh = make(chan Message, 50) // TODO: make configurable
+ for _, streamType := range streamTypes {
+ str.Subscribe(streamType)
+ }
+
+ // TODO: add configurable
+ // max streams per account.
+
+ // Acquire lock.
+ s.mutex.Lock()
+
+ if s.streams == nil {
+ // Main stream-map needs allocating.
+ s.streams = make(map[string][]*Stream)
+ }
+
+ // Add new stream for account.
+ strs := s.streams[accountID]
+ strs = append(strs, str)
+ s.streams[accountID] = strs
+
+ // Register close callback
+ // to remove stream from our
+ // internal map for this account.
+ str.close = func() {
+ s.mutex.Lock()
+ strs := s.streams[accountID]
+ strs = slices.DeleteFunc(strs, func(s *Stream) bool {
+ return s == str // remove 'str' ptr
+ })
+ s.streams[accountID] = strs
+ s.mutex.Unlock()
+ }
+
+ // Done with lock.
+ s.mutex.Unlock()
+
+ return str
+}
+
+// Post will post the given message to all streams of given account ID matching type.
+func (s *Streams) Post(ctx context.Context, accountID string, msg Message) bool {
+ var deferred []func() bool
+
+ // Acquire lock.
+ s.mutex.Lock()
+
+ // Iterate all streams stored for account.
+ for _, str := range s.streams[accountID] {
+
+ // Check whether stream supports any of our message targets.
+ if stype := str.getStreamType(msg.Stream...); stype != "" {
+
+ // Rescope var
+ // to prevent
+ // ptr reuse.
+ stream := str
+
+ // Use a message copy to *only*
+ // include the supported stream.
+ msgCopy := Message{
+ Stream: []string{stype},
+ Event: msg.Event,
+ Payload: msg.Payload,
+ }
+
+ // Send message to supported stream
+ // DEFERRED (i.e. OUTSIDE OF MAIN MUTEX).
+ // This prevents deadlocks between each
+ // msg channel and main Streams{} mutex.
+ deferred = append(deferred, func() bool {
+ return stream.send(ctx, msgCopy)
+ })
+ }
+ }
+
+ // Done with lock.
+ s.mutex.Unlock()
+
+ var ok bool
+
+ // Execute deferred outside lock.
+ for _, deferfn := range deferred {
+ v := deferfn()
+ ok = ok && v
+ }
+
+ return ok
+}
+
+// PostAll will post the given message to all streams with matching types.
+func (s *Streams) PostAll(ctx context.Context, msg Message) bool {
+ var deferred []func() bool
+
+ // Acquire lock.
+ s.mutex.Lock()
+
+ // Iterate ALL stored streams.
+ for _, strs := range s.streams {
+ for _, str := range strs {
+
+ // Check whether stream supports any of our message targets.
+ if stype := str.getStreamType(msg.Stream...); stype != "" {
+
+ // Rescope var
+ // to prevent
+ // ptr reuse.
+ stream := str
+
+ // Use a message copy to *only*
+ // include the supported stream.
+ msgCopy := Message{
+ Stream: []string{stype},
+ Event: msg.Event,
+ Payload: msg.Payload,
+ }
+
+ // Send message to supported stream
+ // DEFERRED (i.e. OUTSIDE OF MAIN MUTEX).
+ // This prevents deadlocks between each
+ // msg channel and main Streams{} mutex.
+ deferred = append(deferred, func() bool {
+ return stream.send(ctx, msgCopy)
+ })
+ }
+ }
+ }
+
+ // Done with lock.
+ s.mutex.Unlock()
+
+ var ok bool
+
+ // Execute deferred outside lock.
+ for _, deferfn := range deferred {
+ v := deferfn()
+ ok = ok && v
+ }
+
+ return ok
}
-// Stream represents one open stream for a client.
+// Stream represents one
+// open stream for a client.
type Stream struct {
- // ID of this stream, generated during creation.
- ID string
- // A set of types subscribed to by this stream: user/public/etc.
- // It's a map to ensure no duplicates; the value is ignored.
- StreamTypes map[string]any
- // Channel of messages for the client to read from
- Messages chan *Message
- // Channel to close when the client drops away
- Hangup chan interface{}
- // Only put messages in the stream when Connected
- Connected bool
- // Mutex to lock/unlock when inserting messages, hanging up, changing the connected state etc.
- sync.Mutex
+
+ // atomically updated ptr to a read-only copy
+ // of supported stream types in a hashmap. this
+ // gets updated via CAS operations in .cas().
+ types atomic.Pointer[map[string]struct{}]
+
+ // protects stream close.
+ done chan struct{}
+
+ // inbound msg ch.
+ msgCh chan Message
+
+ // close hook to remove
+ // stream from Streams{}.
+ close func()
+}
+
+// Subscribe will add given type to given types this stream supports.
+func (s *Stream) Subscribe(streamType string) {
+ s.cas(func(m map[string]struct{}) bool {
+ if _, ok := m[streamType]; ok {
+ return false
+ }
+ m[streamType] = struct{}{}
+ return true
+ })
+}
+
+// Unsubscribe will remove given type (if found) from types this stream supports.
+func (s *Stream) Unsubscribe(streamType string) {
+ s.cas(func(m map[string]struct{}) bool {
+ if _, ok := m[streamType]; !ok {
+ return false
+ }
+ delete(m, streamType)
+ return true
+ })
}
-// Message represents one streamed message.
+// getStreamType returns the first stream type in given list that stream supports.
+func (s *Stream) getStreamType(streamTypes ...string) string {
+ if ptr := s.types.Load(); ptr != nil {
+ for _, streamType := range streamTypes {
+ if _, ok := (*ptr)[streamType]; ok {
+ return streamType
+ }
+ }
+ }
+ return ""
+}
+
+// send will block on posting a new Message{}, returning early with
+// a false value if provided context is canceled, or stream closed.
+func (s *Stream) send(ctx context.Context, msg Message) bool {
+ select {
+ case <-s.done:
+ return false
+ case <-ctx.Done():
+ return false
+ case s.msgCh <- msg:
+ return true
+ }
+}
+
+// Recv will block on receiving Message{}, returning early with a
+// false value if provided context is canceled, or stream closed.
+func (s *Stream) Recv(ctx context.Context) (Message, bool) {
+ select {
+ case <-s.done:
+ return Message{}, false
+ case <-ctx.Done():
+ return Message{}, false
+ case msg := <-s.msgCh:
+ return msg, true
+ }
+}
+
+// Close will close the underlying context, finally
+// removing it from the parent Streams per-account-map.
+func (s *Stream) Close() {
+ select {
+ case <-s.done:
+ default:
+ close(s.done)
+ s.close()
+ }
+}
+
+// cas will perform a Compare And Swap operation on s.types using modifier func.
+func (s *Stream) cas(fn func(map[string]struct{}) bool) {
+ if fn == nil {
+ panic("nil function")
+ }
+ for {
+ var m map[string]struct{}
+
+ // Get current value.
+ ptr := s.types.Load()
+
+ if ptr == nil {
+ // Allocate new types map.
+ m = make(map[string]struct{})
+ } else {
+ // Clone r-only map.
+ m = maps.Clone(*ptr)
+ }
+
+ // Apply
+ // changes.
+ if !fn(m) {
+ return
+ }
+
+ // Attempt to Compare And Swap ptr.
+ if s.types.CompareAndSwap(ptr, &m) {
+ return
+ }
+ }
+}
+
+// Message represents
+// one streamed message.
type Message struct {
- // All the stream types this message should be delivered to.
+
+ // All the stream types this
+ // message should be delivered to.
Stream []string `json:"stream"`
- // The event type of the message (update/delete/notification etc)
+
+ // The event type of the message
+ // (update/delete/notification etc)
Event string `json:"event"`
- // The actual payload of the message. In case of an update or notification, this will be a JSON string.
+
+ // The actual payload of the message. In case of an
+ // update or notification, this will be a JSON string.
Payload string `json:"payload"`
}