未验证 提交 1a0c0ae0 编写于 作者: J Jackwaterveg 提交者: GitHub

Merge pull request #892 from PaddlePaddle/ctc_grad_norm

[WIP] ctc loss api debug
...@@ -362,11 +362,19 @@ def ctc_loss(logits, ...@@ -362,11 +362,19 @@ def ctc_loss(logits,
label_lengths, label_lengths,
blank=0, blank=0,
reduction='mean', reduction='mean',
norm_by_times=True): norm_by_times=True,
norm_by_batchsize=False,
norm_by_total_logits_len=False):
#logger.info("my ctc loss with norm by times") #logger.info("my ctc loss with norm by times")
## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403 ## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403
loss_out = paddle.fluid.layers.warpctc(logits, labels, blank, norm_by_times, loss_out = paddle.fluid.layers.warpctc(
input_lengths, label_lengths) logits,
labels,
blank,
norm_by_times,
input_lengths,
label_lengths,
norm_by_batchsize, )
loss_out = paddle.fluid.layers.squeeze(loss_out, [-1]) loss_out = paddle.fluid.layers.squeeze(loss_out, [-1])
assert reduction in ['mean', 'sum', 'none'] assert reduction in ['mean', 'sum', 'none']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册