未验证 提交 049383c6 编写于 作者: Y Yan Chunwei 提交者: GitHub

add inplace to reshape (#8747)

上级 42e65a20
...@@ -84,6 +84,9 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -84,6 +84,9 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int>>("shape", AddAttr<std::vector<int>>("shape",
"(vector<int>) " "(vector<int>) "
"Target shape of reshape operator."); "Target shape of reshape operator.");
AddAttr<bool>("inplace",
"Change the source tensor's shape without copy memory.")
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
Reshape Operator. Reshape Operator.
......
...@@ -26,10 +26,16 @@ class ReshapeKernel : public framework::OpKernel<T> { ...@@ -26,10 +26,16 @@ class ReshapeKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
bool inplace = ctx.Attr<bool>("inplace");
auto out_dims = out->dims(); auto out_dims = out->dims();
if (!inplace) {
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out); framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out);
out->Resize(out_dims); out->Resize(out_dims);
} else {
out->ShareDataWith(*in);
out->Resize(out_dims);
}
} }
}; };
...@@ -40,10 +46,16 @@ class ReshapeGradKernel : public framework::OpKernel<T> { ...@@ -40,10 +46,16 @@ class ReshapeGradKernel : public framework::OpKernel<T> {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
bool inplace = ctx.Attr<bool>("inplace");
auto in_dims = d_x->dims(); auto in_dims = d_x->dims();
if (!inplace) {
framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x);
d_x->Resize(in_dims); d_x->Resize(in_dims);
} else {
d_x->ShareDataWith(*d_out);
d_x->Resize(in_dims);
}
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -45,5 +45,33 @@ class TestReshapeOpDimInfer(OpTest): ...@@ -45,5 +45,33 @@ class TestReshapeOpDimInfer(OpTest):
self.check_grad(["X"], "Out") self.check_grad(["X"], "Out")
class TestReshapeOpInplace(OpTest):
def setUp(self):
self.op_type = "reshape"
self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
self.attrs = {'shape': [10 * 20], 'inplace': True}
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
class TestReshapeOpDimInferInplace(OpTest):
def setUp(self):
self.op_type = "reshape"
self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
self.attrs = {'shape': [4, -1, 5], 'inplace': True}
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册