未验证 提交 7ca77a90 编写于 作者: Z Zeng Jinle 提交者: GitHub

add Tensor::IsSharedBufferWith method, test=develop (#23175)

上级 27870412
......@@ -104,7 +104,7 @@ void ShareTensorBufferFunctor::operator()(Scope *exec_scope) {
// If in_var is inplaced in the previous batch and we want to fetch
// in_var in the current batch, we have to reset memory of out_var
// to avoid wrong calculation result.
if (in_tensor.Holder() == out_tensor->Holder()) {
if (out_tensor->IsSharedBufferWith(in_tensor)) {
VLOG(1) << "Clear " << out_var_names_[i]
<< " because you may want to fetch an inplaced variable "
<< in_var_info->Name()
......
......@@ -160,6 +160,10 @@ class Tensor {
offset_ = tensor.offset_;
}
bool IsSharedBufferWith(const Tensor& src) const {
return holder_ && holder_ == src.Holder();
}
const std::shared_ptr<memory::Allocation>& Holder() const { return holder_; }
size_t offset() const { return offset_; }
......
......@@ -1100,7 +1100,7 @@ void CommonElementwiseBroadcastBackward(
// for inplace strategy. memset will make dx and dout clear and get wrong
// result.
if (dx && dout.Holder() == dx->Holder()) {
if (dx && dx->IsSharedBufferWith(dout)) {
dx->clear();
dx->mutable_data<T>(x_dims, ctx.GetPlace());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册