diff --git a/action.go b/action.go index 6bff52f..74486a4 100644 --- a/action.go +++ b/action.go @@ -2,8 +2,11 @@ package dns import ( "context" + "encoding/json" "errors" "fmt" + "math" + "strings" ) const ( @@ -21,6 +24,13 @@ var ( errActionMissingValue = errors.New("dns action missing required value") ) +const ( + actionArgBind = "bind" + actionArgIP = "ip" + actionArgName = "name" + actionArgPort = "port" +) + type ActionDefinition struct { Name string Invoke func(map[string]any) (any, bool, error) @@ -221,7 +231,7 @@ func (service *Service) HandleActionContext(ctx context.Context, name string, va func (service *Service) handleResolveAddress(ctx context.Context, values map[string]any) (any, bool, error) { _ = ctx - host, err := stringActionValue(values, "name") + host, err := stringActionValue(values, actionArgName) if err != nil { return nil, false, err } @@ -234,7 +244,7 @@ func (service *Service) handleResolveAddress(ctx context.Context, values map[str func (service *Service) handleResolveTXTRecords(ctx context.Context, values map[string]any) (any, bool, error) { _ = ctx - host, err := stringActionValue(values, "name") + host, err := stringActionValue(values, actionArgName) if err != nil { return nil, false, err } @@ -247,7 +257,7 @@ func (service *Service) handleResolveTXTRecords(ctx context.Context, values map[ func (service *Service) handleResolveAll(ctx context.Context, values map[string]any) (any, bool, error) { _ = ctx - host, err := stringActionValue(values, "name") + host, err := stringActionValue(values, actionArgName) if err != nil { return nil, false, err } @@ -260,7 +270,7 @@ func (service *Service) handleResolveAll(ctx context.Context, values map[string] func (service *Service) handleReverseLookup(ctx context.Context, values map[string]any) (any, bool, error) { _ = ctx - ip, err := stringActionValue(values, "ip") + ip, err := stringActionValue(values, actionArgIP) if err != nil { return nil, false, err } @@ -273,8 +283,8 @@ func (service *Service) handleReverseLookup(ctx context.Context, values map[stri func (service *Service) handleServe(ctx context.Context, values map[string]any) (any, bool, error) { _ = ctx - bind, _ := stringActionValueOptional(values, "bind") - port, portProvided, err := intActionValueOptional(values, "port") + bind, _ := stringActionValueOptional(values, actionArgBind) + port, portProvided, err := intActionValueOptional(values, actionArgPort) if err != nil { return nil, false, err } @@ -297,6 +307,10 @@ func stringActionValue(values map[string]any, key string) (string, error) { return "", errActionMissingValue } if value, ok := raw.(string); ok { + value = strings.TrimSpace(value) + if value == "" { + return "", errActionMissingValue + } return value, nil } return "", errActionMissingValue @@ -314,7 +328,7 @@ func stringActionValueOptional(values map[string]any, key string) (string, error if !ok { return "", fmt.Errorf("%w: %s", errActionMissingValue, key) } - return value, nil + return strings.TrimSpace(value), nil } func intActionValue(values map[string]any, key string) (int, error) { @@ -328,14 +342,62 @@ func intActionValue(values map[string]any, key string) (int, error) { switch value := raw.(type) { case int: return value, nil + case uint: + return int(value), nil + case uint8: + return int(value), nil + case uint16: + return int(value), nil + case uint32: + return int(value), nil + case uint64: + if value > uint64(^uint(0)>>1) { + return 0, fmt.Errorf("%w: %s", errActionMissingValue, key) + } + return int(value), nil case int32: return int(value), nil case int64: + if value > int64(int(^uint(0)>>1)) || value < int64(^uint(0)>>1)*-1-1 { + return 0, fmt.Errorf("%w: %s", errActionMissingValue, key) + } return int(value), nil case float64: + if math.Trunc(value) != value { + return 0, fmt.Errorf("%w: %s", errActionMissingValue, key) + } + if value < 0 || value > float64(int(^uint(0)>>1)) { + return 0, fmt.Errorf("%w: %s", errActionMissingValue, key) + } return int(value), nil case float32: - return int(value), nil + floating := float64(value) + if math.Trunc(floating) != floating { + return 0, fmt.Errorf("%w: %s", errActionMissingValue, key) + } + if floating < 0 || floating > float64(int(^uint(0)>>1)) { + return 0, fmt.Errorf("%w: %s", errActionMissingValue, key) + } + return int(floating), nil + case json.Number: + parsed, err := value.Int64() + if err == nil { + if parsed > int64(int(^uint(0)>>1)) || parsed < int64(^uint(0)>>1)*-1-1 { + return 0, fmt.Errorf("%w: %s", errActionMissingValue, key) + } + return int(parsed), nil + } + floating, parseErr := value.Float64() + if parseErr != nil { + return 0, fmt.Errorf("%w: %s", errActionMissingValue, key) + } + if math.Trunc(floating) != floating { + return 0, fmt.Errorf("%w: %s", errActionMissingValue, key) + } + if floating < 0 || floating > float64(int(^uint(0)>>1)) { + return 0, fmt.Errorf("%w: %s", errActionMissingValue, key) + } + return int(floating), nil default: return 0, fmt.Errorf("%w: %s", errActionMissingValue, key) } diff --git a/service_test.go b/service_test.go index c61abb1..9f52e56 100644 --- a/service_test.go +++ b/service_test.go @@ -2499,6 +2499,48 @@ func TestServiceHandleActionContextPassesThroughToDiscover(t *testing.T) { } } +func TestStringActionValueTrimsWhitespaceForRequiredArgument(t *testing.T) { + value, err := stringActionValue(map[string]any{ + actionArgName: " gateway.charon.lthn ", + }, actionArgName) + if err != nil { + t.Fatalf("expected trimmed string value, got error: %v", err) + } + if value != "gateway.charon.lthn" { + t.Fatalf("expected trimmed value, got %q", value) + } +} + +func TestStringActionValueRejectsWhitespaceOnlyArgument(t *testing.T) { + _, err := stringActionValue(map[string]any{ + actionArgName: " ", + }, actionArgName) + if err == nil { + t.Fatal("expected whitespace-only argument to be rejected") + } +} + +func TestIntActionValueRejectsNonIntegerFloat(t *testing.T) { + _, err := intActionValue(map[string]any{ + actionArgPort: 53.9, + }, actionArgPort) + if err == nil { + t.Fatal("expected non-integer float value to be rejected") + } +} + +func TestIntActionValueAcceptsWholeFloat(t *testing.T) { + value, err := intActionValue(map[string]any{ + actionArgPort: float64(53), + }, actionArgPort) + if err != nil { + t.Fatalf("expected whole float to be accepted: %v", err) + } + if value != 53 { + t.Fatalf("expected value 53, got %d", value) + } +} + type actionRecorder struct { names []string handlers map[string]func(map[string]any) (any, bool, error)