提交 6856ce9c 编写于 作者: M Megvii Engine Team

feat(dnn): support conv bias activation for nchw4 input tensor format and nchw output tensor format

GitOrigin-RevId: 29cd73f87b57b0ee6bfa608d629b75aed1491df0
上级 85368643
......@@ -39,7 +39,10 @@ pdef('Axis').add_fields('int32', 'axis', 0)
'NCHW44','NCHW44_DOT',
Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'),
Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'),
Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'),
Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'),
Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'),
Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'),
Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'),
Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation '
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.'))
)
......
......@@ -48,38 +48,52 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
megdnn_assert(src.dtype.enumv() == filter.dtype.enumv());
}
if (src.dtype.enumv() == DTypeEnum::QuantizedS8) {
float scale_src = src.dtype.param<dtype::QuantizedS8>().scale;
float scale_filter = 0.f;
if (param().format == param::ConvBias::Format::NCHW_WINOGRAD ||
param().format == param::ConvBias::Format::NCHW88_WINOGRAD ||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) {
if (filter.dtype.enumv() == DTypeEnum::QuantizedS32) {
//!int8 winogradf23_44 using float,QuantizedS32 take the scale
scale_filter = filter.dtype.param<dtype::QuantizedS32>().scale;
if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) {
float scale_src = src.dtype.param<dtype::QuantizedS8>().scale;
float scale_filter = 0.f;
if (param().format == param::ConvBias::Format::NCHW_WINOGRAD ||
param().format == param::ConvBias::Format::NCHW88_WINOGRAD ||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) {
if (filter.dtype.enumv() == DTypeEnum::QuantizedS32) {
//! int8 winogradf23_44 using float,QuantizedS32 take the
//! scale
scale_filter =
filter.dtype.param<dtype::QuantizedS32>().scale;
} else {
scale_filter =
filter.dtype.param<dtype::QuantizedS16>().scale;
}
} else {
scale_filter = filter.dtype.param<dtype::QuantizedS16>().scale;
scale_filter = filter.dtype.param<dtype::QuantizedS8>().scale;
}
float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale;
megdnn_assert(
std::abs(scale_src * scale_filter - scale_bias) < 1e-6,
"scale_src: %f scale_filter: %f scale_bias: %f", scale_src,
scale_filter, scale_bias);
} else {
scale_filter = filter.dtype.param<dtype::QuantizedS8>().scale;
megdnn_assert(bias.dtype.enumv() == DTypeEnum::Float32);
}
float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale;
megdnn_assert(std::abs(scale_src * scale_filter - scale_bias) < 1e-6,
"scale_src: %f scale_filter: %f scale_bias: %f",
scale_src, scale_filter, scale_bias);
} else if (src.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
float scale_src = src.dtype.param<dtype::Quantized8Asymm>().scale;
float scale_filter = 0.f;
if (param().format == param::ConvBias::Format::NCHW_WINOGRAD ||
param().format == param::ConvBias::Format::NCHW88_WINOGRAD ||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) {
scale_filter = filter.dtype.param<dtype::QuantizedS16>().scale;
if (bias.dtype.enumv() == DTypeEnum::QuantizedS32) {
float scale_src = src.dtype.param<dtype::Quantized8Asymm>().scale;
float scale_filter = 0.f;
if (param().format == param::ConvBias::Format::NCHW_WINOGRAD ||
param().format == param::ConvBias::Format::NCHW88_WINOGRAD ||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) {
scale_filter = filter.dtype.param<dtype::QuantizedS16>().scale;
} else {
scale_filter =
filter.dtype.param<dtype::Quantized8Asymm>().scale;
}
float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale;
megdnn_assert(
std::abs(scale_src * scale_filter - scale_bias) < 1e-6,
"scale_src: %f scale_filter: %f scale_bias: %f", scale_src,
scale_filter, scale_bias);
} else {
scale_filter = filter.dtype.param<dtype::Quantized8Asymm>().scale;
megdnn_assert(bias.dtype.enumv() == DTypeEnum::Float32);
}
float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale;
megdnn_assert(std::abs(scale_src * scale_filter - scale_bias) < 1e-6,
"scale_src: %f scale_filter: %f scale_bias: %f",
scale_src, scale_filter, scale_bias);
}
auto ret = check_layout_fwd(src, filter, dst);
......@@ -101,7 +115,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
if (check_eq(bias, dst))
return ret;
if (param().format == param::ConvBias::Format::NCHW ||
param().format == param::ConvBias::Format::NCHW_WINOGRAD) {
param().format == param::ConvBias::Format::NCHW_WINOGRAD ||
param().format == param::ConvBias::Format::NCHW4_NCHW) {
megdnn_assert(bias.shape[0] == 1);
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
bias.to_string().c_str(), dst.to_string().c_str());
......@@ -116,7 +131,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
} else if (param().format == param::ConvBias::Format::NCHW4 ||
param().format == param::ConvBias::Format::NCHW44 ||
param().format == param::ConvBias::Format::NCHW44_DOT ||
param().format == param::ConvBias::Format::NCHW44_WINOGRAD) {
param().format == param::ConvBias::Format::NCHW44_WINOGRAD ||
param().format == param::ConvBias::Format::NCHW32_NCHW4) {
megdnn_assert(bias.shape[0] == 1);
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
bias.to_string().c_str(), dst.to_string().c_str());
......@@ -132,7 +148,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
megdnn_assert(bias.shape[2] == 1);
megdnn_assert(bias.shape[3] == 1);
megdnn_assert(bias.shape[4] == 8);
} else if (param().format == param::ConvBias::Format::NCHW32) {
} else if (param().format == param::ConvBias::Format::NCHW32 ||
param().format == param::ConvBias::Format::NCHW4_NCHW32) {
megdnn_assert(bias.shape[0] == 1);
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s",
bias.to_string().c_str(), dst.to_string().c_str());
......@@ -163,6 +180,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec(
param::ConvBias::Format::NCHW88_WINOGRAD);
megdnn_assert(param().format !=
param::ConvBias::Format::NCHW44_WINOGRAD);
megdnn_assert(param().format != param::ConvBias::Format::NCHW4_NCHW32);
megdnn_assert(param().format != param::ConvBias::Format::NCHW32_NCHW4);
megdnn_assert(z.dtype.enumv() == dst.dtype.enumv());
megdnn_assert(z.eq_shape(dst));
}
......
......@@ -443,7 +443,10 @@ void make_canonized_filter_meta_nchwx(
*/
megdnn_assert(param.format == Param::Format::NCHW4 ||
param.format == Param::Format::NCHW8 ||
param.format == Param::Format::NCHW32);
param.format == Param::Format::NCHW32 ||
param.format == Param::Format::NCHW4_NCHW ||
param.format == Param::Format::NCHW4_NCHW32 ||
param.format == Param::Format::NCHW32_NCHW4);
auto img_ndim = src_ndim - 3;
size_t flt_start = 0, flt_spatial_start = 2;
if (param.sparse == Param::Sparse::DENSE) {
......@@ -568,7 +571,9 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta(
make_canonized_filter_meta_nhwcd4<Parameter>(src_ndim, filter,
param(), ret);
}
} else if (param().format == Param::Format::NCHW4) {
} else if (param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW4_NCHW ||
param().format == Param::Format::NCHW4_NCHW32) {
make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter,
param(), ret);
} else if (param().format == Param::Format::NCHW8) {
......@@ -583,7 +588,8 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta(
param().format == Param::Format::NCHW44_WINOGRAD) {
make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter,
param(), ret);
} else if (param().format == Param::Format::NCHW32) {
} else if (param().format == Param::Format::NCHW32 ||
param().format == Param::Format::NCHW32_NCHW4) {
make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter,
param(), ret);
} else if (param().format == Param::Format::CHWN4) {
......@@ -627,6 +633,9 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src,
if (dst.valid() && dst.enumv() == src.enumv()) {
supported_dst_dtype.push_back(dst);
}
if (src.enumv() == DTypeEnum::QuantizedS8) {
supported_dst_dtype.push_back(dtype::Float32());
}
} else if (src.enumv() == DTypeEnum::QuantizedS32) {
//! ConvolutionBackwardData: s8(filter) + s8(dst) -> s32(src)
megdnn_assert(filter.enumv() == DTypeEnum::QuantizedS8);
......@@ -697,10 +706,13 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
} else {
megdnn_assert(param().format == Param::Format::NHWCD4 ||
param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW4_NCHW ||
param().format == Param::Format::NCHW4_NCHW32 ||
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW44_DOT ||
param().format == Param::Format::NCHW8 ||
param().format == Param::Format::NCHW32 ||
param().format == Param::Format::NCHW32_NCHW4 ||
param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW88_WINOGRAD ||
param().format == Param::Format::NCHW44_WINOGRAD ||
......@@ -720,13 +732,17 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
filter.ndim == img_dim + 4 ||
filter.ndim == img_dim + 5,
"%s", errmsg().c_str());
if (param().format == Param::Format::NCHW4) {
if (param().format == Param::Format::NCHW4 ||
param().format == Param::Format::NCHW4_NCHW ||
param().format == Param::Format::NCHW4_NCHW32) {
megdnn_assert(src.ndim == 5 &&
(filter.ndim == 5 || filter.ndim == 6 ||
filter.ndim == 7) &&
src[src.ndim - 1] == 4 &&
filter[filter.ndim - 1] == 4,
"NCHW4 require src and filter's ndim is 5 or 6, and "
"NCHW4/NCHW4_NCHW/NCHW4_NCHW32 require src and "
"filter's ndim is "
"5 or 6, and "
"last shape "
"is 4 "
"but got src %s, filter %s",
......@@ -742,15 +758,17 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
"but got src %s, filter %s",
src.to_string().c_str(), filter.to_string().c_str());
}
if (param().format == Param::Format::NCHW32) {
megdnn_assert(
src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) &&
src[src.ndim - 1] == 32 &&
filter[filter.ndim - 1] == 32,
"NCHW32 require src and filter's ndim is 5 or 6, and last "
"shape is 32 "
"but got src %s, filter %s",
src.to_string().c_str(), filter.to_string().c_str());
if (param().format == Param::Format::NCHW32 ||
param().format == Param::Format::NCHW32_NCHW4) {
megdnn_assert(src.ndim == 5 &&
(filter.ndim == 5 || filter.ndim == 6) &&
src[src.ndim - 1] == 32 &&
filter[filter.ndim - 1] == 32,
"NCHW32/NCHW32_NCHW4 require src and filter's ndim "
"is 5 or 6, and last "
"shape is 32 "
"but got src %s, filter %s",
src.to_string().c_str(), filter.to_string().c_str());
}
if (param().format == Param::Format::NCHW88 ||
param().format == Param::Format::NCHW88_WINOGRAD) {
......@@ -943,6 +961,55 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src,
dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[1],
cflt.stride[1], cflt.padding[1]);
dst[4] = 4;
} else if (param().format == Param::Format::NCHW4_NCHW) {
megdnn_assert(src.ndim == 5,
"invalid src ndim for NCHW4_NCHW, expected=5, got=%zu",
src.ndim);
megdnn_assert(cflt.icpg * cflt.group == src[1] * 4,
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg,
cflt.group);
dst.ndim = 4;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
dst[1] = oc;
dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0],
cflt.stride[0], cflt.padding[0]);
dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1],
cflt.stride[1], cflt.padding[1]);
} else if (param().format == Param::Format::NCHW4_NCHW32) {
megdnn_assert(src.ndim == 5,
"invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu",
src.ndim);
megdnn_assert(cflt.icpg * cflt.group == src[1] * 4,
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg,
cflt.group);
dst.ndim = src.ndim;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
megdnn_assert(oc % 32 == 0);
dst[1] = oc / 32;
dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0],
cflt.stride[0], cflt.padding[0]);
dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1],
cflt.stride[1], cflt.padding[1]);
dst[4] = 32;
} else if (param().format == Param::Format::NCHW32_NCHW4) {
megdnn_assert(src.ndim == 5,
"invalid src ndim for NCHW32_NCHW4, expected=5, got=%zu",
src.ndim);
megdnn_assert(cflt.icpg * cflt.group == src[1] * 32,
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg,
cflt.group);
dst.ndim = src.ndim;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
megdnn_assert(oc % 4 == 0);
dst[1] = oc / 4;
dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0],
cflt.stride[0], cflt.padding[0]);
dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1],
cflt.stride[1], cflt.padding[1]);
dst[4] = 4;
} else {
megdnn_assert(param().format == Param::Format::NHWCD4);
megdnn_assert(src.ndim == 5,
......
......@@ -31,6 +31,9 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
args.bias_layout->eq_shape(*args.dst_layout))
return false;
auto&& param = args.opr->param();
if (param.format == param::ConvBias::Format::NCHW4_NCHW32 ||
param.format == param::ConvBias::Format::NCHW32_NCHW4)
return false;
if (param.format == param::ConvBias::Format::NCHW &&
(param.dilate_h != 1 || param.dilate_w != 1) &&
m_cudnn_enum == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) {
......@@ -152,16 +155,24 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec(
}
};
megdnn_assert(args.src_layout->dtype.category() ==
args.dst_layout->dtype.category() &&
args.src_tensor->layout.dtype.category() ==
args.filter_layout->dtype.category());
auto src_dtype = args.src_layout->dtype,
filter_dtype = args.filter_layout->dtype,
dst_dtype = args.dst_layout->dtype;
megdnn_assert(
(src_dtype.category() == dst_dtype.category()) ||
(args.opr->param().format == param::ConvBias::Format::NCHW4_NCHW &&
src_dtype.enumv() == DTypeEnum::QuantizedS8 &&
dst_dtype.enumv() == DTypeEnum::Float32));
megdnn_assert(src_dtype.category() == filter_dtype.category());
if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED) {
auto expected_bias_scale = get_scale(args.src_layout->dtype) *
get_scale(args.filter_layout->dtype);
alpha = expected_bias_scale / get_scale(args.dst_layout->dtype);
if (args.z_layout->ndim > 0) {
alpha = expected_bias_scale;
if (args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED)
alpha /= get_scale(args.dst_layout->dtype);
if (args.z_layout->ndim > 0 &&
args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) {
beta = get_scale(args.z_layout->dtype) /
get_scale(args.dst_layout->dtype);
}
......@@ -232,10 +243,23 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec(
break;
case param::ConvBias::NonlineMode::H_SWISH: {
megdnn_assert(args.dst_layout->dtype.category() ==
DTypeCategory::QUANTIZED);
auto&& elem_opr = args.handle->create_operator<ElemwiseMultiType>();
elem_opr->param().mode = ElemwiseMultiType::Param::Mode::QH_SWISH;
elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor));
DTypeCategory::QUANTIZED ||
(args.dst_layout->dtype.category() ==
DTypeCategory::FLOAT &&
args.opr->param().format ==
param::ConvBias::Format::NCHW4_NCHW));
if (args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED) {
auto&& elem_opr =
args.handle->create_operator<ElemwiseMultiType>();
elem_opr->param().mode =
ElemwiseMultiType::Param::Mode::QH_SWISH;
elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor));
} else {
auto&& elem_opr =
args.handle->create_operator<ElemwiseForward>();
elem_opr->param().mode = ElemwiseForward::Param::Mode::H_SWISH;
elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor));
}
break;
}
default:
......
......@@ -171,7 +171,8 @@ bool is_cudnn_supported(const BiasForwardSizeArgs& args) {
bool check_bias_share_in_channel(const TensorLayout& bias,
const param::ConvBias::Format format) {
bool share_in_channel = false;
if (format == param::ConvBias::Format::NCHW) {
if (format == param::ConvBias::Format::NCHW ||
format == param::ConvBias::Format::NCHW4_NCHW) {
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 &&
bias[3] == 1);
} else if (format == param::ConvBias::Format::NHWC) {
......@@ -179,7 +180,9 @@ bool check_bias_share_in_channel(const TensorLayout& bias,
bias[2] == 1);
} else if (format == param::ConvBias::Format::NCHW4 ||
format == param::ConvBias::Format::NCHW8 ||
format == param::ConvBias::Format::NCHW32) {
format == param::ConvBias::Format::NCHW32 ||
format == param::ConvBias::Format::NCHW4_NCHW32 ||
format == param::ConvBias::Format::NCHW32_NCHW4) {
share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[2] == 1 &&
bias[3] == 1);
} else if (format == param::ConvBias::Format::NHWCD4) {
......
......@@ -72,12 +72,19 @@ namespace conv_bias {
const TensorLayout& dst, const TensorLayout& bias,
const TensorLayout& z,
const param::ConvBias& param) {
src_desc.set(src, param.format);
using Format = param::ConvBias::Format;
Format src_format, dst_format;
src_format = dst_format = param.format;
if (param.format == Format::NCHW4_NCHW) {
src_format = Format::NCHW4;
dst_format = Format::NCHW;
}
src_desc.set(src, src_format);
filter_desc.set(filter);
if (z.ndim > 0) {
z_desc.set(z, param.format);
z_desc.set(z, dst_format);
}
dst_desc.set(dst, param.format);
dst_desc.set(dst, dst_format);
conv_desc.set_conv_bias(src.dtype, param, filter.group);
// cudnn requires the bias to be float tensor.
......@@ -91,6 +98,12 @@ namespace conv_bias {
float_bias_layout[1] * float_bias_layout[4],
float_bias_layout[2], float_bias_layout[3]});
bias_desc.set(float_bias_layout);
} else if (param.format == param::ConvBias::Format::NCHW4_NCHW) {
megdnn_assert(float_bias_layout.ndim == 4,
"NCHW4_NCHW format assumes bias tensor is stored "
"in NCHW layout, ndim(expected:4,got:%zu)",
float_bias_layout.ndim);
bias_desc.set(float_bias_layout);
} else {
bias_desc.set(float_bias_layout, param.format);
}
......@@ -99,9 +112,16 @@ namespace conv_bias {
void set_conv(const TensorLayout& src,
const CanonizedFilterMeta& filter,
const TensorLayout& dst, const param::ConvBias& param) {
src_desc.set(src, param.format);
using Format = param::ConvBias::Format;
Format src_format, dst_format;
src_format = dst_format = param.format;
if (param.format == Format::NCHW4_NCHW) {
src_format = Format::NCHW4;
dst_format = Format::NCHW;
}
src_desc.set(src, src_format);
filter_desc.set(filter);
dst_desc.set(dst, param.format);
dst_desc.set(dst, dst_format);
conv_desc.set_conv(src.dtype, param, filter.group);
}
};
......
......@@ -187,11 +187,15 @@ void FilterDesc<Param>::set(
megdnn_assert(filter_meta.group == 1);
#endif
auto filter_format = filter_meta.format;
if (filter_format == param::ConvBias::Format::NCHW4_NCHW) {
filter_format = param::ConvBias::Format::NCHW4;
}
// cuDNN version 6 or below filter_meta.group always is 1.
// So it is compatible for all cuDNN versions.
cudnn_check(cudnnSetFilter4dDescriptor(
desc, to_cudnn_dtype(filter_meta.dtype, filter_meta.format),
to_cudnn_format(filter_meta.format),
desc, to_cudnn_dtype(filter_meta.dtype, filter_format),
to_cudnn_format(filter_format),
filter_meta.ocpg * filter_meta.group, // cudnn 6 group always be 1
filter_meta.icpg, filter_meta.spatial[0], filter_meta.spatial[1]));
}
......
......@@ -203,6 +203,7 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
DISPATCH(Int8, Int16)
DISPATCH(Int8, Int32)
DISPATCH(QuantizedS8, QuantizedS32)
DISPATCH(QuantizedS8, Float32)
DISPATCH(Quantized8Asymm, QuantizedS32)
DISPATCH(Quantized4Asymm, QuantizedS32)
DISPATCH_RAW(QuantizedS8, QuantizedS32, QuantizedS32, FLOAT32,
......
......@@ -65,6 +65,15 @@ inline void StrategyFwd::on(dt_quint8& s, dt_quint8& f, dt_qint32& d,
d += cast(s, src_dt) * cast(f, filt_dt);
}
template <>
inline void StrategyFwd::on(dt_qint8& s, dt_qint8& f, dt_float32& d,
DType src_dt, DType filt_dt, DType) {
auto cast = [](const dt_qint8& val, DType dt) {
return dt.param<dtype::QuantizedS8>().dequantize(val);
};
d += cast(s, src_dt) * cast(f, filt_dt);
}
template <>
inline void StrategyFwd::on(dt_qint8& s, dt_qint8& f, dt_qint32& d, DType,
DType, DType) {
......@@ -149,8 +158,11 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
filter_meta.format == Format::NCHW44 ||
filter_meta.format == Format::NCHW44_DOT ||
filter_meta.format == Format::NCHW4 ||
filter_meta.format == Format::NCHW4_NCHW ||
filter_meta.format == Format::NCHW4_NCHW32 ||
filter_meta.format == Format::NCHW8 ||
filter_meta.format == Format::NCHW32) {
filter_meta.format == Format::NCHW32 ||
filter_meta.format == Format::NCHW32_NCHW4) {
spatial_start = 2;
channel_pos = 1;
batch_pos = 0;
......@@ -176,20 +188,25 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
if (filter_meta.format == Format::NCHW4 ||
filter_meta.format == Format::CHWN4 ||
filter_meta.format == Format::NCHW44_DOT ||
filter_meta.format == Format::NCHW44) {
filter_meta.format == Format::NCHW44 ||
filter_meta.format == Format::NCHW32_NCHW4) {
OC *= 4;
} else if (filter_meta.format == Format::NCHW8 ||
filter_meta.format == Format::NCHW88) {
OC *= 8;
} else if (filter_meta.format == Format::NCHW32) {
} else if (filter_meta.format == Format::NCHW32 ||
filter_meta.format == Format::NCHW4_NCHW32) {
OC *= 32;
}
size_t FS_G, FS_OC, FS_IC, FS_SPATIAL;
if (filter_meta.format == Format::NCHW ||
filter_meta.format == Format::NCHW4 ||
filter_meta.format == Format::NCHW4_NCHW ||
filter_meta.format == Format::NCHW4_NCHW32 ||
filter_meta.format == Format::NCHW8 ||
filter_meta.format == Format::NCHW32) {
filter_meta.format == Format::NCHW32 ||
filter_meta.format == Format::NCHW32_NCHW4) {
// g, oc, ic, fh, fw
FS_SPATIAL = 1;
FS_IC = FH * FW;
......@@ -299,10 +316,39 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
return n * layout.stride[0] + (c >> 5) * layout.stride[1] +
h * layout.stride[2] + w * layout.stride[3] +
(c & 0x1F) * layout.stride[4];
} else if (filter_meta.format == Format::NCHW32_NCHW4) {
if (is_output) {
return n * layout.stride[0] + (c / 4) * layout.stride[1] +
h * layout.stride[2] + w * layout.stride[3] +
(c & 0b11) * layout.stride[4];
} else {
return n * layout.stride[0] + (c >> 5) * layout.stride[1] +
h * layout.stride[2] + w * layout.stride[3] +
(c & 0x1F) * layout.stride[4];
}
} else if (filter_meta.format == Format::CHWN4) {
return (c / 4) * layout.stride[0] + h * layout.stride[1] +
w * layout.stride[2] + n * layout.stride[3] +
(c % 4) * layout.stride[4];
} else if (filter_meta.format == Format::NCHW4_NCHW) {
if (is_output) {
return n * layout.stride[0] + c * layout.stride[1] +
h * layout.stride[2] + w * layout.stride[3];
} else {
return n * layout.stride[0] + (c / 4) * layout.stride[1] +
h * layout.stride[2] + w * layout.stride[3] +
(c & 0b11) * layout.stride[4];
}
} else if (filter_meta.format == Format::NCHW4_NCHW32) {
if (is_output) {
return n * layout.stride[0] + (c >> 5) * layout.stride[1] +
h * layout.stride[2] + w * layout.stride[3] +
(c & 0x1F) * layout.stride[4];
} else {
return n * layout.stride[0] + (c / 4) * layout.stride[1] +
h * layout.stride[2] + w * layout.stride[3] +
(c & 0b11) * layout.stride[4];
}
} else {
megdnn_assert(filter_meta.format == Format::NCHW4,
"invalid conv format");
......@@ -314,7 +360,9 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
auto get_filter_addr = [&](GroupCounter& gc_out, size_t ic, size_t ic0,
size_t fh, size_t fw) {
if (filter_meta.format == Format::NCHW4) {
if (filter_meta.format == Format::NCHW4 ||
filter_meta.format == Format::NCHW4_NCHW ||
filter_meta.format == Format::NCHW4_NCHW32) {
return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
(ic - ic0) / 4 * FS_IC * 4 +
(fh * FW + fw) * FS_SPATIAL * 4 + ((ic - ic0) & 0b11);
......@@ -322,7 +370,8 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr,
return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
(ic - ic0) / 8 * FS_IC * 8 +
(fh * FW + fw) * FS_SPATIAL * 8 + ((ic - ic0) & 0b111);
} else if (filter_meta.format == Format::NCHW32) {
} else if (filter_meta.format == Format::NCHW32 ||
filter_meta.format == Format::NCHW32_NCHW4) {
return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC +
(ic - ic0) / 32 * FS_IC * 32 +
(fh * FW + fw) * FS_SPATIAL * 32 + ((ic - ic0) & 0x1F);
......@@ -569,12 +618,16 @@ template <typename stype, typename ftype, typename dtype, typename comp_type>
void forward(_megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_out dst,
const Convolution::CanonizedFilterMeta& filter_meta) {
megdnn_assert(filter_meta.spatial_ndim == 2);
megdnn_assert(filter_meta.format == param::Convolution::Format::NCHW ||
filter_meta.format == param::Convolution::Format::NHWC ||
filter_meta.format == param::Convolution::Format::NCHW88 ||
filter_meta.format == param::Convolution::Format::NCHW44 ||
filter_meta.format == param::Convolution::Format::NCHW44_DOT ||
filter_meta.format == param::Convolution::Format::NCHW4);
megdnn_assert(
filter_meta.format == param::Convolution::Format::NCHW ||
filter_meta.format == param::Convolution::Format::NHWC ||
filter_meta.format == param::Convolution::Format::NCHW88 ||
filter_meta.format == param::Convolution::Format::NCHW44 ||
filter_meta.format == param::Convolution::Format::NCHW44_DOT ||
filter_meta.format == param::Convolution::Format::NCHW4 ||
filter_meta.format == param::Convolution::Format::NCHW4_NCHW ||
filter_meta.format == param::Convolution::Format::NCHW4_NCHW32 ||
filter_meta.format == param::Convolution::Format::NCHW32_NCHW4);
compute2d<stype, ftype, dtype, comp_type, StrategyFwd>(
src, const_cast<ftype*>(fptr), dst, filter_meta);
}
......@@ -631,8 +684,11 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter,
case param::Convolution::Format::NCHW44_DOT:
case param::Convolution::Format::NHWC:
case param::Convolution::Format::NCHW4:
case param::Convolution::Format::NCHW4_NCHW:
case param::Convolution::Format::NCHW4_NCHW32:
case param::Convolution::Format::NCHW8:
case param::Convolution::Format::NCHW32:
case param::Convolution::Format::NCHW32_NCHW4:
case param::Convolution::Format::CHWN4:
compute2d<stype, ftype, dtype, comp_type, StrategyFwd, FilterMeta,
FilterVisitor>(src, filter.compatible_ptr<ftype>(), dst,
......@@ -666,7 +722,8 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter,
using Format = param::ConvBias::Format;
switch (filter_meta.format) {
case Format::NCHW: {
case Format::NCHW:
case Format::NCHW4_NCHW: {
int dst_batch = dst.layout.shape[0];
int dst_channel = dst.layout.shape[1];
int chann_stride = dst.layout.shape[2] * dst.layout.shape[3];
......@@ -707,6 +764,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter,
} while (0)
case Format::NCHW44:
case Format::NCHW44_DOT:
case Format::NCHW32_NCHW4:
case Format::NCHW4: {
BIAS_ADD_NCHWx(4);
break;
......@@ -715,6 +773,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter,
BIAS_ADD_NCHWx(8);
break;
};
case Format::NCHW4_NCHW32:
case Format::NCHW32: {
BIAS_ADD_NCHWx(32);
break;
......
......@@ -429,6 +429,62 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_NCHW4) {
checker.exec({{1, 4, 2, 2, 4}, {16, 4, 3, 3, 4}, {1, 4, 1, 1, 4}, {}, {}});
}
TEST_F(CUDA, CONV_BIAS_FORWARD_NCHW4_NCHW) {
require_compute_capability(6, 1);
using namespace conv_bias;
Checker<ConvBiasForward> checker(handle_cuda());
UniformIntRNG int_rng{-3, 3};
UniformFloatRNG float_rng{-50, 50};
ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW4_NCHW;
param.nonlineMode = ConvBias::Param::NonlineMode::IDENTITY;
checker.set_dtype(0, dtype::QuantizedS8(1.9980618f))
.set_dtype(1, dtype::QuantizedS8(1.9980927f))
.set_dtype(2, dtype::Float32())
.set_dtype(3, dtype::Float32())
.set_dtype(4, dtype::Float32())
.set_rng(0, &int_rng)
.set_rng(1, &int_rng)
.set_rng(2, &float_rng)
.set_rng(3, &float_rng)
.set_param(param);
auto opr = handle_cuda()->create_operator<ConvBias>();
auto run = [&](const TensorShapeArray& shapes) {
opr->param() = param;
TensorLayout dst_layout;
opr->deduce_layout({shapes[0], dtype::Float32()},
{shapes[1], dtype::Float32()}, {}, {}, dst_layout);
checker.execs({shapes[0], shapes[1], shapes[2], dst_layout, {}});
};
run({{1, 4, 4, 4, 4}, {4, 4, 3, 3, 4}, {1, 4, 1, 1}});
run({{20, 1, 24, 24, 4}, {24, 1, 2, 2, 4}, {1, 24, 1, 1}});
run({{20, 2, 24, 24, 4}, {24, 2, 3, 3, 4}, {1, 24, 1, 1}});
param.sparse = ConvBias::Param::Sparse::GROUP;
param.nonlineMode = ConvBias::Param::NonlineMode::RELU;
checker.set_param(param);
run({{1, 4, 24, 24, 4}, {4, 4, 1, 1, 1, 4}, {1, 16, 1, 1}});
run({{20, 8, 24, 24, 4}, {4, 24, 2, 2, 2, 4}, {1, 96, 1, 1}});
run({{1, 3, 24, 24, 4}, {3, 8, 1, 3, 3, 4}, {1, 24, 1, 1}});
param.pad_h = param.pad_w = 1;
param.stride_h = param.stride_w = 2;
checker.set_param(param);
run({{10, 16, 28, 28, 4}, {8, 8, 2, 3, 3, 4}, {1, 64, 1, 1}});
// case which cudnn not supported
param.sparse = ConvBias::Param::Sparse::DENSE;
param.pad_h = param.pad_w = 1;
param.stride_h = param.stride_w = 1;
param.nonlineMode = ConvBias::Param::NonlineMode::H_SWISH;
checker.set_param(param);
checker.exec({{1, 4, 2, 2, 4}, {16, 4, 3, 3, 4}, {1, 16, 1, 1}, {}, {}});
}
#endif
TEST_F(CUDA, CONV_BIAS_FORWARD_CHANWISE) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册