提交 05734355 编写于 作者: M Megvii Engine Team

fix(opencl): support OpenCL NHWC_NHWCD4I channel

not a multiple of 4

GitOrigin-RevId: 1e85cfc8589219a5339cd4a3b60fe50a7f20230a
上级 485e56ca
...@@ -91,7 +91,6 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, TensorLayout& ds ...@@ -91,7 +91,6 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, TensorLayout& ds
dst[6] = 8; dst[6] = 8;
break; break;
case Param::Mode::NHWC_NHWCD4: case Param::Mode::NHWC_NHWCD4:
case Param::Mode::NHWC_NHWCD4I:
megdnn_assert(src.ndim == 4); megdnn_assert(src.ndim == 4);
//! channel mod 4 should == 4 //! channel mod 4 should == 4
megdnn_assert(src[3] % 4 == 0); megdnn_assert(src[3] % 4 == 0);
...@@ -102,6 +101,15 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, TensorLayout& ds ...@@ -102,6 +101,15 @@ void RelayoutFormat::deduce_layout_fwd(const TensorLayout& src, TensorLayout& ds
dst[3] = src[2]; dst[3] = src[2];
dst[4] = 4; dst[4] = 4;
break; break;
case Param::Mode::NHWC_NHWCD4I:
megdnn_assert(src.ndim == 4);
dst.ndim = 5;
dst[0] = src[0];
dst[1] = src[1];
dst[2] = (src[3] + 3) / 4;
dst[3] = src[2];
dst[4] = 4;
break;
case Param::Mode::NHWCD4I_NHWC: case Param::Mode::NHWCD4I_NHWC:
case Param::Mode::NHWCD4_NHWC: case Param::Mode::NHWCD4_NHWC:
megdnn_assert(src.ndim == 5); megdnn_assert(src.ndim == 5);
...@@ -587,13 +595,24 @@ void RelayoutFormat::deduce_exec_layout( ...@@ -587,13 +595,24 @@ void RelayoutFormat::deduce_exec_layout(
exec_dst = dst; exec_dst = dst;
break; break;
case Param::Mode::NHWC_NHWCD4: case Param::Mode::NHWC_NHWCD4:
case Param::Mode::NHWC_NHWCD4I:
// src is {N, H, W, C}, // src is {N, H, W, C},
// dst is {N, H, CB, W, 4} // dst is {N, H, CB, W, 4}
exec_src = src.reshape({src[0], src[1], src[2], src[3] / 4, 4}) exec_src = src.reshape({src[0], src[1], src[2], src[3] / 4, 4})
.dimshuffle({0, 1, 3, 2, 4}); .dimshuffle({0, 1, 3, 2, 4});
exec_dst = dst; exec_dst = dst;
break; break;
case Param::Mode::NHWC_NHWCD4I:
// src is {N, H, W, C},
// dst is {N, H, CB, W, 4}
exec_src = src;
exec_src[3] = (exec_src[3] + 3) / 4 * 4;
exec_src.stride[2] = exec_src[3] * exec_src.stride[3];
exec_src.stride[1] = exec_src[2] * exec_src.stride[2];
exec_src.stride[0] = exec_src[1] * exec_src.stride[1];
exec_src = exec_src.reshape({src[0], src[1], src[2], (src[3] + 3) / 4, 4})
.dimshuffle({0, 1, 3, 2, 4});
exec_dst = dst;
break;
case Param::Mode::NHWCD4I_NHWC: case Param::Mode::NHWCD4I_NHWC:
case Param::Mode::NHWCD4_NHWC: case Param::Mode::NHWCD4_NHWC:
// src is {N, H, CB, W, 4} // src is {N, H, CB, W, 4}
......
...@@ -134,6 +134,25 @@ void padding_src_to_workspace( ...@@ -134,6 +134,25 @@ void padding_src_to_workspace(
} }
} }
template <typename dtype>
void padding_nhwc_src_to_workspace(
dtype* dptr, const dtype* sptr, size_t N, size_t IH, size_t IW, size_t IC) {
size_t IC4 = (IC + 3) / 4 * 4;
size_t HW = IH * IW;
for (size_t n = 0; n < N; n++) {
for (size_t idx = 0; idx < HW; idx++) {
for (size_t c = 0; c < IC4; c++) {
if (c < IC) {
*dptr = sptr[n * IC * HW + idx * IC + c];
} else {
*dptr = 0;
}
dptr++;
}
}
}
}
template <typename dtype> template <typename dtype>
void padding_to_workspace( void padding_to_workspace(
dtype* dptr, const dtype* sptr, const TensorLayout& src_layout, dtype* dptr, const dtype* sptr, const TensorLayout& src_layout,
...@@ -318,6 +337,15 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes( ...@@ -318,6 +337,15 @@ size_t RelayoutFormatImpl::get_workspace_in_bytes(
size_t IW = src[3]; size_t IW = src[3];
return N * IC4 * IH * IW * src.dtype.size(); return N * IC4 * IH * IW * src.dtype.size();
} }
case Param::Mode::NHWC_NHWCD4I: {
if (src[3] % 4 == 0)
return 0;
size_t IC4 = dst[2] * 4;
size_t N = src[0];
size_t IH = src[1];
size_t IW = src[2];
return N * IC4 * IH * IW * src.dtype.size();
}
case Param::Mode::INTER_WEIGHT_DENSEI_DOT: { case Param::Mode::INTER_WEIGHT_DENSEI_DOT: {
if (src[1] % 4 == 0) if (src[1] % 4 == 0)
return 0; return 0;
...@@ -509,7 +537,43 @@ void RelayoutFormatImpl::exec( ...@@ -509,7 +537,43 @@ void RelayoutFormatImpl::exec(
cb(Uint8, dt_uint8); cb(Uint8, dt_uint8);
#undef cb #undef cb
default: default:
megdnn_assert(0); megdnn_assert(
0, "NCHW_NHWCD4I not support dtype %s",
src.layout.dtype.name());
}
exec_src_nd = TensorND{workspace.raw_ptr, exec_src_nd.layout};
}
} else if (param().mode == Param::Mode::NHWC_NHWCD4I) {
size_t N = src.layout[0];
size_t IC = src.layout[3];
size_t IH = src.layout[1];
size_t IW = src.layout[2];
//! ic % 4 != 0
if ((IC & 0x3)) {
switch (src.layout.dtype.enumv()) {
#define cb(name, ctype) \
case (DTypeEnum::name): { \
MIDOUT_BEGIN( \
megdnn_naive_relayout_format, ctype, \
midout_iv(Param::Mode::NHWC_NHWCD4I)) { \
MEGDNN_DISPATCH_CPU_KERN( \
m_handle, padding_nhwc_src_to_workspace<ctype>( \
workspace.ptr<ctype>(), \
src.compatible_ptr<ctype>(), N, IH, IW, IC);); \
} \
MIDOUT_END(); \
break; \
}
cb(Float32, dt_float32);
DNN_INC_FLOAT16(cb(Float16, dt_float16));
cb(Quantized8Asymm, dt_uint8);
cb(QuantizedS8, dt_int8);
cb(Uint8, dt_uint8);
#undef cb
default:
megdnn_assert(
0, "NHWC_NHWCD4I not support dtype %s",
src.layout.dtype.name());
} }
exec_src_nd = TensorND{workspace.raw_ptr, exec_src_nd.layout}; exec_src_nd = TensorND{workspace.raw_ptr, exec_src_nd.layout};
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册