From 1dc1f9270e7872c73fdfb93d3e8ff9d83e8286cf Mon Sep 17 00:00:00 2001 From: GaoWei8 <53294385+GaoWei8@users.noreply.github.com> Date: Tue, 17 Mar 2020 14:57:11 +0800 Subject: [PATCH] Fix lod error of concat op for axis = 0 (#22538) --- paddle/fluid/operators/concat_op.h | 30 ++++++++++++++- .../fluid/tests/unittests/test_concat_op.py | 37 +++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 7c0fe3b635..cf0ae2ec8a 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -76,8 +76,8 @@ template class ConcatKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto ins = ctx.MultiInput("X"); - framework::Tensor* out = ctx.Output("Out"); + auto ins = ctx.MultiInput("X"); + framework::LoDTensor* out = ctx.Output("Out"); PADDLE_ENFORCE_EQ(ins[0] != nullptr, true, "The input should not be null."); auto axis = ctx.Attr("axis"); bool need_resize_out_dims = false; @@ -102,6 +102,32 @@ class ConcatKernel : public framework::OpKernel { auto place = ctx.GetPlace(); out->mutable_data(place); + // If axis is 0, the lod of the output is not the same as inputs. + if (axis == 0 && ins[0]->lod().size() > 0) { + size_t lod_size_0 = ins[0]->lod().size(); + size_t lod_size = lod_size_0; + for (size_t i = 1; i < ins.size(); ++i) { + if (ins[i]->lod().size() > 0) { + PADDLE_ENFORCE_EQ( + ins[i]->lod().size(), lod_size_0, + platform::errors::Unimplemented( + "The lod level of all input LoDTensors should be same. " + "Maybe different lod level of input LoDTensors can concat," + " it is not supported currently.")); + } else { + lod_size = 0; + break; + } + } + if (lod_size) { + auto* out_lod = out->mutable_lod(); + for (size_t i = 1; i < ins.size(); ++i) { + auto in_lod = ConvertToLengthBasedLoD(ins[i]->lod()); + AppendLoD(out_lod, in_lod); + } + } + } + // Sometimes direct copies will be faster, this maybe need deeply analysis. if (axis == 0 && ins.size() < 10) { size_t output_offset = 0; diff --git a/python/paddle/fluid/tests/unittests/test_concat_op.py b/python/paddle/fluid/tests/unittests/test_concat_op.py index 48fd0f56b1..b84608889e 100644 --- a/python/paddle/fluid/tests/unittests/test_concat_op.py +++ b/python/paddle/fluid/tests/unittests/test_concat_op.py @@ -100,6 +100,41 @@ class TestConcatOp5(TestConcatOp): self.axis = -3 +class TestConcatOp6(TestConcatOp): + def setUp(self): + self.op_type = "concat" + self.dtype = self.get_dtype() + self.init_test_data() + self.lod = [[20, 80]] + self.out_lod = [[20, 80, 20, 80, 20, 80]] + self.inputs = { + 'X': [('x0', (self.x0, self.lod)), ('x1', (self.x1, self.lod)), + ('x2', (self.x2, self.lod))] + } + self.attrs = {'axis': self.axis} + if self.axis < 0: + self.actual_axis = self.axis + len(self.x0.shape) + self.actual_axis = self.actual_axis if self.actual_axis > 0 else 0 + else: + self.actual_axis = self.axis + out = np.concatenate((self.x0, self.x1, self.x2), axis=self.actual_axis) + self.outputs = {'Out': (out, self.out_lod)} + + def test_check_output(self): + self.check_output(check_dygraph=False) + + def test_check_grad(self): + self.check_grad(['x0'], 'Out', check_dygraph=False) + self.check_grad(['x1'], 'Out', check_dygraph=False) + self.check_grad(['x2'], 'Out', check_dygraph=False) + + def init_test_data(self): + self.x0 = np.random.random([100]).astype(self.dtype) + self.x1 = np.random.random([100]).astype(self.dtype) + self.x2 = np.random.random([100]).astype(self.dtype) + self.axis = 0 + + def create_test_AxisTensor(parent): class TestConcatAxisTensor(parent): def setUp(self): @@ -134,6 +169,7 @@ create_test_AxisTensor(TestConcatOp2) create_test_AxisTensor(TestConcatOp3) create_test_AxisTensor(TestConcatOp4) create_test_AxisTensor(TestConcatOp5) +create_test_AxisTensor(TestConcatOp6) #----------------Concat Fp16---------------- @@ -155,6 +191,7 @@ create_test_fp16(TestConcatOp2) create_test_fp16(TestConcatOp3) create_test_fp16(TestConcatOp4) create_test_fp16(TestConcatOp5) +create_test_fp16(TestConcatOp6) class TestConcatOpError(unittest.TestCase): -- GitLab