Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,86 @@
package org.apache.doris.nereids.trees.expressions.functions;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.coercion.Int32OrLessType;
import org.apache.doris.qe.ConnectContext;

import java.math.BigInteger;
import java.util.ArrayList;
import java.util.List;

/**
* if argument 0 is float or double, we should return double signature for round like function.
* Signature search for round-like functions (round, round_bankers, ceil, floor, truncate).
*/
public interface SearchSignatureForRound extends ExplicitlyCastableSignature {

int DOUBLE_DECIMAL_RESULT_MAX_SCALE = 15;

@Override
default FunctionSignature searchSignature(List<FunctionSignature> signatures) {
List<Expression> arguments = getArguments();
if (arguments.get(0).getDataType().isFloatLikeType()) {
if (arguments.size() == 1) {
return FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE);
} else if (arguments.size() == 2) {
Comment thread
xy720 marked this conversation as resolved.
if (arguments.get(0).getDataType().isDoubleType()
&& isOptedIntoDecimalReroute()
&& isNonNegativeIntegerLiteralAtMost(arguments.get(1),
DOUBLE_DECIMAL_RESULT_MAX_SCALE)) {
Comment thread
xy720 marked this conversation as resolved.
return ExplicitlyCastableSignature.super.searchSignature(
withoutFloatLikeReturnTypes(signatures));
}
return FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, IntegerType.INSTANCE);
}
}
return ExplicitlyCastableSignature.super.searchSignature(signatures);
}

static boolean isOptedIntoDecimalReroute() {
ConnectContext ctx = ConnectContext.get();
return ctx != null && ctx.getSessionVariable().roundDoubleReturnsDecimalForConstScale;
}

/**
* True iff scale folds to an integer literal whose value lies in the closed range
* [0, maxValue].
*/
static boolean isNonNegativeIntegerLiteralAtMost(Expression scale, int maxValue) {
Expression folded = scale;
if (!folded.isLiteral() && !folded.isSlot()) {
ExpressionRewriteContext ctx = new ExpressionRewriteContext(CascadesContext.initTempContext());
folded = FoldConstantRuleOnFE.evaluate(folded, ctx);
}
Expression unwrapped = folded;
if (unwrapped instanceof Cast && unwrapped.child(0).isLiteral()
&& unwrapped.child(0).getDataType() instanceof Int32OrLessType) {
unwrapped = unwrapped.child(0);
}
if (!(unwrapped instanceof IntegerLikeLiteral)) {
return false;
}
Number number = ((IntegerLikeLiteral) unwrapped).getNumber();
BigInteger value = (number instanceof BigInteger)
? (BigInteger) number
: BigInteger.valueOf(number.longValue());
return value.signum() >= 0 && value.compareTo(BigInteger.valueOf(maxValue)) <= 0;
}
Comment thread
xy720 marked this conversation as resolved.

/** Drop signatures whose return type is a float-like type, so the search falls onto decimal. */
static List<FunctionSignature> withoutFloatLikeReturnTypes(List<FunctionSignature> signatures) {
List<FunctionSignature> result = new ArrayList<>(signatures.size());
for (FunctionSignature signature : signatures) {
if (!signature.returnType.isFloatLikeType()) {
result.add(signature);
}
}
return result;
}
}
20 changes: 20 additions & 0 deletions fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,9 @@ public String toString() {

public static final String DECIMAL_OVERFLOW_SCALE = "decimal_overflow_scale";

public static final String ROUND_DOUBLE_RETURNS_DECIMAL_FOR_CONST_SCALE
= "round_double_returns_decimal_for_const_scale";

public static final String TRIM_TAILING_SPACES_FOR_EXTERNAL_TABLE_QUERY
= "trim_tailing_spaces_for_external_table_query";

Expand Down Expand Up @@ -1915,6 +1918,23 @@ public void setMaxJoinNumberOfReorder(int maxJoinNumberOfReorder) {
)
public int decimalOverflowScale = 6;

@VarAttrDef.VarAttr(name = ROUND_DOUBLE_RETURNS_DECIMAL_FOR_CONST_SCALE,
needForward = true, affectQueryResultInPlan = true,
description = {
"当为 true 时,round/round_bankers/ceil/floor/truncate 在第一个参数为 DOUBLE 且第二个参数"
+ "为非负整数字面量(且不超过 15)时,返回类型从 DOUBLE 改为 DECIMAL,避免出现"
+ " round(23900/293, 2) 显示为 81.56999999999999 这类 IEEE-754 残尾。注意启用后,"
+ " |x| >= 1e15 的 DOUBLE 输入以及 Inf/NaN 在隐式 cast 至 decimal(30, 15) 时会变 NULL"
+ "(非严格模式)或抛 ARITHMETIC_OVERFLOW(严格模式),故默认关闭。",
"When true, round/round_bankers/ceil/floor/truncate return DECIMAL instead of DOUBLE"
+ " when the first argument is a DOUBLE and the second is a non-negative integer literal"
+ " no greater than 15. This avoids IEEE-754 residual tails such as round(23900/293, 2)"
+ " rendering as 81.56999999999999. Enabling it makes DOUBLE inputs with |x| >= 1e15,"
+ " Inf, or NaN turn into NULL (non-strict mode) or raise ARITHMETIC_OVERFLOW"
+ " (strict mode) due to the implicit cast to decimal(30, 15); off by default."}
)
public boolean roundDoubleReturnsDecimalForConstScale = false;

@VarAttrDef.VarAttr(name = ENABLE_DPHYP_OPTIMIZER)
public boolean enableDPHypOptimizer = false;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions;

import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Ceil;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Floor;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Round;
import org.apache.doris.nereids.trees.expressions.functions.scalar.RoundBankers;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Truncate;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.LargeIntLiteral;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.qe.ConnectContext;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;

import java.math.BigInteger;

public class SearchSignatureForRoundTest {

private static final DoubleLiteral DOUBLE_VAL = new DoubleLiteral(81.56996587030717);

/** Run {@code body} with a fresh ConnectContext whose new opt-in var is set to {@code optIn}. */
private static void withOptIn(boolean optIn, Runnable body) {
try (MockedStatic<ConnectContext> mockedContext = Mockito.mockStatic(ConnectContext.class)) {
ConnectContext ctx = new ConnectContext();
ctx.getSessionVariable().roundDoubleReturnsDecimalForConstScale = optIn;
mockedContext.when(ConnectContext::get).thenReturn(ctx);
body.run();
}
}

private static void assertDecimalReturn(int expectedScale, Expression expr) {
Assertions.assertTrue(expr.getDataType() instanceof DecimalV3Type,
() -> "expected DecimalV3Type, got " + expr.getDataType());
DecimalV3Type t = (DecimalV3Type) expr.getDataType();
Assertions.assertEquals(expectedScale, t.getScale(),
() -> "expected scale=" + expectedScale + ", got " + t);
}

private static void assertDoubleReturn(Expression expr) {
Assertions.assertTrue(expr.getDataType() instanceof DoubleType,
() -> "expected DoubleType, got " + expr.getDataType());
}

// ---- opt-in ON: (DOUBLE, non-negative int literal <= 15) routes to DECIMAL ----

@Test
void roundDoubleWithConstScaleReturnsDecimal() {
withOptIn(true, () ->
assertDecimalReturn(2, new Round(DOUBLE_VAL, new IntegerLiteral(2))));
}

@Test
void roundBankersDoubleWithConstScaleReturnsDecimal() {
withOptIn(true, () ->
assertDecimalReturn(2, new RoundBankers(DOUBLE_VAL, new IntegerLiteral(2))));
}

@Test
void ceilDoubleWithConstScaleReturnsDecimal() {
withOptIn(true, () ->
assertDecimalReturn(2, new Ceil(DOUBLE_VAL, new IntegerLiteral(2))));
}

@Test
void floorDoubleWithConstScaleReturnsDecimal() {
withOptIn(true, () ->
assertDecimalReturn(2, new Floor(DOUBLE_VAL, new IntegerLiteral(2))));
}

@Test
void truncateDoubleWithConstScaleReturnsDecimal() {
withOptIn(true, () ->
assertDecimalReturn(2, new Truncate(DOUBLE_VAL, new IntegerLiteral(2))));
}

@Test
void zeroScaleAlsoReturnsDecimal() {
withOptIn(true, () ->
assertDecimalReturn(0, new Round(DOUBLE_VAL, new IntegerLiteral(0))));
}

@Test
void roundDoubleAtMaxPreservableScaleReturnsDecimal() {
// scale 15 == DOUBLE_DECIMAL.scale.
withOptIn(true, () ->
assertDecimalReturn(15, new Round(DOUBLE_VAL, new IntegerLiteral(15))));
}

@Test
void roundDoubleWithCastIntLiteralReturnsDecimal() {
withOptIn(true, () -> {
Cast wrapped = new Cast(new IntegerLiteral(3), IntegerType.INSTANCE);
assertDecimalReturn(3, new Round(DOUBLE_VAL, wrapped));
});
}

// ---- opt-in ON but shape doesn't match: stays DOUBLE ----

@Test
void roundDoubleSingleArgStaysDouble() {
withOptIn(true, () -> assertDoubleReturn(new Round(DOUBLE_VAL)));
}

@Test
void roundDoubleNegativeScaleStaysDouble() {
withOptIn(true, () ->
assertDoubleReturn(new Round(DOUBLE_VAL, new IntegerLiteral(-1))));
}

@Test
void roundDoubleScaleFromColumnStaysDouble() {
// When the scale comes from a column (non-literal), we cannot pick a
// fixed decimal scale at plan time, so we keep the original behavior.
withOptIn(true, () -> {
SlotReference scaleCol = new SlotReference("n", IntegerType.INSTANCE);
assertDoubleReturn(new Round(DOUBLE_VAL, scaleCol));
});
}

@Test
void roundFloatWithConstScaleStaysDouble() {
// FLOAT input keeps the original DOUBLE return path.
withOptIn(true, () -> {
SlotReference floatCol = new SlotReference("f", FloatType.INSTANCE);
assertDoubleReturn(new Round(floatCol, new IntegerLiteral(2)));
});
}

@Test
void roundDoubleScaleAboveMaxPreservableStaysDouble() {
// scale 17 exceeds DOUBLE_DECIMAL.scale (15).
withOptIn(true, () ->
assertDoubleReturn(new Round(DOUBLE_VAL, new IntegerLiteral(17))));
}

@Test
void roundDecimalKeepsExistingDecimalSignature() {
// The decimal-input path is independent of the new var.
withOptIn(true, () -> {
DecimalV3Type t = DecimalV3Type.createDecimalV3Type(10, 5);
SlotReference dec = new SlotReference("d", t);
assertDecimalReturn(2, new Round(dec, new IntegerLiteral(2)));
});
}

@Test
void roundDoubleBigIntScaleAboveIntRangeStaysDouble() {
withOptIn(true, () ->
assertDoubleReturn(new Round(DOUBLE_VAL, new BigIntLiteral(4294967298L))));
}

@Test
void roundDoubleBigIntScaleAtIntMaxPlusOneStaysDouble() {
withOptIn(true, () ->
assertDoubleReturn(new Round(DOUBLE_VAL, new BigIntLiteral(2147483648L))));
}

@Test
void roundDoubleLargeIntScaleAboveLongRangeStaysDouble() {
withOptIn(true, () -> {
BigInteger huge = new BigInteger("99999999999999999999"); // 20 digits
assertDoubleReturn(new Round(DOUBLE_VAL, new LargeIntLiteral(huge)));
});
}

@Test
void roundDoubleBigIntNegativeScaleStaysDouble() {
withOptIn(true, () ->
assertDoubleReturn(new Round(DOUBLE_VAL, new BigIntLiteral(-1L))));
}

// ---- opt-in OFF: DOUBLE shape that would otherwise reroute stays DOUBLE ----

@Test
void roundDoubleStaysDoubleWhenOptInIsOff() {
withOptIn(false, () ->
assertDoubleReturn(new Round(DOUBLE_VAL, new IntegerLiteral(2))));
}

@Test
void truncateDoubleStaysDoubleWhenOptInIsOff() {
withOptIn(false, () ->
assertDoubleReturn(new Truncate(DOUBLE_VAL, new IntegerLiteral(2))));
}

@Test
void roundDecimalIsUnaffectedByOptInOff() {
// Decimal input is independent of the new var.
withOptIn(false, () -> {
DecimalV3Type t = DecimalV3Type.createDecimalV3Type(10, 5);
SlotReference dec = new SlotReference("d", t);
assertDecimalReturn(2, new Round(dec, new IntegerLiteral(2)));
});
}
}
Loading
Loading