未验证 提交 eab34b2d 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix_dygraph_mem_leak, test=develop (#17396)

上级 5d1ac41b
......@@ -118,7 +118,7 @@ void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) {
VLOG(2) << "added grad: " << var_pair.second->Name()
<< " trace id is: " << var_pair.first;
AddTo(grad_to_add, origin_grad, current.first);
delete grad_to_add;
delete var_pair.second;
var_pair.second = nullptr;
}
}
......@@ -230,16 +230,14 @@ std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
new_var->var_->GetMutable<framework::LoDTensor>();
tensor->set_lod(var_->Get<framework::LoDTensor>().lod());
const auto& src_tensor = var_->Get<framework::LoDTensor>();
framework::TensorCopy(src_tensor, dst_place, tensor);
if (blocking) {
platform::DeviceContext* dev_ctx =
platform::DeviceContextPool::Instance().Get(dst_place);
framework::TensorCopySync(var_->Get<framework::LoDTensor>(), dst_place,
tensor);
dev_ctx->Wait();
} else {
framework::TensorCopy(var_->Get<framework::LoDTensor>(), dst_place, tensor);
platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
auto src_place = src_tensor.place();
if (!(src_place == dst_place)) {
platform::DeviceContextPool::Instance().Get(src_place)->Wait();
}
}
if (platform::is_gpu_place(dst_place)) {
......@@ -402,7 +400,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
<< origin_outputs[i]->name_ << " Grad to be added is "
<< outputs[i]->name_;
AddTo(grad, orig_grad, place_);
delete grad;
delete outputs[i];
}
}
}
......
......@@ -77,8 +77,8 @@ void TemporaryAllocator::Free(alloc::Allocation *allocation) {
<< "wait_delete_mem: " << wait_delete_mem;
}
if (FLAGS_limit_of_tmp_allocation > 0 &&
wait_delete_mem > static_cast<size_t>(FLAGS_limit_of_tmp_allocation)) {
if (FLAGS_limit_of_tmp_allocation >= 0 &&
wait_delete_mem >= static_cast<size_t>(FLAGS_limit_of_tmp_allocation)) {
PADDLE_ENFORCE(callback_ != nullptr, "The callback is non-initialized.");
Release(callback_);
}
......
......@@ -49,6 +49,10 @@ class Tracer(core.Tracer):
return list((item for name, item in six.iteritems(self._vars)
if isinstance(item, framework.Parameter)))
def _clear_ops(self):
self._ops = defaultdict()
self._trace_id = 0
def trace_op(self, op, inputs, outputs, stop_gradient=False):
# TODO(minqiyang): remove this line after we take apart all
# backward grads and forward variables
......
......@@ -531,15 +531,12 @@ class Variable(object):
def backward(self, backward_strategy=None):
from .dygraph import BackwardStrategy
if isinstance(backward_strategy, BackwardStrategy):
self._ivar._run_backward(backward_strategy)
elif backward_strategy is not None:
raise TypeError(
"only BackwardStrategy type should be passed in backward")
else:
if backward_strategy is None:
backward_strategy = BackwardStrategy()
backward_strategy.sort_sum_gradient = False
self._ivar._run_backward(backward_strategy)
_dygraph_tracer()._clear_ops()
def gradient(self):
new_ivar = self._ivar._grad_ivar()._copy_to(core.CPUPlace(), True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册