提交 06382519 编写于 作者: D dongshuilong

fix train without eval bug

上级 fc6d2114
...@@ -116,7 +116,8 @@ class Engine(object): ...@@ -116,7 +116,8 @@ class Engine(object):
if self.mode == 'train': if self.mode == 'train':
self.train_dataloader = build_dataloader( self.train_dataloader = build_dataloader(
self.config["DataLoader"], "Train", self.device, self.use_dali) self.config["DataLoader"], "Train", self.device, self.use_dali)
if self.mode in ["train", "eval"]: if self.mode == "eval" or (self.mode == "train" and
self.config["Global"]["eval_during_train"]):
if self.eval_mode == "classification": if self.eval_mode == "classification":
self.eval_dataloader = build_dataloader( self.eval_dataloader = build_dataloader(
self.config["DataLoader"], "Eval", self.device, self.config["DataLoader"], "Eval", self.device,
...@@ -140,7 +141,8 @@ class Engine(object): ...@@ -140,7 +141,8 @@ class Engine(object):
if self.mode == "train": if self.mode == "train":
loss_info = self.config["Loss"]["Train"] loss_info = self.config["Loss"]["Train"]
self.train_loss_func = build_loss(loss_info) self.train_loss_func = build_loss(loss_info)
if self.mode in ["train", "eval"]: if self.mode == "eval" or (self.mode == "train" and
self.config["Global"]["eval_during_train"]):
loss_config = self.config.get("Loss", None) loss_config = self.config.get("Loss", None)
if loss_config is not None: if loss_config is not None:
loss_config = loss_config.get("Eval") loss_config = loss_config.get("Eval")
...@@ -163,7 +165,8 @@ class Engine(object): ...@@ -163,7 +165,8 @@ class Engine(object):
else: else:
self.train_metric_func = None self.train_metric_func = None
if self.mode in ["train", "eval"]: if self.mode == "eval" or (self.mode == "train" and
self.config["Global"]["eval_during_train"]):
metric_config = self.config.get("Metric") metric_config = self.config.get("Metric")
if self.eval_mode == "classification": if self.eval_mode == "classification":
if metric_config is not None: if metric_config is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册