Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
506f2bfd
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
506f2bfd
编写于
10月 24, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add lm interface
上级
12ea02fc
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
73 addition
and
3 deletion
+73
-3
deepspeech/models/lm/transformer.py
deepspeech/models/lm/transformer.py
+4
-3
deepspeech/models/lm_interface.py
deepspeech/models/lm_interface.py
+69
-0
未找到文件。
deepspeech/models/lm/transformer.py
浏览文件 @
506f2bfd
...
@@ -23,9 +23,10 @@ import paddle.nn.functional as F
...
@@ -23,9 +23,10 @@ import paddle.nn.functional as F
from
deepspeech.modules.mask
import
subsequent_mask
from
deepspeech.modules.mask
import
subsequent_mask
from
deepspeech.modules.encoder
import
TransformerEncoder
from
deepspeech.modules.encoder
import
TransformerEncoder
from
deepspeech.decoders.scorers.scorer_interface
import
BatchScorerInterface
from
deepspeech.decoders.scorers.scorer_interface
import
BatchScorerInterface
from
deepspeech.models.lm_interface
import
#LMInterface
#LMInterface
class
TransformerLM
(
nn
.
Layer
,
BatchScorerInterface
):
class
TransformerLM
(
nn
.
Layer
,
LMInterface
,
BatchScorerInterface
):
def
__init__
(
def
__init__
(
self
,
self
,
n_vocab
:
int
,
n_vocab
:
int
,
...
@@ -90,7 +91,7 @@ class TransformerLM(nn.Layer, BatchScorerInterface):
...
@@ -90,7 +91,7 @@ class TransformerLM(nn.Layer, BatchScorerInterface):
return
ys_mask
.
unsqueeze
(
-
2
)
&
m
return
ys_mask
.
unsqueeze
(
-
2
)
&
m
def
forward
(
def
forward
(
self
,
x
:
paddle
.
Tensor
,
xlens
,
t
:
paddle
.
Tensor
self
,
x
:
paddle
.
Tensor
,
t
:
paddle
.
Tensor
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
)
->
Tuple
[
paddle
.
Tensor
,
paddle
.
Tensor
,
paddle
.
Tensor
]:
"""Compute LM loss value from buffer sequences.
"""Compute LM loss value from buffer sequences.
...
@@ -110,11 +111,11 @@ class TransformerLM(nn.Layer, BatchScorerInterface):
...
@@ -110,11 +111,11 @@ class TransformerLM(nn.Layer, BatchScorerInterface):
"""
"""
xm
=
x
!=
0
xm
=
x
!=
0
xlen
=
xm
.
sum
(
axis
=
1
)
if
self
.
embed_drop
is
not
None
:
if
self
.
embed_drop
is
not
None
:
emb
=
self
.
embed_drop
(
self
.
embed
(
x
))
emb
=
self
.
embed_drop
(
self
.
embed
(
x
))
else
:
else
:
emb
=
self
.
embed
(
x
)
emb
=
self
.
embed
(
x
)
xlen
=
xm
.
sum
(
axis
=
1
)
h
,
_
=
self
.
encoder
(
emb
,
xlen
)
h
,
_
=
self
.
encoder
(
emb
,
xlen
)
y
=
self
.
decoder
(
h
)
y
=
self
.
decoder
(
h
)
loss
=
F
.
cross_entropy
(
y
.
view
(
-
1
,
y
.
shape
[
-
1
]),
t
.
view
(
-
1
),
reduction
=
"none"
)
loss
=
F
.
cross_entropy
(
y
.
view
(
-
1
,
y
.
shape
[
-
1
]),
t
.
view
(
-
1
),
reduction
=
"none"
)
...
...
deepspeech/models/lm_interface.py
0 → 100644
浏览文件 @
506f2bfd
"""Language model interface."""
import
argparse
from
deepspeech.decoders.scorers.scorer_interface
import
ScorerInterface
from
deepspeech.utils.dynamic_import
import
dynamic_import
class
LMInterface
(
ScorerInterface
):
"""LM Interface for ESPnet model implementation."""
@
staticmethod
def
add_arguments
(
parser
):
"""Add arguments to command line argument parser."""
return
parser
@
classmethod
def
build
(
cls
,
n_vocab
:
int
,
**
kwargs
):
"""Initialize this class with python-level args.
Args:
idim (int): The number of vocabulary.
Returns:
LMinterface: A new instance of LMInterface.
"""
args
=
argparse
.
Namespace
(
**
kwargs
)
return
cls
(
n_vocab
,
args
)
def
forward
(
self
,
x
,
t
):
"""Compute LM loss value from buffer sequences.
Args:
x (torch.Tensor): Input ids. (batch, len)
t (torch.Tensor): Target ids. (batch, len)
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
loss to backward (scalar),
negative log-likelihood of t: -log p(t) (scalar) and
the number of elements in x (scalar)
Notes:
The last two return values are used
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
raise
NotImplementedError
(
"forward method is not implemented"
)
predefined_lms
=
{
"transformer"
:
"deepspeech.models.lm.transformer:TransformerLM"
,
}
def
dynamic_import_lm
(
module
):
"""Import LM class dynamically.
Args:
module (str): module_name:class_name or alias in `predefined_lms`
Returns:
type: LM class
"""
model_class
=
dynamic_import
(
module
,
predefined_lms
)
assert
issubclass
(
model_class
,
LMInterface
),
f
"
{
module
}
does not implement LMInterface"
return
model_class
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录