diff --git a/cache.go b/cache.go index 030fe1b..31a08d2 100644 --- a/cache.go +++ b/cache.go @@ -4,8 +4,10 @@ package cache import ( "encoding/json" "errors" + "fmt" "os" "path/filepath" + "strings" "time" "forge.lthn.ai/core/go-io" @@ -62,13 +64,33 @@ func New(medium io.Medium, baseDir string, ttl time.Duration) (*Cache, error) { } // Path returns the full path for a cache key. -func (c *Cache) Path(key string) string { - return filepath.Join(c.baseDir, key+".json") +// Returns an error if the key attempts path traversal. +func (c *Cache) Path(key string) (string, error) { + path := filepath.Join(c.baseDir, key+".json") + + // Ensure the resulting path is still within baseDir to prevent traversal attacks + absBase, err := filepath.Abs(c.baseDir) + if err != nil { + return "", fmt.Errorf("failed to get absolute path for baseDir: %w", err) + } + absPath, err := filepath.Abs(path) + if err != nil { + return "", fmt.Errorf("failed to get absolute path for key: %w", err) + } + + if !strings.HasPrefix(absPath, absBase) { + return "", fmt.Errorf("invalid cache key: path traversal attempt") + } + + return path, nil } // Get retrieves a cached item if it exists and hasn't expired. func (c *Cache) Get(key string, dest any) (bool, error) { - path := c.Path(key) + path, err := c.Path(key) + if err != nil { + return false, err + } dataStr, err := c.medium.Read(path) if err != nil { @@ -99,7 +121,10 @@ func (c *Cache) Get(key string, dest any) (bool, error) { // Set stores an item in the cache. func (c *Cache) Set(key string, data any) error { - path := c.Path(key) + path, err := c.Path(key) + if err != nil { + return err + } // Ensure parent directory exists if err := c.medium.EnsureDir(filepath.Dir(path)); err != nil { @@ -128,8 +153,12 @@ func (c *Cache) Set(key string, data any) error { // Delete removes an item from the cache. func (c *Cache) Delete(key string) error { - path := c.Path(key) - err := c.medium.Delete(path) + path, err := c.Path(key) + if err != nil { + return err + } + + err = c.medium.Delete(path) if errors.Is(err, os.ErrNotExist) { return nil } @@ -143,7 +172,10 @@ func (c *Cache) Clear() error { // Age returns how old a cached item is, or -1 if not cached. func (c *Cache) Age(key string) time.Duration { - path := c.Path(key) + path, err := c.Path(key) + if err != nil { + return -1 + } dataStr, err := c.medium.Read(path) if err != nil {