diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 3580932356fd5f29d5e4d00a70e64c207c64e41e..832509641cc3d5178ff090e05437484d395bfe51 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -84,6 +84,9 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr>("shape", "(vector) " "Target shape of reshape operator."); + AddAttr("inplace", + "Change the source tensor's shape without copy memory.") + .SetDefault(true); AddComment(R"DOC( Reshape Operator. diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h index 1357bce4b7e7dcde616ec15df6617db5573df1d3..eacb0a0cf21a60ffbdef5787434859ac549388bc 100644 --- a/paddle/fluid/operators/reshape_op.h +++ b/paddle/fluid/operators/reshape_op.h @@ -26,10 +26,16 @@ class ReshapeKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const { auto* out = ctx.Output("Out"); auto* in = ctx.Input("X"); + bool inplace = ctx.Attr("inplace"); auto out_dims = out->dims(); - out->mutable_data(ctx.GetPlace()); - framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out); - out->Resize(out_dims); + if (!inplace) { + out->mutable_data(ctx.GetPlace()); + framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out); + out->Resize(out_dims); + } else { + out->ShareDataWith(*in); + out->Resize(out_dims); + } } }; @@ -40,10 +46,16 @@ class ReshapeGradKernel : public framework::OpKernel { auto* d_out = ctx.Input(framework::GradVarName("Out")); auto* d_x = ctx.Output(framework::GradVarName("X")); d_x->mutable_data(ctx.GetPlace()); + bool inplace = ctx.Attr("inplace"); auto in_dims = d_x->dims(); - framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); - d_x->Resize(in_dims); + if (!inplace) { + framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); + d_x->Resize(in_dims); + } else { + d_x->ShareDataWith(*d_out); + d_x->Resize(in_dims); + } } }; } // namespace operators diff --git a/python/paddle/fluid/tests/unittests/test_reshape_op.py b/python/paddle/fluid/tests/unittests/test_reshape_op.py index 6d1aa549d592997359342bab9233685f5d6108fb..11f35c74d41146118525a5efa6c211d528e255fe 100644 --- a/python/paddle/fluid/tests/unittests/test_reshape_op.py +++ b/python/paddle/fluid/tests/unittests/test_reshape_op.py @@ -45,5 +45,33 @@ class TestReshapeOpDimInfer(OpTest): self.check_grad(["X"], "Out") +class TestReshapeOpInplace(OpTest): + def setUp(self): + self.op_type = "reshape" + self.inputs = {'X': np.random.random((10, 20)).astype("float32")} + self.attrs = {'shape': [10 * 20], 'inplace': True} + self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +class TestReshapeOpDimInferInplace(OpTest): + def setUp(self): + self.op_type = "reshape" + self.inputs = {'X': np.random.random((10, 20)).astype("float32")} + self.attrs = {'shape': [4, -1, 5], 'inplace': True} + self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + if __name__ == '__main__': unittest.main()