/** * \file src/core/impl/dtype.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/dtype.h" #include "megbrain/common.h" #include "megbrain/exception.h" #include "megbrain/utils/arith_helper.h" #include "megbrain/utils/metahelper.h" #include #include using namespace mgb; namespace { template struct SafeCastFloatCheck; template <> struct SafeCastFloatCheck { template static void check(U val) { MGB_MARK_USED_VAR(val); } }; template <> struct SafeCastFloatCheck { static void check(float val) { mgb_throw_if( fabs(val) > 16777216 || ceilf(val) != val, ConversionError, "can not convert float value %g to int " "without precession loss", val); } static void check(int val) { mgb_throw_if( abs(val) > 16777216, ConversionError, "can not convert int value %d to float " "without precession loss", val); } }; template T static_cast_safe(U from) { constexpr bool integral_diff = (std::is_integral::value ^ std::is_integral::value) && !(std::is_same::value); SafeCastFloatCheck::check(from); return static_cast(from); } template using QuantizedCType = std::enable_if_t::category == DTypeCategory::QUANTIZED, T>; template void batched_static_cast(T* dest, const U* src, size_t nr, DType src_dtype) { for (size_t i = 0; i < nr; ++i) dest[i] = static_cast(src[i]); } template void batched_static_cast( T* dest, const QuantizedCType* src, size_t nr, DType src_dtype) { const auto& param = src_dtype.param::dtype>(); for (size_t i = 0; i < nr; ++i) { dest[i] = static_cast(param.dequantize(src[i])); } } #define cb(_name, _bits) \ template \ void batched_static_cast( \ T* dest, const megdnn::dt_##_name##_bits* src, size_t nr, \ DType src_dtype) { \ std::unique_ptr unpacked_byte(new int8_t[nr]); \ lowbit_memcpy_compact2byte( \ megdnn::dtype::_name##_bits(), unpacked_byte.get(), src, nr); \ for (size_t i = 0; i < nr; ++i) \ dest[i] = static_cast(unpacked_byte[i]); \ } MEGDNN_FOREACH_LOWBIT_DTYPE(cb) #undef cb template void batched_static_cast_safe(T* dest, const U* src, size_t nr, DType src_dtype) { for (size_t i = 0; i < nr; ++i) dest[i] = static_cast_safe(src[i]); } template void batched_static_cast_safe( T* dest, const QuantizedCType* src, size_t nr, DType src_dtype) { const auto& param = src_dtype.param::dtype>(); for (size_t i = 0; i < nr; ++i) { dest[i] = static_cast_safe(param.dequantize(src[i])); } } #define cb(_name, _bits) \ template \ void batched_static_cast_safe( \ T* dest, const megdnn::dt_##_name##_bits* src, size_t nr, \ DType src_dtype) { \ std::unique_ptr unpacked_byte(new int8_t[nr]); \ lowbit_memcpy_compact2byte( \ megdnn::dtype::_name##_bits(), unpacked_byte.get(), src, nr); \ for (size_t i = 0; i < nr; ++i) \ dest[i] = static_cast_safe(unpacked_byte[i]); \ } MEGDNN_FOREACH_LOWBIT_DTYPE(cb) #undef cb } // anonymous namespace template void mgb::static_cast_dtype( T* dest, DType src_type, const void* storage, size_t nr_elem) { switch (src_type.enumv()) { #define cb(_dt) \ case DTypeTrait<_dt>::enumv: \ return batched_static_cast::ctype>( \ dest, static_cast::ctype*>(storage), nr_elem, \ src_type); MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) cb(::megdnn::dtype::Bool) #undef cb #define cb(_name, _bits) \ case DTypeTrait::enumv: \ return batched_static_cast( \ dest, \ static_cast::ctype*>(storage), \ nr_elem, src_type); MEGDNN_FOREACH_LOWBIT_DTYPE(cb) #undef cb default : mgb_throw( ConversionError, "can not convert from dtype %s", src_type.name()); } } template void mgb::static_cast_dtype_safe( T* dest, DType src_type, const void* storage, size_t nr_elem) { switch (src_type.enumv()) { #define cb(_dt) \ case DTypeTrait<_dt>::enumv: \ return batched_static_cast_safe::ctype>( \ dest, static_cast::ctype*>(storage), nr_elem, \ src_type); MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) #undef cb #define cb(_name, _bits) \ case DTypeTrait::enumv: \ return batched_static_cast_safe( \ dest, \ static_cast::ctype*>(storage), \ nr_elem, src_type); MEGDNN_FOREACH_LOWBIT_DTYPE(cb) #undef cb default: mgb_throw( ConversionError, "can not convert from dtype %s", src_type.name()); } } namespace mgb { #define INST(t) \ template void static_cast_dtype(t*, DType, const void*, size_t); \ template void static_cast_dtype_safe(t*, DType, const void*, size_t) INST(bool); INST(unsigned); INST(int); INST(unsigned long); INST(long); INST(float); INST(double); INST(long long); INST(unsigned long long); #undef INST template typename ctype_enable_if::type DTypeScalar::set_retain_dtype(ctype val) { switch (m_dtype.enumv()) { #define cb(_dt) \ case DTypeTrait<_dt>::enumv: { \ using mct = DTypeTrait<_dt>::ctype; \ static_assert(sizeof(mct) <= sizeof(m_storage), "large ctype"); \ visit() = static_cast(val); \ return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE(cb) MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) cb(dt_bool); cb(dt_uint16); #undef cb default: mgb_throw(ConversionError, "can not assign to dtype %s", m_dtype.name()); } } #define INST(t) template void DTypeScalar::set_retain_dtype(t); INST(int); INST(float); #undef INST } // namespace mgb DTypeScalar& DTypeScalar::set_raw(DType dtype, const void* storage) { mgb_assert(dtype.valid() && dtype.size(1) <= sizeof(m_storage)); m_dtype = dtype; memcpy(&m_storage, storage, dtype.size(1)); return *this; } DType mgb::dtype_promotion(DType t0, DType t1) { mgb_assert(t0 != dtype::Byte() && t1 != dtype::Byte()); if (t0 == t1) return t0; // Now t0 != t1. if (t0.category() == DTypeCategory::QUANTIZED && t1.category() == DTypeCategory::QUANTIZED) { mgb_assert( t0.enumv() == t1.enumv(), "promoting unexpected quantized DType: %s and %s", t0.name(), t1.name()); if (t0.enumv() == DTypeEnum::Quantized8Asymm) { auto& param0 = t0.param(); auto& param1 = t1.param(); mgb_assert( param0.zero_point == param1.zero_point && fabs(param0.scale - param1.scale) < 1e-6, "trying to promote two Quantized8Asymm with different scale " "or zero_point, this usually does not make sense: (%f, %u) " "vs (%f, %u)", param0.scale, param0.zero_point, param1.scale, param1.zero_point); return t0; } else if (t0.enumv() == DTypeEnum::QuantizedS8) { auto& param0 = t0.param(); auto& param1 = t1.param(); mgb_assert( fabs(param0.scale - param1.scale) < 1e-6, "trying to promote two QuantizedS8 with different " "scale, this usually does not make sense: %f vs %f", param0.scale, param1.scale); return t0; } else { mgb_assert( t0.enumv() == DTypeEnum::QuantizedS32, "promoting unsupported quantized DType: %s", t0.name()); auto& param0 = t0.param(); auto& param1 = t1.param(); mgb_assert( fabs(param0.scale - param1.scale) < 1e-6, "trying to promote two QuantizedS32 with different " "scale, this usually does not make sense: %f vs %f", param0.scale, param1.scale); return t0; } } else if (t0.category() == DTypeCategory::QUANTIZED) { return t0; } else if (t1.category() == DTypeCategory::QUANTIZED) { return t1; } #if !MEGDNN_DISABLE_FLOAT16 if (t0 == dtype::Float16()) t0 = dtype::Float32(); if (t1 == dtype::Float16()) t1 = dtype::Float32(); #endif if (t0.category() != t1.category()) { return dtype::Float32(); } mgb_throw_if( t0.signedness() != t1.signedness(), ConversionError, "dtype promotion rule between different signedness is undefined: " "%s %s", t0.name(), t1.name()); if (t0.size() > t1.size()) return t0; return t1; } /* ================== lowbit memcpy ================== */ namespace { template struct LowbitMemcpy; template struct LowbitTrait; template <> struct LowbitTrait<1> { // intb1: -1, 1 static constexpr int8_t SHIFT = 1, STEP = 2; }; template <> struct LowbitTrait<2> { // intb2: -3, -1, 1, 3 static constexpr int8_t SHIFT = 3, STEP = 2; }; template <> struct LowbitTrait<4> { // intb2: -15 to 15 static constexpr int8_t SHIFT = 15, STEP = 2; }; template struct LowbitMemcpy { // cast with bits that 8 % bits == 0 static constexpr uint8_t MASK = (1 << bits) - 1; using Trait = LowbitTrait; static void byte2compact(void* dest_raw, const void* src_raw, size_t n) { auto dest = static_cast(dest_raw); auto src = static_cast(src_raw); memset(dest, 0, divup(n * bits, 8)); for (size_t i = 0; i < n; ++i) { int8_t val = src[i]; mgb_assert( val + Trait::SHIFT >= 0 && ((val + Trait::SHIFT) % Trait::STEP) == 0); val = (val + Trait::SHIFT) / Trait::STEP; mgb_assert(val >= 0 && val < (1 << bits)); dest[i * bits / 8] |= val << (i * bits % 8); } } static void compact2byte(void* dest_raw, const void* src_raw, size_t n) { auto dest = static_cast(dest_raw); auto src = static_cast(src_raw); for (size_t i = 0; i < n; ++i) { int8_t val = ((src[i * bits / 8] >> (i * bits % 8)) & MASK); dest[i] = val * Trait::STEP - Trait::SHIFT; } } }; template < typename DT, bool div_byte = (DTypeTrait
::category == DTypeCategory::QUANTIZED) && (8 % DTypeTrait
::low_bit == 0)> struct QuantizedLowbitMemcpy; template struct QuantizedLowbitMemcpy { // cast with bits that 8 % bits == 0 static constexpr uint16_t bits = DTypeTrait
::low_bit; static constexpr uint8_t MASK = (1 << bits) - 1; static constexpr bool signedness = std::is_same::value; static void byte2compact(void* dest_raw, const void* src_raw, size_t n) { auto dest = static_cast(dest_raw); auto src = static_cast(src_raw); memset(dest, 0, divup(n * bits, 8)); for (size_t i = 0; i < n; ++i) { int8_t val = src[i]; static const auto min_val = DTypeTrait
::min(); static const auto max_val = DTypeTrait
::max(); MGB_MARK_USED_VAR(min_val); MGB_MARK_USED_VAR(max_val); mgb_assert( val >= static_cast(min_val) && val <= static_cast(max_val), "data exceeds range(%d,%d) of data type", min_val, max_val); dest[i * bits / 8] |= (val & MASK) << (i * bits % 8); } } static void compact2byte(void* dest_raw, const void* src_raw, size_t n) { auto dest = reinterpret_cast(dest_raw); auto src = static_cast(src_raw); for (size_t i = 0; i < n; ++i) { uint8_t intermediate = ((src[i * bits / 8] >> (i * bits % 8)) & MASK); if (signedness) { int val = (intermediate & uint8_t(1 << (bits - 1))) ? ((int)(intermediate) | ~(int)(MASK)) : (int)(intermediate); dest[i] = static_cast(val); } else { dest[i] = static_cast(intermediate); } } } }; } // anonymous namespace void mgb::lowbit_memcpy_byte2compact( DType dtype, void* dest, const void* src, size_t n) { #define cb(name, bits) \ if (dtype == mgb::dtype::name##bits()) \ return LowbitMemcpy::byte2compact(dest, src, n); MEGDNN_FOREACH_LOWBIT_DTYPE(cb) #undef cb #define cb(dt) \ if (dtype.enumv() == DTypeTrait
::enumv) \ return QuantizedLowbitMemcpy
::byte2compact(dest, src, n); MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) #undef cb mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); } void mgb::lowbit_memcpy_compact2byte( DType dtype, void* dest, const void* src, size_t n) { #define cb(name, bits) \ if (dtype == mgb::dtype::name##bits()) \ return LowbitMemcpy::compact2byte(dest, src, n); MEGDNN_FOREACH_LOWBIT_DTYPE(cb) #undef cb #define cb(dt) \ if (dtype.enumv() == DTypeTrait
::enumv) \ return QuantizedLowbitMemcpy
::compact2byte(dest, src, n); MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) #undef cb mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); } void mgb::lowbit_memcpy_byte2aligned( void* dest, const void* src, const ::megdnn::TensorLayout& layout) { size_t low_bit = layout.dtype.low_bit(); size_t dim = layout.shape[layout.ndim - 1]; if ((dim * low_bit) % 8) { // padding size_t n = layout.total_nr_elems(); size_t stride = divup(dim * low_bit, 8); dt_byte* dest_ptr = reinterpret_cast(dest); const dt_byte* src_ptr = reinterpret_cast(src); for (size_t i = 0; i < n / dim; ++i) { lowbit_memcpy_byte2compact(layout.dtype, dest_ptr, src_ptr, dim); dest_ptr += stride; src_ptr += dim; } } else { lowbit_memcpy_byte2compact(layout.dtype, dest, src, layout.total_nr_elems()); } } void mgb::lowbit_memcpy_aligned2byte( void* dest, const void* src, const ::megdnn::TensorLayout& layout) { size_t low_bit = layout.dtype.low_bit(); size_t dim = layout.shape[layout.ndim - 1]; if ((dim * low_bit) % 8) { // padding size_t n = layout.total_nr_elems(); size_t stride = divup(dim * low_bit, 8); dt_byte* dest_ptr = reinterpret_cast(dest); const dt_byte* src_ptr = reinterpret_cast(src); for (size_t i = 0; i < n / dim; ++i) { lowbit_memcpy_compact2byte(layout.dtype, dest_ptr, src_ptr, dim); dest_ptr += dim; src_ptr += stride; } } else { lowbit_memcpy_compact2byte(layout.dtype, dest, src, layout.total_nr_elems()); } } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}