From 62deb32932f9a1e025a6e304eaea14a67a3231ac Mon Sep 17 00:00:00 2001 From: darshanime Date: Sat, 23 Aug 2025 16:13:38 +0530 Subject: [PATCH] add constant folding Signed-off-by: darshanime --- promql/engine.go | 4 + promql/engine_test.go | 55 --------- promql/folding.go | 118 +++++++++++++++++++ promql/folding_test.go | 259 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 381 insertions(+), 55 deletions(-) create mode 100644 promql/folding.go create mode 100644 promql/folding_test.go diff --git a/promql/engine.go b/promql/engine.go index d476e28cf4..130792f750 100644 --- a/promql/engine.go +++ b/promql/engine.go @@ -3733,6 +3733,10 @@ func PreprocessExpr(expr parser.Expr, start, end time.Time, step time.Duration) return nil, err } + expr, err := ConstantFoldExpr(expr) + if err != nil { + return nil, err + } _, shouldWrap := preprocessExprHelper(expr, start, end) if shouldWrap { return newStepInvariantExpr(expr), nil diff --git a/promql/engine_test.go b/promql/engine_test.go index 536f4cac62..c0dd37a458 100644 --- a/promql/engine_test.go +++ b/promql/engine_test.go @@ -3029,61 +3029,6 @@ func TestPreprocessAndWrapWithStepInvariantExpr(t *testing.T) { }, }, }, - { - input: `floor(some_metric / (3 * 1024))`, - outputTest: true, - expected: &parser.Call{ - Func: &parser.Function{ - Name: "floor", - ArgTypes: []parser.ValueType{parser.ValueTypeVector}, - ReturnType: parser.ValueTypeVector, - }, - Args: parser.Expressions{ - &parser.BinaryExpr{ - Op: parser.DIV, - LHS: &parser.VectorSelector{ - Name: "some_metric", - LabelMatchers: []*labels.Matcher{ - parser.MustLabelMatcher(labels.MatchEqual, "__name__", "some_metric"), - }, - PosRange: posrange.PositionRange{ - Start: 6, - End: 17, - }, - }, - RHS: &parser.StepInvariantExpr{ - Expr: &parser.ParenExpr{ - Expr: &parser.BinaryExpr{ - Op: parser.MUL, - LHS: &parser.NumberLiteral{ - Val: 3, - PosRange: posrange.PositionRange{ - Start: 21, - End: 22, - }, - }, - RHS: &parser.NumberLiteral{ - Val: 1024, - PosRange: posrange.PositionRange{ - Start: 25, - End: 29, - }, - }, - }, - PosRange: posrange.PositionRange{ - Start: 20, - End: 30, - }, - }, - }, - }, - }, - PosRange: posrange.PositionRange{ - Start: 0, - End: 31, - }, - }, - }, } for _, test := range testCases { diff --git a/promql/folding.go b/promql/folding.go new file mode 100644 index 0000000000..c38235a09e --- /dev/null +++ b/promql/folding.go @@ -0,0 +1,118 @@ +// Copyright 2025 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package promql + +import ( + "math" + + "github.com/prometheus/prometheus/promql/parser" +) + +func ConstantFoldExpr(expr parser.Expr) (parser.Expr, error) { + var err error + switch n := expr.(type) { + case *parser.BinaryExpr: + lhs, err := ConstantFoldExpr(n.LHS) + if err != nil { + return expr, err + } + rhs, err := ConstantFoldExpr(n.RHS) + if err != nil { + return expr, err + } + n.LHS, n.RHS = lhs, rhs + + unwrapParenExpr(&lhs) + unwrapParenExpr(&rhs) + + if lhs.Type() == parser.ValueTypeScalar && rhs.Type() == parser.ValueTypeScalar { + lhsVal, okLHS := lhs.(*parser.NumberLiteral) + rhsVal, okRHS := rhs.(*parser.NumberLiteral) + if okLHS && okRHS { + val := scalarBinop(n.Op, lhsVal.Val, rhsVal.Val) + return &parser.NumberLiteral{ + Val: val, + PosRange: n.PositionRange(), + }, nil + } + } + return n, nil + + case *parser.Call: + if n.Func.Name == "pi" { + return &parser.NumberLiteral{ + Val: math.Pi, + PosRange: n.PositionRange(), + }, nil + } + for i := range n.Args { + n.Args[i], err = ConstantFoldExpr(n.Args[i]) + if err != nil { + return expr, err + } + } + return n, nil + + case *parser.MatrixSelector: + n.VectorSelector, err = ConstantFoldExpr(n.VectorSelector) + if err != nil { + return expr, err + } + return n, nil + + case *parser.AggregateExpr: + n.Expr, err = ConstantFoldExpr(n.Expr) + if err != nil { + return expr, err + } + return n, nil + + case *parser.SubqueryExpr: + n.Expr, err = ConstantFoldExpr(n.Expr) + if err != nil { + return expr, err + } + return n, nil + + case *parser.ParenExpr: + n.Expr, err = ConstantFoldExpr(n.Expr) + if err != nil { + return expr, err + } + return n, nil + + case *parser.UnaryExpr: + n.Expr, err = ConstantFoldExpr(n.Expr) + if err != nil { + return expr, err + } + + unwrapParenExpr(&n.Expr) + if n.Expr.Type() == parser.ValueTypeScalar { + if val, ok := n.Expr.(*parser.NumberLiteral); ok { + if n.Op == parser.SUB { + val.Val = -val.Val + } + return &parser.NumberLiteral{ + Val: val.Val, + PosRange: n.PositionRange(), + }, nil + } + } + return n, nil + + default: + return n, nil + } +} diff --git a/promql/folding_test.go b/promql/folding_test.go new file mode 100644 index 0000000000..978bee071a --- /dev/null +++ b/promql/folding_test.go @@ -0,0 +1,259 @@ +package promql + +import ( + "math" + "testing" + + "github.com/prometheus/prometheus/promql/parser" + "github.com/prometheus/prometheus/promql/parser/posrange" + "github.com/stretchr/testify/require" +) + +var testExpr = []struct { + input string + expected parser.Expr +}{ + { + input: "1", + expected: &parser.NumberLiteral{ + Val: 1, + PosRange: posrange.PositionRange{Start: 0, End: 1}, + }, + }, + { + input: "+Inf", + expected: &parser.NumberLiteral{ + Val: math.Inf(1), + PosRange: posrange.PositionRange{Start: 0, End: 4}, + }, + }, + { + input: "123.4567", + expected: &parser.NumberLiteral{ + Val: 123.4567, + PosRange: posrange.PositionRange{Start: 0, End: 8}, + }, + }, + { + input: "5e-3", + expected: &parser.NumberLiteral{ + Val: 0.005, + PosRange: posrange.PositionRange{Start: 0, End: 4}, + }, + }, + { + input: "1 + 1", + expected: &parser.NumberLiteral{ + Val: 2, + PosRange: posrange.PositionRange{Start: 0, End: 5}, + }, + }, + { + input: "pi()", + expected: &parser.NumberLiteral{ + Val: math.Pi, + PosRange: posrange.PositionRange{Start: 0, End: 4}, + }, + }, + { + input: "1 - 1", + expected: &parser.NumberLiteral{ + Val: 0, + PosRange: posrange.PositionRange{Start: 0, End: 5}, + }, + }, + { + input: "1 * 1", + expected: &parser.NumberLiteral{ + Val: 1, + PosRange: posrange.PositionRange{Start: 0, End: 5}, + }, + }, + { + input: "1 % 1", + expected: &parser.NumberLiteral{ + Val: 0, + PosRange: posrange.PositionRange{Start: 0, End: 5}, + }, + }, + { + input: "1 / 1", + expected: &parser.NumberLiteral{ + Val: 1, + PosRange: posrange.PositionRange{Start: 0, End: 5}, + }, + }, + { + input: "1 == bool 1", + expected: &parser.NumberLiteral{ + Val: 1, + PosRange: posrange.PositionRange{Start: 0, End: 11}, + }, + }, + { + input: "1 != bool 1", + expected: &parser.NumberLiteral{ + Val: 0, + PosRange: posrange.PositionRange{Start: 0, End: 11}, + }, + }, + { + input: "1 > bool 1", + expected: &parser.NumberLiteral{ + Val: 0, + PosRange: posrange.PositionRange{Start: 0, End: 10}, + }, + }, + { + input: "1 >= bool 1", + expected: &parser.NumberLiteral{ + Val: 1, + PosRange: posrange.PositionRange{Start: 0, End: 11}, + }, + }, + { + input: "1 < bool 1", + expected: &parser.NumberLiteral{ + Val: 0, + PosRange: posrange.PositionRange{Start: 0, End: 10}, + }, + }, + { + input: "1 <= bool 1", + expected: &parser.NumberLiteral{ + Val: 1, + PosRange: posrange.PositionRange{Start: 0, End: 11}, + }, + }, + { + input: "(-1)^2", + expected: &parser.NumberLiteral{ + Val: 1, + PosRange: posrange.PositionRange{Start: 0, End: 6}, + }, + }, + { + input: "-1*2", + expected: &parser.NumberLiteral{ + Val: -2, + PosRange: posrange.PositionRange{Start: 0, End: 4}, + }, + }, + { + input: "-1+2", + expected: &parser.NumberLiteral{ + Val: 1, + PosRange: posrange.PositionRange{Start: 0, End: 4}, + }, + }, + { + input: "(-1)^-2", + expected: &parser.NumberLiteral{ + Val: 1, + PosRange: posrange.PositionRange{Start: 0, End: 7}, + }, + }, + { + input: "+1 + -2 * 1", + expected: &parser.NumberLiteral{ + Val: -1, + PosRange: posrange.PositionRange{Start: 0, End: 11}, + }, + }, + { + input: "1 + 2/(3*1)", + expected: &parser.NumberLiteral{ + Val: 1.6666666666666665, + PosRange: posrange.PositionRange{Start: 0, End: 11}, + }, + }, + { + input: "1 < bool 2 - 1 * 2", + expected: &parser.NumberLiteral{ + Val: 0, + PosRange: posrange.PositionRange{Start: 0, End: 18}, + }, + }, + { + input: "((1+2)*(3-1))/(2+2)", + expected: &parser.NumberLiteral{ + Val: 1.5, + PosRange: posrange.PositionRange{Start: 0, End: 19}, + }, + }, + { + input: "(((-1)^2) + (2 % 3) * (4 - 1))", + expected: &parser.ParenExpr{ + Expr: &parser.NumberLiteral{ + Val: 7, + PosRange: posrange.PositionRange{Start: 1, End: 29}, + }, + PosRange: posrange.PositionRange{Start: 0, End: 30}, + }, + }, + { + input: "-(1 + 2) * +((3 - 5) ^ 2)", + expected: &parser.NumberLiteral{ + Val: -12, + PosRange: posrange.PositionRange{Start: 0, End: 24}, + }, + }, + { + input: "((1+1) == bool (2))", + expected: &parser.ParenExpr{ + Expr: &parser.NumberLiteral{ + Val: 1, + PosRange: posrange.PositionRange{Start: 1, End: 18}, + }, + PosRange: posrange.PositionRange{Start: 0, End: 19}, + }, + }, + { + input: "((1+2) <= bool (2-1))", + expected: &parser.ParenExpr{ + Expr: &parser.NumberLiteral{ + Val: 0, + PosRange: posrange.PositionRange{Start: 1, End: 20}, + }, + PosRange: posrange.PositionRange{Start: 0, End: 21}, + }, + }, + { + input: "1 + 2/(3*1) + (4-2)*(7%5)", + expected: &parser.NumberLiteral{ + Val: 5.666666666666666, + PosRange: posrange.PositionRange{Start: 0, End: 25}, + }, + }, + { + input: "pi() * (1 + 1) - (3 - 3)", + expected: &parser.NumberLiteral{ + Val: 2 * math.Pi, + PosRange: posrange.PositionRange{Start: 0, End: 24}, + }, + }, + { + input: "((1.5 + 2.25) * 2) / (7 - 3)", + expected: &parser.NumberLiteral{ + Val: ((1.5 + 2.25) * 2) / 4, + PosRange: posrange.PositionRange{Start: 0, End: 28}, + }, + }, + { + input: "((1 < bool 2) + (3 == bool 3)) * (4 % 3)", + expected: &parser.NumberLiteral{ + Val: 2, + PosRange: posrange.PositionRange{Start: 0, End: 40}, + }, + }, +} + +func TestConstantFolding(t *testing.T) { + for _, test := range testExpr { + expr, err := parser.ParseExpr(test.input) + require.NoError(t, err, "cannot parse original query") + expr, err = ConstantFoldExpr(expr) + require.NoError(t, err, "cannot fold query") + require.Equal(t, test.expected, expr, "error on input '%s'", test.input) + } +}