From 899c7d6b353c04565ebaa46d85de57348631f2e1 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Thu, 7 Sep 2017 04:16:32 -0700 Subject: [PATCH] pass unit test --- paddle/operators/reshape_op.cc | 3 ++- paddle/operators/reshape_op.h | 7 +++---- .../paddle/v2/framework/tests/test_reshape_op.py | 15 ++++++--------- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index 1b073a79bca..d75ec766325 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -38,6 +38,7 @@ class ReshapeOp : public framework::OperatorWithKernel { size_t in_size = framework::product(in->dims()); PADDLE_ENFORCE_EQ(shape_size, in_size, "The size of Input(X) mismatches with Attr(shape)."); + ctx.Output("Out")->Resize(in->dims()); } }; @@ -51,7 +52,7 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("shape", "Target shape of reshape operator."); AddComment(R"DOC(Reshape operator -The input tensor will be reshaped with Attr(shape). +Reshape Input(X) into the shape specified by Attr(shape). )DOC"); } }; diff --git a/paddle/operators/reshape_op.h b/paddle/operators/reshape_op.h index 22ede88b12c..61d502c836c 100644 --- a/paddle/operators/reshape_op.h +++ b/paddle/operators/reshape_op.h @@ -23,13 +23,13 @@ namespace operators { using Tensor = framework::Tensor; -template +template class ReshapeKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { auto* out = ctx.Output("Out"); auto* in = ctx.Input("X"); - out->mutable_data(in->place()); + out->mutable_data(ctx.GetPlace()); auto shape = ctx.Attr>("shape"); std::vector tmp; @@ -42,7 +42,7 @@ class ReshapeKernel : public framework::OpKernel { } }; -template +template class ReshapeGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { @@ -51,7 +51,6 @@ class ReshapeGradKernel : public framework::OpKernel { d_x->mutable_data(ctx.GetPlace()); auto in_dims = d_x->dims(); - d_x->CopyFrom(*d_out, ctx.GetPlace()); d_x->Resize(in_dims); } diff --git a/python/paddle/v2/framework/tests/test_reshape_op.py b/python/paddle/v2/framework/tests/test_reshape_op.py index c101b0df9ad..47970194353 100644 --- a/python/paddle/v2/framework/tests/test_reshape_op.py +++ b/python/paddle/v2/framework/tests/test_reshape_op.py @@ -1,6 +1,6 @@ import unittest import numpy as np -from gradient_checker import GradientChecker, create_op +from gradient_checker import GradientChecker, Operator from op_test_util import OpTestMeta @@ -9,19 +9,16 @@ class TestReshapeOp(unittest.TestCase): def setUp(self): self.type = "reshape" - self.inputs = {'X': np.random.random((2, 4)).astype("float32"), } - print self.inputs - self.attrs = {'shape': [4, 2]} + self.inputs = {'X': np.random.random((37, 51)).astype("float32"), } + self.attrs = {'shape': [51, 37]} self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} - print self.outputs class ReshapeGradOpTest(GradientChecker): def test_normal(self): - op = create_op("reshape") - inputs = {"X": np.random.random((2, 4)).astype("float32")} - attrs = {'shape': [4, 2]} - self.check_grad(op, inputs, attrs, set("X"), "Out") + op = Operator("reshape", X='X', Out='Out', shape=[5, 40]) + inputs = {"X": np.random.random((10, 20)).astype("float32")} + self.check_grad(op, inputs, set("X"), "Out") if __name__ == '__main__': -- GitLab