diff options
Diffstat (limited to 'internal/db/bundb')
| -rw-r--r-- | internal/db/bundb/bundb.go | 5 | ||||
| -rw-r--r-- | internal/db/bundb/bundb_test.go | 2 | ||||
| -rw-r--r-- | internal/db/bundb/migrations/20230105171144_report_model.go | 66 | ||||
| -rw-r--r-- | internal/db/bundb/report.go | 138 | ||||
| -rw-r--r-- | internal/db/bundb/report_test.go | 147 | ||||
| -rw-r--r-- | internal/db/bundb/status.go | 18 | ||||
| -rw-r--r-- | internal/db/bundb/status_test.go | 42 | 
7 files changed, 418 insertions, 0 deletions
| diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index e749484a8..1225b2bb0 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -83,6 +83,7 @@ type DBService struct {  	db.Mention  	db.Notification  	db.Relationship +	db.Report  	db.Session  	db.Status  	db.Timeline @@ -197,6 +198,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {  			conn:  conn,  			state: state,  		}, +		Report: &reportDB{ +			conn:  conn, +			state: state, +		},  		Session: &sessionDB{  			conn: conn,  		}, diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go index 45d2e70a7..e050c2b5d 100644 --- a/internal/db/bundb/bundb_test.go +++ b/internal/db/bundb/bundb_test.go @@ -42,6 +42,7 @@ type BunDBStandardTestSuite struct {  	testMentions     map[string]*gtsmodel.Mention  	testFollows      map[string]*gtsmodel.Follow  	testEmojis       map[string]*gtsmodel.Emoji +	testReports      map[string]*gtsmodel.Report  }  func (suite *BunDBStandardTestSuite) SetupSuite() { @@ -56,6 +57,7 @@ func (suite *BunDBStandardTestSuite) SetupSuite() {  	suite.testMentions = testrig.NewTestMentions()  	suite.testFollows = testrig.NewTestFollows()  	suite.testEmojis = testrig.NewTestEmojis() +	suite.testReports = testrig.NewTestReports()  }  func (suite *BunDBStandardTestSuite) SetupTest() { diff --git a/internal/db/bundb/migrations/20230105171144_report_model.go b/internal/db/bundb/migrations/20230105171144_report_model.go new file mode 100644 index 000000000..b175e2995 --- /dev/null +++ b/internal/db/bundb/migrations/20230105171144_report_model.go @@ -0,0 +1,66 @@ +/* +   GoToSocial +   Copyright (C) 2021-2023 GoToSocial Authors admin@gotosocial.org + +   This program is free software: you can redistribute it and/or modify +   it under the terms of the GNU Affero General Public License as published by +   the Free Software Foundation, either version 3 of the License, or +   (at your option) any later version. + +   This program is distributed in the hope that it will be useful, +   but WITHOUT ANY WARRANTY; without even the implied warranty of +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +   GNU Affero General Public License for more details. + +   You should have received a copy of the GNU Affero General Public License +   along with this program.  If not, see <http://www.gnu.org/licenses/>. +*/ + +package migrations + +import ( +	"context" + +	gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/uptrace/bun" +) + +func init() { +	up := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			if _, err := tx.NewCreateTable().Model(>smodel.Report{}).IfNotExists().Exec(ctx); err != nil { +				return err +			} + +			if _, err := tx. +				NewCreateIndex(). +				Model(>smodel.Report{}). +				Index("report_account_id_idx"). +				Column("account_id"). +				Exec(ctx); err != nil { +				return err +			} + +			if _, err := tx. +				NewCreateIndex(). +				Model(>smodel.Report{}). +				Index("report_target_account_id_idx"). +				Column("target_account_id"). +				Exec(ctx); err != nil { +				return err +			} + +			return nil +		}) +	} + +	down := func(ctx context.Context, db *bun.DB) error { +		return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { +			return nil +		}) +	} + +	if err := Migrations.Register(up, down); err != nil { +		panic(err) +	} +} diff --git a/internal/db/bundb/report.go b/internal/db/bundb/report.go new file mode 100644 index 000000000..8cc1d8de9 --- /dev/null +++ b/internal/db/bundb/report.go @@ -0,0 +1,138 @@ +/* +   GoToSocial +   Copyright (C) 2021-2023 GoToSocial Authors admin@gotosocial.org + +   This program is free software: you can redistribute it and/or modify +   it under the terms of the GNU Affero General Public License as published by +   the Free Software Foundation, either version 3 of the License, or +   (at your option) any later version. + +   This program is distributed in the hope that it will be useful, +   but WITHOUT ANY WARRANTY; without even the implied warranty of +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +   GNU Affero General Public License for more details. + +   You should have received a copy of the GNU Affero General Public License +   along with this program.  If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb + +import ( +	"context" +	"fmt" +	"time" + +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/uptrace/bun" +) + +type reportDB struct { +	conn  *DBConn +	state *state.State +} + +func (r *reportDB) newReportQ(report interface{}) *bun.SelectQuery { +	return r.conn.NewSelect().Model(report) +} + +func (r *reportDB) GetReportByID(ctx context.Context, id string) (*gtsmodel.Report, db.Error) { +	return r.getReport( +		ctx, +		"ID", +		func(report *gtsmodel.Report) error { +			return r.newReportQ(report).Where("? = ?", bun.Ident("report.id"), id).Scan(ctx) +		}, +		id, +	) +} + +func (r *reportDB) getReport(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Report) error, keyParts ...any) (*gtsmodel.Report, db.Error) { +	// Fetch report from database cache with loader callback +	report, err := r.state.Caches.GTS.Report().Load(lookup, func() (*gtsmodel.Report, error) { +		var report gtsmodel.Report + +		// Not cached! Perform database query +		if err := dbQuery(&report); err != nil { +			return nil, r.conn.ProcessError(err) +		} + +		return &report, nil +	}, keyParts...) +	if err != nil { +		// error already processed +		return nil, err +	} + +	// Set the report author account +	report.Account, err = r.state.DB.GetAccountByID(ctx, report.AccountID) +	if err != nil { +		return nil, fmt.Errorf("error getting report account: %w", err) +	} + +	// Set the report target account +	report.TargetAccount, err = r.state.DB.GetAccountByID(ctx, report.TargetAccountID) +	if err != nil { +		return nil, fmt.Errorf("error getting report target account: %w", err) +	} + +	if len(report.StatusIDs) > 0 { +		// Fetch reported statuses +		report.Statuses, err = r.state.DB.GetStatuses(ctx, report.StatusIDs) +		if err != nil { +			return nil, fmt.Errorf("error getting status mentions: %w", err) +		} +	} + +	if report.ActionTakenByAccountID != "" { +		// Set the report action taken by account +		report.ActionTakenByAccount, err = r.state.DB.GetAccountByID(ctx, report.ActionTakenByAccountID) +		if err != nil { +			return nil, fmt.Errorf("error getting report action taken by account: %w", err) +		} +	} + +	return report, nil +} + +func (r *reportDB) PutReport(ctx context.Context, report *gtsmodel.Report) db.Error { +	return r.state.Caches.GTS.Report().Store(report, func() error { +		_, err := r.conn.NewInsert().Model(report).Exec(ctx) +		return r.conn.ProcessError(err) +	}) +} + +func (r *reportDB) UpdateReport(ctx context.Context, report *gtsmodel.Report, columns ...string) (*gtsmodel.Report, db.Error) { +	// Update the report's last-updated +	report.UpdatedAt = time.Now() +	if len(columns) != 0 { +		columns = append(columns, "updated_at") +	} + +	if _, err := r.conn. +		NewUpdate(). +		Model(report). +		Where("? = ?", bun.Ident("report.id"), report.ID). +		Column(columns...). +		Exec(ctx); err != nil { +		return nil, r.conn.ProcessError(err) +	} + +	r.state.Caches.GTS.Report().Invalidate("ID", report.ID) +	return report, nil +} + +func (r *reportDB) DeleteReportByID(ctx context.Context, id string) db.Error { +	if _, err := r.conn. +		NewDelete(). +		TableExpr("? AS ?", bun.Ident("reports"), bun.Ident("report")). +		Where("? = ?", bun.Ident("report.id"), id). +		Exec(ctx); err != nil { +		return r.conn.ProcessError(err) +	} + +	r.state.Caches.GTS.Report().Invalidate("ID", id) +	return nil +} diff --git a/internal/db/bundb/report_test.go b/internal/db/bundb/report_test.go new file mode 100644 index 000000000..85bc4b36f --- /dev/null +++ b/internal/db/bundb/report_test.go @@ -0,0 +1,147 @@ +/* +   GoToSocial +   Copyright (C) 2021-2023 GoToSocial Authors admin@gotosocial.org + +   This program is free software: you can redistribute it and/or modify +   it under the terms of the GNU Affero General Public License as published by +   the Free Software Foundation, either version 3 of the License, or +   (at your option) any later version. + +   This program is distributed in the hope that it will be useful, +   but WITHOUT ANY WARRANTY; without even the implied warranty of +   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +   GNU Affero General Public License for more details. + +   You should have received a copy of the GNU Affero General Public License +   along with this program.  If not, see <http://www.gnu.org/licenses/>. +*/ + +package bundb_test + +import ( +	"context" +	"testing" + +	"github.com/stretchr/testify/suite" +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/testrig" +) + +type ReportTestSuite struct { +	BunDBStandardTestSuite +} + +func (suite *ReportTestSuite) TestGetReportByID() { +	report, err := suite.db.GetReportByID(context.Background(), suite.testReports["local_account_2_report_remote_account_1"].ID) +	if err != nil { +		suite.FailNow(err.Error()) +	} +	suite.NotNil(report) +	suite.NotNil(report.Account) +	suite.NotNil(report.TargetAccount) +	suite.Zero(report.ActionTakenAt) +	suite.Nil(report.ActionTakenByAccount) +	suite.Empty(report.ActionTakenByAccountID) +	suite.NotEmpty(report.URI) +} + +func (suite *ReportTestSuite) TestGetReportByURI() { +	report, err := suite.db.GetReportByID(context.Background(), suite.testReports["remote_account_1_report_local_account_2"].ID) +	if err != nil { +		suite.FailNow(err.Error()) +	} +	suite.NotNil(report) +	suite.NotNil(report.Account) +	suite.NotNil(report.TargetAccount) +	suite.NotZero(report.ActionTakenAt) +	suite.NotNil(report.ActionTakenByAccount) +	suite.NotEmpty(report.ActionTakenByAccountID) +	suite.NotEmpty(report.URI) +} + +func (suite *ReportTestSuite) TestPutReport() { +	ctx := context.Background() + +	reportID := "01GP3ECY8QJD8DBJSS8B1CR0AX" +	report := >smodel.Report{ +		ID:              reportID, +		CreatedAt:       testrig.TimeMustParse("2022-05-14T12:20:03+02:00"), +		UpdatedAt:       testrig.TimeMustParse("2022-05-14T12:20:03+02:00"), +		URI:             "http://localhost:8080/01GP3ECY8QJD8DBJSS8B1CR0AX", +		AccountID:       "01F8MH5NBDF2MV7CTC4Q5128HF", +		TargetAccountID: "01F8MH5ZK5VRH73AKHQM6Y9VNX", +		Comment:         "another report", +		StatusIDs:       []string{"01FVW7JHQFSFK166WWKR8CBA6M"}, +		Forwarded:       testrig.TrueBool(), +	} + +	err := suite.db.PutReport(ctx, report) +	suite.NoError(err) +} + +func (suite *ReportTestSuite) TestUpdateReport() { +	ctx := context.Background() + +	report := >smodel.Report{} +	*report = *suite.testReports["local_account_2_report_remote_account_1"] +	report.ActionTaken = "nothing" +	report.ActionTakenByAccountID = suite.testAccounts["admin_account"].ID +	report.ActionTakenAt = testrig.TimeMustParse("2022-05-14T12:20:03+02:00") + +	if _, err := suite.db.UpdateReport(ctx, report, "action_taken", "action_taken_by_account_id", "action_taken_at"); err != nil { +		suite.FailNow(err.Error()) +	} + +	dbReport, err := suite.db.GetReportByID(ctx, report.ID) +	if err != nil { +		suite.FailNow(err.Error()) +	} +	suite.NotNil(dbReport) +	suite.NotNil(dbReport.Account) +	suite.NotNil(dbReport.TargetAccount) +	suite.NotZero(dbReport.ActionTakenAt) +	suite.NotNil(dbReport.ActionTakenByAccount) +	suite.NotEmpty(dbReport.ActionTakenByAccountID) +	suite.NotEmpty(dbReport.URI) +} + +func (suite *ReportTestSuite) TestUpdateReportAllColumns() { +	ctx := context.Background() + +	report := >smodel.Report{} +	*report = *suite.testReports["local_account_2_report_remote_account_1"] +	report.ActionTaken = "nothing" +	report.ActionTakenByAccountID = suite.testAccounts["admin_account"].ID +	report.ActionTakenAt = testrig.TimeMustParse("2022-05-14T12:20:03+02:00") + +	if _, err := suite.db.UpdateReport(ctx, report); err != nil { +		suite.FailNow(err.Error()) +	} + +	dbReport, err := suite.db.GetReportByID(ctx, report.ID) +	if err != nil { +		suite.FailNow(err.Error()) +	} +	suite.NotNil(dbReport) +	suite.NotNil(dbReport.Account) +	suite.NotNil(dbReport.TargetAccount) +	suite.NotZero(dbReport.ActionTakenAt) +	suite.NotNil(dbReport.ActionTakenByAccount) +	suite.NotEmpty(dbReport.ActionTakenByAccountID) +	suite.NotEmpty(dbReport.URI) +} + +func (suite *ReportTestSuite) TestDeleteReport() { +	if err := suite.db.DeleteReportByID(context.Background(), suite.testReports["remote_account_1_report_local_account_2"].ID); err != nil { +		suite.FailNow(err.Error()) +	} + +	report, err := suite.db.GetReportByID(context.Background(), suite.testReports["remote_account_1_report_local_account_2"].ID) +	suite.ErrorIs(err, db.ErrNoEntries) +	suite.Nil(report) +} + +func TestReportTestSuite(t *testing.T) { +	suite.Run(t, new(ReportTestSuite)) +} diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index b52c06978..709105f72 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -67,6 +67,24 @@ func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Stat  	)  } +func (s *statusDB) GetStatuses(ctx context.Context, ids []string) ([]*gtsmodel.Status, db.Error) { +	statuses := make([]*gtsmodel.Status, 0, len(ids)) + +	for _, id := range ids { +		// Attempt fetch from DB +		status, err := s.GetStatusByID(ctx, id) +		if err != nil { +			log.Errorf("GetStatuses: error getting status %q: %v", id, err) +			continue +		} + +		// Append status +		statuses = append(statuses, status) +	} + +	return statuses, nil +} +  func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.Status, db.Error) {  	return s.getStatus(  		ctx, diff --git a/internal/db/bundb/status_test.go b/internal/db/bundb/status_test.go index bef8c7912..d86e0bcf9 100644 --- a/internal/db/bundb/status_test.go +++ b/internal/db/bundb/status_test.go @@ -50,6 +50,48 @@ func (suite *StatusTestSuite) TestGetStatusByID() {  	suite.True(*status.Likeable)  } +func (suite *StatusTestSuite) TestGetStatusesByID() { +	ids := []string{ +		suite.testStatuses["local_account_1_status_1"].ID, +		suite.testStatuses["local_account_2_status_3"].ID, +	} + +	statuses, err := suite.db.GetStatuses(context.Background(), ids) +	if err != nil { +		suite.FailNow(err.Error()) +	} + +	if len(statuses) != 2 { +		suite.FailNow("expected 2 statuses in slice") +	} + +	status1 := statuses[0] +	suite.NotNil(status1) +	suite.NotNil(status1.Account) +	suite.NotNil(status1.CreatedWithApplication) +	suite.Nil(status1.BoostOf) +	suite.Nil(status1.BoostOfAccount) +	suite.Nil(status1.InReplyTo) +	suite.Nil(status1.InReplyToAccount) +	suite.True(*status1.Federated) +	suite.True(*status1.Boostable) +	suite.True(*status1.Replyable) +	suite.True(*status1.Likeable) + +	status2 := statuses[1] +	suite.NotNil(status2) +	suite.NotNil(status2.Account) +	suite.NotNil(status2.CreatedWithApplication) +	suite.Nil(status2.BoostOf) +	suite.Nil(status2.BoostOfAccount) +	suite.Nil(status2.InReplyTo) +	suite.Nil(status2.InReplyToAccount) +	suite.True(*status2.Federated) +	suite.True(*status2.Boostable) +	suite.False(*status2.Replyable) +	suite.False(*status2.Likeable) +} +  func (suite *StatusTestSuite) TestGetStatusByURI() {  	status, err := suite.db.GetStatusByURI(context.Background(), suite.testStatuses["local_account_2_status_3"].URI)  	if err != nil { | 
