Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
DeepSpeech
提交
0f59459a
D
DeepSpeech
项目概览
PaddlePaddle
/
DeepSpeech
大约 2 年 前同步成功
通知
210
Star
8425
Fork
1598
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
245
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeech
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
245
Issue
245
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0f59459a
编写于
10月 15, 2021
作者:
H
Hui Zhang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add LayerDict
上级
666b42d1
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
152 addition
and
1 deletion
+152
-1
deepspeech/__init__.py
deepspeech/__init__.py
+151
-0
deepspeech/utils/utility.py
deepspeech/utils/utility.py
+1
-1
未找到文件。
deepspeech/__init__.py
浏览文件 @
0f59459a
...
@@ -355,6 +355,8 @@ if not hasattr(paddle.Tensor, 'tolist'):
...
@@ -355,6 +355,8 @@ if not hasattr(paddle.Tensor, 'tolist'):
setattr
(
paddle
.
Tensor
,
'tolist'
,
tolist
)
setattr
(
paddle
.
Tensor
,
'tolist'
,
tolist
)
########### hcak paddle.nn.functional #############
# hack loss
# hack loss
def
ctc_loss
(
logits
,
def
ctc_loss
(
logits
,
labels
,
labels
,
...
@@ -381,3 +383,152 @@ logger.debug(
...
@@ -381,3 +383,152 @@ logger.debug(
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!"
"override ctc_loss of paddle.nn.functional if exists, remove this when fixed!"
)
)
F
.
ctc_loss
=
ctc_loss
F
.
ctc_loss
=
ctc_loss
########### hcak paddle.nn #############
from
paddle.nn
import
Layer
from
typing
import
Optional
from
typing
import
Mapping
from
typing
import
Iterable
from
typing
import
Tuple
from
typing
import
Iterator
from
collections
import
OrderedDict
,
abc
as
container_abcs
class
LayerDict
(
paddle
.
nn
.
Layer
):
r
"""Holds submodules in a dictionary.
:class:`~paddle.nn.LayerDict` can be indexed like a regular Python dictionary,
but modules it contains are properly registered, and will be visible by all
:class:`~paddle.nn.Layer` methods.
:class:`~paddle.nn.LayerDict` is an **ordered** dictionary that respects
* the order of insertion, and
* in :meth:`~paddle.nn.LayerDict.update`, the order of the merged
``OrderedDict``, ``dict`` (started from Python 3.6) or another
:class:`~paddle.nn.LayerDict` (the argument to
:meth:`~paddle.nn.LayerDict.update`).
Note that :meth:`~paddle.nn.LayerDict.update` with other unordered mapping
types (e.g., Python's plain ``dict`` before Python version 3.6) does not
preserve the order of the merged mapping.
Args:
modules (iterable, optional): a mapping (dictionary) of (string: module)
or an iterable of key-value pairs of type (string, module)
Example::
class MyModule(nn.Layer):
def __init__(self):
super(MyModule, self).__init__()
self.choices = nn.LayerDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
})
self.activations = nn.LayerDict([
['lrelu', nn.LeakyReLU()],
['prelu', nn.PReLU()]
])
def forward(self, x, choice, act):
x = self.choices[choice](x)
x = self.activations[act](x)
return x
"""
def
__init__
(
self
,
modules
:
Optional
[
Mapping
[
str
,
Layer
]]
=
None
)
->
None
:
super
(
LayerDict
,
self
).
__init__
()
if
modules
is
not
None
:
self
.
update
(
modules
)
def
__getitem__
(
self
,
key
:
str
)
->
Layer
:
return
self
.
_modules
[
key
]
def
__setitem__
(
self
,
key
:
str
,
module
:
Layer
)
->
None
:
self
.
add_module
(
key
,
module
)
def
__delitem__
(
self
,
key
:
str
)
->
None
:
del
self
.
_modules
[
key
]
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_modules
)
def
__iter__
(
self
)
->
Iterator
[
str
]:
return
iter
(
self
.
_modules
)
def
__contains__
(
self
,
key
:
str
)
->
bool
:
return
key
in
self
.
_modules
def
clear
(
self
)
->
None
:
"""Remove all items from the LayerDict.
"""
self
.
_modules
.
clear
()
def
pop
(
self
,
key
:
str
)
->
Layer
:
r
"""Remove key from the LayerDict and return its module.
Args:
key (string): key to pop from the LayerDict
"""
v
=
self
[
key
]
del
self
[
key
]
return
v
def
keys
(
self
)
->
Iterable
[
str
]:
r
"""Return an iterable of the LayerDict keys.
"""
return
self
.
_modules
.
keys
()
def
items
(
self
)
->
Iterable
[
Tuple
[
str
,
Layer
]]:
r
"""Return an iterable of the LayerDict key/value pairs.
"""
return
self
.
_modules
.
items
()
def
values
(
self
)
->
Iterable
[
Layer
]:
r
"""Return an iterable of the LayerDict values.
"""
return
self
.
_modules
.
values
()
def
update
(
self
,
modules
:
Mapping
[
str
,
Layer
])
->
None
:
r
"""Update the :class:`~paddle.nn.LayerDict` with the key-value pairs from a
mapping or an iterable, overwriting existing keys.
.. note::
If :attr:`modules` is an ``OrderedDict``, a :class:`~paddle.nn.LayerDict`, or
an iterable of key-value pairs, the order of new elements in it is preserved.
Args:
modules (iterable): a mapping (dictionary) from string to :class:`~paddle.nn.Layer`,
or an iterable of key-value pairs of type (string, :class:`~paddle.nn.Layer`)
"""
if
not
isinstance
(
modules
,
container_abcs
.
Iterable
):
raise
TypeError
(
"LayerDict.update should be called with an "
"iterable of key/value pairs, but got "
+
type
(
modules
).
__name__
)
if
isinstance
(
modules
,
(
OrderedDict
,
LayerDict
,
container_abcs
.
Mapping
)):
for
key
,
module
in
modules
.
items
():
self
[
key
]
=
module
else
:
# modules here can be a list with two items
for
j
,
m
in
enumerate
(
modules
):
if
not
isinstance
(
m
,
container_abcs
.
Iterable
):
raise
TypeError
(
"LayerDict update sequence element "
"#"
+
str
(
j
)
+
" should be Iterable; is"
+
type
(
m
).
__name__
)
if
not
len
(
m
)
==
2
:
raise
ValueError
(
"LayerDict update sequence element "
"#"
+
str
(
j
)
+
" has length "
+
str
(
len
(
m
))
+
"; 2 is required"
)
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
# that's too cumbersome to type correctly with overloads, so we add an ignore here
self
[
m
[
0
]]
=
m
[
1
]
# type: ignore[assignment]
# remove forward alltogether to fallback on Module's _forward_unimplemented
if
not
hasattr
(
paddle
.
nn
,
'LayerDict'
):
logger
.
debug
(
"register user LayerDict to paddle.nn, remove this when fixed!"
)
setattr
(
paddle
.
nn
,
'LayerDict'
,
LayerDict
)
deepspeech/utils/utility.py
浏览文件 @
0f59459a
...
@@ -42,7 +42,7 @@ def all_version():
...
@@ -42,7 +42,7 @@ def all_version():
"paddle_commit"
:
paddle
.
version
.
commit
,
"paddle_commit"
:
paddle
.
version
.
commit
,
"soundfile"
:
soundfile
.
__version__
,
"soundfile"
:
soundfile
.
__version__
,
}
}
logger
.
info
(
f
"Deps Module Version:
{
pformat
(
vers
.
items
(
))
}
"
)
logger
.
info
(
f
"Deps Module Version:
{
pformat
(
list
(
vers
.
items
()
))
}
"
)
@
contextmanager
@
contextmanager
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录