diff options
Diffstat (limited to 'internal/db/bundb/media.go')
-rw-r--r-- | internal/db/bundb/media.go | 140 |
1 files changed, 134 insertions, 6 deletions
diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go index b64447beb..a9b60e3ae 100644 --- a/internal/db/bundb/media.go +++ b/internal/db/bundb/media.go @@ -24,6 +24,7 @@ import ( "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtscontext" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/state" @@ -110,7 +111,7 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error { // Load media into cache before attempting a delete, // as we need it cached in order to trigger the invalidate // callback. This in turn invalidates others. - _, err := m.GetAttachmentByID(gtscontext.SetBarebones(ctx), id) + media, err := m.GetAttachmentByID(gtscontext.SetBarebones(ctx), id) if err != nil { if errors.Is(err, db.ErrNoEntries) { // not an issue. @@ -119,11 +120,115 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error { return err } - // Finally delete media from DB. - _, err = m.conn.NewDelete(). - TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")). - Where("? = ?", bun.Ident("media_attachment.id"), id). - Exec(ctx) + var ( + invalidateAccount bool + invalidateStatus bool + ) + + // Delete media attachment in new transaction. + err = m.conn.RunInTx(ctx, func(tx bun.Tx) error { + if media.AccountID != "" { + var account gtsmodel.Account + + // Get related account model. + if _, err := tx.NewSelect(). + Model(&account). + Where("? = ?", bun.Ident("id"), media.AccountID). + Exec(ctx); err != nil && !errors.Is(err, db.ErrNoEntries) { + return gtserror.Newf("error selecting account: %w", err) + } + + var set func(*bun.UpdateQuery) *bun.UpdateQuery + + switch { + case *media.Avatar && account.AvatarMediaAttachmentID == id: + set = func(q *bun.UpdateQuery) *bun.UpdateQuery { + return q.Set("? = NULL", bun.Ident("avatar_media_attachment_id")) + } + case *media.Header && account.HeaderMediaAttachmentID == id: + set = func(q *bun.UpdateQuery) *bun.UpdateQuery { + return q.Set("? = NULL", bun.Ident("header_media_attachment_id")) + } + } + + if set != nil { + // Note: this handles not found. + // + // Update the account model. + q := tx.NewUpdate(). + Table("accounts"). + Where("? = ?", bun.Ident("id"), account.ID) + if _, err := set(q).Exec(ctx); err != nil { + return gtserror.Newf("error updating account: %w", err) + } + + // Mark as needing invalidate. + invalidateAccount = true + } + } + + if media.StatusID != "" { + var status gtsmodel.Status + + // Get related status model. + if _, err := tx.NewSelect(). + Model(&status). + Where("? = ?", bun.Ident("id"), media.StatusID). + Exec(ctx); err != nil && !errors.Is(err, db.ErrNoEntries) { + return gtserror.Newf("error selecting status: %w", err) + } + + // Get length of attachments beforehand. + before := len(status.AttachmentIDs) + + for i := 0; i < len(status.AttachmentIDs); { + if status.AttachmentIDs[i] == id { + // Remove this reference to deleted attachment ID. + copy(status.AttachmentIDs[i:], status.AttachmentIDs[i+1:]) + status.AttachmentIDs = status.AttachmentIDs[:len(status.AttachmentIDs)-1] + continue + } + i++ + } + + if before != len(status.AttachmentIDs) { + // Note: this accounts for status not found. + // + // Attachments changed, update the status. + if _, err := tx.NewUpdate(). + Table("statuses"). + Where("? = ?", bun.Ident("id"), status.ID). + Set("? = ?", bun.Ident("attachment_ids"), status.AttachmentIDs). + Exec(ctx); err != nil { + return gtserror.Newf("error updating status: %w", err) + } + + // Mark as needing invalidate. + invalidateStatus = true + } + } + + // Finally delete this media. + if _, err := tx.NewDelete(). + Table("media_attachments"). + Where("? = ?", bun.Ident("id"), id). + Exec(ctx); err != nil { + return gtserror.Newf("error deleting media: %w", err) + } + + return nil + }) + + if invalidateAccount { + // The account for given ID will have been updated in transaction. + m.state.Caches.GTS.Account().Invalidate("ID", media.AccountID) + } + + if invalidateStatus { + // The status for given ID will have been updated in transaction. + m.state.Caches.GTS.Status().Invalidate("ID", media.StatusID) + } + return m.conn.ProcessError(err) } @@ -167,6 +272,29 @@ func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) return count, nil } +func (m *mediaDB) GetAttachments(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) { + attachmentIDs := []string{} + + q := m.conn.NewSelect(). + Table("media_attachments"). + Column("id"). + Order("id DESC") + + if maxID != "" { + q = q.Where("? < ?", bun.Ident("id"), maxID) + } + + if limit != 0 { + q = q.Limit(limit) + } + + if err := q.Scan(ctx, &attachmentIDs); err != nil { + return nil, m.conn.ProcessError(err) + } + + return m.GetAttachmentsByIDs(ctx, attachmentIDs) +} + func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, db.Error) { attachmentIDs := []string{} |