diff options
author | 2024-02-20 18:07:49 +0000 | |
---|---|---|
committer | 2024-02-20 18:07:49 +0000 | |
commit | 291e18099050ff9e19b8ee25c2ffad68d9baafef (patch) | |
tree | 0ad1be36b4c958830d1371f3b9a32f017c5dcff0 /internal/stream/stream.go | |
parent | [feature] Add `requested_by` to relationship model (#2672) (diff) | |
download | gotosocial-291e18099050ff9e19b8ee25c2ffad68d9baafef.tar.xz |
[bugfix] fix possible mutex lockup during streaming code (#2633)
* rewrite Stream{} to use much less mutex locking, update related code
* use new context for the stream context
* ensure stream gets closed on return of writeTo / readFrom WSConn()
* ensure stream write timeout gets cancelled
* remove embedded context type from Stream{}, reformat log messages for consistency
* use c.Request.Context() for context passed into Stream().Open()
* only return 1 boolean, fix tests to expect multiple stream types in messages
* changes to ping logic
* further improved ping logic
* don't export unused function types, update message sending to only include relevant stream type
* ensure stream gets closed :facepalm:
* update to error log on failed json marshal (instead of panic)
* inverse websocket read error checking to _ignore_ expected close errors
Diffstat (limited to 'internal/stream/stream.go')
-rw-r--r-- | internal/stream/stream.go | 385 |
1 files changed, 337 insertions, 48 deletions
diff --git a/internal/stream/stream.go b/internal/stream/stream.go index da5647433..ec22464f5 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -17,36 +17,65 @@ package stream -import "sync" +import ( + "context" + "maps" + "slices" + "sync" + "sync/atomic" +) const ( - // EventTypeNotification -- a user should be shown a notification - EventTypeNotification string = "notification" - // EventTypeUpdate -- a user should be shown an update in their timeline - EventTypeUpdate string = "update" - // EventTypeDelete -- something should be deleted from a user - EventTypeDelete string = "delete" - // EventTypeStatusUpdate -- something in the user's timeline has been edited - // (yes this is a confusing name, blame Mastodon) - EventTypeStatusUpdate string = "status.update" + // EventTypeNotification -- a user + // should be shown a notification. + EventTypeNotification = "notification" + + // EventTypeUpdate -- a user should + // be shown an update in their timeline. + EventTypeUpdate = "update" + + // EventTypeDelete -- something + // should be deleted from a user. + EventTypeDelete = "delete" + + // EventTypeStatusUpdate -- something in the + // user's timeline has been edited (yes this + // is a confusing name, blame Mastodon ...). + EventTypeStatusUpdate = "status.update" ) const ( - // TimelineLocal -- public statuses from the LOCAL timeline. - TimelineLocal string = "public:local" - // TimelinePublic -- public statuses, including federated ones. - TimelinePublic string = "public" - // TimelineHome -- statuses for a user's Home timeline. - TimelineHome string = "user" - // TimelineNotifications -- notification events. - TimelineNotifications string = "user:notification" - // TimelineDirect -- statuses sent to a user directly. - TimelineDirect string = "direct" - // TimelineList -- statuses for a user's list timeline. - TimelineList string = "list" + // TimelineLocal: + // All public posts originating from this + // server. Analogous to the local timeline. + TimelineLocal = "public:local" + + // TimelinePublic: + // All public posts known to the server. + // Analogous to the federated timeline. + TimelinePublic = "public" + + // TimelineHome: + // Events related to the current user, such + // as home feed updates and notifications. + TimelineHome = "user" + + // TimelineNotifications: + // Notifications for the current user. + TimelineNotifications = "user:notification" + + // TimelineDirect: + // Updates to direct conversations. + TimelineDirect = "direct" + + // TimelineList: + // Updates to a specific list. + TimelineList = "list" ) -// AllStatusTimelines contains all Timelines that a status could conceivably be delivered to -- useful for doing deletes. +// AllStatusTimelines contains all Timelines +// that a status could conceivably be delivered +// to, useful for sending out status deletes. var AllStatusTimelines = []string{ TimelineLocal, TimelinePublic, @@ -55,38 +84,298 @@ var AllStatusTimelines = []string{ TimelineList, } -// StreamsForAccount is a wrapper for the multiple streams that one account can have running at the same time. -// TODO: put a limit on this -type StreamsForAccount struct { - // The currently held streams for this account - Streams []*Stream - // Mutex to lock/unlock when modifying the slice of streams. - sync.Mutex +type Streams struct { + streams map[string][]*Stream + mutex sync.Mutex +} + +// Open will open open a new Stream for given account ID and stream types, the given context will be passed to Stream. +func (s *Streams) Open(accountID string, streamTypes ...string) *Stream { + if len(streamTypes) == 0 { + panic("no stream types given") + } + + // Prep new Stream. + str := new(Stream) + str.done = make(chan struct{}) + str.msgCh = make(chan Message, 50) // TODO: make configurable + for _, streamType := range streamTypes { + str.Subscribe(streamType) + } + + // TODO: add configurable + // max streams per account. + + // Acquire lock. + s.mutex.Lock() + + if s.streams == nil { + // Main stream-map needs allocating. + s.streams = make(map[string][]*Stream) + } + + // Add new stream for account. + strs := s.streams[accountID] + strs = append(strs, str) + s.streams[accountID] = strs + + // Register close callback + // to remove stream from our + // internal map for this account. + str.close = func() { + s.mutex.Lock() + strs := s.streams[accountID] + strs = slices.DeleteFunc(strs, func(s *Stream) bool { + return s == str // remove 'str' ptr + }) + s.streams[accountID] = strs + s.mutex.Unlock() + } + + // Done with lock. + s.mutex.Unlock() + + return str +} + +// Post will post the given message to all streams of given account ID matching type. +func (s *Streams) Post(ctx context.Context, accountID string, msg Message) bool { + var deferred []func() bool + + // Acquire lock. + s.mutex.Lock() + + // Iterate all streams stored for account. + for _, str := range s.streams[accountID] { + + // Check whether stream supports any of our message targets. + if stype := str.getStreamType(msg.Stream...); stype != "" { + + // Rescope var + // to prevent + // ptr reuse. + stream := str + + // Use a message copy to *only* + // include the supported stream. + msgCopy := Message{ + Stream: []string{stype}, + Event: msg.Event, + Payload: msg.Payload, + } + + // Send message to supported stream + // DEFERRED (i.e. OUTSIDE OF MAIN MUTEX). + // This prevents deadlocks between each + // msg channel and main Streams{} mutex. + deferred = append(deferred, func() bool { + return stream.send(ctx, msgCopy) + }) + } + } + + // Done with lock. + s.mutex.Unlock() + + var ok bool + + // Execute deferred outside lock. + for _, deferfn := range deferred { + v := deferfn() + ok = ok && v + } + + return ok +} + +// PostAll will post the given message to all streams with matching types. +func (s *Streams) PostAll(ctx context.Context, msg Message) bool { + var deferred []func() bool + + // Acquire lock. + s.mutex.Lock() + + // Iterate ALL stored streams. + for _, strs := range s.streams { + for _, str := range strs { + + // Check whether stream supports any of our message targets. + if stype := str.getStreamType(msg.Stream...); stype != "" { + + // Rescope var + // to prevent + // ptr reuse. + stream := str + + // Use a message copy to *only* + // include the supported stream. + msgCopy := Message{ + Stream: []string{stype}, + Event: msg.Event, + Payload: msg.Payload, + } + + // Send message to supported stream + // DEFERRED (i.e. OUTSIDE OF MAIN MUTEX). + // This prevents deadlocks between each + // msg channel and main Streams{} mutex. + deferred = append(deferred, func() bool { + return stream.send(ctx, msgCopy) + }) + } + } + } + + // Done with lock. + s.mutex.Unlock() + + var ok bool + + // Execute deferred outside lock. + for _, deferfn := range deferred { + v := deferfn() + ok = ok && v + } + + return ok } -// Stream represents one open stream for a client. +// Stream represents one +// open stream for a client. type Stream struct { - // ID of this stream, generated during creation. - ID string - // A set of types subscribed to by this stream: user/public/etc. - // It's a map to ensure no duplicates; the value is ignored. - StreamTypes map[string]any - // Channel of messages for the client to read from - Messages chan *Message - // Channel to close when the client drops away - Hangup chan interface{} - // Only put messages in the stream when Connected - Connected bool - // Mutex to lock/unlock when inserting messages, hanging up, changing the connected state etc. - sync.Mutex + + // atomically updated ptr to a read-only copy + // of supported stream types in a hashmap. this + // gets updated via CAS operations in .cas(). + types atomic.Pointer[map[string]struct{}] + + // protects stream close. + done chan struct{} + + // inbound msg ch. + msgCh chan Message + + // close hook to remove + // stream from Streams{}. + close func() +} + +// Subscribe will add given type to given types this stream supports. +func (s *Stream) Subscribe(streamType string) { + s.cas(func(m map[string]struct{}) bool { + if _, ok := m[streamType]; ok { + return false + } + m[streamType] = struct{}{} + return true + }) +} + +// Unsubscribe will remove given type (if found) from types this stream supports. +func (s *Stream) Unsubscribe(streamType string) { + s.cas(func(m map[string]struct{}) bool { + if _, ok := m[streamType]; !ok { + return false + } + delete(m, streamType) + return true + }) } -// Message represents one streamed message. +// getStreamType returns the first stream type in given list that stream supports. +func (s *Stream) getStreamType(streamTypes ...string) string { + if ptr := s.types.Load(); ptr != nil { + for _, streamType := range streamTypes { + if _, ok := (*ptr)[streamType]; ok { + return streamType + } + } + } + return "" +} + +// send will block on posting a new Message{}, returning early with +// a false value if provided context is canceled, or stream closed. +func (s *Stream) send(ctx context.Context, msg Message) bool { + select { + case <-s.done: + return false + case <-ctx.Done(): + return false + case s.msgCh <- msg: + return true + } +} + +// Recv will block on receiving Message{}, returning early with a +// false value if provided context is canceled, or stream closed. +func (s *Stream) Recv(ctx context.Context) (Message, bool) { + select { + case <-s.done: + return Message{}, false + case <-ctx.Done(): + return Message{}, false + case msg := <-s.msgCh: + return msg, true + } +} + +// Close will close the underlying context, finally +// removing it from the parent Streams per-account-map. +func (s *Stream) Close() { + select { + case <-s.done: + default: + close(s.done) + s.close() + } +} + +// cas will perform a Compare And Swap operation on s.types using modifier func. +func (s *Stream) cas(fn func(map[string]struct{}) bool) { + if fn == nil { + panic("nil function") + } + for { + var m map[string]struct{} + + // Get current value. + ptr := s.types.Load() + + if ptr == nil { + // Allocate new types map. + m = make(map[string]struct{}) + } else { + // Clone r-only map. + m = maps.Clone(*ptr) + } + + // Apply + // changes. + if !fn(m) { + return + } + + // Attempt to Compare And Swap ptr. + if s.types.CompareAndSwap(ptr, &m) { + return + } + } +} + +// Message represents +// one streamed message. type Message struct { - // All the stream types this message should be delivered to. + + // All the stream types this + // message should be delivered to. Stream []string `json:"stream"` - // The event type of the message (update/delete/notification etc) + + // The event type of the message + // (update/delete/notification etc) Event string `json:"event"` - // The actual payload of the message. In case of an update or notification, this will be a JSON string. + + // The actual payload of the message. In case of an + // update or notification, this will be a JSON string. Payload string `json:"payload"` } |