提交 192b6d63 编写于 作者: M minqiyang

Untrack op in eval mode

test=release/1.4
上级 4914da1b
......@@ -48,6 +48,12 @@ class Layer(core.Layer):
self._helper = LayerObjectHelper(self._full_name)
def train(self):
framework._dygraph_tracer()._train_mode()
def eval(self):
framework._dygraph_tracer()._eval_mode()
def full_name(self):
"""Full name for this layers.
......@@ -254,6 +260,12 @@ class PyLayer(core.PyLayer):
def __init__(self):
super(PyLayer, self).__init__()
def train(self):
framework._dygraph_tracer()._train_mode()
def eval(self):
framework._dygraph_tracer()._eval_mode()
@classmethod
def _do_forward(cls, inputs):
return cls._to_tuple(cls.forward(inputs))
......
......@@ -24,7 +24,9 @@ __all__ = ['Tracer']
def release_op(op):
del framework._dygraph_tracer()._ops[op._trace_id]
del framework._dygraph_tracer()._ops[op._trace_id].inputs
del framework._dygraph_tracer()._ops[op._trace_id].outputs
del framework._dygraph_tracer()._ops[op._trace_id].backward_refs
class Tracer(core.Tracer):
......@@ -38,6 +40,7 @@ class Tracer(core.Tracer):
self._ops = defaultdict()
self._vars = defaultdict()
self._trace_id = 0
self._train_mode = True
def trace_var(self, name, var):
self._vars[name] = var
......@@ -46,15 +49,57 @@ class Tracer(core.Tracer):
return list((item for name, item in six.iteritems(self._vars)
if isinstance(item, framework.Parameter)))
def trace_op(self, op, stop_gradient=False):
def trace_op(self, op, inputs, outputs, stop_gradient=False):
# TODO(minqiyang): remove this line after we take apart all
# backward grads and forward variables
if self._train_mode:
op.inputs = inputs
inps = defaultdict(list)
for k, vars in six.iteritems(inputs):
if isinstance(vars, framework.Variable):
inps[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
inps[k].append(var._ivar)
op.outputs = outputs
outs = defaultdict(list)
for k, vars in six.iteritems(outputs):
if isinstance(vars, framework.Variable):
outs[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
outs[k].append(var._ivar)
else:
inps = defaultdict(list)
for k, vars in six.iteritems(inputs):
if isinstance(vars, framework.Variable):
op.previous_ops.append(vars.op)
inps[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
op.previous_ops.append(var.op)
inps[k].append(var._ivar)
op.outputs = outputs
outs = defaultdict(list)
for k, vars in six.iteritems(outputs):
if isinstance(vars, framework.Variable):
vars.op = op
outs[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
var.op = op
outs[k].append(var._ivar)
# record op's trace id
op.iop._trace_id = self._trace_id
backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.attrs,
backward_refs = self.trace(op.iop, inps, outs, op.attrs,
framework._current_expected_place(),
stop_gradient)
if not stop_gradient:
if not stop_gradient and self._train_mode:
self._trace_id += 1
self._ops[op.iop._trace_id] = op
......@@ -65,10 +110,16 @@ class Tracer(core.Tracer):
# TODO(minqiyang): remove all inputs and outputs after separate
# var and grad
op.backward_refs = defaultdict(list)
for k, v in six.iteritems(op.inputs):
for k, v in six.iteritems(inputs):
if k in backward_refs:
op.backward_refs[k] = op.inputs[k]
op.backward_refs[k] = inputs[k]
for k, v in six.iteritems(op.outputs):
for k, v in six.iteritems(outputs):
if k in backward_refs:
op.backward_refs[k] = op.outputs[k]
op.backward_refs[k] = outputs[k]
def _train_mode(self):
self._train_mode = True
def _eval_mode(self):
self._train_mode = False
......@@ -407,6 +407,7 @@ class Variable(object):
if persistable else False)
if persistable:
_dygraph_tracer().trace_var(name, self)
self.op = None
else:
self.error_clip = error_clip
......@@ -935,26 +936,9 @@ class Operator(object):
raise ValueError(
"`type` to initialized an Operator can not be None.")
self.iop = core.OpBase(type)
self.previous_ops = []
# TODO(minqiyang): remove these lines after we take apart all
# backward grads and forward variables
self.inputs = defaultdict(list)
if inputs is not None:
for k, v in six.iteritems(inputs):
if isinstance(v, Variable):
self.inputs[k].append(v._ivar)
elif isinstance(v, list) or isinstance(v, tuple):
self.inputs[k].extend([var._ivar for var in v])
self.outputs = defaultdict(list)
if outputs is not None:
for k, v in six.iteritems(outputs):
if isinstance(v, Variable):
self.outputs[k].append(v._ivar)
elif isinstance(v, list) or isinstance(v, tuple):
self.outputs[k].extend([var._ivar for var in v])
self.attrs = attrs if attrs else {}
self.attrs = attrs
else:
self.block = block
self.desc = desc
......@@ -1643,15 +1627,18 @@ class Block(object):
block=self,
desc=None,
type=kwargs.get("type", None),
inputs=kwargs.get("inputs", None),
outputs=kwargs.get("outputs", None),
attrs=kwargs.get("attrs", None))
inputs=None,
outputs=None,
attrs=kwargs.get("attrs", {}))
# record ops in tracer rather than blocks
#
# TODO(minqiyang): add op stop_gradient support in static mode too.
# currently, we only support stop_gradient in dygraph mode.
_dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False))
_dygraph_tracer().trace_op(op,
kwargs.get("inputs", {}),
kwargs.get("outputs", {}),
kwargs.get("stop_gradient", False))
else:
op_desc = self.desc.append_op()
op = Operator(
......@@ -1715,10 +1702,14 @@ class Block(object):
self,
None,
type=kwargs.get("type", None),
inputs=kwargs.get("inputs", None),
outputs=kwargs.get("outputs", None),
attrs=kwargs.get("attrs", None))
_dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False))
inputs=None,
outputs=None,
attrs=kwargs.get("attrs", {}))
_dygraph_tracer().trace_op(op,
kwargs.get("inputs", {}),
kwargs.get("outputs", {}),
kwargs.get("stop_gradient", False))
else:
op_desc = self.desc._prepend_op()
op = Operator(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册