未验证 提交 1ba67ea8 编写于 作者: A Artem Zuikov 提交者: GitHub

Improve DecimalBinaryOperation specializations (#14743)

上级 51ba12c2
...@@ -65,7 +65,7 @@ namespace ErrorCodes ...@@ -65,7 +65,7 @@ namespace ErrorCodes
*/ */
template <typename A, typename B, typename Op, typename ResultType_ = typename Op::ResultType> template <typename A, typename B, typename Op, typename ResultType_ = typename Op::ResultType>
struct BinaryOperationImplBase struct BinaryOperation
{ {
using ResultType = ResultType_; using ResultType = ResultType_;
static const constexpr bool allow_fixed_string = false; static const constexpr bool allow_fixed_string = false;
...@@ -167,16 +167,24 @@ struct FixedStringOperationImpl ...@@ -167,16 +167,24 @@ struct FixedStringOperationImpl
template <typename A, typename B, typename Op, typename ResultType = typename Op::ResultType> template <typename A, typename B, typename Op, typename ResultType = typename Op::ResultType>
struct BinaryOperationImpl : BinaryOperationImplBase<A, B, Op, ResultType> struct BinaryOperationImpl : BinaryOperation<A, B, Op, ResultType>
{ {
}; };
template <typename T>
inline constexpr const auto & undec(const T & x)
{
if constexpr (IsDecimalNumber<T>)
return x.value;
else
return x;
}
/// Binary operations for Decimals need scale args /// Binary operations for Decimals need scale args
/// +|- scale one of args (which scale factor is not 1). ScaleR = oneof(Scale1, Scale2); /// +|- scale one of args (which scale factor is not 1). ScaleR = oneof(Scale1, Scale2);
/// * no agrs scale. ScaleR = Scale1 + Scale2; /// * no agrs scale. ScaleR = Scale1 + Scale2;
/// / first arg scale. ScaleR = Scale1 (scale_a = DecimalType<B>::getScale()). /// / first arg scale. ScaleR = Scale1 (scale_a = DecimalType<B>::getScale()).
template <typename A, typename B, template <typename, typename> typename Operation, typename ResultType_, bool _check_overflow = true> template <template <typename, typename> typename Operation, typename ResultType_, bool check_overflow = true>
struct DecimalBinaryOperation struct DecimalBinaryOperation
{ {
static constexpr bool is_plus_minus = IsOperation<Operation>::plus || static constexpr bool is_plus_minus = IsOperation<Operation>::plus ||
...@@ -196,48 +204,10 @@ struct DecimalBinaryOperation ...@@ -196,48 +204,10 @@ struct DecimalBinaryOperation
using Op = std::conditional_t<is_float_division, using Op = std::conditional_t<is_float_division,
DivideIntegralImpl<NativeResultType, NativeResultType>, /// substitute divide by intDiv (throw on division by zero) DivideIntegralImpl<NativeResultType, NativeResultType>, /// substitute divide by intDiv (throw on division by zero)
Operation<NativeResultType, NativeResultType>>; Operation<NativeResultType, NativeResultType>>;
using ColVecA = std::conditional_t<IsDecimalNumber<A>, ColumnDecimal<A>, ColumnVector<A>>;
using ColVecB = std::conditional_t<IsDecimalNumber<B>, ColumnDecimal<B>, ColumnVector<B>>;
using ArrayA = typename ColVecA::Container;
using ArrayB = typename ColVecB::Container;
using ArrayC = typename ColumnDecimal<ResultType>::Container;
using SelfNoOverflow = DecimalBinaryOperation<A, B, Operation, ResultType_, false>;
static void vectorVector(const ArrayA & a, const ArrayB & b, ArrayC & c,
NativeResultType scale_a, NativeResultType scale_b, bool check_overflow)
{
if (check_overflow)
vectorVector(a, b, c, scale_a, scale_b);
else
SelfNoOverflow::vectorVector(a, b, c, scale_a, scale_b);
}
static void vectorConstant(const ArrayA & a, B b, ArrayC & c, using ArrayC = typename ColumnDecimal<ResultType>::Container;
NativeResultType scale_a, NativeResultType scale_b, bool check_overflow)
{
if (check_overflow)
vectorConstant(a, b, c, scale_a, scale_b);
else
SelfNoOverflow::vectorConstant(a, b, c, scale_a, scale_b);
}
static void constantVector(A a, const ArrayB & b, ArrayC & c,
NativeResultType scale_a, NativeResultType scale_b, bool check_overflow)
{
if (check_overflow)
constantVector(a, b, c, scale_a, scale_b);
else
SelfNoOverflow::constantVector(a, b, c, scale_a, scale_b);
}
static ResultType constantConstant(A a, B b, NativeResultType scale_a, NativeResultType scale_b, bool check_overflow)
{
if (check_overflow)
return constantConstant(a, b, scale_a, scale_b);
else
return SelfNoOverflow::constantConstant(a, b, scale_a, scale_b);
}
template <bool is_decimal_a, bool is_decimal_b, typename ArrayA, typename ArrayB>
static void NO_INLINE vectorVector(const ArrayA & a, const ArrayB & b, ArrayC & c, static void NO_INLINE vectorVector(const ArrayA & a, const ArrayB & b, ArrayC & c,
NativeResultType scale_a [[maybe_unused]], NativeResultType scale_b [[maybe_unused]]) NativeResultType scale_a [[maybe_unused]], NativeResultType scale_b [[maybe_unused]])
{ {
...@@ -247,92 +217,102 @@ struct DecimalBinaryOperation ...@@ -247,92 +217,102 @@ struct DecimalBinaryOperation
if (scale_a != 1) if (scale_a != 1)
{ {
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
c[i] = applyScaled<true>(a[i], b[i], scale_a); c[i] = applyScaled<true>(undec(a[i]), undec(b[i]), scale_a);
return; return;
} }
else if (scale_b != 1) else if (scale_b != 1)
{ {
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
c[i] = applyScaled<false>(a[i], b[i], scale_b); c[i] = applyScaled<false>(undec(a[i]), undec(b[i]), scale_b);
return; return;
} }
} }
else if constexpr (is_division && IsDecimalNumber<B>) else if constexpr (is_division && is_decimal_b)
{ {
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
c[i] = applyScaledDiv(a[i], b[i], scale_a); c[i] = applyScaledDiv<is_decimal_a>(undec(a[i]), undec(b[i]), scale_a);
return; return;
} }
/// default: use it if no return before /// default: use it if no return before
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
c[i] = apply(a[i], b[i]); c[i] = apply(undec(a[i]), undec(b[i]));
} }
template <bool is_decimal_a, bool is_decimal_b, typename ArrayA, typename B>
static void NO_INLINE vectorConstant(const ArrayA & a, B b, ArrayC & c, static void NO_INLINE vectorConstant(const ArrayA & a, B b, ArrayC & c,
NativeResultType scale_a [[maybe_unused]], NativeResultType scale_b [[maybe_unused]]) NativeResultType scale_a [[maybe_unused]], NativeResultType scale_b [[maybe_unused]])
{ {
static_assert(!IsDecimalNumber<B>);
size_t size = a.size(); size_t size = a.size();
if constexpr (is_plus_minus_compare) if constexpr (is_plus_minus_compare)
{ {
if (scale_a != 1) if (scale_a != 1)
{ {
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
c[i] = applyScaled<true>(a[i], b, scale_a); c[i] = applyScaled<true>(undec(a[i]), b, scale_a);
return; return;
} }
else if (scale_b != 1) else if (scale_b != 1)
{ {
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
c[i] = applyScaled<false>(a[i], b, scale_b); c[i] = applyScaled<false>(undec(a[i]), b, scale_b);
return; return;
} }
} }
else if constexpr (is_division && IsDecimalNumber<B>) else if constexpr (is_division && is_decimal_b)
{ {
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
c[i] = applyScaledDiv(a[i], b, scale_a); c[i] = applyScaledDiv<is_decimal_a>(undec(a[i]), b, scale_a);
return; return;
} }
/// default: use it if no return before /// default: use it if no return before
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
c[i] = apply(a[i], b); c[i] = apply(undec(a[i]), b);
} }
template <bool is_decimal_a, bool is_decimal_b, typename A, typename ArrayB>
static void NO_INLINE constantVector(A a, const ArrayB & b, ArrayC & c, static void NO_INLINE constantVector(A a, const ArrayB & b, ArrayC & c,
NativeResultType scale_a [[maybe_unused]], NativeResultType scale_b [[maybe_unused]]) NativeResultType scale_a [[maybe_unused]], NativeResultType scale_b [[maybe_unused]])
{ {
static_assert(!IsDecimalNumber<A>);
size_t size = b.size(); size_t size = b.size();
if constexpr (is_plus_minus_compare) if constexpr (is_plus_minus_compare)
{ {
if (scale_a != 1) if (scale_a != 1)
{ {
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
c[i] = applyScaled<true>(a, b[i], scale_a); c[i] = applyScaled<true>(a, undec(b[i]), scale_a);
return; return;
} }
else if (scale_b != 1) else if (scale_b != 1)
{ {
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
c[i] = applyScaled<false>(a, b[i], scale_b); c[i] = applyScaled<false>(a, undec(b[i]), scale_b);
return; return;
} }
} }
else if constexpr (is_division && IsDecimalNumber<B>) else if constexpr (is_division && is_decimal_b)
{ {
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
c[i] = applyScaledDiv(a, b[i], scale_a); c[i] = applyScaledDiv<is_decimal_a>(a, undec(b[i]), scale_a);
return; return;
} }
/// default: use it if no return before /// default: use it if no return before
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
c[i] = apply(a, b[i]); c[i] = apply(a, undec(b[i]));
} }
template <bool is_decimal_a, bool is_decimal_b, typename A, typename B>
static ResultType constantConstant(A a, B b, NativeResultType scale_a [[maybe_unused]], NativeResultType scale_b [[maybe_unused]]) static ResultType constantConstant(A a, B b, NativeResultType scale_a [[maybe_unused]], NativeResultType scale_b [[maybe_unused]])
{ {
static_assert(!IsDecimalNumber<A>);
static_assert(!IsDecimalNumber<B>);
if constexpr (is_plus_minus_compare) if constexpr (is_plus_minus_compare)
{ {
if (scale_a != 1) if (scale_a != 1)
...@@ -340,64 +320,16 @@ struct DecimalBinaryOperation ...@@ -340,64 +320,16 @@ struct DecimalBinaryOperation
else if (scale_b != 1) else if (scale_b != 1)
return applyScaled<false>(a, b, scale_b); return applyScaled<false>(a, b, scale_b);
} }
else if constexpr (is_division && IsDecimalNumber<B>) else if constexpr (is_division && is_decimal_b)
return applyScaledDiv(a, b, scale_a); return applyScaledDiv<is_decimal_a>(a, b, scale_a);
return apply(a, b); return apply(a, b);
} }
private: private:
template <typename T, typename U>
static NativeResultType apply(const T & a, const U & b)
{
if constexpr (OverBigInt<T> || OverBigInt<U>)
{
if constexpr (IsDecimalNumber<T>)
return apply(a.value, b);
else if constexpr (IsDecimalNumber<U>)
return apply(a, b.value);
else
return applyNative(bigint_cast<NativeResultType>(a), bigint_cast<NativeResultType>(b));
}
else
return applyNative(a, b);
}
template <bool scale_left, typename T, typename U>
static NativeResultType applyScaled(const T & a, const U & b, NativeResultType scale)
{
if constexpr (OverBigInt<T> || OverBigInt<U>)
{
if constexpr (IsDecimalNumber<T>)
return applyScaled<scale_left>(a.value, b, scale);
else if constexpr (IsDecimalNumber<U>)
return applyScaled<scale_left>(a, b.value, scale);
else
return applyNativeScaled<scale_left>(bigint_cast<NativeResultType>(a), bigint_cast<NativeResultType>(b), scale);
}
else
return applyNativeScaled<scale_left>(a, b, scale);
}
template <typename T, typename U>
static NativeResultType applyScaledDiv(const T & a, const U & b, NativeResultType scale)
{
if constexpr (OverBigInt<T> || OverBigInt<U>)
{
if constexpr (IsDecimalNumber<T>)
return applyScaledDiv(a.value, b, scale);
else if constexpr (IsDecimalNumber<U>)
return applyScaledDiv(a, b.value, scale);
else
return applyNativeScaledDiv(bigint_cast<NativeResultType>(a), bigint_cast<NativeResultType>(b), scale);
}
else
return applyNativeScaledDiv(a, b, scale);
}
/// there's implicit type convertion here /// there's implicit type convertion here
static NativeResultType applyNative(NativeResultType a, NativeResultType b) static NativeResultType apply(NativeResultType a, NativeResultType b)
{ {
if constexpr (can_overflow && _check_overflow) if constexpr (can_overflow && check_overflow)
{ {
NativeResultType res; NativeResultType res;
if (Op::template apply<NativeResultType>(a, b, res)) if (Op::template apply<NativeResultType>(a, b, res))
...@@ -409,13 +341,13 @@ private: ...@@ -409,13 +341,13 @@ private:
} }
template <bool scale_left> template <bool scale_left>
static NO_SANITIZE_UNDEFINED NativeResultType applyNativeScaled(NativeResultType a, NativeResultType b, NativeResultType scale) static NO_SANITIZE_UNDEFINED NativeResultType applyScaled(NativeResultType a, NativeResultType b, NativeResultType scale)
{ {
if constexpr (is_plus_minus_compare) if constexpr (is_plus_minus_compare)
{ {
NativeResultType res; NativeResultType res;
if constexpr (_check_overflow) if constexpr (check_overflow)
{ {
bool overflow = false; bool overflow = false;
if constexpr (scale_left) if constexpr (scale_left)
...@@ -444,14 +376,15 @@ private: ...@@ -444,14 +376,15 @@ private:
} }
} }
static NO_SANITIZE_UNDEFINED NativeResultType applyNativeScaledDiv(NativeResultType a, NativeResultType b, NativeResultType scale) template <bool is_decimal_a>
static NO_SANITIZE_UNDEFINED NativeResultType applyScaledDiv(NativeResultType a, NativeResultType b, NativeResultType scale)
{ {
if constexpr (is_division) if constexpr (is_division)
{ {
if constexpr (_check_overflow) if constexpr (check_overflow)
{ {
bool overflow = false; bool overflow = false;
if constexpr (!IsDecimalNumber<A>) if constexpr (!is_decimal_a)
overflow |= common::mulOverflow(scale, scale, scale); overflow |= common::mulOverflow(scale, scale, scale);
overflow |= common::mulOverflow(a, scale, a); overflow |= common::mulOverflow(a, scale, a);
if (overflow) if (overflow)
...@@ -459,7 +392,7 @@ private: ...@@ -459,7 +392,7 @@ private:
} }
else else
{ {
if constexpr (!IsDecimalNumber<A>) if constexpr (!is_decimal_a)
scale *= scale; scale *= scale;
a *= scale; a *= scale;
} }
...@@ -1024,10 +957,15 @@ public: ...@@ -1024,10 +957,15 @@ public:
if constexpr (IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>) if constexpr (IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>)
{ {
using OpImpl = DecimalBinaryOperation<T0, T1, Op, ResultType>; using NativeResultType = typename NativeType<ResultType>::Type;
using OpImpl = DecimalBinaryOperation<Op, ResultType, false>;
using OpImplCheck = DecimalBinaryOperation<Op, ResultType, true>;
ResultDataType type = decimalResultType<is_multiply, is_division>(left, right); ResultDataType type = decimalResultType<is_multiply, is_division>(left, right);
static constexpr const bool dec_a = IsDecimalNumber<T0>;
static constexpr const bool dec_b = IsDecimalNumber<T1>;
typename ResultDataType::FieldType scale_a = type.scaleFactorFor(left, is_multiply); typename ResultDataType::FieldType scale_a = type.scaleFactorFor(left, is_multiply);
typename ResultDataType::FieldType scale_b = type.scaleFactorFor(right, is_multiply || is_division); typename ResultDataType::FieldType scale_b = type.scaleFactorFor(right, is_multiply || is_division);
if constexpr (IsDataTypeDecimal<RightDataType> && is_division) if constexpr (IsDataTypeDecimal<RightDataType> && is_division)
...@@ -1036,8 +974,12 @@ public: ...@@ -1036,8 +974,12 @@ public:
/// non-vector result /// non-vector result
if (col_left_const && col_right_const) if (col_left_const && col_right_const)
{ {
auto res = OpImpl::constantConstant(col_left_const->template getValue<T0>(), col_right_const->template getValue<T1>(), NativeResultType const_a = col_left_const->template getValue<T0>();
scale_a, scale_b, check_decimal_overflow); NativeResultType const_b = col_right_const->template getValue<T1>();
auto res = check_decimal_overflow ?
OpImplCheck::template constantConstant<dec_a, dec_b>(const_a, const_b, scale_a, scale_b) :
OpImpl::template constantConstant<dec_a, dec_b>(const_a, const_b, scale_a, scale_b);
block.getByPosition(result).column = ResultDataType(type.getPrecision(), type.getScale()).createColumnConst( block.getByPosition(result).column = ResultDataType(type.getPrecision(), type.getScale()).createColumnConst(
col_left_const->size(), toField(res, type.getScale())); col_left_const->size(), toField(res, type.getScale()));
...@@ -1050,17 +992,28 @@ public: ...@@ -1050,17 +992,28 @@ public:
if (col_left && col_right) if (col_left && col_right)
{ {
OpImpl::vectorVector(col_left->getData(), col_right->getData(), vec_res, scale_a, scale_b, check_decimal_overflow); if (check_decimal_overflow)
OpImplCheck::template vectorVector<dec_a, dec_b>(col_left->getData(), col_right->getData(), vec_res, scale_a, scale_b);
else
OpImpl::template vectorVector<dec_a, dec_b>(col_left->getData(), col_right->getData(), vec_res, scale_a, scale_b);
} }
else if (col_left_const && col_right) else if (col_left_const && col_right)
{ {
OpImpl::constantVector(col_left_const->template getValue<T0>(), col_right->getData(), vec_res, NativeResultType const_a = col_left_const->template getValue<T0>();
scale_a, scale_b, check_decimal_overflow);
if (check_decimal_overflow)
OpImplCheck::template constantVector<dec_a, dec_b>(const_a, col_right->getData(), vec_res, scale_a, scale_b);
else
OpImpl::template constantVector<dec_a, dec_b>(const_a, col_right->getData(), vec_res, scale_a, scale_b);
} }
else if (col_left && col_right_const) else if (col_left && col_right_const)
{ {
OpImpl::vectorConstant(col_left->getData(), col_right_const->template getValue<T1>(), vec_res, NativeResultType const_b = col_right_const->template getValue<T1>();
scale_a, scale_b, check_decimal_overflow);
if (check_decimal_overflow)
OpImplCheck::template vectorConstant<dec_a, dec_b>(col_left->getData(), const_b, vec_res, scale_a, scale_b);
else
OpImpl::template vectorConstant<dec_a, dec_b>(col_left->getData(), const_b, vec_res, scale_a, scale_b);
} }
else else
return false; return false;
......
...@@ -22,7 +22,7 @@ namespace ...@@ -22,7 +22,7 @@ namespace
template <typename A, typename B> template <typename A, typename B>
struct DivideIntegralByConstantImpl struct DivideIntegralByConstantImpl
: BinaryOperationImplBase<A, B, DivideIntegralImpl<A, B>> : BinaryOperation<A, B, DivideIntegralImpl<A, B>>
{ {
using ResultType = typename DivideIntegralImpl<A, B>::ResultType; using ResultType = typename DivideIntegralImpl<A, B>::ResultType;
static const constexpr bool allow_fixed_string = false; static const constexpr bool allow_fixed_string = false;
......
...@@ -22,7 +22,7 @@ namespace ...@@ -22,7 +22,7 @@ namespace
template <typename A, typename B> template <typename A, typename B>
struct ModuloByConstantImpl struct ModuloByConstantImpl
: BinaryOperationImplBase<A, B, ModuloImpl<A, B>> : BinaryOperation<A, B, ModuloImpl<A, B>>
{ {
using ResultType = typename ModuloImpl<A, B>::ResultType; using ResultType = typename ModuloImpl<A, B>::ResultType;
static const constexpr bool allow_fixed_string = false; static const constexpr bool allow_fixed_string = false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册