Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
26910132
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 1 年 前同步成功
通知
207
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
26910132
编写于
10月 08, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ctcloss can work w/ paddle2.1.2, but loss larger than before
上级
86a221d5
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
33 addition
and
4 deletion
+33
-4
deepspeech/__init__.py
deepspeech/__init__.py
+28
-0
deepspeech/modules/loss.py
deepspeech/modules/loss.py
+5
-4
未找到文件。
deepspeech/__init__.py
浏览文件 @
26910132
...
...
@@ -353,3 +353,31 @@ if not hasattr(paddle.Tensor, 'tolist'):
logger
.
debug
(
"register user tolist to paddle.Tensor, remove this when fixed!"
)
setattr
(
paddle
.
Tensor
,
'tolist'
,
tolist
)
# hack loss
def
ctc_loss
(
logits
,
labels
,
input_lengths
,
label_lengths
,
blank
=
0
,
reduction
=
'mean'
,
norm_by_times
=
True
):
#logger.info("my ctc loss with norm by times")
## https://github.com/PaddlePaddle/Paddle/blob/f5ca2db2cc/paddle/fluid/operators/warpctc_op.h#L403
loss_out
=
paddle
.
fluid
.
layers
.
warpctc
(
logits
,
labels
,
blank
,
norm_by_times
,
input_lengths
,
label_lengths
)
loss_out
=
paddle
.
fluid
.
layers
.
squeeze
(
loss_out
,
[
-
1
])
assert
reduction
in
[
'mean'
,
'sum'
,
'none'
]
if
reduction
==
'mean'
:
loss_out
=
paddle
.
mean
(
loss_out
/
label_lengths
)
elif
reduction
==
'sum'
:
loss_out
=
paddle
.
sum
(
loss_out
)
return
loss_out
logger
.
debug
(
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!"
)
F
.
ctc_loss
=
ctc_loss
deepspeech/modules/loss.py
浏览文件 @
26910132
...
...
@@ -67,10 +67,10 @@ class CTCLoss(nn.Layer):
except
ValueError
:
# Some function, e.g. built-in function, are failed
param
=
{}
_kwargs
=
{
k
:
v
for
k
,
v
in
self
.
kwargs
.
items
()
if
k
in
param
}
self
.
_kwargs
=
{
k
:
v
for
k
,
v
in
self
.
kwargs
.
items
()
if
k
in
param
}
_notin
=
{
k
:
v
for
k
,
v
in
self
.
kwargs
.
items
()
if
k
not
in
param
}
logger
.
info
(
f
"
{
self
.
loss
}
kwargs:
{
_kwargs
}
, not support:
{
_notin
}
"
)
self
.
loss_fn
=
partial
(
self
.
loss
.
forward
,
**
_kwargs
)
logger
.
info
(
f
"
{
self
.
loss
}
kwargs:
{
self
.
_kwargs
}
, not support:
{
_notin
}
"
)
#
self.loss_fn = partial(self.loss.forward, **_kwargs)
def
forward
(
self
,
logits
,
ys_pad
,
hlens
,
ys_lens
):
"""Compute CTC loss.
...
...
@@ -90,7 +90,8 @@ class CTCLoss(nn.Layer):
# logits: (B, L, D) -> (L, B, D)
logits
=
logits
.
transpose
([
1
,
0
,
2
])
ys_pad
=
ys_pad
.
astype
(
paddle
.
int32
)
loss
=
self
.
loss_fn
(
logits
,
ys_pad
,
hlens
,
ys_lens
)
#loss = self.loss_fn(logits, ys_pad, hlens, ys_lens)
loss
=
self
.
loss
(
logits
,
ys_pad
,
hlens
,
ys_lens
)
if
self
.
batch_average
:
# Batch-size average
loss
=
loss
/
B
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录