提交 8b94f493 编写于 作者: M Megvii Engine Team

fix(dnn/cuda): fix elemwise and relayout int4 bug when last shape is 1

GitOrigin-RevId: e7d64c49871032deeda4176289f0457d4b9d85b8
上级 694aa1bd
...@@ -424,12 +424,20 @@ size_t TensorLayout::access_bytes() const { ...@@ -424,12 +424,20 @@ size_t TensorLayout::access_bytes() const {
if (dtype.is_low_bit()) { if (dtype.is_low_bit()) {
ret = 1; ret = 1;
int align_size_in_elements = 8 / dtype.low_bit(); int align_size_in_elements = 8 / dtype.low_bit();
auto min_stride = contig.stride[0];
for (size_t i = 0; i < contig.ndim; ++i) { for (size_t i = 0; i < contig.ndim; ++i) {
if (contig.stride[i] == 1) { if (contig.stride[i] == 1) {
ret *= round_up((int)contig.shape[i], align_size_in_elements); ret *= round_up((int)contig.shape[i], align_size_in_elements);
} else { } else {
ret *= contig.shape[i]; ret *= contig.shape[i];
} }
if (min_stride > contig.stride[i]) {
min_stride = contig.stride[i];
}
}
if (min_stride != 1) {
megdnn_assert(min_stride == align_size_in_elements);
ret *= min_stride;
} }
ret /= align_size_in_elements; ret /= align_size_in_elements;
} else { } else {
......
...@@ -240,6 +240,7 @@ template <int ndim> ...@@ -240,6 +240,7 @@ template <int ndim>
void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init( void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init(
const TensorND& rv, int /*grid_size*/, int /*block_size*/) { const TensorND& rv, int /*grid_size*/, int /*block_size*/) {
m_ptr = reinterpret_cast<Storage*>(rv.raw_ptr); m_ptr = reinterpret_cast<Storage*>(rv.raw_ptr);
auto min_stride = rv.layout.stride[0];
for (size_t i = 0; i < rv.layout.ndim; ++i) { for (size_t i = 0; i < rv.layout.ndim; ++i) {
m_stride[i] = rv.layout.stride[i]; m_stride[i] = rv.layout.stride[i];
m_shape[i] = rv.layout.shape[i]; m_shape[i] = rv.layout.shape[i];
...@@ -251,7 +252,12 @@ void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init( ...@@ -251,7 +252,12 @@ void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init(
else else
m_align_shape_highdim[i] = rv.layout.shape[i + 1]; m_align_shape_highdim[i] = rv.layout.shape[i + 1];
} }
if (min_stride > rv.layout.stride[i]) {
min_stride = rv.layout.stride[i];
} }
}
megdnn_assert(min_stride == 1 || min_stride == 2);
m_is_min_stride_2 = (min_stride == 2);
for (size_t i = rv.layout.ndim - 1; i < ndim - 1; ++i) { for (size_t i = rv.layout.ndim - 1; i < ndim - 1; ++i) {
m_shape_highdim[i] = 1; m_shape_highdim[i] = 1;
m_align_shape_highdim[i] = 1; m_align_shape_highdim[i] = 1;
......
...@@ -542,6 +542,7 @@ protected: ...@@ -542,6 +542,7 @@ protected:
int m_stride[ndim]; int m_stride[ndim];
int m_shape[ndim]; int m_shape[ndim];
bool m_is_physical_contiguous; bool m_is_physical_contiguous;
bool m_is_min_stride_2;
//! m_shape_highdim[i] = original_shape[i + 1] //! m_shape_highdim[i] = original_shape[i + 1]
#ifdef _MSC_VER #ifdef _MSC_VER
...@@ -592,7 +593,7 @@ public: ...@@ -592,7 +593,7 @@ public:
int idx = 0; int idx = 0;
if (m_is_physical_contiguous) { if (m_is_physical_contiguous) {
idx = access_idx; idx = access_idx;
} else { } else if (!m_is_min_stride_2) {
int shape_idx[ndim]; int shape_idx[ndim];
bool valid = true; bool valid = true;
get_shape_from_access(access_idx, shape_idx); get_shape_from_access(access_idx, shape_idx);
...@@ -605,6 +606,8 @@ public: ...@@ -605,6 +606,8 @@ public:
idx = (idx + shape_idx[i]) * m_shape[i + 1]; idx = (idx + shape_idx[i]) * m_shape[i + 1];
} }
idx = valid ? idx + shape_idx[ndim - 1] : -1; idx = valid ? idx + shape_idx[ndim - 1] : -1;
} else { // min_stride == 2
idx = ((access_idx & 0x1) == 0) ? ((int)access_idx >> 1) : -1;
} }
return idx; return idx;
} }
......
...@@ -70,6 +70,7 @@ void ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER>::host_init( ...@@ -70,6 +70,7 @@ void ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER>::host_init(
const TensorND& rv, int /*grid_size*/, int /*block_size*/) { const TensorND& rv, int /*grid_size*/, int /*block_size*/) {
megdnn_assert(rv.layout.ndim && rv.layout.ndim <= ndim); megdnn_assert(rv.layout.ndim && rv.layout.ndim <= ndim);
m_ptr = reinterpret_cast<Storage*>(rv.raw_ptr); m_ptr = reinterpret_cast<Storage*>(rv.raw_ptr);
auto min_stride = rv.layout.stride[0];
for (size_t i = 0; i < rv.layout.ndim; ++i) { for (size_t i = 0; i < rv.layout.ndim; ++i) {
m_stride[i] = rv.layout.stride[i]; m_stride[i] = rv.layout.stride[i];
m_shape[i] = rv.layout.shape[i]; m_shape[i] = rv.layout.shape[i];
...@@ -81,7 +82,12 @@ void ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER>::host_init( ...@@ -81,7 +82,12 @@ void ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER>::host_init(
else else
m_align_shape_highdim[i] = rv.layout.shape[i + 1]; m_align_shape_highdim[i] = rv.layout.shape[i + 1];
} }
if (min_stride > rv.layout.stride[i]) {
min_stride = rv.layout.stride[i];
} }
}
megdnn_assert(min_stride == 1 || min_stride == 2);
m_is_min_stride_2 = (min_stride == 2);
for (size_t i = rv.layout.ndim - 1; i < ndim - 1; ++i) { for (size_t i = rv.layout.ndim - 1; i < ndim - 1; ++i) {
m_shape_highdim[i] = 1; m_shape_highdim[i] = 1;
m_align_shape_highdim[i] = 1; m_align_shape_highdim[i] = 1;
......
...@@ -132,6 +132,7 @@ class ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER> { ...@@ -132,6 +132,7 @@ class ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER> {
int m_shape[ndim]; int m_shape[ndim];
bool m_is_contiguous; bool m_is_contiguous;
bool m_is_physical_contiguous; bool m_is_physical_contiguous;
bool m_is_min_stride_2;
//! m_shape_highdim[i] = original_shape[i + 1] //! m_shape_highdim[i] = original_shape[i + 1]
#ifdef _MSC_VER #ifdef _MSC_VER
...@@ -197,7 +198,7 @@ public: ...@@ -197,7 +198,7 @@ public:
int idx = 0; int idx = 0;
if (m_is_physical_contiguous) { if (m_is_physical_contiguous) {
idx = access_idx; idx = access_idx;
} else { } else if (!m_is_min_stride_2) {
int shape_idx[ndim]; int shape_idx[ndim];
bool valid = true; bool valid = true;
get_shape_from_access(access_idx, shape_idx); get_shape_from_access(access_idx, shape_idx);
...@@ -209,6 +210,8 @@ public: ...@@ -209,6 +210,8 @@ public:
idx = (idx + shape_idx[i]) * m_shape[i + 1]; idx = (idx + shape_idx[i]) * m_shape[i + 1];
} }
idx = valid ? idx + shape_idx[ndim - 1] : -1; idx = valid ? idx + shape_idx[ndim - 1] : -1;
} else { // min_stride == 2
idx = ((access_idx & 0x1) == 0) ? ((int)access_idx >> 1) : -1;
} }
return idx; return idx;
} }
......
...@@ -152,7 +152,8 @@ static void run_test_q4(int arity, Checker<ElemwiseMultiType>& checker, ...@@ -152,7 +152,8 @@ static void run_test_q4(int arity, Checker<ElemwiseMultiType>& checker,
.execs({{1, 4, 5, 5}, {1, 4, 5, 5}}); .execs({{1, 4, 5, 5}, {1, 4, 5, 5}});
} else if (arity == 2) { } else if (arity == 2) {
checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}}) checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}})
.execs({{1, 4, 5, 5}, {1, 4, 5, 5}, {1, 4, 5, 5}}); .execs({{1, 4, 5, 5}, {1, 4, 5, 5}, {1, 4, 5, 5}})
.execs({{2, 2, 3, 1}, {2, 2, 3, 1}, {2, 2, 3, 1}});
} else { } else {
megdnn_assert(0); megdnn_assert(0);
} }
......
...@@ -925,6 +925,7 @@ TEST_F(CUDA, RELAYOUT_Q4) { ...@@ -925,6 +925,7 @@ TEST_F(CUDA, RELAYOUT_Q4) {
.set_rng(1, &rng_int4) .set_rng(1, &rng_int4)
.set_dtype(0, dtype::QuantizedS4(1.f)) .set_dtype(0, dtype::QuantizedS4(1.f))
.set_dtype(1, dtype::QuantizedS4(1.f)) .set_dtype(1, dtype::QuantizedS4(1.f))
.execs({{2, 2, 1, 1}, {1, 1, 2, 2}})
.execs({{1, 64, 15, 15}, {1, 15, 15, 64}}) .execs({{1, 64, 15, 15}, {1, 15, 15, 64}})
.execs({{1, 5, 9, 32}, {1, 5, 32, 9}}) .execs({{1, 5, 9, 32}, {1, 5, 32, 9}})
.execl(TensorLayoutArray{ .execl(TensorLayoutArray{
......
...@@ -123,11 +123,13 @@ TEST_F(CUDA, QUANTIZED_TYPECVT_4BIT) { ...@@ -123,11 +123,13 @@ TEST_F(CUDA, QUANTIZED_TYPECVT_4BIT) {
set_err(dst_dtype); set_err(dst_dtype);
checker.set_dtype(0, src_dtype) checker.set_dtype(0, src_dtype)
.set_dtype(1, dst_dtype) .set_dtype(1, dst_dtype)
.execs({{16, 3, 224, 223}, {16, 3, 224, 223}}); .execs({{16, 3, 224, 223}, {16, 3, 224, 223}})
.execs({{16, 3, 224, 1}, {16, 3, 224, 1}});
set_err(src_dtype); set_err(src_dtype);
checker.set_dtype(0, dst_dtype) checker.set_dtype(0, dst_dtype)
.set_dtype(1, src_dtype) .set_dtype(1, src_dtype)
.execs({{16, 3, 224, 223}, {16, 3, 224, 223}}); .execs({{16, 3, 224, 223}, {16, 3, 224, 223}})
.execs({{16, 3, 224, 1}, {16, 3, 224, 1}});
}; };
run(dtype::Quantized4Asymm{1.19990518f, 8}, run(dtype::Quantized4Asymm{1.19990518f, 8},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册