提交 44376f70 编写于 作者: M Megvii Engine Team 提交者: huangxinda

refactor(mgb): make conv-backward-data handle noncontiguous tensors

GitOrigin-RevId: 0a8f66f9d378b6466bc383a94c57ec80bcc5cb74
上级 7b2a76d1
...@@ -23,6 +23,14 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available( ...@@ -23,6 +23,14 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available(
if (args.z_layout->ndim > 0) if (args.z_layout->ndim > 0)
return false; return false;
if (args.filter_meta.format != Param::Format::NCHW &&
args.filter_meta.format != Param::Format::NHWC) {
if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
}
}
auto dst_layout = *args.dst_layout; auto dst_layout = *args.dst_layout;
if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
dst_layout.dtype = DType(); dst_layout.dtype = DType();
......
...@@ -24,9 +24,12 @@ using namespace conv_bias; ...@@ -24,9 +24,12 @@ using namespace conv_bias;
bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
const SizeArgs& args) const { const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() || if (args.filter_meta.format != Param::Format::NCHW &&
!args.dst_layout->is_contiguous()) { args.filter_meta.format != Param::Format::NHWC) {
return false; if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
}
} }
if ((args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || if ((args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 ||
args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) && args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
......
...@@ -82,8 +82,8 @@ ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack; ...@@ -82,8 +82,8 @@ ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack;
ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs(
ConvolutionBackwardDataImpl* o, const TensorLayout& filter, ConvolutionBackwardDataImpl* o, const TensorLayout& filter,
const TensorLayout& diff, const TensorLayout& grad) const TensorLayout& diff, const TensorLayout& grad)
: SizeArgs(o, filter, o->check_layout_fwd(grad, filter, diff), diff, : SizeArgs(o, filter, o->make_canonized_filter_meta(grad.ndim, filter),
grad) {} diff, grad) {}
ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs(
ConvolutionBackwardDataImpl* o, const TensorLayout& filter, ConvolutionBackwardDataImpl* o, const TensorLayout& filter,
......
...@@ -21,6 +21,14 @@ using namespace convolution; ...@@ -21,6 +21,14 @@ using namespace convolution;
bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available(
const SizeArgs &args) const { const SizeArgs &args) const {
if (args.filter_meta.format != Param::Format::NCHW &&
args.filter_meta.format != Param::Format::NHWC) {
if (!args.grad_layout->is_contiguous() ||
!args.diff_layout->is_contiguous()) {
return false;
}
}
CUDNNBwdDataDescs D; CUDNNBwdDataDescs D;
if (!is_cudnn_supported(args.as_fwd_args())) if (!is_cudnn_supported(args.as_fwd_args()))
......
...@@ -25,6 +25,11 @@ bool ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm:: ...@@ -25,6 +25,11 @@ bool ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::
if (fm.format != Param::Format::NCHW4) if (fm.format != Param::Format::NCHW4)
return false; return false;
if (!args.grad_layout->is_contiguous() ||
!args.diff_layout->is_contiguous()) {
return false;
}
bool available = true; bool available = true;
auto src_dtype = args.diff_layout->dtype, auto src_dtype = args.diff_layout->dtype,
......
...@@ -25,6 +25,11 @@ bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm:: ...@@ -25,6 +25,11 @@ bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::
if (fm.format != Param::Format::NCHW) if (fm.format != Param::Format::NCHW)
return false; return false;
if (!args.grad_layout->is_contiguous() ||
!args.diff_layout->is_contiguous()) {
return false;
}
bool available = true; bool available = true;
auto src_dtype = args.diff_layout->dtype, auto src_dtype = args.diff_layout->dtype,
......
...@@ -64,8 +64,8 @@ ConvolutionBackwardDataImpl::AlgoMatmul::get_subopr_list( ...@@ -64,8 +64,8 @@ ConvolutionBackwardDataImpl::AlgoMatmul::get_subopr_list(
const TensorLayoutArray& layouts, const OperatorBase* opr) const { const TensorLayoutArray& layouts, const OperatorBase* opr) const {
const ConvolutionBackwardDataImpl* conv_backward_data_opr = const ConvolutionBackwardDataImpl* conv_backward_data_opr =
static_cast<const ConvolutionBackwardDataImpl*>(opr); static_cast<const ConvolutionBackwardDataImpl*>(opr);
CanonizedFilterMeta fm = conv_backward_data_opr->check_layout_fwd( CanonizedFilterMeta fm = conv_backward_data_opr->make_canonized_filter_meta(
layouts[2], layouts[0], layouts[1]); layouts[2].ndim, layouts[0]);
auto&& config = sub_opr_config(fm, layouts[0], layouts[1], layouts[2], auto&& config = sub_opr_config(fm, layouts[0], layouts[1], layouts[2],
conv_backward_data_opr); conv_backward_data_opr);
......
...@@ -661,7 +661,6 @@ template <typename ftype, typename dtype, typename gtype> ...@@ -661,7 +661,6 @@ template <typename ftype, typename dtype, typename gtype>
void backward_data(_megdnn_tensor_in filter, _megdnn_tensor_in diff, void backward_data(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_tensor_out grad,
const Convolution::CanonizedFilterMeta& filter_meta) { const Convolution::CanonizedFilterMeta& filter_meta) {
megdnn_assert(grad.layout.is_contiguous());
memset(grad.raw_ptr, 0, grad.layout.span().dist_byte()); memset(grad.raw_ptr, 0, grad.layout.span().dist_byte());
megdnn_assert(filter_meta.spatial_ndim == 2); megdnn_assert(filter_meta.spatial_ndim == 2);
if (filter_meta.format == param::Convolution::Format::NHWCD4) { if (filter_meta.format == param::Convolution::Format::NHWCD4) {
...@@ -676,7 +675,6 @@ template <typename stype, typename dtype, typename gtype> ...@@ -676,7 +675,6 @@ template <typename stype, typename dtype, typename gtype>
void backward_filter(_megdnn_tensor_in src, _megdnn_tensor_in diff, void backward_filter(_megdnn_tensor_in src, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_tensor_out grad,
const Convolution::CanonizedFilterMeta& filter_meta) { const Convolution::CanonizedFilterMeta& filter_meta) {
megdnn_assert(grad.layout.is_contiguous());
memset(grad.raw_ptr, 0, grad.layout.span().dist_byte()); memset(grad.raw_ptr, 0, grad.layout.span().dist_byte());
megdnn_assert(filter_meta.spatial_ndim == 2); megdnn_assert(filter_meta.spatial_ndim == 2);
compute2d<stype, gtype, dtype, dtype, StrategyBwdFlt>( compute2d<stype, gtype, dtype, dtype, StrategyBwdFlt>(
......
...@@ -238,6 +238,25 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA) { ...@@ -238,6 +238,25 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA) {
} }
} }
TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_CUDNN) {
if (cuda::is_compute_capability_required(7, 0))
return;
using namespace convolution;
Checker<ConvolutionBackwardData> checker(handle_cuda());
checker.set_before_exec_callback(AlgoChecker<ConvolutionBackwardData>(
"CUDNN_CONVOLUTION"));
//! noncontiguous case
{
param::Convolution param;
param.pad_h = param.pad_w = 1;
checker.set_param(param).execl(TensorLayoutArray{
{{16, 16, 3, 3}, {144, 9, 3, 1}, dtype::Float32()},
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()},
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()},
});
}
}
TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_MATMUL) { TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_MATMUL) {
using namespace convolution; using namespace convolution;
std::vector<TestArg> args = get_args_cuda_conv_bwd_data(); std::vector<TestArg> args = get_args_cuda_conv_bwd_data();
...@@ -265,6 +284,16 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_MATMUL) { ...@@ -265,6 +284,16 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_MATMUL) {
.set_param(arg.param) .set_param(arg.param)
.exec(TensorLayoutArray{filter, dst, src}); .exec(TensorLayoutArray{filter, dst, src});
} }
//! noncontiguous case
{
param::Convolution param;
param.pad_h = param.pad_w = 1;
checker.set_param(param).execl(TensorLayoutArray{
{{16, 16, 3, 3}, {144, 9, 3, 1}, dtype::Float32()},
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()},
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()},
});
}
} }
TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NCHW4_DP4A) { TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NCHW4_DP4A) {
...@@ -355,6 +384,16 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NCHW_DP4A) { ...@@ -355,6 +384,16 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NCHW_DP4A) {
} }
checker.set_rng(0, &rng).set_rng(1, &rng).set_param(arg.param).exec( checker.set_rng(0, &rng).set_rng(1, &rng).set_param(arg.param).exec(
TensorLayoutArray{filter, dst, src}); TensorLayoutArray{filter, dst, src});
//! noncontiguous case
{
param::Convolution param;
param.pad_h = param.pad_w = 1;
checker.set_param(param).execl(TensorLayoutArray{
{{16, 16, 3, 3}, {144, 9, 3, 1}, dtype::QuantizedS8{1.3f}},
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::QuantizedS8{1.2f}},
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::QuantizedS8{1.2f}}
});
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册