提交 44376f70 编写于 作者: M Megvii Engine Team 提交者: huangxinda

refactor(mgb): make conv-backward-data handle noncontiguous tensors

GitOrigin-RevId: 0a8f66f9d378b6466bc383a94c57ec80bcc5cb74
上级 7b2a76d1
......@@ -23,6 +23,14 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available(
if (args.z_layout->ndim > 0)
return false;
if (args.filter_meta.format != Param::Format::NCHW &&
args.filter_meta.format != Param::Format::NHWC) {
if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
}
}
auto dst_layout = *args.dst_layout;
if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
dst_layout.dtype = DType();
......
......@@ -24,9 +24,12 @@ using namespace conv_bias;
bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
if (args.filter_meta.format != Param::Format::NCHW &&
args.filter_meta.format != Param::Format::NHWC) {
if (!args.src_layout->is_contiguous() ||
!args.dst_layout->is_contiguous()) {
return false;
}
}
if ((args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 ||
args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
......
......@@ -82,8 +82,8 @@ ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack;
ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs(
ConvolutionBackwardDataImpl* o, const TensorLayout& filter,
const TensorLayout& diff, const TensorLayout& grad)
: SizeArgs(o, filter, o->check_layout_fwd(grad, filter, diff), diff,
grad) {}
: SizeArgs(o, filter, o->make_canonized_filter_meta(grad.ndim, filter),
diff, grad) {}
ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs(
ConvolutionBackwardDataImpl* o, const TensorLayout& filter,
......
......@@ -21,6 +21,14 @@ using namespace convolution;
bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available(
const SizeArgs &args) const {
if (args.filter_meta.format != Param::Format::NCHW &&
args.filter_meta.format != Param::Format::NHWC) {
if (!args.grad_layout->is_contiguous() ||
!args.diff_layout->is_contiguous()) {
return false;
}
}
CUDNNBwdDataDescs D;
if (!is_cudnn_supported(args.as_fwd_args()))
......
......@@ -25,6 +25,11 @@ bool ConvolutionBackwardDataImpl::AlgoInt8NCHW4DotProdImplicitGemm::
if (fm.format != Param::Format::NCHW4)
return false;
if (!args.grad_layout->is_contiguous() ||
!args.diff_layout->is_contiguous()) {
return false;
}
bool available = true;
auto src_dtype = args.diff_layout->dtype,
......
......@@ -25,6 +25,11 @@ bool ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::
if (fm.format != Param::Format::NCHW)
return false;
if (!args.grad_layout->is_contiguous() ||
!args.diff_layout->is_contiguous()) {
return false;
}
bool available = true;
auto src_dtype = args.diff_layout->dtype,
......
......@@ -64,8 +64,8 @@ ConvolutionBackwardDataImpl::AlgoMatmul::get_subopr_list(
const TensorLayoutArray& layouts, const OperatorBase* opr) const {
const ConvolutionBackwardDataImpl* conv_backward_data_opr =
static_cast<const ConvolutionBackwardDataImpl*>(opr);
CanonizedFilterMeta fm = conv_backward_data_opr->check_layout_fwd(
layouts[2], layouts[0], layouts[1]);
CanonizedFilterMeta fm = conv_backward_data_opr->make_canonized_filter_meta(
layouts[2].ndim, layouts[0]);
auto&& config = sub_opr_config(fm, layouts[0], layouts[1], layouts[2],
conv_backward_data_opr);
......
......@@ -661,7 +661,6 @@ template <typename ftype, typename dtype, typename gtype>
void backward_data(_megdnn_tensor_in filter, _megdnn_tensor_in diff,
_megdnn_tensor_out grad,
const Convolution::CanonizedFilterMeta& filter_meta) {
megdnn_assert(grad.layout.is_contiguous());
memset(grad.raw_ptr, 0, grad.layout.span().dist_byte());
megdnn_assert(filter_meta.spatial_ndim == 2);
if (filter_meta.format == param::Convolution::Format::NHWCD4) {
......@@ -676,7 +675,6 @@ template <typename stype, typename dtype, typename gtype>
void backward_filter(_megdnn_tensor_in src, _megdnn_tensor_in diff,
_megdnn_tensor_out grad,
const Convolution::CanonizedFilterMeta& filter_meta) {
megdnn_assert(grad.layout.is_contiguous());
memset(grad.raw_ptr, 0, grad.layout.span().dist_byte());
megdnn_assert(filter_meta.spatial_ndim == 2);
compute2d<stype, gtype, dtype, dtype, StrategyBwdFlt>(
......
......@@ -238,6 +238,25 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA) {
}
}
TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_CUDNN) {
if (cuda::is_compute_capability_required(7, 0))
return;
using namespace convolution;
Checker<ConvolutionBackwardData> checker(handle_cuda());
checker.set_before_exec_callback(AlgoChecker<ConvolutionBackwardData>(
"CUDNN_CONVOLUTION"));
//! noncontiguous case
{
param::Convolution param;
param.pad_h = param.pad_w = 1;
checker.set_param(param).execl(TensorLayoutArray{
{{16, 16, 3, 3}, {144, 9, 3, 1}, dtype::Float32()},
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()},
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()},
});
}
}
TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_MATMUL) {
using namespace convolution;
std::vector<TestArg> args = get_args_cuda_conv_bwd_data();
......@@ -265,6 +284,16 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_MATMUL) {
.set_param(arg.param)
.exec(TensorLayoutArray{filter, dst, src});
}
//! noncontiguous case
{
param::Convolution param;
param.pad_h = param.pad_w = 1;
checker.set_param(param).execl(TensorLayoutArray{
{{16, 16, 3, 3}, {144, 9, 3, 1}, dtype::Float32()},
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()},
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::Float32()},
});
}
}
TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NCHW4_DP4A) {
......@@ -355,6 +384,16 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_INT8_NCHW_DP4A) {
}
checker.set_rng(0, &rng).set_rng(1, &rng).set_param(arg.param).exec(
TensorLayoutArray{filter, dst, src});
//! noncontiguous case
{
param::Convolution param;
param.pad_h = param.pad_w = 1;
checker.set_param(param).execl(TensorLayoutArray{
{{16, 16, 3, 3}, {144, 9, 3, 1}, dtype::QuantizedS8{1.3f}},
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::QuantizedS8{1.2f}},
{{2, 16, 7, 7}, {1568, 49, 7, 1}, dtype::QuantizedS8{1.2f}}
});
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册