From 666c3bb9b07746ecb03ad86b637ad27214ec072c Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Thu, 19 Dec 2019 09:54:55 +0100 Subject: [PATCH] handle multi-inputs with empty inputs for mkldnn_concat_op (#21827) test=develop --- .../operators/mkldnn/concat_mkldnn_op.cc | 12 +++++++++- paddle/fluid/operators/roi_align_op.h | 2 -- .../unittests/mkldnn/test_concat_mkldnn_op.py | 23 ++++++++++++++++--- 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc index adad812b838..91e9d7bbafe 100644 --- a/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc @@ -59,6 +59,16 @@ static const mkldnn::engine& GetMKLDNNEngine( return dev_ctx.GetEngine(); } +// From a multi-input, gather only nonempty inputs +static const std::vector ReduceMultiInput( + const std::vector& inputs) { + std::vector reduced(inputs.size()); + auto end_it = std::copy_if(inputs.begin(), inputs.end(), reduced.begin(), + [](const Tensor* t) { return t->numel() > 0; }); + reduced.resize(std::distance(reduced.begin(), end_it)); + return reduced; +} + template class ConcatPrimitiveFactory { public: @@ -120,7 +130,7 @@ template class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel { public: void Compute(const paddle::framework::ExecutionContext& ctx) const override { - auto multi_input = ctx.MultiInput("X"); + auto multi_input = ReduceMultiInput(ctx.MultiInput("X")); EnforceLayouts(multi_input); Tensor* output = ctx.Output("Out"); int concat_axis = ctx.Attr("axis"); diff --git a/paddle/fluid/operators/roi_align_op.h b/paddle/fluid/operators/roi_align_op.h index 4ed680f0c38..8c6b7cfe5d1 100644 --- a/paddle/fluid/operators/roi_align_op.h +++ b/paddle/fluid/operators/roi_align_op.h @@ -154,8 +154,6 @@ class CPUROIAlignOpKernel : public framework::OpKernel { int width = in_dims[3]; int rois_num = rois->dims()[0]; - if (rois_num == 0) return; - auto in_stride = framework::stride(in_dims); auto roi_stride = framework::stride(rois->dims()); auto out_stride = framework::stride(out->dims()); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_concat_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_concat_mkldnn_op.py index 70897ed53e0..4f3dece5be3 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_concat_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_concat_mkldnn_op.py @@ -15,7 +15,7 @@ from __future__ import print_function import unittest -from paddle.fluid.tests.unittests.test_concat_op import TestConcatOp, TestConcatOp2, TestConcatOp3 +from paddle.fluid.tests.unittests.test_concat_op import TestConcatOp, TestConcatOp2, TestConcatOp3, TestConcatOp4 class TestMKLDNNConcatOp(TestConcatOp): @@ -42,7 +42,7 @@ class TestMKLDNNConcatOp2(TestConcatOp2): self._cpu_only = True def test_check_output(self): - # TODO(wangzhongpu): support mkldnn op in dygraph mode + # TODO(wangzhongpu): support mkldnn op in dygraph mode self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False)) def test_check_grad(self): @@ -59,7 +59,24 @@ class TestMKLDNNConcatOp3(TestConcatOp3): self._cpu_only = True def test_check_output(self): - # TODO(wangzhongpu): support mkldnn op in dygraph mode + # TODO(wangzhongpu): support mkldnn op in dygraph mode + self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False)) + + def test_check_grad(self): + pass + + def init_kernel_type(self): + self.use_mkldnn = True + + +class TestMKLDNNConcatOp4(TestConcatOp4): + def setUp(self): + super(TestMKLDNNConcatOp4, self).setUp() + self.attrs["use_mkldnn"] = True + self._cpu_only = True + + def test_check_output(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False)) def test_check_grad(self): -- GitLab