提交 9e4a1045 编写于 作者: W weishengyu

dbg

上级 81c864b9
...@@ -96,15 +96,14 @@ class Trainer(object): ...@@ -96,15 +96,14 @@ class Trainer(object):
def train(self): def train(self):
# build train loss and metric info # build train loss and metric info
if self.train_loss_func is None: if self.train_loss_func is None:
loss_info = self.config.get("Loss", None) loss_info = self.config["Loss"]["Train"]
if loss_info is not None:
loss_info = loss_info["Train"]
self.train_loss_func = build_loss(loss_info) self.train_loss_func = build_loss(loss_info)
if self.train_metric_func is None: if self.train_metric_func is None:
metric_config = self.config.get("Metric", None) metric_config = self.config.get("Metric")
if metric_config is not None: if metric_config is not None:
metric_config = metric_config["Train"] metric_config = metric_config.get("Train")
self.train_metric_func = build_metrics(metric_config) if metric_config is not None:
self.train_metric_func = build_metrics(metric_config)
if self.train_dataloader is None: if self.train_dataloader is None:
self.train_dataloader = build_dataloader(self.config["DataLoader"], self.train_dataloader = build_dataloader(self.config["DataLoader"],
...@@ -223,20 +222,22 @@ class Trainer(object): ...@@ -223,20 +222,22 @@ class Trainer(object):
def eval(self, epoch_id=0): def eval(self, epoch_id=0):
self.model.eval() self.model.eval()
if self.eval_loss_func is None: if self.eval_loss_func is None:
loss_info = self.config.get("Loss", None) loss_config = self.config.get("Loss", None)
if loss_info is not None: if loss_config is not None:
loss_info = loss_info["Eval"] loss_config = loss_config.get("Eval")
self.eval_loss_func = build_loss(loss_info) if loss_config is not None:
self.eval_loss_func = build_loss(loss_config)
if self.eval_mode == "classification": if self.eval_mode == "classification":
if self.eval_dataloader is None: if self.eval_dataloader is None:
self.eval_dataloader = build_dataloader( self.eval_dataloader = build_dataloader(
self.config["DataLoader"], "Eval", self.device) self.config["DataLoader"], "Eval", self.device)
if self.eval_metric_func is None: if self.eval_metric_func is None:
metric_config = self.config.get("Metric", None) metric_config = self.config.get("Metric")
if metric_config is not None: if metric_config is not None:
metric_config = metric_config["Eval"] metric_config = metric_config.get("Eval")
self.eval_metric_func = build_metrics(metric_config) if metric_config is not None:
self.eval_metric_func = build_metrics(metric_config)
eval_result = self.eval_cls(epoch_id) eval_result = self.eval_cls(epoch_id)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册