提交 2a033328 编写于 作者: W weishengyu

dbg

上级 1c7a23c5
......@@ -96,7 +96,12 @@ class Trainer(object):
def train(self):
# build train loss and metric info
if self.train_loss_func is None:
self.train_loss_func = build_loss(self.config["Loss"])
loss_info = self.config.get("Loss", None)
if loss_info is None:
loss_info = [{"CELoss": {"weight": 1.0}}]
else:
loss_info = loss_info["Train"]
self.train_loss_func = build_loss(loss_info)
if self.train_metric_func is None:
metric_config = self.config.get("Metric", None)
if metric_config is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册