diff --git a/dnn/include/megdnn/dtype.h b/dnn/include/megdnn/dtype.h index 153c777286b65df0aaedc7634a3bb6c2e3c82edf..539934ce9dfb3ddc7890d572cc9c26e96bff4356 100644 --- a/dnn/include/megdnn/dtype.h +++ b/dnn/include/megdnn/dtype.h @@ -505,9 +505,9 @@ class DType { return std::numeric_limits::max() >> m_trait->size_log; } - bool is_low_bit() const { - return m_trait->low_bit != 0; - } + size_t low_bit() const { return m_trait->low_bit; } + + bool is_low_bit() const { return low_bit() != 0; } /*! * \brief size of this data type, in bytes diff --git a/dnn/include/megdnn/tensor_format.h b/dnn/include/megdnn/tensor_format.h index 87f7065b8cddb6844cb4d42524d28d6efb395d19..f8dd374615373ed3e17716b398949387ebbf68ee 100644 --- a/dnn/include/megdnn/tensor_format.h +++ b/dnn/include/megdnn/tensor_format.h @@ -20,12 +20,15 @@ namespace megdnn { enum class TensorFormat::Type { DEFAULT = 0, //!< see DefaultTensorFormat IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat + FOURBITS_ALIGNED_TO_BYTE = 2, //!< }; class TensorFormat::ImplBase { public: using Type = TensorFormat::Type; + virtual void assert_valid(const TensorLayout& layout) const = 0; + virtual size_t init_contiguous_stride(TensorLayout& layout) const = 0; virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0; @@ -63,6 +66,8 @@ public: DefaultTensorFormat() : ImplBase(TYPE) {} + void assert_valid(const TensorLayout& layout) const override; + size_t init_contiguous_stride(TensorLayout& layout) const override; /*! @@ -180,11 +185,11 @@ public: */ size_t image_width(const TensorLayout& layout) const; - //! raise exception if preconditions violated - void assert_valid(const TensorLayout& layout) const; - size_t image_row_pitch(const TensorLayout& layout) const; + //! raise exception if preconditions violated + void assert_valid(const TensorLayout& layout) const override; + //! span for image must include the padding at the last row TensorLayout::Span span_spec(const TensorLayout& layout) const override; @@ -197,31 +202,48 @@ public: }; using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; -///*! -// * \brief used for tensors with lowbit data type -// * -// * \p SIZE_NBITS is the size in bits of element of the tensor. -// * -// */ -//template -//class LowbitTensorFormat : public TensorFormat::ImplBase { -// static constexpr size_t SIZE_NBITS = SIZE_NBITS_; -// size_t m_align_size_in_bits; -// -//protected: //? -// LowbitTensorFormat(Type type, size_t m_align_size_in_bits); -// -//public: -// size_t align_size_in_bits() const { -// return m_align_size_in_bits; -// } -// -// std::string to_string() const override; -// -// void serialize_append( -// -// -//}; +/*! + * \brief used for tensors storing lowbit data + * + * \p SIZE_NBITS is the size in bits of element of the tensor. + * + */ +template +class LowbitsTensorFormatBase : public TensorFormat::ImplBase { + static constexpr size_t SIZE_NBITS = SIZE_NBITS_; + size_t m_align_size_in_bits, m_align_size_in_elements; + +protected: //? + LowbitsTensorFormatBase(Type type, size_t align_size_in_bits); + + virtual ~LowbitsTensorFormatBase() = default; + +public: + size_t align_size_in_bits() const { return m_align_size_in_bits; } + + std::string to_string() const override; + + //! raise exception if given layout is illegal + void assert_valid(const TensorLayout& layout) const; + + void serialize_append(std::string& result) const override; + + //! span for lowbit tensor must include the padding at the innermost + //! dimemsion that make lowbit tensor be aligned to bytes + TensorLayout::Span span_spec(const TensorLayout& layout) const override; + + size_t init_contiguous_stride(TensorLayout& layout) const override; + + bool is_contiguous_spec(const TensorLayout& layout) const override; + + TensorLayout collapse_contiguous_spec( + const TensorLayout& layout) const override; +protected: + struct SerializePack { + uint8_t align_size_in_bits; + }; +}; +using FourBitsAlignedToBytesTensorFormatBase = LowbitsTensorFormatBase<4>; } // namespace detail /*! @@ -270,6 +292,34 @@ private: TYPE, align_axis, align_size_in_elements, vendor_type) {} }; +/*! + * \brief Tensor for storing 4bit data that requires stride corresponding to + * non-innermost dimension to be aligned to bytes, and pack 2 elems into a byte + */ +class FourBitsAlignedToBytesTensorFormat final + : public detail::FourBitsAlignedToBytesTensorFormatBase { +public: + static constexpr Type TYPE = Type::FOURBITS_ALIGNED_TO_BYTE; + + static TensorFormat make(size_t align_size_in_bits); + + static TensorFormat deserialize(const Handle* handle, const void* buf, + size_t size); + + static bool is_valid_layout(const TensorLayout& layout) { + if (layout.format.type() == TYPE) { + layout.format.as_impl() + .assert_valid(layout); + return true; + } + return false; + } + +private: + FourBitsAlignedToBytesTensorFormat(size_t align_size_in_bits) + : detail::FourBitsAlignedToBytesTensorFormatBase( + TYPE, align_size_in_bits) {} +}; } // namespace megdnn #include "megdnn/internal/visibility_epilogue.h" diff --git a/dnn/src/common/basic_types.cpp b/dnn/src/common/basic_types.cpp index 081178991e99c1894ddba656531b725960145348..78700e895ba515c98a51d0fa3e0b5820713fddfe 100644 --- a/dnn/src/common/basic_types.cpp +++ b/dnn/src/common/basic_types.cpp @@ -201,7 +201,15 @@ TensorLayout::TensorLayout(DType dtype_, Format format_) : dtype{dtype_}, format{format_} {} TensorLayout::TensorLayout(const TensorShape& shape, DType dtype) - : TensorLayout(shape, dtype, DefaultTensorFormat::make()) {} + : TensorShape(shape), dtype{dtype} { + if (dtype.low_bit() == 4_z) { + format = FourBitsAlignedToBytesTensorFormat::make(8_z); + } else { + megdnn_assert(!dtype.is_low_bit(), "Unsupported data type(%s)", + dtype.name()); + format = DefaultTensorFormat::make(); + } +} TensorLayout::TensorLayout(const TensorShape& shape, DType dtype, TensorFormat format_) diff --git a/dnn/src/common/tensor_format.cpp b/dnn/src/common/tensor_format.cpp index e61755bda308e45f82269350c35c354e3da247d3..33a19b74680c8d5686517d94042abb94e9a51db6 100644 --- a/dnn/src/common/tensor_format.cpp +++ b/dnn/src/common/tensor_format.cpp @@ -35,6 +35,9 @@ TensorFormat TensorFormat::deserialize(const std::string& bin, case Type::IMAGE2D_PACK4: return Image2DPack4TensorFormat::deserialize( handle, type + 1, bin.size() - sizeof(Type)); + case Type::FOURBITS_ALIGNED_TO_BYTE: + return FourBitsAlignedToBytesTensorFormat::deserialize( + handle, type + 1, bin.size() - sizeof(Type)); default: megdnn_throw("invalid tensor format type in deserialize"); } @@ -67,7 +70,15 @@ bool TensorFormat::is_default() const { } /* ===================== DefaultFormat ===================== */ +void DefaultTensorFormat::assert_valid(const TensorLayout& layout) const { + megdnn_assert( + !layout.dtype.valid() || !layout.dtype.is_low_bit(), + "DefaultTensorFormat does not support low-bits tensor(dtype:%s)", + layout.dtype.name()); +} + size_t DefaultTensorFormat::init_contiguous_stride(TensorLayout& layout) const { + assert_valid(layout); if (!layout.ndim) return 0; megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM); @@ -81,11 +92,13 @@ size_t DefaultTensorFormat::init_contiguous_stride(TensorLayout& layout) const { } bool DefaultTensorFormat::is_contiguous_spec(const TensorLayout& layout) const { + assert_valid(layout); return layout.is_physical_contiguous(); } TensorLayout DefaultTensorFormat::collapse_contiguous_spec( const TensorLayout& layout) const { + assert_valid(layout); megdnn_assert(layout.ndim); TensorLayout res{layout}; @@ -126,6 +139,7 @@ TensorLayout DefaultTensorFormat::collapse_contiguous_spec( TensorLayout::Span DefaultTensorFormat::span_spec( const TensorLayout& layout) const { + assert_valid(layout); if (layout.ndim == 0) return {0, 0, 0, 0}; @@ -146,9 +160,6 @@ TensorLayout::Span DefaultTensorFormat::span_spec( ++high_elem; ptrdiff_t low_byte; if (low_elem < 0) { - megdnn_assert(!layout.dtype.is_low_bit(), - "tensors with low-bit dytes shouldn't have negative " - "strides"); low_byte = low_elem * layout.dtype.size(); } else { low_byte = 0; @@ -422,12 +433,151 @@ TensorLayout Image2DPackedTensorFormatBase::collapse_contiguous_spec return res; } + namespace megdnn { namespace detail { template class Image2DPackedTensorFormatBase<4>; } // namespace detail } // namespace megdnn +/* =============== FourBitsAlignedToBytesTensorFormatBase ============== */ +template +LowbitsTensorFormatBase::LowbitsTensorFormatBase( + Type type, size_t align_size_in_bits) + : ImplBase(type), m_align_size_in_bits(align_size_in_bits) { + megdnn_assert(!(m_align_size_in_bits % SIZE_NBITS), + "align size(%zu) must be a multiple of element size(%zu)", + m_align_size_in_bits, SIZE_NBITS); + m_align_size_in_elements = m_align_size_in_bits / SIZE_NBITS; +} + +template +std::string LowbitsTensorFormatBase::to_string() const { + return ssprintf("LOWBITS{%zu,%zu}", SIZE_NBITS, m_align_size_in_bits); +} + +template +void LowbitsTensorFormatBase::assert_valid( + const TensorLayout& layout) const { + megdnn_assert(layout.dtype.valid() && layout.dtype.is_low_bit() && + layout.dtype.low_bit() == SIZE_NBITS); + bool has_dim_unity_stride = false; + for (int i = layout.ndim - 1; i >= 0; --i) { + if (!has_dim_unity_stride && layout.stride[i] == 1) + has_dim_unity_stride = true; + megdnn_assert( + layout.stride[i] >= 0 && + (layout.stride[i] % m_align_size_in_elements == 0 || + layout.stride[i] == 1), + "bad stride: %zu", layout.stride[i]); + } + megdnn_assert(has_dim_unity_stride, "innermost dim not contiguous"); +} + +template +void LowbitsTensorFormatBase::serialize_append( + std::string& result) const { + SerializePack pack; + pack.align_size_in_bits = m_align_size_in_bits; + megdnn_assert(pack.align_size_in_bits == + m_align_size_in_bits); // detect overflow; + result.append(reinterpret_cast(&pack), sizeof(pack)); +} + +template +TensorLayout::Span LowbitsTensorFormatBase::span_spec( + const TensorLayout& layout) const { + assert_valid(layout); + if (layout.ndim == 0) + return {0, 0, 0, 0}; + + size_t high_elem = 0; + for (size_t i = 0; i < layout.ndim; ++i) { + auto shape_val = layout.shape[i]; + if (!shape_val) { + return {0, 0, 0, 0}; + } + auto stride_val = layout.stride[i]; + megdnn_assert(stride_val >= 0, + "lowbit tensors shouldn't have negative strides"); + high_elem += (shape_val - 1) * stride_val; + } + ++high_elem; + size_t high_byte = layout.dtype.size(high_elem); + return TensorLayout::Span(0, 0, high_elem, high_byte); +} + +template +size_t LowbitsTensorFormatBase::init_contiguous_stride( + TensorLayout& layout) const { + if (!layout.ndim) + return 0; + megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM); + size_t accum = 1; + SafeMultiplies mul; + for (size_t i = layout.ndim; i; --i) { + layout.stride[i - 1] = accum; + auto multiplier = layout.shape[i - 1]; + if (i == layout.ndim) + multiplier = round_up(multiplier, m_align_size_in_elements); + accum = mul(accum, multiplier); + } + return accum; +} + +template +bool LowbitsTensorFormatBase::is_contiguous_spec( + const TensorLayout& layout) const { + assert_valid(layout); + ptrdiff_t expected = 1; + for (int i = static_cast(layout.ndim) - 1; i >= 0; --i) { + if (layout.shape[i] != 1 && layout.stride[i] != expected) + return false; + auto multiplier = layout.shape[i]; + if (i == layout.ndim - 1) + multiplier = round_up(multiplier, m_align_size_in_elements); + expected *= multiplier; + } + return expected != 0; +} + +template +TensorLayout LowbitsTensorFormatBase::collapse_contiguous_spec( + const TensorLayout& layout) const { + assert_valid(layout); + TensorLayout res{layout}; + for (int i = static_cast(res.ndim) - 1; i >= 0; --i) { + if (!res.shape[i]) { + // empty tensor + res.ndim = 1; + res.shape[0] = 0; + res.stride[0] = 1; + return res; + } + if (res.shape[i] == 1) { + res.remove_axis_inplace(i); + } + } + + megdnn_assert(res.ndim && res.shape[res.ndim - 1]); + for (int i = static_cast(res.ndim) - 2; i >= 0; --i) { + megdnn_assert(res.shape[i]); + if (res.stride[i] == + res.stride[i + 1] * static_cast(res.shape[i + 1])) { + res.shape[i] *= res.shape[i + 1]; + res.stride[i] = res.stride[i + 1]; + res.remove_axis_inplace(i + 1); + } + } + return res; +} + +namespace megdnn { +namespace detail { +template class LowbitsTensorFormatBase<4>; +} // namespace detail +} // namespace megdnn + /* ===================== Image2DPack4TensorFormat ===================== */ TensorFormat Image2DPack4TensorFormat::make_raw( size_t align_axis, size_t align_size_in_elements, @@ -466,4 +616,29 @@ TensorFormat Image2DPack4TensorFormat::change_axis(size_t axis) const { return make_raw(axis, align_size_in_elements(), vendor()); } +/* ===================== FourBitsAlignedToBytesTensorFormat + * ===================== */ +TensorFormat FourBitsAlignedToBytesTensorFormat::make( + size_t align_size_in_bits) { + static std::mutex mtx; + static std::unordered_map< + uint32_t, std::unique_ptr> + cache; + megdnn_assert(!(align_size_in_bits % 4)); + MEGDNN_LOCK_GUARD(mtx); + auto&& ptr = cache[static_cast(align_size_in_bits)]; + if (!ptr) { + ptr.reset(new FourBitsAlignedToBytesTensorFormat{align_size_in_bits}); + } + return impl_to_tensor_format(ptr.get()); +} + +TensorFormat FourBitsAlignedToBytesTensorFormat::deserialize(const Handle*, + const void* buf, + size_t size) { + megdnn_assert(size == sizeof(SerializePack)); + auto pack = *static_cast(buf); + return make(pack.align_size_in_bits); +} + // vim: syntax=cpp.doxygen diff --git a/dnn/test/common/checker.h b/dnn/test/common/checker.h index eff092c5880353b01edb6c4c6547d6f5f1061eb4..044cdd9c4e14e3b49c7a60e7fedbb32ff1216107 100644 --- a/dnn/test/common/checker.h +++ b/dnn/test/common/checker.h @@ -128,9 +128,11 @@ public: for (size_t i = 0; i < shapes.size(); ++i) { DType dt = (m_dtype.find(i) != m_dtype.end() ? m_dtype[i] : dtype::Float32()); - TensorFormat fmt = - (m_fmt.find(i) != m_fmt.end() ? m_fmt[i] : TensorFormat{}); - layouts[i] = TensorLayout(shapes[i], dt, fmt); + if (m_fmt.find(i) == m_fmt.end()) { + layouts[i] = TensorLayout(shapes[i], dt); + layouts[i].init_contiguous_stride(); + } else + layouts[i] = TensorLayout(shapes[i], dt, m_fmt[i]); } return layouts; } diff --git a/dnn/test/common/test_basic_types.cpp b/dnn/test/common/test_basic_types.cpp index afed66ef137fae950cc4bc52f7e95f5393a21d7b..d0b2b28539b61a1752fe8a714b4f96a4fade3534 100644 --- a/dnn/test/common/test_basic_types.cpp +++ b/dnn/test/common/test_basic_types.cpp @@ -302,4 +302,50 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) { } } +TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_LOW_BITS) { + TensorLayout layout{{16, 32, 7, 7}, dtype::QuantizedS4{1.2f}}; + layout.init_contiguous_stride(); + ASSERT_EQ(layout.stride[0], 1792); + ASSERT_EQ(layout.stride[1], 56); + ASSERT_EQ(layout.stride[2], 8); + ASSERT_EQ(layout.stride[3], 1); + auto span = layout.span(); + ASSERT_EQ(0, span.low_elem); + ASSERT_EQ(28671, span.high_elem); + ASSERT_EQ(0, span.low_byte); + ASSERT_EQ(14336, span.high_byte); + EXPECT_EQ(make_layout({3584, 7}, {8, 1}, dtype::QuantizedS4{1.2f}), + layout.collapse_contiguous()); + + + layout = make_layout({16, 32, 7, 7}, {1792, 56, 8, 1}, + dtype::QuantizedS4{1.3f}); + layout.format = FourBitsAlignedToBytesTensorFormat::make(8_z); + EXPECT_TRUE(layout.is_contiguous()); + + layout = TensorLayout{{1, 32, 1, 1}, dtype::QuantizedS4{1.2f}}; + layout = layout.broadcast({16, 32, 7, 7}); + EXPECT_EQ(make_layout({16, 32, 49}, {0, 1, 0}, dtype::QuantizedS4{1.2}), + layout.collapse_contiguous()); + + layout = TensorLayout{{1, 32, 1, 1}, dtype::QuantizedS4{1.2f}}; + layout.init_contiguous_stride(); + layout = layout.broadcast({16, 32, 7, 7}); + ASSERT_THROW(layout.span(), MegDNNError); +} + +TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_LOW_BITS_VALID) { + ASSERT_THROW(TensorLayout({1, 32, 1, 1}, dtype::QuantizedS4{1.2f}, + DefaultTensorFormat::make()), + MegDNNError); + ASSERT_THROW(TensorLayout({1, 32, 1, 1}, dtype::QuantizedS32{1.2f}, + FourBitsAlignedToBytesTensorFormat::make(8_z)) + .span(), + MegDNNError); + ASSERT_THROW(TensorLayout({16, 32, 7, 7}, dtype::IntB2{}, + FourBitsAlignedToBytesTensorFormat::make(8_z)) + .span(), + MegDNNError); +} + // vim: syntax=cpp.doxygen