未验证 提交 85e2407e 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix mix training for static program (#1234)

上级 d98b8816
......@@ -29,7 +29,7 @@ class CELoss(nn.Layer):
self.epsilon = epsilon
def _labelsmoothing(self, target, class_num):
if target.ndim == 1 or target.shape[-1] != class_num:
if len(target.shape) == 1 or target.shape[-1] != class_num:
one_hot_target = F.one_hot(target, class_num)
else:
one_hot_target = target
......
......@@ -96,7 +96,6 @@ def create_fetchs(out,
"""
fetchs = OrderedDict()
# build loss
# TODO(littletomatodonkey): support mix training
if use_mix:
y_a = paddle.reshape(feeds['y_a'], [-1, 1])
y_b = paddle.reshape(feeds['y_b'], [-1, 1])
......@@ -106,22 +105,14 @@ def create_fetchs(out,
loss_func = build_loss(config["Loss"][mode])
# TODO: support mix training
loss_dict = loss_func(out, target)
if use_mix:
loss_dict = loss_func(out, [y_a, y_b, lam])
else:
loss_dict = loss_func(out, target)
loss_out = loss_dict["loss"]
# if "AMP" in config and config.AMP.get("use_pure_fp16", False):
# loss_out = loss_out.astype("float16")
# if use_mix:
# return loss_func(out, feed_y_a, feed_y_b, feed_lam)
# else:
# return loss_func(out, target)
fetchs['loss'] = (loss_out, AverageMeter('loss', '7.4f', need_avg=True))
assert use_mix is False
# build metric
if not use_mix:
metric_func = build_metrics(config["Metric"][mode])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册