提交 cd7090ac 编写于 作者: M Megvii Engine Team

fix(opencl): enable image on mali(cl2.1)

GitOrigin-RevId: 0c670fba807e9bf25e7825e7de5ce8c04d30dae8
上级 fc8d13cd
......@@ -38,6 +38,17 @@ class Handle {
CAMBRICON = 12,
};
//! Device vendor
enum class HandleVendorType : uint32_t {
NOT_SPEC = 0,
MALI = 1,
ADRENO = 2,
CUDA = 3,
INTEL = 4,
POWERVR = 5,
AMD = 6,
};
protected:
Handle(megcoreComputingHandle_t computing_handle, HandleType type);
......@@ -130,6 +141,9 @@ class Handle {
//! get alignment in bytes for rows of image 2D tensor format
virtual size_t image2d_pitch_alignment() const;
//! get vendor type
virtual HandleVendorType vendor_type() const;
HandleType type() const {
return m_handle_type;
}
......
......@@ -12,6 +12,7 @@
#pragma once
#include "megdnn/basic_types.h"
#include "megdnn/handle.h"
#include "megdnn/internal/visibility_prologue.h"
namespace megdnn {
......@@ -43,7 +44,7 @@ public:
protected:
ImplBase(Type type) : m_type{type} {}
~ImplBase() = default;
virtual ~ImplBase() = default;
static TensorFormat impl_to_tensor_format(ImplBase* impl) { return {impl}; }
......@@ -93,8 +94,8 @@ namespace detail {
*
* \p align_axis is the axis to be aligned, also the first axis of image width.
* More precisely speaking, `stride[align_axis-1] * dtype.size()` must divide \p
* align_size_in_byte. Axes from 0 to align_axis-1 would be considered as the
* height of the image, and other axes are the width.
* align_size_in_elements. Axes from 0 to align_axis-1 would be considered as
* the height of the image, and other axes are the width.
*
* Empty tensors and negative strides are not allowed. Only contiguous or
* broadcasted cases are allowed.
......@@ -103,41 +104,32 @@ namespace detail {
* considered as contiguous.
*/
class Image2DTensorFormatBase : public TensorFormat::ImplBase {
size_t m_align_axis, m_align_size_in_byte_log2;
size_t m_align_axis, m_align_size_in_elements_log2;
protected:
Image2DTensorFormatBase(Type type, size_t align_axis,
size_t align_size_in_byte);
~Image2DTensorFormatBase() = default;
size_t align_size_in_elements);
virtual ~Image2DTensorFormatBase() = default;
public:
/*!
* \brief get alignment requirement in bytes
* \brief get alignment requirement in elements
* \param div_log2 the result would be divided by `(1 << div_log2)`
*/
size_t align_size_in_byte(size_t div_log2 = 0) const {
return 1 << (m_align_size_in_byte_log2 > div_log2
? m_align_size_in_byte_log2 - div_log2
size_t align_size_in_elements(size_t div_log2 = 0) const {
return 1 << (m_align_size_in_elements_log2 > div_log2
? m_align_size_in_elements_log2 - div_log2
: 0);
}
size_t align_axis() const { return m_align_axis; }
size_t init_contiguous_stride(TensorLayout& layout) const override;
bool is_contiguous_spec(const TensorLayout& layout) const override;
TensorLayout collapse_contiguous_spec(
const TensorLayout& layout) const override;
//! span for image must include the padding at the last row
TensorLayout::Span span_spec(const TensorLayout& layout) const override;
size_t align_size_in_elements_log2() const {
return m_align_size_in_elements_log2;
}
std::string to_string() const override;
//! raise exception if preconditions violated
virtual void assert_valid(const TensorLayout& layout) const;
//! modify the align axis and return a new TensorFormat
virtual TensorFormat change_axis(size_t axis) const = 0;
......@@ -147,9 +139,6 @@ public:
//! number of rows
size_t image_height(const TensorLayout& layout) const;
//! delta of addresses of consecutive rows (in bytes)
size_t image_row_pitch(const TensorLayout& layout) const;
void serialize_append(std::string& result) const override;
protected:
struct SerializePack {
......@@ -159,9 +148,27 @@ protected:
template <size_t PIXEL_SIZE>
class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase {
Handle::HandleVendorType m_vendor_type = Handle::HandleVendorType::NOT_SPEC;
/*!
* \brief get fix alignment requirement in bytes, consider m_vendor_type,
* for example on MALI, CL_DEVICE_IMAGE_PITCH_ALIGNMENT means image_width
* align COUNT, but mdl needs align size in byte, which equal to
* (image_width algin count) * sizeof(data_type) * pixel_size
*/
size_t image_pitch_alignment_in_bytes(size_t align_size_in_elements,
const TensorLayout& layout) const;
protected:
using Image2DTensorFormatBase::Image2DTensorFormatBase;
~Image2DPackedTensorFormatBase() = default;
Image2DPackedTensorFormatBase(Type type, size_t align_axis,
size_t align_size_in_elements,
Handle::HandleVendorType vendor_type)
: detail::Image2DTensorFormatBase(type, align_axis,
align_size_in_elements),
m_vendor_type(vendor_type) {}
virtual ~Image2DPackedTensorFormatBase() = default;
Handle::HandleVendorType vendor() const { return m_vendor_type; }
public:
/*!
......@@ -173,7 +180,20 @@ public:
*/
size_t image_width(const TensorLayout& layout) const;
void assert_valid(const TensorLayout& layout) const override;
//! raise exception if preconditions violated
void assert_valid(const TensorLayout& layout) const;
size_t image_row_pitch(const TensorLayout& layout) const;
//! span for image must include the padding at the last row
TensorLayout::Span span_spec(const TensorLayout& layout) const override;
size_t init_contiguous_stride(TensorLayout& layout) const override;
bool is_contiguous_spec(const TensorLayout& layout) const override;
TensorLayout collapse_contiguous_spec(
const TensorLayout& layout) const override;
};
using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>;
} // namespace detail
......@@ -190,7 +210,10 @@ public:
static constexpr Type TYPE = Type::IMAGE2D_PACK4;
//! for internal usage or test purposes
static TensorFormat make_raw(size_t align_axis, size_t align_size_in_byte);
static TensorFormat make_raw(size_t align_axis,
size_t align_size_in_elements,
Handle::HandleVendorType vendor_type =
Handle::HandleVendorType::NOT_SPEC);
static TensorFormat make(size_t align_axis, const Handle* handle);
......@@ -215,9 +238,10 @@ public:
TensorFormat change_axis(size_t axis) const override;
private:
Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_byte)
: detail::Image2DPack4TensorFormatBase(TYPE, align_axis,
align_size_in_byte) {}
Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_elements,
Handle::HandleVendorType vendor_type)
: detail::Image2DPack4TensorFormatBase(
TYPE, align_axis, align_size_in_elements, vendor_type) {}
};
} // namespace megdnn
......
......@@ -147,6 +147,10 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle,
megdnn_throw("image2d tensor format not supported on this handle");
}
megdnn::HandleImplHelper::HandleVendorType Handle::vendor_type() const {
return HandleVendorType::NOT_SPEC;
}
bool Handle::check_cross_dev_copy_constraint(const TensorLayout& src) {
return src.is_contiguous();
}
......
......@@ -236,6 +236,7 @@ void RelayoutFormat::deduce_layout(const TensorLayout& src, TensorLayout& dst) {
void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
size_t align = handle()->image2d_pitch_alignment();
auto vendor_type = handle()->vendor_type();
using Param = param::RelayoutFormat;
#define CHECK_SRC(_expect) \
megdnn_assert(src == _expect, "invalid src format: expect=%s got=%s", \
......@@ -251,7 +252,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
break;
case Param::Mode::NHWC_NHWCD4I:
CHECK_SRC(DefaultTensorFormat::make());
dst = Image2DPack4TensorFormat::make_raw(2, align);
dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type);
break;
case Param::Mode::NCHW_NHWCD4:
CHECK_SRC(DefaultTensorFormat::make());
......@@ -263,10 +264,10 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
break;
case Param::Mode::NCHW_NHWCD4I:
CHECK_SRC(DefaultTensorFormat::make());
dst = Image2DPack4TensorFormat::make_raw(2, align);
dst = Image2DPack4TensorFormat::make_raw(2, align, vendor_type);
break;
case Param::Mode::NHWCD4I_NCHW:
CHECK_SRC(Image2DPack4TensorFormat::make_raw(2, align));
CHECK_SRC(Image2DPack4TensorFormat::make_raw(2, align, vendor_type));
dst = DefaultTensorFormat::make();
break;
case Param::Mode::NHWCD4_NCHW:
......@@ -280,7 +281,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
case Param::Mode::INTER_WEIGHT_DENSEI:
case Param::Mode::INTER_WEIGHT_DENSEI_DOT:
CHECK_SRC(DefaultTensorFormat::make());
dst = Image2DPack4TensorFormat::make_raw(3, align);
dst = Image2DPack4TensorFormat::make_raw(3, align, vendor_type);
break;
case Param::Mode::INTER_WEIGHT_GROUP:
CHECK_SRC(DefaultTensorFormat::make());
......@@ -289,7 +290,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
case Param::Mode::INTER_WEIGHT_GROUPI:
case Param::Mode::INTER_WEIGHT_GROUPI_DOT:
CHECK_SRC(DefaultTensorFormat::make());
dst = Image2DPack4TensorFormat::make_raw(4, align);
dst = Image2DPack4TensorFormat::make_raw(4, align, vendor_type);
break;
case Param::Mode::INTER_WEIGHT_CHAN:
CHECK_SRC(DefaultTensorFormat::make());
......@@ -297,7 +298,7 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
break;
case Param::Mode::INTER_WEIGHT_CHANI:
CHECK_SRC(DefaultTensorFormat::make());
dst = Image2DPack4TensorFormat::make_raw(1, align);
dst = Image2DPack4TensorFormat::make_raw(1, align, vendor_type);
break;
case Param::Mode::NCHW4_CHWN4:
CHECK_SRC(DefaultTensorFormat::make());
......
......@@ -185,23 +185,134 @@ TensorFormat DefaultTensorFormat::make() {
/* ===================== Image2DTensorFormatBase ===================== */
Image2DTensorFormatBase::Image2DTensorFormatBase(Type type, size_t align_axis,
size_t align_size_in_byte)
: ImplBase(type) {
megdnn_assert(align_size_in_byte && align_axis);
m_align_axis = align_axis;
m_align_size_in_byte_log2 = __builtin_ctz(align_size_in_byte);
megdnn_assert((1u << m_align_size_in_byte_log2) == align_size_in_byte,
"align size not power of 2: %zu", align_size_in_byte);
size_t align_size_in_elements)
: ImplBase(type), m_align_axis(align_axis) {
megdnn_assert(align_size_in_elements && align_axis);
m_align_size_in_elements_log2 = __builtin_ctz(align_size_in_elements);
megdnn_assert(
(1u << m_align_size_in_elements_log2) == align_size_in_elements,
"align size not power of 2: %zu", align_size_in_elements);
}
size_t Image2DTensorFormatBase::init_contiguous_stride(
void Image2DTensorFormatBase::serialize_append(std::string& result) const {
SerializePack pack;
pack.align_axis = m_align_axis;
megdnn_assert(pack.align_axis == m_align_axis); // detect overflow
result.append(reinterpret_cast<char*>(&pack), sizeof(pack));
}
size_t Image2DTensorFormatBase::image_height(const TensorLayout& layout) const {
size_t accum = 1;
for (int i = m_align_axis - 1; i >= 0; --i) {
if (layout.stride[i] == 0) {
// this dimension is broadcasted
} else {
accum *= layout.shape[i];
}
}
return accum;
}
size_t Image2DTensorFormatBase::image_width_elems(
const TensorLayout& layout) const {
size_t high_elem = 0;
for (size_t i = m_align_axis; i < layout.ndim; ++i) {
high_elem += (layout.shape[i] - 1) * layout.stride[i];
}
return high_elem + 1;
}
std::string Image2DTensorFormatBase::to_string() const {
return ssprintf("I2D{%zu,%d}", m_align_axis,
1 << m_align_size_in_elements_log2);
}
/* ===================== Image2DPackedTensorFormatBase ===================== */
template <size_t PIXEL_SIZE>
size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_width(
const TensorLayout& layout) const {
auto ret = image_width_elems(layout);
megdnn_assert(ret % PIXEL_SIZE == 0);
return ret / PIXEL_SIZE;
}
template <size_t PIXEL_SIZE>
void Image2DPackedTensorFormatBase<PIXEL_SIZE>::assert_valid(
const TensorLayout& layout) const {
auto m_align_axis = align_axis();
megdnn_assert(!(layout.shape[layout.ndim - 1] % PIXEL_SIZE),
"bad shape: %zu", layout.shape[layout.ndim - 1]);
megdnn_assert(layout.dtype.valid() && layout.ndim > m_align_axis);
ptrdiff_t first_non_zero_stride = 0;
for (int i = layout.ndim - 1; i >= 0; --i) {
megdnn_assert(layout.shape[i] && layout.stride[i] >= 0);
if (i < static_cast<int>(m_align_axis) && !first_non_zero_stride) {
first_non_zero_stride = layout.stride[i];
}
}
size_t mask =
image_pitch_alignment_in_bytes(
align_size_in_elements(layout.dtype.size_log()), layout) -
1;
megdnn_assert(!(first_non_zero_stride & mask),
"first stride is %d, but alignment is %zu",
static_cast<int>(first_non_zero_stride), mask + 1);
}
template <size_t PIXEL_SIZE>
size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_row_pitch(
const TensorLayout& layout) const {
for (int i = align_axis() - 1; i >= 0; --i) {
// find a non-broadcast axis
if (auto s = layout.stride[i]) {
return layout.dtype.size(s);
}
}
// use width for all broadcasted case
size_t alignment_in_bytes_log2 = align_size_in_elements_log2();
if (m_vendor_type == Handle::HandleVendorType::MALI) {
alignment_in_bytes_log2 +=
__builtin_ctz(layout.dtype.size() * PIXEL_SIZE);
}
return get_aligned_power2<size_t>(
layout.dtype.size(image_width_elems(layout)),
1 << alignment_in_bytes_log2);
}
template <size_t PIXEL_SIZE>
size_t
Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_pitch_alignment_in_bytes(
size_t align_size_in_elements, const TensorLayout& layout) const {
return m_vendor_type == Handle::HandleVendorType::MALI
? (align_size_in_elements * layout.dtype.size() * PIXEL_SIZE)
: align_size_in_elements;
}
template <size_t PIXEL_SIZE>
TensorLayout::Span Image2DPackedTensorFormatBase<PIXEL_SIZE>::span_spec(
const TensorLayout& layout) const {
assert_valid(layout);
size_t size = image_height(layout) * image_row_pitch(layout);
auto mask = (1 << layout.dtype.size_log()) - 1;
megdnn_assert(!(size & mask), "unaligned size: %zu", size);
return {0, 0, size >> layout.dtype.size_log(), size};
}
template <size_t PIXEL_SIZE>
size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::init_contiguous_stride(
TensorLayout& layout) const {
auto m_align_axis = align_axis();
if (!layout.ndim)
return 0;
megdnn_assert(layout.dtype.valid() && layout.ndim > m_align_axis,
"dtype=%s ndim=%zu align=%zu", layout.dtype.name(),
layout.ndim, m_align_axis);
size_t align_size = align_size_in_byte(layout.dtype.size_log());
size_t align_size = image_pitch_alignment_in_bytes(
align_size_in_elements(layout.dtype.size_log()), layout);
size_t accum = 1;
SafeMultiplies<size_t> mul;
for (size_t i = layout.ndim; i; --i) {
......@@ -216,12 +327,15 @@ size_t Image2DTensorFormatBase::init_contiguous_stride(
return accum;
};
bool Image2DTensorFormatBase::is_contiguous_spec(
template <size_t PIXEL_SIZE>
bool Image2DPackedTensorFormatBase<PIXEL_SIZE>::is_contiguous_spec(
const TensorLayout& layout) const {
megdnn_assert(layout.dtype.valid());
size_t align_size = align_size_in_byte(layout.dtype.size_log());
size_t align_size = image_pitch_alignment_in_bytes(
align_size_in_elements(layout.dtype.size_log()), layout);
ptrdiff_t expected = 1;
int height_axis = static_cast<int>(m_align_axis - 1);
int height_axis = static_cast<int>(align_axis() - 1);
for (int i = layout.ndim - 1; i >= 0; --i) {
if (i == height_axis) {
expected = megdnn::get_aligned_power2<size_t>(expected, align_size);
......@@ -235,7 +349,12 @@ bool Image2DTensorFormatBase::is_contiguous_spec(
return false;
}
size_t mask = align_size_in_byte(layout.dtype.size_log()) - 1;
size_t mask =
image_pitch_alignment_in_bytes(
align_size_in_elements(layout.dtype.size_log()),
layout) -
1;
megdnn_assert(s > expected && !(s & mask),
"invalid row pitch: %d; layout: %s",
static_cast<int>(s), layout.to_string().c_str());
......@@ -250,11 +369,12 @@ bool Image2DTensorFormatBase::is_contiguous_spec(
return expected != 0;
}
TensorLayout Image2DTensorFormatBase::collapse_contiguous_spec(
template <size_t PIXEL_SIZE>
TensorLayout Image2DPackedTensorFormatBase<PIXEL_SIZE>::collapse_contiguous_spec(
const TensorLayout& layout) const {
assert_valid(layout);
TensorLayout res{layout};
int new_axis = m_align_axis;
int new_axis = align_axis();
// remove all dims with shape 1
for (int i = static_cast<int>(res.ndim) - 1; i >= 0 && res.ndim >= 3; --i) {
if (i == new_axis && static_cast<int>(res.ndim) == new_axis + 1) {
......@@ -302,95 +422,6 @@ TensorLayout Image2DTensorFormatBase::collapse_contiguous_spec(
return res;
}
TensorLayout::Span Image2DTensorFormatBase::span_spec(
const TensorLayout& layout) const {
assert_valid(layout);
size_t size = image_height(layout) * image_row_pitch(layout);
auto mask = (1 << layout.dtype.size_log()) - 1;
megdnn_assert(!(size & mask), "unaligned size: %zu", size);
return {0, 0, size >> layout.dtype.size_log(), size};
}
void Image2DTensorFormatBase::serialize_append(std::string& result) const {
SerializePack pack;
pack.align_axis = m_align_axis;
megdnn_assert(pack.align_axis == m_align_axis); // detect overflow
result.append(reinterpret_cast<char*>(&pack), sizeof(pack));
}
size_t Image2DTensorFormatBase::image_height(const TensorLayout& layout) const {
size_t accum = 1;
for (int i = m_align_axis - 1; i >= 0; --i) {
if (layout.stride[i] == 0) {
// this dimension is broadcasted
} else {
accum *= layout.shape[i];
}
}
return accum;
}
size_t Image2DTensorFormatBase::image_row_pitch(
const TensorLayout& layout) const {
for (int i = m_align_axis - 1; i >= 0; --i) {
// find a non-broadcast axis
if (auto s = layout.stride[i]) {
return layout.dtype.size(s);
}
}
// use width for all broadcasted case
return get_aligned_power2<size_t>(
layout.dtype.size(image_width_elems(layout)),
1 << m_align_size_in_byte_log2);
}
void Image2DTensorFormatBase::assert_valid(const TensorLayout& layout) const {
megdnn_assert(layout.dtype.valid() && layout.ndim > m_align_axis);
ptrdiff_t first_non_zero_stride = 0;
for (int i = layout.ndim - 1; i >= 0; --i) {
megdnn_assert(layout.shape[i] && layout.stride[i] >= 0);
if (i < static_cast<int>(m_align_axis) && !first_non_zero_stride) {
first_non_zero_stride = layout.stride[i];
}
}
size_t mask = align_size_in_byte(layout.dtype.size_log()) - 1;
megdnn_assert(!(first_non_zero_stride & mask),
"first stride is %d, but alignment is %zu",
static_cast<int>(first_non_zero_stride), mask + 1);
}
size_t Image2DTensorFormatBase::image_width_elems(
const TensorLayout& layout) const {
size_t high_elem = 0;
for (size_t i = m_align_axis; i < layout.ndim; ++i) {
high_elem += (layout.shape[i] - 1) * layout.stride[i];
}
return high_elem + 1;
}
std::string Image2DTensorFormatBase::to_string() const {
return ssprintf("I2D{%zu,%d}", m_align_axis,
1 << m_align_size_in_byte_log2);
}
/* ===================== Image2DPackedTensorFormatBase ===================== */
template <size_t PIXEL_SIZE>
size_t Image2DPackedTensorFormatBase<PIXEL_SIZE>::image_width(
const TensorLayout& layout) const {
auto ret = image_width_elems(layout);
megdnn_assert(ret % PIXEL_SIZE == 0);
return ret / PIXEL_SIZE;
}
template <size_t PIXEL_SIZE>
void Image2DPackedTensorFormatBase<PIXEL_SIZE>::assert_valid(
const TensorLayout& layout) const {
Image2DTensorFormatBase::assert_valid(layout);
megdnn_assert(!(layout.shape[layout.ndim - 1] % PIXEL_SIZE),
"bad shape: %zu", layout.shape[layout.ndim - 1]);
}
namespace megdnn {
namespace detail {
template class Image2DPackedTensorFormatBase<4>;
......@@ -398,26 +429,29 @@ template class Image2DPackedTensorFormatBase<4>;
} // namespace megdnn
/* ===================== Image2DPack4TensorFormat ===================== */
TensorFormat Image2DPack4TensorFormat::make_raw(size_t align_axis,
size_t align_size_in_byte) {
TensorFormat Image2DPack4TensorFormat::make_raw(
size_t align_axis, size_t align_size_in_elements,
Handle::HandleVendorType vendor_type) {
static std::mutex mtx;
static std::unordered_map<uint64_t,
std::unique_ptr<Image2DPack4TensorFormat>>
cache;
megdnn_assert(std::max(align_axis, align_size_in_byte) <=
megdnn_assert(std::max(align_axis, align_size_in_elements) <=
std::numeric_limits<uint32_t>::max());
MEGDNN_LOCK_GUARD(mtx);
auto&& ptr = cache[(static_cast<uint64_t>(align_axis) << 32) |
align_size_in_byte];
align_size_in_elements];
if (!ptr) {
ptr.reset(new Image2DPack4TensorFormat{align_axis, align_size_in_byte});
ptr.reset(new Image2DPack4TensorFormat{
align_axis, align_size_in_elements, vendor_type});
}
return impl_to_tensor_format(ptr.get());
}
TensorFormat Image2DPack4TensorFormat::make(size_t align_axis,
const Handle* handle) {
return make_raw(align_axis, handle->image2d_pitch_alignment());
return make_raw(align_axis, handle->image2d_pitch_alignment(),
handle->vendor_type());
}
TensorFormat Image2DPack4TensorFormat::deserialize(const Handle* handle,
......@@ -429,7 +463,7 @@ TensorFormat Image2DPack4TensorFormat::deserialize(const Handle* handle,
}
TensorFormat Image2DPack4TensorFormat::change_axis(size_t axis) const {
return make_raw(axis, align_size_in_byte());
return make_raw(axis, align_size_in_elements(), vendor());
}
// vim: syntax=cpp.doxygen
......@@ -123,6 +123,10 @@ size_t HandleImpl::image2d_pitch_alignment() const {
return align;
}
HandleImpl::HandleVendorType HandleImpl::vendor_type() const {
return HandleVendorType::CUDA;
}
} // namespace cuda
} // namespace megdnn
......
......@@ -123,6 +123,7 @@ class HandleImpl: public HandleImplHelper {
TypeCvt* typecvt_opr() { return get_helper_opr<TypeCvt, 0>(this); }
size_t image2d_pitch_alignment() const override;
HandleVendorType vendor_type() const override;
private:
bool m_is_tegra_k1;
int m_device_id;
......
......@@ -118,6 +118,10 @@ size_t HandleImpl::image2d_pitch_alignment() const {
return g_image2d_pitch_alignment;
}
HandleImpl::HandleVendorType HandleImpl::vendor_type() const {
return HandleVendorType::NOT_SPEC;
}
size_t HandleImpl::exchange_image2d_pitch_alignment(size_t alignment) {
auto ret = g_image2d_pitch_alignment;
g_image2d_pitch_alignment = alignment;
......
......@@ -169,6 +169,7 @@ public:
* \param alignment the new alignment value to set
*/
static size_t exchange_image2d_pitch_alignment(size_t alignment);
HandleVendorType vendor_type() const override;
};
} // namespace naive
......
......@@ -175,6 +175,30 @@ namespace {
}
}
TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_WITH_VENDOR_MALI) {
TensorFormat fmt = Image2DPack4TensorFormat::make_raw(
1, 512, Handle::HandleVendorType::MALI);
TensorLayout layout{{5, 3, 8}, dtype::Float32{}, fmt};
ASSERT_EQ(layout.stride[2], 1);
ASSERT_EQ(layout.stride[1], 8);
ASSERT_EQ(layout.stride[0], 2048);
ASSERT_EQ(8192u, image_row_pitch(layout));
ASSERT_EQ(6u, image_width(layout));
ASSERT_EQ(5u, image_height(layout));
fmt = Image2DPack4TensorFormat::make_raw(1, 512,
Handle::HandleVendorType::MALI);
TensorLayout layout_s{{5, 3, 8}, dtype::Float16{}, fmt};
ASSERT_EQ(layout_s.stride[2], 1);
ASSERT_EQ(layout_s.stride[1], 8);
ASSERT_EQ(layout_s.stride[0], 2048);
ASSERT_EQ(4096u, image_row_pitch(layout_s));
ASSERT_EQ(6u, image_width(layout_s));
ASSERT_EQ(5u, image_height(layout_s));
}
TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT) {
TensorFormat fmt = Image2DPack4TensorFormat::make_raw(1, 1024);
ASSERT_FALSE(fmt.is_default());
......@@ -233,7 +257,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT) {
auto&& impl = contig.format.as_impl<Image2DPack4TensorFormat>();
ASSERT_EQ(make_layout({1, 8}, {32, 1}, layout.dtype), contig);
ASSERT_EQ(1u, impl.align_axis());
ASSERT_EQ(64u, impl.align_size_in_byte());
ASSERT_EQ(64u, impl.align_size_in_elements());
}
}
......@@ -258,7 +282,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_H) {
auto&& impl = contig.format.as_impl<Image2DPack4TensorFormat>();
ASSERT_EQ(make_layout({v0, 8}, {32, 1}, layout.dtype), contig);
ASSERT_EQ(1u, impl.align_axis());
ASSERT_EQ(64u, impl.align_size_in_byte());
ASSERT_EQ(64u, impl.align_size_in_elements());
}
}
......@@ -274,7 +298,7 @@ TEST(BASIC_TYPES, TENSOR_LAYOUT_FMT_COLLAPSE_W) {
layout.dtype),
contig);
ASSERT_EQ(1u, impl.align_axis());
ASSERT_EQ(64u, impl.align_size_in_byte());
ASSERT_EQ(64u, impl.align_size_in_elements());
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册