提交 8ba8c11d 编写于 作者: M Megvii Engine Team

feat(dnn): add nchw44 layout

GitOrigin-RevId: d92672b88a48a2de396532ccbc6bd7e467d5eab9
上级 a744b3cb
...@@ -35,9 +35,10 @@ pdef('Axis').add_fields('int32', 'axis', 0) ...@@ -35,9 +35,10 @@ pdef('Axis').add_fields('int32', 'axis', 0)
). ).
add_enum(Doc('Format', 'convolution data/filter/output format; see ' add_enum(Doc('Format', 'convolution data/filter/output format; see '
':class:`RelayoutFormat` for more details'), ':class:`RelayoutFormat` for more details'),
'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88', 'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88', 'NCHW44',
Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'), Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'),
Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'), Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'),
Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'),
Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation '
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.')) 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'))
) )
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/common/conv_bias.h" #include "src/common/conv_bias.h"
...@@ -33,7 +34,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( ...@@ -33,7 +34,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
const TensorLayout& bias, const TensorLayout& z, const TensorLayout& bias, const TensorLayout& z,
const TensorLayout& dst, size_t workspace_in_bytes) { const TensorLayout& dst, size_t workspace_in_bytes) {
if ((param().format == param::ConvBias::Format::NCHW_WINOGRAD || if ((param().format == param::ConvBias::Format::NCHW_WINOGRAD ||
param().format == param::ConvBias::Format::NCHW88_WINOGRAD) && param().format == param::ConvBias::Format::NCHW88_WINOGRAD ||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) &&
src.dtype.category() == DTypeCategory::QUANTIZED) { src.dtype.category() == DTypeCategory::QUANTIZED) {
megdnn_assert(filter.dtype.enumv() == DTypeEnum::QuantizedS16); megdnn_assert(filter.dtype.enumv() == DTypeEnum::QuantizedS16);
megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8 || megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8 ||
...@@ -45,7 +47,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( ...@@ -45,7 +47,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
float scale_src = src.dtype.param<dtype::QuantizedS8>().scale; float scale_src = src.dtype.param<dtype::QuantizedS8>().scale;
float scale_filter = 0.f; float scale_filter = 0.f;
if (param().format == param::ConvBias::Format::NCHW_WINOGRAD || if (param().format == param::ConvBias::Format::NCHW_WINOGRAD ||
param().format == param::ConvBias::Format::NCHW88_WINOGRAD) { param().format == param::ConvBias::Format::NCHW88_WINOGRAD ||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) {
scale_filter = filter.dtype.param<dtype::QuantizedS16>().scale; scale_filter = filter.dtype.param<dtype::QuantizedS16>().scale;
} else { } else {
scale_filter = filter.dtype.param<dtype::QuantizedS8>().scale; scale_filter = filter.dtype.param<dtype::QuantizedS8>().scale;
...@@ -58,7 +61,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( ...@@ -58,7 +61,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
float scale_src = src.dtype.param<dtype::Quantized8Asymm>().scale; float scale_src = src.dtype.param<dtype::Quantized8Asymm>().scale;
float scale_filter = 0.f; float scale_filter = 0.f;
if (param().format == param::ConvBias::Format::NCHW_WINOGRAD || if (param().format == param::ConvBias::Format::NCHW_WINOGRAD ||
param().format == param::ConvBias::Format::NCHW88_WINOGRAD) { param().format == param::ConvBias::Format::NCHW88_WINOGRAD ||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) {
scale_filter = filter.dtype.param<dtype::QuantizedS16>().scale; scale_filter = filter.dtype.param<dtype::QuantizedS16>().scale;
} else { } else {
scale_filter = filter.dtype.param<dtype::Quantized8Asymm>().scale; scale_filter = filter.dtype.param<dtype::Quantized8Asymm>().scale;
...@@ -98,7 +102,9 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( ...@@ -98,7 +102,9 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[2] == 1);
megdnn_assert(bias.shape[3] == dst.shape[3], "bias:%s, dst:%s", megdnn_assert(bias.shape[3] == dst.shape[3], "bias:%s, dst:%s",
bias.to_string().c_str(), dst.to_string().c_str()); bias.to_string().c_str(), dst.to_string().c_str());
} else if (param().format == param::ConvBias::Format::NCHW4) { } else if (param().format == param::ConvBias::Format::NCHW4 ||
param().format == param::ConvBias::Format::NCHW44 ||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) {
megdnn_assert(bias.shape[0] == 1); megdnn_assert(bias.shape[0] == 1);
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
bias.to_string().c_str(), dst.to_string().c_str()); bias.to_string().c_str(), dst.to_string().c_str());
...@@ -141,7 +147,10 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( ...@@ -141,7 +147,10 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
if (z.ndim != 0) { if (z.ndim != 0) {
megdnn_assert(param().format != param::ConvBias::Format::NCHW_WINOGRAD); megdnn_assert(param().format != param::ConvBias::Format::NCHW_WINOGRAD);
megdnn_assert(param().format != param::ConvBias::Format::NCHW88_WINOGRAD); megdnn_assert(param().format !=
param::ConvBias::Format::NCHW88_WINOGRAD);
megdnn_assert(param().format !=
param::ConvBias::Format::NCHW44_WINOGRAD);
megdnn_assert(z.dtype.enumv() == dst.dtype.enumv()); megdnn_assert(z.dtype.enumv() == dst.dtype.enumv());
megdnn_assert(z.eq_shape(dst)); megdnn_assert(z.eq_shape(dst));
} }
...@@ -163,10 +172,7 @@ std::string ConvBias::algo_name(const std::string& base, const T& p) { ...@@ -163,10 +172,7 @@ std::string ConvBias::algo_name(const std::string& base, const T& p) {
} }
#define FOREACH_CONV_BIAS_PARAM(cb) \ #define FOREACH_CONV_BIAS_PARAM(cb) \
cb(WinogradParam) \ cb(WinogradParam) cb(DirectParam) cb(MatmulParam) cb(DefaultParam)
cb(DirectParam) \
cb(MatmulParam) \
cb(DefaultParam)
#define cb(pt) \ #define cb(pt) \
template <> \ template <> \
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "megdnn/oprs/nn.h" #include "megdnn/oprs/nn.h"
...@@ -55,7 +56,13 @@ spatial_getter<param::ConvBias, param::ConvBias::Format::NCHW88_WINOGRAD>( ...@@ -55,7 +56,13 @@ spatial_getter<param::ConvBias, param::ConvBias::Format::NCHW88_WINOGRAD>(
//! f = m + r - 1 -> r = f + 1 - m //! f = m + r - 1 -> r = f + 1 - m
return filter - param.output_block_size + 1; return filter - param.output_block_size + 1;
} }
template <>
uint32_t
spatial_getter<param::ConvBias, param::ConvBias::Format::NCHW44_WINOGRAD>(
uint32_t filter, const param::ConvBias& param) {
//! f = m + r - 1 -> r = f + 1 - m
return filter - param.output_block_size + 1;
}
template <typename Parameter, typename Param> template <typename Parameter, typename Param>
void make_canonized_filter_meta_nchw_nhwc( void make_canonized_filter_meta_nchw_nhwc(
...@@ -273,7 +280,7 @@ void make_canonized_filter_meta_nchwxx( ...@@ -273,7 +280,7 @@ void make_canonized_filter_meta_nchwxx(
/** /**
* input: N IC/pack_size, H, W, pack_size * input: N IC/pack_size, H, W, pack_size
* *
* NCHW88 mode * NCHW88 and NCHW44 mode
* filter: * filter:
* {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)} * {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)}
* [dense] * [dense]
...@@ -281,7 +288,7 @@ void make_canonized_filter_meta_nchwxx( ...@@ -281,7 +288,7 @@ void make_canonized_filter_meta_nchwxx(
* FH, FW, pack_size(IC), pack_size(OC)} [group] * FH, FW, pack_size(IC), pack_size(OC)} [group]
* {GROUP/pack_size, 1, 1, FH, FW, pack_size} [chan] * {GROUP/pack_size, 1, 1, FH, FW, pack_size} [chan]
* *
** NCHW88_WINOGRAD mode ** NCHW88_WINOGRAD and NCHW44_WINOGRAD mode
* filter: * filter:
* {alpha, alpha, OC/pack_size, IC/pack_size, pack_size(IC), * {alpha, alpha, OC/pack_size, IC/pack_size, pack_size(IC),
*pack_size(OC)} [dense] *pack_size(OC)} [dense]
...@@ -291,6 +298,7 @@ void make_canonized_filter_meta_nchwxx( ...@@ -291,6 +298,7 @@ void make_canonized_filter_meta_nchwxx(
*/ */
megdnn_assert(param.format == Param::Format::NCHW88 || megdnn_assert(param.format == Param::Format::NCHW88 ||
param.format == Param::Format::NCHW44 ||
param.format == Param::Format::NCHW88_WINOGRAD); param.format == Param::Format::NCHW88_WINOGRAD);
size_t img_ndim = 2; size_t img_ndim = 2;
size_t flt_start = 0; size_t flt_start = 0;
...@@ -305,7 +313,8 @@ void make_canonized_filter_meta_nchwxx( ...@@ -305,7 +313,8 @@ void make_canonized_filter_meta_nchwxx(
filter[filter.ndim - 1]); filter[filter.ndim - 1]);
ret.group = 1; ret.group = 1;
flt_start = 0; flt_start = 0;
if (param.format == Param::Format::NCHW88_WINOGRAD) { if (param.format == Param::Format::NCHW88_WINOGRAD ||
param.format == Param::Format::NCHW44_WINOGRAD) {
flt_start = 2; flt_start = 2;
} }
ret.ocpg = filter[flt_start] * pack_size; ret.ocpg = filter[flt_start] * pack_size;
...@@ -314,6 +323,8 @@ void make_canonized_filter_meta_nchwxx( ...@@ -314,6 +323,8 @@ void make_canonized_filter_meta_nchwxx(
// ohwi8o // ohwi8o
megdnn_assert(param.format != Param::Format::NCHW88_WINOGRAD, megdnn_assert(param.format != Param::Format::NCHW88_WINOGRAD,
"Hybrid nchw88 mode in not support winograd"); "Hybrid nchw88 mode in not support winograd");
megdnn_assert(param.format != Param::Format::NCHW44_WINOGRAD,
"Hybrid nchw44 mode in not support winograd");
flt_start = 0; flt_start = 0;
flt_spatial_start = 1; flt_spatial_start = 1;
ret.group = 1; ret.group = 1;
...@@ -321,20 +332,22 @@ void make_canonized_filter_meta_nchwxx( ...@@ -321,20 +332,22 @@ void make_canonized_filter_meta_nchwxx(
ret.icpg = filter[flt_start + 3]; ret.icpg = filter[flt_start + 3];
} else { } else {
megdnn_assert(0, "not support nchw88 filter dim = %zu", megdnn_assert(0, "not support nchwxx filter dim = %zu",
filter.ndim); filter.ndim);
} }
} else { } else {
megdnn_assert(param.sparse == Param::Sparse::GROUP, megdnn_assert(param.sparse == Param::Sparse::GROUP,
"invalid convolution sparse type"); "invalid convolution sparse type");
flt_start = 1; flt_start = 1;
if (param.format == Param::Format::NCHW88_WINOGRAD) { if (param.format == Param::Format::NCHW88_WINOGRAD ||
param.format == Param::Format::NCHW44_WINOGRAD) {
flt_start = 3; flt_start = 3;
} }
auto filter_oc = filter[flt_start]; auto filter_oc = filter[flt_start];
auto filter_ic = filter[flt_start + 1]; auto filter_ic = filter[flt_start + 1];
if (filter_oc == 1 && filter_ic == 1 && filter.ndim == (img_ndim + 4) && if (filter_oc == 1 && filter_ic == 1 && filter.ndim == (img_ndim + 4) &&
param.format != Param::Format::NCHW88_WINOGRAD) { param.format != Param::Format::NCHW88_WINOGRAD &&
param.format != Param::Format::NCHW44_WINOGRAD) {
// Depthwise case goihw8g // Depthwise case goihw8g
megdnn_assert(filter.ndim == img_ndim + 4, megdnn_assert(filter.ndim == img_ndim + 4,
"bad filter ndim for group convolution: " "bad filter ndim for group convolution: "
...@@ -343,7 +356,7 @@ void make_canonized_filter_meta_nchwxx( ...@@ -343,7 +356,7 @@ void make_canonized_filter_meta_nchwxx(
megdnn_assert(filter[filter.ndim - 1] == pack_size, megdnn_assert(filter[filter.ndim - 1] == pack_size,
"last dim of filter must be %zu, but %zu", pack_size, "last dim of filter must be %zu, but %zu", pack_size,
filter[filter.ndim - 1]); filter[filter.ndim - 1]);
ret.group = filter[0] * 8; ret.group = filter[0] * pack_size;
ret.ocpg = filter_oc; ret.ocpg = filter_oc;
ret.icpg = filter_ic; ret.icpg = filter_ic;
...@@ -381,6 +394,10 @@ void make_canonized_filter_meta_nchwxx( ...@@ -381,6 +394,10 @@ void make_canonized_filter_meta_nchwxx(
ret.spatial[i] = ret.spatial[i] =
spatial_getter<Param, Param::Format::NCHW88_WINOGRAD>( spatial_getter<Param, Param::Format::NCHW88_WINOGRAD>(
filter[i + flt_start - 2], param); filter[i + flt_start - 2], param);
} else if (param.format == Param::Format::NCHW44_WINOGRAD) {
ret.spatial[i] =
spatial_getter<Param, Param::Format::NCHW44_WINOGRAD>(
filter[i + flt_start - 2], param);
} else { } else {
ret.spatial[i] = filter[i + flt_start + flt_spatial_start]; ret.spatial[i] = filter[i + flt_start + flt_spatial_start];
} }
...@@ -535,6 +552,10 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta( ...@@ -535,6 +552,10 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta(
param().format == Param::Format::NCHW88_WINOGRAD) { param().format == Param::Format::NCHW88_WINOGRAD) {
make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter, make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter,
param(), ret); param(), ret);
} else if (param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW44_WINOGRAD) {
make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter,
param(), ret);
} else if (param().format == Param::Format::NCHW32) { } else if (param().format == Param::Format::NCHW32) {
make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter, make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter,
param(), ret); param(), ret);
...@@ -629,18 +650,22 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, ...@@ -629,18 +650,22 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
} else { } else {
megdnn_assert(param().format == Param::Format::NHWCD4 || megdnn_assert(param().format == Param::Format::NHWCD4 ||
param().format == Param::Format::NCHW4 || param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW8 || param().format == Param::Format::NCHW8 ||
param().format == Param::Format::NCHW32 || param().format == Param::Format::NCHW32 ||
param().format == Param::Format::NCHW88 || param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW88_WINOGRAD || param().format == Param::Format::NCHW88_WINOGRAD ||
param().format == Param::Format::CHWN4); param().format == Param::Format::CHWN4);
img_dim = src.ndim - 3; img_dim = src.ndim - 3;
if (param().format == Param::Format::NCHW88 && filter.ndim == 5) { if ((param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW44) &&
filter.ndim == 5) {
img_dim = src.ndim - 2; img_dim = src.ndim - 2;
} }
megdnn_assert(filter.ndim == img_dim + 3 || megdnn_assert(filter.ndim == img_dim + 3 ||
(filter.ndim == img_dim + 2 && (filter.ndim == img_dim + 2 &&
param().format == Param::Format::NCHW88) || (param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW44)) ||
filter.ndim == img_dim + 4 || filter.ndim == img_dim + 4 ||
filter.ndim == img_dim + 5, filter.ndim == img_dim + 5,
"%s", errmsg().c_str()); "%s", errmsg().c_str());
...@@ -691,6 +716,21 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, ...@@ -691,6 +716,21 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
", and last shape two is 8 but got src %s, filter %s", ", and last shape two is 8 but got src %s, filter %s",
src.to_string().c_str(), filter.to_string().c_str()); src.to_string().c_str(), filter.to_string().c_str());
} }
if (param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW44_WINOGRAD) {
megdnn_assert((src.ndim == 4 && filter.ndim == 5 &&
filter[filter.ndim - 1] == 4) ||
(src.ndim == 5 &&
((filter.ndim == 6 &&
filter[filter.ndim - 1] == 4) ||
(filter.ndim == 7 &&
filter[filter.ndim - 1] == 4 &&
filter[filter.ndim - 2] == 4)) &&
src[src.ndim - 1] == 4),
"NCHW44 require src ndim is 5 and filter's ndim is 6 "
", and last shape two is 4 but got src %s, filter %s",
src.to_string().c_str(), filter.to_string().c_str());
}
if (param().format == Param::Format::CHWN4) { if (param().format == Param::Format::CHWN4) {
megdnn_assert( megdnn_assert(
src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) && src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
...@@ -808,6 +848,27 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, ...@@ -808,6 +848,27 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
cflt.group); cflt.group);
} }
} else if (param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW44_WINOGRAD) {
megdnn_assert(src.ndim == 5 || (src.ndim == 4 && src[1] <= 8),
"invalid src ndim for NCHW44, expected=5 or 4, got=%zu",
src.ndim);
dst.ndim = 5;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
megdnn_assert(oc % 4 == 0);
dst[1] = oc / 4;
dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0],
cflt.stride[0], cflt.padding[0]);
dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1],
cflt.stride[1], cflt.padding[1]);
dst[4] = 4;
if (cflt.group == 1) {
megdnn_assert(cflt.icpg * cflt.group == src[1] * 4 ||
(cflt.icpg * cflt.group == src[1]),
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg,
cflt.group);
}
} else if (param().format == Param::Format::CHWN4) { } else if (param().format == Param::Format::CHWN4) {
megdnn_assert(src.ndim == 5, megdnn_assert(src.ndim == 5,
"invalid src ndim for CHWN4, expected=5, got=%zu", "invalid src ndim for CHWN4, expected=5, got=%zu",
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
...@@ -47,6 +48,7 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, ...@@ -47,6 +48,7 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
spatial_pos = 1; spatial_pos = 1;
c_pos = 3; c_pos = 3;
} else if (param().format == Param::Format::NCHW4 || } else if (param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW88 || param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW32) { param().format == Param::Format::NCHW32) {
megdnn_assert(src.ndim == 5_z, "%s", errmsg_c); megdnn_assert(src.ndim == 5_z, "%s", errmsg_c);
...@@ -73,6 +75,7 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, ...@@ -73,6 +75,7 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
iw = src[spatial_pos + 2]; iw = src[spatial_pos + 2];
} }
if (param().format == Param::Format::NCHW4 || if (param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::CHWN4) { param().format == Param::Format::CHWN4) {
c *= 4; c *= 4;
} }
...@@ -96,7 +99,8 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, ...@@ -96,7 +99,8 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src,
megdnn_assert(param().format == Param::Format::NHWC, megdnn_assert(param().format == Param::Format::NHWC,
"invalid pooling format"); "invalid pooling format");
dst = TensorLayout({n, oh, ow, c}, src.dtype, src.format); dst = TensorLayout({n, oh, ow, c}, src.dtype, src.format);
} else if (param().format == Param::Format::NCHW4) { } else if (param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44) {
dst = TensorLayout{{n, c / 4, oh, ow, 4}, src.dtype, src.format}; dst = TensorLayout{{n, c / 4, oh, ow, 4}, src.dtype, src.format};
} else if (param().format == Param::Format::NCHW88) { } else if (param().format == Param::Format::NCHW88) {
dst = TensorLayout{{n, c / 8, oh, ow, 8}, src.dtype, src.format}; dst = TensorLayout{{n, c / 8, oh, ow, 8}, src.dtype, src.format};
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/fallback/convolution/opr_impl.h" #include "src/fallback/convolution/opr_impl.h"
#include "src/common/algo_chooser.h" #include "src/common/algo_chooser.h"
...@@ -157,9 +158,11 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param( ...@@ -157,9 +158,11 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param(
if (param().format == Param::Format::NCHW88 || if (param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW8 || param().format == Param::Format::NCHW8 ||
param().format == Param::Format::NCHW4 || param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW || param().format == Param::Format::NCHW ||
param().format == Param::Format::NCHW_WINOGRAD || param().format == Param::Format::NCHW_WINOGRAD ||
param().format == Param::Format::NCHW88_WINOGRAD) { param().format == Param::Format::NCHW88_WINOGRAD ||
param().format == Param::Format::NCHW44_WINOGRAD) {
spatial_pos = 2; spatial_pos = 2;
} else if (param().format == Param::Format::NHWC) { } else if (param().format == Param::Format::NHWC) {
spatial_pos = 1; spatial_pos = 1;
...@@ -188,7 +191,8 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param( ...@@ -188,7 +191,8 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param(
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT; param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT;
if (param().format == Param::Format::NCHW_WINOGRAD || if (param().format == Param::Format::NCHW_WINOGRAD ||
param().format == Param::Format::NCHW88_WINOGRAD) { param().format == Param::Format::NCHW88_WINOGRAD ||
param().format == Param::Format::NCHW44_WINOGRAD) {
size_t flt_start = 0; size_t flt_start = 0;
if (param().sparse == Param::Sparse::GROUP) { if (param().sparse == Param::Sparse::GROUP) {
flt_start = 1; flt_start = 1;
...@@ -325,7 +329,7 @@ const char* ConvBiasImpl::get_algorithm_set_name() const { ...@@ -325,7 +329,7 @@ const char* ConvBiasImpl::get_algorithm_set_name() const {
return "F0"; return "F0";
} }
namespace megdnn{ namespace megdnn {
namespace fallback { namespace fallback {
template <typename T> template <typename T>
...@@ -342,7 +346,6 @@ const T* ConvBiasImpl::NCBKernParam::src(size_t batch_id, size_t group_pack_id, ...@@ -342,7 +346,6 @@ const T* ConvBiasImpl::NCBKernParam::src(size_t batch_id, size_t group_pack_id,
batch_offset + group_offset + channel_offset); batch_offset + group_offset + channel_offset);
} }
template <typename T> template <typename T>
const T* ConvBiasImpl::NCBKernParam::filter(size_t group_pack_id, const T* ConvBiasImpl::NCBKernParam::filter(size_t group_pack_id,
size_t pack_group_size) const { size_t pack_group_size) const {
...@@ -453,5 +456,4 @@ INST(void) ...@@ -453,5 +456,4 @@ INST(void)
} // namespace fallback } // namespace fallback
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
...@@ -87,7 +88,9 @@ class ConvBias { ...@@ -87,7 +88,9 @@ class ConvBias {
if (param.filter_meta.format != if (param.filter_meta.format !=
param::ConvBias::Format::NCHW_WINOGRAD && param::ConvBias::Format::NCHW_WINOGRAD &&
param.filter_meta.format != param.filter_meta.format !=
param::ConvBias::Format::NCHW88_WINOGRAD) { param::ConvBias::Format::NCHW88_WINOGRAD &&
param.filter_meta.format !=
param::ConvBias::Format::NCHW44_WINOGRAD) {
filter_transform_buf_size = Strategy::ALPHA * Strategy::ALPHA * OC * filter_transform_buf_size = Strategy::ALPHA * Strategy::ALPHA * OC *
IC * sizeof(input_filter_compute_type); IC * sizeof(input_filter_compute_type);
} }
...@@ -95,7 +98,8 @@ class ConvBias { ...@@ -95,7 +98,8 @@ class ConvBias {
get_wbundle_compute(param, matmul_algo).total_size_in_bytes() * get_wbundle_compute(param, matmul_algo).total_size_in_bytes() *
nr_threads; nr_threads;
if (param.filter_meta.format == param::ConvBias::Format::NCHW || if (param.filter_meta.format == param::ConvBias::Format::NCHW ||
param.filter_meta.format == param::ConvBias::Format::NCHW88) { param.filter_meta.format == param::ConvBias::Format::NCHW88 ||
param.filter_meta.format == param::ConvBias::Format::NCHW44) {
return WorkspaceBundle( return WorkspaceBundle(
nullptr, nullptr,
{winograd_comput_size, filter_transform_buf_size * GROUP}); {winograd_comput_size, filter_transform_buf_size * GROUP});
...@@ -103,7 +107,9 @@ class ConvBias { ...@@ -103,7 +107,9 @@ class ConvBias {
megdnn_assert(param.filter_meta.format == megdnn_assert(param.filter_meta.format ==
param::ConvBias::Format::NCHW_WINOGRAD || param::ConvBias::Format::NCHW_WINOGRAD ||
param.filter_meta.format == param.filter_meta.format ==
param::ConvBias::Format::NCHW88_WINOGRAD); param::ConvBias::Format::NCHW88_WINOGRAD ||
param.filter_meta.format ==
param::ConvBias::Format::NCHW44_WINOGRAD);
return WorkspaceBundle(nullptr, {winograd_comput_size}); return WorkspaceBundle(nullptr, {winograd_comput_size});
} }
} }
...@@ -210,11 +216,17 @@ public: ...@@ -210,11 +216,17 @@ public:
reinterpret_cast<input_filter_compute_type*>( reinterpret_cast<input_filter_compute_type*>(
reinterpret_cast<uintptr_t>(bundle_compute.get(2)) + reinterpret_cast<uintptr_t>(bundle_compute.get(2)) +
compute_workspace_size_per_thread * thread_id); compute_workspace_size_per_thread * thread_id);
const stype* filter_ptr = kern_param.filter<stype>(group_id); const stype* filter_ptr = kern_param.filter<stype>(group_id);
size_t oc_start = oc_id, oc_end = oc_id+1; size_t oc_start = oc_id, oc_end = oc_id + 1;
if (kern_param.filter_meta.format == param::ConvBias::Format::NCHW88) { if (kern_param.filter_meta.format == param::ConvBias::Format::NCHW88) {
oc_start = 8 * oc_id; oc_start = 8 * oc_id;
oc_end = oc_start + 8; oc_end = oc_start + 8;
} else if (kern_param.filter_meta.format ==
param::ConvBias::Format::NCHW44) {
oc_start = 4 * oc_id;
oc_end = oc_start + 4;
} }
strategy.filter(filter_ptr, filter_transform_buf, transform_mid_buf, OC, strategy.filter(filter_ptr, filter_transform_buf, transform_mid_buf, OC,
IC, oc_start, oc_end); IC, oc_start, oc_end);
...@@ -279,7 +291,8 @@ public: ...@@ -279,7 +291,8 @@ public:
static_cast<const input_filter_compute_type*>( static_cast<const input_filter_compute_type*>(
ncb_param.filter<input_filter_compute_type>(group_id)); ncb_param.filter<input_filter_compute_type>(group_id));
if (ncb_param.filter_meta.format == param::ConvBias::Format::NCHW || if (ncb_param.filter_meta.format == param::ConvBias::Format::NCHW ||
ncb_param.filter_meta.format == param::ConvBias::Format::NCHW88) { ncb_param.filter_meta.format == param::ConvBias::Format::NCHW88 ||
ncb_param.filter_meta.format == param::ConvBias::Format::NCHW44) {
filter_transform_buf = reinterpret_cast<input_filter_compute_type*>( filter_transform_buf = reinterpret_cast<input_filter_compute_type*>(
reinterpret_cast<uintptr_t>(bundle_top.get(1)) + reinterpret_cast<uintptr_t>(bundle_top.get(1)) +
group_id * filter_group_size); group_id * filter_group_size);
...@@ -404,14 +417,18 @@ public: ...@@ -404,14 +417,18 @@ public:
param.filter_meta.stride[1] == 1 && param.filter_meta.stride[1] == 1 &&
(param.filter_meta.format == param::ConvBias::Format::NCHW || (param.filter_meta.format == param::ConvBias::Format::NCHW ||
param.filter_meta.format == param::ConvBias::Format::NCHW88 || param.filter_meta.format == param::ConvBias::Format::NCHW88 ||
param.filter_meta.format == param::ConvBias::Format::NCHW44 ||
param.filter_meta.format == param.filter_meta.format ==
param::ConvBias::Format::NCHW_WINOGRAD || param::ConvBias::Format::NCHW_WINOGRAD ||
param.filter_meta.format == param.filter_meta.format ==
param::ConvBias::Format::NCHW88_WINOGRAD)); param::ConvBias::Format::NCHW88_WINOGRAD ||
param.filter_meta.format ==
param::ConvBias::Format::NCHW44_WINOGRAD));
SmallVector<NCBKern> kerns; SmallVector<NCBKern> kerns;
if (param.filter_meta.format == param::ConvBias::Format::NCHW || if (param.filter_meta.format == param::ConvBias::Format::NCHW ||
param.filter_meta.format == param::ConvBias::Format::NCHW88) { param.filter_meta.format == param::ConvBias::Format::NCHW88 ||
param.filter_meta.format == param::ConvBias::Format::NCHW44) {
//! probably a gcc bug, labmda require capturing 'this' to call //! probably a gcc bug, labmda require capturing 'this' to call
//! static member function //! static member function
auto filter_process_kern = [this, strategy, bundle_top, auto filter_process_kern = [this, strategy, bundle_top,
...@@ -426,6 +443,10 @@ public: ...@@ -426,6 +443,10 @@ public:
if (param.filter_meta.format == param::ConvBias::Format::NCHW88) { if (param.filter_meta.format == param::ConvBias::Format::NCHW88) {
megdnn_assert(OC % 8 == 0); megdnn_assert(OC % 8 == 0);
oc_parallelism = OC / 8; oc_parallelism = OC / 8;
} else if (param.filter_meta.format ==
param::ConvBias::Format::NCHW44) {
megdnn_assert(OC % 4 == 0);
oc_parallelism = OC / 4;
} }
kerns.push_back({filter_process_kern, {GROUP, 1, oc_parallelism}}); kerns.push_back({filter_process_kern, {GROUP, 1, oc_parallelism}});
} }
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/fallback/convolution/opr_impl.h" #include "src/fallback/convolution/opr_impl.h"
...@@ -142,7 +143,8 @@ ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param( ...@@ -142,7 +143,8 @@ ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param(
size_t spatial_pos; size_t spatial_pos;
if (param().format == Param::Format::NCHW88 || if (param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW8 || param().format == Param::Format::NCHW8 ||
param().format == Param::Format::NCHW4) { param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44) {
spatial_pos = 2; spatial_pos = 2;
} else if (param().format == Param::Format::NCHW || } else if (param().format == Param::Format::NCHW ||
param().format == Param::Format::NCHW_WINOGRAD) { param().format == Param::Format::NCHW_WINOGRAD) {
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
...@@ -145,6 +146,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, ...@@ -145,6 +146,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
using Format = param::Convolution::Format; using Format = param::Convolution::Format;
if (filter_meta.format == Format::NCHW || if (filter_meta.format == Format::NCHW ||
filter_meta.format == Format::NCHW88 || filter_meta.format == Format::NCHW88 ||
filter_meta.format == Format::NCHW44 ||
filter_meta.format == Format::NCHW4 || filter_meta.format == Format::NCHW4 ||
filter_meta.format == Format::NCHW8 || filter_meta.format == Format::NCHW8 ||
filter_meta.format == Format::NCHW32) { filter_meta.format == Format::NCHW32) {
...@@ -171,7 +173,8 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, ...@@ -171,7 +173,8 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
OW = dst.layout.shape[spatial_start + 1]; OW = dst.layout.shape[spatial_start + 1];
if (filter_meta.format == Format::NCHW4 || if (filter_meta.format == Format::NCHW4 ||
filter_meta.format == Format::CHWN4) { filter_meta.format == Format::CHWN4 ||
filter_meta.format == Format::NCHW44) {
OC *= 4; OC *= 4;
} else if (filter_meta.format == Format::NCHW8 || } else if (filter_meta.format == Format::NCHW8 ||
filter_meta.format == Format::NCHW88) { filter_meta.format == Format::NCHW88) {
...@@ -216,6 +219,26 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, ...@@ -216,6 +219,26 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
FS_G = FS_OC * filter_meta.ocpg / 8; FS_G = FS_OC * filter_meta.ocpg / 8;
} }
} }
} else if (filter_meta.format == Format::NCHW44) {
if (filter_meta.group > 1 && filter_meta.icpg == 1 &&
src.layout.ndim == 5 && filter_meta.ocpg == 1) {
FS_SPATIAL = 4;
FS_IC = FH * FW * FS_SPATIAL;
FS_OC = FS_IC * filter_meta.icpg;
FS_G = FS_OC * filter_meta.ocpg;
} else {
if (src.layout.ndim == 4 && dst.layout.ndim == 5) {
FS_IC = 4;
FS_SPATIAL = filter_meta.icpg * FS_IC;
FS_OC = FH * FW * FS_SPATIAL;
FS_G = FS_OC * filter_meta.ocpg / 4;
} else {
FS_SPATIAL = 4 * 4;
FS_IC = FH * FW * FS_SPATIAL;
FS_OC = FS_IC * filter_meta.icpg / 4;
FS_G = FS_OC * filter_meta.ocpg / 4;
}
}
} else { } else {
// g, oc, fh, fw, ic // g, oc, fh, fw, ic
megdnn_assert(filter_meta.format == Format::NHWC); megdnn_assert(filter_meta.format == Format::NHWC);
...@@ -259,6 +282,16 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, ...@@ -259,6 +282,16 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
h * layout.stride[2] + w * layout.stride[3] + h * layout.stride[2] + w * layout.stride[3] +
(c & 0b111) * layout.stride[4]; (c & 0b111) * layout.stride[4];
} }
} else if (filter_meta.format == Format::NCHW44) {
if (filter_meta.format == Format::NCHW44 && !is_output &&
src.layout.ndim == 4) {
return n * layout.stride[0] + c * layout.stride[1] +
h * layout.stride[2] + w * layout.stride[3];
} else {
return n * layout.stride[0] + (c / 4) * layout.stride[1] +
h * layout.stride[2] + w * layout.stride[3] +
(c % 4) * layout.stride[4];
}
} else if (filter_meta.format == Format::NCHW32) { } else if (filter_meta.format == Format::NCHW32) {
return n * layout.stride[0] + (c >> 5) * layout.stride[1] + return n * layout.stride[0] + (c >> 5) * layout.stride[1] +
h * layout.stride[2] + w * layout.stride[3] + h * layout.stride[2] + w * layout.stride[3] +
...@@ -315,6 +348,27 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, ...@@ -315,6 +348,27 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
megdnn_assert( megdnn_assert(
0, "nchw88 naive not support this input and output\n"); 0, "nchw88 naive not support this input and output\n");
} }
} else if (filter_meta.format == Format::NCHW44) {
if (src.layout.ndim == 4) {
// ic < 8, input is nchw
return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC +
(fh * FW + fw) * FS_SPATIAL + (ic - ic0) * FS_IC +
gc_out.cur_off % 4;
} else if (filter_meta.group > 1 && filter_meta.icpg == 1 &&
filter_meta.ocpg == 1 && src.layout.ndim == 5) {
// dw case
return gc_out.cur_grp / 4 * FS_G + gc_out.cur_off * FS_OC +
(ic - ic0) * FS_IC + (fh * FW + fw) * FS_SPATIAL +
gc_out.cur_grp % 4;
} else if (src.layout.ndim == 5) {
// normal case
return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC +
(ic - ic0) / 4 * FS_IC + (fh * FW + fw) * FS_SPATIAL +
((ic - ic0) % 4) * 4 + gc_out.cur_off % 4;
} else {
megdnn_assert(
0, "nchw44 naive not support this input and output\n");
}
} else { } else {
return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
(ic - ic0) * FS_IC + (fh * FW + fw) * FS_SPATIAL; (ic - ic0) * FS_IC + (fh * FW + fw) * FS_SPATIAL;
...@@ -504,6 +558,7 @@ void forward(_megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_out dst, ...@@ -504,6 +558,7 @@ void forward(_megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_out dst,
megdnn_assert(filter_meta.format == param::Convolution::Format::NCHW || megdnn_assert(filter_meta.format == param::Convolution::Format::NCHW ||
filter_meta.format == param::Convolution::Format::NHWC || filter_meta.format == param::Convolution::Format::NHWC ||
filter_meta.format == param::Convolution::Format::NCHW88 || filter_meta.format == param::Convolution::Format::NCHW88 ||
filter_meta.format == param::Convolution::Format::NCHW44 ||
filter_meta.format == param::Convolution::Format::NCHW4); filter_meta.format == param::Convolution::Format::NCHW4);
compute2d<stype, ftype, dtype, comp_type, StrategyFwd>( compute2d<stype, ftype, dtype, comp_type, StrategyFwd>(
src, const_cast<ftype*>(fptr), dst, filter_meta); src, const_cast<ftype*>(fptr), dst, filter_meta);
...@@ -557,6 +612,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, ...@@ -557,6 +612,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter,
switch (filter_meta.format) { switch (filter_meta.format) {
case param::Convolution::Format::NCHW: case param::Convolution::Format::NCHW:
case param::Convolution::Format::NCHW88: case param::Convolution::Format::NCHW88:
case param::Convolution::Format::NCHW44:
case param::Convolution::Format::NHWC: case param::Convolution::Format::NHWC:
case param::Convolution::Format::NCHW4: case param::Convolution::Format::NCHW4:
case param::Convolution::Format::NCHW8: case param::Convolution::Format::NCHW8:
...@@ -633,6 +689,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, ...@@ -633,6 +689,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter,
} \ } \
} \ } \
} while (0) } while (0)
case Format::NCHW44:
case Format::NCHW4: { case Format::NCHW4: {
BIAS_ADD_NCHWx(4); BIAS_ADD_NCHWx(4);
break; break;
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/naive/pooling/opr_impl.h" #include "src/naive/pooling/opr_impl.h"
...@@ -168,6 +169,13 @@ struct NCHW88IdxGetter { ...@@ -168,6 +169,13 @@ struct NCHW88IdxGetter {
return id; return id;
} }
}; };
struct NCHW44IdxGetter {
static size_t get_idx(size_t n, size_t c, size_t h, size_t w, size_t,
size_t C, size_t H, size_t W) {
size_t id = (((n * (C >> 2) + (c >> 2)) * H + h) * W + w) * 4 + (c % 4);
return id;
}
};
struct CHWN4IdxGetter { struct CHWN4IdxGetter {
static size_t get_idx(size_t n, size_t c, size_t h, size_t w, size_t N, static size_t get_idx(size_t n, size_t c, size_t h, size_t w, size_t N,
...@@ -375,6 +383,7 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, ...@@ -375,6 +383,7 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
if (param().format == Param::Format::NCHW || if (param().format == Param::Format::NCHW ||
param().format == Param::Format::NCHW4 || param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW88 || param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW32) { param().format == Param::Format::NCHW32) {
c_pos = 1; c_pos = 1;
spatial_pos = 2; spatial_pos = 2;
...@@ -401,6 +410,7 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, ...@@ -401,6 +410,7 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
OW = dst.layout.shape[spatial_pos + 2]; OW = dst.layout.shape[spatial_pos + 2];
} }
if (param().format == Param::Format::NCHW4 || if (param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::CHWN4) { param().format == Param::Format::CHWN4) {
C *= 4; C *= 4;
} }
...@@ -437,6 +447,9 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, ...@@ -437,6 +447,9 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
case Param::Format::NCHW88: \ case Param::Format::NCHW88: \
DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW88IdxGetter); \ DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW88IdxGetter); \
break; \ break; \
case Param::Format::NCHW44: \
DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW44IdxGetter); \
break; \
case Param::Format::NCHW32: \ case Param::Format::NCHW32: \
DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW32IdxGetter); \ DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW32IdxGetter); \
break; \ break; \
......
...@@ -6,13 +6,14 @@ ...@@ -6,13 +6,14 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "test/naive/fixture.h"
#include "megdnn/oprs/nn.h" #include "megdnn/oprs/nn.h"
#include "test/common/checker.h" #include "test/common/checker.h"
#include "test/common/workspace_wrapper.h" #include "test/common/workspace_wrapper.h"
#include "test/naive/fixture.h"
using namespace megdnn; using namespace megdnn;
using namespace test; using namespace test;
...@@ -35,55 +36,39 @@ private: ...@@ -35,55 +36,39 @@ private:
} // namespace } // namespace
TEST_F(NAIVE, CONV_BIAS_QUANTIZED8x8x32) { TEST_F(NAIVE, CONV_BIAS_QUANTIZED8x8x32) {
Checker<ConvBias> checker(handle(), /* check_dispatch */false); Checker<ConvBias> checker(handle(), /* check_dispatch */ false);
ConvBias::Param param; ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW; param.format = ConvBias::Param::Format::NCHW;
checker.set_param(param).exect( checker.set_param(param).exect(
Testcase{ Testcase{TensorValue({1, 1, 4, 4}, dtype::QuantizedS8(0.1f),
TensorValue({1, 1, 4, 4}, dtype::QuantizedS8(0.1f), {90 - 128, 136 - 128, 85 - 128, 204 - 128,
{90-128, 136-128, 85-128, 204-128, 48 - 128, 9 - 128, 226 - 128, 25 - 128,
48-128, 9-128, 226-128, 25-128, 118 - 128, 109 - 128, 87 - 128, 132 - 128,
118-128, 109-128, 87-128, 132-128, 104 - 128, 163 - 128, 25 - 128, 90 - 128}),
104-128, 163-128, 25-128, 90-128}), TensorValue({3, 1, 3, 3}, dtype::QuantizedS8(0.2f),
TensorValue({3, 1, 3, 3}, dtype::QuantizedS8(0.2f), {153 - 124, 170 - 124, 102 - 124, 103 - 124,
{153-124, 170-124, 102-124, 23 - 124, 213 - 124, 116 - 124, 195 - 124,
103-124, 23-124, 213-124, 191 - 124, 44 - 124, 50 - 124, 247 - 124,
116-124, 195-124, 191-124, 172 - 124, 42 - 124, 32 - 124, 233 - 124,
163 - 124, 247 - 124, 120 - 124, 241 - 124,
44-124, 50-124, 247-124, 209 - 124, 83 - 124, 201 - 124, 115 - 124,
172-124, 42-124, 32-124, 32 - 124, 140 - 124, 147 - 124}),
233-124, 163-124, 247-124, TensorValue({1, 3, 1, 1}, dtype::QuantizedS32(0.02f),
{0, 0, 0}),
120-124, 241-124, 209-124, TensorValue({1, 3, 2, 2}, dtype::QuantizedS32(0.3f),
83-124, 201-124, 115-124, {1234, 0, 0, 0, 0, 0, 0, 0, 0, -234, 0, 0}),
32-124, 140-124, 147-124}), {}},
TensorValue({1, 3, 1, 1}, dtype::QuantizedS32(0.02f), Testcase{{},
{0, 0, 0}), {},
TensorValue({1, 3, 2, 2}, dtype::QuantizedS32(0.3f), {},
{1234, 0, {},
0, 0, TensorValue({1, 3, 2, 2}, dtype::QuantizedS32(0.1f * 0.2f),
{37127, -22475, -15694, -1920,
0, 0,
0, 0, -12813, 4440, 18190, -13195,
0, -234, -9659, 12423, -5558, -4969})});
0, 0}),
{}},
Testcase{
{},
{},
{},
{},
TensorValue({1, 3, 2, 2}, dtype::QuantizedS32(0.1f * 0.2f),
{37127, -22475,
-15694, -1920,
-12813, 4440,
18190, -13195,
-9659, 12423,
-5558, -4969})});
} }
TEST_F(NAIVE, CONV_BIAS_QUANTIZED4x4x32) { TEST_F(NAIVE, CONV_BIAS_QUANTIZED4x4x32) {
...@@ -175,10 +160,8 @@ TEST_F(NAIVE, CONV_BIAS_QUANTIZED4x4x32) { ...@@ -175,10 +160,8 @@ TEST_F(NAIVE, CONV_BIAS_QUANTIZED4x4x32) {
{0, 0, 0, 0, 0, 0, 0, 0}), {0, 0, 0, 0, 0, 0, 0, 0}),
TensorValue( TensorValue(
{1, 1, 2, 2, 8}, dtype::QuantizedS32(0.3f), {1, 1, 2, 2, 8}, dtype::QuantizedS32(0.3f),
{0, 0, 0, 0, 0, 0, 0, 0, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -87, 0, 0, 0, 0, 0, 0, 0, 0, 0}),
0, 0, 0, 0, 0, 0, -87, 0,
0, 0, 0, 0, 0, 0, 0, 0}),
{}}, {}},
Testcase{ Testcase{
{}, {},
...@@ -316,8 +299,221 @@ TEST_F(NAIVE, CONV_BIAS_QUANTIZED8x8x32_NCHW32) { ...@@ -316,8 +299,221 @@ TEST_F(NAIVE, CONV_BIAS_QUANTIZED8x8x32_NCHW32) {
TensorNDArray{src_ts_32.tensornd(), TensorNDArray{src_ts_32.tensornd(),
filter_ts_32.tensornd(), filter_ts_32.tensornd(),
bias_ts_32.tensornd(), bias_ts_32.tensornd(),
z_ts_32.tensornd(), {}}, z_ts_32.tensornd(),
{}},
TensorNDArray{{}, {}, {}, {}, dst_ts_32.tensornd()}); TensorNDArray{{}, {}, {}, {}, dst_ts_32.tensornd()});
} }
TEST_F(NAIVE, CONV_BIAS_NCHW44) {
Checker<ConvBias> checker(handle(), /* check_dispatch */ false);
ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW44;
size_t n = 1;
size_t ic = 4;
size_t oc = 8;
size_t h = 2;
size_t w = 2;
size_t filter_size = 3;
size_t pad = 1;
auto src_tensor_shape = TensorShape{n, ic / 4, h, w, 4};
auto weight_tensor_shape =
TensorShape{oc / 4, ic / 4, filter_size, filter_size, 4, 4};
auto bias_tensor_shape = TensorShape{1, oc / 4, 1, 1, 4};
param.pad_h = pad;
param.pad_w = pad;
UniformIntRNG rng{-127, 127};
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(4, dtype::Float32())
.set_rng(0, &rng)
.set_rng(1, &rng)
.set_rng(2, &rng)
.set_epsilon(1e-3)
.set_param(param)
.execs({src_tensor_shape,
weight_tensor_shape,
bias_tensor_shape,
{},
{}});
checker.set_dtype(0, dtype::QuantizedS8(2.f))
.set_dtype(1, dtype::QuantizedS8(3.f))
.set_dtype(2, dtype::QuantizedS32(6.f))
.set_dtype(4, dtype::QuantizedS32(6.f))
.set_rng(0, &rng)
.set_rng(1, &rng)
.set_rng(2, &rng)
.set_epsilon(1e-3)
.set_param(param)
.execs({src_tensor_shape,
weight_tensor_shape,
bias_tensor_shape,
{},
{}});
{
// test normal conv
ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW44;
param.sparse = ConvBias::Param::Sparse::DENSE;
param.pad_h = 1;
param.pad_w = 1;
checker.set_param(param).exect(
Testcase{TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
{7, 2, 2, 1, 7, 5, 6, 3, 1, 2, 8, 3, 7, 7,
6, 4}),
TensorValue(
{1, 1, 3, 3, 4, 4}, dtype::Float32(),
{3, 5, 5, 2, 0, 1, 4, 8, 3, 5, 0, 7, 1, 7, 0,
7, 6, 4, 7, 7, 5, 2, 2, 4, 7, 6, 6, 3, 3, 2,
2, 8, 5, 0, 4, 4, 0, 5, 1, 0, 0, 4, 8, 4, 7,
7, 2, 0, 4, 8, 7, 3, 6, 2, 3, 0, 0, 6, 4, 4,
1, 4, 3, 8, 8, 8, 7, 2, 2, 5, 5, 1, 3, 2, 8,
1, 7, 0, 2, 7, 1, 6, 1, 5, 0, 6, 3, 0, 2, 4,
1, 1, 4, 2, 7, 5, 7, 8, 4, 5, 5, 7, 0, 3, 3,
2, 8, 6, 0, 1, 4, 6, 6, 6, 0, 1, 2, 4, 4, 1,
1, 7, 8, 2, 5, 2, 8, 3, 8, 3, 5, 0, 6, 3, 4,
3, 3, 7, 2, 8, 1, 1, 1, 4}),
TensorValue({1, 1, 1, 1, 4}, dtype::Float32(),
{7, 2, 8, 1}),
TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0}),
{}},
Testcase{
{},
{},
{},
{},
TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
{264, 338, 309, 195, 276, 332, 390, 199,
224, 268, 311, 218, 288, 311, 346, 277})});
}
{
// test dw conv
ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW44;
param.sparse = ConvBias::Param::Sparse::GROUP;
param.pad_h = 1;
param.pad_w = 1;
checker.set_param(param).exect(
Testcase{TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
{5, 8, 3, 2, 4, 6, 1, 5, 0, 8, 2, 6, 8, 6,
5, 7}),
TensorValue({1, 1, 1, 3, 3, 4}, dtype::Float32(),
{3, 0, 3, 1, 6, 5, 7, 3, 5, 0, 0, 7,
4, 6, 0, 1, 8, 2, 3, 7, 1, 0, 2, 4,
7, 5, 3, 0, 6, 2, 1, 5, 8, 6, 3, 1}),
TensorValue({1, 1, 1, 1, 4}, dtype::Float32(),
{4, 3, 5, 6}),
TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0}),
{}},
Testcase{{},
{},
{},
{},
TensorValue({1, 1, 2, 2, 4}, dtype::Float32(),
{112, 71, 33, 77, 104, 115, 19, 78, 62, 59,
42, 117, 107, 93, 36, 78})});
}
{
// test group conv
ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW44;
param.sparse = ConvBias::Param::Sparse::GROUP;
param.pad_h = 1;
param.pad_w = 1;
checker.set_param(param).exect(
Testcase{TensorValue({1, 2, 2, 2, 4}, dtype::Float32(),
{6, 3, 2, 7, 7, 6, 4, 5, 8, 6, 3,
1, 1, 2, 8, 3, 1, 0, 6, 1, 3, 3,
6, 0, 0, 5, 6, 7, 2, 2, 4, 4}),
TensorValue(
{2, 1, 1, 3, 3, 4, 4}, dtype::Float32(),
{3, 5, 5, 2, 0, 1, 4, 8, 3, 5, 0, 7, 1, 7, 0,
7, 6, 4, 7, 7, 5, 2, 2, 4, 7, 6, 6, 3, 3, 2,
2, 8, 5, 0, 4, 4, 0, 5, 1, 0, 0, 4, 8, 4, 7,
7, 2, 0, 4, 8, 7, 3, 6, 2, 3, 0, 0, 6, 4, 4,
1, 4, 3, 8, 8, 8, 7, 2, 2, 5, 5, 1, 3, 2, 8,
1, 7, 0, 2, 7, 1, 6, 1, 5, 0, 6, 3, 0, 2, 4,
1, 1, 4, 2, 7, 5, 7, 8, 4, 5, 5, 7, 0, 3, 3,
2, 8, 6, 0, 1, 4, 6, 6, 6, 0, 1, 2, 4, 4, 1,
1, 7, 8, 2, 5, 2, 8, 3, 8, 3, 5, 0, 6, 3, 4,
3, 3, 7, 2, 8, 1, 1, 1, 4, 7, 4, 5, 0, 6, 8,
7, 4, 8, 1, 3, 5, 3, 0, 0, 3, 7, 7, 7, 3, 8,
1, 2, 0, 1, 1, 2, 1, 3, 0, 0, 1, 1, 3, 0, 5,
6, 3, 0, 5, 4, 1, 4, 7, 0, 2, 1, 6, 7, 8, 0,
2, 1, 6, 7, 6, 3, 2, 7, 6, 5, 1, 1, 1, 2, 4,
6, 3, 3, 8, 0, 7, 1, 3, 7, 3, 2, 2, 4, 3, 5,
5, 6, 3, 3, 1, 2, 3, 0, 4, 0, 3, 3, 5, 5, 5,
2, 3, 1, 5, 4, 5, 8, 1, 7, 2, 1, 0, 1, 8, 2,
6, 7, 8, 4, 4, 7, 8, 4, 5, 8, 1, 1, 0, 7, 8,
4, 2, 2, 8, 6, 5, 2, 4, 8, 4, 0, 4, 0, 2, 1,
7, 1, 6}),
TensorValue({1, 2, 1, 1, 4}, dtype::Float32(),
{1, 8, 5, 6, 2, 8, 7, 7}),
TensorValue({1, 2, 2, 2, 4}, dtype::Float32(),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0}),
{}},
Testcase{
{},
{},
{},
{},
TensorValue({1, 2, 2, 2, 4}, dtype::Float32(),
{260, 342, 244, 241, 293, 385, 362, 257,
278, 301, 303, 226, 273, 306, 318, 307,
180, 244, 169, 156, 210, 244, 206, 167,
126, 165, 156, 207, 191, 141, 209, 172})});
}
{
// test normal conv
ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW44;
param.sparse = ConvBias::Param::Sparse::DENSE;
param.pad_h = 1;
param.pad_w = 1;
checker.set_param(param).exect(
Testcase{TensorValue({1, 1, 2, 2, 4}, dtype::Int8(),
{7, 2, 2, 1, 7, 5, 6, 3, 1, 2, 8, 3, 7, 7,
6, 4}),
TensorValue(
{1, 1, 3, 3, 4, 4}, dtype::Int8(),
{3, 5, 5, 2, 0, 1, 4, 8, 3, 5, 0, 7, 1, 7, 0,
7, 6, 4, 7, 7, 5, 2, 2, 4, 7, 6, 6, 3, 3, 2,
2, 8, 5, 0, 4, 4, 0, 5, 1, 0, 0, 4, 8, 4, 7,
7, 2, 0, 4, 8, 7, 3, 6, 2, 3, 0, 0, 6, 4, 4,
1, 4, 3, 8, 8, 8, 7, 2, 2, 5, 5, 1, 3, 2, 8,
1, 7, 0, 2, 7, 1, 6, 1, 5, 0, 6, 3, 0, 2, 4,
1, 1, 4, 2, 7, 5, 7, 8, 4, 5, 5, 7, 0, 3, 3,
2, 8, 6, 0, 1, 4, 6, 6, 6, 0, 1, 2, 4, 4, 1,
1, 7, 8, 2, 5, 2, 8, 3, 8, 3, 5, 0, 6, 3, 4,
3, 3, 7, 2, 8, 1, 1, 1, 4}),
TensorValue({1, 1, 1, 1, 4}, dtype::Int32(),
{7, 2, 8, 1}),
TensorValue({1, 1, 2, 2, 4}, dtype::Int32(),
{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0}),
{}},
Testcase{
{},
{},
{},
{},
TensorValue({1, 1, 2, 2, 4}, dtype::Int32(),
{264, 338, 309, 195, 276, 332, 390, 199,
224, 268, 311, 218, 288, 311, 346, 277})});
}
}
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册