未验证 提交 85e2407e 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix mix training for static program (#1234)

上级 d98b8816
...@@ -29,7 +29,7 @@ class CELoss(nn.Layer): ...@@ -29,7 +29,7 @@ class CELoss(nn.Layer):
self.epsilon = epsilon self.epsilon = epsilon
def _labelsmoothing(self, target, class_num): def _labelsmoothing(self, target, class_num):
if target.ndim == 1 or target.shape[-1] != class_num: if len(target.shape) == 1 or target.shape[-1] != class_num:
one_hot_target = F.one_hot(target, class_num) one_hot_target = F.one_hot(target, class_num)
else: else:
one_hot_target = target one_hot_target = target
......
...@@ -96,7 +96,6 @@ def create_fetchs(out, ...@@ -96,7 +96,6 @@ def create_fetchs(out,
""" """
fetchs = OrderedDict() fetchs = OrderedDict()
# build loss # build loss
# TODO(littletomatodonkey): support mix training
if use_mix: if use_mix:
y_a = paddle.reshape(feeds['y_a'], [-1, 1]) y_a = paddle.reshape(feeds['y_a'], [-1, 1])
y_b = paddle.reshape(feeds['y_b'], [-1, 1]) y_b = paddle.reshape(feeds['y_b'], [-1, 1])
...@@ -106,22 +105,14 @@ def create_fetchs(out, ...@@ -106,22 +105,14 @@ def create_fetchs(out,
loss_func = build_loss(config["Loss"][mode]) loss_func = build_loss(config["Loss"][mode])
# TODO: support mix training if use_mix:
loss_dict = loss_func(out, target) loss_dict = loss_func(out, [y_a, y_b, lam])
else:
loss_dict = loss_func(out, target)
loss_out = loss_dict["loss"] loss_out = loss_dict["loss"]
# if "AMP" in config and config.AMP.get("use_pure_fp16", False):
# loss_out = loss_out.astype("float16")
# if use_mix:
# return loss_func(out, feed_y_a, feed_y_b, feed_lam)
# else:
# return loss_func(out, target)
fetchs['loss'] = (loss_out, AverageMeter('loss', '7.4f', need_avg=True)) fetchs['loss'] = (loss_out, AverageMeter('loss', '7.4f', need_avg=True))
assert use_mix is False
# build metric # build metric
if not use_mix: if not use_mix:
metric_func = build_metrics(config["Metric"][mode]) metric_func = build_metrics(config["Metric"][mode])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册