未验证 提交 306eadcd 编写于 作者: H Hongyu Liu 提交者: GitHub

fix eval mode bug; test=develop (#17499)

上级 287de41c
...@@ -54,9 +54,7 @@ class Tracer(core.Tracer): ...@@ -54,9 +54,7 @@ class Tracer(core.Tracer):
self._trace_id = 0 self._trace_id = 0
def trace_op(self, op, inputs, outputs, stop_gradient=False): def trace_op(self, op, inputs, outputs, stop_gradient=False):
# TODO(minqiyang): remove this line after we take apart all # TODO(hy): previous version will cause memory failed
# backward grads and forward variables
if self._train_mode:
op.inputs = inputs op.inputs = inputs
inps = defaultdict(list) inps = defaultdict(list)
for k, vars in six.iteritems(inputs): for k, vars in six.iteritems(inputs):
...@@ -74,27 +72,6 @@ class Tracer(core.Tracer): ...@@ -74,27 +72,6 @@ class Tracer(core.Tracer):
elif isinstance(vars, list) or isinstance(vars, tuple): elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars: for var in vars:
outs[k].append(var._ivar) outs[k].append(var._ivar)
else:
inps = defaultdict(list)
for k, vars in six.iteritems(inputs):
if isinstance(vars, framework.Variable):
op.previous_ops.append(vars.op)
inps[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
op.previous_ops.append(var.op)
inps[k].append(var._ivar)
op.outputs = outputs
outs = defaultdict(list)
for k, vars in six.iteritems(outputs):
if isinstance(vars, framework.Variable):
vars.op = op
outs[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
var.op = op
outs[k].append(var._ivar)
# record op's trace id # record op's trace id
op.iop._trace_id = self._trace_id op.iop._trace_id = self._trace_id
......
...@@ -1670,13 +1670,18 @@ class Block(object): ...@@ -1670,13 +1670,18 @@ class Block(object):
Operator: the append Operator. Operator: the append Operator.
""" """
if in_dygraph_mode(): if in_dygraph_mode():
attrs = kwargs.get("attrs", {})
if _dygraph_tracer_._train_mode == False:
# eval mode
attrs['is_test'] = True
op = Operator( op = Operator(
block=self, block=self,
desc=None, desc=None,
type=kwargs.get("type", None), type=kwargs.get("type", None),
inputs=None, inputs=None,
outputs=None, outputs=None,
attrs=kwargs.get("attrs", {})) attrs=attrs)
# record ops in tracer rather than blocks # record ops in tracer rather than blocks
# #
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册