未验证 提交 eab34b2d 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix_dygraph_mem_leak, test=develop (#17396)

上级 5d1ac41b
...@@ -118,7 +118,7 @@ void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) { ...@@ -118,7 +118,7 @@ void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) {
VLOG(2) << "added grad: " << var_pair.second->Name() VLOG(2) << "added grad: " << var_pair.second->Name()
<< " trace id is: " << var_pair.first; << " trace id is: " << var_pair.first;
AddTo(grad_to_add, origin_grad, current.first); AddTo(grad_to_add, origin_grad, current.first);
delete grad_to_add; delete var_pair.second;
var_pair.second = nullptr; var_pair.second = nullptr;
} }
} }
...@@ -230,16 +230,14 @@ std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place, ...@@ -230,16 +230,14 @@ std::unique_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
new_var->var_->GetMutable<framework::LoDTensor>(); new_var->var_->GetMutable<framework::LoDTensor>();
tensor->set_lod(var_->Get<framework::LoDTensor>().lod()); tensor->set_lod(var_->Get<framework::LoDTensor>().lod());
const auto& src_tensor = var_->Get<framework::LoDTensor>();
framework::TensorCopy(src_tensor, dst_place, tensor);
if (blocking) { if (blocking) {
platform::DeviceContext* dev_ctx = platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
platform::DeviceContextPool::Instance().Get(dst_place); auto src_place = src_tensor.place();
if (!(src_place == dst_place)) {
framework::TensorCopySync(var_->Get<framework::LoDTensor>(), dst_place, platform::DeviceContextPool::Instance().Get(src_place)->Wait();
tensor); }
dev_ctx->Wait();
} else {
framework::TensorCopy(var_->Get<framework::LoDTensor>(), dst_place, tensor);
} }
if (platform::is_gpu_place(dst_place)) { if (platform::is_gpu_place(dst_place)) {
...@@ -402,7 +400,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad( ...@@ -402,7 +400,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
<< origin_outputs[i]->name_ << " Grad to be added is " << origin_outputs[i]->name_ << " Grad to be added is "
<< outputs[i]->name_; << outputs[i]->name_;
AddTo(grad, orig_grad, place_); AddTo(grad, orig_grad, place_);
delete grad; delete outputs[i];
} }
} }
} }
......
...@@ -77,8 +77,8 @@ void TemporaryAllocator::Free(alloc::Allocation *allocation) { ...@@ -77,8 +77,8 @@ void TemporaryAllocator::Free(alloc::Allocation *allocation) {
<< "wait_delete_mem: " << wait_delete_mem; << "wait_delete_mem: " << wait_delete_mem;
} }
if (FLAGS_limit_of_tmp_allocation > 0 && if (FLAGS_limit_of_tmp_allocation >= 0 &&
wait_delete_mem > static_cast<size_t>(FLAGS_limit_of_tmp_allocation)) { wait_delete_mem >= static_cast<size_t>(FLAGS_limit_of_tmp_allocation)) {
PADDLE_ENFORCE(callback_ != nullptr, "The callback is non-initialized."); PADDLE_ENFORCE(callback_ != nullptr, "The callback is non-initialized.");
Release(callback_); Release(callback_);
} }
......
...@@ -49,6 +49,10 @@ class Tracer(core.Tracer): ...@@ -49,6 +49,10 @@ class Tracer(core.Tracer):
return list((item for name, item in six.iteritems(self._vars) return list((item for name, item in six.iteritems(self._vars)
if isinstance(item, framework.Parameter))) if isinstance(item, framework.Parameter)))
def _clear_ops(self):
self._ops = defaultdict()
self._trace_id = 0
def trace_op(self, op, inputs, outputs, stop_gradient=False): def trace_op(self, op, inputs, outputs, stop_gradient=False):
# TODO(minqiyang): remove this line after we take apart all # TODO(minqiyang): remove this line after we take apart all
# backward grads and forward variables # backward grads and forward variables
......
...@@ -531,15 +531,12 @@ class Variable(object): ...@@ -531,15 +531,12 @@ class Variable(object):
def backward(self, backward_strategy=None): def backward(self, backward_strategy=None):
from .dygraph import BackwardStrategy from .dygraph import BackwardStrategy
if isinstance(backward_strategy, BackwardStrategy): if backward_strategy is None:
self._ivar._run_backward(backward_strategy)
elif backward_strategy is not None:
raise TypeError(
"only BackwardStrategy type should be passed in backward")
else:
backward_strategy = BackwardStrategy() backward_strategy = BackwardStrategy()
backward_strategy.sort_sum_gradient = False backward_strategy.sort_sum_gradient = False
self._ivar._run_backward(backward_strategy)
self._ivar._run_backward(backward_strategy)
_dygraph_tracer()._clear_ops()
def gradient(self): def gradient(self):
new_ivar = self._ivar._grad_ivar()._copy_to(core.CPUPlace(), True) new_ivar = self._ivar._grad_ivar()._copy_to(core.CPUPlace(), True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册