diff options
Diffstat (limited to 'internal/db')
-rw-r--r-- | internal/db/bundb/bundb.go | 5 | ||||
-rw-r--r-- | internal/db/bundb/bundb_test.go | 2 | ||||
-rw-r--r-- | internal/db/bundb/migrations/20230105171144_report_model.go | 66 | ||||
-rw-r--r-- | internal/db/bundb/report.go | 138 | ||||
-rw-r--r-- | internal/db/bundb/report_test.go | 147 | ||||
-rw-r--r-- | internal/db/bundb/status.go | 18 | ||||
-rw-r--r-- | internal/db/bundb/status_test.go | 42 | ||||
-rw-r--r-- | internal/db/db.go | 1 | ||||
-rw-r--r-- | internal/db/report.go | 41 | ||||
-rw-r--r-- | internal/db/status.go | 3 |
10 files changed, 463 insertions, 0 deletions
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index e749484a8..1225b2bb0 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -83,6 +83,7 @@ type DBService struct { db.Mention db.Notification db.Relationship + db.Report db.Session db.Status db.Timeline @@ -197,6 +198,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { conn: conn, state: state, }, + Report: &reportDB{ + conn: conn, + state: state, + }, Session: &sessionDB{ conn: conn, }, diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go index 45d2e70a7..e050c2b5d 100644 --- a/internal/db/bundb/bundb_test.go +++ b/internal/db/bundb/bundb_test.go @@ -42,6 +42,7 @@ type BunDBStandardTestSuite struct { testMentions map[string]*gtsmodel.Mention testFollows map[string]*gtsmodel.Follow testEmojis map[string]*gtsmodel.Emoji + testReports map[string]*gtsmodel.Report } func (suite *BunDBStandardTestSuite) SetupSuite() { @@ -56,6 +57,7 @@ func (suite *BunDBStandardTestSuite) SetupSuite() { suite.testMentions = testrig.NewTestMentions() suite.testFollows = testrig.NewTestFollows() suite.testEmojis = testrig.NewTestEmojis() + suite.testReports = testrig.NewTestReports() } func (suite *BunDBStandardTestSuite) SetupTest() { diff --git a/internal/db/bundb/migrations/20230105171144_report_model.go b/internal/db/bundb/migrations/20230105171144_report_model.go new file mode 100644 index 000000000..b175e2995 --- /dev/null +++ b/internal/db/bundb/migrations/20230105171144_report_model.go @@ -0,0 +1,66 @@ +/* + GoToSocial + Copyright (C) 2021-2023 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package migrations + +import ( + "context" + + gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/uptrace/bun" +) + +func init() { + up := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + if _, err := tx.NewCreateTable().Model(>smodel.Report{}).IfNotExists().Exec(ctx); err != nil { + return err + } + + if _, err := tx. + NewCreateIndex(). + Model(>smodel.Report{}). + Index("report_account_id_idx"). + Column("account_id"). + Exec(ctx); err != nil { + return err + } + + if _, err := tx. + NewCreateIndex(). + Model(>smodel.Report{}). + Index("report_target_account_id_idx"). + Column("target_account_id"). + Exec(ctx); err != nil { + return err + } + + return nil + }) + } + + down := func(ctx context.Context, db *bun.DB) error { + return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { + return nil + }) + } + + if err := Migrations.Register(up, down); err != nil { + panic(err) + } +} diff --git a/internal/db/bundb/report.go b/internal/db/bundb/report.go new file mode 100644 index 000000000..8cc1d8de9 --- /dev/null +++ b/internal/db/bundb/report.go @@ -0,0 +1,138 @@ +/* + GoToSocial + Copyright (C) 2021-2023 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( + "context" + "fmt" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/uptrace/bun" +) + +type reportDB struct { + conn *DBConn + state *state.State +} + +func (r *reportDB) newReportQ(report interface{}) *bun.SelectQuery { + return r.conn.NewSelect().Model(report) +} + +func (r *reportDB) GetReportByID(ctx context.Context, id string) (*gtsmodel.Report, db.Error) { + return r.getReport( + ctx, + "ID", + func(report *gtsmodel.Report) error { + return r.newReportQ(report).Where("? = ?", bun.Ident("report.id"), id).Scan(ctx) + }, + id, + ) +} + +func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Report) error, keyParts ...any) (*gtsmodel.Report, db.Error) { + // Fetch report from database cache with loader callback + report, err := r.state.Caches.GTS.Report().Load(lookup, func() (*gtsmodel.Report, error) { + var report gtsmodel.Report + + // Not cached! Perform database query + if err := dbQuery(&report); err != nil { + return nil, r.conn.ProcessError(err) + } + + return &report, nil + }, keyParts...) + if err != nil { + // error already processed + return nil, err + } + + // Set the report author account + report.Account, err = r.state.DB.GetAccountByID(ctx, report.AccountID) + if err != nil { + return nil, fmt.Errorf("error getting report account: %w", err) + } + + // Set the report target account + report.TargetAccount, err = r.state.DB.GetAccountByID(ctx, report.TargetAccountID) + if err != nil { + return nil, fmt.Errorf("error getting report target account: %w", err) + } + + if len(report.StatusIDs) > 0 { + // Fetch reported statuses + report.Statuses, err = r.state.DB.GetStatuses(ctx, report.StatusIDs) + if err != nil { + return nil, fmt.Errorf("error getting status mentions: %w", err) + } + } + + if report.ActionTakenByAccountID != "" { + // Set the report action taken by account + report.ActionTakenByAccount, err = r.state.DB.GetAccountByID(ctx, report.ActionTakenByAccountID) + if err != nil { + return nil, fmt.Errorf("error getting report action taken by account: %w", err) + } + } + + return report, nil +} + +func (r *reportDB) PutReport(ctx context.Context, report *gtsmodel.Report) db.Error { + return r.state.Caches.GTS.Report().Store(report, func() error { + _, err := r.conn.NewInsert().Model(report).Exec(ctx) + return r.conn.ProcessError(err) + }) +} + +func (r *reportDB) UpdateReport(ctx context.Context, report *gtsmodel.Report, columns ...string) (*gtsmodel.Report, db.Error) { + // Update the report's last-updated + report.UpdatedAt = time.Now() + if len(columns) != 0 { + columns = append(columns, "updated_at") + } + + if _, err := r.conn. + NewUpdate(). + Model(report). + Where("? = ?", bun.Ident("report.id"), report.ID). + Column(columns...). + Exec(ctx); err != nil { + return nil, r.conn.ProcessError(err) + } + + r.state.Caches.GTS.Report().Invalidate("ID", report.ID) + return report, nil +} + +func (r *reportDB) DeleteReportByID(ctx context.Context, id string) db.Error { + if _, err := r.conn. + NewDelete(). + TableExpr("? AS ?", bun.Ident("reports"), bun.Ident("report")). + Where("? = ?", bun.Ident("report.id"), id). + Exec(ctx); err != nil { + return r.conn.ProcessError(err) + } + + r.state.Caches.GTS.Report().Invalidate("ID", id) + return nil +} diff --git a/internal/db/bundb/report_test.go b/internal/db/bundb/report_test.go new file mode 100644 index 000000000..85bc4b36f --- /dev/null +++ b/internal/db/bundb/report_test.go @@ -0,0 +1,147 @@ +/* + GoToSocial + Copyright (C) 2021-2023 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/testrig" +) + +type ReportTestSuite struct { + BunDBStandardTestSuite +} + +func (suite *ReportTestSuite) TestGetReportByID() { + report, err := suite.db.GetReportByID(context.Background(), suite.testReports["local_account_2_report_remote_account_1"].ID) + if err != nil { + suite.FailNow(err.Error()) + } + suite.NotNil(report) + suite.NotNil(report.Account) + suite.NotNil(report.TargetAccount) + suite.Zero(report.ActionTakenAt) + suite.Nil(report.ActionTakenByAccount) + suite.Empty(report.ActionTakenByAccountID) + suite.NotEmpty(report.URI) +} + +func (suite *ReportTestSuite) TestGetReportByURI() { + report, err := suite.db.GetReportByID(context.Background(), suite.testReports["remote_account_1_report_local_account_2"].ID) + if err != nil { + suite.FailNow(err.Error()) + } + suite.NotNil(report) + suite.NotNil(report.Account) + suite.NotNil(report.TargetAccount) + suite.NotZero(report.ActionTakenAt) + suite.NotNil(report.ActionTakenByAccount) + suite.NotEmpty(report.ActionTakenByAccountID) + suite.NotEmpty(report.URI) +} + +func (suite *ReportTestSuite) TestPutReport() { + ctx := context.Background() + + reportID := "01GP3ECY8QJD8DBJSS8B1CR0AX" + report := >smodel.Report{ + ID: reportID, + CreatedAt: testrig.TimeMustParse("2022-05-14T12:20:03+02:00"), + UpdatedAt: testrig.TimeMustParse("2022-05-14T12:20:03+02:00"), + URI: "http://localhost:8080/01GP3ECY8QJD8DBJSS8B1CR0AX", + AccountID: "01F8MH5NBDF2MV7CTC4Q5128HF", + TargetAccountID: "01F8MH5ZK5VRH73AKHQM6Y9VNX", + Comment: "another report", + StatusIDs: []string{"01FVW7JHQFSFK166WWKR8CBA6M"}, + Forwarded: testrig.TrueBool(), + } + + err := suite.db.PutReport(ctx, report) + suite.NoError(err) +} + +func (suite *ReportTestSuite) TestUpdateReport() { + ctx := context.Background() + + report := >smodel.Report{} + *report = *suite.testReports["local_account_2_report_remote_account_1"] + report.ActionTaken = "nothing" + report.ActionTakenByAccountID = suite.testAccounts["admin_account"].ID + report.ActionTakenAt = testrig.TimeMustParse("2022-05-14T12:20:03+02:00") + + if _, err := suite.db.UpdateReport(ctx, report, "action_taken", "action_taken_by_account_id", "action_taken_at"); err != nil { + suite.FailNow(err.Error()) + } + + dbReport, err := suite.db.GetReportByID(ctx, report.ID) + if err != nil { + suite.FailNow(err.Error()) + } + suite.NotNil(dbReport) + suite.NotNil(dbReport.Account) + suite.NotNil(dbReport.TargetAccount) + suite.NotZero(dbReport.ActionTakenAt) + suite.NotNil(dbReport.ActionTakenByAccount) + suite.NotEmpty(dbReport.ActionTakenByAccountID) + suite.NotEmpty(dbReport.URI) +} + +func (suite *ReportTestSuite) TestUpdateReportAllColumns() { + ctx := context.Background() + + report := >smodel.Report{} + *report = *suite.testReports["local_account_2_report_remote_account_1"] + report.ActionTaken = "nothing" + report.ActionTakenByAccountID = suite.testAccounts["admin_account"].ID + report.ActionTakenAt = testrig.TimeMustParse("2022-05-14T12:20:03+02:00") + + if _, err := suite.db.UpdateReport(ctx, report); err != nil { + suite.FailNow(err.Error()) + } + + dbReport, err := suite.db.GetReportByID(ctx, report.ID) + if err != nil { + suite.FailNow(err.Error()) + } + suite.NotNil(dbReport) + suite.NotNil(dbReport.Account) + suite.NotNil(dbReport.TargetAccount) + suite.NotZero(dbReport.ActionTakenAt) + suite.NotNil(dbReport.ActionTakenByAccount) + suite.NotEmpty(dbReport.ActionTakenByAccountID) + suite.NotEmpty(dbReport.URI) +} + +func (suite *ReportTestSuite) TestDeleteReport() { + if err := suite.db.DeleteReportByID(context.Background(), suite.testReports["remote_account_1_report_local_account_2"].ID); err != nil { + suite.FailNow(err.Error()) + } + + report, err := suite.db.GetReportByID(context.Background(), suite.testReports["remote_account_1_report_local_account_2"].ID) + suite.ErrorIs(err, db.ErrNoEntries) + suite.Nil(report) +} + +func TestReportTestSuite(t *testing.T) { + suite.Run(t, new(ReportTestSuite)) +} diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index b52c06978..709105f72 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -67,6 +67,24 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat ) } +func (s *statusDB) GetStatuses(ctx context.Context, ids []string) ([]*gtsmodel.Status, db.Error) { + statuses := make([]*gtsmodel.Status, 0, len(ids)) + + for _, id := range ids { + // Attempt fetch from DB + status, err := s.GetStatusByID(ctx, id) + if err != nil { + log.Errorf("GetStatuses: error getting status %q: %v", id, err) + continue + } + + // Append status + statuses = append(statuses, status) + } + + return statuses, nil +} + func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) { return s.getStatus( ctx, diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go index bef8c7912..d86e0bcf9 100644 --- a/internal/db/bundb/status_test.go +++ b/internal/db/bundb/status_test.go @@ -50,6 +50,48 @@ func (suite *StatusTestSuite) TestGetStatusByID() { suite.True(*status.Likeable) } +func (suite *StatusTestSuite) TestGetStatusesByID() { + ids := []string{ + suite.testStatuses["local_account_1_status_1"].ID, + suite.testStatuses["local_account_2_status_3"].ID, + } + + statuses, err := suite.db.GetStatuses(context.Background(), ids) + if err != nil { + suite.FailNow(err.Error()) + } + + if len(statuses) != 2 { + suite.FailNow("expected 2 statuses in slice") + } + + status1 := statuses[0] + suite.NotNil(status1) + suite.NotNil(status1.Account) + suite.NotNil(status1.CreatedWithApplication) + suite.Nil(status1.BoostOf) + suite.Nil(status1.BoostOfAccount) + suite.Nil(status1.InReplyTo) + suite.Nil(status1.InReplyToAccount) + suite.True(*status1.Federated) + suite.True(*status1.Boostable) + suite.True(*status1.Replyable) + suite.True(*status1.Likeable) + + status2 := statuses[1] + suite.NotNil(status2) + suite.NotNil(status2.Account) + suite.NotNil(status2.CreatedWithApplication) + suite.Nil(status2.BoostOf) + suite.Nil(status2.BoostOfAccount) + suite.Nil(status2.InReplyTo) + suite.Nil(status2.InReplyToAccount) + suite.True(*status2.Federated) + suite.True(*status2.Boostable) + suite.False(*status2.Replyable) + suite.False(*status2.Likeable) +} + func (suite *StatusTestSuite) TestGetStatusByURI() { status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_2_status_3"].URI) if err != nil { diff --git a/internal/db/db.go b/internal/db/db.go index efe867e3e..aa1929da9 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -41,6 +41,7 @@ type DB interface { Mention Notification Relationship + Report Session Status Timeline diff --git a/internal/db/report.go b/internal/db/report.go new file mode 100644 index 000000000..216e10fdd --- /dev/null +++ b/internal/db/report.go @@ -0,0 +1,41 @@ +/* + GoToSocial + Copyright (C) 2021-2023 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +package db + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// Report handles getting/creation/deletion/updating of user reports/flags. +type Report interface { + // GetReportByID gets one report by its db id + GetReportByID(ctx context.Context, id string) (*gtsmodel.Report, Error) + // PutReport puts the given report in the database. + PutReport(ctx context.Context, report *gtsmodel.Report) Error + // UpdateReport updates one report by its db id. + // The given columns will be updated; if no columns are + // provided, then all columns will be updated. + // updated_at will also be updated, no need to pass this + // as a specific column. + UpdateReport(ctx context.Context, report *gtsmodel.Report, columns ...string) (*gtsmodel.Report, Error) + // DeleteReportByID deletes report with the given id. + DeleteReportByID(ctx context.Context, id string) Error +} diff --git a/internal/db/status.go b/internal/db/status.go index f854664c8..15d1362f5 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -29,6 +29,9 @@ type Status interface { // GetStatusByID returns one status from the database, with no rel fields populated, only their linking ID / URIs GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error) + // GetStatuses gets a slice of statuses corresponding to the given status IDs. + GetStatuses(ctx context.Context, ids []string) ([]*gtsmodel.Status, Error) + // GetStatusByURI returns one status from the database, with no rel fields populated, only their linking ID / URIs GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error) |