提交 91d61607 编写于 作者: M Megvii Engine Team

feat(dnn/common): add tensor format for low-bits tensor layout

GitOrigin-RevId: 0aa3753f37ae73b338fcd5bbe758d432b9e66261
上级 19a554d6
......@@ -505,9 +505,9 @@ class DType {
return std::numeric_limits<size_t>::max() >> m_trait->size_log;
}
bool is_low_bit() const {
return m_trait->low_bit != 0;
}
size_t low_bit() const { return m_trait->low_bit; }
bool is_low_bit() const { return low_bit() != 0; }
/*!
* \brief size of this data type, in bytes
......
......@@ -20,12 +20,15 @@ namespace megdnn {
enum class TensorFormat::Type {
DEFAULT = 0, //!< see DefaultTensorFormat
IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat
FOURBITS_ALIGNED_TO_BYTE = 2, //!<
};
class TensorFormat::ImplBase {
public:
using Type = TensorFormat::Type;
virtual void assert_valid(const TensorLayout& layout) const = 0;
virtual size_t init_contiguous_stride(TensorLayout& layout) const = 0;
virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0;
......@@ -63,6 +66,8 @@ public:
DefaultTensorFormat() : ImplBase(TYPE) {}
void assert_valid(const TensorLayout& layout) const override;
size_t init_contiguous_stride(TensorLayout& layout) const override;
/*!
......@@ -180,11 +185,11 @@ public:
*/
size_t image_width(const TensorLayout& layout) const;
//! raise exception if preconditions violated
void assert_valid(const TensorLayout& layout) const;
size_t image_row_pitch(const TensorLayout& layout) const;
//! raise exception if preconditions violated
void assert_valid(const TensorLayout& layout) const override;
//! span for image must include the padding at the last row
TensorLayout::Span span_spec(const TensorLayout& layout) const override;
......@@ -197,31 +202,48 @@ public:
};
using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>;
///*!
// * \brief used for tensors with lowbit data type
// *
// * \p SIZE_NBITS is the size in bits of element of the tensor.
// *
// */
//template <size_t SIZE_NBITS_>
//class LowbitTensorFormat : public TensorFormat::ImplBase {
// static constexpr size_t SIZE_NBITS = SIZE_NBITS_;
// size_t m_align_size_in_bits;
//
//protected: //?
// LowbitTensorFormat(Type type, size_t m_align_size_in_bits);
//
//public:
// size_t align_size_in_bits() const {
// return m_align_size_in_bits;
// }
//
// std::string to_string() const override;
//
// void serialize_append(
//
//
//};
/*!
* \brief used for tensors storing lowbit data
*
* \p SIZE_NBITS is the size in bits of element of the tensor.
*
*/
template <size_t SIZE_NBITS_>
class LowbitsTensorFormatBase : public TensorFormat::ImplBase {
static constexpr size_t SIZE_NBITS = SIZE_NBITS_;
size_t m_align_size_in_bits, m_align_size_in_elements;
protected: //?
LowbitsTensorFormatBase(Type type, size_t align_size_in_bits);
virtual ~LowbitsTensorFormatBase() = default;
public:
size_t align_size_in_bits() const { return m_align_size_in_bits; }
std::string to_string() const override;
//! raise exception if given layout is illegal
void assert_valid(const TensorLayout& layout) const;
void serialize_append(std::string& result) const override;
//! span for lowbit tensor must include the padding at the innermost
//! dimemsion that make lowbit tensor be aligned to bytes
TensorLayout::Span span_spec(const TensorLayout& layout) const override;
size_t init_contiguous_stride(TensorLayout& layout) const override;
bool is_contiguous_spec(const TensorLayout& layout) const override;
TensorLayout collapse_contiguous_spec(
const TensorLayout& layout) const override;
protected:
struct SerializePack {
uint8_t align_size_in_bits;
};
};
using FourBitsAlignedToBytesTensorFormatBase = LowbitsTensorFormatBase<4>;
} // namespace detail
/*!
......@@ -270,6 +292,34 @@ private:
TYPE, align_axis, align_size_in_elements, vendor_type) {}
};
/*!
* \brief Tensor for storing 4bit data that requires stride corresponding to
* non-innermost dimension to be aligned to bytes, and pack 2 elems into a byte
*/
class FourBitsAlignedToBytesTensorFormat final
: public detail::FourBitsAlignedToBytesTensorFormatBase {
public:
static constexpr Type TYPE = Type::FOURBITS_ALIGNED_TO_BYTE;
static TensorFormat make(size_t align_size_in_bits);
static TensorFormat deserialize(const Handle* handle, const void* buf,
size_t size);
static bool is_valid_layout(const TensorLayout& layout) {
if (layout.format.type() == TYPE) {
layout.format.as_impl<FourBitsAlignedToBytesTensorFormat>()
.assert_valid(layout);
return true;
}
return false;
}
private:
FourBitsAlignedToBytesTensorFormat(size_t align_size_in_bits)
: detail::FourBitsAlignedToBytesTensorFormatBase(
TYPE, align_size_in_bits) {}
};
} // namespace megdnn
#include "megdnn/internal/visibility_epilogue.h"
......
......@@ -201,7 +201,15 @@ TensorLayout::TensorLayout(DType dtype_, Format format_)
: dtype{dtype_}, format{format_} {}
TensorLayout::TensorLayout(const TensorShape& shape, DType dtype)
: TensorLayout(shape, dtype, DefaultTensorFormat::make()) {}
: TensorShape(shape), dtype{dtype} {
if (dtype.low_bit() == 4_z) {
format = FourBitsAlignedToBytesTensorFormat::make(8_z);
} else {
megdnn_assert(!dtype.is_low_bit(), "Unsupported data type(%s)",
dtype.name());
format = DefaultTensorFormat::make();
}
}
TensorLayout::TensorLayout(const TensorShape& shape, DType dtype,
TensorFormat format_)
......
......@@ -35,6 +35,9 @@ TensorFormat TensorFormat::deserialize(const std::string& bin,
case Type::IMAGE2D_PACK4:
return Image2DPack4TensorFormat::deserialize(
handle, type + 1, bin.size() - sizeof(Type));
case Type::FOURBITS_ALIGNED_TO_BYTE:
return FourBitsAlignedToBytesTensorFormat::deserialize(
handle, type + 1, bin.size() - sizeof(Type));
default:
megdnn_throw("invalid tensor format type in deserialize");
}
......@@ -67,7 +70,15 @@ bool TensorFormat::is_default() const {
}
/* ===================== DefaultFormat ===================== */
void DefaultTensorFormat::assert_valid(const TensorLayout& layout) const {
megdnn_assert(
!layout.dtype.valid() || !layout.dtype.is_low_bit(),
"DefaultTensorFormat does not support low-bits tensor(dtype:%s)",
layout.dtype.name());
}
size_t DefaultTensorFormat::init_contiguous_stride(TensorLayout& layout) const {
assert_valid(layout);
if (!layout.ndim)
return 0;
megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM);
......@@ -81,11 +92,13 @@ size_t DefaultTensorFormat::init_contiguous_stride(TensorLayout& layout) const {
}
bool DefaultTensorFormat::is_contiguous_spec(const TensorLayout& layout) const {
assert_valid(layout);
return layout.is_physical_contiguous();
}
TensorLayout DefaultTensorFormat::collapse_contiguous_spec(
const TensorLayout& layout) const {
assert_valid(layout);
megdnn_assert(layout.ndim);
TensorLayout res{layout};
......@@ -126,6 +139,7 @@ TensorLayout DefaultTensorFormat::collapse_contiguous_spec(
TensorLayout::Span DefaultTensorFormat::span_spec(
const TensorLayout& layout) const {
assert_valid(layout);
if (layout.ndim == 0)
return {0, 0, 0, 0};
......@@ -146,9 +160,6 @@ TensorLayout::Span DefaultTensorFormat::span_spec(
++high_elem;
ptrdiff_t low_byte;
if (low_elem < 0) {
megdnn_assert(!layout.dtype.is_low_bit(),
"tensors with low-bit dytes shouldn't have negative "
"strides");
low_byte = low_elem * layout.dtype.size();
} else {
low_byte = 0;
......@@ -422,12 +433,151 @@ TensorLayout Image2DPackedTensorFormatBase<PIXEL_SIZE>::collapse_contiguous_spec
return res;
}
namespace megdnn {
namespace detail {
template class Image2DPackedTensorFormatBase<4>;
} // namespace detail
} // namespace megdnn
/* =============== FourBitsAlignedToBytesTensorFormatBase ============== */
template <size_t SIZE_NBITS>
LowbitsTensorFormatBase<SIZE_NBITS>::LowbitsTensorFormatBase(
Type type, size_t align_size_in_bits)
: ImplBase(type), m_align_size_in_bits(align_size_in_bits) {
megdnn_assert(!(m_align_size_in_bits % SIZE_NBITS),
"align size(%zu) must be a multiple of element size(%zu)",
m_align_size_in_bits, SIZE_NBITS);
m_align_size_in_elements = m_align_size_in_bits / SIZE_NBITS;
}
template <size_t SIZE_NBITS>
std::string LowbitsTensorFormatBase<SIZE_NBITS>::to_string() const {
return ssprintf("LOWBITS{%zu,%zu}", SIZE_NBITS, m_align_size_in_bits);
}
template <size_t SIZE_NBITS>
void LowbitsTensorFormatBase<SIZE_NBITS>::assert_valid(
const TensorLayout& layout) const {
megdnn_assert(layout.dtype.valid() && layout.dtype.is_low_bit() &&
layout.dtype.low_bit() == SIZE_NBITS);
bool has_dim_unity_stride = false;
for (int i = layout.ndim - 1; i >= 0; --i) {
if (!has_dim_unity_stride && layout.stride[i] == 1)
has_dim_unity_stride = true;
megdnn_assert(
layout.stride[i] >= 0 &&
(layout.stride[i] % m_align_size_in_elements == 0 ||
layout.stride[i] == 1),
"bad stride: %zu", layout.stride[i]);
}
megdnn_assert(has_dim_unity_stride, "innermost dim not contiguous");
}
template <size_t SIZE_NBITS>
void LowbitsTensorFormatBase<SIZE_NBITS>::serialize_append(
std::string& result) const {
SerializePack pack;
pack.align_size_in_bits = m_align_size_in_bits;
megdnn_assert(pack.align_size_in_bits ==
m_align_size_in_bits); // detect overflow;
result.append(reinterpret_cast<char*>(&pack), sizeof(pack));
}
template <size_t SIZE_NBITS>
TensorLayout::Span LowbitsTensorFormatBase<SIZE_NBITS>::span_spec(
const TensorLayout& layout) const {
assert_valid(layout);
if (layout.ndim == 0)
return {0, 0, 0, 0};
size_t high_elem = 0;
for (size_t i = 0; i < layout.ndim; ++i) {
auto shape_val = layout.shape[i];
if (!shape_val) {
return {0, 0, 0, 0};
}
auto stride_val = layout.stride[i];
megdnn_assert(stride_val >= 0,
"lowbit tensors shouldn't have negative strides");
high_elem += (shape_val - 1) * stride_val;
}
++high_elem;
size_t high_byte = layout.dtype.size(high_elem);
return TensorLayout::Span(0, 0, high_elem, high_byte);
}
template <size_t SIZE_NBITS>
size_t LowbitsTensorFormatBase<SIZE_NBITS>::init_contiguous_stride(
TensorLayout& layout) const {
if (!layout.ndim)
return 0;
megdnn_assert(layout.ndim <= TensorLayout::MAX_NDIM);
size_t accum = 1;
SafeMultiplies<size_t> mul;
for (size_t i = layout.ndim; i; --i) {
layout.stride[i - 1] = accum;
auto multiplier = layout.shape[i - 1];
if (i == layout.ndim)
multiplier = round_up(multiplier, m_align_size_in_elements);
accum = mul(accum, multiplier);
}
return accum;
}
template <size_t SIZE_NBITS>
bool LowbitsTensorFormatBase<SIZE_NBITS>::is_contiguous_spec(
const TensorLayout& layout) const {
assert_valid(layout);
ptrdiff_t expected = 1;
for (int i = static_cast<int>(layout.ndim) - 1; i >= 0; --i) {
if (layout.shape[i] != 1 && layout.stride[i] != expected)
return false;
auto multiplier = layout.shape[i];
if (i == layout.ndim - 1)
multiplier = round_up(multiplier, m_align_size_in_elements);
expected *= multiplier;
}
return expected != 0;
}
template <size_t SIZE_NBITS>
TensorLayout LowbitsTensorFormatBase<SIZE_NBITS>::collapse_contiguous_spec(
const TensorLayout& layout) const {
assert_valid(layout);
TensorLayout res{layout};
for (int i = static_cast<int>(res.ndim) - 1; i >= 0; --i) {
if (!res.shape[i]) {
// empty tensor
res.ndim = 1;
res.shape[0] = 0;
res.stride[0] = 1;
return res;
}
if (res.shape[i] == 1) {
res.remove_axis_inplace(i);
}
}
megdnn_assert(res.ndim && res.shape[res.ndim - 1]);
for (int i = static_cast<int>(res.ndim) - 2; i >= 0; --i) {
megdnn_assert(res.shape[i]);
if (res.stride[i] ==
res.stride[i + 1] * static_cast<ptrdiff_t>(res.shape[i + 1])) {
res.shape[i] *= res.shape[i + 1];
res.stride[i] = res.stride[i + 1];
res.remove_axis_inplace(i + 1);
}
}
return res;
}
namespace megdnn {
namespace detail {
template class LowbitsTensorFormatBase<4>;
} // namespace detail
} // namespace megdnn
/* ===================== Image2DPack4TensorFormat ===================== */
TensorFormat Image2DPack4TensorFormat::make_raw(
size_t align_axis, size_t align_size_in_elements,
......@@ -466,4 +616,29 @@ TensorFormat Image2DPack4TensorFormat::change_axis(size_t axis) const {
return make_raw(axis, align_size_in_elements(), vendor());
}
/* ===================== FourBitsAlignedToBytesTensorFormat
* ===================== */
TensorFormat FourBitsAlignedToBytesTensorFormat::make(
size_t align_size_in_bits) {
static std::mutex mtx;
static std::unordered_map<
uint32_t, std::unique_ptr<FourBitsAlignedToBytesTensorFormat>>
cache;
megdnn_assert(!(align_size_in_bits % 4));
MEGDNN_LOCK_GUARD(mtx);
auto&& ptr = cache[static_cast<uint32_t>(align_size_in_bits)];
if (!ptr) {
ptr.reset(new FourBitsAlignedToBytesTensorFormat{align_size_in_bits});
}
return impl_to_tensor_format(ptr.get());
}
TensorFormat FourBitsAlignedToBytesTensorFormat::deserialize(const Handle*,
const void* buf,
size_t size) {
megdnn_assert(size == sizeof(SerializePack));
auto pack = *static_cast<const SerializePack*>(buf);
return make(pack.align_size_in_bits);
}
// vim: syntax=cpp.doxygen
......@@ -128,9 +128,11 @@ public:
for (size_t i = 0; i < shapes.size(); ++i) {
DType dt = (m_dtype.find(i) != m_dtype.end() ? m_dtype[i]
: dtype::Float32());
TensorFormat fmt =
(m_fmt.find(i) != m_fmt.end() ? m_fmt[i] : TensorFormat{});
layouts[i] = TensorLayout(shapes[i], dt, fmt);
if (m_fmt.find(i) == m_fmt.end()) {
layouts[i] = TensorLayout(shapes[i], dt);
layouts[i].init_contiguous_stride();
} else
layouts[i] = TensorLayout(shapes[i], dt, m_fmt[i]);
}
return layouts;
}
......
......@@ -302,4 +302,50 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) {
}
}
TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_LOW_BITS) {
TensorLayout layout{{16, 32, 7, 7}, dtype::QuantizedS4{1.2f}};
layout.init_contiguous_stride();
ASSERT_EQ(layout.stride[0], 1792);
ASSERT_EQ(layout.stride[1], 56);
ASSERT_EQ(layout.stride[2], 8);
ASSERT_EQ(layout.stride[3], 1);
auto span = layout.span();
ASSERT_EQ(0, span.low_elem);
ASSERT_EQ(28671, span.high_elem);
ASSERT_EQ(0, span.low_byte);
ASSERT_EQ(14336, span.high_byte);
EXPECT_EQ(make_layout({3584, 7}, {8, 1}, dtype::QuantizedS4{1.2f}),
layout.collapse_contiguous());
layout = make_layout({16, 32, 7, 7}, {1792, 56, 8, 1},
dtype::QuantizedS4{1.3f});
layout.format = FourBitsAlignedToBytesTensorFormat::make(8_z);
EXPECT_TRUE(layout.is_contiguous());
layout = TensorLayout{{1, 32, 1, 1}, dtype::QuantizedS4{1.2f}};
layout = layout.broadcast({16, 32, 7, 7});
EXPECT_EQ(make_layout({16, 32, 49}, {0, 1, 0}, dtype::QuantizedS4{1.2}),
layout.collapse_contiguous());
layout = TensorLayout{{1, 32, 1, 1}, dtype::QuantizedS4{1.2f}};
layout.init_contiguous_stride();
layout = layout.broadcast({16, 32, 7, 7});
ASSERT_THROW(layout.span(), MegDNNError);
}
TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_LOW_BITS_VALID) {
ASSERT_THROW(TensorLayout({1, 32, 1, 1}, dtype::QuantizedS4{1.2f},
DefaultTensorFormat::make()),
MegDNNError);
ASSERT_THROW(TensorLayout({1, 32, 1, 1}, dtype::QuantizedS32{1.2f},
FourBitsAlignedToBytesTensorFormat::make(8_z))
.span(),
MegDNNError);
ASSERT_THROW(TensorLayout({16, 32, 7, 7}, dtype::IntB2{},
FourBitsAlignedToBytesTensorFormat::make(8_z))
.span(),
MegDNNError);
}
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册