提交 48526abb 编写于 作者: M Megvii Engine Team

fix(mgb): fix concat cd4 tensor check size invalid

GitOrigin-RevId: 065e0b4be06fef049ebec215fee8e9b789d579db
上级 c87d998e
......@@ -140,20 +140,30 @@ struct TensorLayout : public TensorShape {
/*!
* \brief Describes min and max offsets of tensor elements with respect to
* its first element, so all tensor elements are guaranteed to be in
* the range [elem[0]+low, elem[0]+high).
* the range [elem[0]+low, elem[0]+last). Besides, we have a high to
* describe the range including row pitch when using image2D
*/
struct Span {
ptrdiff_t low_elem, low_byte;
size_t high_elem, high_byte;
//! The differece between high_elem and last elem is that last_elem describes
//! the last element of a tensor regardless of the row pitch at the last row. It
//! will be useful when copying into a part of image.
size_t last_elem, last_byte;
Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem, size_t high_byte)
Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem, size_t high_byte,
size_t last_elem, size_t last_byte)
: low_elem(low_elem),
low_byte(low_byte),
high_elem(high_elem),
high_byte(high_byte) {}
high_byte(high_byte),
last_elem(last_elem),
last_byte(last_byte) {}
size_t dist_elem() const { return high_elem - low_elem; }
size_t dist_byte() const { return high_byte - low_byte; }
size_t dist_last_byte() const { return last_byte - low_byte; }
};
/*!
......
......@@ -157,14 +157,14 @@ TensorLayout DefaultTensorFormat::collapse_contiguous_spec(
TensorLayout::Span DefaultTensorFormat::span_spec(const TensorLayout& layout) const {
assert_valid(layout);
if (layout.ndim == 0)
return {0, 0, 0, 0};
return {0, 0, 0, 0, 0, 0};
ptrdiff_t low_elem = 0;
size_t high_elem = 0;
for (size_t i = 0; i < layout.ndim; ++i) {
auto shape_val = layout.shape[i];
if (!shape_val) {
return {0, 0, 0, 0};
return {0, 0, 0, 0, 0, 0};
}
auto stride_val = layout.stride[i];
if (stride_val > 0) {
......@@ -181,7 +181,8 @@ TensorLayout::Span DefaultTensorFormat::span_spec(const TensorLayout& layout) co
low_byte = 0;
}
size_t high_byte = layout.dtype.size(high_elem);
return TensorLayout::Span(low_elem, low_byte, high_elem, high_byte);
return TensorLayout::Span(
low_elem, low_byte, high_elem, high_byte, high_elem, high_byte);
}
std::string DefaultTensorFormat::to_string() const {
......@@ -274,7 +275,12 @@ void Image2DPackedTensorFormatBase<PIXEL_SIZE>::assert_valid(
layout.ndim > m_align_axis);
ptrdiff_t first_non_zero_stride = 0;
for (int i = layout.ndim - 1; i >= 0; --i) {
megdnn_assert(layout.shape[i] && layout.stride[i] >= 0);
megdnn_assert(layout.shape[i]);
megdnn_assert(
layout.stride[i] >= 0,
"stride in Image2D format does not support negative stride {%s}. Use "
"NCHW format instead.",
layout.to_string().c_str());
if (i < static_cast<int>(m_align_axis) && !first_non_zero_stride) {
first_non_zero_stride = layout.stride[i];
}
......@@ -322,7 +328,14 @@ TensorLayout::Span Image2DPackedTensorFormatBase<PIXEL_SIZE>::span_spec(
size_t size = image_height(layout) * image_row_pitch(layout);
auto mask = (1 << layout.dtype.size_log()) - 1;
megdnn_assert(!(size & mask), "unaligned size: %zu", size);
return {0, 0, size >> layout.dtype.size_log(), size};
auto collapse_layout = layout.collapse_contiguous();
size_t last_elem = 0;
for (size_t i = 0; i < 2; ++i) {
last_elem += (collapse_layout.shape[i] - 1) * collapse_layout.stride[i];
}
last_elem++;
size_t last_byte = last_elem * layout.dtype.size();
return {0, 0, size >> layout.dtype.size_log(), size, last_elem, last_byte};
}
template <size_t PIXEL_SIZE>
......@@ -507,13 +520,13 @@ TensorLayout::Span LowbitsAlignedTensorFormatBase::span_spec(
const TensorLayout& layout) const {
assert_valid(layout);
if (layout.ndim == 0)
return {0, 0, 0, 0};
return {0, 0, 0, 0, 0, 0};
size_t high_elem = 0;
for (size_t i = 0; i < layout.ndim; ++i) {
auto shape_val = layout.shape[i];
if (!shape_val) {
return {0, 0, 0, 0};
return {0, 0, 0, 0, 0, 0};
}
auto stride_val = layout.stride[i];
megdnn_assert(
......@@ -522,7 +535,7 @@ TensorLayout::Span LowbitsAlignedTensorFormatBase::span_spec(
}
++high_elem;
size_t high_byte = layout.dtype.size(high_elem);
return TensorLayout::Span(0, 0, high_elem, high_byte);
return TensorLayout::Span(0, 0, high_elem, high_byte, high_elem, high_byte);
}
size_t LowbitsAlignedTensorFormatBase::init_contiguous_stride(
......
......@@ -500,7 +500,23 @@ DEF(reset, &)(TensorStorage storage, const TensorLayout& layout) {
//! The storage to be reset is either satisfy the layout or empty.
//! Empty storage is used after weight preprocess for saving memory and
//! checking layout when running
mgb_assert(!layout.ndim || storage.valid_span(layout.span()) || storage.empty());
auto span = layout.span();
if (span.last_elem == span.high_elem) {
mgb_assert(!layout.ndim || storage.valid_span(span) || storage.empty());
} else {
size_t start_pos = span.low_byte + static_cast<ptrdiff_t>(storage.offset());
bool enough_size = span.last_byte <= storage.size();
bool valid_size = storage.comp_node().valid() && start_pos >= 0 && enough_size;
mgb_assert(!layout.ndim || valid_size || storage.empty());
if (valid_size && !storage.valid_span(span)) {
mgb_log_warn(
"storage size %zu can not hold the whole layout %s, but holds all "
"elements. Only accepted when copying one CD4 Tensor into another "
"CD4 Tensor\n",
storage.size(), layout.to_string().c_str());
}
}
m_storage = std::move(storage);
m_layout = layout;
return static_cast<ChainReturnType&>(*this);
......@@ -686,8 +702,8 @@ const typename TensorND<TensorStorage>::ChainReturnType& TensorND<
if (should_check_overlap(*this, src)) {
check_overlapped(
this->raw_ptr() + dst_span.low_byte,
this->raw_ptr() + dst_span.high_byte, src.raw_ptr() + src_span.low_byte,
src.raw_ptr() + src_span.high_byte);
this->raw_ptr() + dst_span.last_byte, src.raw_ptr() + src_span.low_byte,
src.raw_ptr() + src_span.last_byte);
}
bool self_contig =
......@@ -702,12 +718,12 @@ const typename TensorND<TensorStorage>::ChainReturnType& TensorND<
src.layout().format.is_lowbit_aligned())) {
mgb_assert(
src_span.low_byte == 0 && dst_span.low_byte == 0 &&
src_span.high_byte == dst_span.high_byte);
m_storage.copy_from(src.storage(), src_span.high_byte);
src_span.last_byte == dst_span.last_byte);
m_storage.copy_from(src.storage(), src_span.last_byte);
} else {
mgb_assert(src_span.low_byte == 0 && dst_span.low_byte == 0);
m_storage.copy_from(
src.storage(), std::min(src_span.high_byte, dst_span.high_byte));
src.storage(), std::min(src_span.last_byte, dst_span.last_byte));
}
return static_cast<const ChainReturnType&>(*this);
}
......
......@@ -1446,6 +1446,48 @@ TEST(TestTensorManip, ConcatEmpty2) {
ASSERT_EQ(TensorShape({2, 0, 11}), host_z.shape());
}
#if MGB_OPENCL
#include "megcore_opencl.h"
#define REQUIRE_OPENCL() \
do { \
if (!CompNode::get_device_count(CompNode::DeviceType::OPENCL)) { \
return; \
} \
} while (0)
TEST(TestTensorManip, ConcatCD4) {
REQUIRE_OPENCL();
auto cn = CompNode::load("openclx");
HostTensorGenerator<> gen;
auto host_x = gen({1, 4, 2, 2}, cn), host_y = gen({1, 4, 2, 2}, cn);
auto graph0 = ComputingGraph::make();
auto x = opr::Host2DeviceCopy::make(*graph0, host_x);
auto y = opr::Host2DeviceCopy::make(*graph0, host_y);
x = opr::RelayoutFormat::make(x, {opr::RelayoutFormat::Param::Mode::NCHW_NHWCD4I});
y = opr::RelayoutFormat::make(y, {opr::RelayoutFormat::Param::Mode::NCHW_NHWCD4I});
auto z = opr::Concat::make({x, y}, 2);
HostTensorND host_z0;
auto func = graph0->compile({make_callback_copy(z, host_z0)});
func->execute();
ASSERT_EQ(TensorShape({1, 2, 2, 2, 4}), host_z0.shape());
auto graph1 = ComputingGraph::make();
x = opr::Host2DeviceCopy::make(*graph1, host_x);
y = opr::Host2DeviceCopy::make(*graph1, host_y);
z = opr::RelayoutFormat::make(
opr::Concat::make({x, y}, 1),
{opr::RelayoutFormat::Param::Mode::NCHW_NHWCD4I});
HostTensorND host_z1;
func = graph1->compile({make_callback_copy(z, host_z1)});
func->execute();
MGB_ASSERT_TENSOR_EQ(host_z0, host_z1);
}
#endif
TEST(TestTensorManip, AxisAddRemove) {
HostTensorGenerator<> gen;
for (bool dyn_shape : {false, true}) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册