From 20e304f2aedd5a09b260f593408057217a806e94 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Thu, 11 Apr 2019 09:56:52 +0800 Subject: [PATCH] Tracer does not hold op any more test=develop --- python/paddle/fluid/dygraph/tracer.py | 37 +++++++++++++++++----- python/paddle/fluid/framework.py | 45 +++++++++++---------------- 2 files changed, 48 insertions(+), 34 deletions(-) diff --git a/python/paddle/fluid/dygraph/tracer.py b/python/paddle/fluid/dygraph/tracer.py index 94e212b139..e5e715bcdc 100644 --- a/python/paddle/fluid/dygraph/tracer.py +++ b/python/paddle/fluid/dygraph/tracer.py @@ -24,7 +24,7 @@ __all__ = ['Tracer'] def release_op(op): - del framework._dygraph_tracer()._ops[op._trace_id] + del framework._dygraph_tracer()._ops[op._trace_id].inputs class Tracer(core.Tracer): @@ -46,11 +46,34 @@ class Tracer(core.Tracer): return list((item for name, item in six.iteritems(self._vars) if isinstance(item, framework.Parameter))) - def trace_op(self, op, stop_gradient=False): + def trace_op(self, op, inputs, outputs, stop_gradient=False): + # TODO(minqiyang): remove this line after we take apart all + # backward grads and forward variables + op.inputs = inputs + inps = defaultdict(list) + for k, vars in six.iteritems(inputs): + if isinstance(vars, framework.Variable): + op.previous_ops.append(vars.op) + inps[k].append(vars._ivar) + elif isinstance(vars, list) or isinstance(vars, tuple): + for var in vars: + op.previous_ops.append(var.op) + inps[k].append(var._ivar) + + outs = defaultdict(list) + for k, vars in six.iteritems(outputs): + if isinstance(vars, framework.Variable): + vars.op = op + outs[k].append(vars._ivar) + elif isinstance(vars, list) or isinstance(vars, tuple): + for var in vars: + var.op = op + outs[k].append(var._ivar) + # record op's trace id op.iop._trace_id = self._trace_id - backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.attrs, + backward_refs = self.trace(op.iop, inps, outs, op.attrs, framework._current_expected_place(), stop_gradient) @@ -65,10 +88,10 @@ class Tracer(core.Tracer): # TODO(minqiyang): remove all inputs and outputs after separate # var and grad op.backward_refs = defaultdict(list) - for k, v in six.iteritems(op.inputs): + for k, v in six.iteritems(inputs): if k in backward_refs: - op.backward_refs[k] = op.inputs[k] + op.backward_refs[k] = inputs[k] - for k, v in six.iteritems(op.outputs): + for k, v in six.iteritems(outputs): if k in backward_refs: - op.backward_refs[k] = op.outputs[k] + op.backward_refs[k] = outputs[k] diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index c05e5fb9e3..a29db04900 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -411,6 +411,7 @@ class Variable(object): if persistable else False) if persistable: _dygraph_tracer().trace_var(name, self) + self.op = None else: self.error_clip = error_clip @@ -939,26 +940,9 @@ class Operator(object): raise ValueError( "`type` to initialized an Operator can not be None.") self.iop = core.OpBase(type) + self.previous_ops = [] - # TODO(minqiyang): remove these lines after we take apart all - # backward grads and forward variables - self.inputs = defaultdict(list) - if inputs is not None: - for k, v in six.iteritems(inputs): - if isinstance(v, Variable): - self.inputs[k].append(v._ivar) - elif isinstance(v, list) or isinstance(v, tuple): - self.inputs[k].extend([var._ivar for var in v]) - - self.outputs = defaultdict(list) - if outputs is not None: - for k, v in six.iteritems(outputs): - if isinstance(v, Variable): - self.outputs[k].append(v._ivar) - elif isinstance(v, list) or isinstance(v, tuple): - self.outputs[k].extend([var._ivar for var in v]) - - self.attrs = attrs if attrs else {} + self.attrs = attrs else: self.block = block self.desc = desc @@ -1647,15 +1631,18 @@ class Block(object): block=self, desc=None, type=kwargs.get("type", None), - inputs=kwargs.get("inputs", None), - outputs=kwargs.get("outputs", None), - attrs=kwargs.get("attrs", None)) + inputs=None, + outputs=None, + attrs=kwargs.get("attrs", {})) # record ops in tracer rather than blocks # # TODO(minqiyang): add op stop_gradient support in static mode too. # currently, we only support stop_gradient in dygraph mode. - _dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False)) + _dygraph_tracer().trace_op(op, + kwargs.get("inputs", {}), + kwargs.get("outputs", {}), + kwargs.get("stop_gradient", False)) else: op_desc = self.desc.append_op() op = Operator( @@ -1719,10 +1706,14 @@ class Block(object): self, None, type=kwargs.get("type", None), - inputs=kwargs.get("inputs", None), - outputs=kwargs.get("outputs", None), - attrs=kwargs.get("attrs", None)) - _dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False)) + inputs=None, + outputs=None, + attrs=kwargs.get("attrs", {})) + + _dygraph_tracer().trace_op(op, + kwargs.get("inputs", {}), + kwargs.get("outputs", {}), + kwargs.get("stop_gradient", False)) else: op_desc = self.desc._prepend_op() op = Operator( -- GitLab