diff options
| author | 2024-02-20 18:07:49 +0000 | |
|---|---|---|
| committer | 2024-02-20 18:07:49 +0000 | |
| commit | 291e18099050ff9e19b8ee25c2ffad68d9baafef (patch) | |
| tree | 0ad1be36b4c958830d1371f3b9a32f017c5dcff0 /internal/stream | |
| parent | [feature] Add `requested_by` to relationship model (#2672) (diff) | |
| download | gotosocial-291e18099050ff9e19b8ee25c2ffad68d9baafef.tar.xz | |
[bugfix] fix possible mutex lockup during streaming code (#2633)
* rewrite Stream{} to use much less mutex locking, update related code
* use new context for the stream context
* ensure stream gets closed on return of writeTo / readFrom WSConn()
* ensure stream write timeout gets cancelled
* remove embedded context type from Stream{}, reformat log messages for consistency
* use c.Request.Context() for context passed into Stream().Open()
* only return 1 boolean, fix tests to expect multiple stream types in messages
* changes to ping logic
* further improved ping logic
* don't export unused function types, update message sending to only include relevant stream type
* ensure stream gets closed :facepalm:
* update to error log on failed json marshal (instead of panic)
* inverse websocket read error checking to _ignore_ expected close errors
Diffstat (limited to 'internal/stream')
| -rw-r--r-- | internal/stream/stream.go | 385 | 
1 files changed, 337 insertions, 48 deletions
| diff --git a/internal/stream/stream.go b/internal/stream/stream.go index da5647433..ec22464f5 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -17,36 +17,65 @@  package stream -import "sync" +import ( +	"context" +	"maps" +	"slices" +	"sync" +	"sync/atomic" +)  const ( -	// EventTypeNotification -- a user should be shown a notification -	EventTypeNotification string = "notification" -	// EventTypeUpdate -- a user should be shown an update in their timeline -	EventTypeUpdate string = "update" -	// EventTypeDelete -- something should be deleted from a user -	EventTypeDelete string = "delete" -	// EventTypeStatusUpdate -- something in the user's timeline has been edited -	// (yes this is a confusing name, blame Mastodon) -	EventTypeStatusUpdate string = "status.update" +	// EventTypeNotification -- a user +	// should be shown a notification. +	EventTypeNotification = "notification" + +	// EventTypeUpdate -- a user should +	// be shown an update in their timeline. +	EventTypeUpdate = "update" + +	// EventTypeDelete -- something +	// should be deleted from a user. +	EventTypeDelete = "delete" + +	// EventTypeStatusUpdate -- something in the +	// user's timeline has been edited (yes this +	// is a confusing name, blame Mastodon ...). +	EventTypeStatusUpdate = "status.update"  )  const ( -	// TimelineLocal -- public statuses from the LOCAL timeline. -	TimelineLocal string = "public:local" -	// TimelinePublic -- public statuses, including federated ones. -	TimelinePublic string = "public" -	// TimelineHome -- statuses for a user's Home timeline. -	TimelineHome string = "user" -	// TimelineNotifications -- notification events. -	TimelineNotifications string = "user:notification" -	// TimelineDirect -- statuses sent to a user directly. -	TimelineDirect string = "direct" -	// TimelineList -- statuses for a user's list timeline. -	TimelineList string = "list" +	// TimelineLocal: +	// All public posts originating from this +	// server. Analogous to the local timeline. +	TimelineLocal = "public:local" + +	// TimelinePublic: +	// All public posts known to the server. +	// Analogous to the federated timeline. +	TimelinePublic = "public" + +	// TimelineHome: +	// Events related to the current user, such +	// as home feed updates and notifications. +	TimelineHome = "user" + +	// TimelineNotifications: +	// Notifications for the current user. +	TimelineNotifications = "user:notification" + +	// TimelineDirect: +	// Updates to direct conversations. +	TimelineDirect = "direct" + +	// TimelineList: +	// Updates to a specific list. +	TimelineList = "list"  ) -// AllStatusTimelines contains all Timelines that a status could conceivably be delivered to -- useful for doing deletes. +// AllStatusTimelines contains all Timelines +// that a status could conceivably be delivered +// to, useful for sending out status deletes.  var AllStatusTimelines = []string{  	TimelineLocal,  	TimelinePublic, @@ -55,38 +84,298 @@ var AllStatusTimelines = []string{  	TimelineList,  } -// StreamsForAccount is a wrapper for the multiple streams that one account can have running at the same time. -// TODO: put a limit on this -type StreamsForAccount struct { -	// The currently held streams for this account -	Streams []*Stream -	// Mutex to lock/unlock when modifying the slice of streams. -	sync.Mutex +type Streams struct { +	streams map[string][]*Stream +	mutex   sync.Mutex +} + +// Open will open open a new Stream for given account ID and stream types, the given context will be passed to Stream. +func (s *Streams) Open(accountID string, streamTypes ...string) *Stream { +	if len(streamTypes) == 0 { +		panic("no stream types given") +	} + +	// Prep new Stream. +	str := new(Stream) +	str.done = make(chan struct{}) +	str.msgCh = make(chan Message, 50) // TODO: make configurable +	for _, streamType := range streamTypes { +		str.Subscribe(streamType) +	} + +	// TODO: add configurable +	// max streams per account. + +	// Acquire lock. +	s.mutex.Lock() + +	if s.streams == nil { +		// Main stream-map needs allocating. +		s.streams = make(map[string][]*Stream) +	} + +	// Add new stream for account. +	strs := s.streams[accountID] +	strs = append(strs, str) +	s.streams[accountID] = strs + +	// Register close callback +	// to remove stream from our +	// internal map for this account. +	str.close = func() { +		s.mutex.Lock() +		strs := s.streams[accountID] +		strs = slices.DeleteFunc(strs, func(s *Stream) bool { +			return s == str // remove 'str' ptr +		}) +		s.streams[accountID] = strs +		s.mutex.Unlock() +	} + +	// Done with lock. +	s.mutex.Unlock() + +	return str +} + +// Post will post the given message to all streams of given account ID matching type. +func (s *Streams) Post(ctx context.Context, accountID string, msg Message) bool { +	var deferred []func() bool + +	// Acquire lock. +	s.mutex.Lock() + +	// Iterate all streams stored for account. +	for _, str := range s.streams[accountID] { + +		// Check whether stream supports any of our message targets. +		if stype := str.getStreamType(msg.Stream...); stype != "" { + +			// Rescope var +			// to prevent +			// ptr reuse. +			stream := str + +			// Use a message copy to *only* +			// include the supported stream. +			msgCopy := Message{ +				Stream:  []string{stype}, +				Event:   msg.Event, +				Payload: msg.Payload, +			} + +			// Send message to supported stream +			// DEFERRED (i.e. OUTSIDE OF MAIN MUTEX). +			// This prevents deadlocks between each +			// msg channel and main Streams{} mutex. +			deferred = append(deferred, func() bool { +				return stream.send(ctx, msgCopy) +			}) +		} +	} + +	// Done with lock. +	s.mutex.Unlock() + +	var ok bool + +	// Execute deferred outside lock. +	for _, deferfn := range deferred { +		v := deferfn() +		ok = ok && v +	} + +	return ok +} + +// PostAll will post the given message to all streams with matching types. +func (s *Streams) PostAll(ctx context.Context, msg Message) bool { +	var deferred []func() bool + +	// Acquire lock. +	s.mutex.Lock() + +	// Iterate ALL stored streams. +	for _, strs := range s.streams { +		for _, str := range strs { + +			// Check whether stream supports any of our message targets. +			if stype := str.getStreamType(msg.Stream...); stype != "" { + +				// Rescope var +				// to prevent +				// ptr reuse. +				stream := str + +				// Use a message copy to *only* +				// include the supported stream. +				msgCopy := Message{ +					Stream:  []string{stype}, +					Event:   msg.Event, +					Payload: msg.Payload, +				} + +				// Send message to supported stream +				// DEFERRED (i.e. OUTSIDE OF MAIN MUTEX). +				// This prevents deadlocks between each +				// msg channel and main Streams{} mutex. +				deferred = append(deferred, func() bool { +					return stream.send(ctx, msgCopy) +				}) +			} +		} +	} + +	// Done with lock. +	s.mutex.Unlock() + +	var ok bool + +	// Execute deferred outside lock. +	for _, deferfn := range deferred { +		v := deferfn() +		ok = ok && v +	} + +	return ok  } -// Stream represents one open stream for a client. +// Stream represents one +// open stream for a client.  type Stream struct { -	// ID of this stream, generated during creation. -	ID string -	// A set of types subscribed to by this stream: user/public/etc. -	// It's a map to ensure no duplicates; the value is ignored. -	StreamTypes map[string]any -	// Channel of messages for the client to read from -	Messages chan *Message -	// Channel to close when the client drops away -	Hangup chan interface{} -	// Only put messages in the stream when Connected -	Connected bool -	// Mutex to lock/unlock when inserting messages, hanging up, changing the connected state etc. -	sync.Mutex + +	// atomically updated ptr to a read-only copy +	// of supported stream types in a hashmap. this +	// gets updated via CAS operations in .cas(). +	types atomic.Pointer[map[string]struct{}] + +	// protects stream close. +	done chan struct{} + +	// inbound msg ch. +	msgCh chan Message + +	// close hook to remove +	// stream from Streams{}. +	close func() +} + +// Subscribe will add given type to given types this stream supports. +func (s *Stream) Subscribe(streamType string) { +	s.cas(func(m map[string]struct{}) bool { +		if _, ok := m[streamType]; ok { +			return false +		} +		m[streamType] = struct{}{} +		return true +	}) +} + +// Unsubscribe will remove given type (if found) from types this stream supports. +func (s *Stream) Unsubscribe(streamType string) { +	s.cas(func(m map[string]struct{}) bool { +		if _, ok := m[streamType]; !ok { +			return false +		} +		delete(m, streamType) +		return true +	})  } -// Message represents one streamed message. +// getStreamType returns the first stream type in given list that stream supports. +func (s *Stream) getStreamType(streamTypes ...string) string { +	if ptr := s.types.Load(); ptr != nil { +		for _, streamType := range streamTypes { +			if _, ok := (*ptr)[streamType]; ok { +				return streamType +			} +		} +	} +	return "" +} + +// send will block on posting a new Message{}, returning early with +// a false value if provided context is canceled, or stream closed. +func (s *Stream) send(ctx context.Context, msg Message) bool { +	select { +	case <-s.done: +		return false +	case <-ctx.Done(): +		return false +	case s.msgCh <- msg: +		return true +	} +} + +// Recv will block on receiving Message{}, returning early with a +// false value if provided context is canceled, or stream closed. +func (s *Stream) Recv(ctx context.Context) (Message, bool) { +	select { +	case <-s.done: +		return Message{}, false +	case <-ctx.Done(): +		return Message{}, false +	case msg := <-s.msgCh: +		return msg, true +	} +} + +// Close will close the underlying context, finally +// removing it from the parent Streams per-account-map. +func (s *Stream) Close() { +	select { +	case <-s.done: +	default: +		close(s.done) +		s.close() +	} +} + +// cas will perform a Compare And Swap operation on s.types using modifier func. +func (s *Stream) cas(fn func(map[string]struct{}) bool) { +	if fn == nil { +		panic("nil function") +	} +	for { +		var m map[string]struct{} + +		// Get current value. +		ptr := s.types.Load() + +		if ptr == nil { +			// Allocate new types map. +			m = make(map[string]struct{}) +		} else { +			// Clone r-only map. +			m = maps.Clone(*ptr) +		} + +		// Apply +		// changes. +		if !fn(m) { +			return +		} + +		// Attempt to Compare And Swap ptr. +		if s.types.CompareAndSwap(ptr, &m) { +			return +		} +	} +} + +// Message represents +// one streamed message.  type Message struct { -	// All the stream types this message should be delivered to. + +	// All the stream types this +	// message should be delivered to.  	Stream []string `json:"stream"` -	// The event type of the message (update/delete/notification etc) + +	// The event type of the message +	// (update/delete/notification etc)  	Event string `json:"event"` -	// The actual payload of the message. In case of an update or notification, this will be a JSON string. + +	// The actual payload of the message. In case of an +	// update or notification, this will be a JSON string.  	Payload string `json:"payload"`  } | 
