diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index c47df734056ca0daf072c85064eb50ee23c61476..2ad49437a97be5c69f81f0242b6ee23b90df955b 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -32,7 +32,6 @@ class ReshapeOp : public framework::OperatorWithKernel { "Output(Out) of ReshapeOp should not be null."); const std::vector &shape = ctx->Attrs().Get>("shape"); - PADDLE_ENFORCE_EQ(shape.empty(), ctx->HasInput("Shape"), "The shape information can only be set by Attr(shape) or " "by Input(Shape). Attr(shape) and Input(Shape) cannot be " @@ -41,27 +40,29 @@ class ReshapeOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); if (ctx->HasInput("Shape")) { + // The shape information in given by Input(Shape). auto shape_dims = ctx->GetInputDim("Shape"); PADDLE_ENFORCE(shape_dims.size() == 2UL && shape_dims[0] == 1UL, "The Input(Label) should be a 2-D tensor with the 1st " "dimensions fixed to 1 (a row vector)."); - // The actual output shape will be set at runtime, here temporially the + // The actual output shape will be set at runtime, here temporially set // the shape of output the same as the shape of input. ctx->SetOutputDim("Out", x_dims); } else { + // The shape information in given by Attr(shape). std::vector output_shape; ValidateShape(shape, framework::product(x_dims), output_shape); auto out_dims = framework::make_ddim(output_shape); ctx->SetOutputDim("Out", out_dims); - } - if (shape[0] == x_dims[0]) { - // Only pass LoD when the first dimension of output and input are the - // same. - ctx->ShareLoD("X", /*->*/ "Out"); + if (shape[0] == x_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", /*->*/ "Out"); + } } } @@ -94,6 +95,14 @@ class ReshapeOp : public framework::OperatorWithKernel { [](int a) { return static_cast(a); }); if (neg_dims_idx.size()) output_shape[neg_dims_idx[0]] = inferred_dim; } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } }; class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { @@ -101,11 +110,13 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { ReshapeOpMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input tensor of reshape operator."); - AddInput("Shape", "a 1-D tensor that provides the shape information.") + AddInput( + "Shape", + "Tensor, a 1-D tensor that provides the shape information.") .AsDispensable(); AddOutput("Out", "The output tensor of reshape operator."); - AddAttr>("shape", - "(vector) Target shape of reshape operator.") + AddAttr>( + "shape", "(std::vector) Target shape of reshape operator.") .SetDefault(std::vector()); AddComment(R"DOC( Reshape Operator. @@ -139,6 +150,14 @@ class ReshapeGradOp : public framework::OperatorWithKernel { "Input(Out@GRAD) shouldn't be null."); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } }; } // namespace operators diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h index fc0885c1494ccbcc52f5e0ce89bb3fc01778d283..0c97dc639f07522a5ffd5e5cb27426d65bc9ab75 100644 --- a/paddle/fluid/operators/reshape_op.h +++ b/paddle/fluid/operators/reshape_op.h @@ -33,9 +33,6 @@ class ReshapeKernel : public framework::OpKernel { std::vector output_shape; ValidateShape(*shape, framework::product(in->dims()), output_shape); - for (auto d : output_shape) std::cout << d << " "; - std::cout << std::endl; - out_dims = framework::make_ddim(output_shape); } else { out_dims = out->dims(); @@ -85,11 +82,18 @@ class ReshapeGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const { auto* d_out = ctx.Input(framework::GradVarName("Out")); auto* d_x = ctx.Output(framework::GradVarName("X")); + d_x->mutable_data(ctx.GetPlace()); + bool inplace = ctx.Attr("inplace"); auto in_dims = d_x->dims(); - framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); - d_x->Resize(in_dims); + if (!inplace) { + framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); + d_x->Resize(in_dims); + } else { + d_x->ShareDataWith(*d_out); + d_x->Resize(in_dims); + } } }; } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/test_reshape_op.py b/python/paddle/fluid/tests/unittests/test_reshape_op.py index ae1cca0c3ef00baef0c0b5eb29001ec33133c6b5..dc96aed8dbe55f2f40d9ffb21f569a2e00ac6425 100644 --- a/python/paddle/fluid/tests/unittests/test_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/test_reshape_op.py @@ -33,7 +33,8 @@ from op_test import OpTest # # def test_check_grad(self): # self.check_grad(["X"], "Out") - +# +# # class TestReshapeOpDimInfer1(OpTest): # def setUp(self): # self.op_type = "reshape" @@ -56,7 +57,8 @@ class TestReshapeOp2(OpTest): self.op_type = "reshape" self.inputs = { "X": np.random.random(ori_shape).astype("float32"), - "Shape": np.array(new_shape) + "Shape": np.array( + new_shape, dtype="int64") } self.outputs = {"Out": self.inputs["X"].reshape(new_shape[0])} @@ -67,5 +69,32 @@ class TestReshapeOp2(OpTest): self.check_grad(["X"], "Out") +# class TestReshapeOpInplace(OpTest): +# def setUp(self): +# self.op_type = "reshape" +# self.inputs = {'X': np.random.random((10, 20)).astype("float32")} +# self.attrs = {'shape': [10 * 20], 'inplace': True} +# self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} +# +# def test_check_output(self): +# self.check_output() +# +# def test_check_grad(self): +# self.check_grad(["X"], "Out") +# +# +# class TestReshapeOpDimInferInplace(OpTest): +# def setUp(self): +# self.op_type = "reshape" +# self.inputs = {'X': np.random.random((10, 20)).astype("float32")} +# self.attrs = {'shape': [4, -1, 5], 'inplace': True} +# self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} +# +# def test_check_output(self): +# self.check_output() +# +# def test_check_grad(self): +# self.check_grad(["X"], "Out") + if __name__ == "__main__": unittest.main()