提交 28d85838 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(dnn): add relayout_format for nchw to nchw4 and ic <=4

GitOrigin-RevId: 07f2ee6c5be69adaf796d4073f16f88241480330
上级 3a53872f
...@@ -434,7 +434,7 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) ...@@ -434,7 +434,7 @@ pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0)
'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), 'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'),
Doc('MK8', 'Split 8 from M and K, better for neon compute:' Doc('MK8', 'Split 8 from M and K, better for neon compute:'
'(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' '(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the '
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'), 'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))'),
Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:' Doc('MK4_DOT', 'Split 4 from M and K, better for neon dotprod:'
'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the ' 'M/4, K/4, 4(m), 4(k)) x (K/4, N, 4(k)). if transposeA the '
'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))')) 'layout is (K/4, M/4, 4(m), 4(k)) x (K/4, N, 4(k))'))
...@@ -858,7 +858,10 @@ when the ``I`` suffix is present. ...@@ -858,7 +858,10 @@ when the ``I`` suffix is present.
'NCHW_NCHW88_CONV_CHAN_WEIGHT', 'NCHW_NCHW88_CONV_CHAN_WEIGHT',
'NCHW_NCHW88_CONV_GROUP_WEIGHT', 'NCHW_NCHW88_CONV_GROUP_WEIGHT',
'NCHW_NCHW88', 'NCHW_NCHW88',
'NCHW88_NCHW') 'NCHW88_NCHW',
'NCHW_NCHW4_IC_SMALL',
'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT',
)
) )
......
...@@ -28,6 +28,26 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, ...@@ -28,6 +28,26 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src,
dst[3] = src[3]; dst[3] = src[3];
dst[4] = 4; dst[4] = 4;
break; break;
case Param::Mode::NCHW_NCHW4_IC_SMALL:
dst.ndim = 5;
megdnn_assert(src[1] <= 4_z, "ic should be less equal 4");
dst[0] = src[0];
dst[1] = div_ceil(src[1], 4_z);
dst[2] = src[2];
dst[3] = src[3];
dst[4] = 4;
break;
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT:
megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 4");
megdnn_assert(src[1] <= 4_z, "ic should be less equal 4");
dst.ndim = 5;
dst[0] = src[0];
dst[1] = div_ceil(src[1], 4_z);
dst[2] = src[2];
dst[3] = src[3];
dst[4] = 4;
break;
case Param::Mode::NCHW_NCHW88: case Param::Mode::NCHW_NCHW88:
dst.ndim = 5; dst.ndim = 5;
dst[0] = src[0]; dst[0] = src[0];
...@@ -276,6 +296,8 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) { ...@@ -276,6 +296,8 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT:
case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT:
case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT:
case Param::Mode::NCHW_NCHW4_IC_SMALL:
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT:
CHECK_SRC(DefaultTensorFormat::make()); CHECK_SRC(DefaultTensorFormat::make());
dst = src; dst = src;
break; break;
...@@ -374,6 +396,23 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src, ...@@ -374,6 +396,23 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src,
exec_dst = dst; exec_dst = dst;
} }
break; break;
case Param::Mode::NCHW_NCHW4_IC_SMALL:
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT:
// nchw to nchw4c or oihw to oihw4i
{
TensorLayout work_space_layout(
{src[0], round_up(src[1], 4_z), src[2], src[3]},
src.dtype, src.format);
exec_src = work_space_layout
.reshape({src[0], div_ceil(src[1], 4_z), 4,
src[2], src[3]})
.dimshuffle({0, 1, 3, 4, 2});
exec_dst = dst;
}
break;
case Param::Mode::NCHW_NHWCD4: case Param::Mode::NCHW_NHWCD4:
case Param::Mode::NCHW_NHWCD4I: case Param::Mode::NCHW_NHWCD4I:
// src is {N, C, H, W} // src is {N, C, H, W}
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "src/cuda/relayout_format/opr_impl.h" #include "src/cuda/relayout_format/opr_impl.h"
#include "src/cuda/handle.h" #include "src/cuda/handle.h"
#include "src/cuda/utils.h"
using namespace megdnn; using namespace megdnn;
using namespace cuda; using namespace cuda;
...@@ -20,15 +21,22 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, ...@@ -20,15 +21,22 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
auto src_dtype = src.layout.dtype; auto src_dtype = src.layout.dtype;
megdnn_assert( megdnn_assert(
param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 || param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 ||
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4, param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4 ||
param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL ||
param().mode ==
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT,
"relayout format of cuda only support NCHW4->CHWN4 or " "relayout format of cuda only support NCHW4->CHWN4 or "
"CHWN4->NCHW4"); "CHWN4->NCHW4 or NCHW->NCHW4");
if (src_dtype.enumv() == DTypeEnum::QuantizedS8) { if ((param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 ||
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4) &&
src_dtype.enumv() == DTypeEnum::QuantizedS8) {
size_t row = 0, col = 0; size_t row = 0, col = 0;
if (param().mode == Param::RelayoutFormat::Mode::NCHW4_CHWN4) { if (param().mode == Param::RelayoutFormat::Mode::NCHW4_CHWN4) {
row = src.layout[0], row = src.layout[0],
col = src.layout[1] * src.layout[2] * src.layout[3]; col = src.layout[1] * src.layout[2] * src.layout[3];
} else { } else {
megdnn_assert(param().mode ==
param::RelayoutFormat::Mode::CHWN4_NCHW4);
row = src.layout[0] * src.layout[1] * src.layout[2], row = src.layout[0] * src.layout[1] * src.layout[2],
col = src.layout[3]; col = src.layout[3];
} }
...@@ -43,6 +51,27 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, ...@@ -43,6 +51,27 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
return handle()->create_operator<RelayoutForward>()->exec(trans_in, return handle()->create_operator<RelayoutForward>()->exec(trans_in,
trans_out); trans_out);
} }
if ((param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL ||
param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) &&
src.layout[1] % 4 != 0) {
megdnn_assert(src.raw_ptr != dst.raw_ptr && src.layout.ndim == 4,
"The mode of NCHW_NCHW4 and NCHW_NCHW4_CONV_DENSE_WEIGHT "
"of RelayoutFormat opr(cuda backend) does not support "
"src.ptr == dst.ptr");
megdnn_assert(src.layout[1] <= 4);
cuda_check(cudaMemsetAsync(dst.raw_ptr, 0,
dst.layout.span().dist_byte(),
cuda_stream(this->handle())));
TensorLayout exec_dst_layout = dst.layout;
exec_dst_layout[4] = src.layout[1];
TensorLayout exec_src_layout =
src.layout
.reshape({src.layout[0], src.layout[1], 1,
src.layout[2], src.layout[3]})
.dimshuffle({0, 2, 3, 4, 1});
return handle()->create_operator<RelayoutForward>()->exec(
{src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout});
}
TensorLayout exec_src, exec_dst; TensorLayout exec_src, exec_dst;
deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst); deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst);
TensorND exec_src_nd{src.raw_ptr, exec_src}; TensorND exec_src_nd{src.raw_ptr, exec_src};
......
...@@ -79,6 +79,7 @@ void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src, ...@@ -79,6 +79,7 @@ void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src,
} }
cb(Float32, dt_float32); cb(Float32, dt_float32);
cb(QuantizedS8, dt_qint8);
default: default:
megdnn_assert(0); megdnn_assert(0);
#undef cb #undef cb
...@@ -138,7 +139,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, ...@@ -138,7 +139,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src,
return n * c * h * w * src.dtype.size(); return n * c * h * w * src.dtype.size();
} }
case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: { case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: {
megdnn_assert(src.ndim == 4, "src must be oihw ,nmdim == 5"); megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5");
megdnn_assert(src[0] % 8 == 0, megdnn_assert(src[0] % 8 == 0,
"NCHW_NCHW88_CONV_DENSE_WEIGHT oc must align to 8"); "NCHW_NCHW88_CONV_DENSE_WEIGHT oc must align to 8");
if (src[1] % 8 == 0) if (src[1] % 8 == 0)
...@@ -150,7 +151,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, ...@@ -150,7 +151,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src,
return oc * ic * h * w * src.dtype.size(); return oc * ic * h * w * src.dtype.size();
} }
case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: { case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: {
megdnn_assert(src.ndim == 5, "src must be goihw ,nmdim == 5"); megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5");
megdnn_assert(src[1] % 8 == 0, megdnn_assert(src[1] % 8 == 0,
"NCHW_NCHW88_CONV_CHAN_WEIGHT oc per group must " "NCHW_NCHW88_CONV_CHAN_WEIGHT oc per group must "
"align to 8"); "align to 8");
...@@ -164,7 +165,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, ...@@ -164,7 +165,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src,
return group * ocpg * icpg * h * w * src.dtype.size(); return group * ocpg * icpg * h * w * src.dtype.size();
} }
case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: { case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: {
megdnn_assert(src.ndim == 5, "src must be goihw ,nmdim == 5"); megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5");
if (src[0] % 8 == 0) if (src[0] % 8 == 0)
return 0; return 0;
size_t group = round_up(src[0], 8_z); size_t group = round_up(src[0], 8_z);
...@@ -174,6 +175,27 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src, ...@@ -174,6 +175,27 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src,
size_t w = src[4]; size_t w = src[4];
return group * ocpg * icpg * h * w * src.dtype.size(); return group * ocpg * icpg * h * w * src.dtype.size();
} }
case Param::Mode::NCHW_NCHW4_IC_SMALL: {
if (src[1] % 4 == 0)
return 0;
size_t n = src[0];
size_t c = round_up(src[1], 4_z);
size_t h = src[2];
size_t w = src[3];
return n * c * h * w * src.dtype.size();
}
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: {
megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5");
if (src[1] % 4 == 0)
return 0;
size_t oc = src[0];
size_t ic = round_up(src[1], 4_z);
size_t h = src[2];
size_t w = src[3];
return oc * ic * h * w * src.dtype.size();
}
default: default:
return 0; return 0;
} }
...@@ -244,31 +266,28 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, ...@@ -244,31 +266,28 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
exec_src_nd.raw_ptr = workspace.raw_ptr; exec_src_nd.raw_ptr = workspace.raw_ptr;
} }
} else if (param().mode == Param::Mode::NCHW_NCHW88) { } else if (param().mode == Param::Mode::NCHW_NCHW88) {
size_t ic = src.layout[1]; #define cb(_idx, _pack_size) \
if (ic % 8 != 0) { size_t val = src.layout[_idx]; \
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 1, 8); if (val % _pack_size != 0) { \
exec_src_nd.raw_ptr = workspace.raw_ptr; padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \
} _pack_size); \
exec_src_nd.raw_ptr = workspace.raw_ptr; \
}
cb(1, 8);
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) { } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) {
megdnn_assert(src.layout[0] % 8 == 0); megdnn_assert(src.layout[0] % 8 == 0);
size_t ic = src.layout[1]; cb(1, 8);
if (ic % 8 != 0) {
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 1, 8_z);
exec_src_nd.raw_ptr = workspace.raw_ptr;
}
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) { } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) {
size_t group = src.layout[0]; cb(0, 8);
if (group % 8 != 0) {
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 0, 8_z);
exec_src_nd.raw_ptr = workspace.raw_ptr;
}
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) { } else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) {
megdnn_assert(src.layout[1] % 8 == 0); megdnn_assert(src.layout[1] % 8 == 0);
size_t ic = src.layout[2]; cb(2, 8);
if (ic % 8 != 0) { } else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) {
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 2, 8_z); cb(1, 4);
exec_src_nd.raw_ptr = workspace.raw_ptr; } else if (param().mode ==
} Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) {
cb(1, 4);
} }
m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle()); m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle());
} }
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "megdnn/dtype.h"
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "test/common/checker.h" #include "test/common/checker.h"
#include "test/common/rng.h" #include "test/common/rng.h"
...@@ -30,4 +31,25 @@ TEST_F(CUDA, RELAYOUT_FORMAT) { ...@@ -30,4 +31,25 @@ TEST_F(CUDA, RELAYOUT_FORMAT) {
checker.execs({{22, 23, 24, 25, 4}, {}}); checker.execs({{22, 23, 24, 25, 4}, {}});
} }
TEST_F(CUDA, RELAYOUT_FORMAT_NCHW4) {
Checker<RelayoutFormat> checker(handle_cuda());
UniformIntRNG rng{-50, 50};
param::RelayoutFormat param;
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL;
for (DType dtype :
std::vector<DType>({dtype::QuantizedS8{0.1f}, dtype::Float32{}})) {
checker.set_dtype(0, dtype).set_rng(0, &rng);
checker.set_param(param).execs({{2, 4, 35, 36}, {}});
checker.set_param(param).execs({{2, 3, 35, 36}, {}});
checker.set_param(param).execs({{2, 1, 35, 36}, {}});
param.mode = param::RelayoutFormat::Mode::
NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT;
checker.set_param(param).execs({{4, 3, 3, 3}, {}});
checker.set_param(param).execs({{4, 4, 3, 3}, {}});
checker.set_param(param).execs({{1, 4, 3, 3}, {}});
}
}
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册