提交 e595200f 编写于 作者: Y Yang Zhang

Fix a bug where `train()` is called for all mode in dygraph

上级 0935e730
...@@ -369,7 +369,7 @@ class DynamicGraphAdapter(object): ...@@ -369,7 +369,7 @@ class DynamicGraphAdapter(object):
def eval(self, inputs, labels, device='CPU', device_ids=None): def eval(self, inputs, labels, device='CPU', device_ids=None):
assert self.model._loss_functions, \ assert self.model._loss_functions, \
"model not ready, please call `model.prepare()` first" "model not ready, please call `model.prepare()` first"
super(Model, self.model).train() super(Model, self.model).eval()
self.mode = 'eval' self.mode = 'eval'
inputs = to_list(inputs) inputs = to_list(inputs)
labels = to_list(labels) labels = to_list(labels)
...@@ -379,7 +379,7 @@ class DynamicGraphAdapter(object): ...@@ -379,7 +379,7 @@ class DynamicGraphAdapter(object):
[to_numpy(l) for l in losses] [to_numpy(l) for l in losses]
def test(self, inputs, device='CPU', device_ids=None): def test(self, inputs, device='CPU', device_ids=None):
super(Model, self.model).train() super(Model, self.model).eval()
self.mode = 'test' self.mode = 'test'
inputs = [to_variable(x) for x in to_list(inputs)] inputs = [to_variable(x) for x in to_list(inputs)]
outputs = self.model.forward(*inputs) outputs = self.model.forward(*inputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册