From 600f555d20cbaf64674b12e7ac2accb48e5f6c4e Mon Sep 17 00:00:00 2001 From: Corneliu Radu <262nos@gmail.com> Date: Wed, 17 Apr 2024 18:39:49 +0300 Subject: [PATCH] fixes panic/stackoverflow when unparseable values passed to coerceNumeric (#43) --- evaluator/modifiers/modifiers.go | 15 +++++++++- evaluator/modifiers/modifiers_test.go | 43 +++++++++++++++++---------- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/evaluator/modifiers/modifiers.go b/evaluator/modifiers/modifiers.go index 27f1bfa..58be472 100644 --- a/evaluator/modifiers/modifiers.go +++ b/evaluator/modifiers/modifiers.go @@ -3,11 +3,12 @@ package modifiers import ( "encoding/base64" "fmt" - "gopkg.in/yaml.v3" "net" "reflect" "regexp" "strings" + + "gopkg.in/yaml.v3" ) func GetComparator(modifiers ...string) (ComparatorFunc, error) { @@ -239,6 +240,10 @@ func coerceString(v interface{}) string { // coerceNumeric makes both operands into the widest possible number of the same type func coerceNumeric(left, right interface{}) (interface{}, interface{}, error) { + // Check for nil interface, otherwise the function panics + if left == nil || right == nil { + return nil, nil, fmt.Errorf("cannot coerce %T and %T to numeric", left, right) + } leftV := reflect.ValueOf(left) leftType := reflect.ValueOf(left).Type() rightV := reflect.ValueOf(right) @@ -265,12 +270,20 @@ func coerceNumeric(left, right interface{}) (interface{}, interface{}, error) { if err := yaml.Unmarshal([]byte(left.(string)), &leftParsed); err != nil { return nil, nil, err } + //Check the parsed type is the correct one, otherwise we get a stack overflow + if reflect.TypeOf(leftParsed).Kind() != reflect.Float64 && reflect.TypeOf(leftParsed).Kind() != reflect.Int { + return nil, nil, fmt.Errorf("cannot coerce %T and %T to numeric", left, right) + } return coerceNumeric(leftParsed, right) case rightType.Kind() == reflect.String: var rightParsed interface{} if err := yaml.Unmarshal([]byte(right.(string)), &rightParsed); err != nil { return nil, nil, err } + //Check the parsed type is the correct one, otherwise we get a stack overflow + if reflect.TypeOf(rightParsed).Kind() != reflect.Float64 && reflect.TypeOf(rightParsed).Kind() != reflect.Int { + return nil, nil, fmt.Errorf("cannot coerce %T and %T to numeric", left, right) + } return coerceNumeric(left, rightParsed) default: diff --git a/evaluator/modifiers/modifiers_test.go b/evaluator/modifiers/modifiers_test.go index 9e041d6..2070e35 100644 --- a/evaluator/modifiers/modifiers_test.go +++ b/evaluator/modifiers/modifiers_test.go @@ -8,27 +8,40 @@ import ( func Test_compareNumeric(t *testing.T) { tests := []struct { - left interface{} - right interface{} - wantGt bool - wantGte bool - wantLt bool - wantLte bool + left interface{} + right interface{} + wantGt bool + wantGte bool + wantLt bool + wantLte bool + shouldFail bool }{ - {1, 2, false, false, true, true}, - {1.1, 1.2, false, false, true, true}, - {1, 1.2, false, false, true, true}, - {1.1, 2, false, false, true, true}, - {1, "2", false, false, true, true}, - {"1.1", 1.2, false, false, true, true}, - {"1.1", 1.1, false, true, false, true}, + {1, 2, false, false, true, true, false}, + {1.1, 1.2, false, false, true, true, false}, + {1, 1.2, false, false, true, true, false}, + {1.1, 2, false, false, true, true, false}, + {1, "2", false, false, true, true, false}, + {"1.1", 1.2, false, false, true, true, false}, + {"1.1", 1.1, false, true, false, true, false}, + + // The function panics if it's interfaces are nil, this happens if it doesn't find the field in the event and it's compared to a int or float + {nil, 2, true, false, false, false, true}, + {nil, nil, true, false, false, false, true}, + {2, nil, true, false, false, false, true}, + // If we pass anything (like an ip address) other than an int or float, the functions recurses until it stack overflows + {"127.0.0.1", "127.0.0.1", true, false, false, false, true}, + {"127.0.0.1", 0.2, true, false, false, false, true}, } for _, tt := range tests { t.Run(fmt.Sprintf("%s_%s", tt.left, tt.right), func(t *testing.T) { gotGt, gotGte, gotLt, gotLte, err := compareNumeric(tt.left, tt.right) if err != nil { - t.Errorf("compareNumeric() error = %v", err) - return + if !tt.shouldFail { + t.Errorf("compareNumeric() error = %v", err) + return + } else { + return + } } if gotGt != tt.wantGt { t.Errorf("compareNumeric() gotGt = %v, want %v", gotGt, tt.wantGt)