From e595200f6ee039226b55bbe2a9c634bb621dcf93 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Mon, 6 Jan 2020 10:44:21 +0800 Subject: [PATCH] Fix a bug where `train()` is called for all mode in dygraph --- model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model.py b/model.py index b258be1..0b3fed8 100644 --- a/model.py +++ b/model.py @@ -369,7 +369,7 @@ class DynamicGraphAdapter(object): def eval(self, inputs, labels, device='CPU', device_ids=None): assert self.model._loss_functions, \ "model not ready, please call `model.prepare()` first" - super(Model, self.model).train() + super(Model, self.model).eval() self.mode = 'eval' inputs = to_list(inputs) labels = to_list(labels) @@ -379,7 +379,7 @@ class DynamicGraphAdapter(object): [to_numpy(l) for l in losses] def test(self, inputs, device='CPU', device_ids=None): - super(Model, self.model).train() + super(Model, self.model).eval() self.mode = 'test' inputs = [to_variable(x) for x in to_list(inputs)] outputs = self.model.forward(*inputs) -- GitLab