summaryrefslogtreecommitdiff
path: root/internal/db
diff options
context:
space:
mode:
Diffstat (limited to 'internal/db')
-rw-r--r--internal/db/bundb/bundb.go5
-rw-r--r--internal/db/bundb/bundb_test.go2
-rw-r--r--internal/db/bundb/migrations/20230105171144_report_model.go66
-rw-r--r--internal/db/bundb/report.go138
-rw-r--r--internal/db/bundb/report_test.go147
-rw-r--r--internal/db/bundb/status.go18
-rw-r--r--internal/db/bundb/status_test.go42
-rw-r--r--internal/db/db.go1
-rw-r--r--internal/db/report.go41
-rw-r--r--internal/db/status.go3
10 files changed, 463 insertions, 0 deletions
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index e749484a8..1225b2bb0 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -83,6 +83,7 @@ type DBService struct {
db.Mention
db.Notification
db.Relationship
+ db.Report
db.Session
db.Status
db.Timeline
@@ -197,6 +198,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
conn: conn,
state: state,
},
+ Report: &reportDB{
+ conn: conn,
+ state: state,
+ },
Session: &sessionDB{
conn: conn,
},
diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go
index 45d2e70a7..e050c2b5d 100644
--- a/internal/db/bundb/bundb_test.go
+++ b/internal/db/bundb/bundb_test.go
@@ -42,6 +42,7 @@ type BunDBStandardTestSuite struct {
testMentions map[string]*gtsmodel.Mention
testFollows map[string]*gtsmodel.Follow
testEmojis map[string]*gtsmodel.Emoji
+ testReports map[string]*gtsmodel.Report
}
func (suite *BunDBStandardTestSuite) SetupSuite() {
@@ -56,6 +57,7 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {
suite.testMentions = testrig.NewTestMentions()
suite.testFollows = testrig.NewTestFollows()
suite.testEmojis = testrig.NewTestEmojis()
+ suite.testReports = testrig.NewTestReports()
}
func (suite *BunDBStandardTestSuite) SetupTest() {
diff --git a/internal/db/bundb/migrations/20230105171144_report_model.go b/internal/db/bundb/migrations/20230105171144_report_model.go
new file mode 100644
index 000000000..b175e2995
--- /dev/null
+++ b/internal/db/bundb/migrations/20230105171144_report_model.go
@@ -0,0 +1,66 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2023 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package migrations
+
+import (
+ "context"
+
+ gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ if _, err := tx.NewCreateTable().Model(&gtsmodel.Report{}).IfNotExists().Exec(ctx); err != nil {
+ return err
+ }
+
+ if _, err := tx.
+ NewCreateIndex().
+ Model(&gtsmodel.Report{}).
+ Index("report_account_id_idx").
+ Column("account_id").
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ if _, err := tx.
+ NewCreateIndex().
+ Model(&gtsmodel.Report{}).
+ Index("report_target_account_id_idx").
+ Column("target_account_id").
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ return nil
+ })
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ return nil
+ })
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/db/bundb/report.go b/internal/db/bundb/report.go
new file mode 100644
index 000000000..8cc1d8de9
--- /dev/null
+++ b/internal/db/bundb/report.go
@@ -0,0 +1,138 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2023 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
+ "github.com/uptrace/bun"
+)
+
+type reportDB struct {
+ conn *DBConn
+ state *state.State
+}
+
+func (r *reportDB) newReportQ(report interface{}) *bun.SelectQuery {
+ return r.conn.NewSelect().Model(report)
+}
+
+func (r *reportDB) GetReportByID(ctx context.Context, id string) (*gtsmodel.Report, db.Error) {
+ return r.getReport(
+ ctx,
+ "ID",
+ func(report *gtsmodel.Report) error {
+ return r.newReportQ(report).Where("? = ?", bun.Ident("report.id"), id).Scan(ctx)
+ },
+ id,
+ )
+}
+
+func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Report) error, keyParts ...any) (*gtsmodel.Report, db.Error) {
+ // Fetch report from database cache with loader callback
+ report, err := r.state.Caches.GTS.Report().Load(lookup, func() (*gtsmodel.Report, error) {
+ var report gtsmodel.Report
+
+ // Not cached! Perform database query
+ if err := dbQuery(&report); err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+
+ return &report, nil
+ }, keyParts...)
+ if err != nil {
+ // error already processed
+ return nil, err
+ }
+
+ // Set the report author account
+ report.Account, err = r.state.DB.GetAccountByID(ctx, report.AccountID)
+ if err != nil {
+ return nil, fmt.Errorf("error getting report account: %w", err)
+ }
+
+ // Set the report target account
+ report.TargetAccount, err = r.state.DB.GetAccountByID(ctx, report.TargetAccountID)
+ if err != nil {
+ return nil, fmt.Errorf("error getting report target account: %w", err)
+ }
+
+ if len(report.StatusIDs) > 0 {
+ // Fetch reported statuses
+ report.Statuses, err = r.state.DB.GetStatuses(ctx, report.StatusIDs)
+ if err != nil {
+ return nil, fmt.Errorf("error getting status mentions: %w", err)
+ }
+ }
+
+ if report.ActionTakenByAccountID != "" {
+ // Set the report action taken by account
+ report.ActionTakenByAccount, err = r.state.DB.GetAccountByID(ctx, report.ActionTakenByAccountID)
+ if err != nil {
+ return nil, fmt.Errorf("error getting report action taken by account: %w", err)
+ }
+ }
+
+ return report, nil
+}
+
+func (r *reportDB) PutReport(ctx context.Context, report *gtsmodel.Report) db.Error {
+ return r.state.Caches.GTS.Report().Store(report, func() error {
+ _, err := r.conn.NewInsert().Model(report).Exec(ctx)
+ return r.conn.ProcessError(err)
+ })
+}
+
+func (r *reportDB) UpdateReport(ctx context.Context, report *gtsmodel.Report, columns ...string) (*gtsmodel.Report, db.Error) {
+ // Update the report's last-updated
+ report.UpdatedAt = time.Now()
+ if len(columns) != 0 {
+ columns = append(columns, "updated_at")
+ }
+
+ if _, err := r.conn.
+ NewUpdate().
+ Model(report).
+ Where("? = ?", bun.Ident("report.id"), report.ID).
+ Column(columns...).
+ Exec(ctx); err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+
+ r.state.Caches.GTS.Report().Invalidate("ID", report.ID)
+ return report, nil
+}
+
+func (r *reportDB) DeleteReportByID(ctx context.Context, id string) db.Error {
+ if _, err := r.conn.
+ NewDelete().
+ TableExpr("? AS ?", bun.Ident("reports"), bun.Ident("report")).
+ Where("? = ?", bun.Ident("report.id"), id).
+ Exec(ctx); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ r.state.Caches.GTS.Report().Invalidate("ID", id)
+ return nil
+}
diff --git a/internal/db/bundb/report_test.go b/internal/db/bundb/report_test.go
new file mode 100644
index 000000000..85bc4b36f
--- /dev/null
+++ b/internal/db/bundb/report_test.go
@@ -0,0 +1,147 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2023 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package bundb_test
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type ReportTestSuite struct {
+ BunDBStandardTestSuite
+}
+
+func (suite *ReportTestSuite) TestGetReportByID() {
+ report, err := suite.db.GetReportByID(context.Background(), suite.testReports["local_account_2_report_remote_account_1"].ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(report)
+ suite.NotNil(report.Account)
+ suite.NotNil(report.TargetAccount)
+ suite.Zero(report.ActionTakenAt)
+ suite.Nil(report.ActionTakenByAccount)
+ suite.Empty(report.ActionTakenByAccountID)
+ suite.NotEmpty(report.URI)
+}
+
+func (suite *ReportTestSuite) TestGetReportByURI() {
+ report, err := suite.db.GetReportByID(context.Background(), suite.testReports["remote_account_1_report_local_account_2"].ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(report)
+ suite.NotNil(report.Account)
+ suite.NotNil(report.TargetAccount)
+ suite.NotZero(report.ActionTakenAt)
+ suite.NotNil(report.ActionTakenByAccount)
+ suite.NotEmpty(report.ActionTakenByAccountID)
+ suite.NotEmpty(report.URI)
+}
+
+func (suite *ReportTestSuite) TestPutReport() {
+ ctx := context.Background()
+
+ reportID := "01GP3ECY8QJD8DBJSS8B1CR0AX"
+ report := &gtsmodel.Report{
+ ID: reportID,
+ CreatedAt: testrig.TimeMustParse("2022-05-14T12:20:03+02:00"),
+ UpdatedAt: testrig.TimeMustParse("2022-05-14T12:20:03+02:00"),
+ URI: "http://localhost:8080/01GP3ECY8QJD8DBJSS8B1CR0AX",
+ AccountID: "01F8MH5NBDF2MV7CTC4Q5128HF",
+ TargetAccountID: "01F8MH5ZK5VRH73AKHQM6Y9VNX",
+ Comment: "another report",
+ StatusIDs: []string{"01FVW7JHQFSFK166WWKR8CBA6M"},
+ Forwarded: testrig.TrueBool(),
+ }
+
+ err := suite.db.PutReport(ctx, report)
+ suite.NoError(err)
+}
+
+func (suite *ReportTestSuite) TestUpdateReport() {
+ ctx := context.Background()
+
+ report := &gtsmodel.Report{}
+ *report = *suite.testReports["local_account_2_report_remote_account_1"]
+ report.ActionTaken = "nothing"
+ report.ActionTakenByAccountID = suite.testAccounts["admin_account"].ID
+ report.ActionTakenAt = testrig.TimeMustParse("2022-05-14T12:20:03+02:00")
+
+ if _, err := suite.db.UpdateReport(ctx, report, "action_taken", "action_taken_by_account_id", "action_taken_at"); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ dbReport, err := suite.db.GetReportByID(ctx, report.ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(dbReport)
+ suite.NotNil(dbReport.Account)
+ suite.NotNil(dbReport.TargetAccount)
+ suite.NotZero(dbReport.ActionTakenAt)
+ suite.NotNil(dbReport.ActionTakenByAccount)
+ suite.NotEmpty(dbReport.ActionTakenByAccountID)
+ suite.NotEmpty(dbReport.URI)
+}
+
+func (suite *ReportTestSuite) TestUpdateReportAllColumns() {
+ ctx := context.Background()
+
+ report := &gtsmodel.Report{}
+ *report = *suite.testReports["local_account_2_report_remote_account_1"]
+ report.ActionTaken = "nothing"
+ report.ActionTakenByAccountID = suite.testAccounts["admin_account"].ID
+ report.ActionTakenAt = testrig.TimeMustParse("2022-05-14T12:20:03+02:00")
+
+ if _, err := suite.db.UpdateReport(ctx, report); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ dbReport, err := suite.db.GetReportByID(ctx, report.ID)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+ suite.NotNil(dbReport)
+ suite.NotNil(dbReport.Account)
+ suite.NotNil(dbReport.TargetAccount)
+ suite.NotZero(dbReport.ActionTakenAt)
+ suite.NotNil(dbReport.ActionTakenByAccount)
+ suite.NotEmpty(dbReport.ActionTakenByAccountID)
+ suite.NotEmpty(dbReport.URI)
+}
+
+func (suite *ReportTestSuite) TestDeleteReport() {
+ if err := suite.db.DeleteReportByID(context.Background(), suite.testReports["remote_account_1_report_local_account_2"].ID); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ report, err := suite.db.GetReportByID(context.Background(), suite.testReports["remote_account_1_report_local_account_2"].ID)
+ suite.ErrorIs(err, db.ErrNoEntries)
+ suite.Nil(report)
+}
+
+func TestReportTestSuite(t *testing.T) {
+ suite.Run(t, new(ReportTestSuite))
+}
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
index b52c06978..709105f72 100644
--- a/internal/db/bundb/status.go
+++ b/internal/db/bundb/status.go
@@ -67,6 +67,24 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat
)
}
+func (s *statusDB) GetStatuses(ctx context.Context, ids []string) ([]*gtsmodel.Status, db.Error) {
+ statuses := make([]*gtsmodel.Status, 0, len(ids))
+
+ for _, id := range ids {
+ // Attempt fetch from DB
+ status, err := s.GetStatusByID(ctx, id)
+ if err != nil {
+ log.Errorf("GetStatuses: error getting status %q: %v", id, err)
+ continue
+ }
+
+ // Append status
+ statuses = append(statuses, status)
+ }
+
+ return statuses, nil
+}
+
func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {
return s.getStatus(
ctx,
diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go
index bef8c7912..d86e0bcf9 100644
--- a/internal/db/bundb/status_test.go
+++ b/internal/db/bundb/status_test.go
@@ -50,6 +50,48 @@ func (suite *StatusTestSuite) TestGetStatusByID() {
suite.True(*status.Likeable)
}
+func (suite *StatusTestSuite) TestGetStatusesByID() {
+ ids := []string{
+ suite.testStatuses["local_account_1_status_1"].ID,
+ suite.testStatuses["local_account_2_status_3"].ID,
+ }
+
+ statuses, err := suite.db.GetStatuses(context.Background(), ids)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ if len(statuses) != 2 {
+ suite.FailNow("expected 2 statuses in slice")
+ }
+
+ status1 := statuses[0]
+ suite.NotNil(status1)
+ suite.NotNil(status1.Account)
+ suite.NotNil(status1.CreatedWithApplication)
+ suite.Nil(status1.BoostOf)
+ suite.Nil(status1.BoostOfAccount)
+ suite.Nil(status1.InReplyTo)
+ suite.Nil(status1.InReplyToAccount)
+ suite.True(*status1.Federated)
+ suite.True(*status1.Boostable)
+ suite.True(*status1.Replyable)
+ suite.True(*status1.Likeable)
+
+ status2 := statuses[1]
+ suite.NotNil(status2)
+ suite.NotNil(status2.Account)
+ suite.NotNil(status2.CreatedWithApplication)
+ suite.Nil(status2.BoostOf)
+ suite.Nil(status2.BoostOfAccount)
+ suite.Nil(status2.InReplyTo)
+ suite.Nil(status2.InReplyToAccount)
+ suite.True(*status2.Federated)
+ suite.True(*status2.Boostable)
+ suite.False(*status2.Replyable)
+ suite.False(*status2.Likeable)
+}
+
func (suite *StatusTestSuite) TestGetStatusByURI() {
status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_2_status_3"].URI)
if err != nil {
diff --git a/internal/db/db.go b/internal/db/db.go
index efe867e3e..aa1929da9 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -41,6 +41,7 @@ type DB interface {
Mention
Notification
Relationship
+ Report
Session
Status
Timeline
diff --git a/internal/db/report.go b/internal/db/report.go
new file mode 100644
index 000000000..216e10fdd
--- /dev/null
+++ b/internal/db/report.go
@@ -0,0 +1,41 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2023 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+*/
+
+package db
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+// Report handles getting/creation/deletion/updating of user reports/flags.
+type Report interface {
+ // GetReportByID gets one report by its db id
+ GetReportByID(ctx context.Context, id string) (*gtsmodel.Report, Error)
+ // PutReport puts the given report in the database.
+ PutReport(ctx context.Context, report *gtsmodel.Report) Error
+ // UpdateReport updates one report by its db id.
+ // The given columns will be updated; if no columns are
+ // provided, then all columns will be updated.
+ // updated_at will also be updated, no need to pass this
+ // as a specific column.
+ UpdateReport(ctx context.Context, report *gtsmodel.Report, columns ...string) (*gtsmodel.Report, Error)
+ // DeleteReportByID deletes report with the given id.
+ DeleteReportByID(ctx context.Context, id string) Error
+}
diff --git a/internal/db/status.go b/internal/db/status.go
index f854664c8..15d1362f5 100644
--- a/internal/db/status.go
+++ b/internal/db/status.go
@@ -29,6 +29,9 @@ type Status interface {
// GetStatusByID returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, Error)
+ // GetStatuses gets a slice of statuses corresponding to the given status IDs.
+ GetStatuses(ctx context.Context, ids []string) ([]*gtsmodel.Status, Error)
+
// GetStatusByURI returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, Error)