feat(dns): tighten action argument parsing
Co-Authored-By: Virgil <virgil@lethean.io>
This commit is contained in:
parent
8e87a2c7be
commit
1195dbb596
2 changed files with 112 additions and 8 deletions
78
action.go
78
action.go
|
|
@ -2,8 +2,11 @@ package dns
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -21,6 +24,13 @@ var (
|
|||
errActionMissingValue = errors.New("dns action missing required value")
|
||||
)
|
||||
|
||||
const (
|
||||
actionArgBind = "bind"
|
||||
actionArgIP = "ip"
|
||||
actionArgName = "name"
|
||||
actionArgPort = "port"
|
||||
)
|
||||
|
||||
type ActionDefinition struct {
|
||||
Name string
|
||||
Invoke func(map[string]any) (any, bool, error)
|
||||
|
|
@ -221,7 +231,7 @@ func (service *Service) HandleActionContext(ctx context.Context, name string, va
|
|||
|
||||
func (service *Service) handleResolveAddress(ctx context.Context, values map[string]any) (any, bool, error) {
|
||||
_ = ctx
|
||||
host, err := stringActionValue(values, "name")
|
||||
host, err := stringActionValue(values, actionArgName)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
|
@ -234,7 +244,7 @@ func (service *Service) handleResolveAddress(ctx context.Context, values map[str
|
|||
|
||||
func (service *Service) handleResolveTXTRecords(ctx context.Context, values map[string]any) (any, bool, error) {
|
||||
_ = ctx
|
||||
host, err := stringActionValue(values, "name")
|
||||
host, err := stringActionValue(values, actionArgName)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
|
@ -247,7 +257,7 @@ func (service *Service) handleResolveTXTRecords(ctx context.Context, values map[
|
|||
|
||||
func (service *Service) handleResolveAll(ctx context.Context, values map[string]any) (any, bool, error) {
|
||||
_ = ctx
|
||||
host, err := stringActionValue(values, "name")
|
||||
host, err := stringActionValue(values, actionArgName)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
|
@ -260,7 +270,7 @@ func (service *Service) handleResolveAll(ctx context.Context, values map[string]
|
|||
|
||||
func (service *Service) handleReverseLookup(ctx context.Context, values map[string]any) (any, bool, error) {
|
||||
_ = ctx
|
||||
ip, err := stringActionValue(values, "ip")
|
||||
ip, err := stringActionValue(values, actionArgIP)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
|
@ -273,8 +283,8 @@ func (service *Service) handleReverseLookup(ctx context.Context, values map[stri
|
|||
|
||||
func (service *Service) handleServe(ctx context.Context, values map[string]any) (any, bool, error) {
|
||||
_ = ctx
|
||||
bind, _ := stringActionValueOptional(values, "bind")
|
||||
port, portProvided, err := intActionValueOptional(values, "port")
|
||||
bind, _ := stringActionValueOptional(values, actionArgBind)
|
||||
port, portProvided, err := intActionValueOptional(values, actionArgPort)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
|
@ -297,6 +307,10 @@ func stringActionValue(values map[string]any, key string) (string, error) {
|
|||
return "", errActionMissingValue
|
||||
}
|
||||
if value, ok := raw.(string); ok {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
return "", errActionMissingValue
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
return "", errActionMissingValue
|
||||
|
|
@ -314,7 +328,7 @@ func stringActionValueOptional(values map[string]any, key string) (string, error
|
|||
if !ok {
|
||||
return "", fmt.Errorf("%w: %s", errActionMissingValue, key)
|
||||
}
|
||||
return value, nil
|
||||
return strings.TrimSpace(value), nil
|
||||
}
|
||||
|
||||
func intActionValue(values map[string]any, key string) (int, error) {
|
||||
|
|
@ -328,14 +342,62 @@ func intActionValue(values map[string]any, key string) (int, error) {
|
|||
switch value := raw.(type) {
|
||||
case int:
|
||||
return value, nil
|
||||
case uint:
|
||||
return int(value), nil
|
||||
case uint8:
|
||||
return int(value), nil
|
||||
case uint16:
|
||||
return int(value), nil
|
||||
case uint32:
|
||||
return int(value), nil
|
||||
case uint64:
|
||||
if value > uint64(^uint(0)>>1) {
|
||||
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
|
||||
}
|
||||
return int(value), nil
|
||||
case int32:
|
||||
return int(value), nil
|
||||
case int64:
|
||||
if value > int64(int(^uint(0)>>1)) || value < int64(^uint(0)>>1)*-1-1 {
|
||||
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
|
||||
}
|
||||
return int(value), nil
|
||||
case float64:
|
||||
if math.Trunc(value) != value {
|
||||
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
|
||||
}
|
||||
if value < 0 || value > float64(int(^uint(0)>>1)) {
|
||||
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
|
||||
}
|
||||
return int(value), nil
|
||||
case float32:
|
||||
return int(value), nil
|
||||
floating := float64(value)
|
||||
if math.Trunc(floating) != floating {
|
||||
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
|
||||
}
|
||||
if floating < 0 || floating > float64(int(^uint(0)>>1)) {
|
||||
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
|
||||
}
|
||||
return int(floating), nil
|
||||
case json.Number:
|
||||
parsed, err := value.Int64()
|
||||
if err == nil {
|
||||
if parsed > int64(int(^uint(0)>>1)) || parsed < int64(^uint(0)>>1)*-1-1 {
|
||||
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
|
||||
}
|
||||
return int(parsed), nil
|
||||
}
|
||||
floating, parseErr := value.Float64()
|
||||
if parseErr != nil {
|
||||
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
|
||||
}
|
||||
if math.Trunc(floating) != floating {
|
||||
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
|
||||
}
|
||||
if floating < 0 || floating > float64(int(^uint(0)>>1)) {
|
||||
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
|
||||
}
|
||||
return int(floating), nil
|
||||
default:
|
||||
return 0, fmt.Errorf("%w: %s", errActionMissingValue, key)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2499,6 +2499,48 @@ func TestServiceHandleActionContextPassesThroughToDiscover(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestStringActionValueTrimsWhitespaceForRequiredArgument(t *testing.T) {
|
||||
value, err := stringActionValue(map[string]any{
|
||||
actionArgName: " gateway.charon.lthn ",
|
||||
}, actionArgName)
|
||||
if err != nil {
|
||||
t.Fatalf("expected trimmed string value, got error: %v", err)
|
||||
}
|
||||
if value != "gateway.charon.lthn" {
|
||||
t.Fatalf("expected trimmed value, got %q", value)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringActionValueRejectsWhitespaceOnlyArgument(t *testing.T) {
|
||||
_, err := stringActionValue(map[string]any{
|
||||
actionArgName: " ",
|
||||
}, actionArgName)
|
||||
if err == nil {
|
||||
t.Fatal("expected whitespace-only argument to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntActionValueRejectsNonIntegerFloat(t *testing.T) {
|
||||
_, err := intActionValue(map[string]any{
|
||||
actionArgPort: 53.9,
|
||||
}, actionArgPort)
|
||||
if err == nil {
|
||||
t.Fatal("expected non-integer float value to be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntActionValueAcceptsWholeFloat(t *testing.T) {
|
||||
value, err := intActionValue(map[string]any{
|
||||
actionArgPort: float64(53),
|
||||
}, actionArgPort)
|
||||
if err != nil {
|
||||
t.Fatalf("expected whole float to be accepted: %v", err)
|
||||
}
|
||||
if value != 53 {
|
||||
t.Fatalf("expected value 53, got %d", value)
|
||||
}
|
||||
}
|
||||
|
||||
type actionRecorder struct {
|
||||
names []string
|
||||
handlers map[string]func(map[string]any) (any, bool, error)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue