提交 cb85ee98 编写于 作者: M minqiyang

Remove var op deps in imperative mode

test=develop
上级 50c2eb0e
......@@ -159,6 +159,7 @@ void BlockDesc::RemoveOpInternal(const OpDesc *op_desc) {
for (auto it = ops_.begin(); it != ops_.end(); ++it) {
if (it->get() == op_desc) {
ops_.erase(it);
break;
}
}
}
......
......@@ -158,8 +158,9 @@ class Autograd {
for (auto it : candidate->pre_ops_) {
for (OpBase* pre_op : it.second) {
if (!pre_op) continue;
VLOG(5) << "op dep " << candidate->op_desc_->Type() << " <---- "
<< it.first << " <---- " << pre_op->op_desc_->Type();
VLOG(5) << "op dep " << candidate->op_desc_->Type() << " "
<< candidate->trace_id_ << " <---- " << it.first << " <---- "
<< pre_op->op_desc_->Type() << " " << pre_op->trace_id_;
if (visited.find(pre_op) == visited.end()) {
visited.insert(pre_op);
queue.push_back(pre_op);
......
......@@ -714,7 +714,9 @@ class Operator(object):
out_arg_names = []
for arg in out_args:
out_arg_names.append(cpt.to_text(arg.name))
arg.op = self
# TODO(minqiyang): could we remove variable's op in static mode?
if not _in_imperative_mode():
arg.op = self
self.desc.set_output(out_proto.name, out_arg_names)
if op_attrs is not None:
......
......@@ -24,6 +24,10 @@ __all__ = ['Tracer']
def release_op(op):
import gc
assert len(
gc.get_referrers(framework._imperative_tracer()._ops[
op._trace_id])) == 1
del framework._imperative_tracer()._ops[op._trace_id]
......@@ -41,7 +45,6 @@ class Tracer(core.Tracer):
def trace_op(self, op, stop_gradient=False):
# record op's trace id
op.iop._trace_id = self._trace_id
self._trace_id += 1
# trace op and save it
backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.block.desc,
......@@ -49,6 +52,7 @@ class Tracer(core.Tracer):
stop_gradient)
if not stop_gradient:
self._trace_id += 1
self._ops[op.iop._trace_id] = op
# register backward hooks and variables if needed
......
......@@ -19,6 +19,7 @@ import numpy as np
from .wrapped_decorator import signature_safe_contextmanager
from .core import VarDesc
from . import unique_name
from .imperative import base
__all__ = [
'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear',
......@@ -165,7 +166,8 @@ class ConstantInitializer(Initializer):
'force_cpu': self._force_cpu or force_init_on_cpu()
},
stop_gradient=True)
var.op = op
if not base.enabled():
var.op = op
return op
......@@ -244,7 +246,8 @@ class UniformInitializer(Initializer):
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
var.op = op
if not base.enabled():
var.op = op
return op
......@@ -322,7 +325,8 @@ class NormalInitializer(Initializer):
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
var.op = op
if not base.enabled():
var.op = op
return op
......@@ -400,7 +404,8 @@ class TruncatedNormalInitializer(Initializer):
outputs={"Out": var},
attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype})
var.op = op
if not base.enabled():
var.op = op
return op
......@@ -505,7 +510,8 @@ class XavierInitializer(Initializer):
"seed": self._seed
},
stop_gradient=True)
var.op = op
if not base.enabled():
var.op = op
return op
......@@ -605,7 +611,8 @@ class MSRAInitializer(Initializer):
"seed": self._seed
},
stop_gradient=True)
var.op = op
if not base.enabled():
var.op = op
return op
......@@ -703,7 +710,8 @@ class BilinearInitializer(Initializer):
'shape': list(shape),
value_name: values
})
var.op = op
if not base.enabled():
var.op = op
return op
......@@ -761,7 +769,8 @@ class NumpyArrayInitializer(Initializer):
value_name: values
},
stop_gradient=True)
var.op = op
if not base.enabled():
var.op = op
return op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册