diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 014ee41f4c5aa280fb5b366d8f1704290cc067d4..d564ac6e4a9714ebe28643c6f36a749e23f4824e 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -48,6 +48,12 @@ class Layer(core.Layer): self._helper = LayerObjectHelper(self._full_name) + def train(self): + framework._dygraph_tracer()._train_mode() + + def eval(self): + framework._dygraph_tracer()._eval_mode() + def full_name(self): """Full name for this layers. @@ -254,6 +260,12 @@ class PyLayer(core.PyLayer): def __init__(self): super(PyLayer, self).__init__() + def train(self): + framework._dygraph_tracer()._train_mode() + + def eval(self): + framework._dygraph_tracer()._eval_mode() + @classmethod def _do_forward(cls, inputs): return cls._to_tuple(cls.forward(inputs)) diff --git a/python/paddle/fluid/dygraph/tracer.py b/python/paddle/fluid/dygraph/tracer.py index 94e212b139b2b375aa9f5252d396e90235ba33c1..ee37ffab2cb7521b83108a40febcfe88cab28633 100644 --- a/python/paddle/fluid/dygraph/tracer.py +++ b/python/paddle/fluid/dygraph/tracer.py @@ -24,7 +24,9 @@ __all__ = ['Tracer'] def release_op(op): - del framework._dygraph_tracer()._ops[op._trace_id] + del framework._dygraph_tracer()._ops[op._trace_id].inputs + del framework._dygraph_tracer()._ops[op._trace_id].outputs + del framework._dygraph_tracer()._ops[op._trace_id].backward_refs class Tracer(core.Tracer): @@ -38,6 +40,7 @@ class Tracer(core.Tracer): self._ops = defaultdict() self._vars = defaultdict() self._trace_id = 0 + self._train_mode = True def trace_var(self, name, var): self._vars[name] = var @@ -46,15 +49,57 @@ class Tracer(core.Tracer): return list((item for name, item in six.iteritems(self._vars) if isinstance(item, framework.Parameter))) - def trace_op(self, op, stop_gradient=False): + def trace_op(self, op, inputs, outputs, stop_gradient=False): + # TODO(minqiyang): remove this line after we take apart all + # backward grads and forward variables + if self._train_mode: + op.inputs = inputs + inps = defaultdict(list) + for k, vars in six.iteritems(inputs): + if isinstance(vars, framework.Variable): + inps[k].append(vars._ivar) + elif isinstance(vars, list) or isinstance(vars, tuple): + for var in vars: + inps[k].append(var._ivar) + + op.outputs = outputs + outs = defaultdict(list) + for k, vars in six.iteritems(outputs): + if isinstance(vars, framework.Variable): + outs[k].append(vars._ivar) + elif isinstance(vars, list) or isinstance(vars, tuple): + for var in vars: + outs[k].append(var._ivar) + else: + inps = defaultdict(list) + for k, vars in six.iteritems(inputs): + if isinstance(vars, framework.Variable): + op.previous_ops.append(vars.op) + inps[k].append(vars._ivar) + elif isinstance(vars, list) or isinstance(vars, tuple): + for var in vars: + op.previous_ops.append(var.op) + inps[k].append(var._ivar) + + op.outputs = outputs + outs = defaultdict(list) + for k, vars in six.iteritems(outputs): + if isinstance(vars, framework.Variable): + vars.op = op + outs[k].append(vars._ivar) + elif isinstance(vars, list) or isinstance(vars, tuple): + for var in vars: + var.op = op + outs[k].append(var._ivar) + # record op's trace id op.iop._trace_id = self._trace_id - backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.attrs, + backward_refs = self.trace(op.iop, inps, outs, op.attrs, framework._current_expected_place(), stop_gradient) - if not stop_gradient: + if not stop_gradient and self._train_mode: self._trace_id += 1 self._ops[op.iop._trace_id] = op @@ -65,10 +110,16 @@ class Tracer(core.Tracer): # TODO(minqiyang): remove all inputs and outputs after separate # var and grad op.backward_refs = defaultdict(list) - for k, v in six.iteritems(op.inputs): + for k, v in six.iteritems(inputs): if k in backward_refs: - op.backward_refs[k] = op.inputs[k] + op.backward_refs[k] = inputs[k] - for k, v in six.iteritems(op.outputs): + for k, v in six.iteritems(outputs): if k in backward_refs: - op.backward_refs[k] = op.outputs[k] + op.backward_refs[k] = outputs[k] + + def _train_mode(self): + self._train_mode = True + + def _eval_mode(self): + self._train_mode = False diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 7953d98bcbb826267fa21f6503e55049c8aff5ba..efba771f27aa670a99fd310f76d0eb378fc04e4f 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -407,6 +407,7 @@ class Variable(object): if persistable else False) if persistable: _dygraph_tracer().trace_var(name, self) + self.op = None else: self.error_clip = error_clip @@ -935,26 +936,9 @@ class Operator(object): raise ValueError( "`type` to initialized an Operator can not be None.") self.iop = core.OpBase(type) + self.previous_ops = [] - # TODO(minqiyang): remove these lines after we take apart all - # backward grads and forward variables - self.inputs = defaultdict(list) - if inputs is not None: - for k, v in six.iteritems(inputs): - if isinstance(v, Variable): - self.inputs[k].append(v._ivar) - elif isinstance(v, list) or isinstance(v, tuple): - self.inputs[k].extend([var._ivar for var in v]) - - self.outputs = defaultdict(list) - if outputs is not None: - for k, v in six.iteritems(outputs): - if isinstance(v, Variable): - self.outputs[k].append(v._ivar) - elif isinstance(v, list) or isinstance(v, tuple): - self.outputs[k].extend([var._ivar for var in v]) - - self.attrs = attrs if attrs else {} + self.attrs = attrs else: self.block = block self.desc = desc @@ -1643,15 +1627,18 @@ class Block(object): block=self, desc=None, type=kwargs.get("type", None), - inputs=kwargs.get("inputs", None), - outputs=kwargs.get("outputs", None), - attrs=kwargs.get("attrs", None)) + inputs=None, + outputs=None, + attrs=kwargs.get("attrs", {})) # record ops in tracer rather than blocks # # TODO(minqiyang): add op stop_gradient support in static mode too. # currently, we only support stop_gradient in dygraph mode. - _dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False)) + _dygraph_tracer().trace_op(op, + kwargs.get("inputs", {}), + kwargs.get("outputs", {}), + kwargs.get("stop_gradient", False)) else: op_desc = self.desc.append_op() op = Operator( @@ -1715,10 +1702,14 @@ class Block(object): self, None, type=kwargs.get("type", None), - inputs=kwargs.get("inputs", None), - outputs=kwargs.get("outputs", None), - attrs=kwargs.get("attrs", None)) - _dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False)) + inputs=None, + outputs=None, + attrs=kwargs.get("attrs", {})) + + _dygraph_tracer().trace_op(op, + kwargs.get("inputs", {}), + kwargs.get("outputs", {}), + kwargs.get("stop_gradient", False)) else: op_desc = self.desc._prepend_op() op = Operator(