From a72907bbf45dc78bebe6050714d31b3386b3d32e Mon Sep 17 00:00:00 2001 From: jerrywgz Date: Sun, 5 May 2019 15:44:49 +0800 Subject: [PATCH] Enhance concat op to support empty input. (#17015) * enhance_concat, test=develop --- paddle/fluid/operators/concat_op.h | 14 +++++++++++--- .../paddle/fluid/tests/unittests/test_concat_op.py | 11 +++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index bd474be0fac..0414550dd18 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -37,6 +37,9 @@ class ConcatKernel : public framework::OpKernel { if (axis == 0 && ins.size() < 10) { size_t output_offset = 0; for (auto* in : ins) { + if (!in || in->numel() == 0UL) { + continue; + } auto in_stride = framework::stride_numel(in->dims()); auto out_stride = framework::stride_numel(out->dims()); StridedNumelCopyWithAxis(ctx.device_context(), axis, @@ -45,9 +48,13 @@ class ConcatKernel : public framework::OpKernel { output_offset += in_stride[axis]; } } else { - std::vector inputs(ins.size()); + std::vector inputs; for (size_t j = 0; j < ins.size(); ++j) { - inputs[j] = *ins[j]; + if (ins[j] && ins[j]->numel() > 0) { + inputs.push_back(*ins[j]); + } else { + continue; + } } auto& dev_ctx = ctx.template device_context(); paddle::operators::math::ConcatFunctor concat_functor; @@ -82,7 +89,8 @@ class ConcatGradKernel : public framework::OpKernel { // get output tensor that the name is not kEmptyVarName std::vector outputs; for (size_t j = 0; j < outs.size(); ++j) { - if (out_var_names[j] != framework::kEmptyVarName) { + if (out_var_names[j] != framework::kEmptyVarName && + outs[j]->numel() != 0UL) { outs[j]->mutable_data(ctx.GetPlace()); outputs.push_back(outs[j]); } else { diff --git a/python/paddle/fluid/tests/unittests/test_concat_op.py b/python/paddle/fluid/tests/unittests/test_concat_op.py index 436ab7d49f4..42276a0647d 100644 --- a/python/paddle/fluid/tests/unittests/test_concat_op.py +++ b/python/paddle/fluid/tests/unittests/test_concat_op.py @@ -64,5 +64,16 @@ class TestConcatOp3(TestConcatOp): pass +class TestConcatOp4(TestConcatOp): + def init_test_data(self): + self.x0 = np.random.random((2, 3, 4, 5)).astype('float32') + self.x1 = np.random.random((2, 3, 4, 5)).astype('float32') + self.x2 = np.random.random((0, 3, 4, 5)).astype('float32') + self.axis = 0 + + def test_check_grad(self): + pass + + if __name__ == '__main__': unittest.main() -- GitLab