diff options
| author | 2024-02-20 18:07:49 +0000 | |
|---|---|---|
| committer | 2024-02-20 18:07:49 +0000 | |
| commit | 291e18099050ff9e19b8ee25c2ffad68d9baafef (patch) | |
| tree | 0ad1be36b4c958830d1371f3b9a32f017c5dcff0 | |
| parent | [feature] Add `requested_by` to relationship model (#2672) (diff) | |
| download | gotosocial-291e18099050ff9e19b8ee25c2ffad68d9baafef.tar.xz | |
[bugfix] fix possible mutex lockup during streaming code (#2633)
* rewrite Stream{} to use much less mutex locking, update related code
* use new context for the stream context
* ensure stream gets closed on return of writeTo / readFrom WSConn()
* ensure stream write timeout gets cancelled
* remove embedded context type from Stream{}, reformat log messages for consistency
* use c.Request.Context() for context passed into Stream().Open()
* only return 1 boolean, fix tests to expect multiple stream types in messages
* changes to ping logic
* further improved ping logic
* don't export unused function types, update message sending to only include relevant stream type
* ensure stream gets closed :facepalm:
* update to error log on failed json marshal (instead of panic)
* inverse websocket read error checking to _ignore_ expected close errors
| -rw-r--r-- | internal/api/client/streaming/stream.go | 235 | ||||
| -rw-r--r-- | internal/processing/stream/delete.go | 34 | ||||
| -rw-r--r-- | internal/processing/stream/notification.go | 21 | ||||
| -rw-r--r-- | internal/processing/stream/notification_test.go | 7 | ||||
| -rw-r--r-- | internal/processing/stream/open.go | 97 | ||||
| -rw-r--r-- | internal/processing/stream/statusupdate.go | 21 | ||||
| -rw-r--r-- | internal/processing/stream/statusupdate_test.go | 7 | ||||
| -rw-r--r-- | internal/processing/stream/stream.go | 46 | ||||
| -rw-r--r-- | internal/processing/stream/update.go | 18 | ||||
| -rw-r--r-- | internal/processing/workers/fromclientapi_test.go | 25 | ||||
| -rw-r--r-- | internal/processing/workers/fromfediapi_test.go | 53 | ||||
| -rw-r--r-- | internal/processing/workers/surfacenotify.go | 5 | ||||
| -rw-r--r-- | internal/processing/workers/surfacetimeline.go | 18 | ||||
| -rw-r--r-- | internal/stream/stream.go | 385 | 
14 files changed, 528 insertions, 444 deletions
diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index 266b64976..8df4e9e76 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -22,10 +22,10 @@ import (  	"slices"  	"time" -	"codeberg.org/gruf/go-kv"  	apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/id"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/oauth"  	streampkg "github.com/superseriousbusiness/gotosocial/internal/stream" @@ -202,7 +202,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) {  	// functions pass messages into a channel, which we can  	// then read from and put into a websockets connection.  	stream, errWithCode := m.processor.Stream().Open( -		c.Request.Context(), +		c.Request.Context(), // this ctx is only used for logging  		account,  		streamType,  	) @@ -213,10 +213,8 @@ func (m *Module) StreamGETHandler(c *gin.Context) {  	l := log.  		WithContext(c.Request.Context()). -		WithFields(kv.Fields{ -			{"username", account.Username}, -			{"streamID", stream.ID}, -		}...) +		WithField("streamID", id.NewULID()). +		WithField("username", account.Username)  	// Upgrade the incoming HTTP request. This hijacks the  	// underlying connection and reuses it for the websocket @@ -227,18 +225,16 @@ func (m *Module) StreamGETHandler(c *gin.Context) {  	wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil)  	if err != nil {  		l.Errorf("error upgrading websocket connection: %v", err) -		close(stream.Hangup) +		stream.Close()  		return  	} -	l.Info("opened websocket connection") -  	// We perform the main websocket rw loops in a separate  	// goroutine in order to let the upgrade handler return.  	// This prevents the upgrade handler from holding open any  	// throttle / rate-limit request tokens which could become  	// problematic on instances with multiple users. -	go m.handleWSConn(account.Username, wsConn, stream) +	go m.handleWSConn(&l, wsConn, stream)  }  // handleWSConn handles a two-way websocket streaming connection. @@ -246,48 +242,39 @@ func (m *Module) StreamGETHandler(c *gin.Context) {  // into the connection. If any errors are encountered while reading  // or writing (including expected errors like clients leaving), the  // connection will be closed. -func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *streampkg.Stream) { -	// Create new context for the lifetime of this connection. -	ctx, cancel := context.WithCancel(context.Background()) - -	l := log. -		WithContext(ctx). -		WithFields(kv.Fields{ -			{"username", username}, -			{"streamID", stream.ID}, -		}...) +func (m *Module) handleWSConn(l *log.Entry, wsConn *websocket.Conn, stream *streampkg.Stream) { +	l.Info("opened websocket connection") -	// Create ticker to send keepalive pings -	pinger := time.NewTicker(m.dTicker) +	// Create new async context with cancel. +	ctx, cncl := context.WithCancel(context.Background()) -	// Read messages coming from the Websocket client connection into the server.  	go func() { -		defer cancel() -		m.readFromWSConn(ctx, username, wsConn, stream) +		defer cncl() + +		// Read messages from websocket to server. +		m.readFromWSConn(ctx, wsConn, stream, l)  	}() -	// Write messages coming from the processor into the Websocket client connection.  	go func() { -		defer cancel() -		m.writeToWSConn(ctx, username, wsConn, stream, pinger) +		defer cncl() + +		// Write messages from processor in websocket conn. +		m.writeToWSConn(ctx, wsConn, stream, m.dTicker, l)  	}() -	// Wait for either the read or write functions to close, to indicate -	// that the client has left, or something else has gone wrong. +	// Wait for ctx +	// to be closed.  	<-ctx.Done() +	// Close stream +	// straightaway. +	stream.Close() +  	// Tidy up underlying websocket connection.  	if err := wsConn.Close(); err != nil {  		l.Errorf("error closing websocket connection: %v", err)  	} -	// Close processor channel so the processor knows -	// not to send any more messages to this stream. -	close(stream.Hangup) - -	// Stop ping ticker (tiny resource saving). -	pinger.Stop() -  	l.Info("closed websocket connection")  } @@ -299,89 +286,64 @@ func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *s  // if the given context is canceled.  func (m *Module) readFromWSConn(  	ctx context.Context, -	username string,  	wsConn *websocket.Conn,  	stream *streampkg.Stream, +	l *log.Entry,  ) { -	l := log. -		WithContext(ctx). -		WithFields(kv.Fields{ -			{"username", username}, -			{"streamID", stream.ID}, -		}...) -readLoop:  	for { -		select { -		case <-ctx.Done(): -			// Connection closed. -			break readLoop +		var msg struct { +			Type   string `json:"type"` +			Stream string `json:"stream"` +			List   string `json:"list,omitempty"` +		} -		default: -			// Read JSON objects from the client and act on them. -			var msg map[string]string -			if err := wsConn.ReadJSON(&msg); err != nil { -				// Only log an error if something weird happened. -				// See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7 -				if websocket.IsUnexpectedCloseError(err, []int{ -					websocket.CloseNormalClosure, -					websocket.CloseGoingAway, -					websocket.CloseNoStatusReceived, -				}...) { -					l.Errorf("error reading from websocket: %v", err) -				} - -				// The connection is gone; no -				// further streaming possible. -				break readLoop +		// Read JSON objects from the client and act on them. +		if err := wsConn.ReadJSON(&msg); err != nil { +			// Only log an error if something weird happened. +			// See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7 +			if !websocket.IsCloseError(err, []int{ +				websocket.CloseNormalClosure, +				websocket.CloseGoingAway, +				websocket.CloseNoStatusReceived, +			}...) { +				l.Errorf("error during websocket read: %v", err)  			} -			// Messages *from* the WS connection are infrequent -			// and usually interesting, so log this at info. -			l.Infof("received message from websocket: %v", msg) - -			// If the message contains 'stream' and 'type' fields, we can -			// update the set of timelines that are subscribed for events. -			updateType, ok := msg["type"] -			if !ok { -				l.Warn("'type' field not provided") -				continue -			} +			// The connection is gone; no +			// further streaming possible. +			break +		} -			updateStream, ok := msg["stream"] -			if !ok { -				l.Warn("'stream' field not provided") -				continue -			} +		// Messages *from* the WS connection are infrequent +		// and usually interesting, so log this at info. +		l.Infof("received websocket message: %+v", msg) -			// Ignore if the updateStreamType is unknown (or missing), -			// so a bad client can't cause extra memory allocations -			if !slices.Contains(streampkg.AllStatusTimelines, updateStream) { -				l.Warnf("unknown 'stream' field: %v", msg) -				continue -			} +		// Ignore if the updateStreamType is unknown (or missing), +		// so a bad client can't cause extra memory allocations +		if !slices.Contains(streampkg.AllStatusTimelines, msg.Stream) { +			l.Warnf("unknown 'stream' field: %v", msg) +			continue +		} -			updateList, ok := msg["list"] -			if ok { -				updateStream += ":" + updateList -			} +		if msg.List != "" { +			// If a list is given, add this to +			// the stream name as this is how we +			// we track stream types internally. +			msg.Stream += ":" + msg.List +		} -			switch updateType { -			case "subscribe": -				stream.Lock() -				stream.StreamTypes[updateStream] = true -				stream.Unlock() -			case "unsubscribe": -				stream.Lock() -				delete(stream.StreamTypes, updateStream) -				stream.Unlock() -			default: -				l.Warnf("invalid 'type' field: %v", msg) -			} +		switch msg.Type { +		case "subscribe": +			stream.Subscribe(msg.Stream) +		case "unsubscribe": +			stream.Unsubscribe(msg.Stream) +		default: +			l.Warnf("invalid 'type' field: %v", msg)  		}  	} -	l.Debug("finished reading from websocket connection") +	l.Debug("finished websocket read")  }  // writeToWSConn receives messages coming from the processor via the @@ -393,46 +355,47 @@ readLoop:  // if the given context is canceled.  func (m *Module) writeToWSConn(  	ctx context.Context, -	username string,  	wsConn *websocket.Conn,  	stream *streampkg.Stream, -	pinger *time.Ticker, +	ping time.Duration, +	l *log.Entry,  ) { -	l := log. -		WithContext(ctx). -		WithFields(kv.Fields{ -			{"username", username}, -			{"streamID", stream.ID}, -		}...) - -writeLoop:  	for { -		select { -		case <-ctx.Done(): -			// Connection closed. -			break writeLoop - -		case msg := <-stream.Messages: -			// Received a new message from the processor. -			l.Tracef("writing message to websocket: %+v", msg) -			if err := wsConn.WriteJSON(msg); err != nil { -				l.Debugf("error writing json to websocket: %v", err) -				break writeLoop -			} +		// Wrap context with timeout to send a ping. +		pingctx, cncl := context.WithTimeout(ctx, ping) + +		// Block on receipt of msg. +		msg, ok := stream.Recv(pingctx) -			// Reset pinger on successful send, since -			// we know the connection is still there. -			pinger.Reset(m.dTicker) +		// Check if cancel because ping. +		pinged := (pingctx.Err() != nil) +		cncl() -		case <-pinger.C: -			// Time to send a keep-alive "ping". -			l.Trace("writing ping control message to websocket") +		switch { +		case !ok && pinged: +			// The ping context timed out! +			l.Trace("writing websocket ping") + +			// Wrapped context time-out, send a keep-alive "ping".  			if err := wsConn.WriteControl(websocket.PingMessage, nil, time.Time{}); err != nil { -				l.Debugf("error writing ping to websocket: %v", err) -				break writeLoop +				l.Debugf("error writing websocket ping: %v", err) +				break  			} + +		case !ok: +			// Stream was +			// closed. +			return +		} + +		l.Trace("writing websocket message: %+v", msg) + +		// Received a new message from the processor. +		if err := wsConn.WriteJSON(msg); err != nil { +			l.Debugf("error writing websocket message: %v", err) +			break  		}  	} -	l.Debug("finished writing to websocket connection") +	l.Debug("finished websocket write")  } diff --git a/internal/processing/stream/delete.go b/internal/processing/stream/delete.go index d7745eef8..1c61b98d3 100644 --- a/internal/processing/stream/delete.go +++ b/internal/processing/stream/delete.go @@ -18,38 +18,16 @@  package stream  import ( -	"fmt" -	"strings" +	"context"  	"github.com/superseriousbusiness/gotosocial/internal/stream"  )  // Delete streams the delete of the given statusID to *ALL* open streams. -func (p *Processor) Delete(statusID string) error { -	errs := []string{} - -	// get all account IDs with open streams -	accountIDs := []string{} -	p.streamMap.Range(func(k interface{}, _ interface{}) bool { -		key, ok := k.(string) -		if !ok { -			panic("streamMap key was not a string (account id)") -		} - -		accountIDs = append(accountIDs, key) -		return true +func (p *Processor) Delete(ctx context.Context, statusID string) { +	p.streams.PostAll(ctx, stream.Message{ +		Payload: statusID, +		Event:   stream.EventTypeDelete, +		Stream:  stream.AllStatusTimelines,  	}) - -	// stream the delete to every account -	for _, accountID := range accountIDs { -		if err := p.toAccount(statusID, stream.EventTypeDelete, stream.AllStatusTimelines, accountID); err != nil { -			errs = append(errs, err.Error()) -		} -	} - -	if len(errs) != 0 { -		return fmt.Errorf("one or more errors streaming status delete: %s", strings.Join(errs, ";")) -	} - -	return nil  } diff --git a/internal/processing/stream/notification.go b/internal/processing/stream/notification.go index 63d7c5d11..a16da11e6 100644 --- a/internal/processing/stream/notification.go +++ b/internal/processing/stream/notification.go @@ -18,20 +18,29 @@  package stream  import ( +	"context"  	"encoding/json" -	"fmt" +	"codeberg.org/gruf/go-byteutil"  	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/stream"  )  // Notify streams the given notification to any open, appropriate streams belonging to the given account. -func (p *Processor) Notify(n *apimodel.Notification, account *gtsmodel.Account) error { -	bytes, err := json.Marshal(n) +func (p *Processor) Notify(ctx context.Context, account *gtsmodel.Account, notif *apimodel.Notification) { +	b, err := json.Marshal(notif)  	if err != nil { -		return fmt.Errorf("error marshalling notification to json: %s", err) +		log.Errorf(ctx, "error marshaling json: %v", err) +		return  	} - -	return p.toAccount(string(bytes), stream.EventTypeNotification, []string{stream.TimelineNotifications, stream.TimelineHome}, account.ID) +	p.streams.Post(ctx, account.ID, stream.Message{ +		Payload: byteutil.B2S(b), +		Event:   stream.EventTypeNotification, +		Stream: []string{ +			stream.TimelineNotifications, +			stream.TimelineHome, +		}, +	})  } diff --git a/internal/processing/stream/notification_test.go b/internal/processing/stream/notification_test.go index 2138f0025..e12f23abe 100644 --- a/internal/processing/stream/notification_test.go +++ b/internal/processing/stream/notification_test.go @@ -49,10 +49,11 @@ func (suite *NotificationTestSuite) TestStreamNotification() {  		Account:   followAccountAPIModel,  	} -	err = suite.streamProcessor.Notify(notification, account) -	suite.NoError(err) +	suite.streamProcessor.Notify(context.Background(), account, notification) + +	msg, ok := openStream.Recv(context.Background()) +	suite.True(ok) -	msg := <-openStream.Messages  	dst := new(bytes.Buffer)  	err = json.Indent(dst, []byte(msg.Payload), "", "  ")  	suite.NoError(err) diff --git a/internal/processing/stream/open.go b/internal/processing/stream/open.go index 1c041309f..2f2bbd4a3 100644 --- a/internal/processing/stream/open.go +++ b/internal/processing/stream/open.go @@ -19,13 +19,10 @@ package stream  import (  	"context" -	"errors" -	"fmt"  	"codeberg.org/gruf/go-kv"  	"github.com/superseriousbusiness/gotosocial/internal/gtserror"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" -	"github.com/superseriousbusiness/gotosocial/internal/id"  	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/stream"  ) @@ -37,97 +34,5 @@ func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamT  		{"streamType", streamType},  	}...)  	l.Debug("received open stream request") - -	var ( -		streamID string -		err      error -	) - -	// Each stream needs a unique ID so we know to close it. -	streamID, err = id.NewRandomULID() -	if err != nil { -		return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %w", err)) -	} - -	// Each stream can be subscibed to multiple types. -	// Record them in a set, and include the initial one -	// if it was given to us. -	streamTypes := map[string]any{} -	if streamType != "" { -		streamTypes[streamType] = true -	} - -	newStream := &stream.Stream{ -		ID:          streamID, -		StreamTypes: streamTypes, -		Messages:    make(chan *stream.Message, 100), -		Hangup:      make(chan interface{}, 1), -		Connected:   true, -	} -	go p.waitToCloseStream(account, newStream) - -	v, ok := p.streamMap.Load(account.ID) -	if ok { -		// There is an entry in the streamMap -		// for this account. Parse it out. -		streamsForAccount, ok := v.(*stream.StreamsForAccount) -		if !ok { -			return nil, gtserror.NewErrorInternalError(errors.New("stream map error")) -		} - -		// Append new stream to existing entry. -		streamsForAccount.Lock() -		streamsForAccount.Streams = append(streamsForAccount.Streams, newStream) -		streamsForAccount.Unlock() -	} else { -		// There is no entry in the streamMap for -		// this account yet. Create one and store it. -		p.streamMap.Store(account.ID, &stream.StreamsForAccount{ -			Streams: []*stream.Stream{ -				newStream, -			}, -		}) -	} - -	return newStream, nil -} - -// waitToCloseStream waits until the hangup channel is closed for the given stream. -// It then iterates through the map of streams stored by the processor, removes the stream from it, -// and then closes the messages channel of the stream to indicate that the channel should no longer be read from. -func (p *Processor) waitToCloseStream(account *gtsmodel.Account, thisStream *stream.Stream) { -	<-thisStream.Hangup // wait for a hangup message - -	// lock the stream to prevent more messages being put in it while we work -	thisStream.Lock() -	defer thisStream.Unlock() - -	// indicate the stream is no longer connected -	thisStream.Connected = false - -	// load and parse the entry for this account from the stream map -	v, ok := p.streamMap.Load(account.ID) -	if !ok || v == nil { -		return -	} -	streamsForAccount, ok := v.(*stream.StreamsForAccount) -	if !ok { -		return -	} - -	// lock the streams for account while we remove this stream from its slice -	streamsForAccount.Lock() -	defer streamsForAccount.Unlock() - -	// put everything into modified streams *except* the stream we're removing -	modifiedStreams := []*stream.Stream{} -	for _, s := range streamsForAccount.Streams { -		if s.ID != thisStream.ID { -			modifiedStreams = append(modifiedStreams, s) -		} -	} -	streamsForAccount.Streams = modifiedStreams - -	// finally close the messages channel so no more messages can be read from it -	close(thisStream.Messages) +	return p.streams.Open(account.ID, streamType), nil  } diff --git a/internal/processing/stream/statusupdate.go b/internal/processing/stream/statusupdate.go index fd8e388ce..bd4658873 100644 --- a/internal/processing/stream/statusupdate.go +++ b/internal/processing/stream/statusupdate.go @@ -18,21 +18,26 @@  package stream  import ( +	"context"  	"encoding/json" -	"fmt" +	"codeberg.org/gruf/go-byteutil"  	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/stream"  ) -// StatusUpdate streams the given edited status to any open, appropriate -// streams belonging to the given account. -func (p *Processor) StatusUpdate(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error { -	bytes, err := json.Marshal(s) +// StatusUpdate streams the given edited status to any open, appropriate streams belonging to the given account. +func (p *Processor) StatusUpdate(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) { +	b, err := json.Marshal(status)  	if err != nil { -		return fmt.Errorf("error marshalling status to json: %s", err) +		log.Errorf(ctx, "error marshaling json: %v", err) +		return  	} - -	return p.toAccount(string(bytes), stream.EventTypeStatusUpdate, streamTypes, account.ID) +	p.streams.Post(ctx, account.ID, stream.Message{ +		Payload: byteutil.B2S(b), +		Event:   stream.EventTypeStatusUpdate, +		Stream:  []string{streamType}, +	})  } diff --git a/internal/processing/stream/statusupdate_test.go b/internal/processing/stream/statusupdate_test.go index 7b987b412..8814c966f 100644 --- a/internal/processing/stream/statusupdate_test.go +++ b/internal/processing/stream/statusupdate_test.go @@ -42,10 +42,11 @@ func (suite *StatusUpdateTestSuite) TestStreamNotification() {  	apiStatus, err := typeutils.NewConverter(&suite.state).StatusToAPIStatus(context.Background(), editedStatus, account)  	suite.NoError(err) -	err = suite.streamProcessor.StatusUpdate(apiStatus, account, []string{stream.TimelineHome}) -	suite.NoError(err) +	suite.streamProcessor.StatusUpdate(context.Background(), account, apiStatus, stream.TimelineHome) + +	msg, ok := openStream.Recv(context.Background()) +	suite.True(ok) -	msg := <-openStream.Messages  	dst := new(bytes.Buffer)  	err = json.Indent(dst, []byte(msg.Payload), "", "  ")  	suite.NoError(err) diff --git a/internal/processing/stream/stream.go b/internal/processing/stream/stream.go index a5b3b9386..0b7285b58 100644 --- a/internal/processing/stream/stream.go +++ b/internal/processing/stream/stream.go @@ -18,8 +18,6 @@  package stream  import ( -	"sync" -  	"github.com/superseriousbusiness/gotosocial/internal/oauth"  	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/stream" @@ -28,53 +26,13 @@ import (  type Processor struct {  	state       *state.State  	oauthServer oauth.Server -	streamMap   *sync.Map +	streams     stream.Streams  }  func New(state *state.State, oauthServer oauth.Server) Processor {  	return Processor{  		state:       state,  		oauthServer: oauthServer, -		streamMap:   &sync.Map{}, -	} -} - -// toAccount streams the given payload with the given event type to any streams currently open for the given account ID. -func (p *Processor) toAccount(payload string, event string, streamTypes []string, accountID string) error { -	// Load all streams open for this account. -	v, ok := p.streamMap.Load(accountID) -	if !ok { -		return nil // No entry = nothing to stream. +		streams:     stream.Streams{},  	} -	streamsForAccount := v.(*stream.StreamsForAccount) - -	streamsForAccount.Lock() -	defer streamsForAccount.Unlock() - -	for _, s := range streamsForAccount.Streams { -		s.Lock() -		defer s.Unlock() - -		if !s.Connected { -			continue -		} - -	typeLoop: -		for _, streamType := range streamTypes { -			if _, found := s.StreamTypes[streamType]; found { -				s.Messages <- &stream.Message{ -					Stream:  []string{streamType}, -					Event:   string(event), -					Payload: payload, -				} - -				// Break out to the outer loop, -				// to avoid sending duplicates of -				// the same event to the same stream. -				break typeLoop -			} -		} -	} - -	return nil  } diff --git a/internal/processing/stream/update.go b/internal/processing/stream/update.go index ee70bda11..a84763d51 100644 --- a/internal/processing/stream/update.go +++ b/internal/processing/stream/update.go @@ -18,20 +18,26 @@  package stream  import ( +	"context"  	"encoding/json" -	"fmt" +	"codeberg.org/gruf/go-byteutil"  	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/log"  	"github.com/superseriousbusiness/gotosocial/internal/stream"  )  // Update streams the given update to any open, appropriate streams belonging to the given account. -func (p *Processor) Update(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error { -	bytes, err := json.Marshal(s) +func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) { +	b, err := json.Marshal(status)  	if err != nil { -		return fmt.Errorf("error marshalling status to json: %s", err) +		log.Errorf(ctx, "error marshaling json: %v", err) +		return  	} - -	return p.toAccount(string(bytes), stream.EventTypeUpdate, streamTypes, account.ID) +	p.streams.Post(ctx, account.ID, stream.Message{ +		Payload: byteutil.B2S(b), +		Event:   stream.EventTypeUpdate, +		Stream:  []string{streamType}, +	})  } diff --git a/internal/processing/workers/fromclientapi_test.go b/internal/processing/workers/fromclientapi_test.go index 05526f437..3d3630b11 100644 --- a/internal/processing/workers/fromclientapi_test.go +++ b/internal/processing/workers/fromclientapi_test.go @@ -116,23 +116,20 @@ func (suite *FromClientAPITestSuite) checkStreamed(  	expectPayload string,  	expectEventType string,  ) { -	var msg *stream.Message -streamLoop: -	for { -		select { -		case msg = <-str.Messages: -			break streamLoop // Got it. -		case <-time.After(5 * time.Second): -			break streamLoop // Didn't get it. -		} -	} -	if expectMessage && msg == nil { -		suite.FailNow("expected a message but message was nil") +	// Set a 5s timeout on context. +	ctx := context.Background() +	ctx, cncl := context.WithTimeout(ctx, time.Second*5) +	defer cncl() + +	msg, ok := str.Recv(ctx) + +	if expectMessage && !ok { +		suite.FailNow("expected a message but message was not received")  	} -	if !expectMessage && msg != nil { -		suite.FailNow("expected no message but message was not nil") +	if !expectMessage && ok { +		suite.FailNow("expected no message but message was received")  	}  	if expectPayload != "" && msg.Payload != expectPayload { diff --git a/internal/processing/workers/fromfediapi_test.go b/internal/processing/workers/fromfediapi_test.go index 799eaf2dc..446355628 100644 --- a/internal/processing/workers/fromfediapi_test.go +++ b/internal/processing/workers/fromfediapi_test.go @@ -130,14 +130,9 @@ func (suite *FromFediAPITestSuite) TestProcessReplyMention() {  	suite.Equal(replyingStatus.ID, notif.StatusID)  	suite.False(*notif.Read) -	// the notification should be streamed -	var msg *stream.Message -	select { -	case msg = <-wssStream.Messages: -		// fine -	case <-time.After(5 * time.Second): -		suite.FailNow("no message from wssStream") -	} +	ctx, _ := context.WithTimeout(context.Background(), time.Second*5) +	msg, ok := wssStream.Recv(ctx) +	suite.True(ok)  	suite.Equal(stream.EventTypeNotification, msg.Event)  	suite.NotEmpty(msg.Payload) @@ -203,14 +198,10 @@ func (suite *FromFediAPITestSuite) TestProcessFave() {  	suite.Equal(fave.StatusID, notif.StatusID)  	suite.False(*notif.Read) -	// 2. a notification should be streamed -	var msg *stream.Message -	select { -	case msg = <-wssStream.Messages: -		// fine -	case <-time.After(5 * time.Second): -		suite.FailNow("no message from wssStream") -	} +	ctx, _ := context.WithTimeout(context.Background(), time.Second*5) +	msg, ok := wssStream.Recv(ctx) +	suite.True(ok) +  	suite.Equal(stream.EventTypeNotification, msg.Event)  	suite.NotEmpty(msg.Payload)  	suite.EqualValues([]string{stream.TimelineNotifications}, msg.Stream) @@ -277,7 +268,9 @@ func (suite *FromFediAPITestSuite) TestProcessFaveWithDifferentReceivingAccount(  	suite.False(*notif.Read)  	// 2. no notification should be streamed to the account that received the fave message, because they weren't the target -	suite.Empty(wssStream.Messages) +	ctx, _ := context.WithTimeout(context.Background(), time.Second*5) +	_, ok := wssStream.Recv(ctx) +	suite.False(ok)  }  func (suite *FromFediAPITestSuite) TestProcessAccountDelete() { @@ -405,14 +398,10 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() {  	})  	suite.NoError(err) -	// a notification should be streamed -	var msg *stream.Message -	select { -	case msg = <-wssStream.Messages: -		// fine -	case <-time.After(5 * time.Second): -		suite.FailNow("no message from wssStream") -	} +	ctx, _ = context.WithTimeout(ctx, time.Second*5) +	msg, ok := wssStream.Recv(context.Background()) +	suite.True(ok) +  	suite.Equal(stream.EventTypeNotification, msg.Event)  	suite.NotEmpty(msg.Payload)  	suite.EqualValues([]string{stream.TimelineHome}, msg.Stream) @@ -423,7 +412,7 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() {  	suite.Equal(originAccount.ID, notif.Account.ID)  	// no messages should have been sent out, since we didn't need to federate an accept -	suite.Empty(suite.httpClient.SentMessages) +	suite.Empty(&suite.httpClient.SentMessages)  }  func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() { @@ -503,14 +492,10 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() {  	suite.Equal(originAccount.URI, accept.To)  	suite.Equal("Accept", accept.Type) -	// a notification should be streamed -	var msg *stream.Message -	select { -	case msg = <-wssStream.Messages: -		// fine -	case <-time.After(5 * time.Second): -		suite.FailNow("no message from wssStream") -	} +	ctx, _ = context.WithTimeout(ctx, time.Second*5) +	msg, ok := wssStream.Recv(context.Background()) +	suite.True(ok) +  	suite.Equal(stream.EventTypeNotification, msg.Event)  	suite.NotEmpty(msg.Payload)  	suite.EqualValues([]string{stream.TimelineHome}, msg.Stream) diff --git a/internal/processing/workers/surfacenotify.go b/internal/processing/workers/surfacenotify.go index 39798f45e..a8c36248c 100644 --- a/internal/processing/workers/surfacenotify.go +++ b/internal/processing/workers/surfacenotify.go @@ -394,10 +394,7 @@ func (s *surface) notify(  	if err != nil {  		return gtserror.Newf("error converting notification to api representation: %w", err)  	} - -	if err := s.stream.Notify(apiNotif, targetAccount); err != nil { -		return gtserror.Newf("error streaming notification to account: %w", err) -	} +	s.stream.Notify(ctx, targetAccount, apiNotif)  	return nil  } diff --git a/internal/processing/workers/surfacetimeline.go b/internal/processing/workers/surfacetimeline.go index e63b8a7c0..14634f846 100644 --- a/internal/processing/workers/surfacetimeline.go +++ b/internal/processing/workers/surfacetimeline.go @@ -348,11 +348,7 @@ func (s *surface) timelineStatus(  		err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err)  		return true, err  	} - -	if err := s.stream.Update(apiStatus, account, []string{streamType}); err != nil { -		err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err) -		return true, err -	} +	s.stream.Update(ctx, account, apiStatus, streamType)  	return true, nil  } @@ -363,12 +359,11 @@ func (s *surface) deleteStatusFromTimelines(ctx context.Context, statusID string  	if err := s.state.Timelines.Home.WipeItemFromAllTimelines(ctx, statusID); err != nil {  		return err  	} -  	if err := s.state.Timelines.List.WipeItemFromAllTimelines(ctx, statusID); err != nil {  		return err  	} - -	return s.stream.Delete(statusID) +	s.stream.Delete(ctx, statusID) +	return nil  }  // invalidateStatusFromTimelines does cache invalidation on the given status by @@ -555,11 +550,6 @@ func (s *surface) timelineStreamStatusUpdate(  		err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err)  		return err  	} - -	if err := s.stream.StatusUpdate(apiStatus, account, []string{streamType}); err != nil { -		err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err) -		return err -	} - +	s.stream.StatusUpdate(ctx, account, apiStatus, streamType)  	return nil  } diff --git a/internal/stream/stream.go b/internal/stream/stream.go index da5647433..ec22464f5 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -17,36 +17,65 @@  package stream -import "sync" +import ( +	"context" +	"maps" +	"slices" +	"sync" +	"sync/atomic" +)  const ( -	// EventTypeNotification -- a user should be shown a notification -	EventTypeNotification string = "notification" -	// EventTypeUpdate -- a user should be shown an update in their timeline -	EventTypeUpdate string = "update" -	// EventTypeDelete -- something should be deleted from a user -	EventTypeDelete string = "delete" -	// EventTypeStatusUpdate -- something in the user's timeline has been edited -	// (yes this is a confusing name, blame Mastodon) -	EventTypeStatusUpdate string = "status.update" +	// EventTypeNotification -- a user +	// should be shown a notification. +	EventTypeNotification = "notification" + +	// EventTypeUpdate -- a user should +	// be shown an update in their timeline. +	EventTypeUpdate = "update" + +	// EventTypeDelete -- something +	// should be deleted from a user. +	EventTypeDelete = "delete" + +	// EventTypeStatusUpdate -- something in the +	// user's timeline has been edited (yes this +	// is a confusing name, blame Mastodon ...). +	EventTypeStatusUpdate = "status.update"  )  const ( -	// TimelineLocal -- public statuses from the LOCAL timeline. -	TimelineLocal string = "public:local" -	// TimelinePublic -- public statuses, including federated ones. -	TimelinePublic string = "public" -	// TimelineHome -- statuses for a user's Home timeline. -	TimelineHome string = "user" -	// TimelineNotifications -- notification events. -	TimelineNotifications string = "user:notification" -	// TimelineDirect -- statuses sent to a user directly. -	TimelineDirect string = "direct" -	// TimelineList -- statuses for a user's list timeline. -	TimelineList string = "list" +	// TimelineLocal: +	// All public posts originating from this +	// server. Analogous to the local timeline. +	TimelineLocal = "public:local" + +	// TimelinePublic: +	// All public posts known to the server. +	// Analogous to the federated timeline. +	TimelinePublic = "public" + +	// TimelineHome: +	// Events related to the current user, such +	// as home feed updates and notifications. +	TimelineHome = "user" + +	// TimelineNotifications: +	// Notifications for the current user. +	TimelineNotifications = "user:notification" + +	// TimelineDirect: +	// Updates to direct conversations. +	TimelineDirect = "direct" + +	// TimelineList: +	// Updates to a specific list. +	TimelineList = "list"  ) -// AllStatusTimelines contains all Timelines that a status could conceivably be delivered to -- useful for doing deletes. +// AllStatusTimelines contains all Timelines +// that a status could conceivably be delivered +// to, useful for sending out status deletes.  var AllStatusTimelines = []string{  	TimelineLocal,  	TimelinePublic, @@ -55,38 +84,298 @@ var AllStatusTimelines = []string{  	TimelineList,  } -// StreamsForAccount is a wrapper for the multiple streams that one account can have running at the same time. -// TODO: put a limit on this -type StreamsForAccount struct { -	// The currently held streams for this account -	Streams []*Stream -	// Mutex to lock/unlock when modifying the slice of streams. -	sync.Mutex +type Streams struct { +	streams map[string][]*Stream +	mutex   sync.Mutex +} + +// Open will open open a new Stream for given account ID and stream types, the given context will be passed to Stream. +func (s *Streams) Open(accountID string, streamTypes ...string) *Stream { +	if len(streamTypes) == 0 { +		panic("no stream types given") +	} + +	// Prep new Stream. +	str := new(Stream) +	str.done = make(chan struct{}) +	str.msgCh = make(chan Message, 50) // TODO: make configurable +	for _, streamType := range streamTypes { +		str.Subscribe(streamType) +	} + +	// TODO: add configurable +	// max streams per account. + +	// Acquire lock. +	s.mutex.Lock() + +	if s.streams == nil { +		// Main stream-map needs allocating. +		s.streams = make(map[string][]*Stream) +	} + +	// Add new stream for account. +	strs := s.streams[accountID] +	strs = append(strs, str) +	s.streams[accountID] = strs + +	// Register close callback +	// to remove stream from our +	// internal map for this account. +	str.close = func() { +		s.mutex.Lock() +		strs := s.streams[accountID] +		strs = slices.DeleteFunc(strs, func(s *Stream) bool { +			return s == str // remove 'str' ptr +		}) +		s.streams[accountID] = strs +		s.mutex.Unlock() +	} + +	// Done with lock. +	s.mutex.Unlock() + +	return str +} + +// Post will post the given message to all streams of given account ID matching type. +func (s *Streams) Post(ctx context.Context, accountID string, msg Message) bool { +	var deferred []func() bool + +	// Acquire lock. +	s.mutex.Lock() + +	// Iterate all streams stored for account. +	for _, str := range s.streams[accountID] { + +		// Check whether stream supports any of our message targets. +		if stype := str.getStreamType(msg.Stream...); stype != "" { + +			// Rescope var +			// to prevent +			// ptr reuse. +			stream := str + +			// Use a message copy to *only* +			// include the supported stream. +			msgCopy := Message{ +				Stream:  []string{stype}, +				Event:   msg.Event, +				Payload: msg.Payload, +			} + +			// Send message to supported stream +			// DEFERRED (i.e. OUTSIDE OF MAIN MUTEX). +			// This prevents deadlocks between each +			// msg channel and main Streams{} mutex. +			deferred = append(deferred, func() bool { +				return stream.send(ctx, msgCopy) +			}) +		} +	} + +	// Done with lock. +	s.mutex.Unlock() + +	var ok bool + +	// Execute deferred outside lock. +	for _, deferfn := range deferred { +		v := deferfn() +		ok = ok && v +	} + +	return ok +} + +// PostAll will post the given message to all streams with matching types. +func (s *Streams) PostAll(ctx context.Context, msg Message) bool { +	var deferred []func() bool + +	// Acquire lock. +	s.mutex.Lock() + +	// Iterate ALL stored streams. +	for _, strs := range s.streams { +		for _, str := range strs { + +			// Check whether stream supports any of our message targets. +			if stype := str.getStreamType(msg.Stream...); stype != "" { + +				// Rescope var +				// to prevent +				// ptr reuse. +				stream := str + +				// Use a message copy to *only* +				// include the supported stream. +				msgCopy := Message{ +					Stream:  []string{stype}, +					Event:   msg.Event, +					Payload: msg.Payload, +				} + +				// Send message to supported stream +				// DEFERRED (i.e. OUTSIDE OF MAIN MUTEX). +				// This prevents deadlocks between each +				// msg channel and main Streams{} mutex. +				deferred = append(deferred, func() bool { +					return stream.send(ctx, msgCopy) +				}) +			} +		} +	} + +	// Done with lock. +	s.mutex.Unlock() + +	var ok bool + +	// Execute deferred outside lock. +	for _, deferfn := range deferred { +		v := deferfn() +		ok = ok && v +	} + +	return ok  } -// Stream represents one open stream for a client. +// Stream represents one +// open stream for a client.  type Stream struct { -	// ID of this stream, generated during creation. -	ID string -	// A set of types subscribed to by this stream: user/public/etc. -	// It's a map to ensure no duplicates; the value is ignored. -	StreamTypes map[string]any -	// Channel of messages for the client to read from -	Messages chan *Message -	// Channel to close when the client drops away -	Hangup chan interface{} -	// Only put messages in the stream when Connected -	Connected bool -	// Mutex to lock/unlock when inserting messages, hanging up, changing the connected state etc. -	sync.Mutex + +	// atomically updated ptr to a read-only copy +	// of supported stream types in a hashmap. this +	// gets updated via CAS operations in .cas(). +	types atomic.Pointer[map[string]struct{}] + +	// protects stream close. +	done chan struct{} + +	// inbound msg ch. +	msgCh chan Message + +	// close hook to remove +	// stream from Streams{}. +	close func() +} + +// Subscribe will add given type to given types this stream supports. +func (s *Stream) Subscribe(streamType string) { +	s.cas(func(m map[string]struct{}) bool { +		if _, ok := m[streamType]; ok { +			return false +		} +		m[streamType] = struct{}{} +		return true +	}) +} + +// Unsubscribe will remove given type (if found) from types this stream supports. +func (s *Stream) Unsubscribe(streamType string) { +	s.cas(func(m map[string]struct{}) bool { +		if _, ok := m[streamType]; !ok { +			return false +		} +		delete(m, streamType) +		return true +	})  } -// Message represents one streamed message. +// getStreamType returns the first stream type in given list that stream supports. +func (s *Stream) getStreamType(streamTypes ...string) string { +	if ptr := s.types.Load(); ptr != nil { +		for _, streamType := range streamTypes { +			if _, ok := (*ptr)[streamType]; ok { +				return streamType +			} +		} +	} +	return "" +} + +// send will block on posting a new Message{}, returning early with +// a false value if provided context is canceled, or stream closed. +func (s *Stream) send(ctx context.Context, msg Message) bool { +	select { +	case <-s.done: +		return false +	case <-ctx.Done(): +		return false +	case s.msgCh <- msg: +		return true +	} +} + +// Recv will block on receiving Message{}, returning early with a +// false value if provided context is canceled, or stream closed. +func (s *Stream) Recv(ctx context.Context) (Message, bool) { +	select { +	case <-s.done: +		return Message{}, false +	case <-ctx.Done(): +		return Message{}, false +	case msg := <-s.msgCh: +		return msg, true +	} +} + +// Close will close the underlying context, finally +// removing it from the parent Streams per-account-map. +func (s *Stream) Close() { +	select { +	case <-s.done: +	default: +		close(s.done) +		s.close() +	} +} + +// cas will perform a Compare And Swap operation on s.types using modifier func. +func (s *Stream) cas(fn func(map[string]struct{}) bool) { +	if fn == nil { +		panic("nil function") +	} +	for { +		var m map[string]struct{} + +		// Get current value. +		ptr := s.types.Load() + +		if ptr == nil { +			// Allocate new types map. +			m = make(map[string]struct{}) +		} else { +			// Clone r-only map. +			m = maps.Clone(*ptr) +		} + +		// Apply +		// changes. +		if !fn(m) { +			return +		} + +		// Attempt to Compare And Swap ptr. +		if s.types.CompareAndSwap(ptr, &m) { +			return +		} +	} +} + +// Message represents +// one streamed message.  type Message struct { -	// All the stream types this message should be delivered to. + +	// All the stream types this +	// message should be delivered to.  	Stream []string `json:"stream"` -	// The event type of the message (update/delete/notification etc) + +	// The event type of the message +	// (update/delete/notification etc)  	Event string `json:"event"` -	// The actual payload of the message. In case of an update or notification, this will be a JSON string. + +	// The actual payload of the message. In case of an +	// update or notification, this will be a JSON string.  	Payload string `json:"payload"`  }  | 
