diff --git a/model.py b/model.py index b258be1537197417016f80e3740a2fec92396b1f..0b3fed861948f603553d4c5bf08a4aa490a05566 100644 --- a/model.py +++ b/model.py @@ -369,7 +369,7 @@ class DynamicGraphAdapter(object): def eval(self, inputs, labels, device='CPU', device_ids=None): assert self.model._loss_functions, \ "model not ready, please call `model.prepare()` first" - super(Model, self.model).train() + super(Model, self.model).eval() self.mode = 'eval' inputs = to_list(inputs) labels = to_list(labels) @@ -379,7 +379,7 @@ class DynamicGraphAdapter(object): [to_numpy(l) for l in losses] def test(self, inputs, device='CPU', device_ids=None): - super(Model, self.model).train() + super(Model, self.model).eval() self.mode = 'test' inputs = [to_variable(x) for x in to_list(inputs)] outputs = self.model.forward(*inputs)