diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 0d0d92f42a1a80e022fc1fef15c71797637e9a46..7f249924f5b9a1092af725f2f9271ac3cdbd26f3 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -36,7 +36,10 @@ class ConcatOp : public framework::OperatorWithKernel { "Output(Out) of ConcatOp should not be null."); auto ins = ctx->GetInputsDim("X"); - size_t axis = static_cast(ctx->Attrs().Get("axis")); + size_t axis = + ComputeAxis(static_cast(ctx->Attrs().Get("axis")), + static_cast(ins[0].size())); + const size_t n = ins.size(); PADDLE_ENFORCE_GT(n, 0, "Input tensors count should > 0."); @@ -115,7 +118,10 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default false) Indicates if MKL-DNN kernel will be used") .SetDefault(false); AddAttr("axis", - "The axis along which the input tensors will be concatenated.") + "The axis along which the input tensors will be concatenated." + "The axis could also be negative numbers. Negative axis is " + "interpreted as counting from the end of the rank." + "i.e., axis + rank(X) th dimension.") .SetDefault(0); AddAttr("use_quantizer", "(bool, default false) " diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 0414550dd18f7818ff922dfd5113ede763299185..4a371de32354d196492a54dce47bf73bf644bad1 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -23,13 +23,22 @@ limitations under the License. */ namespace paddle { namespace operators { +static inline int64_t ComputeAxis(int64_t axis, int64_t rank) { + if (axis < 0) { + axis = axis + rank; + } + return axis > 0 ? axis : 0; +} + template class ConcatKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto ins = ctx.MultiInput("X"); framework::Tensor* out = ctx.Output("Out"); - int64_t axis = static_cast(ctx.Attr("axis")); + PADDLE_ENFORCE(ins[0], "The input should not be null."); + auto axis = ComputeAxis(static_cast(ctx.Attr("axis")), + static_cast(ins[0]->dims().size())); auto place = ctx.GetPlace(); out->mutable_data(place); @@ -83,8 +92,9 @@ class ConcatGradKernel : public framework::OpKernel { } } } - - int64_t axis = static_cast(ctx.Attr("axis")); + PADDLE_ENFORCE(ins[0], "The input should not be null."); + auto axis = ComputeAxis(static_cast(ctx.Attr("axis")), + static_cast(ins[0]->dims().size())); // get output tensor that the name is not kEmptyVarName std::vector outputs; diff --git a/python/paddle/fluid/tests/unittests/test_concat_op.py b/python/paddle/fluid/tests/unittests/test_concat_op.py index 42276a0647d95173d064bd1609ce743d7933ab79..b5d1115723e350f56e0d3e04d191886e43a15667 100644 --- a/python/paddle/fluid/tests/unittests/test_concat_op.py +++ b/python/paddle/fluid/tests/unittests/test_concat_op.py @@ -25,9 +25,15 @@ class TestConcatOp(OpTest): self.init_test_data() self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]} self.attrs = {'axis': self.axis} + if self.axis < 0: + self.actual_axis = self.axis + len(self.x0.shape) + self.actual_axis = self.actual_axis if self.actual_axis > 0 else 0 + else: + self.actual_axis = self.axis + self.outputs = { 'Out': np.concatenate( - (self.x0, self.x1, self.x2), axis=self.axis) + (self.x0, self.x1, self.x2), axis=self.actual_axis) } def test_check_output(self): @@ -75,5 +81,13 @@ class TestConcatOp4(TestConcatOp): pass +class TestConcatOp5(TestConcatOp): + def init_test_data(self): + self.x0 = np.random.random((2, 1, 4, 5)).astype('float32') + self.x1 = np.random.random((2, 2, 4, 5)).astype('float32') + self.x2 = np.random.random((2, 3, 4, 5)).astype('float32') + self.axis = -3 + + if __name__ == '__main__': unittest.main()