提交 0cd9b8c0 编写于 作者: X xzl

modify the input\output name to X\Out

上级 a9a7ba3c
...@@ -25,19 +25,18 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -25,19 +25,18 @@ class TransposeOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
"Input(Input) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"), "Output(Out) should not be null");
"Output(Output) should not be null"); auto x_dims = ctx.Input<Tensor>("X")->dims();
auto input_dim = ctx.Input<Tensor>("Input")->dims();
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis"); std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
size_t input_rank = input_dim.size(); size_t x_rank = x_dims.size();
size_t axis_size = axis.size(); size_t axis_size = axis.size();
PADDLE_ENFORCE_EQ(input_rank, axis_size, PADDLE_ENFORCE_EQ(x_rank, axis_size,
"the input tensor's rank(%d) " "the input tensor's rank(%d) "
"should be equal to the axis's size(%d)", "should be equal to the axis's size(%d)",
input_rank, axis_size); x_rank, axis_size);
std::vector<int> count(axis_size, 0); std::vector<int> count(axis_size, 0);
for (size_t i = 0; i < axis_size; i++) { for (size_t i = 0; i < axis_size; i++) {
...@@ -48,11 +47,11 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -48,11 +47,11 @@ class TransposeOp : public framework::OperatorWithKernel {
"where the dims is the axis's size"); "where the dims is the axis's size");
} }
framework::DDim output_dim(input_dim); framework::DDim out_dims(x_dims);
for (size_t i = 0; i < axis_size; i++) { for (size_t i = 0; i < axis_size; i++) {
output_dim[i] = input_dim[axis[i]]; out_dims[i] = x_dims[axis[i]];
} }
ctx.Output<framework::LoDTensor>("Output")->Resize(output_dim); ctx.Output<framework::LoDTensor>("Out")->Resize(out_dims);
} }
}; };
...@@ -62,9 +61,9 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -62,9 +61,9 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput(
"Input", "X",
"(Tensor)The input tensor, tensors with rank at most 6 are supported"); "(Tensor)The input tensor, tensors with rank at most 6 are supported");
AddOutput("Output", "(Tensor)The output tensor"); AddOutput("Out", "(Tensor)The output tensor");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"axis", "axis",
"(vector<int>)a list of values, and the size of the list should be " "(vector<int>)a list of values, and the size of the list should be "
...@@ -96,15 +95,14 @@ class TransposeOpGrad : public framework::OperatorWithKernel { ...@@ -96,15 +95,14 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
"Input(Input) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Output")), "Input(Out@GRAD) should not be null");
"Input(Output@GRAD) should not be null"); auto x_dims = ctx.Input<Tensor>("X")->dims();
auto input_dim = ctx.Input<Tensor>("Input")->dims(); auto *x_grad =
auto *input_grad = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
ctx.Output<framework::LoDTensor>(framework::GradVarName("Input"));
if (x_grad) x_grad->Resize(x_dims);
if (input_grad) input_grad->Resize(input_dim);
} }
}; };
......
...@@ -41,30 +41,30 @@ template <typename Place, typename T> ...@@ -41,30 +41,30 @@ template <typename Place, typename T>
class TransposeKernel : public framework::OpKernel { class TransposeKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<framework::Tensor>("Input"); auto* x = context.Input<framework::Tensor>("X");
auto* output = context.Output<framework::Tensor>("Output"); auto* out = context.Output<framework::Tensor>("Out");
output->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
std::vector<int> axis = context.Attr<std::vector<int>>("axis"); std::vector<int> axis = context.Attr<std::vector<int>>("axis");
int ndims = axis.size(); int ndims = axis.size();
switch (ndims) { switch (ndims) {
case 1: case 1:
EigenTranspose<Place, T, 1>(context, *input, *output, axis); EigenTranspose<Place, T, 1>(context, *x, *out, axis);
break; break;
case 2: case 2:
EigenTranspose<Place, T, 2>(context, *input, *output, axis); EigenTranspose<Place, T, 2>(context, *x, *out, axis);
break; break;
case 3: case 3:
EigenTranspose<Place, T, 3>(context, *input, *output, axis); EigenTranspose<Place, T, 3>(context, *x, *out, axis);
break; break;
case 4: case 4:
EigenTranspose<Place, T, 4>(context, *input, *output, axis); EigenTranspose<Place, T, 4>(context, *x, *out, axis);
break; break;
case 5: case 5:
EigenTranspose<Place, T, 5>(context, *input, *output, axis); EigenTranspose<Place, T, 5>(context, *x, *out, axis);
break; break;
case 6: case 6:
EigenTranspose<Place, T, 6>(context, *input, *output, axis); EigenTranspose<Place, T, 6>(context, *x, *out, axis);
break; break;
default: default:
PADDLE_THROW("Tensors with rank at most 6 are supported"); PADDLE_THROW("Tensors with rank at most 6 are supported");
...@@ -76,12 +76,12 @@ template <typename Place, typename T> ...@@ -76,12 +76,12 @@ template <typename Place, typename T>
class TransposeGradKernel : public framework::OpKernel { class TransposeGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* output_grad = auto* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Output")); context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* input_grad = auto* x_grad =
context.Output<framework::Tensor>(framework::GradVarName("Input")); context.Output<framework::Tensor>(framework::GradVarName("X"));
if (input_grad) { if (x_grad) {
input_grad->mutable_data<T>(context.GetPlace()); x_grad->mutable_data<T>(context.GetPlace());
std::vector<int> axis = context.Attr<std::vector<int>>("axis"); std::vector<int> axis = context.Attr<std::vector<int>>("axis");
std::vector<int> reversed_axis(axis); std::vector<int> reversed_axis(axis);
...@@ -94,27 +94,27 @@ class TransposeGradKernel : public framework::OpKernel { ...@@ -94,27 +94,27 @@ class TransposeGradKernel : public framework::OpKernel {
switch (ndims) { switch (ndims) {
case 1: case 1:
EigenTranspose<Place, T, 1>(context, *output_grad, *input_grad, EigenTranspose<Place, T, 1>(context, *out_grad, *x_grad,
reversed_axis); reversed_axis);
break; break;
case 2: case 2:
EigenTranspose<Place, T, 2>(context, *output_grad, *input_grad, EigenTranspose<Place, T, 2>(context, *out_grad, *x_grad,
reversed_axis); reversed_axis);
break; break;
case 3: case 3:
EigenTranspose<Place, T, 3>(context, *output_grad, *input_grad, EigenTranspose<Place, T, 3>(context, *out_grad, *x_grad,
reversed_axis); reversed_axis);
break; break;
case 4: case 4:
EigenTranspose<Place, T, 4>(context, *output_grad, *input_grad, EigenTranspose<Place, T, 4>(context, *out_grad, *x_grad,
reversed_axis); reversed_axis);
break; break;
case 5: case 5:
EigenTranspose<Place, T, 5>(context, *output_grad, *input_grad, EigenTranspose<Place, T, 5>(context, *out_grad, *x_grad,
reversed_axis); reversed_axis);
break; break;
case 6: case 6:
EigenTranspose<Place, T, 6>(context, *output_grad, *input_grad, EigenTranspose<Place, T, 6>(context, *out_grad, *x_grad,
reversed_axis); reversed_axis);
break; break;
default: default:
......
...@@ -7,15 +7,15 @@ class TestTransposeOp(OpTest): ...@@ -7,15 +7,15 @@ class TestTransposeOp(OpTest):
def setUp(self): def setUp(self):
self.initTestCase() self.initTestCase()
self.op_type = "transpose" self.op_type = "transpose"
self.inputs = {'Input': np.random.random(self.shape).astype("float32")} self.inputs = {'X': np.random.random(self.shape).astype("float32")}
self.attrs = {'axis': list(self.axis)} self.attrs = {'axis': list(self.axis)}
self.outputs = {'Output': self.inputs['Input'].transpose(self.axis)} self.outputs = {'Out': self.inputs['X'].transpose(self.axis)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Input'], 'Output') self.check_grad(['X'], 'Out')
def initTestCase(self): def initTestCase(self):
self.shape = (3, 4) self.shape = (3, 4)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册