提交 20e304f2 编写于 作者: M minqiyang

Tracer does not hold op any more

test=develop
上级 112f1614
......@@ -24,7 +24,7 @@ __all__ = ['Tracer']
def release_op(op):
del framework._dygraph_tracer()._ops[op._trace_id]
del framework._dygraph_tracer()._ops[op._trace_id].inputs
class Tracer(core.Tracer):
......@@ -46,11 +46,34 @@ class Tracer(core.Tracer):
return list((item for name, item in six.iteritems(self._vars)
if isinstance(item, framework.Parameter)))
def trace_op(self, op, stop_gradient=False):
def trace_op(self, op, inputs, outputs, stop_gradient=False):
# TODO(minqiyang): remove this line after we take apart all
# backward grads and forward variables
op.inputs = inputs
inps = defaultdict(list)
for k, vars in six.iteritems(inputs):
if isinstance(vars, framework.Variable):
op.previous_ops.append(vars.op)
inps[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
op.previous_ops.append(var.op)
inps[k].append(var._ivar)
outs = defaultdict(list)
for k, vars in six.iteritems(outputs):
if isinstance(vars, framework.Variable):
vars.op = op
outs[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
var.op = op
outs[k].append(var._ivar)
# record op's trace id
op.iop._trace_id = self._trace_id
backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.attrs,
backward_refs = self.trace(op.iop, inps, outs, op.attrs,
framework._current_expected_place(),
stop_gradient)
......@@ -65,10 +88,10 @@ class Tracer(core.Tracer):
# TODO(minqiyang): remove all inputs and outputs after separate
# var and grad
op.backward_refs = defaultdict(list)
for k, v in six.iteritems(op.inputs):
for k, v in six.iteritems(inputs):
if k in backward_refs:
op.backward_refs[k] = op.inputs[k]
op.backward_refs[k] = inputs[k]
for k, v in six.iteritems(op.outputs):
for k, v in six.iteritems(outputs):
if k in backward_refs:
op.backward_refs[k] = op.outputs[k]
op.backward_refs[k] = outputs[k]
......@@ -411,6 +411,7 @@ class Variable(object):
if persistable else False)
if persistable:
_dygraph_tracer().trace_var(name, self)
self.op = None
else:
self.error_clip = error_clip
......@@ -939,26 +940,9 @@ class Operator(object):
raise ValueError(
"`type` to initialized an Operator can not be None.")
self.iop = core.OpBase(type)
self.previous_ops = []
# TODO(minqiyang): remove these lines after we take apart all
# backward grads and forward variables
self.inputs = defaultdict(list)
if inputs is not None:
for k, v in six.iteritems(inputs):
if isinstance(v, Variable):
self.inputs[k].append(v._ivar)
elif isinstance(v, list) or isinstance(v, tuple):
self.inputs[k].extend([var._ivar for var in v])
self.outputs = defaultdict(list)
if outputs is not None:
for k, v in six.iteritems(outputs):
if isinstance(v, Variable):
self.outputs[k].append(v._ivar)
elif isinstance(v, list) or isinstance(v, tuple):
self.outputs[k].extend([var._ivar for var in v])
self.attrs = attrs if attrs else {}
self.attrs = attrs
else:
self.block = block
self.desc = desc
......@@ -1647,15 +1631,18 @@ class Block(object):
block=self,
desc=None,
type=kwargs.get("type", None),
inputs=kwargs.get("inputs", None),
outputs=kwargs.get("outputs", None),
attrs=kwargs.get("attrs", None))
inputs=None,
outputs=None,
attrs=kwargs.get("attrs", {}))
# record ops in tracer rather than blocks
#
# TODO(minqiyang): add op stop_gradient support in static mode too.
# currently, we only support stop_gradient in dygraph mode.
_dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False))
_dygraph_tracer().trace_op(op,
kwargs.get("inputs", {}),
kwargs.get("outputs", {}),
kwargs.get("stop_gradient", False))
else:
op_desc = self.desc.append_op()
op = Operator(
......@@ -1719,10 +1706,14 @@ class Block(object):
self,
None,
type=kwargs.get("type", None),
inputs=kwargs.get("inputs", None),
outputs=kwargs.get("outputs", None),
attrs=kwargs.get("attrs", None))
_dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False))
inputs=None,
outputs=None,
attrs=kwargs.get("attrs", {}))
_dygraph_tracer().trace_op(op,
kwargs.get("inputs", {}),
kwargs.get("outputs", {}),
kwargs.get("stop_gradient", False))
else:
op_desc = self.desc._prepend_op()
op = Operator(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册