diff options
Diffstat (limited to 'internal/processing')
21 files changed, 1123 insertions, 189 deletions
diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go index e89ebf13f..bf1ea2f44 100644 --- a/internal/processing/account/delete.go +++ b/internal/processing/account/delete.go @@ -387,12 +387,12 @@ statusLoop:  func (p *Processor) deleteAccountNotifications(ctx context.Context, account *gtsmodel.Account) error {  	// Delete all notifications of all types targeting given account.  	if err := p.state.DB.DeleteNotifications(ctx, nil, account.ID, ""); err != nil && !errors.Is(err, db.ErrNoEntries) { -		return err +		return gtserror.Newf("error deleting notifications targeting account: %w", err)  	}  	// Delete all notifications of all types originating from given account.  	if err := p.state.DB.DeleteNotifications(ctx, nil, "", account.ID); err != nil && !errors.Is(err, db.ErrNoEntries) { -		return err +		return gtserror.Newf("error deleting notifications by account: %w", err)  	}  	return nil @@ -402,29 +402,35 @@ func (p *Processor) deleteAccountPeripheral(ctx context.Context, account *gtsmod  	// Delete all bookmarks owned by given account.  	if err := p.state.DB.DeleteStatusBookmarks(ctx, account.ID, ""); // nocollapse  	err != nil && !errors.Is(err, db.ErrNoEntries) { -		return err +		return gtserror.Newf("error deleting bookmarks by account: %w", err)  	}  	// Delete all bookmarks targeting given account.  	if err := p.state.DB.DeleteStatusBookmarks(ctx, "", account.ID); // nocollapse  	err != nil && !errors.Is(err, db.ErrNoEntries) { -		return err +		return gtserror.Newf("error deleting bookmarks targeting account: %w", err)  	}  	// Delete all faves owned by given account.  	if err := p.state.DB.DeleteStatusFaves(ctx, account.ID, ""); // nocollapse  	err != nil && !errors.Is(err, db.ErrNoEntries) { -		return err +		return gtserror.Newf("error deleting faves by account: %w", err)  	}  	// Delete all faves targeting given account.  	if err := p.state.DB.DeleteStatusFaves(ctx, "", account.ID); // nocollapse  	err != nil && !errors.Is(err, db.ErrNoEntries) { -		return err +		return gtserror.Newf("error deleting faves targeting account: %w", err)  	}  	// TODO: add status mutes here when they're implemented. +	// Delete all poll votes owned by given account. +	if err := p.state.DB.DeletePollVotesByAccountID(ctx, account.ID); // nocollapse +	err != nil && !errors.Is(err, db.ErrNoEntries) { +		return gtserror.Newf("error deleting poll votes by account: %w", err) +	} +  	return nil  } diff --git a/internal/processing/common/account.go.go b/internal/processing/common/account.go.go index 06e87fa0e..425f23483 100644 --- a/internal/processing/common/account.go.go +++ b/internal/processing/common/account.go.go @@ -47,8 +47,11 @@ func (p *Processor) GetTargetAccountBy(  	if target == nil {  		// DB loader could not find account in database. -		err := errors.New("target account not found") -		return nil, false, gtserror.NewErrorNotFound(err) +		const text = "target account not found" +		return nil, false, gtserror.NewErrorNotFound( +			errors.New(text), +			text, +		)  	}  	// Check whether target account is visible to requesting account. @@ -106,8 +109,11 @@ func (p *Processor) GetVisibleTargetAccount(  	if !visible {  		// Pretend account doesn't exist if not visible. -		err := errors.New("target account not found") -		return nil, gtserror.NewErrorNotFound(err) +		const text = "target account not found" +		return nil, gtserror.NewErrorNotFound( +			errors.New(text), +			text, +		)  	}  	return target, nil diff --git a/internal/processing/common/status.go b/internal/processing/common/status.go index fb480ec7e..233c1c867 100644 --- a/internal/processing/common/status.go +++ b/internal/processing/common/status.go @@ -47,8 +47,11 @@ func (p *Processor) GetTargetStatusBy(  	if target == nil {  		// DB loader could not find status in database. -		err := errors.New("target status not found") -		return nil, false, gtserror.NewErrorNotFound(err) +		const text = "target status not found" +		return nil, false, gtserror.NewErrorNotFound( +			errors.New(text), +			text, +		)  	}  	// Check whether target status is visible to requesting account. @@ -106,8 +109,11 @@ func (p *Processor) GetVisibleTargetStatus(  	if !visible {  		// Target should not be seen by requester. -		err := errors.New("target status not found") -		return nil, gtserror.NewErrorNotFound(err) +		const text = "target status not found" +		return nil, gtserror.NewErrorNotFound( +			errors.New(text), +			text, +		)  	}  	return target, nil diff --git a/internal/processing/fedi/status.go b/internal/processing/fedi/status.go index 9b8c7dd95..c8534eb5e 100644 --- a/internal/processing/fedi/status.go +++ b/internal/processing/fedi/status.go @@ -56,12 +56,12 @@ func (p *Processor) StatusGet(ctx context.Context, requestedUsername string, req  		return nil, gtserror.NewErrorNotFound(err)  	} -	asStatus, err := p.converter.StatusToAS(ctx, status) +	statusable, err := p.converter.StatusToAS(ctx, status)  	if err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} -	data, err := ap.Serialize(asStatus) +	data, err := ap.Serialize(statusable)  	if err != nil {  		return nil, gtserror.NewErrorInternalError(err)  	} diff --git a/internal/processing/polls/expiry.go b/internal/processing/polls/expiry.go new file mode 100644 index 000000000..59d0f17fe --- /dev/null +++ b/internal/processing/polls/expiry.go @@ -0,0 +1,126 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package polls + +import ( +	"context" +	"time" + +	"github.com/superseriousbusiness/gotosocial/internal/ap" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/log" +	"github.com/superseriousbusiness/gotosocial/internal/messages" +) + +func (p *Processor) ScheduleAll(ctx context.Context) error { +	// Fetch all open polls from the database (barebones models are enough). +	polls, err := p.state.DB.GetOpenPolls(gtscontext.SetBarebones(ctx)) +	if err != nil { +		return gtserror.Newf("error getting open polls from db: %w", err) +	} + +	var errs gtserror.MultiError + +	for _, poll := range polls { +		// Schedule each of the polls and catch any errors. +		if err := p.ScheduleExpiry(ctx, poll); err != nil { +			errs.Append(err) +		} +	} + +	return errs.Combine() +} + +func (p *Processor) ScheduleExpiry(ctx context.Context, poll *gtsmodel.Poll) error { +	// Ensure has a valid expiry. +	if !poll.ClosedAt.IsZero() { +		return gtserror.Newf("poll %s already expired", poll.ID) +	} + +	// Add the given poll to the scheduler. +	ok := p.state.Workers.Scheduler.AddOnce( +		poll.ID, +		poll.ExpiresAt, +		p.onExpiry(poll.ID), +	) + +	if !ok { +		// Failed to add the poll to the scheduler, either it was +		// starting / stopping or there already exists a task for poll. +		return gtserror.Newf("failed adding poll %s to scheduler", poll.ID) +	} + +	atStr := poll.ExpiresAt.Local().Format("Jan _2 2006 15:04:05") +	log.Infof(ctx, "scheduled poll expiry for %s at '%s'", poll.ID, atStr) +	return nil +} + +// onExpiry returns a callback function to be used by the scheduler when the given poll expires. +func (p *Processor) onExpiry(pollID string) func(context.Context, time.Time) { +	return func(ctx context.Context, now time.Time) { +		// Get the latest version of poll from database. +		poll, err := p.state.DB.GetPollByID(ctx, pollID) +		if err != nil { +			log.Errorf(ctx, "error getting poll %s from db: %v", pollID, err) +			return +		} + +		if !poll.ClosedAt.IsZero() { +			// Expiry handler has already been run for this poll. +			log.Errorf(ctx, "poll %s already closed", pollID) +			return +		} + +		// Extract status and +		// set its Poll field. +		status := poll.Status +		status.Poll = poll + +		// Ensure the status is fully populated (we need the account) +		if err := p.state.DB.PopulateStatus(ctx, status); err != nil { +			log.Errorf(ctx, "error populating poll %s status: %v", pollID, err) + +			if status.Account == nil { +				// cannot continue without +				// status account author. +				return +			} +		} + +		// Set "closed" time. +		poll.ClosedAt = now +		poll.Closing = true + +		// Update the Poll to mark it as closed in the database. +		if err := p.state.DB.UpdatePoll(ctx, poll, "closed_at"); err != nil { +			log.Errorf(ctx, "error updating poll %s in db: %v", pollID, err) +			return +		} + +		// Enqueue a status update operation to the client API worker, +		// this will asynchronously send an update with the Poll close time. +		p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ +			APActivityType: ap.ActivityUpdate, +			APObjectType:   ap.ObjectNote, +			GTSModel:       status, +			OriginAccount:  status.Account, +		}) +	} +} diff --git a/internal/processing/polls/get.go b/internal/processing/polls/get.go new file mode 100644 index 000000000..42fecbd43 --- /dev/null +++ b/internal/processing/polls/get.go @@ -0,0 +1,37 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package polls + +import ( +	"context" + +	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +func (p *Processor) PollGet(ctx context.Context, requester *gtsmodel.Account, pollID string) (*apimodel.Poll, gtserror.WithCode) { +	// Get (+ check visibility of) requested poll with ID. +	poll, errWithCode := p.getTargetPoll(ctx, requester, pollID) +	if errWithCode != nil { +		return nil, errWithCode +	} + +	// Return converted API model poll. +	return p.toAPIPoll(ctx, requester, poll) +} diff --git a/internal/processing/polls/poll.go b/internal/processing/polls/poll.go new file mode 100644 index 000000000..3b258b76c --- /dev/null +++ b/internal/processing/polls/poll.go @@ -0,0 +1,91 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package polls + +import ( +	"context" +	"errors" + +	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtscontext" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/processing/common" +	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/superseriousbusiness/gotosocial/internal/typeutils" +) + +type Processor struct { +	// common processor logic +	c *common.Processor + +	state     *state.State +	converter *typeutils.Converter +} + +func New(common *common.Processor, state *state.State, converter *typeutils.Converter) Processor { +	return Processor{ +		c:         common, +		state:     state, +		converter: converter, +	} +} + +// getTargetPoll fetches a target poll ID for requesting account, taking visibility of the poll's originating status into account. +func (p *Processor) getTargetPoll(ctx context.Context, requestingAccount *gtsmodel.Account, targetID string) (*gtsmodel.Poll, gtserror.WithCode) { +	// Load the requested poll with ID. +	// (barebones as we fetch status below) +	poll, err := p.state.DB.GetPollByID( +		gtscontext.SetBarebones(ctx), +		targetID, +	) +	if err != nil && !errors.Is(err, db.ErrNoEntries) { +		return nil, gtserror.NewErrorInternalError(err) +	} + +	if poll == nil { +		// No poll could be found for given ID. +		const text = "target poll not found" +		return nil, gtserror.NewErrorNotFound( +			errors.New(text), +			text, +		) +	} + +	// Check that we can see + fetch the originating status for requesting account. +	status, errWithCode := p.c.GetVisibleTargetStatus(ctx, requestingAccount, poll.StatusID) +	if errWithCode != nil { +		return nil, errWithCode +	} + +	// Update poll status. +	poll.Status = status + +	return poll, nil +} + +// toAPIPoll converrts a given Poll to frontend API model, returning an appropriate error with HTTP code on failure. +func (p *Processor) toAPIPoll(ctx context.Context, requester *gtsmodel.Account, poll *gtsmodel.Poll) (*apimodel.Poll, gtserror.WithCode) { +	apiPoll, err := p.converter.PollToAPIPoll(ctx, requester, poll) +	if err != nil { +		err := gtserror.Newf("error converting to api model: %w", err) +		return nil, gtserror.NewErrorInternalError(err) +	} +	return apiPoll, nil +} diff --git a/internal/processing/polls/poll_test.go b/internal/processing/polls/poll_test.go new file mode 100644 index 000000000..15a1938a8 --- /dev/null +++ b/internal/processing/polls/poll_test.go @@ -0,0 +1,234 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package polls_test + +import ( +	"context" +	"math/rand" +	"net/http" +	"testing" + +	"github.com/stretchr/testify/suite" +	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/media" +	"github.com/superseriousbusiness/gotosocial/internal/processing/common" +	"github.com/superseriousbusiness/gotosocial/internal/processing/polls" +	"github.com/superseriousbusiness/gotosocial/internal/state" +	"github.com/superseriousbusiness/gotosocial/internal/typeutils" +	"github.com/superseriousbusiness/gotosocial/internal/visibility" +	"github.com/superseriousbusiness/gotosocial/testrig" +) + +type PollTestSuite struct { +	suite.Suite +	state  state.State +	filter *visibility.Filter +	polls  polls.Processor + +	testAccounts map[string]*gtsmodel.Account +	testPolls    map[string]*gtsmodel.Poll +} + +func (suite *PollTestSuite) SetupTest() { +	testrig.InitTestConfig() +	testrig.InitTestLog() +	suite.state.Caches.Init() +	testrig.StartWorkers(&suite.state) +	testrig.NewTestDB(&suite.state) +	converter := typeutils.NewConverter(&suite.state) +	controller := testrig.NewTestTransportController(&suite.state, nil) +	mediaMgr := media.NewManager(&suite.state) +	federator := testrig.NewTestFederator(&suite.state, controller, mediaMgr) +	suite.filter = visibility.NewFilter(&suite.state) +	common := common.New(&suite.state, converter, federator, suite.filter) +	suite.polls = polls.New(&common, &suite.state, converter) +} + +func (suite *PollTestSuite) TearDownTest() { +	testrig.StopWorkers(&suite.state) +	testrig.StandardDBTeardown(suite.state.DB) +} + +func (suite *PollTestSuite) TestPollGet() { +	// Create a new context for this test. +	ctx, cncl := context.WithCancel(context.Background()) +	defer cncl() + +	// Perform test for all requester + poll combos. +	for _, account := range suite.testAccounts { +		for _, poll := range suite.testPolls { +			suite.testPollGet(ctx, account, poll) +		} +	} +} + +func (suite *PollTestSuite) testPollGet(ctx context.Context, requester *gtsmodel.Account, poll *gtsmodel.Poll) { +	// Ensure poll model is fully populated before anything. +	if err := suite.state.DB.PopulatePoll(ctx, poll); err != nil { +		suite.T().Fatalf("error populating poll: %v", err) +	} + +	var check func(*apimodel.Poll, gtserror.WithCode) bool + +	switch { +	case !pollIsVisible(suite.filter, ctx, requester, poll): +		// Poll should not be visible to requester, this should +		// return an error code 404 (to prevent info leak). +		check = func(poll *apimodel.Poll, err gtserror.WithCode) bool { +			return poll == nil && err.Code() == http.StatusNotFound +		} + +	default: +		// All other cases should succeed! i.e. no error and poll returned. +		check = func(poll *apimodel.Poll, err gtserror.WithCode) bool { +			return poll != nil && err == nil +		} +	} + +	// Perform the poll vote and check the expected response. +	if !check(suite.polls.PollGet(ctx, requester, poll.ID)) { +		suite.T().Errorf("unexpected response for poll get by %s", requester.DisplayName) +	} + +} + +func (suite *PollTestSuite) TestPollVote() { +	// Create a new context for this test. +	ctx, cncl := context.WithCancel(context.Background()) +	defer cncl() + +	// randomChoices generates random vote choices in poll. +	randomChoices := func(poll *gtsmodel.Poll) []int { +		var max int +		if *poll.Multiple { +			max = len(poll.Options) +		} else { +			max = 1 +		} +		count := 1 + rand.Intn(max) +		choices := make([]int, count) +		for i := range choices { +			choices[i] = rand.Intn(len(poll.Options)) +		} +		return choices +	} + +	// Perform test for all requester + poll combos. +	for _, account := range suite.testAccounts { +		for _, poll := range suite.testPolls { +			// Generate some valid choices and test. +			choices := randomChoices(poll) +			suite.testPollVote(ctx, +				account, +				poll, +				choices, +			) + +			// Test with empty choices. +			suite.testPollVote(ctx, +				account, +				poll, +				nil, +			) + +			// Test with out of range choice. +			suite.testPollVote(ctx, +				account, +				poll, +				[]int{len(poll.Options)}, +			) +		} +	} +} + +func (suite *PollTestSuite) testPollVote(ctx context.Context, requester *gtsmodel.Account, poll *gtsmodel.Poll, choices []int) { +	// Ensure poll model is fully populated before anything. +	if err := suite.state.DB.PopulatePoll(ctx, poll); err != nil { +		suite.T().Fatalf("error populating poll: %v", err) +	} + +	var check func(*apimodel.Poll, gtserror.WithCode) bool + +	switch { +	case !poll.ClosedAt.IsZero(): +		// Poll is already closed, i.e. no new votes allowed! +		// This should return an error 422 (unprocessable entity). +		check = func(poll *apimodel.Poll, err gtserror.WithCode) bool { +			return poll == nil && err.Code() == http.StatusUnprocessableEntity +		} + +	case !voteChoicesAreValid(poll, choices): +		// These are invalid vote choices, this should return +		// an error code 400 to indicate invalid request data. +		check = func(poll *apimodel.Poll, err gtserror.WithCode) bool { +			return poll == nil && err.Code() == http.StatusBadRequest +		} + +	case poll.Status.AccountID == requester.ID: +		// Immediately we know that poll owner cannot vote in +		// their own poll. this should return an error 422. +		check = func(poll *apimodel.Poll, err gtserror.WithCode) bool { +			return poll == nil && err.Code() == http.StatusUnprocessableEntity +		} + +	case !pollIsVisible(suite.filter, ctx, requester, poll): +		// Poll should not be visible to requester, this should +		// return an error code 404 (to prevent info leak). +		check = func(poll *apimodel.Poll, err gtserror.WithCode) bool { +			return poll == nil && err.Code() == http.StatusNotFound +		} + +	default: +		// All other cases should succeed! i.e. no error and poll returned. +		check = func(poll *apimodel.Poll, err gtserror.WithCode) bool { +			return poll != nil && err == nil +		} +	} + +	// Perform the poll vote and check the expected response. +	if !check(suite.polls.PollVote(ctx, requester, poll.ID, choices)) { +		suite.T().Errorf("unexpected response for poll vote by %s with %v", requester.DisplayName, choices) +	} +} + +// voteChoicesAreValid is a utility function to check whether choices are valid for poll. +func voteChoicesAreValid(poll *gtsmodel.Poll, choices []int) bool { +	if len(choices) == 0 || !*poll.Multiple && len(choices) > 1 { +		// Invalid number of vote choices. +		return false +	} +	for _, choice := range choices { +		if choice < 0 || choice >= len(poll.Options) { +			// Choice index out of range. +			return false +		} +	} +	return true +} + +// pollIsVisible is a short-hand function to return only a single boolean value for a visibility check on poll source status to account. +func pollIsVisible(filter *visibility.Filter, ctx context.Context, to *gtsmodel.Account, poll *gtsmodel.Poll) bool { +	visible, _ := filter.StatusVisible(ctx, to, poll.Status) +	return visible +} + +func TestPollTestSuite(t *testing.T) { +	suite.Run(t, new(PollTestSuite)) +} diff --git a/internal/processing/polls/vote.go b/internal/processing/polls/vote.go new file mode 100644 index 000000000..8c8f22225 --- /dev/null +++ b/internal/processing/polls/vote.go @@ -0,0 +1,108 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program.  If not, see <http://www.gnu.org/licenses/>. + +package polls + +import ( +	"context" +	"errors" + +	"github.com/superseriousbusiness/gotosocial/internal/ap" +	apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" +	"github.com/superseriousbusiness/gotosocial/internal/db" +	"github.com/superseriousbusiness/gotosocial/internal/gtserror" +	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +	"github.com/superseriousbusiness/gotosocial/internal/id" +	"github.com/superseriousbusiness/gotosocial/internal/messages" +) + +func (p *Processor) PollVote(ctx context.Context, requester *gtsmodel.Account, pollID string, choices []int) (*apimodel.Poll, gtserror.WithCode) { +	// Get (+ check visibility of) requested poll with ID. +	poll, errWithCode := p.getTargetPoll(ctx, requester, pollID) +	if errWithCode != nil { +		return nil, errWithCode +	} + +	switch { +	// Poll author isn't allowed to vote in their own poll. +	case requester.ID == poll.Status.AccountID: +		const text = "you can't vote in your own poll" +		return nil, gtserror.NewErrorUnprocessableEntity(errors.New(text), text) + +	// Poll has already closed, no more voting! +	case !poll.ClosedAt.IsZero(): +		const text = "poll already closed" +		return nil, gtserror.NewErrorUnprocessableEntity(errors.New(text), text) + +	// No choices given, or multiple given for single-choice poll. +	case len(choices) == 0 || (!*poll.Multiple && len(choices) > 1): +		const text = "invalid number of choices for poll" +		return nil, gtserror.NewErrorBadRequest(errors.New(text), text) +	} + +	for _, choice := range choices { +		if choice < 0 || choice >= len(poll.Options) { +			// This is an invalid choice (index out of range). +			const text = "invalid option index for poll" +			return nil, gtserror.NewErrorBadRequest(errors.New(text), text) +		} +	} + +	// Wrap the choices in a PollVote model. +	vote := >smodel.PollVote{ +		ID:        id.NewULID(), +		Choices:   choices, +		AccountID: requester.ID, +		Account:   requester, +		PollID:    pollID, +		Poll:      poll, +	} + +	// Insert the new poll votes into the database. +	err := p.state.DB.PutPollVote(ctx, vote) +	switch { + +	case err == nil: +		// no issue. + +	case errors.Is(err, db.ErrAlreadyExists): +		// Users cannot vote multiple *times* (not choices). +		const text = "you have already voted in poll" +		return nil, gtserror.NewErrorUnprocessableEntity(err, text) + +	default: +		// Any other irrecoverable database error. +		err := gtserror.Newf("error inserting poll vote: %w", err) +		return nil, gtserror.NewErrorInternalError(err) +	} + +	// Enqueue worker task to handle side-effects of user poll vote(s). +	p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ +		APActivityType: ap.ActivityCreate, +		APObjectType:   ap.ActivityQuestion, +		GTSModel:       vote, // the vote choices +		OriginAccount:  requester, +	}) + +	// Before returning the converted poll model, +	// increment the vote counts on our local copy +	// to get latest, instead of another db query. +	poll.IncrementVotes(choices) + +	// Return converted API model poll. +	return p.toAPIPoll(ctx, requester, poll) +} diff --git a/internal/processing/processor.go b/internal/processing/processor.go index b571ff499..65f05f49e 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -30,6 +30,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/processing/list"  	"github.com/superseriousbusiness/gotosocial/internal/processing/markers"  	"github.com/superseriousbusiness/gotosocial/internal/processing/media" +	"github.com/superseriousbusiness/gotosocial/internal/processing/polls"  	"github.com/superseriousbusiness/gotosocial/internal/processing/report"  	"github.com/superseriousbusiness/gotosocial/internal/processing/search"  	"github.com/superseriousbusiness/gotosocial/internal/processing/status" @@ -64,6 +65,7 @@ type Processor struct {  	list     list.Processor  	markers  markers.Processor  	media    media.Processor +	polls    polls.Processor  	report   report.Processor  	search   search.Processor  	status   status.Processor @@ -97,6 +99,10 @@ func (p *Processor) Media() *media.Processor {  	return &p.media  } +func (p *Processor) Polls() *polls.Processor { +	return &p.polls +} +  func (p *Processor) Report() *report.Processor {  	return &p.report  } @@ -151,23 +157,22 @@ func NewProcessor(  	// Start with sub processors that will  	// be required by the workers processor.  	commonProcessor := common.New(state, converter, federator, filter) -	accountProcessor := account.New(&commonProcessor, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc) -	mediaProcessor := media.New(state, converter, mediaManager, federator.TransportController()) -	streamProcessor := stream.New(state, oauthServer) +	processor.account = account.New(&commonProcessor, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc) +	processor.media = media.New(state, converter, mediaManager, federator.TransportController()) +	processor.stream = stream.New(state, oauthServer)  	// Instantiate the rest of the sub  	// processors + pin them to this struct. -	processor.account = accountProcessor +	processor.account = account.New(&commonProcessor, state, converter, mediaManager, oauthServer, federator, filter, parseMentionFunc)  	processor.admin = admin.New(state, cleaner, converter, mediaManager, federator.TransportController(), emailSender)  	processor.fedi = fedi.New(state, converter, federator, filter)  	processor.list = list.New(state, converter)  	processor.markers = markers.New(state, converter) -	processor.media = mediaProcessor +	processor.polls = polls.New(&commonProcessor, state, converter)  	processor.report = report.New(state, converter)  	processor.timeline = timeline.New(state, converter, filter)  	processor.search = search.New(state, federator, converter, filter) -	processor.status = status.New(&commonProcessor, state, federator, converter, filter, parseMentionFunc) -	processor.stream = streamProcessor +	processor.status = status.New(state, &commonProcessor, &processor.polls, federator, converter, filter, parseMentionFunc)  	processor.user = user.New(state, emailSender)  	// Workers processor handles asynchronous @@ -179,9 +184,9 @@ func NewProcessor(  		converter,  		filter,  		emailSender, -		&accountProcessor, -		&mediaProcessor, -		&streamProcessor, +		&processor.account, +		&processor.media, +		&processor.stream,  	)  	return processor diff --git a/internal/processing/status/create.go b/internal/processing/status/create.go index 40b3f2df2..fbe1fbd64 100644 --- a/internal/processing/status/create.go +++ b/internal/processing/status/create.go @@ -66,6 +66,26 @@ func (p *Processor) Create(ctx context.Context, requestingAccount *gtsmodel.Acco  		Text:                     form.Status,  	} +	if form.Poll != nil { +		// Update the status AS type to "Question". +		status.ActivityStreamsType = ap.ActivityQuestion + +		// Create new poll for status from form. +		secs := time.Duration(form.Poll.ExpiresIn) +		status.Poll = >smodel.Poll{ +			ID:         id.NewULID(), +			Multiple:   &form.Poll.Multiple, +			HideCounts: &form.Poll.HideTotals, +			Options:    form.Poll.Options, +			StatusID:   statusID, +			Status:     status, +			ExpiresAt:  now.Add(secs * time.Second), +		} + +		// Set poll ID on the status. +		status.PollID = status.Poll.ID +	} +  	if errWithCode := p.processReplyToID(ctx, form, requestingAccount.ID, status); errWithCode != nil {  		return nil, errWithCode  	} @@ -90,6 +110,14 @@ func (p *Processor) Create(ctx context.Context, requestingAccount *gtsmodel.Acco  		return nil, gtserror.NewErrorInternalError(err)  	} +	if status.Poll != nil { +		// Try to insert the new status poll in the database. +		if err := p.state.DB.PutPoll(ctx, status.Poll); err != nil { +			err := gtserror.Newf("error inserting poll in db: %w", err) +			return nil, gtserror.NewErrorInternalError(err) +		} +	} +  	// Insert this new status in the database.  	if err := p.state.DB.PutStatus(ctx, status); err != nil {  		return nil, gtserror.NewErrorInternalError(err) @@ -103,6 +131,15 @@ func (p *Processor) Create(ctx context.Context, requestingAccount *gtsmodel.Acco  		OriginAccount:  requestingAccount,  	}) +	if status.Poll != nil { +		// Now that the status is inserted, and side effects queued, +		// attempt to schedule an expiry handler for the status poll. +		if err := p.polls.ScheduleExpiry(ctx, status.Poll); err != nil { +			err := gtserror.Newf("error scheduling poll expiry: %w", err) +			return nil, gtserror.NewErrorInternalError(err) +		} +	} +  	return p.c.GetAPIStatus(ctx, requestingAccount, status)  } @@ -370,6 +407,18 @@ func (p *Processor) processContent(ctx context.Context, parseMention gtsmodel.Pa  	status.ContentWarning = warningRes.HTML  	status.Emojis = append(status.Emojis, warningRes.Emojis...) +	if status.Poll != nil { +		for i := range status.Poll.Options { +			// Sanitize each option title name and format. +			option := text.SanitizeToPlaintext(status.Poll.Options[i]) +			optionRes := formatInput(format, option) + +			// Collect each formatted result. +			status.Poll.Options[i] = optionRes.HTML +			status.Emojis = append(status.Emojis, optionRes.Emojis...) +		} +	} +  	// Gather all the database IDs from each of the gathered status mentions, tags, and emojis.  	status.MentionIDs = gatherIDs(status.Mentions, func(mention *gtsmodel.Mention) string { return mention.ID })  	status.TagIDs = gatherIDs(status.Tags, func(tag *gtsmodel.Tag) string { return tag.ID }) diff --git a/internal/processing/status/status.go b/internal/processing/status/status.go index b45b1651e..eaeb12b39 100644 --- a/internal/processing/status/status.go +++ b/internal/processing/status/status.go @@ -21,6 +21,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/federation"  	"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"  	"github.com/superseriousbusiness/gotosocial/internal/processing/common" +	"github.com/superseriousbusiness/gotosocial/internal/processing/polls"  	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/text"  	"github.com/superseriousbusiness/gotosocial/internal/typeutils" @@ -28,7 +29,7 @@ import (  )  type Processor struct { -	// common processor logic +	// embedded common logic  	c *common.Processor  	state        *state.State @@ -37,12 +38,16 @@ type Processor struct {  	filter       *visibility.Filter  	formatter    *text.Formatter  	parseMention gtsmodel.ParseMentionFunc + +	// other processors +	polls *polls.Processor  }  // New returns a new status processor.  func New( -	common *common.Processor,  	state *state.State, +	common *common.Processor, +	polls *polls.Processor,  	federator *federation.Federator,  	converter *typeutils.Converter,  	filter *visibility.Filter, @@ -56,5 +61,6 @@ func New(  		filter:       filter,  		formatter:    text.NewFormatter(state.DB),  		parseMention: parseMention, +		polls:        polls,  	}  } diff --git a/internal/processing/status/status_test.go b/internal/processing/status/status_test.go index 22486ecf2..dd9ad00f8 100644 --- a/internal/processing/status/status_test.go +++ b/internal/processing/status/status_test.go @@ -25,6 +25,7 @@ import (  	"github.com/superseriousbusiness/gotosocial/internal/media"  	"github.com/superseriousbusiness/gotosocial/internal/processing"  	"github.com/superseriousbusiness/gotosocial/internal/processing/common" +	"github.com/superseriousbusiness/gotosocial/internal/processing/polls"  	"github.com/superseriousbusiness/gotosocial/internal/processing/status"  	"github.com/superseriousbusiness/gotosocial/internal/state"  	"github.com/superseriousbusiness/gotosocial/internal/storage" @@ -96,8 +97,8 @@ func (suite *StatusStandardTestSuite) SetupTest() {  	)  	common := common.New(&suite.state, suite.typeConverter, suite.federator, filter) - -	suite.status = status.New(&common, &suite.state, suite.federator, suite.typeConverter, filter, processing.GetParseMentionFunc(suite.db, suite.federator)) +	polls := polls.New(&common, &suite.state, suite.typeConverter) +	suite.status = status.New(&suite.state, &common, &polls, suite.federator, suite.typeConverter, filter, processing.GetParseMentionFunc(suite.db, suite.federator))  	testrig.StandardDBSetup(suite.db, suite.testAccounts)  	testrig.StandardStorageSetup(suite.storage, "../../../testrig/media") diff --git a/internal/processing/stream/notification_test.go b/internal/processing/stream/notification_test.go index 7ce8b95d5..8a66b68a4 100644 --- a/internal/processing/stream/notification_test.go +++ b/internal/processing/stream/notification_test.go @@ -77,8 +77,8 @@ func (suite *NotificationTestSuite) TestStreamNotification() {      "header_static": "http://localhost:8080/assets/default_header.png",      "followers_count": 0,      "following_count": 0, -    "statuses_count": 1, -    "last_status_at": "2021-09-20T10:40:37.000Z", +    "statuses_count": 2, +    "last_status_at": "2021-09-11T09:40:37.000Z",      "emojis": [],      "fields": []    } diff --git a/internal/processing/workers/federate.go b/internal/processing/workers/federate.go index 80b01ca40..44432998d 100644 --- a/internal/processing/workers/federate.go +++ b/internal/processing/workers/federate.go @@ -158,26 +158,52 @@ func (f *federate) CreateStatus(ctx context.Context, status *gtsmodel.Status) er  		return err  	} -	// Convert status to ActivityStreams Statusable implementing type. +	// Convert status to AS Statusable implementing type.  	statusable, err := f.converter.StatusToAS(ctx, status)  	if err != nil {  		return gtserror.Newf("error converting status to Statusable: %w", err)  	} -	// Use ActivityStreams Statusable type as Object of Create. -	create, err := f.converter.WrapStatusableInCreate(statusable, false) +	// Send a Create activity with Statusable via the Actor's outbox. +	create := typeutils.WrapStatusableInCreate(statusable, false) +	if _, err := f.FederatingActor().Send(ctx, outboxIRI, create); err != nil { +		return gtserror.Newf("error sending Create activity via outbox %s: %w", outboxIRI, err) +	} +	return nil +} + +func (f *federate) CreatePollVote(ctx context.Context, poll *gtsmodel.Poll, vote *gtsmodel.PollVote) error { +	// Extract status from poll. +	status := poll.Status + +	// Do nothing if the status +	// shouldn't be federated. +	if !*status.Federated { +		return nil +	} + +	// Do nothing if this is +	// a vote in our status. +	if *status.Local { +		return nil +	} + +	// Parse the outbox URI of the poll vote author. +	outboxIRI, err := parseURI(vote.Account.OutboxURI)  	if err != nil { -		return gtserror.Newf("error wrapping Statusable in Create: %w", err) +		return err  	} -	// Send the Create via the Actor's outbox. -	if _, err := f.FederatingActor().Send( -		ctx, outboxIRI, create, -	); err != nil { -		return gtserror.Newf( -			"error sending activity %T via outbox %s: %w", -			create, outboxIRI, err, -		) +	// Convert votes to AS PollOptionable implementing type. +	notes, err := f.converter.PollVoteToASOptions(ctx, vote) +	if err != nil { +		return gtserror.Newf("error converting to notes: %w", err) +	} + +	// Send a Create activity with PollOptionables via the Actor's outbox. +	create := typeutils.WrapPollOptionablesInCreate(notes...) +	if _, err := f.FederatingActor().Send(ctx, outboxIRI, create); err != nil { +		return gtserror.Newf("error sending Create activity via outbox %s: %w", outboxIRI, err)  	}  	return nil @@ -256,13 +282,8 @@ func (f *federate) UpdateStatus(ctx context.Context, status *gtsmodel.Status) er  		return gtserror.Newf("error converting status to Statusable: %w", err)  	} -	// Use ActivityStreams Statusable type as Object of Update. -	update, err := f.converter.WrapStatusableInUpdate(statusable, false) -	if err != nil { -		return gtserror.Newf("error wrapping Statusable in Update: %w", err) -	} - -	// Send the Update activity with Statusable via the Actor's outbox. +	// Send an Update activity with Statusable via the Actor's outbox. +	update := typeutils.WrapStatusableInUpdate(statusable, false)  	if _, err := f.FederatingActor().Send(ctx, outboxIRI, update); err != nil {  		return gtserror.Newf("error sending Update activity via outbox %s: %w", outboxIRI, err)  	} diff --git a/internal/processing/workers/fromclientapi.go b/internal/processing/workers/fromclientapi.go index 789145226..e3f1e2d76 100644 --- a/internal/processing/workers/fromclientapi.go +++ b/internal/processing/workers/fromclientapi.go @@ -93,6 +93,13 @@ func (p *Processor) ProcessFromClientAPI(ctx context.Context, cMsg messages.From  		case ap.ObjectNote:  			return p.clientAPI.CreateStatus(ctx, cMsg) +		// CREATE QUESTION +		// (note we don't handle poll *votes* as AS +		// question type when federating (just notes), +		// but it makes for a nicer type switch here. +		case ap.ActivityQuestion: +			return p.clientAPI.CreatePollVote(ctx, cMsg) +  		// CREATE FOLLOW (request)  		case ap.ActivityFollow:  			return p.clientAPI.CreateFollowReq(ctx, cMsg) @@ -189,7 +196,7 @@ func (p *Processor) ProcessFromClientAPI(ctx context.Context, cMsg messages.From  		}  	} -	return nil +	return gtserror.Newf("unhandled: %s %s", cMsg.APActivityType, cMsg.APObjectType)  }  func (p *clientAPI) CreateAccount(ctx context.Context, cMsg messages.FromClientAPI) error { @@ -205,7 +212,7 @@ func (p *clientAPI) CreateAccount(ctx context.Context, cMsg messages.FromClientA  	}  	if err := p.surface.emailPleaseConfirm(ctx, user, account.Username); err != nil { -		return gtserror.Newf("error emailing %s: %w", account.Username, err) +		log.Errorf(ctx, "error emailing confirm: %v", err)  	}  	return nil @@ -218,7 +225,7 @@ func (p *clientAPI) CreateStatus(ctx context.Context, cMsg messages.FromClientAP  	}  	if err := p.surface.timelineAndNotifyStatus(ctx, status); err != nil { -		return gtserror.Newf("error timelining status: %w", err) +		log.Errorf(ctx, "error timelining and notifying status: %v", err)  	}  	if status.InReplyToID != "" { @@ -228,7 +235,48 @@ func (p *clientAPI) CreateStatus(ctx context.Context, cMsg messages.FromClientAP  	}  	if err := p.federate.CreateStatus(ctx, status); err != nil { -		return gtserror.Newf("error federating status: %w", err) +		log.Errorf(ctx, "error federating status: %v", err) +	} + +	return nil +} + +func (p *clientAPI) CreatePollVote(ctx context.Context, cMsg messages.FromClientAPI) error { +	// Cast the create poll vote attached to message. +	vote, ok := cMsg.GTSModel.(*gtsmodel.PollVote) +	if !ok { +		return gtserror.Newf("cannot cast %T -> *gtsmodel.Pollvote", cMsg.GTSModel) +	} + +	// Ensure the vote is fully populated in order to get original poll. +	if err := p.state.DB.PopulatePollVote(ctx, vote); err != nil { +		return gtserror.Newf("error populating poll vote from db: %w", err) +	} + +	// Ensure the poll on the vote is fully populated to get origin status. +	if err := p.state.DB.PopulatePoll(ctx, vote.Poll); err != nil { +		return gtserror.Newf("error populating poll from db: %w", err) +	} + +	// Get the origin status, +	// (also set the poll on it). +	status := vote.Poll.Status +	status.Poll = vote.Poll + +	// Interaction counts changed on the source status, uncache from timelines. +	p.surface.invalidateStatusFromTimelines(ctx, vote.Poll.StatusID) + +	if *status.Local { +		// These are poll votes in a local status, we only need to +		// federate the updated status model with latest vote counts. +		if err := p.federate.UpdateStatus(ctx, status); err != nil { +			log.Errorf(ctx, "error federating status update: %v", err) +		} +	} else { +		// These are votes in a remote poll, federate to origin the new poll vote(s). +		if err := p.federate.CreatePollVote(ctx, vote.Poll, vote); err != nil { +			log.Errorf(ctx, "error federating poll vote: %v", err) +		}  	}  	return nil @@ -241,14 +289,17 @@ func (p *clientAPI) CreateFollowReq(ctx context.Context, cMsg messages.FromClien  	}  	if err := p.surface.notifyFollowRequest(ctx, followRequest); err != nil { -		return gtserror.Newf("error notifying follow request: %w", err) +		log.Errorf(ctx, "error notifying follow request: %v", err)  	} +	// Convert the follow request to follow model (requests are sent as follows). +	follow := p.converter.FollowRequestToFollow(ctx, followRequest) +  	if err := p.federate.Follow(  		ctx, -		p.converter.FollowRequestToFollow(ctx, followRequest), +		follow,  	); err != nil { -		return gtserror.Newf("error federating follow: %w", err) +		log.Errorf(ctx, "error federating follow request: %v", err)  	}  	return nil @@ -266,7 +317,7 @@ func (p *clientAPI) CreateLike(ctx context.Context, cMsg messages.FromClientAPI)  	}  	if err := p.surface.notifyFave(ctx, fave); err != nil { -		return gtserror.Newf("error notifying fave: %w", err) +		log.Errorf(ctx, "error notifying fave: %v", err)  	}  	// Interaction counts changed on the faved status; @@ -274,7 +325,7 @@ func (p *clientAPI) CreateLike(ctx context.Context, cMsg messages.FromClientAPI)  	p.surface.invalidateStatusFromTimelines(ctx, fave.StatusID)  	if err := p.federate.Like(ctx, fave); err != nil { -		return gtserror.Newf("error federating like: %w", err) +		log.Errorf(ctx, "error federating like: %v", err)  	}  	return nil @@ -288,12 +339,12 @@ func (p *clientAPI) CreateAnnounce(ctx context.Context, cMsg messages.FromClient  	// Timeline and notify the boost wrapper status.  	if err := p.surface.timelineAndNotifyStatus(ctx, boost); err != nil { -		return gtserror.Newf("error timelining boost: %w", err) +		log.Errorf(ctx, "error timelining and notifying status: %v", err)  	}  	// Notify the boost target account.  	if err := p.surface.notifyAnnounce(ctx, boost); err != nil { -		return gtserror.Newf("error notifying boost: %w", err) +		log.Errorf(ctx, "error notifying boost: %v", err)  	}  	// Interaction counts changed on the boosted status; @@ -301,7 +352,7 @@ func (p *clientAPI) CreateAnnounce(ctx context.Context, cMsg messages.FromClient  	p.surface.invalidateStatusFromTimelines(ctx, boost.BoostOfID)  	if err := p.federate.Announce(ctx, boost); err != nil { -		return gtserror.Newf("error federating announce: %w", err) +		log.Errorf(ctx, "error federating announce: %v", err)  	}  	return nil @@ -335,7 +386,7 @@ func (p *clientAPI) CreateBlock(ctx context.Context, cMsg messages.FromClientAPI  	// TODO: same with bookmarks?  	if err := p.federate.Block(ctx, block); err != nil { -		return gtserror.Newf("error federating block: %w", err) +		log.Errorf(ctx, "error federating block: %v", err)  	}  	return nil @@ -350,7 +401,19 @@ func (p *clientAPI) UpdateStatus(ctx context.Context, cMsg messages.FromClientAP  	// Federate the updated status changes out remotely.  	if err := p.federate.UpdateStatus(ctx, status); err != nil { -		return gtserror.Newf("error federating status update: %w", err) +		log.Errorf(ctx, "error federating status update: %v", err) +	} + +	// Status representation has changed, invalidate from timelines. +	p.surface.invalidateStatusFromTimelines(ctx, status.ID) + +	if status.Poll != nil && status.Poll.Closing { + +		// If the latest status has a newly closed poll, at least compared +		// to the existing version, then notify poll close to all voters. +		if err := p.surface.notifyPollClose(ctx, status); err != nil { +			log.Errorf(ctx, "error notifying poll close: %v", err) +		}  	}  	return nil @@ -363,7 +426,7 @@ func (p *clientAPI) UpdateAccount(ctx context.Context, cMsg messages.FromClientA  	}  	if err := p.federate.UpdateAccount(ctx, account); err != nil { -		return gtserror.Newf("error federating account update: %w", err) +		log.Errorf(ctx, "error federating account update: %v", err)  	}  	return nil @@ -382,7 +445,7 @@ func (p *clientAPI) UpdateReport(ctx context.Context, cMsg messages.FromClientAP  	}  	if err := p.surface.emailReportClosed(ctx, report); err != nil { -		return gtserror.Newf("error sending report closed email: %w", err) +		log.Errorf(ctx, "error emailing report closed: %v", err)  	}  	return nil @@ -395,11 +458,11 @@ func (p *clientAPI) AcceptFollow(ctx context.Context, cMsg messages.FromClientAP  	}  	if err := p.surface.notifyFollow(ctx, follow); err != nil { -		return gtserror.Newf("error notifying follow: %w", err) +		log.Errorf(ctx, "error notifying follow: %v", err)  	}  	if err := p.federate.AcceptFollow(ctx, follow); err != nil { -		return gtserror.Newf("error federating follow request accept: %w", err) +		log.Errorf(ctx, "error federating follow accept: %v", err)  	}  	return nil @@ -415,7 +478,7 @@ func (p *clientAPI) RejectFollowRequest(ctx context.Context, cMsg messages.FromC  		ctx,  		p.converter.FollowRequestToFollow(ctx, followReq),  	); err != nil { -		return gtserror.Newf("error federating reject follow: %w", err) +		log.Errorf(ctx, "error federating follow reject: %v", err)  	}  	return nil @@ -428,7 +491,7 @@ func (p *clientAPI) UndoFollow(ctx context.Context, cMsg messages.FromClientAPI)  	}  	if err := p.federate.UndoFollow(ctx, follow); err != nil { -		return gtserror.Newf("error federating undo follow: %w", err) +		log.Errorf(ctx, "error federating follow undo: %v", err)  	}  	return nil @@ -441,7 +504,7 @@ func (p *clientAPI) UndoBlock(ctx context.Context, cMsg messages.FromClientAPI)  	}  	if err := p.federate.UndoBlock(ctx, block); err != nil { -		return gtserror.Newf("error federating undo block: %w", err) +		log.Errorf(ctx, "error federating block undo: %v", err)  	}  	return nil @@ -458,7 +521,7 @@ func (p *clientAPI) UndoFave(ctx context.Context, cMsg messages.FromClientAPI) e  	p.surface.invalidateStatusFromTimelines(ctx, statusFave.StatusID)  	if err := p.federate.UndoLike(ctx, statusFave); err != nil { -		return gtserror.Newf("error federating undo like: %w", err) +		log.Errorf(ctx, "error federating like undo: %v", err)  	}  	return nil @@ -475,7 +538,7 @@ func (p *clientAPI) UndoAnnounce(ctx context.Context, cMsg messages.FromClientAP  	}  	if err := p.surface.deleteStatusFromTimelines(ctx, status.ID); err != nil { -		return gtserror.Newf("error removing status from timelines: %w", err) +		log.Errorf(ctx, "error removing timelined status: %v", err)  	}  	// Interaction counts changed on the boosted status; @@ -483,7 +546,7 @@ func (p *clientAPI) UndoAnnounce(ctx context.Context, cMsg messages.FromClientAP  	p.surface.invalidateStatusFromTimelines(ctx, status.BoostOfID)  	if err := p.federate.UndoAnnounce(ctx, status); err != nil { -		return gtserror.Newf("error federating undo announce: %w", err) +		log.Errorf(ctx, "error federating announce undo: %v", err)  	}  	return nil @@ -509,7 +572,7 @@ func (p *clientAPI) DeleteStatus(ctx context.Context, cMsg messages.FromClientAP  	}  	if err := p.wipeStatus(ctx, status, deleteAttachments); err != nil { -		return gtserror.Newf("error wiping status: %w", err) +		log.Errorf(ctx, "error wiping status: %v", err)  	}  	if status.InReplyToID != "" { @@ -519,7 +582,7 @@ func (p *clientAPI) DeleteStatus(ctx context.Context, cMsg messages.FromClientAP  	}  	if err := p.federate.DeleteStatus(ctx, status); err != nil { -		return gtserror.Newf("error federating status delete: %w", err) +		log.Errorf(ctx, "error federating status delete: %v", err)  	}  	return nil @@ -543,11 +606,11 @@ func (p *clientAPI) DeleteAccount(ctx context.Context, cMsg messages.FromClientA  	}  	if err := p.federate.DeleteAccount(ctx, cMsg.TargetAccount); err != nil { -		return gtserror.Newf("error federating account delete: %w", err) +		log.Errorf(ctx, "error federating account delete: %v", err)  	}  	if err := p.account.Delete(ctx, cMsg.TargetAccount, originID); err != nil { -		return gtserror.Newf("error deleting account: %w", err) +		log.Errorf(ctx, "error deleting account: %v", err)  	}  	return nil @@ -563,12 +626,12 @@ func (p *clientAPI) ReportAccount(ctx context.Context, cMsg messages.FromClientA  	// remote instance if desired.  	if *report.Forwarded {  		if err := p.federate.Flag(ctx, report); err != nil { -			return gtserror.Newf("error federating report: %w", err) +			log.Errorf(ctx, "error federating flag: %v", err)  		}  	}  	if err := p.surface.emailReportOpened(ctx, report); err != nil { -		return gtserror.Newf("error sending report opened email: %w", err) +		log.Errorf(ctx, "error emailing report opened: %v", err)  	}  	return nil diff --git a/internal/processing/workers/fromfediapi.go b/internal/processing/workers/fromfediapi.go index 1ce3b6076..2b0bfa9fa 100644 --- a/internal/processing/workers/fromfediapi.go +++ b/internal/processing/workers/fromfediapi.go @@ -114,6 +114,10 @@ func (p *Processor) ProcessFromFediAPI(ctx context.Context, fMsg messages.FromFe  		// CREATE FLAG/REPORT  		case ap.ActivityFlag:  			return p.fediAPI.CreateFlag(ctx, fMsg) + +		// CREATE QUESTION +		case ap.ActivityQuestion: +			return p.fediAPI.CreatePollVote(ctx, fMsg)  		}  	// UPDATE SOMETHING @@ -170,7 +174,7 @@ func (p *fediAPI) CreateStatus(ctx context.Context, fMsg messages.FromFediAPI) e  		// Both situations we need to parse account URI to fetch it.  		accountURI, err := url.Parse(status.AccountURI)  		if err != nil { -			return err +			return gtserror.Newf("error parsing account uri: %w", err)  		}  		// Ensure that account for this status has been deref'd. @@ -180,7 +184,7 @@ func (p *fediAPI) CreateStatus(ctx context.Context, fMsg messages.FromFediAPI) e  			accountURI,  		)  		if err != nil { -			return err +			return gtserror.Newf("error getting account by uri: %w", err)  		}  	} @@ -192,7 +196,48 @@ func (p *fediAPI) CreateStatus(ctx context.Context, fMsg messages.FromFediAPI) e  	}  	if err := p.surface.timelineAndNotifyStatus(ctx, status); err != nil { -		return gtserror.Newf("error timelining status: %w", err) +		log.Errorf(ctx, "error timelining and notifying status: %v", err) +	} + +	return nil +} + +func (p *fediAPI) CreatePollVote(ctx context.Context, fMsg messages.FromFediAPI) error { +	// Cast poll vote type from the worker message. +	vote, ok := fMsg.GTSModel.(*gtsmodel.PollVote) +	if !ok { +		return gtserror.Newf("cannot cast %T -> *gtsmodel.PollVote", fMsg.GTSModel) +	} + +	// Insert the new poll vote in the database. +	if err := p.state.DB.PutPollVote(ctx, vote); err != nil { +		return gtserror.Newf("error inserting poll vote in db: %w", err) +	} + +	// Ensure the poll vote is fully populated at this point. +	if err := p.state.DB.PopulatePollVote(ctx, vote); err != nil { +		return gtserror.Newf("error populating poll vote from db: %w", err) +	} + +	// Ensure the poll on the vote is fully populated to get origin status. +	if err := p.state.DB.PopulatePoll(ctx, vote.Poll); err != nil { +		return gtserror.Newf("error populating poll from db: %w", err) +	} + +	// Get the origin status, +	// (also set the poll on it). +	status := vote.Poll.Status +	status.Poll = vote.Poll + +	// Interaction counts changed on the source status, uncache from timelines. +	p.surface.invalidateStatusFromTimelines(ctx, vote.Poll.StatusID) + +	if *status.Local { +		// These were poll votes in a local status, we need to +		// federate the updated status model with latest vote counts. +		if err := p.federate.UpdateStatus(ctx, status); err != nil { +			log.Errorf(ctx, "error federating status update: %v", err) +		}  	}  	return nil @@ -269,12 +314,10 @@ func (p *fediAPI) CreateFollowReq(ctx context.Context, fMsg messages.FromFediAPI  	}  	if *followRequest.TargetAccount.Locked { -		// Account on our instance is locked: -		// just notify the follow request. +		// Account on our instance is locked: just notify the follow request.  		if err := p.surface.notifyFollowRequest(ctx, followRequest); err != nil { -			return gtserror.Newf("error notifying follow request: %w", err) +			log.Errorf(ctx, "error notifying follow request: %v", err)  		} -  		return nil  	} @@ -291,11 +334,11 @@ func (p *fediAPI) CreateFollowReq(ctx context.Context, fMsg messages.FromFediAPI  	}  	if err := p.federate.AcceptFollow(ctx, follow); err != nil { -		return gtserror.Newf("error federating accept follow request: %w", err) +		log.Errorf(ctx, "error federating follow request accept: %v", err)  	}  	if err := p.surface.notifyFollow(ctx, follow); err != nil { -		return gtserror.Newf("error notifying follow: %w", err) +		log.Errorf(ctx, "error notifying follow: %v", err)  	}  	return nil @@ -313,7 +356,7 @@ func (p *fediAPI) CreateLike(ctx context.Context, fMsg messages.FromFediAPI) err  	}  	if err := p.surface.notifyFave(ctx, fave); err != nil { -		return gtserror.Newf("error notifying fave: %w", err) +		log.Errorf(ctx, "error notifying fave: %v", err)  	}  	// Interaction counts changed on the faved status; @@ -354,11 +397,11 @@ func (p *fediAPI) CreateAnnounce(ctx context.Context, fMsg messages.FromFediAPI)  	// Timeline and notify the announce.  	if err := p.surface.timelineAndNotifyStatus(ctx, status); err != nil { -		return gtserror.Newf("error timelining status: %w", err) +		log.Errorf(ctx, "error timelining and notifying status: %v", err)  	}  	if err := p.surface.notifyAnnounce(ctx, status); err != nil { -		return gtserror.Newf("error notifying status: %w", err) +		log.Errorf(ctx, "error notifying announce: %v", err)  	}  	// Interaction counts changed on the boosted status; @@ -382,7 +425,7 @@ func (p *fediAPI) CreateBlock(ctx context.Context, fMsg messages.FromFediAPI) er  		block.AccountID,  		block.TargetAccountID,  	); err != nil { -		return gtserror.Newf("%w", err) +		log.Errorf(ctx, "error wiping items from block -> target's home timeline: %v", err)  	}  	if err := p.state.Timelines.Home.WipeItemsFromAccountID( @@ -390,7 +433,7 @@ func (p *fediAPI) CreateBlock(ctx context.Context, fMsg messages.FromFediAPI) er  		block.TargetAccountID,  		block.AccountID,  	); err != nil { -		return gtserror.Newf("%w", err) +		log.Errorf(ctx, "error wiping items from target -> block's home timeline: %v", err)  	}  	// Now list timelines. @@ -399,7 +442,7 @@ func (p *fediAPI) CreateBlock(ctx context.Context, fMsg messages.FromFediAPI) er  		block.AccountID,  		block.TargetAccountID,  	); err != nil { -		return gtserror.Newf("%w", err) +		log.Errorf(ctx, "error wiping items from block -> target's list timeline(s): %v", err)  	}  	if err := p.state.Timelines.List.WipeItemsFromAccountID( @@ -407,7 +450,7 @@ func (p *fediAPI) CreateBlock(ctx context.Context, fMsg messages.FromFediAPI) er  		block.TargetAccountID,  		block.AccountID,  	); err != nil { -		return gtserror.Newf("%w", err) +		log.Errorf(ctx, "error wiping items from target -> block's list timeline(s): %v", err)  	}  	// Remove any follows that existed between blocker + blockee. @@ -416,10 +459,7 @@ func (p *fediAPI) CreateBlock(ctx context.Context, fMsg messages.FromFediAPI) er  		block.AccountID,  		block.TargetAccountID,  	); err != nil { -		return gtserror.Newf( -			"db error deleting follow from %s targeting %s: %w", -			block.AccountID, block.TargetAccountID, err, -		) +		log.Errorf(ctx, "error deleting follow from block -> target: %v", err)  	}  	if err := p.state.DB.DeleteFollow( @@ -427,10 +467,7 @@ func (p *fediAPI) CreateBlock(ctx context.Context, fMsg messages.FromFediAPI) er  		block.TargetAccountID,  		block.AccountID,  	); err != nil { -		return gtserror.Newf( -			"db error deleting follow from %s targeting %s: %w", -			block.TargetAccountID, block.AccountID, err, -		) +		log.Errorf(ctx, "error deleting follow from target -> block: %v", err)  	}  	// Remove any follow requests that existed between blocker + blockee. @@ -439,10 +476,7 @@ func (p *fediAPI) CreateBlock(ctx context.Context, fMsg messages.FromFediAPI) er  		block.AccountID,  		block.TargetAccountID,  	); err != nil { -		return gtserror.Newf( -			"db error deleting follow request from %s targeting %s: %w", -			block.AccountID, block.TargetAccountID, err, -		) +		log.Errorf(ctx, "error deleting follow request from block -> target: %v", err)  	}  	if err := p.state.DB.DeleteFollowRequest( @@ -450,10 +484,7 @@ func (p *fediAPI) CreateBlock(ctx context.Context, fMsg messages.FromFediAPI) er  		block.TargetAccountID,  		block.AccountID,  	); err != nil { -		return gtserror.Newf( -			"db error deleting follow request from %s targeting %s: %w", -			block.TargetAccountID, block.AccountID, err, -		) +		log.Errorf(ctx, "error deleting follow request from target -> block: %v", err)  	}  	return nil @@ -469,7 +500,7 @@ func (p *fediAPI) CreateFlag(ctx context.Context, fMsg messages.FromFediAPI) err  	// - notify admins by dm / notification  	if err := p.surface.emailReportOpened(ctx, incomingReport); err != nil { -		return gtserror.Newf("error sending report opened email: %w", err) +		log.Errorf(ctx, "error emailing report opened: %v", err)  	}  	return nil @@ -497,7 +528,7 @@ func (p *fediAPI) UpdateAccount(ctx context.Context, fMsg messages.FromFediAPI)  		true, // Force refresh.  	)  	if err != nil { -		return gtserror.Newf("error refreshing updated account: %w", err) +		log.Errorf(ctx, "error refreshing account: %v", err)  	}  	return nil @@ -514,7 +545,7 @@ func (p *fediAPI) UpdateStatus(ctx context.Context, fMsg messages.FromFediAPI) e  	apStatus, _ := fMsg.APObjectModel.(ap.Statusable)  	// Fetch up-to-date attach status attachments, etc. -	_, statusable, err := p.federate.RefreshStatus( +	status, _, err := p.federate.RefreshStatus(  		ctx,  		fMsg.ReceivingAccount.Username,  		existing, @@ -522,12 +553,19 @@ func (p *fediAPI) UpdateStatus(ctx context.Context, fMsg messages.FromFediAPI) e  		true,  	)  	if err != nil { -		return gtserror.Newf("error refreshing updated status: %w", err) +		log.Errorf(ctx, "error refreshing status: %v", err)  	} -	if statusable != nil { -		// Status representation was refetched, uncache from timelines. -		p.surface.invalidateStatusFromTimelines(ctx, existing.ID) +	// Status representation was refetched, uncache from timelines. +	p.surface.invalidateStatusFromTimelines(ctx, status.ID) + +	if status.Poll != nil && status.Poll.Closing { + +		// If the latest status has a newly closed poll, at least compared +		// to the existing version, then notify poll close to all voters. +		if err := p.surface.notifyPollClose(ctx, status); err != nil { +			log.Errorf(ctx, "error sending poll notification: %v", err) +		}  	}  	return nil @@ -545,7 +583,7 @@ func (p *fediAPI) DeleteStatus(ctx context.Context, fMsg messages.FromFediAPI) e  	}  	if err := p.wipeStatus(ctx, status, deleteAttachments); err != nil { -		return gtserror.Newf("error wiping status: %w", err) +		log.Errorf(ctx, "error wiping status: %v", err)  	}  	if status.InReplyToID != "" { @@ -564,7 +602,7 @@ func (p *fediAPI) DeleteAccount(ctx context.Context, fMsg messages.FromFediAPI)  	}  	if err := p.account.Delete(ctx, account, account.ID); err != nil { -		return gtserror.Newf("error deleting account: %w", err) +		log.Errorf(ctx, "error deleting account: %v", err)  	}  	return nil diff --git a/internal/processing/workers/fromfediapi_test.go b/internal/processing/workers/fromfediapi_test.go index b8d86ac45..952c008cc 100644 --- a/internal/processing/workers/fromfediapi_test.go +++ b/internal/processing/workers/fromfediapi_test.go @@ -347,8 +347,15 @@ func (suite *FromFediAPITestSuite) TestProcessAccountDelete() {  		suite.FailNow("timeout waiting for statuses to be deleted")  	} -	dbAccount, err := suite.db.GetAccountByID(ctx, deletedAccount.ID) -	suite.NoError(err) +	var dbAccount *gtsmodel.Account + +	// account data should be zeroed. +	if !testrig.WaitFor(func() bool { +		dbAccount, err = suite.db.GetAccountByID(ctx, deletedAccount.ID) +		return err == nil && dbAccount.DisplayName == "" +	}) { +		suite.FailNow("timeout waiting for statuses to be deleted") +	}  	suite.Empty(dbAccount.Note)  	suite.Empty(dbAccount.DisplayName) diff --git a/internal/processing/workers/surfacenotify.go b/internal/processing/workers/surfacenotify.go index b99fa3ad3..2dc60023c 100644 --- a/internal/processing/workers/surfacenotify.go +++ b/internal/processing/workers/surfacenotify.go @@ -35,12 +35,25 @@ func (s *surface) notifyMentions(  	ctx context.Context,  	status *gtsmodel.Status,  ) error { -	var ( -		mentions = status.Mentions -		errs     = gtserror.NewMultiError(len(mentions)) -	) +	var errs gtserror.MultiError + +	for _, mention := range status.Mentions { +		// Set status on the mention (stops +		// the below function populating it). +		mention.Status = status + +		// Beforehand, ensure the passed mention is fully populated. +		if err := s.state.DB.PopulateMention(ctx, mention); err != nil { +			errs.Appendf("error populating mention %s: %w", mention.ID, err) +			continue +		} + +		if mention.TargetAccount.IsRemote() { +			// no need to notify +			// remote accounts. +			continue +		} -	for _, mention := range mentions {  		// Ensure thread not muted  		// by mentioned account.  		muted, err := s.state.DB.IsThreadMutedByAccount( @@ -48,9 +61,8 @@ func (s *surface) notifyMentions(  			status.ThreadID,  			mention.TargetAccountID,  		) -  		if err != nil { -			errs.Append(err) +			errs.Appendf("error checking status thread mute %s: %w", status.ThreadID, err)  			continue  		} @@ -61,14 +73,16 @@ func (s *surface) notifyMentions(  			continue  		} -		if err := s.notify( -			ctx, +		// notify mentioned +		// by status author. +		if err := s.notify(ctx,  			gtsmodel.NotificationMention, -			mention.TargetAccountID, -			mention.OriginAccountID, +			mention.TargetAccount, +			mention.OriginAccount,  			mention.StatusID,  		); err != nil { -			errs.Append(err) +			errs.Appendf("error notifying mention target %s: %w", mention.TargetAccountID, err) +			continue  		}  	} @@ -79,15 +93,30 @@ func (s *surface) notifyMentions(  // follow request that they have a new follow request.  func (s *surface) notifyFollowRequest(  	ctx context.Context, -	followRequest *gtsmodel.FollowRequest, +	followReq *gtsmodel.FollowRequest,  ) error { -	return s.notify( -		ctx, +	// Beforehand, ensure the passed follow request is fully populated. +	if err := s.state.DB.PopulateFollowRequest(ctx, followReq); err != nil { +		return gtserror.Newf("error populating follow request %s: %w", followReq.ID, err) +	} + +	if followReq.TargetAccount.IsRemote() { +		// no need to notify +		// remote accounts. +		return nil +	} + +	// Now notify the follow request itself. +	if err := s.notify(ctx,  		gtsmodel.NotificationFollowRequest, -		followRequest.TargetAccountID, -		followRequest.AccountID, +		followReq.TargetAccount, +		followReq.Account,  		"", -	) +	); err != nil { +		return gtserror.Newf("error notifying follow target %s: %w", followReq.TargetAccountID, err) +	} + +	return nil  }  // notifyFollow notifies the target of the given follow that @@ -98,6 +127,17 @@ func (s *surface) notifyFollow(  	ctx context.Context,  	follow *gtsmodel.Follow,  ) error { +	// Beforehand, ensure the passed follow is fully populated. +	if err := s.state.DB.PopulateFollow(ctx, follow); err != nil { +		return gtserror.Newf("error populating follow %s: %w", follow.ID, err) +	} + +	if follow.TargetAccount.IsRemote() { +		// no need to notify +		// remote accounts. +		return nil +	} +  	// Check if previous follow req notif exists.  	prevNotif, err := s.state.DB.GetNotification(  		gtscontext.SetBarebones(ctx), @@ -107,24 +147,28 @@ func (s *surface) notifyFollow(  		"",  	)  	if err != nil && !errors.Is(err, db.ErrNoEntries) { -		return gtserror.Newf("db error checking for previous follow request notification: %w", err) +		return gtserror.Newf("error getting notification: %w", err)  	}  	if prevNotif != nil { -		// Previous notif existed, delete it. -		if err := s.state.DB.DeleteNotificationByID(ctx, prevNotif.ID); err != nil { -			return gtserror.Newf("db error removing previous follow request notification %s: %w", prevNotif.ID, err) +		// Previous follow request notif existed, delete it before creating new. +		if err := s.state.DB.DeleteNotificationByID(ctx, prevNotif.ID); // nocollapse +		err != nil && !errors.Is(err, db.ErrNoEntries) { +			return gtserror.Newf("error deleting notification %s: %w", prevNotif.ID, err)  		}  	}  	// Now notify the follow itself. -	return s.notify( -		ctx, +	if err := s.notify(ctx,  		gtsmodel.NotificationFollow, -		follow.TargetAccountID, -		follow.AccountID, +		follow.TargetAccount, +		follow.Account,  		"", -	) +	); err != nil { +		return gtserror.Newf("error notifying follow target %s: %w", follow.TargetAccountID, err) +	} + +	return nil  }  // notifyFave notifies the target of the given @@ -138,6 +182,17 @@ func (s *surface) notifyFave(  		return nil  	} +	// Beforehand, ensure the passed status fave is fully populated. +	if err := s.state.DB.PopulateStatusFave(ctx, fave); err != nil { +		return gtserror.Newf("error populating fave %s: %w", fave.ID, err) +	} + +	if fave.TargetAccount.IsRemote() { +		// no need to notify +		// remote accounts. +		return nil +	} +  	// Ensure favee hasn't  	// muted the thread.  	muted, err := s.state.DB.IsThreadMutedByAccount( @@ -145,24 +200,28 @@ func (s *surface) notifyFave(  		fave.Status.ThreadID,  		fave.TargetAccountID,  	) -  	if err != nil { -		return err +		return gtserror.Newf("error checking status thread mute %s: %w", fave.StatusID, err)  	}  	if muted { -		// Boostee doesn't want +		// Favee doesn't want  		// notifs for this thread.  		return nil  	} -	return s.notify( -		ctx, +	// notify status author +	// of fave by account. +	if err := s.notify(ctx,  		gtsmodel.NotificationFave, -		fave.TargetAccountID, -		fave.AccountID, +		fave.TargetAccount, +		fave.Account,  		fave.StatusID, -	) +	); err != nil { +		return gtserror.Newf("error notifying status author %s: %w", fave.TargetAccountID, err) +	} + +	return nil  }  // notifyAnnounce notifies the status boost target @@ -176,14 +235,19 @@ func (s *surface) notifyAnnounce(  		return nil  	} -	if status.BoostOf == nil { -		// No boosted status -		// set, nothing to do. +	if status.BoostOfAccountID == status.AccountID { +		// Self-boost, nothing to do.  		return nil  	} -	if status.BoostOfAccountID == status.AccountID { -		// Self-boost, nothing to do. +	// Beforehand, ensure the passed status is fully populated. +	if err := s.state.DB.PopulateStatus(ctx, status); err != nil { +		return gtserror.Newf("error populating status %s: %w", status.ID, err) +	} + +	if status.BoostOfAccount.IsRemote() { +		// no need to notify +		// remote accounts.  		return nil  	} @@ -196,7 +260,7 @@ func (s *surface) notifyAnnounce(  	)  	if err != nil { -		return err +		return gtserror.Newf("error checking status thread mute %s: %w", status.BoostOfID, err)  	}  	if muted { @@ -205,13 +269,68 @@ func (s *surface) notifyAnnounce(  		return nil  	} -	return s.notify( -		ctx, +	// notify status author +	// of boost by account. +	if err := s.notify(ctx,  		gtsmodel.NotificationReblog, -		status.BoostOfAccountID, -		status.AccountID, +		status.BoostOfAccount, +		status.Account,  		status.ID, -	) +	); err != nil { +		return gtserror.Newf("error notifying status author %s: %w", status.BoostOfAccountID, err) +	} + +	return nil +} + +func (s *surface) notifyPollClose(ctx context.Context, status *gtsmodel.Status) error { +	// Beforehand, ensure the passed status is fully populated. +	if err := s.state.DB.PopulateStatus(ctx, status); err != nil { +		return gtserror.Newf("error populating status %s: %w", status.ID, err) +	} + +	// Fetch all votes in the attached status poll. +	votes, err := s.state.DB.GetPollVotes(ctx, status.PollID) +	if err != nil { +		return gtserror.Newf("error getting poll %s votes: %w", status.PollID, err) +	} + +	var errs gtserror.MultiError + +	if status.Account.IsLocal() { +		// Send a notification to the status +		// author that their poll has closed! +		if err := s.notify(ctx, +			gtsmodel.NotificationPoll, +			status.Account, +			status.Account, +			status.ID, +		); err != nil { +			errs.Appendf("error notifying poll author: %w", err) +		} +	} + +	for _, vote := range votes { +		if vote.Account.IsRemote() { +			// no need to notify +			// remote accounts. +			continue +		} + +		// notify voter that +		// poll has been closed. +		if err := s.notify(ctx, +			gtsmodel.NotificationMention, +			vote.Account, +			status.Account, +			status.ID, +		); err != nil { +			errs.Appendf("error notifying poll voter %s: %w", vote.AccountID, err) +			continue +		} +	} + +	return errs.Combine()  }  // notify creates, inserts, and streams a new @@ -228,17 +347,12 @@ func (s *surface) notifyAnnounce(  func (s *surface) notify(  	ctx context.Context,  	notificationType gtsmodel.NotificationType, -	targetAccountID string, -	originAccountID string, +	targetAccount *gtsmodel.Account, +	originAccount *gtsmodel.Account,  	statusID string,  ) error { -	targetAccount, err := s.state.DB.GetAccountByID(ctx, targetAccountID) -	if err != nil { -		return gtserror.Newf("error getting target account %s: %w", targetAccountID, err) -	} - -	if !targetAccount.IsLocal() { -		// Nothing to do. +	if targetAccount.IsRemote() { +		// nothing to do.  		return nil  	} @@ -247,8 +361,8 @@ func (s *surface) notify(  	if _, err := s.state.DB.GetNotification(  		gtscontext.SetBarebones(ctx),  		notificationType, -		targetAccountID, -		originAccountID, +		targetAccount.ID, +		originAccount.ID,  		statusID,  	); err == nil {  		// Notification exists; @@ -264,8 +378,10 @@ func (s *surface) notify(  	notif := >smodel.Notification{  		ID:               id.NewULID(),  		NotificationType: notificationType, -		TargetAccountID:  targetAccountID, -		OriginAccountID:  originAccountID, +		TargetAccountID:  targetAccount.ID, +		TargetAccount:    targetAccount, +		OriginAccountID:  originAccount.ID, +		OriginAccount:    originAccount,  		StatusID:         statusID,  	} diff --git a/internal/processing/workers/surfacetimeline.go b/internal/processing/workers/surfacetimeline.go index 15263cf78..baebdbc66 100644 --- a/internal/processing/workers/surfacetimeline.go +++ b/internal/processing/workers/surfacetimeline.go @@ -85,7 +85,7 @@ func (s *surface) timelineAndNotifyStatusForFollowers(  	follows []*gtsmodel.Follow,  ) error {  	var ( -		errs  = new(gtserror.MultiError) +		errs  gtserror.MultiError  		boost = status.BoostOfID != ""  		reply = status.InReplyToURI != ""  	) @@ -117,7 +117,7 @@ func (s *surface) timelineAndNotifyStatusForFollowers(  			ctx,  			status,  			follow, -			errs, +			&errs,  		)  		// Add status to home timeline for owner @@ -160,11 +160,10 @@ func (s *surface) timelineAndNotifyStatusForFollowers(  		//   - This is a top-level post (not a reply or boost).  		//  		// That means we can officially notify this one. -		if err := s.notify( -			ctx, +		if err := s.notify(ctx,  			gtsmodel.NotificationStatus, -			follow.AccountID, -			status.AccountID, +			follow.Account, +			status.Account,  			status.ID,  		); err != nil {  			errs.Appendf("error notifying account %s about new status: %w", follow.AccountID, err) diff --git a/internal/processing/workers/wipestatus.go b/internal/processing/workers/wipestatus.go index ab59f14be..90a037928 100644 --- a/internal/processing/workers/wipestatus.go +++ b/internal/processing/workers/wipestatus.go @@ -85,6 +85,21 @@ func wipeStatusF(state *state.State, media *media.Processor, surface *surface) w  			errs.Appendf("error deleting status faves: %w", err)  		} +		if pollID := statusToDelete.PollID; pollID != "" { +			// Delete this poll by ID from the database. +			if err := state.DB.DeletePollByID(ctx, pollID); err != nil { +				errs.Appendf("error deleting status poll: %w", err) +			} + +			// Delete any poll votes pointing to this poll ID. +			if err := state.DB.DeletePollVotes(ctx, pollID); err != nil { +				errs.Appendf("error deleting status poll votes: %w", err) +			} + +			// Cancel any scheduled expiry task for poll. +			_ = state.Workers.Scheduler.Cancel(pollID) +		} +  		// delete all boosts for this status + remove them from timelines  		boosts, err := state.DB.GetStatusBoosts(  			// we MUST set a barebones context here,  | 
