From 306eadcd3966b73b71b831791f09a6d68ebd80c2 Mon Sep 17 00:00:00 2001 From: Hongyu Liu <43953930+phlrain@users.noreply.github.com> Date: Tue, 21 May 2019 10:07:45 +0800 Subject: [PATCH] fix eval mode bug; test=develop (#17499) --- python/paddle/fluid/dygraph/tracer.py | 59 ++++++++------------------- python/paddle/fluid/framework.py | 7 +++- 2 files changed, 24 insertions(+), 42 deletions(-) diff --git a/python/paddle/fluid/dygraph/tracer.py b/python/paddle/fluid/dygraph/tracer.py index 4248e3c310f..92092458146 100644 --- a/python/paddle/fluid/dygraph/tracer.py +++ b/python/paddle/fluid/dygraph/tracer.py @@ -54,47 +54,24 @@ class Tracer(core.Tracer): self._trace_id = 0 def trace_op(self, op, inputs, outputs, stop_gradient=False): - # TODO(minqiyang): remove this line after we take apart all - # backward grads and forward variables - if self._train_mode: - op.inputs = inputs - inps = defaultdict(list) - for k, vars in six.iteritems(inputs): - if isinstance(vars, framework.Variable): - inps[k].append(vars._ivar) - elif isinstance(vars, list) or isinstance(vars, tuple): - for var in vars: - inps[k].append(var._ivar) - - op.outputs = outputs - outs = defaultdict(list) - for k, vars in six.iteritems(outputs): - if isinstance(vars, framework.Variable): - outs[k].append(vars._ivar) - elif isinstance(vars, list) or isinstance(vars, tuple): - for var in vars: - outs[k].append(var._ivar) - else: - inps = defaultdict(list) - for k, vars in six.iteritems(inputs): - if isinstance(vars, framework.Variable): - op.previous_ops.append(vars.op) - inps[k].append(vars._ivar) - elif isinstance(vars, list) or isinstance(vars, tuple): - for var in vars: - op.previous_ops.append(var.op) - inps[k].append(var._ivar) - - op.outputs = outputs - outs = defaultdict(list) - for k, vars in six.iteritems(outputs): - if isinstance(vars, framework.Variable): - vars.op = op - outs[k].append(vars._ivar) - elif isinstance(vars, list) or isinstance(vars, tuple): - for var in vars: - var.op = op - outs[k].append(var._ivar) + # TODO(hy): previous version will cause memory failed + op.inputs = inputs + inps = defaultdict(list) + for k, vars in six.iteritems(inputs): + if isinstance(vars, framework.Variable): + inps[k].append(vars._ivar) + elif isinstance(vars, list) or isinstance(vars, tuple): + for var in vars: + inps[k].append(var._ivar) + + op.outputs = outputs + outs = defaultdict(list) + for k, vars in six.iteritems(outputs): + if isinstance(vars, framework.Variable): + outs[k].append(vars._ivar) + elif isinstance(vars, list) or isinstance(vars, tuple): + for var in vars: + outs[k].append(var._ivar) # record op's trace id op.iop._trace_id = self._trace_id diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 43ba26248a4..3f160b71f44 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1670,13 +1670,18 @@ class Block(object): Operator: the append Operator. """ if in_dygraph_mode(): + attrs = kwargs.get("attrs", {}) + if _dygraph_tracer_._train_mode == False: + # eval mode + attrs['is_test'] = True + op = Operator( block=self, desc=None, type=kwargs.get("type", None), inputs=None, outputs=None, - attrs=kwargs.get("attrs", {})) + attrs=attrs) # record ops in tracer rather than blocks # -- GitLab