security: prevent path traversal in cache key resolution

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Snider 2026-03-09 08:28:55 +00:00
parent dbf7719032
commit 32ede3b495

View file

@ -4,8 +4,10 @@ package cache
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"forge.lthn.ai/core/go-io"
@ -62,13 +64,33 @@ func New(medium io.Medium, baseDir string, ttl time.Duration) (*Cache, error) {
}
// Path returns the full path for a cache key.
func (c *Cache) Path(key string) string {
return filepath.Join(c.baseDir, key+".json")
// Returns an error if the key attempts path traversal.
func (c *Cache) Path(key string) (string, error) {
path := filepath.Join(c.baseDir, key+".json")
// Ensure the resulting path is still within baseDir to prevent traversal attacks
absBase, err := filepath.Abs(c.baseDir)
if err != nil {
return "", fmt.Errorf("failed to get absolute path for baseDir: %w", err)
}
absPath, err := filepath.Abs(path)
if err != nil {
return "", fmt.Errorf("failed to get absolute path for key: %w", err)
}
if !strings.HasPrefix(absPath, absBase) {
return "", fmt.Errorf("invalid cache key: path traversal attempt")
}
return path, nil
}
// Get retrieves a cached item if it exists and hasn't expired.
func (c *Cache) Get(key string, dest any) (bool, error) {
path := c.Path(key)
path, err := c.Path(key)
if err != nil {
return false, err
}
dataStr, err := c.medium.Read(path)
if err != nil {
@ -99,7 +121,10 @@ func (c *Cache) Get(key string, dest any) (bool, error) {
// Set stores an item in the cache.
func (c *Cache) Set(key string, data any) error {
path := c.Path(key)
path, err := c.Path(key)
if err != nil {
return err
}
// Ensure parent directory exists
if err := c.medium.EnsureDir(filepath.Dir(path)); err != nil {
@ -128,8 +153,12 @@ func (c *Cache) Set(key string, data any) error {
// Delete removes an item from the cache.
func (c *Cache) Delete(key string) error {
path := c.Path(key)
err := c.medium.Delete(path)
path, err := c.Path(key)
if err != nil {
return err
}
err = c.medium.Delete(path)
if errors.Is(err, os.ErrNotExist) {
return nil
}
@ -143,7 +172,10 @@ func (c *Cache) Clear() error {
// Age returns how old a cached item is, or -1 if not cached.
func (c *Cache) Age(key string) time.Duration {
path := c.Path(key)
path, err := c.Path(key)
if err != nil {
return -1
}
dataStr, err := c.medium.Read(path)
if err != nil {