From 7ca77a90ac3c835cb3a5fe3d97a9ac25eebd250e Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 25 Mar 2020 01:31:12 -0500 Subject: [PATCH] add Tensor::IsSharedBufferWith method, test=develop (#23175) --- paddle/fluid/framework/details/share_tensor_buffer_functor.cc | 2 +- paddle/fluid/framework/tensor.h | 4 ++++ paddle/fluid/operators/elementwise/elementwise_op_function.h | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/details/share_tensor_buffer_functor.cc b/paddle/fluid/framework/details/share_tensor_buffer_functor.cc index fb43bfbf34..6fdec553f3 100644 --- a/paddle/fluid/framework/details/share_tensor_buffer_functor.cc +++ b/paddle/fluid/framework/details/share_tensor_buffer_functor.cc @@ -104,7 +104,7 @@ void ShareTensorBufferFunctor::operator()(Scope *exec_scope) { // If in_var is inplaced in the previous batch and we want to fetch // in_var in the current batch, we have to reset memory of out_var // to avoid wrong calculation result. - if (in_tensor.Holder() == out_tensor->Holder()) { + if (out_tensor->IsSharedBufferWith(in_tensor)) { VLOG(1) << "Clear " << out_var_names_[i] << " because you may want to fetch an inplaced variable " << in_var_info->Name() diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index f4d3457003..0b95e585d7 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -160,6 +160,10 @@ class Tensor { offset_ = tensor.offset_; } + bool IsSharedBufferWith(const Tensor& src) const { + return holder_ && holder_ == src.Holder(); + } + const std::shared_ptr& Holder() const { return holder_; } size_t offset() const { return offset_; } diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 3710e008ca..23afa75279 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -1100,7 +1100,7 @@ void CommonElementwiseBroadcastBackward( // for inplace strategy. memset will make dx and dout clear and get wrong // result. - if (dx && dout.Holder() == dx->Holder()) { + if (dx && dx->IsSharedBufferWith(dout)) { dx->clear(); dx->mutable_data(x_dims, ctx.GetPlace()); } -- GitLab