feat(dns): tighten action argument parsing

Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
Virgil 2026-04-03 23:00:07 +00:00
parent 8e87a2c7be
commit 1195dbb596
2 changed files with 112 additions and 8 deletions

View file

@ -2,8 +2,11 @@ package dns
import (
"context"
"encoding/json"
"errors"
"fmt"
"math"
"strings"
)
const (
@ -21,6 +24,13 @@ var (
errActionMissingValue = errors.New("dns action missing required value")
)
const (
actionArgBind = "bind"
actionArgIP = "ip"
actionArgName = "name"
actionArgPort = "port"
)
type ActionDefinition struct {
Name string
Invoke func(map[string]any) (any, bool, error)
@ -221,7 +231,7 @@ func (service *Service) HandleActionContext(ctx context.Context, name string, va
func (service *Service) handleResolveAddress(ctx context.Context, values map[string]any) (any, bool, error) {
_ = ctx
host, err := stringActionValue(values, "name")
host, err := stringActionValue(values, actionArgName)
if err != nil {
return nil, false, err
}
@ -234,7 +244,7 @@ func (service *Service) handleResolveAddress(ctx context.Context, values map[str
func (service *Service) handleResolveTXTRecords(ctx context.Context, values map[string]any) (any, bool, error) {
_ = ctx
host, err := stringActionValue(values, "name")
host, err := stringActionValue(values, actionArgName)
if err != nil {
return nil, false, err
}
@ -247,7 +257,7 @@ func (service *Service) handleResolveTXTRecords(ctx context.Context, values map[
func (service *Service) handleResolveAll(ctx context.Context, values map[string]any) (any, bool, error) {
_ = ctx
host, err := stringActionValue(values, "name")
host, err := stringActionValue(values, actionArgName)
if err != nil {
return nil, false, err
}
@ -260,7 +270,7 @@ func (service *Service) handleResolveAll(ctx context.Context, values map[string]
func (service *Service) handleReverseLookup(ctx context.Context, values map[string]any) (any, bool, error) {
_ = ctx
ip, err := stringActionValue(values, "ip")
ip, err := stringActionValue(values, actionArgIP)
if err != nil {
return nil, false, err
}
@ -273,8 +283,8 @@ func (service *Service) handleReverseLookup(ctx context.Context, values map[stri
func (service *Service) handleServe(ctx context.Context, values map[string]any) (any, bool, error) {
_ = ctx
bind, _ := stringActionValueOptional(values, "bind")
port, portProvided, err := intActionValueOptional(values, "port")
bind, _ := stringActionValueOptional(values, actionArgBind)
port, portProvided, err := intActionValueOptional(values, actionArgPort)
if err != nil {
return nil, false, err
}
@ -297,6 +307,10 @@ func stringActionValue(values map[string]any, key string) (string, error) {
return "", errActionMissingValue
}
if value, ok := raw.(string); ok {
value = strings.TrimSpace(value)
if value == "" {
return "", errActionMissingValue
}
return value, nil
}
return "", errActionMissingValue
@ -314,7 +328,7 @@ func stringActionValueOptional(values map[string]any, key string) (string, error
if !ok {
return "", fmt.Errorf("%w: %s", errActionMissingValue, key)
}
return value, nil
return strings.TrimSpace(value), nil
}
func intActionValue(values map[string]any, key string) (int, error) {
@ -328,14 +342,62 @@ func intActionValue(values map[string]any, key string) (int, error) {
switch value := raw.(type) {
case int:
return value, nil
case uint:
return int(value), nil
case uint8:
return int(value), nil
case uint16:
return int(value), nil
case uint32:
return int(value), nil
case uint64:
if value > uint64(^uint(0)>>1) {
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
}
return int(value), nil
case int32:
return int(value), nil
case int64:
if value > int64(int(^uint(0)>>1)) || value < int64(^uint(0)>>1)*-1-1 {
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
}
return int(value), nil
case float64:
if math.Trunc(value) != value {
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
}
if value < 0 || value > float64(int(^uint(0)>>1)) {
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
}
return int(value), nil
case float32:
return int(value), nil
floating := float64(value)
if math.Trunc(floating) != floating {
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
}
if floating < 0 || floating > float64(int(^uint(0)>>1)) {
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
}
return int(floating), nil
case json.Number:
parsed, err := value.Int64()
if err == nil {
if parsed > int64(int(^uint(0)>>1)) || parsed < int64(^uint(0)>>1)*-1-1 {
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
}
return int(parsed), nil
}
floating, parseErr := value.Float64()
if parseErr != nil {
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
}
if math.Trunc(floating) != floating {
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
}
if floating < 0 || floating > float64(int(^uint(0)>>1)) {
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
}
return int(floating), nil
default:
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
}

View file

@ -2499,6 +2499,48 @@ func TestServiceHandleActionContextPassesThroughToDiscover(t *testing.T) {
}
}
func TestStringActionValueTrimsWhitespaceForRequiredArgument(t *testing.T) {
value, err := stringActionValue(map[string]any{
actionArgName: " gateway.charon.lthn ",
}, actionArgName)
if err != nil {
t.Fatalf("expected trimmed string value, got error: %v", err)
}
if value != "gateway.charon.lthn" {
t.Fatalf("expected trimmed value, got %q", value)
}
}
func TestStringActionValueRejectsWhitespaceOnlyArgument(t *testing.T) {
_, err := stringActionValue(map[string]any{
actionArgName: " ",
}, actionArgName)
if err == nil {
t.Fatal("expected whitespace-only argument to be rejected")
}
}
func TestIntActionValueRejectsNonIntegerFloat(t *testing.T) {
_, err := intActionValue(map[string]any{
actionArgPort: 53.9,
}, actionArgPort)
if err == nil {
t.Fatal("expected non-integer float value to be rejected")
}
}
func TestIntActionValueAcceptsWholeFloat(t *testing.T) {
value, err := intActionValue(map[string]any{
actionArgPort: float64(53),
}, actionArgPort)
if err != nil {
t.Fatalf("expected whole float to be accepted: %v", err)
}
if value != 53 {
t.Fatalf("expected value 53, got %d", value)
}
}
type actionRecorder struct {
names []string
handlers map[string]func(map[string]any) (any, bool, error)