summaryrefslogtreecommitdiff
path: root/internal/stream/stream.go
diff options
context:
space:
mode:
authorLibravatar kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com>2024-02-20 18:07:49 +0000
committerLibravatar GitHub <noreply@github.com>2024-02-20 18:07:49 +0000
commit291e18099050ff9e19b8ee25c2ffad68d9baafef (patch)
tree0ad1be36b4c958830d1371f3b9a32f017c5dcff0 /internal/stream/stream.go
parent[feature] Add `requested_by` to relationship model (#2672) (diff)
downloadgotosocial-291e18099050ff9e19b8ee25c2ffad68d9baafef.tar.xz
[bugfix] fix possible mutex lockup during streaming code (#2633)
* rewrite Stream{} to use much less mutex locking, update related code * use new context for the stream context * ensure stream gets closed on return of writeTo / readFrom WSConn() * ensure stream write timeout gets cancelled * remove embedded context type from Stream{}, reformat log messages for consistency * use c.Request.Context() for context passed into Stream().Open() * only return 1 boolean, fix tests to expect multiple stream types in messages * changes to ping logic * further improved ping logic * don't export unused function types, update message sending to only include relevant stream type * ensure stream gets closed :facepalm: * update to error log on failed json marshal (instead of panic) * inverse websocket read error checking to _ignore_ expected close errors
Diffstat (limited to 'internal/stream/stream.go')
-rw-r--r--internal/stream/stream.go385
1 files changed, 337 insertions, 48 deletions
diff --git a/internal/stream/stream.go b/internal/stream/stream.go
index da5647433..ec22464f5 100644
--- a/internal/stream/stream.go
+++ b/internal/stream/stream.go
@@ -17,36 +17,65 @@
package stream
-import "sync"
+import (
+ "context"
+ "maps"
+ "slices"
+ "sync"
+ "sync/atomic"
+)
const (
- // EventTypeNotification -- a user should be shown a notification
- EventTypeNotification string = "notification"
- // EventTypeUpdate -- a user should be shown an update in their timeline
- EventTypeUpdate string = "update"
- // EventTypeDelete -- something should be deleted from a user
- EventTypeDelete string = "delete"
- // EventTypeStatusUpdate -- something in the user's timeline has been edited
- // (yes this is a confusing name, blame Mastodon)
- EventTypeStatusUpdate string = "status.update"
+ // EventTypeNotification -- a user
+ // should be shown a notification.
+ EventTypeNotification = "notification"
+
+ // EventTypeUpdate -- a user should
+ // be shown an update in their timeline.
+ EventTypeUpdate = "update"
+
+ // EventTypeDelete -- something
+ // should be deleted from a user.
+ EventTypeDelete = "delete"
+
+ // EventTypeStatusUpdate -- something in the
+ // user's timeline has been edited (yes this
+ // is a confusing name, blame Mastodon ...).
+ EventTypeStatusUpdate = "status.update"
)
const (
- // TimelineLocal -- public statuses from the LOCAL timeline.
- TimelineLocal string = "public:local"
- // TimelinePublic -- public statuses, including federated ones.
- TimelinePublic string = "public"
- // TimelineHome -- statuses for a user's Home timeline.
- TimelineHome string = "user"
- // TimelineNotifications -- notification events.
- TimelineNotifications string = "user:notification"
- // TimelineDirect -- statuses sent to a user directly.
- TimelineDirect string = "direct"
- // TimelineList -- statuses for a user's list timeline.
- TimelineList string = "list"
+ // TimelineLocal:
+ // All public posts originating from this
+ // server. Analogous to the local timeline.
+ TimelineLocal = "public:local"
+
+ // TimelinePublic:
+ // All public posts known to the server.
+ // Analogous to the federated timeline.
+ TimelinePublic = "public"
+
+ // TimelineHome:
+ // Events related to the current user, such
+ // as home feed updates and notifications.
+ TimelineHome = "user"
+
+ // TimelineNotifications:
+ // Notifications for the current user.
+ TimelineNotifications = "user:notification"
+
+ // TimelineDirect:
+ // Updates to direct conversations.
+ TimelineDirect = "direct"
+
+ // TimelineList:
+ // Updates to a specific list.
+ TimelineList = "list"
)
-// AllStatusTimelines contains all Timelines that a status could conceivably be delivered to -- useful for doing deletes.
+// AllStatusTimelines contains all Timelines
+// that a status could conceivably be delivered
+// to, useful for sending out status deletes.
var AllStatusTimelines = []string{
TimelineLocal,
TimelinePublic,
@@ -55,38 +84,298 @@ var AllStatusTimelines = []string{
TimelineList,
}
-// StreamsForAccount is a wrapper for the multiple streams that one account can have running at the same time.
-// TODO: put a limit on this
-type StreamsForAccount struct {
- // The currently held streams for this account
- Streams []*Stream
- // Mutex to lock/unlock when modifying the slice of streams.
- sync.Mutex
+type Streams struct {
+ streams map[string][]*Stream
+ mutex sync.Mutex
+}
+
+// Open will open open a new Stream for given account ID and stream types, the given context will be passed to Stream.
+func (s *Streams) Open(accountID string, streamTypes ...string) *Stream {
+ if len(streamTypes) == 0 {
+ panic("no stream types given")
+ }
+
+ // Prep new Stream.
+ str := new(Stream)
+ str.done = make(chan struct{})
+ str.msgCh = make(chan Message, 50) // TODO: make configurable
+ for _, streamType := range streamTypes {
+ str.Subscribe(streamType)
+ }
+
+ // TODO: add configurable
+ // max streams per account.
+
+ // Acquire lock.
+ s.mutex.Lock()
+
+ if s.streams == nil {
+ // Main stream-map needs allocating.
+ s.streams = make(map[string][]*Stream)
+ }
+
+ // Add new stream for account.
+ strs := s.streams[accountID]
+ strs = append(strs, str)
+ s.streams[accountID] = strs
+
+ // Register close callback
+ // to remove stream from our
+ // internal map for this account.
+ str.close = func() {
+ s.mutex.Lock()
+ strs := s.streams[accountID]
+ strs = slices.DeleteFunc(strs, func(s *Stream) bool {
+ return s == str // remove 'str' ptr
+ })
+ s.streams[accountID] = strs
+ s.mutex.Unlock()
+ }
+
+ // Done with lock.
+ s.mutex.Unlock()
+
+ return str
+}
+
+// Post will post the given message to all streams of given account ID matching type.
+func (s *Streams) Post(ctx context.Context, accountID string, msg Message) bool {
+ var deferred []func() bool
+
+ // Acquire lock.
+ s.mutex.Lock()
+
+ // Iterate all streams stored for account.
+ for _, str := range s.streams[accountID] {
+
+ // Check whether stream supports any of our message targets.
+ if stype := str.getStreamType(msg.Stream...); stype != "" {
+
+ // Rescope var
+ // to prevent
+ // ptr reuse.
+ stream := str
+
+ // Use a message copy to *only*
+ // include the supported stream.
+ msgCopy := Message{
+ Stream: []string{stype},
+ Event: msg.Event,
+ Payload: msg.Payload,
+ }
+
+ // Send message to supported stream
+ // DEFERRED (i.e. OUTSIDE OF MAIN MUTEX).
+ // This prevents deadlocks between each
+ // msg channel and main Streams{} mutex.
+ deferred = append(deferred, func() bool {
+ return stream.send(ctx, msgCopy)
+ })
+ }
+ }
+
+ // Done with lock.
+ s.mutex.Unlock()
+
+ var ok bool
+
+ // Execute deferred outside lock.
+ for _, deferfn := range deferred {
+ v := deferfn()
+ ok = ok && v
+ }
+
+ return ok
+}
+
+// PostAll will post the given message to all streams with matching types.
+func (s *Streams) PostAll(ctx context.Context, msg Message) bool {
+ var deferred []func() bool
+
+ // Acquire lock.
+ s.mutex.Lock()
+
+ // Iterate ALL stored streams.
+ for _, strs := range s.streams {
+ for _, str := range strs {
+
+ // Check whether stream supports any of our message targets.
+ if stype := str.getStreamType(msg.Stream...); stype != "" {
+
+ // Rescope var
+ // to prevent
+ // ptr reuse.
+ stream := str
+
+ // Use a message copy to *only*
+ // include the supported stream.
+ msgCopy := Message{
+ Stream: []string{stype},
+ Event: msg.Event,
+ Payload: msg.Payload,
+ }
+
+ // Send message to supported stream
+ // DEFERRED (i.e. OUTSIDE OF MAIN MUTEX).
+ // This prevents deadlocks between each
+ // msg channel and main Streams{} mutex.
+ deferred = append(deferred, func() bool {
+ return stream.send(ctx, msgCopy)
+ })
+ }
+ }
+ }
+
+ // Done with lock.
+ s.mutex.Unlock()
+
+ var ok bool
+
+ // Execute deferred outside lock.
+ for _, deferfn := range deferred {
+ v := deferfn()
+ ok = ok && v
+ }
+
+ return ok
}
-// Stream represents one open stream for a client.
+// Stream represents one
+// open stream for a client.
type Stream struct {
- // ID of this stream, generated during creation.
- ID string
- // A set of types subscribed to by this stream: user/public/etc.
- // It's a map to ensure no duplicates; the value is ignored.
- StreamTypes map[string]any
- // Channel of messages for the client to read from
- Messages chan *Message
- // Channel to close when the client drops away
- Hangup chan interface{}
- // Only put messages in the stream when Connected
- Connected bool
- // Mutex to lock/unlock when inserting messages, hanging up, changing the connected state etc.
- sync.Mutex
+
+ // atomically updated ptr to a read-only copy
+ // of supported stream types in a hashmap. this
+ // gets updated via CAS operations in .cas().
+ types atomic.Pointer[map[string]struct{}]
+
+ // protects stream close.
+ done chan struct{}
+
+ // inbound msg ch.
+ msgCh chan Message
+
+ // close hook to remove
+ // stream from Streams{}.
+ close func()
+}
+
+// Subscribe will add given type to given types this stream supports.
+func (s *Stream) Subscribe(streamType string) {
+ s.cas(func(m map[string]struct{}) bool {
+ if _, ok := m[streamType]; ok {
+ return false
+ }
+ m[streamType] = struct{}{}
+ return true
+ })
+}
+
+// Unsubscribe will remove given type (if found) from types this stream supports.
+func (s *Stream) Unsubscribe(streamType string) {
+ s.cas(func(m map[string]struct{}) bool {
+ if _, ok := m[streamType]; !ok {
+ return false
+ }
+ delete(m, streamType)
+ return true
+ })
}
-// Message represents one streamed message.
+// getStreamType returns the first stream type in given list that stream supports.
+func (s *Stream) getStreamType(streamTypes ...string) string {
+ if ptr := s.types.Load(); ptr != nil {
+ for _, streamType := range streamTypes {
+ if _, ok := (*ptr)[streamType]; ok {
+ return streamType
+ }
+ }
+ }
+ return ""
+}
+
+// send will block on posting a new Message{}, returning early with
+// a false value if provided context is canceled, or stream closed.
+func (s *Stream) send(ctx context.Context, msg Message) bool {
+ select {
+ case <-s.done:
+ return false
+ case <-ctx.Done():
+ return false
+ case s.msgCh <- msg:
+ return true
+ }
+}
+
+// Recv will block on receiving Message{}, returning early with a
+// false value if provided context is canceled, or stream closed.
+func (s *Stream) Recv(ctx context.Context) (Message, bool) {
+ select {
+ case <-s.done:
+ return Message{}, false
+ case <-ctx.Done():
+ return Message{}, false
+ case msg := <-s.msgCh:
+ return msg, true
+ }
+}
+
+// Close will close the underlying context, finally
+// removing it from the parent Streams per-account-map.
+func (s *Stream) Close() {
+ select {
+ case <-s.done:
+ default:
+ close(s.done)
+ s.close()
+ }
+}
+
+// cas will perform a Compare And Swap operation on s.types using modifier func.
+func (s *Stream) cas(fn func(map[string]struct{}) bool) {
+ if fn == nil {
+ panic("nil function")
+ }
+ for {
+ var m map[string]struct{}
+
+ // Get current value.
+ ptr := s.types.Load()
+
+ if ptr == nil {
+ // Allocate new types map.
+ m = make(map[string]struct{})
+ } else {
+ // Clone r-only map.
+ m = maps.Clone(*ptr)
+ }
+
+ // Apply
+ // changes.
+ if !fn(m) {
+ return
+ }
+
+ // Attempt to Compare And Swap ptr.
+ if s.types.CompareAndSwap(ptr, &m) {
+ return
+ }
+ }
+}
+
+// Message represents
+// one streamed message.
type Message struct {
- // All the stream types this message should be delivered to.
+
+ // All the stream types this
+ // message should be delivered to.
Stream []string `json:"stream"`
- // The event type of the message (update/delete/notification etc)
+
+ // The event type of the message
+ // (update/delete/notification etc)
Event string `json:"event"`
- // The actual payload of the message. In case of an update or notification, this will be a JSON string.
+
+ // The actual payload of the message. In case of an
+ // update or notification, this will be a JSON string.
Payload string `json:"payload"`
}