提交 cd7090ac 编写于 作者: M Megvii Engine Team

fix(opencl): enable image on mali(cl2.1)

GitOrigin-RevId: 0c670fba807e9bf25e7825e7de5ce8c04d30dae8
上级 fc8d13cd
...@@ -38,6 +38,17 @@ class Handle { ...@@ -38,6 +38,17 @@ class Handle {
CAMBRICON = 12, CAMBRICON = 12,
}; };
//! Device vendor
enum class HandleVendorType : uint32_t {
NOT_SPEC = 0,
MALI = 1,
ADRENO = 2,
CUDA = 3,
INTEL = 4,
POWERVR = 5,
AMD = 6,
};
protected: protected:
Handle(megcoreComputingHandle_t computing_handle, HandleType type); Handle(megcoreComputingHandle_t computing_handle, HandleType type);
...@@ -130,6 +141,9 @@ class Handle { ...@@ -130,6 +141,9 @@ class Handle {
//! get alignment in bytes for rows of image 2D tensor format //! get alignment in bytes for rows of image 2D tensor format
virtual size_t image2d_pitch_alignment() const; virtual size_t image2d_pitch_alignment() const;
//! get vendor type
virtual HandleVendorType vendor_type() const;
HandleType type() const { HandleType type() const {
return m_handle_type; return m_handle_type;
} }
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#pragma once #pragma once
#include "megdnn/basic_types.h" #include "megdnn/basic_types.h"
#include "megdnn/handle.h"
#include "megdnn/internal/visibility_prologue.h" #include "megdnn/internal/visibility_prologue.h"
namespace megdnn { namespace megdnn {
...@@ -43,7 +44,7 @@ public: ...@@ -43,7 +44,7 @@ public:
protected: protected:
ImplBase(Type type) : m_type{type} {} ImplBase(Type type) : m_type{type} {}
~ImplBase() = default; virtual ~ImplBase() = default;
static TensorFormat impl_to_tensor_format(ImplBase* impl) { return {impl}; } static TensorFormat impl_to_tensor_format(ImplBase* impl) { return {impl}; }
...@@ -93,8 +94,8 @@ namespace detail { ...@@ -93,8 +94,8 @@ namespace detail {
* *
* \p align_axis is the axis to be aligned, also the first axis of image width. * \p align_axis is the axis to be aligned, also the first axis of image width.
* More precisely speaking, `stride[align_axis-1] * dtype.size()` must divide \p * More precisely speaking, `stride[align_axis-1] * dtype.size()` must divide \p
* align_size_in_byte. Axes from 0 to align_axis-1 would be considered as the * align_size_in_elements. Axes from 0 to align_axis-1 would be considered as
* height of the image, and other axes are the width. * the height of the image, and other axes are the width.
* *
* Empty tensors and negative strides are not allowed. Only contiguous or * Empty tensors and negative strides are not allowed. Only contiguous or
* broadcasted cases are allowed. * broadcasted cases are allowed.
...@@ -103,41 +104,32 @@ namespace detail { ...@@ -103,41 +104,32 @@ namespace detail {
* considered as contiguous. * considered as contiguous.
*/ */
class Image2DTensorFormatBase : public TensorFormat::ImplBase { class Image2DTensorFormatBase : public TensorFormat::ImplBase {
size_t m_align_axis, m_align_size_in_byte_log2; size_t m_align_axis, m_align_size_in_elements_log2;
protected: protected:
Image2DTensorFormatBase(Type type, size_t align_axis, Image2DTensorFormatBase(Type type, size_t align_axis,
size_t align_size_in_byte); size_t align_size_in_elements);
~Image2DTensorFormatBase() = default; virtual ~Image2DTensorFormatBase() = default;
public: public:
/*! /*!
* \brief get alignment requirement in bytes * \brief get alignment requirement in elements
* \param div_log2 the result would be divided by `(1 << div_log2)` * \param div_log2 the result would be divided by `(1 << div_log2)`
*/ */
size_t align_size_in_byte(size_t div_log2 = 0) const { size_t align_size_in_elements(size_t div_log2 = 0) const {
return 1 << (m_align_size_in_byte_log2 > div_log2 return 1 << (m_align_size_in_elements_log2 > div_log2
? m_align_size_in_byte_log2 - div_log2 ? m_align_size_in_elements_log2 - div_log2
: 0); : 0);
} }
size_t align_axis() const { return m_align_axis; } size_t align_axis() const { return m_align_axis; }
size_t init_contiguous_stride(TensorLayout& layout) const override; size_t align_size_in_elements_log2() const {
return m_align_size_in_elements_log2;
bool is_contiguous_spec(const TensorLayout& layout) const override; }
TensorLayout collapse_contiguous_spec(
const TensorLayout& layout) const override;
//! span for image must include the padding at the last row
TensorLayout::Span span_spec(const TensorLayout& layout) const override;
std::string to_string() const override; std::string to_string() const override;
//! raise exception if preconditions violated
virtual void assert_valid(const TensorLayout& layout) const;
//! modify the align axis and return a new TensorFormat //! modify the align axis and return a new TensorFormat
virtual TensorFormat change_axis(size_t axis) const = 0; virtual TensorFormat change_axis(size_t axis) const = 0;
...@@ -147,9 +139,6 @@ public: ...@@ -147,9 +139,6 @@ public:
//! number of rows //! number of rows
size_t image_height(const TensorLayout& layout) const; size_t image_height(const TensorLayout& layout) const;
//! delta of addresses of consecutive rows (in bytes)
size_t image_row_pitch(const TensorLayout& layout) const;
void serialize_append(std::string& result) const override; void serialize_append(std::string& result) const override;
protected: protected:
struct SerializePack { struct SerializePack {
...@@ -159,9 +148,27 @@ protected: ...@@ -159,9 +148,27 @@ protected:
template <size_t PIXEL_SIZE> template <size_t PIXEL_SIZE>
class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase { class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase {
Handle::HandleVendorType m_vendor_type = Handle::HandleVendorType::NOT_SPEC;
/*!
* \brief get fix alignment requirement in bytes, consider m_vendor_type,
* for example on MALI, CL_DEVICE_IMAGE_PITCH_ALIGNMENT means image_width
* align COUNT, but mdl needs align size in byte, which equal to
* (image_width algin count) * sizeof(data_type) * pixel_size
*/
size_t image_pitch_alignment_in_bytes(size_t align_size_in_elements,
const TensorLayout& layout) const;
protected: protected:
using Image2DTensorFormatBase::Image2DTensorFormatBase; Image2DPackedTensorFormatBase(Type type, size_t align_axis,
~Image2DPackedTensorFormatBase() = default; size_t align_size_in_elements,
Handle::HandleVendorType vendor_type)
: detail::Image2DTensorFormatBase(type, align_axis,
align_size_in_elements),
m_vendor_type(vendor_type) {}
virtual ~Image2DPackedTensorFormatBase() = default;
Handle::HandleVendorType vendor() const { return m_vendor_type; }
public: public:
/*! /*!
...@@ -173,7 +180,20 @@ public: ...@@ -173,7 +180,20 @@ public:
*/ */
size_t image_width(const TensorLayout& layout) const; size_t image_width(const TensorLayout& layout) const;
void assert_valid(const TensorLayout& layout) const override; //! raise exception if preconditions violated
void assert_valid(const TensorLayout& layout) const;
size_t image_row_pitch(const TensorLayout& layout) const;
//! span for image must include the padding at the last row
TensorLayout::Span span_spec(const TensorLayout& layout) const override;
size_t init_contiguous_stride(TensorLayout& layout) const override;
bool is_contiguous_spec(const TensorLayout& layout) const override;
TensorLayout collapse_contiguous_spec(
const TensorLayout& layout) const override;
}; };
using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>;
} // namespace detail } // namespace detail
...@@ -190,7 +210,10 @@ public: ...@@ -190,7 +210,10 @@ public:
static constexpr Type TYPE = Type::IMAGE2D_PACK4; static constexpr Type TYPE = Type::IMAGE2D_PACK4;
//! for internal usage or test purposes //! for internal usage or test purposes
static TensorFormat make_raw(size_t align_axis, size_t align_size_in_byte); static TensorFormat make_raw(size_t align_axis,
size_t align_size_in_elements,
Handle::HandleVendorType vendor_type =
Handle::HandleVendorType::NOT_SPEC);
static TensorFormat make(size_t align_axis, const Handle* handle); static TensorFormat make(size_t align_axis, const Handle* handle);
...@@ -215,9 +238,10 @@ public: ...@@ -215,9 +238,10 @@ public:
TensorFormat change_axis(size_t axis) const override; TensorFormat change_axis(size_t axis) const override;
private: private:
Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_byte) Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_elements,
: detail::Image2DPack4TensorFormatBase(TYPE, align_axis, Handle::HandleVendorType vendor_type)
align_size_in_byte) {} : detail::Image2DPack4TensorFormatBase(
TYPE, align_axis, align_size_in_elements, vendor_type) {}
}; };
} // namespace megdnn } // namespace megdnn
......
...@@ -147,6 +147,10 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle, ...@@ -147,6 +147,10 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle,
megdnn_throw("image2d tensor format not supported on this handle"); megdnn_throw("image2d tensor format not supported on this handle");
} }
megdnn::HandleImplHelper::HandleVendorType Handle::vendor_type() const {
return HandleVendorType::NOT_SPEC;
}
bool Handle::check_cross_dev_copy_constraint(const TensorLayout& src) { bool Handle::check_cross_dev_copy_constraint(const TensorLayout& src) {
return src.is_contiguous(); return src.is_contiguous();
} }
......
...@@ -236,6 +236,7 @@ void RelayoutFormat::deduce_layout(const TensorLayout& src, TensorLayout& dst) { ...@@ -236,6 +236,7 @@ void RelayoutFormat::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
size_t align = handle()->image2d_pitch_alignment(); size_t align = handle()->image2d_pitch_alignment();
auto vendor_type = handle()->vendor_type();
using Param = param::RelayoutFormat; using Param = param::RelayoutFormat;
#define CHECK_SRC(_expect) \ #define CHECK_SRC(_expect) \
megdnn_assert(src == _expect, "invalid src format: expect=%s got=%s", \ megdnn_assert(src == _expect, "invalid src format: expect=%s got=%s", \
...@@ -251,7 +252,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { ...@@ -251,7 +252,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
break; break;
case Param::Mode::NHWC_NHWCD4I: case Param::Mode::NHWC_NHWCD4I:
CHECK_SRC(DefaultTensorFormat::make()); CHECK_SRC(DefaultTensorFormat::make());
dst = Image2DPack4TensorFormat::make_raw(2, align); dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type);
break; break;
case Param::Mode::NCHW_NHWCD4: case Param::Mode::NCHW_NHWCD4:
CHECK_SRC(DefaultTensorFormat::make()); CHECK_SRC(DefaultTensorFormat::make());
...@@ -263,10 +264,10 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { ...@@ -263,10 +264,10 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
break; break;
case Param::Mode::NCHW_NHWCD4I: case Param::Mode::NCHW_NHWCD4I:
CHECK_SRC(DefaultTensorFormat::make()); CHECK_SRC(DefaultTensorFormat::make());
dst = Image2DPack4TensorFormat::make_raw(2, align); dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type);
break; break;
case Param::Mode::NHWCD4I_NCHW: case Param::Mode::NHWCD4I_NCHW:
CHECK_SRC(Image2DPack4TensorFormat::make_raw(2, align)); CHECK_SRC(Image2DPack4TensorFormat::make_raw(2, align, vendor_type));
dst = DefaultTensorFormat::make(); dst = DefaultTensorFormat::make();
break; break;
case Param::Mode::NHWCD4_NCHW: case Param::Mode::NHWCD4_NCHW:
...@@ -280,7 +281,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { ...@@ -280,7 +281,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
case Param::Mode::INTER_WEIGHT_DENSEI: case Param::Mode::INTER_WEIGHT_DENSEI:
case Param::Mode::INTER_WEIGHT_DENSEI_DOT: case Param::Mode::INTER_WEIGHT_DENSEI_DOT:
CHECK_SRC(DefaultTensorFormat::make()); CHECK_SRC(DefaultTensorFormat::make());
dst = Image2DPack4TensorFormat::make_raw(3, align); dst = Image2DPack4TensorFormat::make_raw(3, align, vendor_type);
break; break;
case Param::Mode::INTER_WEIGHT_GROUP: case Param::Mode::INTER_WEIGHT_GROUP:
CHECK_SRC(DefaultTensorFormat::make()); CHECK_SRC(DefaultTensorFormat::make());
...@@ -289,7 +290,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { ...@@ -289,7 +290,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
case Param::Mode::INTER_WEIGHT_GROUPI: case Param::Mode::INTER_WEIGHT_GROUPI:
case Param::Mode::INTER_WEIGHT_GROUPI_DOT: case Param::Mode::INTER_WEIGHT_GROUPI_DOT:
CHECK_SRC(DefaultTensorFormat::make()); CHECK_SRC(DefaultTensorFormat::make());
dst = Image2DPack4TensorFormat::make_raw(4, align); dst = Image2DPack4TensorFormat::make_raw(4, align, vendor_type);
break; break;
case Param::Mode::INTER_WEIGHT_CHAN: case Param::Mode::INTER_WEIGHT_CHAN:
CHECK_SRC(DefaultTensorFormat::make()); CHECK_SRC(DefaultTensorFormat::make());
...@@ -297,7 +298,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { ...@@ -297,7 +298,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
break; break;
case Param::Mode::INTER_WEIGHT_CHANI: case Param::Mode::INTER_WEIGHT_CHANI:
CHECK_SRC(DefaultTensorFormat::make()); CHECK_SRC(DefaultTensorFormat::make());
dst = Image2DPack4TensorFormat::make_raw(1, align); dst = Image2DPack4TensorFormat::make_raw(1, align, vendor_type);
break; break;
case Param::Mode::NCHW4_CHWN4: case Param::Mode::NCHW4_CHWN4:
CHECK_SRC(DefaultTensorFormat::make()); CHECK_SRC(DefaultTensorFormat::make());
......
...@@ -185,23 +185,134 @@ TensorFormat DefaultTensorFormat::make() { ...@@ -185,23 +185,134 @@ TensorFormat DefaultTensorFormat::make() {
/* ===================== Image2DTensorFormatBase ===================== */ /* ===================== Image2DTensorFormatBase ===================== */
Image2DTensorFormatBase::Image2DTensorFormatBase(Type type, size_t align_axis, Image2DTensorFormatBase::Image2DTensorFormatBase(Type type, size_t align_axis,
size_t align_size_in_byte) size_t align_size_in_elements)
: ImplBase(type) { : ImplBase(type), m_align_axis(align_axis) {
megdnn_assert(align_size_in_byte && align_axis); megdnn_assert(align_size_in_elements && align_axis);
m_align_axis = align_axis; m_align_size_in_elements_log2 = __builtin_ctz(align_size_in_elements);
m_align_size_in_byte_log2 = __builtin_ctz(align_size_in_byte); megdnn_assert(
megdnn_assert((1u << m_align_size_in_byte_log2) == align_size_in_byte, (1u << m_align_size_in_elements_log2) == align_size_in_elements,
"align size not power of 2: %zu", align_size_in_byte); "align size not power of 2: %zu", align_size_in_elements);
} }
size_t Image2DTensorFormatBase::init_contiguous_stride( void Image2DTensorFormatBase::serialize_append(std::string& result) const {
SerializePack pack;
pack.align_axis = m_align_axis;
megdnn_assert(pack.align_axis == m_align_axis); // detect overflow
result.append(reinterpret_cast<char*>(&pack), sizeof(pack));
}
size_t Image2DTensorFormatBase::image_height(const TensorLayout& layout) const {
size_t accum = 1;
for (int i = m_align_axis - 1; i >= 0; --i) {
if (layout.stride[i] == 0) {
// this dimension is broadcasted
} else {
accum *= layout.shape[i];
}
}
return accum;
}
size_t Image2DTensorFormatBase::image_width_elems(
const TensorLayout& layout) const {
size_t high_elem = 0;
for (size_t i = m_align_axis; i < layout.ndim; ++i) {
high_elem += (layout.shape[i] - 1) * layout.stride[i];
}
return high_elem + 1;
}
std::string Image2DTensorFormatBase::to_string() const {
return ssprintf("I2D{%zu,%d}", m_align_axis,
1 << m_align_size_in_elements_log2);
}
/* ===================== Image2DPackedTensorFormatBase ===================== */
template <size_t PIXEL_SIZE>
size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_width(
const TensorLayout& layout) const {
auto ret = image_width_elems(layout);
megdnn_assert(ret % PIXEL_SIZE == 0);
return ret / PIXEL_SIZE;
}
template <size_t PIXEL_SIZE>
void Image2DPackedTensorFormatBase<PIXEL_SIZE>::assert_valid(
const TensorLayout& layout) const {
auto m_align_axis = align_axis();
megdnn_assert(!(layout.shape[layout.ndim - 1] % PIXEL_SIZE),
"bad shape: %zu", layout.shape[layout.ndim - 1]);
megdnn_assert(layout.dtype.valid() && layout.ndim > m_align_axis);
ptrdiff_t first_non_zero_stride = 0;
for (int i = layout.ndim - 1; i >= 0; --i) {
megdnn_assert(layout.shape[i] && layout.stride[i] >= 0);
if (i < static_cast<int>(m_align_axis) && !first_non_zero_stride) {
first_non_zero_stride = layout.stride[i];
}
}
size_t mask =
image_pitch_alignment_in_bytes(
align_size_in_elements(layout.dtype.size_log()), layout) -
1;
megdnn_assert(!(first_non_zero_stride & mask),
"first stride is %d, but alignment is %zu",
static_cast<int>(first_non_zero_stride), mask + 1);
}
template <size_t PIXEL_SIZE>
size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_row_pitch(
const TensorLayout& layout) const {
for (int i = align_axis() - 1; i >= 0; --i) {
// find a non-broadcast axis
if (auto s = layout.stride[i]) {
return layout.dtype.size(s);
}
}
// use width for all broadcasted case
size_t alignment_in_bytes_log2 = align_size_in_elements_log2();
if (m_vendor_type == Handle::HandleVendorType::MALI) {
alignment_in_bytes_log2 +=
__builtin_ctz(layout.dtype.size() * PIXEL_SIZE);
}
return get_aligned_power2<size_t>(
layout.dtype.size(image_width_elems(layout)),
1 << alignment_in_bytes_log2);
}
template <size_t PIXEL_SIZE>
size_t
Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_pitch_alignment_in_bytes(
size_t align_size_in_elements, const TensorLayout& layout) const {
return m_vendor_type == Handle::HandleVendorType::MALI
? (align_size_in_elements * layout.dtype.size() * PIXEL_SIZE)
: align_size_in_elements;
}
template <size_t PIXEL_SIZE>
TensorLayout::Span Image2DPackedTensorFormatBase<PIXEL_SIZE>::span_spec(
const TensorLayout& layout) const {
assert_valid(layout);
size_t size = image_height(layout) * image_row_pitch(layout);
auto mask = (1 << layout.dtype.size_log()) - 1;
megdnn_assert(!(size & mask), "unaligned size: %zu", size);
return {0, 0, size >> layout.dtype.size_log(), size};
}
template <size_t PIXEL_SIZE>
size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::init_contiguous_stride(
TensorLayout& layout) const { TensorLayout& layout) const {
auto m_align_axis = align_axis();
if (!layout.ndim) if (!layout.ndim)
return 0; return 0;
megdnn_assert(layout.dtype.valid() && layout.ndim > m_align_axis, megdnn_assert(layout.dtype.valid() && layout.ndim > m_align_axis,
"dtype=%s ndim=%zu align=%zu", layout.dtype.name(), "dtype=%s ndim=%zu align=%zu", layout.dtype.name(),
layout.ndim, m_align_axis); layout.ndim, m_align_axis);
size_t align_size = align_size_in_byte(layout.dtype.size_log()); size_t align_size = image_pitch_alignment_in_bytes(
align_size_in_elements(layout.dtype.size_log()), layout);
size_t accum = 1; size_t accum = 1;
SafeMultiplies<size_t> mul; SafeMultiplies<size_t> mul;
for (size_t i = layout.ndim; i; --i) { for (size_t i = layout.ndim; i; --i) {
...@@ -216,12 +327,15 @@ size_t Image2DTensorFormatBase::init_contiguous_stride( ...@@ -216,12 +327,15 @@ size_t Image2DTensorFormatBase::init_contiguous_stride(
return accum; return accum;
}; };
bool Image2DTensorFormatBase::is_contiguous_spec( template <size_t PIXEL_SIZE>
bool Image2DPackedTensorFormatBase<PIXEL_SIZE>::is_contiguous_spec(
const TensorLayout& layout) const { const TensorLayout& layout) const {
megdnn_assert(layout.dtype.valid()); megdnn_assert(layout.dtype.valid());
size_t align_size = align_size_in_byte(layout.dtype.size_log()); size_t align_size = image_pitch_alignment_in_bytes(
align_size_in_elements(layout.dtype.size_log()), layout);
ptrdiff_t expected = 1; ptrdiff_t expected = 1;
int height_axis = static_cast<int>(m_align_axis - 1); int height_axis = static_cast<int>(align_axis() - 1);
for (int i = layout.ndim - 1; i >= 0; --i) { for (int i = layout.ndim - 1; i >= 0; --i) {
if (i == height_axis) { if (i == height_axis) {
expected = megdnn::get_aligned_power2<size_t>(expected, align_size); expected = megdnn::get_aligned_power2<size_t>(expected, align_size);
...@@ -235,7 +349,12 @@ bool Image2DTensorFormatBase::is_contiguous_spec( ...@@ -235,7 +349,12 @@ bool Image2DTensorFormatBase::is_contiguous_spec(
return false; return false;
} }
size_t mask = align_size_in_byte(layout.dtype.size_log()) - 1; size_t mask =
image_pitch_alignment_in_bytes(
align_size_in_elements(layout.dtype.size_log()),
layout) -
1;
megdnn_assert(s > expected && !(s & mask), megdnn_assert(s > expected && !(s & mask),
"invalid row pitch: %d; layout: %s", "invalid row pitch: %d; layout: %s",
static_cast<int>(s), layout.to_string().c_str()); static_cast<int>(s), layout.to_string().c_str());
...@@ -250,11 +369,12 @@ bool Image2DTensorFormatBase::is_contiguous_spec( ...@@ -250,11 +369,12 @@ bool Image2DTensorFormatBase::is_contiguous_spec(
return expected != 0; return expected != 0;
} }
TensorLayout Image2DTensorFormatBase::collapse_contiguous_spec( template <size_t PIXEL_SIZE>
TensorLayout Image2DPackedTensorFormatBase<PIXEL_SIZE>::collapse_contiguous_spec(
const TensorLayout& layout) const { const TensorLayout& layout) const {
assert_valid(layout); assert_valid(layout);
TensorLayout res{layout}; TensorLayout res{layout};
int new_axis = m_align_axis; int new_axis = align_axis();
// remove all dims with shape 1 // remove all dims with shape 1
for (int i = static_cast<int>(res.ndim) - 1; i >= 0 && res.ndim >= 3; --i) { for (int i = static_cast<int>(res.ndim) - 1; i >= 0 && res.ndim >= 3; --i) {
if (i == new_axis && static_cast<int>(res.ndim) == new_axis + 1) { if (i == new_axis && static_cast<int>(res.ndim) == new_axis + 1) {
...@@ -302,95 +422,6 @@ TensorLayout Image2DTensorFormatBase::collapse_contiguous_spec( ...@@ -302,95 +422,6 @@ TensorLayout Image2DTensorFormatBase::collapse_contiguous_spec(
return res; return res;
} }
TensorLayout::Span Image2DTensorFormatBase::span_spec(
const TensorLayout& layout) const {
assert_valid(layout);
size_t size = image_height(layout) * image_row_pitch(layout);
auto mask = (1 << layout.dtype.size_log()) - 1;
megdnn_assert(!(size & mask), "unaligned size: %zu", size);
return {0, 0, size >> layout.dtype.size_log(), size};
}
void Image2DTensorFormatBase::serialize_append(std::string& result) const {
SerializePack pack;
pack.align_axis = m_align_axis;
megdnn_assert(pack.align_axis == m_align_axis); // detect overflow
result.append(reinterpret_cast<char*>(&pack), sizeof(pack));
}
size_t Image2DTensorFormatBase::image_height(const TensorLayout& layout) const {
size_t accum = 1;
for (int i = m_align_axis - 1; i >= 0; --i) {
if (layout.stride[i] == 0) {
// this dimension is broadcasted
} else {
accum *= layout.shape[i];
}
}
return accum;
}
size_t Image2DTensorFormatBase::image_row_pitch(
const TensorLayout& layout) const {
for (int i = m_align_axis - 1; i >= 0; --i) {
// find a non-broadcast axis
if (auto s = layout.stride[i]) {
return layout.dtype.size(s);
}
}
// use width for all broadcasted case
return get_aligned_power2<size_t>(
layout.dtype.size(image_width_elems(layout)),
1 << m_align_size_in_byte_log2);
}
void Image2DTensorFormatBase::assert_valid(const TensorLayout& layout) const {
megdnn_assert(layout.dtype.valid() && layout.ndim > m_align_axis);
ptrdiff_t first_non_zero_stride = 0;
for (int i = layout.ndim - 1; i >= 0; --i) {
megdnn_assert(layout.shape[i] && layout.stride[i] >= 0);
if (i < static_cast<int>(m_align_axis) && !first_non_zero_stride) {
first_non_zero_stride = layout.stride[i];
}
}
size_t mask = align_size_in_byte(layout.dtype.size_log()) - 1;
megdnn_assert(!(first_non_zero_stride & mask),
"first stride is %d, but alignment is %zu",
static_cast<int>(first_non_zero_stride), mask + 1);
}
size_t Image2DTensorFormatBase::image_width_elems(
const TensorLayout& layout) const {
size_t high_elem = 0;
for (size_t i = m_align_axis; i < layout.ndim; ++i) {
high_elem += (layout.shape[i] - 1) * layout.stride[i];
}
return high_elem + 1;
}
std::string Image2DTensorFormatBase::to_string() const {
return ssprintf("I2D{%zu,%d}", m_align_axis,
1 << m_align_size_in_byte_log2);
}
/* ===================== Image2DPackedTensorFormatBase ===================== */
template <size_t PIXEL_SIZE>
size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_width(
const TensorLayout& layout) const {
auto ret = image_width_elems(layout);
megdnn_assert(ret % PIXEL_SIZE == 0);
return ret / PIXEL_SIZE;
}
template <size_t PIXEL_SIZE>
void Image2DPackedTensorFormatBase<PIXEL_SIZE>::assert_valid(
const TensorLayout& layout) const {
Image2DTensorFormatBase::assert_valid(layout);
megdnn_assert(!(layout.shape[layout.ndim - 1] % PIXEL_SIZE),
"bad shape: %zu", layout.shape[layout.ndim - 1]);
}
namespace megdnn { namespace megdnn {
namespace detail { namespace detail {
template class Image2DPackedTensorFormatBase<4>; template class Image2DPackedTensorFormatBase<4>;
...@@ -398,26 +429,29 @@ template class Image2DPackedTensorFormatBase<4>; ...@@ -398,26 +429,29 @@ template class Image2DPackedTensorFormatBase<4>;
} // namespace megdnn } // namespace megdnn
/* ===================== Image2DPack4TensorFormat ===================== */ /* ===================== Image2DPack4TensorFormat ===================== */
TensorFormat Image2DPack4TensorFormat::make_raw(size_t align_axis, TensorFormat Image2DPack4TensorFormat::make_raw(
size_t align_size_in_byte) { size_t align_axis, size_t align_size_in_elements,
Handle::HandleVendorType vendor_type) {
static std::mutex mtx; static std::mutex mtx;
static std::unordered_map<uint64_t, static std::unordered_map<uint64_t,
std::unique_ptr<Image2DPack4TensorFormat>> std::unique_ptr<Image2DPack4TensorFormat>>
cache; cache;
megdnn_assert(std::max(align_axis, align_size_in_byte) <= megdnn_assert(std::max(align_axis, align_size_in_elements) <=
std::numeric_limits<uint32_t>::max()); std::numeric_limits<uint32_t>::max());
MEGDNN_LOCK_GUARD(mtx); MEGDNN_LOCK_GUARD(mtx);
auto&& ptr = cache[(static_cast<uint64_t>(align_axis) << 32) | auto&& ptr = cache[(static_cast<uint64_t>(align_axis) << 32) |
align_size_in_byte]; align_size_in_elements];
if (!ptr) { if (!ptr) {
ptr.reset(new Image2DPack4TensorFormat{align_axis, align_size_in_byte}); ptr.reset(new Image2DPack4TensorFormat{
align_axis, align_size_in_elements, vendor_type});
} }
return impl_to_tensor_format(ptr.get()); return impl_to_tensor_format(ptr.get());
} }
TensorFormat Image2DPack4TensorFormat::make(size_t align_axis, TensorFormat Image2DPack4TensorFormat::make(size_t align_axis,
const Handle* handle) { const Handle* handle) {
return make_raw(align_axis, handle->image2d_pitch_alignment()); return make_raw(align_axis, handle->image2d_pitch_alignment(),
handle->vendor_type());
} }
TensorFormat Image2DPack4TensorFormat::deserialize(const Handle* handle, TensorFormat Image2DPack4TensorFormat::deserialize(const Handle* handle,
...@@ -429,7 +463,7 @@ TensorFormat Image2DPack4TensorFormat::deserialize(const Handle* handle, ...@@ -429,7 +463,7 @@ TensorFormat Image2DPack4TensorFormat::deserialize(const Handle* handle,
} }
TensorFormat Image2DPack4TensorFormat::change_axis(size_t axis) const { TensorFormat Image2DPack4TensorFormat::change_axis(size_t axis) const {
return make_raw(axis, align_size_in_byte()); return make_raw(axis, align_size_in_elements(), vendor());
} }
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -123,6 +123,10 @@ size_t HandleImpl::image2d_pitch_alignment() const { ...@@ -123,6 +123,10 @@ size_t HandleImpl::image2d_pitch_alignment() const {
return align; return align;
} }
HandleImpl::HandleVendorType HandleImpl::vendor_type() const {
return HandleVendorType::CUDA;
}
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
......
...@@ -123,6 +123,7 @@ class HandleImpl: public HandleImplHelper { ...@@ -123,6 +123,7 @@ class HandleImpl: public HandleImplHelper {
TypeCvt* typecvt_opr() { return get_helper_opr<TypeCvt, 0>(this); } TypeCvt* typecvt_opr() { return get_helper_opr<TypeCvt, 0>(this); }
size_t image2d_pitch_alignment() const override; size_t image2d_pitch_alignment() const override;
HandleVendorType vendor_type() const override;
private: private:
bool m_is_tegra_k1; bool m_is_tegra_k1;
int m_device_id; int m_device_id;
......
...@@ -118,6 +118,10 @@ size_t HandleImpl::image2d_pitch_alignment() const { ...@@ -118,6 +118,10 @@ size_t HandleImpl::image2d_pitch_alignment() const {
return g_image2d_pitch_alignment; return g_image2d_pitch_alignment;
} }
HandleImpl::HandleVendorType HandleImpl::vendor_type() const {
return HandleVendorType::NOT_SPEC;
}
size_t HandleImpl::exchange_image2d_pitch_alignment(size_t alignment) { size_t HandleImpl::exchange_image2d_pitch_alignment(size_t alignment) {
auto ret = g_image2d_pitch_alignment; auto ret = g_image2d_pitch_alignment;
g_image2d_pitch_alignment = alignment; g_image2d_pitch_alignment = alignment;
......
...@@ -169,6 +169,7 @@ public: ...@@ -169,6 +169,7 @@ public:
* \param alignment the new alignment value to set * \param alignment the new alignment value to set
*/ */
static size_t exchange_image2d_pitch_alignment(size_t alignment); static size_t exchange_image2d_pitch_alignment(size_t alignment);
HandleVendorType vendor_type() const override;
}; };
} // namespace naive } // namespace naive
......
...@@ -175,6 +175,30 @@ namespace { ...@@ -175,6 +175,30 @@ namespace {
} }
} }
TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_WITH_VENDOR_MALI) {
TensorFormat fmt = Image2DPack4TensorFormat::make_raw(
1, 512, Handle::HandleVendorType::MALI);
TensorLayout layout{{5, 3, 8}, dtype::Float32{}, fmt};
ASSERT_EQ(layout.stride[2], 1);
ASSERT_EQ(layout.stride[1], 8);
ASSERT_EQ(layout.stride[0], 2048);
ASSERT_EQ(8192u, image_row_pitch(layout));
ASSERT_EQ(6u, image_width(layout));
ASSERT_EQ(5u, image_height(layout));
fmt = Image2DPack4TensorFormat::make_raw(1, 512,
Handle::HandleVendorType::MALI);
TensorLayout layout_s{{5, 3, 8}, dtype::Float16{}, fmt};
ASSERT_EQ(layout_s.stride[2], 1);
ASSERT_EQ(layout_s.stride[1], 8);
ASSERT_EQ(layout_s.stride[0], 2048);
ASSERT_EQ(4096u, image_row_pitch(layout_s));
ASSERT_EQ(6u, image_width(layout_s));
ASSERT_EQ(5u, image_height(layout_s));
}
TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT) { TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT) {
TensorFormat fmt = Image2DPack4TensorFormat::make_raw(1, 1024); TensorFormat fmt = Image2DPack4TensorFormat::make_raw(1, 1024);
ASSERT_FALSE(fmt.is_default()); ASSERT_FALSE(fmt.is_default());
...@@ -233,7 +257,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT) { ...@@ -233,7 +257,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT) {
auto&& impl = contig.format.as_impl<Image2DPack4TensorFormat>(); auto&& impl = contig.format.as_impl<Image2DPack4TensorFormat>();
ASSERT_EQ(make_layout({1, 8}, {32, 1}, layout.dtype), contig); ASSERT_EQ(make_layout({1, 8}, {32, 1}, layout.dtype), contig);
ASSERT_EQ(1u, impl.align_axis()); ASSERT_EQ(1u, impl.align_axis());
ASSERT_EQ(64u, impl.align_size_in_byte()); ASSERT_EQ(64u, impl.align_size_in_elements());
} }
} }
...@@ -258,7 +282,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_H) { ...@@ -258,7 +282,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_H) {
auto&& impl = contig.format.as_impl<Image2DPack4TensorFormat>(); auto&& impl = contig.format.as_impl<Image2DPack4TensorFormat>();
ASSERT_EQ(make_layout({v0, 8}, {32, 1}, layout.dtype), contig); ASSERT_EQ(make_layout({v0, 8}, {32, 1}, layout.dtype), contig);
ASSERT_EQ(1u, impl.align_axis()); ASSERT_EQ(1u, impl.align_axis());
ASSERT_EQ(64u, impl.align_size_in_byte()); ASSERT_EQ(64u, impl.align_size_in_elements());
} }
} }
...@@ -274,7 +298,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) { ...@@ -274,7 +298,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) {
layout.dtype), layout.dtype),
contig); contig);
ASSERT_EQ(1u, impl.align_axis()); ASSERT_EQ(1u, impl.align_axis());
ASSERT_EQ(64u, impl.align_size_in_byte()); ASSERT_EQ(64u, impl.align_size_in_elements());
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册