diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index 1b073a79bcaf48663a80c7619af4a334d32cbfc3..d75ec766325ce8ae30918b7efd7171c17b3b8be2 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -38,6 +38,7 @@ class ReshapeOp : public framework::OperatorWithKernel { size_t in_size = framework::product(in->dims()); PADDLE_ENFORCE_EQ(shape_size, in_size, "The size of Input(X) mismatches with Attr(shape)."); + ctx.Output("Out")->Resize(in->dims()); } }; @@ -51,7 +52,7 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("shape", "Target shape of reshape operator."); AddComment(R"DOC(Reshape operator -The input tensor will be reshaped with Attr(shape). +Reshape Input(X) into the shape specified by Attr(shape). )DOC"); } }; diff --git a/paddle/operators/reshape_op.h b/paddle/operators/reshape_op.h index 22ede88b12c7704e6637b87fd9cf792851a2302f..61d502c836cb06fc741cca2e1702c0a39fd400df 100644 --- a/paddle/operators/reshape_op.h +++ b/paddle/operators/reshape_op.h @@ -23,13 +23,13 @@ namespace operators { using Tensor = framework::Tensor; -template +template class ReshapeKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* out = ctx.Output("Out"); auto* in = ctx.Input("X"); - out->mutable_data(in->place()); + out->mutable_data(ctx.GetPlace()); auto shape = ctx.Attr>("shape"); std::vector tmp; @@ -42,7 +42,7 @@ class ReshapeKernel : public framework::OpKernel { } }; -template +template class ReshapeGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { @@ -51,7 +51,6 @@ class ReshapeGradKernel : public framework::OpKernel { d_x->mutable_data(ctx.GetPlace()); auto in_dims = d_x->dims(); - d_x->CopyFrom(*d_out, ctx.GetPlace()); d_x->Resize(in_dims); } diff --git a/python/paddle/v2/framework/tests/test_reshape_op.py b/python/paddle/v2/framework/tests/test_reshape_op.py index c101b0df9ad4ef1ab919d385a08f877736faab76..47970194353f4f64613f97c1deaed1f52c47af3c 100644 --- a/python/paddle/v2/framework/tests/test_reshape_op.py +++ b/python/paddle/v2/framework/tests/test_reshape_op.py @@ -1,6 +1,6 @@ import unittest import numpy as np -from gradient_checker import GradientChecker, create_op +from gradient_checker import GradientChecker, Operator from op_test_util import OpTestMeta @@ -9,19 +9,16 @@ class TestReshapeOp(unittest.TestCase): def setUp(self): self.type = "reshape" - self.inputs = {'X': np.random.random((2, 4)).astype("float32"), } - print self.inputs - self.attrs = {'shape': [4, 2]} + self.inputs = {'X': np.random.random((37, 51)).astype("float32"), } + self.attrs = {'shape': [51, 37]} self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} - print self.outputs class ReshapeGradOpTest(GradientChecker): def test_normal(self): - op = create_op("reshape") - inputs = {"X": np.random.random((2, 4)).astype("float32")} - attrs = {'shape': [4, 2]} - self.check_grad(op, inputs, attrs, set("X"), "Out") + op = Operator("reshape", X='X', Out='Out', shape=[5, 40]) + inputs = {"X": np.random.random((10, 20)).astype("float32")} + self.check_grad(op, inputs, set("X"), "Out") if __name__ == '__main__':