提交 e595200f 编写于 作者: Y Yang Zhang

Fix a bug where `train()` is called for all mode in dygraph

上级 0935e730
......@@ -369,7 +369,7 @@ class DynamicGraphAdapter(object):
def eval(self, inputs, labels, device='CPU', device_ids=None):
assert self.model._loss_functions, \
"model not ready, please call `model.prepare()` first"
super(Model, self.model).train()
super(Model, self.model).eval()
self.mode = 'eval'
inputs = to_list(inputs)
labels = to_list(labels)
......@@ -379,7 +379,7 @@ class DynamicGraphAdapter(object):
[to_numpy(l) for l in losses]
def test(self, inputs, device='CPU', device_ids=None):
super(Model, self.model).train()
super(Model, self.model).eval()
self.mode = 'test'
inputs = [to_variable(x) for x in to_list(inputs)]
outputs = self.model.forward(*inputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册