提交 9e4a1045 编写于 作者: W weishengyu

dbg

上级 81c864b9
......@@ -96,15 +96,14 @@ class Trainer(object):
def train(self):
# build train loss and metric info
if self.train_loss_func is None:
loss_info = self.config.get("Loss", None)
if loss_info is not None:
loss_info = loss_info["Train"]
loss_info = self.config["Loss"]["Train"]
self.train_loss_func = build_loss(loss_info)
if self.train_metric_func is None:
metric_config = self.config.get("Metric", None)
metric_config = self.config.get("Metric")
if metric_config is not None:
metric_config = metric_config["Train"]
self.train_metric_func = build_metrics(metric_config)
metric_config = metric_config.get("Train")
if metric_config is not None:
self.train_metric_func = build_metrics(metric_config)
if self.train_dataloader is None:
self.train_dataloader = build_dataloader(self.config["DataLoader"],
......@@ -223,20 +222,22 @@ class Trainer(object):
def eval(self, epoch_id=0):
self.model.eval()
if self.eval_loss_func is None:
loss_info = self.config.get("Loss", None)
if loss_info is not None:
loss_info = loss_info["Eval"]
self.eval_loss_func = build_loss(loss_info)
loss_config = self.config.get("Loss", None)
if loss_config is not None:
loss_config = loss_config.get("Eval")
if loss_config is not None:
self.eval_loss_func = build_loss(loss_config)
if self.eval_mode == "classification":
if self.eval_dataloader is None:
self.eval_dataloader = build_dataloader(
self.config["DataLoader"], "Eval", self.device)
if self.eval_metric_func is None:
metric_config = self.config.get("Metric", None)
metric_config = self.config.get("Metric")
if metric_config is not None:
metric_config = metric_config["Eval"]
self.eval_metric_func = build_metrics(metric_config)
metric_config = metric_config.get("Eval")
if metric_config is not None:
self.eval_metric_func = build_metrics(metric_config)
eval_result = self.eval_cls(epoch_id)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册