提交 b71af29f 编写于 作者: M minqiyang 提交者: ceci3

Remove var op deps in imperative mode

test=develop
上级 690be0bb
...@@ -159,6 +159,7 @@ void BlockDesc::RemoveOpInternal(const OpDesc *op_desc) { ...@@ -159,6 +159,7 @@ void BlockDesc::RemoveOpInternal(const OpDesc *op_desc) {
for (auto it = ops_.begin(); it != ops_.end(); ++it) { for (auto it = ops_.begin(); it != ops_.end(); ++it) {
if (it->get() == op_desc) { if (it->get() == op_desc) {
ops_.erase(it); ops_.erase(it);
break;
} }
} }
} }
......
...@@ -158,8 +158,9 @@ class Autograd { ...@@ -158,8 +158,9 @@ class Autograd {
for (auto it : candidate->pre_ops_) { for (auto it : candidate->pre_ops_) {
for (OpBase* pre_op : it.second) { for (OpBase* pre_op : it.second) {
if (!pre_op) continue; if (!pre_op) continue;
VLOG(5) << "op dep " << candidate->op_desc_->Type() << " <---- " VLOG(5) << "op dep " << candidate->op_desc_->Type() << " "
<< it.first << " <---- " << pre_op->op_desc_->Type(); << candidate->trace_id_ << " <---- " << it.first << " <---- "
<< pre_op->op_desc_->Type() << " " << pre_op->trace_id_;
if (visited.find(pre_op) == visited.end()) { if (visited.find(pre_op) == visited.end()) {
visited.insert(pre_op); visited.insert(pre_op);
queue.push_back(pre_op); queue.push_back(pre_op);
......
...@@ -723,6 +723,8 @@ class Operator(object): ...@@ -723,6 +723,8 @@ class Operator(object):
out_arg_names = [] out_arg_names = []
for arg in out_args: for arg in out_args:
out_arg_names.append(cpt.to_text(arg.name)) out_arg_names.append(cpt.to_text(arg.name))
# TODO(minqiyang): could we remove variable's op in static mode?
if not _in_imperative_mode():
arg.op = self arg.op = self
self.desc.set_output(out_proto.name, out_arg_names) self.desc.set_output(out_proto.name, out_arg_names)
......
...@@ -24,6 +24,10 @@ __all__ = ['Tracer'] ...@@ -24,6 +24,10 @@ __all__ = ['Tracer']
def release_op(op): def release_op(op):
import gc
assert len(
gc.get_referrers(framework._imperative_tracer()._ops[
op._trace_id])) == 1
del framework._imperative_tracer()._ops[op._trace_id] del framework._imperative_tracer()._ops[op._trace_id]
...@@ -41,7 +45,6 @@ class Tracer(core.Tracer): ...@@ -41,7 +45,6 @@ class Tracer(core.Tracer):
def trace_op(self, op, stop_gradient=False): def trace_op(self, op, stop_gradient=False):
# record op's trace id # record op's trace id
op.iop._trace_id = self._trace_id op.iop._trace_id = self._trace_id
self._trace_id += 1
# trace op and save it # trace op and save it
backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.block.desc, backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.block.desc,
...@@ -49,6 +52,7 @@ class Tracer(core.Tracer): ...@@ -49,6 +52,7 @@ class Tracer(core.Tracer):
stop_gradient) stop_gradient)
if not stop_gradient: if not stop_gradient:
self._trace_id += 1
self._ops[op.iop._trace_id] = op self._ops[op.iop._trace_id] = op
# register backward hooks and variables if needed # register backward hooks and variables if needed
......
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
from .wrapped_decorator import signature_safe_contextmanager from .wrapped_decorator import signature_safe_contextmanager
from .core import VarDesc from .core import VarDesc
from . import unique_name from . import unique_name
from .imperative import base
__all__ = [ __all__ = [
'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear', 'Constant', 'Uniform', 'Normal', 'TruncatedNormal', 'Xavier', 'Bilinear',
...@@ -165,6 +166,7 @@ class ConstantInitializer(Initializer): ...@@ -165,6 +166,7 @@ class ConstantInitializer(Initializer):
'force_cpu': self._force_cpu or force_init_on_cpu() 'force_cpu': self._force_cpu or force_init_on_cpu()
}, },
stop_gradient=True) stop_gradient=True)
if not base.enabled():
var.op = op var.op = op
return op return op
...@@ -244,6 +246,7 @@ class UniformInitializer(Initializer): ...@@ -244,6 +246,7 @@ class UniformInitializer(Initializer):
attrs={"in_dtype": out_var.dtype, attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype}) "out_dtype": var.dtype})
if not base.enabled():
var.op = op var.op = op
return op return op
...@@ -322,6 +325,7 @@ class NormalInitializer(Initializer): ...@@ -322,6 +325,7 @@ class NormalInitializer(Initializer):
outputs={"Out": var}, outputs={"Out": var},
attrs={"in_dtype": out_var.dtype, attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype}) "out_dtype": var.dtype})
if not base.enabled():
var.op = op var.op = op
return op return op
...@@ -400,6 +404,7 @@ class TruncatedNormalInitializer(Initializer): ...@@ -400,6 +404,7 @@ class TruncatedNormalInitializer(Initializer):
outputs={"Out": var}, outputs={"Out": var},
attrs={"in_dtype": out_var.dtype, attrs={"in_dtype": out_var.dtype,
"out_dtype": var.dtype}) "out_dtype": var.dtype})
if not base.enabled():
var.op = op var.op = op
return op return op
...@@ -505,6 +510,7 @@ class XavierInitializer(Initializer): ...@@ -505,6 +510,7 @@ class XavierInitializer(Initializer):
"seed": self._seed "seed": self._seed
}, },
stop_gradient=True) stop_gradient=True)
if not base.enabled():
var.op = op var.op = op
return op return op
...@@ -605,6 +611,7 @@ class MSRAInitializer(Initializer): ...@@ -605,6 +611,7 @@ class MSRAInitializer(Initializer):
"seed": self._seed "seed": self._seed
}, },
stop_gradient=True) stop_gradient=True)
if not base.enabled():
var.op = op var.op = op
return op return op
...@@ -703,6 +710,7 @@ class BilinearInitializer(Initializer): ...@@ -703,6 +710,7 @@ class BilinearInitializer(Initializer):
'shape': list(shape), 'shape': list(shape),
value_name: values value_name: values
}) })
if not base.enabled():
var.op = op var.op = op
return op return op
...@@ -761,6 +769,7 @@ class NumpyArrayInitializer(Initializer): ...@@ -761,6 +769,7 @@ class NumpyArrayInitializer(Initializer):
value_name: values value_name: values
}, },
stop_gradient=True) stop_gradient=True)
if not base.enabled():
var.op = op var.op = op
return op return op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册