未验证 提交 1ba67ea8 编写于 作者: A Artem Zuikov 提交者: GitHub

Improve DecimalBinaryOperation specializations (#14743)

上级 51ba12c2
......@@ -65,7 +65,7 @@ namespace ErrorCodes
*/
template <typename A, typename B, typename Op, typename ResultType_ = typename Op::ResultType>
struct BinaryOperationImplBase
struct BinaryOperation
{
using ResultType = ResultType_;
static const constexpr bool allow_fixed_string = false;
......@@ -167,16 +167,24 @@ struct FixedStringOperationImpl
template <typename A, typename B, typename Op, typename ResultType = typename Op::ResultType>
struct BinaryOperationImpl : BinaryOperationImplBase<A, B, Op, ResultType>
struct BinaryOperationImpl : BinaryOperation<A, B, Op, ResultType>
{
};
template <typename T>
inline constexpr const auto & undec(const T & x)
{
if constexpr (IsDecimalNumber<T>)
return x.value;
else
return x;
}
/// Binary operations for Decimals need scale args
/// +|- scale one of args (which scale factor is not 1). ScaleR = oneof(Scale1, Scale2);
/// * no agrs scale. ScaleR = Scale1 + Scale2;
/// / first arg scale. ScaleR = Scale1 (scale_a = DecimalType<B>::getScale()).
template <typename A, typename B, template <typename, typename> typename Operation, typename ResultType_, bool _check_overflow = true>
template <template <typename, typename> typename Operation, typename ResultType_, bool check_overflow = true>
struct DecimalBinaryOperation
{
static constexpr bool is_plus_minus = IsOperation<Operation>::plus ||
......@@ -196,48 +204,10 @@ struct DecimalBinaryOperation
using Op = std::conditional_t<is_float_division,
DivideIntegralImpl<NativeResultType, NativeResultType>, /// substitute divide by intDiv (throw on division by zero)
Operation<NativeResultType, NativeResultType>>;
using ColVecA = std::conditional_t<IsDecimalNumber<A>, ColumnDecimal<A>, ColumnVector<A>>;
using ColVecB = std::conditional_t<IsDecimalNumber<B>, ColumnDecimal<B>, ColumnVector<B>>;
using ArrayA = typename ColVecA::Container;
using ArrayB = typename ColVecB::Container;
using ArrayC = typename ColumnDecimal<ResultType>::Container;
using SelfNoOverflow = DecimalBinaryOperation<A, B, Operation, ResultType_, false>;
static void vectorVector(const ArrayA & a, const ArrayB & b, ArrayC & c,
NativeResultType scale_a, NativeResultType scale_b, bool check_overflow)
{
if (check_overflow)
vectorVector(a, b, c, scale_a, scale_b);
else
SelfNoOverflow::vectorVector(a, b, c, scale_a, scale_b);
}
static void vectorConstant(const ArrayA & a, B b, ArrayC & c,
NativeResultType scale_a, NativeResultType scale_b, bool check_overflow)
{
if (check_overflow)
vectorConstant(a, b, c, scale_a, scale_b);
else
SelfNoOverflow::vectorConstant(a, b, c, scale_a, scale_b);
}
static void constantVector(A a, const ArrayB & b, ArrayC & c,
NativeResultType scale_a, NativeResultType scale_b, bool check_overflow)
{
if (check_overflow)
constantVector(a, b, c, scale_a, scale_b);
else
SelfNoOverflow::constantVector(a, b, c, scale_a, scale_b);
}
static ResultType constantConstant(A a, B b, NativeResultType scale_a, NativeResultType scale_b, bool check_overflow)
{
if (check_overflow)
return constantConstant(a, b, scale_a, scale_b);
else
return SelfNoOverflow::constantConstant(a, b, scale_a, scale_b);
}
using ArrayC = typename ColumnDecimal<ResultType>::Container;
template <bool is_decimal_a, bool is_decimal_b, typename ArrayA, typename ArrayB>
static void NO_INLINE vectorVector(const ArrayA & a, const ArrayB & b, ArrayC & c,
NativeResultType scale_a [[maybe_unused]], NativeResultType scale_b [[maybe_unused]])
{
......@@ -247,92 +217,102 @@ struct DecimalBinaryOperation
if (scale_a != 1)
{
for (size_t i = 0; i < size; ++i)
c[i] = applyScaled<true>(a[i], b[i], scale_a);
c[i] = applyScaled<true>(undec(a[i]), undec(b[i]), scale_a);
return;
}
else if (scale_b != 1)
{
for (size_t i = 0; i < size; ++i)
c[i] = applyScaled<false>(a[i], b[i], scale_b);
c[i] = applyScaled<false>(undec(a[i]), undec(b[i]), scale_b);
return;
}
}
else if constexpr (is_division && IsDecimalNumber<B>)
else if constexpr (is_division && is_decimal_b)
{
for (size_t i = 0; i < size; ++i)
c[i] = applyScaledDiv(a[i], b[i], scale_a);
c[i] = applyScaledDiv<is_decimal_a>(undec(a[i]), undec(b[i]), scale_a);
return;
}
/// default: use it if no return before
for (size_t i = 0; i < size; ++i)
c[i] = apply(a[i], b[i]);
c[i] = apply(undec(a[i]), undec(b[i]));
}
template <bool is_decimal_a, bool is_decimal_b, typename ArrayA, typename B>
static void NO_INLINE vectorConstant(const ArrayA & a, B b, ArrayC & c,
NativeResultType scale_a [[maybe_unused]], NativeResultType scale_b [[maybe_unused]])
{
static_assert(!IsDecimalNumber<B>);
size_t size = a.size();
if constexpr (is_plus_minus_compare)
{
if (scale_a != 1)
{
for (size_t i = 0; i < size; ++i)
c[i] = applyScaled<true>(a[i], b, scale_a);
c[i] = applyScaled<true>(undec(a[i]), b, scale_a);
return;
}
else if (scale_b != 1)
{
for (size_t i = 0; i < size; ++i)
c[i] = applyScaled<false>(a[i], b, scale_b);
c[i] = applyScaled<false>(undec(a[i]), b, scale_b);
return;
}
}
else if constexpr (is_division && IsDecimalNumber<B>)
else if constexpr (is_division && is_decimal_b)
{
for (size_t i = 0; i < size; ++i)
c[i] = applyScaledDiv(a[i], b, scale_a);
c[i] = applyScaledDiv<is_decimal_a>(undec(a[i]), b, scale_a);
return;
}
/// default: use it if no return before
for (size_t i = 0; i < size; ++i)
c[i] = apply(a[i], b);
c[i] = apply(undec(a[i]), b);
}
template <bool is_decimal_a, bool is_decimal_b, typename A, typename ArrayB>
static void NO_INLINE constantVector(A a, const ArrayB & b, ArrayC & c,
NativeResultType scale_a [[maybe_unused]], NativeResultType scale_b [[maybe_unused]])
{
static_assert(!IsDecimalNumber<A>);
size_t size = b.size();
if constexpr (is_plus_minus_compare)
{
if (scale_a != 1)
{
for (size_t i = 0; i < size; ++i)
c[i] = applyScaled<true>(a, b[i], scale_a);
c[i] = applyScaled<true>(a, undec(b[i]), scale_a);
return;
}
else if (scale_b != 1)
{
for (size_t i = 0; i < size; ++i)
c[i] = applyScaled<false>(a, b[i], scale_b);
c[i] = applyScaled<false>(a, undec(b[i]), scale_b);
return;
}
}
else if constexpr (is_division && IsDecimalNumber<B>)
else if constexpr (is_division && is_decimal_b)
{
for (size_t i = 0; i < size; ++i)
c[i] = applyScaledDiv(a, b[i], scale_a);
c[i] = applyScaledDiv<is_decimal_a>(a, undec(b[i]), scale_a);
return;
}
/// default: use it if no return before
for (size_t i = 0; i < size; ++i)
c[i] = apply(a, b[i]);
c[i] = apply(a, undec(b[i]));
}
template <bool is_decimal_a, bool is_decimal_b, typename A, typename B>
static ResultType constantConstant(A a, B b, NativeResultType scale_a [[maybe_unused]], NativeResultType scale_b [[maybe_unused]])
{
static_assert(!IsDecimalNumber<A>);
static_assert(!IsDecimalNumber<B>);
if constexpr (is_plus_minus_compare)
{
if (scale_a != 1)
......@@ -340,64 +320,16 @@ struct DecimalBinaryOperation
else if (scale_b != 1)
return applyScaled<false>(a, b, scale_b);
}
else if constexpr (is_division && IsDecimalNumber<B>)
return applyScaledDiv(a, b, scale_a);
else if constexpr (is_division && is_decimal_b)
return applyScaledDiv<is_decimal_a>(a, b, scale_a);
return apply(a, b);
}
private:
template <typename T, typename U>
static NativeResultType apply(const T & a, const U & b)
{
if constexpr (OverBigInt<T> || OverBigInt<U>)
{
if constexpr (IsDecimalNumber<T>)
return apply(a.value, b);
else if constexpr (IsDecimalNumber<U>)
return apply(a, b.value);
else
return applyNative(bigint_cast<NativeResultType>(a), bigint_cast<NativeResultType>(b));
}
else
return applyNative(a, b);
}
template <bool scale_left, typename T, typename U>
static NativeResultType applyScaled(const T & a, const U & b, NativeResultType scale)
{
if constexpr (OverBigInt<T> || OverBigInt<U>)
{
if constexpr (IsDecimalNumber<T>)
return applyScaled<scale_left>(a.value, b, scale);
else if constexpr (IsDecimalNumber<U>)
return applyScaled<scale_left>(a, b.value, scale);
else
return applyNativeScaled<scale_left>(bigint_cast<NativeResultType>(a), bigint_cast<NativeResultType>(b), scale);
}
else
return applyNativeScaled<scale_left>(a, b, scale);
}
template <typename T, typename U>
static NativeResultType applyScaledDiv(const T & a, const U & b, NativeResultType scale)
{
if constexpr (OverBigInt<T> || OverBigInt<U>)
{
if constexpr (IsDecimalNumber<T>)
return applyScaledDiv(a.value, b, scale);
else if constexpr (IsDecimalNumber<U>)
return applyScaledDiv(a, b.value, scale);
else
return applyNativeScaledDiv(bigint_cast<NativeResultType>(a), bigint_cast<NativeResultType>(b), scale);
}
else
return applyNativeScaledDiv(a, b, scale);
}
/// there's implicit type convertion here
static NativeResultType applyNative(NativeResultType a, NativeResultType b)
static NativeResultType apply(NativeResultType a, NativeResultType b)
{
if constexpr (can_overflow && _check_overflow)
if constexpr (can_overflow && check_overflow)
{
NativeResultType res;
if (Op::template apply<NativeResultType>(a, b, res))
......@@ -409,13 +341,13 @@ private:
}
template <bool scale_left>
static NO_SANITIZE_UNDEFINED NativeResultType applyNativeScaled(NativeResultType a, NativeResultType b, NativeResultType scale)
static NO_SANITIZE_UNDEFINED NativeResultType applyScaled(NativeResultType a, NativeResultType b, NativeResultType scale)
{
if constexpr (is_plus_minus_compare)
{
NativeResultType res;
if constexpr (_check_overflow)
if constexpr (check_overflow)
{
bool overflow = false;
if constexpr (scale_left)
......@@ -444,14 +376,15 @@ private:
}
}
static NO_SANITIZE_UNDEFINED NativeResultType applyNativeScaledDiv(NativeResultType a, NativeResultType b, NativeResultType scale)
template <bool is_decimal_a>
static NO_SANITIZE_UNDEFINED NativeResultType applyScaledDiv(NativeResultType a, NativeResultType b, NativeResultType scale)
{
if constexpr (is_division)
{
if constexpr (_check_overflow)
if constexpr (check_overflow)
{
bool overflow = false;
if constexpr (!IsDecimalNumber<A>)
if constexpr (!is_decimal_a)
overflow |= common::mulOverflow(scale, scale, scale);
overflow |= common::mulOverflow(a, scale, a);
if (overflow)
......@@ -459,7 +392,7 @@ private:
}
else
{
if constexpr (!IsDecimalNumber<A>)
if constexpr (!is_decimal_a)
scale *= scale;
a *= scale;
}
......@@ -1024,10 +957,15 @@ public:
if constexpr (IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>)
{
using OpImpl = DecimalBinaryOperation<T0, T1, Op, ResultType>;
using NativeResultType = typename NativeType<ResultType>::Type;
using OpImpl = DecimalBinaryOperation<Op, ResultType, false>;
using OpImplCheck = DecimalBinaryOperation<Op, ResultType, true>;
ResultDataType type = decimalResultType<is_multiply, is_division>(left, right);
static constexpr const bool dec_a = IsDecimalNumber<T0>;
static constexpr const bool dec_b = IsDecimalNumber<T1>;
typename ResultDataType::FieldType scale_a = type.scaleFactorFor(left, is_multiply);
typename ResultDataType::FieldType scale_b = type.scaleFactorFor(right, is_multiply || is_division);
if constexpr (IsDataTypeDecimal<RightDataType> && is_division)
......@@ -1036,8 +974,12 @@ public:
/// non-vector result
if (col_left_const && col_right_const)
{
auto res = OpImpl::constantConstant(col_left_const->template getValue<T0>(), col_right_const->template getValue<T1>(),
scale_a, scale_b, check_decimal_overflow);
NativeResultType const_a = col_left_const->template getValue<T0>();
NativeResultType const_b = col_right_const->template getValue<T1>();
auto res = check_decimal_overflow ?
OpImplCheck::template constantConstant<dec_a, dec_b>(const_a, const_b, scale_a, scale_b) :
OpImpl::template constantConstant<dec_a, dec_b>(const_a, const_b, scale_a, scale_b);
block.getByPosition(result).column = ResultDataType(type.getPrecision(), type.getScale()).createColumnConst(
col_left_const->size(), toField(res, type.getScale()));
......@@ -1050,17 +992,28 @@ public:
if (col_left && col_right)
{
OpImpl::vectorVector(col_left->getData(), col_right->getData(), vec_res, scale_a, scale_b, check_decimal_overflow);
if (check_decimal_overflow)
OpImplCheck::template vectorVector<dec_a, dec_b>(col_left->getData(), col_right->getData(), vec_res, scale_a, scale_b);
else
OpImpl::template vectorVector<dec_a, dec_b>(col_left->getData(), col_right->getData(), vec_res, scale_a, scale_b);
}
else if (col_left_const && col_right)
{
OpImpl::constantVector(col_left_const->template getValue<T0>(), col_right->getData(), vec_res,
scale_a, scale_b, check_decimal_overflow);
NativeResultType const_a = col_left_const->template getValue<T0>();
if (check_decimal_overflow)
OpImplCheck::template constantVector<dec_a, dec_b>(const_a, col_right->getData(), vec_res, scale_a, scale_b);
else
OpImpl::template constantVector<dec_a, dec_b>(const_a, col_right->getData(), vec_res, scale_a, scale_b);
}
else if (col_left && col_right_const)
{
OpImpl::vectorConstant(col_left->getData(), col_right_const->template getValue<T1>(), vec_res,
scale_a, scale_b, check_decimal_overflow);
NativeResultType const_b = col_right_const->template getValue<T1>();
if (check_decimal_overflow)
OpImplCheck::template vectorConstant<dec_a, dec_b>(col_left->getData(), const_b, vec_res, scale_a, scale_b);
else
OpImpl::template vectorConstant<dec_a, dec_b>(col_left->getData(), const_b, vec_res, scale_a, scale_b);
}
else
return false;
......
......@@ -22,7 +22,7 @@ namespace
template <typename A, typename B>
struct DivideIntegralByConstantImpl
: BinaryOperationImplBase<A, B, DivideIntegralImpl<A, B>>
: BinaryOperation<A, B, DivideIntegralImpl<A, B>>
{
using ResultType = typename DivideIntegralImpl<A, B>::ResultType;
static const constexpr bool allow_fixed_string = false;
......
......@@ -22,7 +22,7 @@ namespace
template <typename A, typename B>
struct ModuloByConstantImpl
: BinaryOperationImplBase<A, B, ModuloImpl<A, B>>
: BinaryOperation<A, B, ModuloImpl<A, B>>
{
using ResultType = typename ModuloImpl<A, B>::ResultType;
static const constexpr bool allow_fixed_string = false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册