diff --git a/paddle/operators/transpose_op.cc b/paddle/operators/transpose_op.cc index 2fd86d900a0e27aa05714fe6a2f97a722c6bf710..107d80cde7a3e956ee35cf8373f774c565adc8e5 100644 --- a/paddle/operators/transpose_op.cc +++ b/paddle/operators/transpose_op.cc @@ -27,26 +27,29 @@ class TransposeOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"), "Input(Input) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"), + "Output(Output) should not be null"); auto input_dim = ctx.Input("Input")->dims(); - auto axis = ctx.Attr>("axis"); - size_t input_dim_size = input_dim.size(); + std::vector axis = ctx.Attr>("axis"); + size_t input_rank = input_dim.size(); size_t axis_size = axis.size(); - PADDLE_ENFORCE_EQ(input_dim_size, axis_size, - "the input tensor's dimension(%d) " + PADDLE_ENFORCE_EQ(input_rank, axis_size, + "the input tensor's rank(%d) " "should be equal to the axis's size(%d)", - input_dim_size, axis_size); - - std::vector axis_sorted(axis); - std::sort(axis_sorted.begin(), axis_sorted.end()); - for (size_t i = 0; i < axis_sorted.size(); i++) { - PADDLE_ENFORCE_EQ(axis_sorted[i], static_cast(i), - "the sorted axis should be [0, 1, ... dims - 1], " - "where the dims is the axis's size"); + input_rank, axis_size); + + std::vector count(axis_size, 0); + for (size_t i = 0; i < axis_size; i++) { + PADDLE_ENFORCE( + axis[i] < static_cast(axis_size) && ++count[axis[i]] == 1, + "Each element of Attribute axis should be a unique value " + "range from 0 to (dims - 1), " + "where the dims is the axis's size"); } framework::DDim output_dim(input_dim); - for (size_t i = 0; i < axis.size(); i++) { + for (size_t i = 0; i < axis_size; i++) { output_dim[i] = input_dim[axis[i]]; } ctx.Output("Output")->Resize(output_dim); @@ -60,12 +63,12 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( "Input", - "(Tensor)The input tensor, tensors with rank at most 7 are supported"); + "(Tensor)The input tensor, tensors with rank at most 6 are supported"); AddOutput("Output", "(Tensor)The output tensor"); AddAttr>( "axis", "(vector)a list of values, and the size of the list should be " - "the same with the input tensor dimensions, the tensor will " + "the same with the input tensor rank, the tensor will " "permute the axes according the the values given"); AddComment(R"DOC( The Tensor will be permuted according to the axis values given. @@ -97,18 +100,11 @@ class TransposeOpGrad : public framework::OperatorWithKernel { "Input(Input) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Output")), "Input(Output@GRAD) should not be null"); - auto input_dims = ctx.Input("Input")->dims(); + auto input_dim = ctx.Input("Input")->dims(); auto *input_grad = ctx.Output(framework::GradVarName("Input")); - auto output_grad_dims = - ctx.Input(framework::GradVarName("Output"))->dims(); - auto output_dims = ctx.Input("Output")->dims(); - - PADDLE_ENFORCE(output_grad_dims == output_dims, - "Output@GRAD dims must equal to Input(Input) dims"); - - input_grad->Resize(input_dims); + if (input_grad) input_grad->Resize(input_dim); } }; diff --git a/paddle/operators/transpose_op.h b/paddle/operators/transpose_op.h index 48d8c250a8c77c80d36626517dad3684bc59c86c..731b6a77016ca8d033b78e0b708cfb0ec5feb5ee 100644 --- a/paddle/operators/transpose_op.h +++ b/paddle/operators/transpose_op.h @@ -20,19 +20,19 @@ namespace paddle { namespace operators { -template +template void EigenTranspose(const framework::ExecutionContext& context, const framework::Tensor& in, framework::Tensor& out, std::vector axis) { - Eigen::array permute; - for (int i = 0; i < Dims; i++) { + Eigen::array permute; + for (int i = 0; i < Rank; i++) { permute[i] = axis[i]; } auto in_dim = in.dims(); auto out_dim = out.dims(); - auto eigen_in = framework::EigenTensor::From(in); - auto eigen_out = framework::EigenTensor::From(out); + auto eigen_in = framework::EigenTensor::From(in); + auto eigen_out = framework::EigenTensor::From(out); auto& dev = context.GetEigenDevice(); eigen_out.device(dev) = eigen_in.shuffle(permute); } @@ -45,10 +45,11 @@ class TransposeKernel : public framework::OpKernel { auto* output = context.Output("Output"); output->mutable_data(context.GetPlace()); - auto axis = context.Attr>("axis"); + std::vector axis = context.Attr>("axis"); int ndims = axis.size(); switch (ndims) { case 1: + EigenTranspose(context, *input, *output, axis); break; case 2: EigenTranspose(context, *input, *output, axis); @@ -79,37 +80,46 @@ class TransposeGradKernel : public framework::OpKernel { context.Input(framework::GradVarName("Output")); auto* input_grad = context.Output(framework::GradVarName("Input")); - input_grad->mutable_data(context.GetPlace()); - - auto axis_temp = context.Attr>("axis"); - std::vector axis(axis_temp); - - for (size_t i = 0; i < axis.size(); i++) { - axis[axis_temp[i]] = i; - } - - int ndims = axis.size(); - - switch (ndims) { - case 1: - break; - case 2: - EigenTranspose(context, *output_grad, *input_grad, axis); - break; - case 3: - EigenTranspose(context, *output_grad, *input_grad, axis); - break; - case 4: - EigenTranspose(context, *output_grad, *input_grad, axis); - break; - case 5: - EigenTranspose(context, *output_grad, *input_grad, axis); - break; - case 6: - EigenTranspose(context, *output_grad, *input_grad, axis); - break; - default: - PADDLE_THROW("Tensors with rank at most 6 are supported"); + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + + std::vector axis = context.Attr>("axis"); + std::vector reversed_axis(axis); + + for (size_t i = 0; i < axis.size(); i++) { + reversed_axis[axis[i]] = i; + } + + int ndims = axis.size(); + + switch (ndims) { + case 1: + EigenTranspose(context, *output_grad, *input_grad, + reversed_axis); + break; + case 2: + EigenTranspose(context, *output_grad, *input_grad, + reversed_axis); + break; + case 3: + EigenTranspose(context, *output_grad, *input_grad, + reversed_axis); + break; + case 4: + EigenTranspose(context, *output_grad, *input_grad, + reversed_axis); + break; + case 5: + EigenTranspose(context, *output_grad, *input_grad, + reversed_axis); + break; + case 6: + EigenTranspose(context, *output_grad, *input_grad, + reversed_axis); + break; + default: + PADDLE_THROW("Tensors with rank at most 6 are supported"); + } } } }; diff --git a/python/paddle/v2/framework/tests/test_transpose_op.py b/python/paddle/v2/framework/tests/test_transpose_op.py index 8e7e12910d44055b90ab955b4b2df61eac38c801..373a988f5f8c2a01d481b3149d9a00e7d536078b 100644 --- a/python/paddle/v2/framework/tests/test_transpose_op.py +++ b/python/paddle/v2/framework/tests/test_transpose_op.py @@ -22,6 +22,12 @@ class TestTransposeOp(OpTest): self.axis = (1, 0) +class TestCase0(TestTransposeOp): + def initTestCase(self): + self.shape = (3, ) + self.axis = (0, ) + + class TestCase1(TestTransposeOp): def initTestCase(self): self.shape = (3, 4, 5)