提交 8b94f493 编写于 作者: M Megvii Engine Team

fix(dnn/cuda): fix elemwise and relayout int4 bug when last shape is 1

GitOrigin-RevId: e7d64c49871032deeda4176289f0457d4b9d85b8
上级 694aa1bd
......@@ -424,12 +424,20 @@ size_t TensorLayout::access_bytes() const {
if (dtype.is_low_bit()) {
ret = 1;
int align_size_in_elements = 8 / dtype.low_bit();
auto min_stride = contig.stride[0];
for (size_t i = 0; i < contig.ndim; ++i) {
if (contig.stride[i] == 1) {
ret *= round_up((int)contig.shape[i], align_size_in_elements);
} else {
ret *= contig.shape[i];
}
if (min_stride > contig.stride[i]) {
min_stride = contig.stride[i];
}
}
if (min_stride != 1) {
megdnn_assert(min_stride == align_size_in_elements);
ret *= min_stride;
}
ret /= align_size_in_elements;
} else {
......
......@@ -240,6 +240,7 @@ template <int ndim>
void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init(
const TensorND& rv, int /*grid_size*/, int /*block_size*/) {
m_ptr = reinterpret_cast<Storage*>(rv.raw_ptr);
auto min_stride = rv.layout.stride[0];
for (size_t i = 0; i < rv.layout.ndim; ++i) {
m_stride[i] = rv.layout.stride[i];
m_shape[i] = rv.layout.shape[i];
......@@ -251,7 +252,12 @@ void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init(
else
m_align_shape_highdim[i] = rv.layout.shape[i + 1];
}
if (min_stride > rv.layout.stride[i]) {
min_stride = rv.layout.stride[i];
}
}
megdnn_assert(min_stride == 1 || min_stride == 2);
m_is_min_stride_2 = (min_stride == 2);
for (size_t i = rv.layout.ndim - 1; i < ndim - 1; ++i) {
m_shape_highdim[i] = 1;
m_align_shape_highdim[i] = 1;
......
......@@ -542,6 +542,7 @@ protected:
int m_stride[ndim];
int m_shape[ndim];
bool m_is_physical_contiguous;
bool m_is_min_stride_2;
//! m_shape_highdim[i] = original_shape[i + 1]
#ifdef _MSC_VER
......@@ -592,7 +593,7 @@ public:
int idx = 0;
if (m_is_physical_contiguous) {
idx = access_idx;
} else {
} else if (!m_is_min_stride_2) {
int shape_idx[ndim];
bool valid = true;
get_shape_from_access(access_idx, shape_idx);
......@@ -605,6 +606,8 @@ public:
idx = (idx + shape_idx[i]) * m_shape[i + 1];
}
idx = valid ? idx + shape_idx[ndim - 1] : -1;
} else { // min_stride == 2
idx = ((access_idx & 0x1) == 0) ? ((int)access_idx >> 1) : -1;
}
return idx;
}
......
......@@ -70,6 +70,7 @@ void ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER>::host_init(
const TensorND& rv, int /*grid_size*/, int /*block_size*/) {
megdnn_assert(rv.layout.ndim && rv.layout.ndim <= ndim);
m_ptr = reinterpret_cast<Storage*>(rv.raw_ptr);
auto min_stride = rv.layout.stride[0];
for (size_t i = 0; i < rv.layout.ndim; ++i) {
m_stride[i] = rv.layout.stride[i];
m_shape[i] = rv.layout.shape[i];
......@@ -81,7 +82,12 @@ void ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER>::host_init(
else
m_align_shape_highdim[i] = rv.layout.shape[i + 1];
}
if (min_stride > rv.layout.stride[i]) {
min_stride = rv.layout.stride[i];
}
}
megdnn_assert(min_stride == 1 || min_stride == 2);
m_is_min_stride_2 = (min_stride == 2);
for (size_t i = rv.layout.ndim - 1; i < ndim - 1; ++i) {
m_shape_highdim[i] = 1;
m_align_shape_highdim[i] = 1;
......
......@@ -132,6 +132,7 @@ class ParamElemVisitor<ndim, dt_quint4, CONTIG_OTHER> {
int m_shape[ndim];
bool m_is_contiguous;
bool m_is_physical_contiguous;
bool m_is_min_stride_2;
//! m_shape_highdim[i] = original_shape[i + 1]
#ifdef _MSC_VER
......@@ -197,7 +198,7 @@ public:
int idx = 0;
if (m_is_physical_contiguous) {
idx = access_idx;
} else {
} else if (!m_is_min_stride_2) {
int shape_idx[ndim];
bool valid = true;
get_shape_from_access(access_idx, shape_idx);
......@@ -209,6 +210,8 @@ public:
idx = (idx + shape_idx[i]) * m_shape[i + 1];
}
idx = valid ? idx + shape_idx[ndim - 1] : -1;
} else { // min_stride == 2
idx = ((access_idx & 0x1) == 0) ? ((int)access_idx >> 1) : -1;
}
return idx;
}
......
......@@ -152,7 +152,8 @@ static void run_test_q4(int arity, Checker<ElemwiseMultiType>& checker,
.execs({{1, 4, 5, 5}, {1, 4, 5, 5}});
} else if (arity == 2) {
checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}})
.execs({{1, 4, 5, 5}, {1, 4, 5, 5}, {1, 4, 5, 5}});
.execs({{1, 4, 5, 5}, {1, 4, 5, 5}, {1, 4, 5, 5}})
.execs({{2, 2, 3, 1}, {2, 2, 3, 1}, {2, 2, 3, 1}});
} else {
megdnn_assert(0);
}
......
......@@ -925,6 +925,7 @@ TEST_F(CUDA, RELAYOUT_Q4) {
.set_rng(1, &rng_int4)
.set_dtype(0, dtype::QuantizedS4(1.f))
.set_dtype(1, dtype::QuantizedS4(1.f))
.execs({{2, 2, 1, 1}, {1, 1, 2, 2}})
.execs({{1, 64, 15, 15}, {1, 15, 15, 64}})
.execs({{1, 5, 9, 32}, {1, 5, 32, 9}})
.execl(TensorLayoutArray{
......
......@@ -123,11 +123,13 @@ TEST_F(CUDA, QUANTIZED_TYPECVT_4BIT) {
set_err(dst_dtype);
checker.set_dtype(0, src_dtype)
.set_dtype(1, dst_dtype)
.execs({{16, 3, 224, 223}, {16, 3, 224, 223}});
.execs({{16, 3, 224, 223}, {16, 3, 224, 223}})
.execs({{16, 3, 224, 1}, {16, 3, 224, 1}});
set_err(src_dtype);
checker.set_dtype(0, dst_dtype)
.set_dtype(1, src_dtype)
.execs({{16, 3, 224, 223}, {16, 3, 224, 223}});
.execs({{16, 3, 224, 223}, {16, 3, 224, 223}})
.execs({{16, 3, 224, 1}, {16, 3, 224, 1}});
};
run(dtype::Quantized4Asymm{1.19990518f, 8},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册