diff --git a/paddle/operators/transpose_op.cc b/paddle/operators/transpose_op.cc index 107d80cde7a3e956ee35cf8373f774c565adc8e5..babf2f561c31d5436fe1611c576e6e7fc04401db 100644 --- a/paddle/operators/transpose_op.cc +++ b/paddle/operators/transpose_op.cc @@ -25,19 +25,18 @@ class TransposeOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"), - "Input(Input) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"), - "Output(Output) should not be null"); - auto input_dim = ctx.Input("Input")->dims(); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) should not be null"); + auto x_dims = ctx.Input("X")->dims(); std::vector axis = ctx.Attr>("axis"); - size_t input_rank = input_dim.size(); + size_t x_rank = x_dims.size(); size_t axis_size = axis.size(); - PADDLE_ENFORCE_EQ(input_rank, axis_size, + PADDLE_ENFORCE_EQ(x_rank, axis_size, "the input tensor's rank(%d) " "should be equal to the axis's size(%d)", - input_rank, axis_size); + x_rank, axis_size); std::vector count(axis_size, 0); for (size_t i = 0; i < axis_size; i++) { @@ -48,11 +47,11 @@ class TransposeOp : public framework::OperatorWithKernel { "where the dims is the axis's size"); } - framework::DDim output_dim(input_dim); + framework::DDim out_dims(x_dims); for (size_t i = 0; i < axis_size; i++) { - output_dim[i] = input_dim[axis[i]]; + out_dims[i] = x_dims[axis[i]]; } - ctx.Output("Output")->Resize(output_dim); + ctx.Output("Out")->Resize(out_dims); } }; @@ -62,9 +61,9 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput( - "Input", + "X", "(Tensor)The input tensor, tensors with rank at most 6 are supported"); - AddOutput("Output", "(Tensor)The output tensor"); + AddOutput("Out", "(Tensor)The output tensor"); AddAttr>( "axis", "(vector)a list of values, and the size of the list should be " @@ -96,15 +95,14 @@ class TransposeOpGrad : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"), - "Input(Input) should not be null"); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Output")), - "Input(Output@GRAD) should not be null"); - auto input_dim = ctx.Input("Input")->dims(); - auto *input_grad = - ctx.Output(framework::GradVarName("Input")); - - if (input_grad) input_grad->Resize(input_dim); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + auto *x_grad = + ctx.Output(framework::GradVarName("X")); + + if (x_grad) x_grad->Resize(x_dims); } }; diff --git a/paddle/operators/transpose_op.h b/paddle/operators/transpose_op.h index 731b6a77016ca8d033b78e0b708cfb0ec5feb5ee..ea299dce72ad340b0a65ee50582dc156b5ad7abb 100644 --- a/paddle/operators/transpose_op.h +++ b/paddle/operators/transpose_op.h @@ -41,30 +41,30 @@ template class TransposeKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* input = context.Input("Input"); - auto* output = context.Output("Output"); - output->mutable_data(context.GetPlace()); + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + out->mutable_data(context.GetPlace()); std::vector axis = context.Attr>("axis"); int ndims = axis.size(); switch (ndims) { case 1: - EigenTranspose(context, *input, *output, axis); + EigenTranspose(context, *x, *out, axis); break; case 2: - EigenTranspose(context, *input, *output, axis); + EigenTranspose(context, *x, *out, axis); break; case 3: - EigenTranspose(context, *input, *output, axis); + EigenTranspose(context, *x, *out, axis); break; case 4: - EigenTranspose(context, *input, *output, axis); + EigenTranspose(context, *x, *out, axis); break; case 5: - EigenTranspose(context, *input, *output, axis); + EigenTranspose(context, *x, *out, axis); break; case 6: - EigenTranspose(context, *input, *output, axis); + EigenTranspose(context, *x, *out, axis); break; default: PADDLE_THROW("Tensors with rank at most 6 are supported"); @@ -76,12 +76,12 @@ template class TransposeGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* output_grad = - context.Input(framework::GradVarName("Output")); - auto* input_grad = - context.Output(framework::GradVarName("Input")); - if (input_grad) { - input_grad->mutable_data(context.GetPlace()); + auto* out_grad = + context.Input(framework::GradVarName("Out")); + auto* x_grad = + context.Output(framework::GradVarName("X")); + if (x_grad) { + x_grad->mutable_data(context.GetPlace()); std::vector axis = context.Attr>("axis"); std::vector reversed_axis(axis); @@ -94,27 +94,27 @@ class TransposeGradKernel : public framework::OpKernel { switch (ndims) { case 1: - EigenTranspose(context, *output_grad, *input_grad, + EigenTranspose(context, *out_grad, *x_grad, reversed_axis); break; case 2: - EigenTranspose(context, *output_grad, *input_grad, + EigenTranspose(context, *out_grad, *x_grad, reversed_axis); break; case 3: - EigenTranspose(context, *output_grad, *input_grad, + EigenTranspose(context, *out_grad, *x_grad, reversed_axis); break; case 4: - EigenTranspose(context, *output_grad, *input_grad, + EigenTranspose(context, *out_grad, *x_grad, reversed_axis); break; case 5: - EigenTranspose(context, *output_grad, *input_grad, + EigenTranspose(context, *out_grad, *x_grad, reversed_axis); break; case 6: - EigenTranspose(context, *output_grad, *input_grad, + EigenTranspose(context, *out_grad, *x_grad, reversed_axis); break; default: diff --git a/python/paddle/v2/framework/tests/test_transpose_op.py b/python/paddle/v2/framework/tests/test_transpose_op.py index 373a988f5f8c2a01d481b3149d9a00e7d536078b..9409cbaa00f792b60d5950556b869108aa732478 100644 --- a/python/paddle/v2/framework/tests/test_transpose_op.py +++ b/python/paddle/v2/framework/tests/test_transpose_op.py @@ -7,15 +7,15 @@ class TestTransposeOp(OpTest): def setUp(self): self.initTestCase() self.op_type = "transpose" - self.inputs = {'Input': np.random.random(self.shape).astype("float32")} + self.inputs = {'X': np.random.random(self.shape).astype("float32")} self.attrs = {'axis': list(self.axis)} - self.outputs = {'Output': self.inputs['Input'].transpose(self.axis)} + self.outputs = {'Out': self.inputs['X'].transpose(self.axis)} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['Input'], 'Output') + self.check_grad(['X'], 'Out') def initTestCase(self): self.shape = (3, 4)