From 72fb4e3b8e0c55185c3cbdc8d3d4236882fc39ba Mon Sep 17 00:00:00 2001 From: Snider Date: Mon, 9 Mar 2026 08:31:16 +0000 Subject: [PATCH] fix: improve API client error handling and pagination Co-Authored-By: Claude Opus 4.6 --- admin.go | 11 +++---- client.go | 91 +++++++++++++++++++++++++++++++++++++++++---------- pagination.go | 42 +++++++----------------- 3 files changed, 89 insertions(+), 55 deletions(-) diff --git a/admin.go b/admin.go index 637d201..f621251 100644 --- a/admin.go +++ b/admin.go @@ -2,7 +2,6 @@ package forge import ( "context" - "fmt" "iter" "forge.lthn.ai/core/go-forge/types" @@ -40,19 +39,19 @@ func (s *AdminService) CreateUser(ctx context.Context, opts *types.CreateUserOpt // EditUser edits an existing user (admin only). func (s *AdminService) EditUser(ctx context.Context, username string, opts map[string]any) error { - path := fmt.Sprintf("/api/v1/admin/users/%s", username) + path := ResolvePath("/api/v1/admin/users/{username}", Params{"username": username}) return s.client.Patch(ctx, path, opts, nil) } // DeleteUser deletes a user (admin only). func (s *AdminService) DeleteUser(ctx context.Context, username string) error { - path := fmt.Sprintf("/api/v1/admin/users/%s", username) + path := ResolvePath("/api/v1/admin/users/{username}", Params{"username": username}) return s.client.Delete(ctx, path) } // RenameUser renames a user (admin only). func (s *AdminService) RenameUser(ctx context.Context, username, newName string) error { - path := fmt.Sprintf("/api/v1/admin/users/%s/rename", username) + path := ResolvePath("/api/v1/admin/users/{username}/rename", Params{"username": username}) return s.client.Post(ctx, path, &types.RenameUserOption{NewName: newName}, nil) } @@ -68,7 +67,7 @@ func (s *AdminService) IterOrgs(ctx context.Context) iter.Seq2[types.Organizatio // RunCron runs a cron task by name (admin only). func (s *AdminService) RunCron(ctx context.Context, task string) error { - path := fmt.Sprintf("/api/v1/admin/cron/%s", task) + path := ResolvePath("/api/v1/admin/cron/{task}", Params{"task": task}) return s.client.Post(ctx, path, nil, nil) } @@ -84,7 +83,7 @@ func (s *AdminService) IterCron(ctx context.Context) iter.Seq2[types.Cron, error // AdoptRepo adopts an unadopted repository (admin only). func (s *AdminService) AdoptRepo(ctx context.Context, owner, repo string) error { - path := fmt.Sprintf("/api/v1/admin/unadopted/%s/%s", owner, repo) + path := ResolvePath("/api/v1/admin/unadopted/{owner}/{repo}", Params{"owner": owner, "repo": repo}) return s.client.Post(ctx, path, nil, nil) } diff --git a/client.go b/client.go index 715cd72..24ec77e 100644 --- a/client.go +++ b/client.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "strconv" "strings" ) @@ -53,21 +54,38 @@ func WithUserAgent(ua string) Option { return func(c *Client) { c.userAgent = ua } } +// RateLimit represents the rate limit information from the Forgejo API. +type RateLimit struct { + Limit int + Remaining int + Reset int64 +} + // Client is a low-level HTTP client for the Forgejo API. type Client struct { baseURL string token string httpClient *http.Client userAgent string + rateLimit RateLimit +} + +// RateLimit returns the last known rate limit information. +func (c *Client) RateLimit() RateLimit { + return c.rateLimit } // NewClient creates a new Forgejo API client. func NewClient(url, token string, opts ...Option) *Client { c := &Client{ - baseURL: strings.TrimRight(url, "/"), - token: token, - httpClient: http.DefaultClient, - userAgent: "go-forge/0.1", + baseURL: strings.TrimRight(url, "/"), + token: token, + httpClient: &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + }, + userAgent: "go-forge/0.1", } for _, opt := range opts { opt(c) @@ -77,32 +95,38 @@ func NewClient(url, token string, opts ...Option) *Client { // Get performs a GET request. func (c *Client) Get(ctx context.Context, path string, out any) error { - return c.do(ctx, http.MethodGet, path, nil, out) + _, err := c.doJSON(ctx, http.MethodGet, path, nil, out) + return err } // Post performs a POST request. func (c *Client) Post(ctx context.Context, path string, body, out any) error { - return c.do(ctx, http.MethodPost, path, body, out) + _, err := c.doJSON(ctx, http.MethodPost, path, body, out) + return err } // Patch performs a PATCH request. func (c *Client) Patch(ctx context.Context, path string, body, out any) error { - return c.do(ctx, http.MethodPatch, path, body, out) + _, err := c.doJSON(ctx, http.MethodPatch, path, body, out) + return err } // Put performs a PUT request. func (c *Client) Put(ctx context.Context, path string, body, out any) error { - return c.do(ctx, http.MethodPut, path, body, out) + _, err := c.doJSON(ctx, http.MethodPut, path, body, out) + return err } // Delete performs a DELETE request. func (c *Client) Delete(ctx context.Context, path string) error { - return c.do(ctx, http.MethodDelete, path, nil, nil) + _, err := c.doJSON(ctx, http.MethodDelete, path, nil, nil) + return err } // DeleteWithBody performs a DELETE request with a JSON body. func (c *Client) DeleteWithBody(ctx context.Context, path string, body any) error { - return c.do(ctx, http.MethodDelete, path, body, nil) + _, err := c.doJSON(ctx, http.MethodDelete, path, body, nil) + return err } // PostRaw performs a POST request with a JSON body and returns the raw @@ -183,20 +207,25 @@ func (c *Client) GetRaw(ctx context.Context, path string) ([]byte, error) { } func (c *Client) do(ctx context.Context, method, path string, body, out any) error { + _, err := c.doJSON(ctx, method, path, body, out) + return err +} + +func (c *Client) doJSON(ctx context.Context, method, path string, body, out any) (*http.Response, error) { url := c.baseURL + path var bodyReader io.Reader if body != nil { data, err := json.Marshal(body) if err != nil { - return fmt.Errorf("forge: marshal body: %w", err) + return nil, fmt.Errorf("forge: marshal body: %w", err) } bodyReader = bytes.NewReader(data) } req, err := http.NewRequestWithContext(ctx, method, url, bodyReader) if err != nil { - return fmt.Errorf("forge: create request: %w", err) + return nil, fmt.Errorf("forge: create request: %w", err) } req.Header.Set("Authorization", "token "+c.token) @@ -210,31 +239,57 @@ func (c *Client) do(ctx context.Context, method, path string, body, out any) err resp, err := c.httpClient.Do(req) if err != nil { - return fmt.Errorf("forge: request %s %s: %w", method, path, err) + return nil, fmt.Errorf("forge: request %s %s: %w", method, path, err) } defer resp.Body.Close() + c.updateRateLimit(resp) + if resp.StatusCode >= 400 { - return c.parseError(resp, path) + return nil, c.parseError(resp, path) } if out != nil && resp.StatusCode != http.StatusNoContent { if err := json.NewDecoder(resp.Body).Decode(out); err != nil { - return fmt.Errorf("forge: decode response: %w", err) + return nil, fmt.Errorf("forge: decode response: %w", err) } } - return nil + return resp, nil } func (c *Client) parseError(resp *http.Response, path string) error { var errBody struct { Message string `json:"message"` } - _ = json.NewDecoder(resp.Body).Decode(&errBody) + + // Read a bit of the body to see if we can get a message + data, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + _ = json.Unmarshal(data, &errBody) + + msg := errBody.Message + if msg == "" && len(data) > 0 { + msg = string(data) + } + if msg == "" { + msg = http.StatusText(resp.StatusCode) + } + return &APIError{ StatusCode: resp.StatusCode, - Message: errBody.Message, + Message: msg, URL: path, } } + +func (c *Client) updateRateLimit(resp *http.Response) { + if limit := resp.Header.Get("X-RateLimit-Limit"); limit != "" { + c.rateLimit.Limit, _ = strconv.Atoi(limit) + } + if remaining := resp.Header.Get("X-RateLimit-Remaining"); remaining != "" { + c.rateLimit.Remaining, _ = strconv.Atoi(remaining) + } + if reset := resp.Header.Get("X-RateLimit-Reset"); reset != "" { + c.rateLimit.Reset, _ = strconv.ParseInt(reset, 10, 64) + } +} diff --git a/pagination.go b/pagination.go index d5f9cc6..b978694 100644 --- a/pagination.go +++ b/pagination.go @@ -2,7 +2,6 @@ package forge import ( "context" - "encoding/json" "fmt" "iter" "net/http" @@ -37,9 +36,9 @@ func ListPage[T any](ctx context.Context, c *Client, path string, query map[stri opts.Limit = 50 } - u, err := url.Parse(c.baseURL + path) + u, err := url.Parse(path) if err != nil { - return nil, fmt.Errorf("forge: parse url: %w", err) + return nil, fmt.Errorf("forge: parse path: %w", err) } q := u.Query() @@ -50,30 +49,10 @@ func ListPage[T any](ctx context.Context, c *Client, path string, query map[stri } u.RawQuery = q.Encode() - req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) - if err != nil { - return nil, fmt.Errorf("forge: create request: %w", err) - } - - req.Header.Set("Authorization", "token "+c.token) - req.Header.Set("Accept", "application/json") - if c.userAgent != "" { - req.Header.Set("User-Agent", c.userAgent) - } - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("forge: request GET %s: %w", path, err) - } - defer resp.Body.Close() - - if resp.StatusCode >= 400 { - return nil, c.parseError(resp, path) - } - var items []T - if err := json.NewDecoder(resp.Body).Decode(&items); err != nil { - return nil, fmt.Errorf("forge: decode response: %w", err) + resp, err := c.doJSON(ctx, http.MethodGet, u.String(), nil, &items) + if err != nil { + return nil, err } totalCount, _ := strconv.Atoi(resp.Header.Get("X-Total-Count")) @@ -82,7 +61,10 @@ func ListPage[T any](ctx context.Context, c *Client, path string, query map[stri Items: items, TotalCount: totalCount, Page: opts.Page, - HasMore: len(items) >= opts.Limit && opts.Page*opts.Limit < totalCount, + // If totalCount is provided, use it to determine if there are more items. + // Otherwise, assume there are more if we got a full page. + HasMore: (totalCount > 0 && (opts.Page-1)*opts.Limit+len(items) < totalCount) || + (totalCount == 0 && len(items) >= opts.Limit), }, nil } @@ -97,7 +79,7 @@ func ListAll[T any](ctx context.Context, c *Client, path string, query map[strin return nil, err } all = append(all, result.Items...) - if len(result.Items) == 0 || len(all) >= result.TotalCount { + if !result.HasMore { break } page++ @@ -110,7 +92,6 @@ func ListAll[T any](ctx context.Context, c *Client, path string, query map[strin func ListIter[T any](ctx context.Context, c *Client, path string, query map[string]string) iter.Seq2[T, error] { return func(yield func(T, error) bool) { page := 1 - count := 0 for { result, err := ListPage[T](ctx, c, path, query, ListOptions{Page: page, Limit: 50}) if err != nil { @@ -121,9 +102,8 @@ func ListIter[T any](ctx context.Context, c *Client, path string, query map[stri if !yield(item, nil) { return } - count++ } - if len(result.Items) == 0 || count >= result.TotalCount { + if !result.HasMore { break } page++