提交 28d85838 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(dnn): add relayout_format for nchw to nchw4 and ic <=4

GitOrigin-RevId: 07f2ee6c5be69adaf796d4073f16f88241480330
上级 3a53872f
......@@ -858,7 +858,10 @@ when the ``I`` suffix is present.
'NCHW_NCHW88_CONV_CHAN_WEIGHT',
'NCHW_NCHW88_CONV_GROUP_WEIGHT',
'NCHW_NCHW88',
'NCHW88_NCHW')
'NCHW88_NCHW',
'NCHW_NCHW4_IC_SMALL',
'NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT',
)
)
......
......@@ -28,6 +28,26 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src,
dst[3] = src[3];
dst[4] = 4;
break;
case Param::Mode::NCHW_NCHW4_IC_SMALL:
dst.ndim = 5;
megdnn_assert(src[1] <= 4_z, "ic should be less equal 4");
dst[0] = src[0];
dst[1] = div_ceil(src[1], 4_z);
dst[2] = src[2];
dst[3] = src[3];
dst[4] = 4;
break;
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT:
megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 4");
megdnn_assert(src[1] <= 4_z, "ic should be less equal 4");
dst.ndim = 5;
dst[0] = src[0];
dst[1] = div_ceil(src[1], 4_z);
dst[2] = src[2];
dst[3] = src[3];
dst[4] = 4;
break;
case Param::Mode::NCHW_NCHW88:
dst.ndim = 5;
dst[0] = src[0];
......@@ -276,6 +296,8 @@ void RelayoutFormat::deduce_format(TensorFormat src, TensorFormat& dst) {
case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT:
case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT:
case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT:
case Param::Mode::NCHW_NCHW4_IC_SMALL:
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT:
CHECK_SRC(DefaultTensorFormat::make());
dst = src;
break;
......@@ -374,6 +396,23 @@ void RelayoutFormat::deduce_exec_layout(const TensorLayout& src,
exec_dst = dst;
}
break;
case Param::Mode::NCHW_NCHW4_IC_SMALL:
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT:
// nchw to nchw4c or oihw to oihw4i
{
TensorLayout work_space_layout(
{src[0], round_up(src[1], 4_z), src[2], src[3]},
src.dtype, src.format);
exec_src = work_space_layout
.reshape({src[0], div_ceil(src[1], 4_z), 4,
src[2], src[3]})
.dimshuffle({0, 1, 3, 4, 2});
exec_dst = dst;
}
break;
case Param::Mode::NCHW_NHWCD4:
case Param::Mode::NCHW_NHWCD4I:
// src is {N, C, H, W}
......
......@@ -11,6 +11,7 @@
#include "src/cuda/relayout_format/opr_impl.h"
#include "src/cuda/handle.h"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
......@@ -20,15 +21,22 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
auto src_dtype = src.layout.dtype;
megdnn_assert(
param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 ||
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4,
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4 ||
param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL ||
param().mode ==
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT,
"relayout format of cuda only support NCHW4->CHWN4 or "
"CHWN4->NCHW4");
if (src_dtype.enumv() == DTypeEnum::QuantizedS8) {
"CHWN4->NCHW4 or NCHW->NCHW4");
if ((param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 ||
param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4) &&
src_dtype.enumv() == DTypeEnum::QuantizedS8) {
size_t row = 0, col = 0;
if (param().mode == Param::RelayoutFormat::Mode::NCHW4_CHWN4) {
row = src.layout[0],
col = src.layout[1] * src.layout[2] * src.layout[3];
} else {
megdnn_assert(param().mode ==
param::RelayoutFormat::Mode::CHWN4_NCHW4);
row = src.layout[0] * src.layout[1] * src.layout[2],
col = src.layout[3];
}
......@@ -43,6 +51,27 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
return handle()->create_operator<RelayoutForward>()->exec(trans_in,
trans_out);
}
if ((param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL ||
param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) &&
src.layout[1] % 4 != 0) {
megdnn_assert(src.raw_ptr != dst.raw_ptr && src.layout.ndim == 4,
"The mode of NCHW_NCHW4 and NCHW_NCHW4_CONV_DENSE_WEIGHT "
"of RelayoutFormat opr(cuda backend) does not support "
"src.ptr == dst.ptr");
megdnn_assert(src.layout[1] <= 4);
cuda_check(cudaMemsetAsync(dst.raw_ptr, 0,
dst.layout.span().dist_byte(),
cuda_stream(this->handle())));
TensorLayout exec_dst_layout = dst.layout;
exec_dst_layout[4] = src.layout[1];
TensorLayout exec_src_layout =
src.layout
.reshape({src.layout[0], src.layout[1], 1,
src.layout[2], src.layout[3]})
.dimshuffle({0, 2, 3, 4, 1});
return handle()->create_operator<RelayoutForward>()->exec(
{src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout});
}
TensorLayout exec_src, exec_dst;
deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst);
TensorND exec_src_nd{src.raw_ptr, exec_src};
......
......@@ -79,6 +79,7 @@ void padding_to_workspace(_megdnn_tensor_out dst, _megdnn_tensor_in src,
}
cb(Float32, dt_float32);
cb(QuantizedS8, dt_qint8);
default:
megdnn_assert(0);
#undef cb
......@@ -138,7 +139,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src,
return n * c * h * w * src.dtype.size();
}
case Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT: {
megdnn_assert(src.ndim == 4, "src must be oihw ,nmdim == 5");
megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5");
megdnn_assert(src[0] % 8 == 0,
"NCHW_NCHW88_CONV_DENSE_WEIGHT oc must align to 8");
if (src[1] % 8 == 0)
......@@ -150,7 +151,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src,
return oc * ic * h * w * src.dtype.size();
}
case Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT: {
megdnn_assert(src.ndim == 5, "src must be goihw ,nmdim == 5");
megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5");
megdnn_assert(src[1] % 8 == 0,
"NCHW_NCHW88_CONV_CHAN_WEIGHT oc per group must "
"align to 8");
......@@ -164,7 +165,7 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src,
return group * ocpg * icpg * h * w * src.dtype.size();
}
case Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT: {
megdnn_assert(src.ndim == 5, "src must be goihw ,nmdim == 5");
megdnn_assert(src.ndim == 5, "src must be goihw, ndim == 5");
if (src[0] % 8 == 0)
return 0;
size_t group = round_up(src[0], 8_z);
......@@ -174,6 +175,27 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(const TensorLayout& src,
size_t w = src[4];
return group * ocpg * icpg * h * w * src.dtype.size();
}
case Param::Mode::NCHW_NCHW4_IC_SMALL: {
if (src[1] % 4 == 0)
return 0;
size_t n = src[0];
size_t c = round_up(src[1], 4_z);
size_t h = src[2];
size_t w = src[3];
return n * c * h * w * src.dtype.size();
}
case Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT: {
megdnn_assert(src.ndim == 4, "src must be oihw, ndim == 5");
if (src[1] % 4 == 0)
return 0;
size_t oc = src[0];
size_t ic = round_up(src[1], 4_z);
size_t h = src[2];
size_t w = src[3];
return oc * ic * h * w * src.dtype.size();
}
default:
return 0;
}
......@@ -244,31 +266,28 @@ void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
exec_src_nd.raw_ptr = workspace.raw_ptr;
}
} else if (param().mode == Param::Mode::NCHW_NCHW88) {
size_t ic = src.layout[1];
if (ic % 8 != 0) {
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 1, 8);
exec_src_nd.raw_ptr = workspace.raw_ptr;
}
#define cb(_idx, _pack_size) \
size_t val = src.layout[_idx]; \
if (val % _pack_size != 0) { \
padding_to_workspace({workspace.raw_ptr, exec_src}, src, _idx, \
_pack_size); \
exec_src_nd.raw_ptr = workspace.raw_ptr; \
}
cb(1, 8);
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_DENSE_WEIGHT) {
megdnn_assert(src.layout[0] % 8 == 0);
size_t ic = src.layout[1];
if (ic % 8 != 0) {
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 1, 8_z);
exec_src_nd.raw_ptr = workspace.raw_ptr;
}
cb(1, 8);
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_CHAN_WEIGHT) {
size_t group = src.layout[0];
if (group % 8 != 0) {
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 0, 8_z);
exec_src_nd.raw_ptr = workspace.raw_ptr;
}
cb(0, 8);
} else if (param().mode == Param::Mode::NCHW_NCHW88_CONV_GROUP_WEIGHT) {
megdnn_assert(src.layout[1] % 8 == 0);
size_t ic = src.layout[2];
if (ic % 8 != 0) {
padding_to_workspace({workspace.raw_ptr, exec_src}, src, 2, 8_z);
exec_src_nd.raw_ptr = workspace.raw_ptr;
}
cb(2, 8);
} else if (param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL) {
cb(1, 4);
} else if (param().mode ==
Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) {
cb(1, 4);
}
m_handle->relayout_opr()->exec(exec_src_nd, exec_dst_nd, handle());
}
......
......@@ -8,6 +8,7 @@
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/common/rng.h"
......@@ -30,4 +31,25 @@ TEST_F(CUDA, RELAYOUT_FORMAT) {
checker.execs({{22, 23, 24, 25, 4}, {}});
}
TEST_F(CUDA, RELAYOUT_FORMAT_NCHW4) {
Checker<RelayoutFormat> checker(handle_cuda());
UniformIntRNG rng{-50, 50};
param::RelayoutFormat param;
param.mode = param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL;
for (DType dtype :
std::vector<DType>({dtype::QuantizedS8{0.1f}, dtype::Float32{}})) {
checker.set_dtype(0, dtype).set_rng(0, &rng);
checker.set_param(param).execs({{2, 4, 35, 36}, {}});
checker.set_param(param).execs({{2, 3, 35, 36}, {}});
checker.set_param(param).execs({{2, 1, 35, 36}, {}});
param.mode = param::RelayoutFormat::Mode::
NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT;
checker.set_param(param).execs({{4, 3, 3, 3}, {}});
checker.set_param(param).execs({{4, 4, 3, 3}, {}});
checker.set_param(param).execs({{1, 4, 3, 3}, {}});
}
}
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册