From dd64349a9213b419c6a50c81e06e2d6a8fa9ebd5 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Mon, 11 Sep 2017 00:06:06 -0700 Subject: [PATCH] refine reshape operator --- paddle/operators/reshape_op.cc | 15 +++++++++------ paddle/operators/reshape_op.h | 10 ++++------ .../paddle/v2/framework/tests/test_reshape_op.py | 16 ++++++++++++++-- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index d75ec76632..37cbecbf25 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -29,14 +29,17 @@ class ReshapeOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { auto *in = ctx.Input("X"); auto shape = ctx.Attr>("shape"); - PADDLE_ENFORCE_EQ((unsigned)shape.size(), in->dims().size(), - "The dimension of Input(X) mismatches with Attr(shape)."); - size_t shape_size = 1; + int64_t capacity = -1; for (auto dim : shape) { - shape_size *= dim; + PADDLE_ENFORCE(dim > 0, "Each dimension of shape must be positive."); + if (capacity < 0) { + capacity = dim; + } else { + capacity *= dim; + } } - size_t in_size = framework::product(in->dims()); - PADDLE_ENFORCE_EQ(shape_size, in_size, + int64_t in_size = framework::product(in->dims()); + PADDLE_ENFORCE_EQ(capacity, in_size, "The size of Input(X) mismatches with Attr(shape)."); ctx.Output("Out")->Resize(in->dims()); } diff --git a/paddle/operators/reshape_op.h b/paddle/operators/reshape_op.h index 61d502c836..0e920329d9 100644 --- a/paddle/operators/reshape_op.h +++ b/paddle/operators/reshape_op.h @@ -21,14 +21,12 @@ namespace paddle { namespace operators { -using Tensor = framework::Tensor; - template class ReshapeKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { - auto* out = ctx.Output("Out"); - auto* in = ctx.Input("X"); + auto* out = ctx.Output("Out"); + auto* in = ctx.Input("X"); out->mutable_data(ctx.GetPlace()); auto shape = ctx.Attr>("shape"); @@ -46,8 +44,8 @@ template class ReshapeGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { - auto* d_out = ctx.Input(framework::GradVarName("Out")); - auto* d_x = ctx.Output(framework::GradVarName("X")); + auto* d_out = ctx.Input(framework::GradVarName("Out")); + auto* d_x = ctx.Output(framework::GradVarName("X")); d_x->mutable_data(ctx.GetPlace()); auto in_dims = d_x->dims(); diff --git a/python/paddle/v2/framework/tests/test_reshape_op.py b/python/paddle/v2/framework/tests/test_reshape_op.py index 4797019435..df7d913ba4 100644 --- a/python/paddle/v2/framework/tests/test_reshape_op.py +++ b/python/paddle/v2/framework/tests/test_reshape_op.py @@ -10,15 +10,27 @@ class TestReshapeOp(unittest.TestCase): def setUp(self): self.type = "reshape" self.inputs = {'X': np.random.random((37, 51)).astype("float32"), } - self.attrs = {'shape': [51, 37]} + self.attrs = {'shape': [51 * 37]} self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} -class ReshapeGradOpTest(GradientChecker): +class TestReshapeGradOp(GradientChecker): + """ def test_normal(self): op = Operator("reshape", X='X', Out='Out', shape=[5, 40]) inputs = {"X": np.random.random((10, 20)).astype("float32")} self.check_grad(op, inputs, set("X"), "Out") + """ + + def setUp(self): + self.op = Operator("reshape", X='X', Out='Out', shape=[5, 40]) + self.inputs = {"X": np.random.random((10, 20)).astype("float32")} + + def test_normal(self): + self.check_grad(self.op, self.inputs, ["X"], "Out") + + def test_dev_compare(self): + self.compare_grad(self.op, self.inputs) if __name__ == '__main__': -- GitLab