summaryrefslogtreecommitdiff
path: root/internal/db/bundb/media.go
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db/bundb/media.go')
-rw-r--r--internal/db/bundb/media.go140
1 files changed, 134 insertions, 6 deletions
diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go
index b64447beb..a9b60e3ae 100644
--- a/internal/db/bundb/media.go
+++ b/internal/db/bundb/media.go
@@ -24,6 +24,7 @@ import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@@ -110,7 +111,7 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
// Load media into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
- _, err := m.GetAttachmentByID(gtscontext.SetBarebones(ctx), id)
+ media, err := m.GetAttachmentByID(gtscontext.SetBarebones(ctx), id)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// not an issue.
@@ -119,11 +120,115 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
return err
}
- // Finally delete media from DB.
- _, err = m.conn.NewDelete().
- TableExpr("? AS ?", bun.Ident("media_attachments"), bun.Ident("media_attachment")).
- Where("? = ?", bun.Ident("media_attachment.id"), id).
- Exec(ctx)
+ var (
+ invalidateAccount bool
+ invalidateStatus bool
+ )
+
+ // Delete media attachment in new transaction.
+ err = m.conn.RunInTx(ctx, func(tx bun.Tx) error {
+ if media.AccountID != "" {
+ var account gtsmodel.Account
+
+ // Get related account model.
+ if _, err := tx.NewSelect().
+ Model(&account).
+ Where("? = ?", bun.Ident("id"), media.AccountID).
+ Exec(ctx); err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return gtserror.Newf("error selecting account: %w", err)
+ }
+
+ var set func(*bun.UpdateQuery) *bun.UpdateQuery
+
+ switch {
+ case *media.Avatar && account.AvatarMediaAttachmentID == id:
+ set = func(q *bun.UpdateQuery) *bun.UpdateQuery {
+ return q.Set("? = NULL", bun.Ident("avatar_media_attachment_id"))
+ }
+ case *media.Header && account.HeaderMediaAttachmentID == id:
+ set = func(q *bun.UpdateQuery) *bun.UpdateQuery {
+ return q.Set("? = NULL", bun.Ident("header_media_attachment_id"))
+ }
+ }
+
+ if set != nil {
+ // Note: this handles not found.
+ //
+ // Update the account model.
+ q := tx.NewUpdate().
+ Table("accounts").
+ Where("? = ?", bun.Ident("id"), account.ID)
+ if _, err := set(q).Exec(ctx); err != nil {
+ return gtserror.Newf("error updating account: %w", err)
+ }
+
+ // Mark as needing invalidate.
+ invalidateAccount = true
+ }
+ }
+
+ if media.StatusID != "" {
+ var status gtsmodel.Status
+
+ // Get related status model.
+ if _, err := tx.NewSelect().
+ Model(&status).
+ Where("? = ?", bun.Ident("id"), media.StatusID).
+ Exec(ctx); err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return gtserror.Newf("error selecting status: %w", err)
+ }
+
+ // Get length of attachments beforehand.
+ before := len(status.AttachmentIDs)
+
+ for i := 0; i < len(status.AttachmentIDs); {
+ if status.AttachmentIDs[i] == id {
+ // Remove this reference to deleted attachment ID.
+ copy(status.AttachmentIDs[i:], status.AttachmentIDs[i+1:])
+ status.AttachmentIDs = status.AttachmentIDs[:len(status.AttachmentIDs)-1]
+ continue
+ }
+ i++
+ }
+
+ if before != len(status.AttachmentIDs) {
+ // Note: this accounts for status not found.
+ //
+ // Attachments changed, update the status.
+ if _, err := tx.NewUpdate().
+ Table("statuses").
+ Where("? = ?", bun.Ident("id"), status.ID).
+ Set("? = ?", bun.Ident("attachment_ids"), status.AttachmentIDs).
+ Exec(ctx); err != nil {
+ return gtserror.Newf("error updating status: %w", err)
+ }
+
+ // Mark as needing invalidate.
+ invalidateStatus = true
+ }
+ }
+
+ // Finally delete this media.
+ if _, err := tx.NewDelete().
+ Table("media_attachments").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return gtserror.Newf("error deleting media: %w", err)
+ }
+
+ return nil
+ })
+
+ if invalidateAccount {
+ // The account for given ID will have been updated in transaction.
+ m.state.Caches.GTS.Account().Invalidate("ID", media.AccountID)
+ }
+
+ if invalidateStatus {
+ // The status for given ID will have been updated in transaction.
+ m.state.Caches.GTS.Status().Invalidate("ID", media.StatusID)
+ }
+
return m.conn.ProcessError(err)
}
@@ -167,6 +272,29 @@ func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time)
return count, nil
}
+func (m *mediaDB) GetAttachments(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, error) {
+ attachmentIDs := []string{}
+
+ q := m.conn.NewSelect().
+ Table("media_attachments").
+ Column("id").
+ Order("id DESC")
+
+ if maxID != "" {
+ q = q.Where("? < ?", bun.Ident("id"), maxID)
+ }
+
+ if limit != 0 {
+ q = q.Limit(limit)
+ }
+
+ if err := q.Scan(ctx, &attachmentIDs); err != nil {
+ return nil, m.conn.ProcessError(err)
+ }
+
+ return m.GetAttachmentsByIDs(ctx, attachmentIDs)
+}
+
func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit int) ([]*gtsmodel.MediaAttachment, db.Error) {
attachmentIDs := []string{}