Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
6e962618
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6e962618
编写于
5月 12, 2020
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add unit test for StackedRNNCell.
上级
f75b39e8
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
37 addition
and
0 deletion
+37
-0
hapi/tests/test_text.py
hapi/tests/test_text.py
+34
-0
hapi/text/text.py
hapi/text/text.py
+3
-0
未找到文件。
hapi/tests/test_text.py
浏览文件 @
6e962618
...
@@ -567,6 +567,40 @@ class TestSequenceTaggingInfer(TestSequenceTagging):
...
@@ -567,6 +567,40 @@ class TestSequenceTaggingInfer(TestSequenceTagging):
return
inputs
return
inputs
class
TestStackedRNN
(
ModuleApiTest
):
def
setUp
(
self
):
shape
=
(
2
,
4
,
16
)
self
.
inputs
=
[
np
.
random
.
random
(
shape
).
astype
(
"float32"
)]
self
.
outputs
=
None
self
.
attrs
=
{
"input_size"
:
16
,
"hidden_size"
:
16
,
"num_layers"
:
2
}
self
.
param_states
=
{}
@
staticmethod
def
model_init
(
self
,
input_size
,
hidden_size
,
num_layers
):
cells
=
[
BasicLSTMCell
(
input_size
,
hidden_size
),
BasicLSTMCell
(
hidden_size
,
hidden_size
)
]
stacked_cell
=
StackedRNNCell
(
cells
)
self
.
lstm
=
RNN
(
stacked_cell
)
@
staticmethod
def
model_forward
(
self
,
inputs
):
return
self
.
lstm
(
inputs
)[
0
]
def
make_inputs
(
self
):
inputs
=
[
Input
(
[
None
,
None
,
self
.
inputs
[
-
1
].
shape
[
-
1
]],
"float32"
,
name
=
"input"
),
]
return
inputs
def
test_check_output
(
self
):
self
.
check_output
()
class
TestLSTM
(
ModuleApiTest
):
class
TestLSTM
(
ModuleApiTest
):
def
setUp
(
self
):
def
setUp
(
self
):
shape
=
(
2
,
4
,
16
)
shape
=
(
2
,
4
,
16
)
...
...
hapi/text/text.py
浏览文件 @
6e962618
...
@@ -49,6 +49,8 @@ __all__ = [
...
@@ -49,6 +49,8 @@ __all__ = [
'BasicLSTMCell'
,
'BasicLSTMCell'
,
'BasicGRUCell'
,
'BasicGRUCell'
,
'RNN'
,
'RNN'
,
'BidirectionalRNN'
,
'StackedRNNCell'
,
'StackedLSTMCell'
,
'StackedLSTMCell'
,
'LSTM'
,
'LSTM'
,
'BidirectionalLSTM'
,
'BidirectionalLSTM'
,
...
@@ -1025,6 +1027,7 @@ class StackedRNNCell(RNNCell):
...
@@ -1025,6 +1027,7 @@ class StackedRNNCell(RNNCell):
"""
"""
def
__init__
(
self
,
cells
):
def
__init__
(
self
,
cells
):
super
(
StackedRNNCell
,
self
).
__init__
()
self
.
cells
=
[]
self
.
cells
=
[]
for
i
,
cell
in
enumerate
(
cells
):
for
i
,
cell
in
enumerate
(
cells
):
self
.
cells
.
append
(
self
.
add_sublayer
(
"cell_%d"
%
i
,
cell
))
self
.
cells
.
append
(
self
.
add_sublayer
(
"cell_%d"
%
i
,
cell
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录