提交 192b6d63 编写于 作者: M minqiyang

Untrack op in eval mode

test=release/1.4
上级 4914da1b
...@@ -48,6 +48,12 @@ class Layer(core.Layer): ...@@ -48,6 +48,12 @@ class Layer(core.Layer):
self._helper = LayerObjectHelper(self._full_name) self._helper = LayerObjectHelper(self._full_name)
def train(self):
framework._dygraph_tracer()._train_mode()
def eval(self):
framework._dygraph_tracer()._eval_mode()
def full_name(self): def full_name(self):
"""Full name for this layers. """Full name for this layers.
...@@ -254,6 +260,12 @@ class PyLayer(core.PyLayer): ...@@ -254,6 +260,12 @@ class PyLayer(core.PyLayer):
def __init__(self): def __init__(self):
super(PyLayer, self).__init__() super(PyLayer, self).__init__()
def train(self):
framework._dygraph_tracer()._train_mode()
def eval(self):
framework._dygraph_tracer()._eval_mode()
@classmethod @classmethod
def _do_forward(cls, inputs): def _do_forward(cls, inputs):
return cls._to_tuple(cls.forward(inputs)) return cls._to_tuple(cls.forward(inputs))
......
...@@ -24,7 +24,9 @@ __all__ = ['Tracer'] ...@@ -24,7 +24,9 @@ __all__ = ['Tracer']
def release_op(op): def release_op(op):
del framework._dygraph_tracer()._ops[op._trace_id] del framework._dygraph_tracer()._ops[op._trace_id].inputs
del framework._dygraph_tracer()._ops[op._trace_id].outputs
del framework._dygraph_tracer()._ops[op._trace_id].backward_refs
class Tracer(core.Tracer): class Tracer(core.Tracer):
...@@ -38,6 +40,7 @@ class Tracer(core.Tracer): ...@@ -38,6 +40,7 @@ class Tracer(core.Tracer):
self._ops = defaultdict() self._ops = defaultdict()
self._vars = defaultdict() self._vars = defaultdict()
self._trace_id = 0 self._trace_id = 0
self._train_mode = True
def trace_var(self, name, var): def trace_var(self, name, var):
self._vars[name] = var self._vars[name] = var
...@@ -46,15 +49,57 @@ class Tracer(core.Tracer): ...@@ -46,15 +49,57 @@ class Tracer(core.Tracer):
return list((item for name, item in six.iteritems(self._vars) return list((item for name, item in six.iteritems(self._vars)
if isinstance(item, framework.Parameter))) if isinstance(item, framework.Parameter)))
def trace_op(self, op, stop_gradient=False): def trace_op(self, op, inputs, outputs, stop_gradient=False):
# TODO(minqiyang): remove this line after we take apart all
# backward grads and forward variables
if self._train_mode:
op.inputs = inputs
inps = defaultdict(list)
for k, vars in six.iteritems(inputs):
if isinstance(vars, framework.Variable):
inps[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
inps[k].append(var._ivar)
op.outputs = outputs
outs = defaultdict(list)
for k, vars in six.iteritems(outputs):
if isinstance(vars, framework.Variable):
outs[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
outs[k].append(var._ivar)
else:
inps = defaultdict(list)
for k, vars in six.iteritems(inputs):
if isinstance(vars, framework.Variable):
op.previous_ops.append(vars.op)
inps[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
op.previous_ops.append(var.op)
inps[k].append(var._ivar)
op.outputs = outputs
outs = defaultdict(list)
for k, vars in six.iteritems(outputs):
if isinstance(vars, framework.Variable):
vars.op = op
outs[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
var.op = op
outs[k].append(var._ivar)
# record op's trace id # record op's trace id
op.iop._trace_id = self._trace_id op.iop._trace_id = self._trace_id
backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.attrs, backward_refs = self.trace(op.iop, inps, outs, op.attrs,
framework._current_expected_place(), framework._current_expected_place(),
stop_gradient) stop_gradient)
if not stop_gradient: if not stop_gradient and self._train_mode:
self._trace_id += 1 self._trace_id += 1
self._ops[op.iop._trace_id] = op self._ops[op.iop._trace_id] = op
...@@ -65,10 +110,16 @@ class Tracer(core.Tracer): ...@@ -65,10 +110,16 @@ class Tracer(core.Tracer):
# TODO(minqiyang): remove all inputs and outputs after separate # TODO(minqiyang): remove all inputs and outputs after separate
# var and grad # var and grad
op.backward_refs = defaultdict(list) op.backward_refs = defaultdict(list)
for k, v in six.iteritems(op.inputs): for k, v in six.iteritems(inputs):
if k in backward_refs: if k in backward_refs:
op.backward_refs[k] = op.inputs[k] op.backward_refs[k] = inputs[k]
for k, v in six.iteritems(op.outputs): for k, v in six.iteritems(outputs):
if k in backward_refs: if k in backward_refs:
op.backward_refs[k] = op.outputs[k] op.backward_refs[k] = outputs[k]
def _train_mode(self):
self._train_mode = True
def _eval_mode(self):
self._train_mode = False
...@@ -407,6 +407,7 @@ class Variable(object): ...@@ -407,6 +407,7 @@ class Variable(object):
if persistable else False) if persistable else False)
if persistable: if persistable:
_dygraph_tracer().trace_var(name, self) _dygraph_tracer().trace_var(name, self)
self.op = None
else: else:
self.error_clip = error_clip self.error_clip = error_clip
...@@ -935,26 +936,9 @@ class Operator(object): ...@@ -935,26 +936,9 @@ class Operator(object):
raise ValueError( raise ValueError(
"`type` to initialized an Operator can not be None.") "`type` to initialized an Operator can not be None.")
self.iop = core.OpBase(type) self.iop = core.OpBase(type)
self.previous_ops = []
# TODO(minqiyang): remove these lines after we take apart all self.attrs = attrs
# backward grads and forward variables
self.inputs = defaultdict(list)
if inputs is not None:
for k, v in six.iteritems(inputs):
if isinstance(v, Variable):
self.inputs[k].append(v._ivar)
elif isinstance(v, list) or isinstance(v, tuple):
self.inputs[k].extend([var._ivar for var in v])
self.outputs = defaultdict(list)
if outputs is not None:
for k, v in six.iteritems(outputs):
if isinstance(v, Variable):
self.outputs[k].append(v._ivar)
elif isinstance(v, list) or isinstance(v, tuple):
self.outputs[k].extend([var._ivar for var in v])
self.attrs = attrs if attrs else {}
else: else:
self.block = block self.block = block
self.desc = desc self.desc = desc
...@@ -1643,15 +1627,18 @@ class Block(object): ...@@ -1643,15 +1627,18 @@ class Block(object):
block=self, block=self,
desc=None, desc=None,
type=kwargs.get("type", None), type=kwargs.get("type", None),
inputs=kwargs.get("inputs", None), inputs=None,
outputs=kwargs.get("outputs", None), outputs=None,
attrs=kwargs.get("attrs", None)) attrs=kwargs.get("attrs", {}))
# record ops in tracer rather than blocks # record ops in tracer rather than blocks
# #
# TODO(minqiyang): add op stop_gradient support in static mode too. # TODO(minqiyang): add op stop_gradient support in static mode too.
# currently, we only support stop_gradient in dygraph mode. # currently, we only support stop_gradient in dygraph mode.
_dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False)) _dygraph_tracer().trace_op(op,
kwargs.get("inputs", {}),
kwargs.get("outputs", {}),
kwargs.get("stop_gradient", False))
else: else:
op_desc = self.desc.append_op() op_desc = self.desc.append_op()
op = Operator( op = Operator(
...@@ -1715,10 +1702,14 @@ class Block(object): ...@@ -1715,10 +1702,14 @@ class Block(object):
self, self,
None, None,
type=kwargs.get("type", None), type=kwargs.get("type", None),
inputs=kwargs.get("inputs", None), inputs=None,
outputs=kwargs.get("outputs", None), outputs=None,
attrs=kwargs.get("attrs", None)) attrs=kwargs.get("attrs", {}))
_dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False))
_dygraph_tracer().trace_op(op,
kwargs.get("inputs", {}),
kwargs.get("outputs", {}),
kwargs.get("stop_gradient", False))
else: else:
op_desc = self.desc._prepend_op() op_desc = self.desc._prepend_op()
op = Operator( op = Operator(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册