From de8a3c92149a8f1e5b6341bfeffe867b3e0e794f Mon Sep 17 00:00:00 2001 From: Virgil Date: Thu, 2 Apr 2026 01:25:43 +0000 Subject: [PATCH] feat(notifications): add notification list filters Co-Authored-By: Virgil --- notifications.go | 127 +++++++++++++++++++++++++++++++++++++++--- notifications_test.go | 87 +++++++++++++++++++++++++++++ 2 files changed, 206 insertions(+), 8 deletions(-) diff --git a/notifications.go b/notifications.go index ff50476..09cd4cb 100644 --- a/notifications.go +++ b/notifications.go @@ -3,12 +3,45 @@ package forge import ( "context" "iter" + "net/http" "net/url" + "strconv" "time" "dappco.re/go/core/forge/types" ) +// NotificationListOptions controls filtering for notification listings. +type NotificationListOptions struct { + All bool + StatusTypes []string + SubjectTypes []string + Since *time.Time + Before *time.Time +} + +func (o NotificationListOptions) addQuery(values url.Values) { + if o.All { + values.Set("all", "true") + } + for _, status := range o.StatusTypes { + if status != "" { + values.Add("status-types", status) + } + } + for _, subjectType := range o.SubjectTypes { + if subjectType != "" { + values.Add("subject-type", subjectType) + } + } + if o.Since != nil { + values.Set("since", o.Since.Format(time.RFC3339)) + } + if o.Before != nil { + values.Set("before", o.Before.Format(time.RFC3339)) + } +} + // NotificationService handles notification operations via the Forgejo API. // No Resource embedding — varied endpoint shapes. // @@ -52,13 +85,13 @@ func (o NotificationRepoMarkOptions) queryString() string { } // List returns all notifications for the authenticated user. -func (s *NotificationService) List(ctx context.Context) ([]types.NotificationThread, error) { - return ListAll[types.NotificationThread](ctx, s.client, "/api/v1/notifications", nil) +func (s *NotificationService) List(ctx context.Context, filters ...NotificationListOptions) ([]types.NotificationThread, error) { + return s.listAll(ctx, "/api/v1/notifications", filters...) } // Iter returns an iterator over all notifications for the authenticated user. -func (s *NotificationService) Iter(ctx context.Context) iter.Seq2[types.NotificationThread, error] { - return ListIter[types.NotificationThread](ctx, s.client, "/api/v1/notifications", nil) +func (s *NotificationService) Iter(ctx context.Context, filters ...NotificationListOptions) iter.Seq2[types.NotificationThread, error] { + return s.listIter(ctx, "/api/v1/notifications", filters...) } // NewAvailable returns the count of unread notifications for the authenticated user. @@ -71,15 +104,15 @@ func (s *NotificationService) NewAvailable(ctx context.Context) (*types.Notifica } // ListRepo returns all notifications for a specific repository. -func (s *NotificationService) ListRepo(ctx context.Context, owner, repo string) ([]types.NotificationThread, error) { +func (s *NotificationService) ListRepo(ctx context.Context, owner, repo string, filters ...NotificationListOptions) ([]types.NotificationThread, error) { path := ResolvePath("/api/v1/repos/{owner}/{repo}/notifications", pathParams("owner", owner, "repo", repo)) - return ListAll[types.NotificationThread](ctx, s.client, path, nil) + return s.listAll(ctx, path, filters...) } // IterRepo returns an iterator over all notifications for a specific repository. -func (s *NotificationService) IterRepo(ctx context.Context, owner, repo string) iter.Seq2[types.NotificationThread, error] { +func (s *NotificationService) IterRepo(ctx context.Context, owner, repo string, filters ...NotificationListOptions) iter.Seq2[types.NotificationThread, error] { path := ResolvePath("/api/v1/repos/{owner}/{repo}/notifications", pathParams("owner", owner, "repo", repo)) - return ListIter[types.NotificationThread](ctx, s.client, path, nil) + return s.listIter(ctx, path, filters...) } // MarkRepoNotifications marks repository notification threads as read, unread, or pinned. @@ -117,3 +150,81 @@ func (s *NotificationService) MarkThreadRead(ctx context.Context, id int64) erro path := ResolvePath("/api/v1/notifications/threads/{id}", pathParams("id", int64String(id))) return s.client.Patch(ctx, path, nil, nil) } + +func (s *NotificationService) listAll(ctx context.Context, path string, filters ...NotificationListOptions) ([]types.NotificationThread, error) { + var all []types.NotificationThread + page := 1 + + for { + result, err := s.listPage(ctx, path, ListOptions{Page: page, Limit: 50}, filters...) + if err != nil { + return nil, err + } + all = append(all, result.Items...) + if !result.HasMore { + break + } + page++ + } + + return all, nil +} + +func (s *NotificationService) listIter(ctx context.Context, path string, filters ...NotificationListOptions) iter.Seq2[types.NotificationThread, error] { + return func(yield func(types.NotificationThread, error) bool) { + page := 1 + for { + result, err := s.listPage(ctx, path, ListOptions{Page: page, Limit: 50}, filters...) + if err != nil { + yield(*new(types.NotificationThread), err) + return + } + for _, item := range result.Items { + if !yield(item, nil) { + return + } + } + if !result.HasMore { + break + } + page++ + } + } +} + +func (s *NotificationService) listPage(ctx context.Context, path string, opts ListOptions, filters ...NotificationListOptions) (*PagedResult[types.NotificationThread], error) { + if opts.Page < 1 { + opts.Page = 1 + } + if opts.Limit < 1 { + opts.Limit = 50 + } + + u, err := url.Parse(path) + if err != nil { + return nil, err + } + + values := u.Query() + values.Set("page", strconv.Itoa(opts.Page)) + values.Set("limit", strconv.Itoa(opts.Limit)) + for _, filter := range filters { + filter.addQuery(values) + } + u.RawQuery = values.Encode() + + var items []types.NotificationThread + resp, err := s.client.doJSON(ctx, http.MethodGet, u.String(), nil, &items) + if err != nil { + return nil, err + } + + totalCount, _ := strconv.Atoi(resp.Header.Get("X-Total-Count")) + return &PagedResult[types.NotificationThread]{ + Items: items, + TotalCount: totalCount, + Page: opts.Page, + HasMore: (totalCount > 0 && (opts.Page-1)*opts.Limit+len(items) < totalCount) || + (totalCount == 0 && len(items) >= opts.Limit), + }, nil +} diff --git a/notifications_test.go b/notifications_test.go index b9718b9..79ee76b 100644 --- a/notifications_test.go +++ b/notifications_test.go @@ -46,6 +46,54 @@ func TestNotificationService_List_Good(t *testing.T) { } } +func TestNotificationService_List_Filters(t *testing.T) { + since := time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC) + before := time.Date(2026, time.April, 2, 12, 0, 0, 0, time.UTC) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("expected GET, got %s", r.Method) + } + if r.URL.Path != "/api/v1/notifications" { + t.Errorf("wrong path: %s", r.URL.Path) + } + if got := r.URL.Query().Get("all"); got != "true" { + t.Errorf("got all=%q, want true", got) + } + if got := r.URL.Query()["status-types"]; len(got) != 2 || got[0] != "unread" || got[1] != "pinned" { + t.Errorf("got status-types=%v, want [unread pinned]", got) + } + if got := r.URL.Query()["subject-type"]; len(got) != 2 || got[0] != "issue" || got[1] != "pull" { + t.Errorf("got subject-type=%v, want [issue pull]", got) + } + if got := r.URL.Query().Get("since"); got != since.Format(time.RFC3339) { + t.Errorf("got since=%q, want %q", got, since.Format(time.RFC3339)) + } + if got := r.URL.Query().Get("before"); got != before.Format(time.RFC3339) { + t.Errorf("got before=%q, want %q", got, before.Format(time.RFC3339)) + } + w.Header().Set("X-Total-Count", "1") + json.NewEncoder(w).Encode([]types.NotificationThread{ + {ID: 11, Unread: true, Subject: &types.NotificationSubject{Title: "Filtered"}}, + }) + })) + defer srv.Close() + + f := NewForge(srv.URL, "tok") + threads, err := f.Notifications.List(context.Background(), NotificationListOptions{ + All: true, + StatusTypes: []string{"unread", "pinned"}, + SubjectTypes: []string{"issue", "pull"}, + Since: &since, + Before: &before, + }) + if err != nil { + t.Fatal(err) + } + if len(threads) != 1 || threads[0].ID != 11 { + t.Fatalf("got threads=%+v", threads) + } +} + func TestNotificationService_ListRepo_Good(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { @@ -74,6 +122,45 @@ func TestNotificationService_ListRepo_Good(t *testing.T) { } } +func TestNotificationService_ListRepo_Filters(t *testing.T) { + since := time.Date(2026, time.April, 1, 12, 0, 0, 0, time.UTC) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("expected GET, got %s", r.Method) + } + if r.URL.Path != "/api/v1/repos/core/go-forge/notifications" { + t.Errorf("wrong path: %s", r.URL.Path) + } + if got := r.URL.Query()["status-types"]; len(got) != 1 || got[0] != "read" { + t.Errorf("got status-types=%v, want [read]", got) + } + if got := r.URL.Query()["subject-type"]; len(got) != 1 || got[0] != "repository" { + t.Errorf("got subject-type=%v, want [repository]", got) + } + if got := r.URL.Query().Get("since"); got != since.Format(time.RFC3339) { + t.Errorf("got since=%q, want %q", got, since.Format(time.RFC3339)) + } + w.Header().Set("X-Total-Count", "1") + json.NewEncoder(w).Encode([]types.NotificationThread{ + {ID: 12, Unread: false, Subject: &types.NotificationSubject{Title: "Repo filtered"}}, + }) + })) + defer srv.Close() + + f := NewForge(srv.URL, "tok") + threads, err := f.Notifications.ListRepo(context.Background(), "core", "go-forge", NotificationListOptions{ + StatusTypes: []string{"read"}, + SubjectTypes: []string{"repository"}, + Since: &since, + }) + if err != nil { + t.Fatal(err) + } + if len(threads) != 1 || threads[0].ID != 12 { + t.Fatalf("got threads=%+v", threads) + } +} + func TestNotificationService_NewAvailable_Good(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet {