Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
85e422bb
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
85e422bb
编写于
4月 20, 2020
作者:
X
xyzhou-puck
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine text.py
上级
ed14907e
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
637 addition
and
115 deletion
+637
-115
hapi/text/text.py
hapi/text/text.py
+637
-115
未找到文件。
hapi/text/text.py
浏览文件 @
85e422bb
...
@@ -44,13 +44,12 @@ from paddle.fluid import layers
...
@@ -44,13 +44,12 @@ from paddle.fluid import layers
from
paddle.fluid.dygraph
import
Layer
from
paddle.fluid.dygraph
import
Layer
from
paddle.fluid.layers
import
BeamSearchDecoder
from
paddle.fluid.layers
import
BeamSearchDecoder
__all__
=
[
__all__
=
[
'RNNCell'
,
'BasicLSTMCell'
,
'BasicGRUCell'
,
'RNN'
,
'DynamicDecode'
,
'RNNCell'
,
'BasicLSTMCell'
,
'BasicGRUCell'
,
'RNN'
,
'DynamicDecode'
,
'BeamSearchDecoder'
,
'MultiHeadAttention'
,
'FFN'
,
'BeamSearchDecoder'
,
'MultiHeadAttention'
,
'FFN'
,
'TransformerEncoderLayer'
,
'TransformerEncoder'
,
'TransformerDecoderLayer'
,
'TransformerEncoderLayer'
,
'TransformerEncoder'
,
'TransformerDecoderLayer'
,
'TransformerDecoder'
,
'TransformerBeamSearchDecoder'
,
'
GRUCell'
,
'GRUEncoderCell'
,
'TransformerDecoder'
,
'TransformerBeamSearchDecoder'
,
'
BiGRU'
,
'
BiGRU'
,
'
Linear_chain_crf'
,
'Crf_decoding'
,
'SequenceTagging'
'Linear_chain_crf'
,
'Crf_decoding'
,
'SequenceTagging'
]
]
...
@@ -219,7 +218,19 @@ class BasicLSTMCell(RNNCell):
...
@@ -219,7 +218,19 @@ class BasicLSTMCell(RNNCell):
gate_activation
=
None
,
gate_activation
=
None
,
activation
=
None
,
activation
=
None
,
forget_bias
=
1.0
,
forget_bias
=
1.0
,
dtype
=
'float32'
):
dtype
=
'float32'
,
forget_gate_weights
=
{
"w"
:
None
,
"h"
:
None
,
"b"
:
None
},
input_gate_weights
=
{
"w"
:
None
,
"h"
:
None
,
"b"
:
None
},
output_gate_weights
=
{
"w"
:
None
,
"h"
:
None
,
"b"
:
None
},
cell_weights
=
{
"w"
:
None
,
"h"
:
None
,
"b"
:
None
}):
super
(
BasicLSTMCell
,
self
).
__init__
()
super
(
BasicLSTMCell
,
self
).
__init__
()
self
.
_hidden_size
=
hidden_size
self
.
_hidden_size
=
hidden_size
...
@@ -233,6 +244,20 @@ class BasicLSTMCell(RNNCell):
...
@@ -233,6 +244,20 @@ class BasicLSTMCell(RNNCell):
self
.
_dtype
=
dtype
self
.
_dtype
=
dtype
self
.
_input_size
=
input_size
self
.
_input_size
=
input_size
self
.
use_customized_weight
=
False
for
_weights
in
[
forget_gate_weights
,
input_gate_weights
,
output_gate_weights
,
cell_weights
]:
for
_key
in
_weights
:
if
_weights
[
_key
]
is
not
None
:
self
.
use_customized_weight
=
True
break
if
self
.
use_customized_weight
:
break
if
not
self
.
use_customized_weight
:
self
.
_weight
=
self
.
create_parameter
(
self
.
_weight
=
self
.
create_parameter
(
attr
=
self
.
_param_attr
,
attr
=
self
.
_param_attr
,
shape
=
[
shape
=
[
...
@@ -245,13 +270,199 @@ class BasicLSTMCell(RNNCell):
...
@@ -245,13 +270,199 @@ class BasicLSTMCell(RNNCell):
shape
=
[
4
*
self
.
_hidden_size
],
shape
=
[
4
*
self
.
_hidden_size
],
dtype
=
self
.
_dtype
,
dtype
=
self
.
_dtype
,
is_bias
=
True
)
is_bias
=
True
)
else
:
if
"w"
in
forget_gate_weights
and
forget_gate_weights
[
"w"
]
is
not
None
:
self
.
fg_w
=
forget_gate_weights
[
"w"
]
else
:
if
self
.
_param_attr
is
not
None
and
self
.
_param_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
self
.
_param_attr
)
tmp_param_attr
.
name
+=
"_forget_gate_w"
else
:
tmp_param_attr
=
self
.
_param_attr
self
.
fg_w
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_input_size
,
self
.
_hidden_size
],
dtype
=
self
.
_dtype
)
if
"h"
in
forget_gate_weights
and
forget_gate_weights
[
"h"
]
is
not
None
:
self
.
fg_h
=
forget_gate_weights
[
"h"
]
else
:
if
self
.
_param_attr
is
not
None
and
self
.
_param_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
self
.
_param_attr
)
tmp_param_attr
.
name
+=
"_forget_gate_h"
else
:
tmp_param_attr
=
self
.
_param_attr
self
.
fg_h
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_hidden_size
,
self
.
_hidden_size
],
dtype
=
self
.
_dtype
)
if
"b"
in
forget_gate_weights
and
forget_gate_weights
[
"b"
]
is
not
None
:
self
.
fg_b
=
forget_gate_weights
[
"b"
]
else
:
if
self
.
_bias_attr
is
not
None
and
self
.
_bias_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
self
.
_bias_attr
)
tmp_param_attr
.
name
+=
"_forget_gate_b"
else
:
tmp_param_attr
=
self
.
_bias_attr
self
.
fg_b
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_hidden_size
],
dtype
=
self
.
_dtype
,
is_bias
=
True
)
if
"w"
in
input_gate_weights
and
input_gate_weights
[
"w"
]
is
not
None
:
self
.
ig_w
=
input_gate_weights
[
"w"
]
else
:
if
self
.
_param_attr
is
not
None
and
self
.
_param_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
self
.
_param_attr
)
tmp_param_attr
.
name
+=
"_input_gate_w"
else
:
tmp_param_attr
=
self
.
_param_attr
self
.
ig_w
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_input_size
,
self
.
_hidden_size
],
dtype
=
self
.
_dtype
)
if
"h"
in
input_gate_weights
and
input_gate_weights
[
"h"
]
is
not
None
:
self
.
ig_h
=
input_gate_weights
[
"h"
]
else
:
if
self
.
_param_attr
is
not
None
and
self
.
_param_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
self
.
_param_attr
)
tmp_param_attr
.
name
+=
"_input_gate_h"
else
:
tmp_param_attr
=
self
.
_param_attr
self
.
ig_h
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_hidden_size
,
self
.
_hidden_size
],
dtype
=
self
.
_dtype
)
if
"b"
in
input_gate_weights
and
input_gate_weights
[
"b"
]
is
not
None
:
self
.
ig_b
=
input_gate_weights
[
"b"
]
else
:
if
self
.
_bias_attr
is
not
None
and
self
.
_bias_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
self
.
_bias_attr
)
tmp_param_attr
.
name
+=
"_input_gate_b"
else
:
tmp_param_attr
=
self
.
_bias_attr
self
.
ig_b
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_hidden_size
],
dtype
=
self
.
_dtype
,
is_bias
=
True
)
if
"w"
in
output_gate_weights
and
output_gate_weights
[
"w"
]
is
not
None
:
self
.
og_w
=
output_gate_weights
[
"w"
]
else
:
if
self
.
_param_attr
is
not
None
and
self
.
_param_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
self
.
_param_attr
)
tmp_param_attr
.
name
+=
"_output_gate_w"
else
:
tmp_param_attr
=
self
.
_param_attr
self
.
og_w
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_input_size
,
self
.
_hidden_size
],
dtype
=
self
.
_dtype
)
if
"h"
in
output_gate_weights
and
output_gate_weights
[
"h"
]
is
not
None
:
self
.
og_h
=
output_gate_weights
[
"h"
]
else
:
if
self
.
_param_attr
is
not
None
and
self
.
_param_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
self
.
_param_attr
)
tmp_param_attr
.
name
+=
"_output_gate_h"
else
:
tmp_param_attr
=
self
.
_param_attr
self
.
og_h
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_hidden_size
,
self
.
_hidden_size
],
dtype
=
self
.
_dtype
)
if
"b"
in
output_gate_weights
and
output_gate_weights
[
"b"
]
is
not
None
:
self
.
og_b
=
output_gate_weights
[
"b"
]
else
:
if
self
.
_bias_attr
is
not
None
and
self
.
_bias_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
self
.
_bias_attr
)
tmp_param_attr
.
name
+=
"_output_gate_b"
else
:
tmp_param_attr
=
self
.
_bias_attr
self
.
og_b
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_hidden_size
],
dtype
=
self
.
_dtype
,
is_bias
=
True
)
if
"w"
in
cell_weights
and
cell_weights
[
"w"
]
is
not
None
:
self
.
c_w
=
cell_weights
[
"w"
]
else
:
if
self
.
_param_attr
is
not
None
and
self
.
_param_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
self
.
_param_attr
)
tmp_param_attr
.
name
+=
"_cell_w"
else
:
tmp_param_attr
=
self
.
_param_attr
self
.
c_w
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_input_size
,
self
.
_hidden_size
],
dtype
=
self
.
_dtype
)
if
"h"
in
cell_weights
and
cell_weights
[
"h"
]
is
not
None
:
self
.
c_h
=
cell_weights
[
"h"
]
else
:
if
self
.
_param_attr
is
not
None
and
self
.
_param_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
self
.
_param_attr
)
tmp_param_attr
.
name
+=
"_cell_h"
else
:
tmp_param_attr
=
self
.
_param_attr
self
.
c_h
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_hidden_size
,
self
.
_hidden_size
],
dtype
=
self
.
_dtype
)
if
"b"
in
cell_weights
and
cell_weights
[
"b"
]
is
not
None
:
self
.
c_b
=
cell_weights
[
"b"
]
else
:
if
self
.
_bias_attr
is
not
None
and
self
.
_bias_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
self
.
_bias_attr
)
tmp_param_attr
.
name
+=
"_cell_b"
else
:
tmp_param_attr
=
self
.
_bias_attr
self
.
c_b
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_hidden_size
],
dtype
=
self
.
_dtype
,
is_bias
=
True
)
def
forward
(
self
,
input
,
state
):
def
forward
(
self
,
input
,
state
):
if
self
.
use_customized_weight
:
weight_w
=
fluid
.
layers
.
concat
(
[
self
.
ig_w
,
self
.
c_w
,
self
.
fg_w
,
self
.
og_w
],
axis
=-
1
)
weight_h
=
fluid
.
layers
.
concat
(
[
self
.
ig_h
,
self
.
c_h
,
self
.
fg_h
,
self
.
og_h
],
axis
=-
1
)
_weight
=
fluid
.
layers
.
concat
([
weight_w
,
weight_h
],
axis
=
0
)
_bias
=
fluid
.
layers
.
concat
(
[
self
.
ig_b
,
self
.
c_b
,
self
.
fg_b
,
self
.
og_b
])
else
:
_weight
=
self
.
_weight
_bias
=
self
.
_bias
pre_hidden
,
pre_cell
=
state
pre_hidden
,
pre_cell
=
state
concat_input_hidden
=
layers
.
concat
([
input
,
pre_hidden
],
1
)
concat_input_hidden
=
layers
.
concat
([
input
,
pre_hidden
],
1
)
gate_input
=
layers
.
matmul
(
x
=
concat_input_hidden
,
y
=
self
.
_weight
)
gate_input
=
layers
.
matmul
(
x
=
concat_input_hidden
,
y
=
_weight
)
gate_input
=
layers
.
elementwise_add
(
gate_input
,
self
.
_bias
)
gate_input
=
layers
.
elementwise_add
(
gate_input
,
_bias
)
i
,
j
,
f
,
o
=
layers
.
split
(
gate_input
,
num_or_sections
=
4
,
dim
=-
1
)
i
,
j
,
f
,
o
=
layers
.
split
(
gate_input
,
num_or_sections
=
4
,
dim
=-
1
)
new_cell
=
layers
.
elementwise_add
(
new_cell
=
layers
.
elementwise_add
(
layers
.
elementwise_mul
(
layers
.
elementwise_mul
(
...
@@ -308,7 +519,16 @@ class BasicGRUCell(RNNCell):
...
@@ -308,7 +519,16 @@ class BasicGRUCell(RNNCell):
bias_attr
=
None
,
bias_attr
=
None
,
gate_activation
=
None
,
gate_activation
=
None
,
activation
=
None
,
activation
=
None
,
dtype
=
'float32'
):
dtype
=
'float32'
,
update_gate_weights
=
{
"w"
:
None
,
"h"
:
None
,
"b"
:
None
},
reset_gate_weights
=
{
"w"
:
None
,
"h"
:
None
,
"b"
:
None
},
cell_weights
=
{
"w"
:
None
,
"h"
:
None
,
"b"
:
None
}):
super
(
BasicGRUCell
,
self
).
__init__
()
super
(
BasicGRUCell
,
self
).
__init__
()
self
.
_input_size
=
input_size
self
.
_input_size
=
input_size
self
.
_hidden_size
=
hidden_size
self
.
_hidden_size
=
hidden_size
...
@@ -318,6 +538,20 @@ class BasicGRUCell(RNNCell):
...
@@ -318,6 +538,20 @@ class BasicGRUCell(RNNCell):
self
.
_activation
=
activation
or
layers
.
tanh
self
.
_activation
=
activation
or
layers
.
tanh
self
.
_dtype
=
dtype
self
.
_dtype
=
dtype
assert
isinstance
(
update_gate_weights
,
dict
)
assert
isinstance
(
reset_gate_weights
,
dict
)
assert
isinstance
(
cell_weights
,
dict
)
self
.
use_customized_weight
=
False
for
_weights
in
[
update_gate_weights
,
reset_gate_weights
,
cell_weights
]:
for
_key
in
_weights
:
if
_weights
[
_key
]
is
not
None
:
self
.
use_customized_weight
=
True
if
self
.
use_customized_weight
:
break
if
self
.
_param_attr
is
not
None
and
self
.
_param_attr
.
name
is
not
None
:
if
self
.
_param_attr
is
not
None
and
self
.
_param_attr
.
name
is
not
None
:
gate_param_attr
=
copy
.
deepcopy
(
self
.
_param_attr
)
gate_param_attr
=
copy
.
deepcopy
(
self
.
_param_attr
)
candidate_param_attr
=
copy
.
deepcopy
(
self
.
_param_attr
)
candidate_param_attr
=
copy
.
deepcopy
(
self
.
_param_attr
)
...
@@ -327,14 +561,19 @@ class BasicGRUCell(RNNCell):
...
@@ -327,14 +561,19 @@ class BasicGRUCell(RNNCell):
gate_param_attr
=
self
.
_param_attr
gate_param_attr
=
self
.
_param_attr
candidate_param_attr
=
self
.
_param_attr
candidate_param_attr
=
self
.
_param_attr
if
not
self
.
use_customized_weight
:
self
.
_gate_weight
=
self
.
create_parameter
(
self
.
_gate_weight
=
self
.
create_parameter
(
attr
=
gate_param_attr
,
attr
=
gate_param_attr
,
shape
=
[
self
.
_input_size
+
self
.
_hidden_size
,
2
*
self
.
_hidden_size
],
shape
=
[
self
.
_input_size
+
self
.
_hidden_size
,
2
*
self
.
_hidden_size
],
dtype
=
self
.
_dtype
)
dtype
=
self
.
_dtype
)
self
.
_candidate_weight
=
self
.
create_parameter
(
self
.
_candidate_weight
=
self
.
create_parameter
(
attr
=
candidate_param_attr
,
attr
=
candidate_param_attr
,
shape
=
[
self
.
_input_size
+
self
.
_hidden_size
,
self
.
_hidden_size
],
shape
=
[
self
.
_input_size
+
self
.
_hidden_size
,
self
.
_hidden_size
],
dtype
=
self
.
_dtype
)
dtype
=
self
.
_dtype
)
if
self
.
_bias_attr
is
not
None
and
self
.
_bias_attr
.
name
is
not
None
:
if
self
.
_bias_attr
is
not
None
and
self
.
_bias_attr
.
name
is
not
None
:
...
@@ -357,13 +596,159 @@ class BasicGRUCell(RNNCell):
...
@@ -357,13 +596,159 @@ class BasicGRUCell(RNNCell):
dtype
=
self
.
_dtype
,
dtype
=
self
.
_dtype
,
is_bias
=
True
)
is_bias
=
True
)
else
:
# create the parameters of gates in gru
if
"w"
in
update_gate_weights
and
update_gate_weights
[
"w"
]
is
not
None
:
self
.
ug_w
=
update_gate_weights
[
"w"
]
else
:
if
gate_param_attr
is
not
None
and
gate_param_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
gate_param_attr
)
tmp_param_attr
.
name
+=
"_update_gate_w"
else
:
tmp_param_attr
=
gate_param_attr
self
.
ug_w
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_input_size
,
self
.
_hidden_size
],
dtype
=
self
.
_dtype
)
if
"h"
in
update_gate_weights
and
update_gate_weights
[
"h"
]
is
not
None
:
self
.
ug_h
=
update_gate_weights
[
"h"
]
else
:
if
gate_param_attr
is
not
None
and
gate_param_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
gate_param_attr
)
tmp_param_attr
.
name
+=
"_update_gate_h"
else
:
tmp_param_attr
=
gate_param_attr
self
.
ug_h
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_hidden_size
,
self
.
_hidden_size
],
dtype
=
self
.
_dtype
)
if
"b"
in
update_gate_weights
and
update_gate_weights
[
"b"
]
is
not
None
:
self
.
ug_b
=
update_gate_weights
[
"b"
]
else
:
if
gate_bias_attr
is
not
None
and
gate_bias_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
gate_bias_attr
)
tmp_param_attr
.
name
+=
"_update_gate_b"
else
:
tmp_param_attr
=
gate_bias_attr
self
.
ug_b
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_hidden_size
],
dtype
=
self
.
_dtype
,
is_bias
=
True
)
# reset gate parameters
if
"w"
in
reset_gate_weights
and
reset_gate_weights
[
"w"
]
is
not
None
:
self
.
rg_w
=
reset_gate_weights
[
"w"
]
else
:
if
gate_param_attr
is
not
None
and
gate_param_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
gate_param_attr
)
tmp_param_attr
.
name
+=
"_reset_gate_w"
else
:
tmp_param_attr
=
gate_param_attr
self
.
rg_w
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_input_size
,
self
.
_hidden_size
],
dtype
=
self
.
_dtype
)
if
"h"
in
reset_gate_weights
and
reset_gate_weights
[
"h"
]
is
not
None
:
self
.
rg_h
=
reset_gate_weights
[
"h"
]
else
:
if
gate_param_attr
is
not
None
and
gate_param_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
gate_param_attr
)
tmp_param_attr
.
name
+=
"_reset_gate_h"
else
:
tmp_param_attr
=
gate_param_attr
self
.
rg_h
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_hidden_size
,
self
.
_hidden_size
],
dtype
=
self
.
_dtype
)
if
"b"
in
reset_gate_weights
and
reset_gate_weights
[
"b"
]
is
not
None
:
self
.
rg_b
=
reused_params
[
"b"
]
else
:
if
gate_bias_attr
is
not
None
and
gate_bias_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
gate_bias_attr
)
tmp_param_attr
.
name
+=
"_reset_gate_b"
else
:
tmp_param_attr
=
gate_bias_attr
self
.
rg_b
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_hidden_size
],
dtype
=
self
.
_dtype
,
is_bias
=
True
)
# cell parameters
if
"w"
in
cell_weights
and
cell_weights
[
"w"
]
is
not
None
:
self
.
c_w
=
cell_weights
[
"w"
]
else
:
if
candidate_param_attr
is
not
None
and
candidate_param_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
candidate_param_attr
)
tmp_param_attr
.
name
+=
"_cell_w"
else
:
tmp_param_attr
=
gate_param_attr
self
.
c_w
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_input_size
,
self
.
_hidden_size
],
dtype
=
self
.
_dtype
)
if
"h"
in
cell_weights
and
cell_weights
[
"h"
]
is
not
None
:
self
.
c_h
=
cell_weights
[
"h"
]
else
:
if
candidate_param_attr
is
not
None
and
candidate_param_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
candidate_param_attr
)
tmp_param_attr
.
name
+=
"_cell_h"
else
:
tmp_param_attr
=
gate_param_attr
self
.
c_h
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_hidden_size
,
self
.
_hidden_size
],
dtype
=
self
.
_dtype
)
if
"b"
in
cell_weights
and
cell_weights
[
"b"
]
is
not
None
:
self
.
c_b
=
cell_weights
[
"b"
]
else
:
if
candidate_bias_attr
is
not
None
and
candidate_bias_attr
.
name
is
not
None
:
tmp_param_attr
=
copy
.
deepcopy
(
candidate_bias_attr
)
tmp_param_attr
.
name
+=
"_cell_b"
else
:
tmp_param_attr
=
gate_bias_attr
self
.
c_b
=
self
.
create_parameter
(
attr
=
tmp_param_attr
,
shape
=
[
self
.
_hidden_size
],
dtype
=
self
.
_dtype
,
is_bias
=
True
)
def
forward
(
self
,
input
,
state
):
def
forward
(
self
,
input
,
state
):
if
self
.
use_customized_weight
:
rg_weights
=
layers
.
concat
([
self
.
rg_w
,
self
.
rg_h
],
axis
=
0
)
ug_weights
=
layers
.
concat
([
self
.
ug_w
,
self
.
ug_h
],
axis
=
0
)
_gate_weight
=
layers
.
concat
([
rg_weights
,
ug_weights
],
axis
=-
1
)
_candidate_weight
=
layers
.
concat
([
self
.
c_w
,
self
.
c_h
],
axis
=
0
)
_gate_bias
=
layers
.
concat
([
self
.
rg_b
,
self
.
ug_b
],
axis
=
0
)
_candidate_bias
=
self
.
c_b
else
:
_gate_weight
=
self
.
_gate_weight
_gate_bias
=
self
.
_gate_bias
_candidate_weight
=
self
.
_candidate_weight
_candidate_bias
=
self
.
_candidate_bias
pre_hidden
=
state
pre_hidden
=
state
concat_input_hidden
=
layers
.
concat
([
input
,
pre_hidden
],
axis
=
1
)
concat_input_hidden
=
layers
.
concat
([
input
,
pre_hidden
],
axis
=
1
)
gate_input
=
layers
.
matmul
(
x
=
concat_input_hidden
,
y
=
self
.
_gate_weight
)
gate_input
=
layers
.
matmul
(
x
=
concat_input_hidden
,
y
=
_gate_weight
)
gate_input
=
layers
.
elementwise_add
(
gate_input
,
self
.
_gate_bias
)
gate_input
=
layers
.
elementwise_add
(
gate_input
,
_gate_bias
)
gate_input
=
self
.
_gate_activation
(
gate_input
)
gate_input
=
self
.
_gate_activation
(
gate_input
)
r
,
u
=
layers
.
split
(
gate_input
,
num_or_sections
=
2
,
dim
=
1
)
r
,
u
=
layers
.
split
(
gate_input
,
num_or_sections
=
2
,
dim
=
1
)
...
@@ -371,8 +756,8 @@ class BasicGRUCell(RNNCell):
...
@@ -371,8 +756,8 @@ class BasicGRUCell(RNNCell):
r_hidden
=
r
*
pre_hidden
r_hidden
=
r
*
pre_hidden
candidate
=
layers
.
matmul
(
candidate
=
layers
.
matmul
(
layers
.
concat
([
input
,
r_hidden
],
1
),
self
.
_candidate_weight
)
layers
.
concat
([
input
,
r_hidden
],
1
),
_candidate_weight
)
candidate
=
layers
.
elementwise_add
(
candidate
,
self
.
_candidate_bias
)
candidate
=
layers
.
elementwise_add
(
candidate
,
_candidate_bias
)
c
=
self
.
_activation
(
candidate
)
c
=
self
.
_activation
(
candidate
)
new_hidden
=
u
*
pre_hidden
+
(
1
-
u
)
*
c
new_hidden
=
u
*
pre_hidden
+
(
1
-
u
)
*
c
...
@@ -700,7 +1085,11 @@ class PrePostProcessLayer(Layer):
...
@@ -700,7 +1085,11 @@ class PrePostProcessLayer(Layer):
PrePostProcessLayer
PrePostProcessLayer
"""
"""
def
__init__
(
self
,
process_cmd
,
d_model
,
dropout_rate
):
def
__init__
(
self
,
process_cmd
,
d_model
,
dropout_rate
,
reused_layer_norm
=
None
):
super
(
PrePostProcessLayer
,
self
).
__init__
()
super
(
PrePostProcessLayer
,
self
).
__init__
()
self
.
process_cmd
=
process_cmd
self
.
process_cmd
=
process_cmd
self
.
functors
=
[]
self
.
functors
=
[]
...
@@ -708,16 +1097,21 @@ class PrePostProcessLayer(Layer):
...
@@ -708,16 +1097,21 @@ class PrePostProcessLayer(Layer):
if
cmd
==
"a"
:
# add residual connection
if
cmd
==
"a"
:
# add residual connection
self
.
functors
.
append
(
lambda
x
,
y
:
x
+
y
if
y
else
x
)
self
.
functors
.
append
(
lambda
x
,
y
:
x
+
y
if
y
else
x
)
elif
cmd
==
"n"
:
# add layer normalization
elif
cmd
==
"n"
:
# add layer normalization
self
.
functors
.
append
(
if
reused_layer_norm
is
not
None
:
self
.
add_sublayer
(
layer_norm
=
reused_layer_norm
"layer_norm_%d"
%
len
(
else
:
self
.
sublayers
(
include_sublayers
=
False
)),
layer_norm
=
LayerNorm
(
LayerNorm
(
normalized_shape
=
d_model
,
normalized_shape
=
d_model
,
param_attr
=
fluid
.
ParamAttr
(
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
1.
)),
initializer
=
fluid
.
initializer
.
Constant
(
1.
)),
bias_attr
=
fluid
.
ParamAttr
(
bias_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
0.
)))))
initializer
=
fluid
.
initializer
.
Constant
(
0.
)))
self
.
functors
.
append
(
self
.
add_sublayer
(
"layer_norm_%d"
%
len
(
self
.
sublayers
(
include_sublayers
=
False
)),
layer_norm
))
elif
cmd
==
"d"
:
# add dropout
elif
cmd
==
"d"
:
# add dropout
self
.
functors
.
append
(
lambda
x
:
layers
.
dropout
(
self
.
functors
.
append
(
lambda
x
:
layers
.
dropout
(
x
,
dropout_prob
=
dropout_rate
,
is_test
=
False
)
x
,
dropout_prob
=
dropout_rate
,
is_test
=
False
)
...
@@ -737,21 +1131,48 @@ class MultiHeadAttention(Layer):
...
@@ -737,21 +1131,48 @@ class MultiHeadAttention(Layer):
Multi-Head Attention
Multi-Head Attention
"""
"""
def
__init__
(
self
,
d_key
,
d_value
,
d_model
,
n_head
=
1
,
dropout_rate
=
0.
):
def
__init__
(
self
,
d_key
,
d_value
,
d_model
,
n_head
=
1
,
dropout_rate
=
0.0
,
reused_query_fc
=
None
,
reused_key_fc
=
None
,
reused_value_fc
=
None
,
reused_proj_fc
=
None
):
super
(
MultiHeadAttention
,
self
).
__init__
()
super
(
MultiHeadAttention
,
self
).
__init__
()
self
.
n_head
=
n_head
self
.
n_head
=
n_head
self
.
d_key
=
d_key
self
.
d_key
=
d_key
self
.
d_value
=
d_value
self
.
d_value
=
d_value
self
.
d_model
=
d_model
self
.
d_model
=
d_model
self
.
dropout_rate
=
dropout_rate
self
.
dropout_rate
=
dropout_rate
if
reused_query_fc
is
not
None
:
self
.
q_fc
=
reused_query_fc
else
:
self
.
q_fc
=
Linear
(
self
.
q_fc
=
Linear
(
input_dim
=
d_model
,
output_dim
=
d_key
*
n_head
,
bias_attr
=
False
)
input_dim
=
d_model
,
output_dim
=
d_key
*
n_head
,
bias_attr
=
False
)
if
reused_key_fc
is
not
None
:
self
.
k_fc
=
reused_key_fc
else
:
self
.
k_fc
=
Linear
(
self
.
k_fc
=
Linear
(
input_dim
=
d_model
,
output_dim
=
d_key
*
n_head
,
bias_attr
=
False
)
input_dim
=
d_model
,
output_dim
=
d_key
*
n_head
,
bias_attr
=
False
)
if
reused_value_fc
is
not
None
:
self
.
v_fc
=
reused_value_fc
else
:
self
.
v_fc
=
Linear
(
self
.
v_fc
=
Linear
(
input_dim
=
d_model
,
output_dim
=
d_value
*
n_head
,
bias_attr
=
False
)
input_dim
=
d_model
,
output_dim
=
d_value
*
n_head
,
bias_attr
=
False
)
if
reused_proj_fc
is
not
None
:
self
.
proj_fc
=
reused_proj_fc
else
:
self
.
proj_fc
=
Linear
(
self
.
proj_fc
=
Linear
(
input_dim
=
d_value
*
n_head
,
output_dim
=
d_model
,
bias_attr
=
False
)
input_dim
=
d_value
*
n_head
,
output_dim
=
d_model
,
bias_attr
=
False
)
def
_prepare_qkv
(
self
,
queries
,
keys
,
values
,
cache
=
None
):
def
_prepare_qkv
(
self
,
queries
,
keys
,
values
,
cache
=
None
):
if
keys
is
None
:
# self-attention
if
keys
is
None
:
# self-attention
...
@@ -828,11 +1249,23 @@ class FFN(Layer):
...
@@ -828,11 +1249,23 @@ class FFN(Layer):
Feed-Forward Network
Feed-Forward Network
"""
"""
def
__init__
(
self
,
d_inner_hid
,
d_model
,
dropout_rate
):
def
__init__
(
self
,
d_inner_hid
,
d_model
,
dropout_rate
,
fc1_act
=
"relu"
,
reused_fc1
=
None
,
reused_fc2
=
None
):
super
(
FFN
,
self
).
__init__
()
super
(
FFN
,
self
).
__init__
()
self
.
dropout_rate
=
dropout_rate
self
.
dropout_rate
=
dropout_rate
if
reused_fc1
is
not
None
:
self
.
fc1
=
reused_fc1
else
:
self
.
fc1
=
Linear
(
self
.
fc1
=
Linear
(
input_dim
=
d_model
,
output_dim
=
d_inner_hid
,
act
=
"relu"
)
input_dim
=
d_model
,
output_dim
=
d_inner_hid
,
act
=
fc1_act
)
if
reused_fc2
is
not
None
:
self
.
fc2
=
reused_fc2
else
:
self
.
fc2
=
Linear
(
input_dim
=
d_inner_hid
,
output_dim
=
d_model
)
self
.
fc2
=
Linear
(
input_dim
=
d_inner_hid
,
output_dim
=
d_model
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -859,22 +1292,52 @@ class TransformerEncoderLayer(Layer):
...
@@ -859,22 +1292,52 @@ class TransformerEncoderLayer(Layer):
attention_dropout
,
attention_dropout
,
relu_dropout
,
relu_dropout
,
preprocess_cmd
=
"n"
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
):
postprocess_cmd
=
"da"
,
ffn_fc1_act
=
"relu"
,
reused_pre_selatt_layernorm
=
None
,
reused_multihead_att_weights
=
{
"reused_query_fc"
:
None
,
"reused_key_fc"
:
None
,
"reused_value_fc"
:
None
,
"reused_proj_fc"
:
None
},
reused_post_selfatt_layernorm
=
None
,
reused_pre_ffn_layernorm
=
None
,
reused_ffn_weights
=
{
"reused_fc1"
:
None
,
"reused_fc2"
:
None
},
reused_post_ffn_layernorm
=
None
):
super
(
TransformerEncoderLayer
,
self
).
__init__
()
super
(
TransformerEncoderLayer
,
self
).
__init__
()
self
.
preprocesser1
=
PrePostProcessLayer
(
preprocess_cmd
,
d_model
,
self
.
preprocesser1
=
PrePostProcessLayer
(
preprocess_cmd
,
d_model
,
prepostprocess_dropout
)
prepostprocess_dropout
,
self
.
self_attn
=
MultiHeadAttention
(
d_key
,
d_value
,
d_model
,
n_head
,
reused_pre_selatt_layernorm
)
attention_dropout
)
self
.
self_attn
=
MultiHeadAttention
(
self
.
postprocesser1
=
PrePostProcessLayer
(
postprocess_cmd
,
d_model
,
d_key
,
prepostprocess_dropout
)
d_value
,
d_model
,
n_head
,
attention_dropout
,
reused_query_fc
=
reused_multihead_att_weights
[
"reused_query_fc"
],
reused_key_fc
=
reused_multihead_att_weights
[
"reused_key_fc"
],
reused_value_fc
=
reused_multihead_att_weights
[
"reused_value_fc"
],
reused_proj_fc
=
reused_multihead_att_weights
[
"reused_proj_fc"
])
self
.
postprocesser1
=
PrePostProcessLayer
(
postprocess_cmd
,
d_model
,
prepostprocess_dropout
,
reused_post_selfatt_layernorm
)
self
.
preprocesser2
=
PrePostProcessLayer
(
preprocess_cmd
,
d_model
,
self
.
preprocesser2
=
PrePostProcessLayer
(
preprocess_cmd
,
d_model
,
prepostprocess_dropout
)
prepostprocess_dropout
,
self
.
ffn
=
FFN
(
d_inner_hid
,
d_model
,
relu_dropout
)
reused_pre_ffn_layernorm
)
self
.
ffn
=
FFN
(
d_inner_hid
,
d_model
,
relu_dropout
,
fc1_act
=
ffn_fc1_act
,
reused_fc1
=
reused_ffn_weights
[
"reused_fc1"
],
reused_fc2
=
reused_ffn_weights
[
"reused_fc2"
])
self
.
postprocesser2
=
PrePostProcessLayer
(
postprocess_cmd
,
d_model
,
self
.
postprocesser2
=
PrePostProcessLayer
(
postprocess_cmd
,
d_model
,
prepostprocess_dropout
)
prepostprocess_dropout
,
reused_post_ffn_layernorm
)
def
forward
(
self
,
enc_input
,
attn_bias
):
def
forward
(
self
,
enc_input
,
attn_bias
):
attn_output
=
self
.
self_attn
(
attn_output
=
self
.
self_attn
(
...
@@ -902,7 +1365,8 @@ class TransformerEncoder(Layer):
...
@@ -902,7 +1365,8 @@ class TransformerEncoder(Layer):
attention_dropout
,
attention_dropout
,
relu_dropout
,
relu_dropout
,
preprocess_cmd
=
"n"
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
):
postprocess_cmd
=
"da"
,
ffn_fc1_act
=
"relu"
):
super
(
TransformerEncoder
,
self
).
__init__
()
super
(
TransformerEncoder
,
self
).
__init__
()
...
@@ -912,9 +1376,17 @@ class TransformerEncoder(Layer):
...
@@ -912,9 +1376,17 @@ class TransformerEncoder(Layer):
self
.
add_sublayer
(
self
.
add_sublayer
(
"layer_%d"
%
i
,
"layer_%d"
%
i
,
TransformerEncoderLayer
(
TransformerEncoderLayer
(
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
n_head
,
prepostprocess_dropout
,
attention_dropout
,
d_key
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
)))
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
ffn_fc1_act
=
ffn_fc1_act
)))
self
.
processer
=
PrePostProcessLayer
(
preprocess_cmd
,
d_model
,
self
.
processer
=
PrePostProcessLayer
(
preprocess_cmd
,
d_model
,
prepostprocess_dropout
)
prepostprocess_dropout
)
...
@@ -941,28 +1413,79 @@ class TransformerDecoderLayer(Layer):
...
@@ -941,28 +1413,79 @@ class TransformerDecoderLayer(Layer):
attention_dropout
,
attention_dropout
,
relu_dropout
,
relu_dropout
,
preprocess_cmd
=
"n"
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
):
postprocess_cmd
=
"da"
,
reused_pre_selfatt_layernorm
=
None
,
reused_self_multihead_att_weights
=
{
"reused_query_fc"
:
None
,
"reused_key_fc"
:
None
,
"reused_value_fc"
:
None
,
"reused_proj_fc"
:
None
},
reused_post_selfatt_layernorm
=
None
,
reused_pre_crossatt_layernorm
=
None
,
reused_cross_multihead_att_weights
=
{
"reused_query_fc"
:
None
,
"reused_key_fc"
:
None
,
"reused_value_fc"
:
None
,
"reused_proj_fc"
:
None
},
reused_post_crossatt_layernorm
=
None
,
reused_pre_ffn_layernorm
=
None
,
reused_ffn_weights
=
{
"reused_fc1"
:
None
,
"reused_fc2"
:
None
},
reused_post_ffn_layernorm
=
None
):
super
(
TransformerDecoderLayer
,
self
).
__init__
()
super
(
TransformerDecoderLayer
,
self
).
__init__
()
self
.
preprocesser1
=
PrePostProcessLayer
(
preprocess_cmd
,
d_model
,
self
.
preprocesser1
=
PrePostProcessLayer
(
preprocess_cmd
,
d_model
,
prepostprocess_dropout
)
prepostprocess_dropout
,
self
.
self_attn
=
MultiHeadAttention
(
d_key
,
d_value
,
d_model
,
n_head
,
reused_pre_selfatt_layernorm
)
attention_dropout
)
self
.
self_attn
=
MultiHeadAttention
(
self
.
postprocesser1
=
PrePostProcessLayer
(
postprocess_cmd
,
d_model
,
d_key
,
prepostprocess_dropout
)
d_value
,
d_model
,
n_head
,
attention_dropout
,
reused_query_fc
=
reused_self_multihead_att_weights
[
"reused_query_fc"
],
reused_key_fc
=
reused_self_multihead_att_weights
[
"reused_key_fc"
],
reused_value_fc
=
reused_self_multihead_att_weights
[
"reused_value_fc"
],
reused_proj_fc
=
reused_self_multihead_att_weights
[
"reused_proj_fc"
])
self
.
postprocesser1
=
PrePostProcessLayer
(
postprocess_cmd
,
d_model
,
prepostprocess_dropout
,
reused_post_selfatt_layernorm
)
self
.
preprocesser2
=
PrePostProcessLayer
(
preprocess_cmd
,
d_model
,
self
.
preprocesser2
=
PrePostProcessLayer
(
preprocess_cmd
,
d_model
,
prepostprocess_dropout
)
prepostprocess_dropout
,
self
.
cross_attn
=
MultiHeadAttention
(
d_key
,
d_value
,
d_model
,
n_head
,
reused_pre_crossatt_layernorm
)
attention_dropout
)
self
.
cross_attn
=
MultiHeadAttention
(
self
.
postprocesser2
=
PrePostProcessLayer
(
postprocess_cmd
,
d_model
,
d_key
,
prepostprocess_dropout
)
d_value
,
d_model
,
n_head
,
attention_dropout
,
reused_query_fc
=
reused_cross_multihead_att_weights
[
"reused_query_fc"
],
reused_key_fc
=
reused_cross_multihead_att_weights
[
"reused_key_fc"
],
reused_value_fc
=
reused_cross_multihead_att_weights
[
"reused_value_fc"
],
reused_proj_fc
=
reused_cross_multihead_att_weights
[
"reused_proj_fc"
])
self
.
postprocesser2
=
PrePostProcessLayer
(
postprocess_cmd
,
d_model
,
prepostprocess_dropout
,
reused_post_crossatt_layernorm
)
self
.
preprocesser3
=
PrePostProcessLayer
(
preprocess_cmd
,
d_model
,
self
.
preprocesser3
=
PrePostProcessLayer
(
preprocess_cmd
,
d_model
,
prepostprocess_dropout
)
prepostprocess_dropout
,
self
.
ffn
=
FFN
(
d_inner_hid
,
d_model
,
relu_dropout
)
reused_pre_ffn_layernorm
)
self
.
ffn
=
FFN
(
d_inner_hid
,
d_model
,
relu_dropout
,
reused_fc1
=
reused_ffn_weights
[
"reused_fc1"
],
reused_fc2
=
reused_ffn_weights
[
"reused_fc2"
])
self
.
postprocesser3
=
PrePostProcessLayer
(
postprocess_cmd
,
d_model
,
self
.
postprocesser3
=
PrePostProcessLayer
(
postprocess_cmd
,
d_model
,
prepostprocess_dropout
)
prepostprocess_dropout
,
reused_post_ffn_layernorm
)
def
forward
(
self
,
def
forward
(
self
,
dec_input
,
dec_input
,
...
@@ -1031,7 +1554,7 @@ class TransformerDecoder(Layer):
...
@@ -1031,7 +1554,7 @@ class TransformerDecoder(Layer):
]
]
#TODO: we should merge GRUCell with BasicGRUCell
class
GRUCell
(
RNNCell
):
class
GRUCell
(
RNNCell
):
def
__init__
(
self
,
def
__init__
(
self
,
input_size
,
input_size
,
...
@@ -1044,9 +1567,7 @@ class GRUCell(RNNCell):
...
@@ -1044,9 +1567,7 @@ class GRUCell(RNNCell):
super
(
GRUCell
,
self
).
__init__
()
super
(
GRUCell
,
self
).
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
fc_layer
=
Linear
(
self
.
fc_layer
=
Linear
(
input_size
,
input_size
,
hidden_size
*
3
,
param_attr
=
param_attr
)
hidden_size
*
3
,
param_attr
=
param_attr
)
self
.
gru_unit
=
GRUUnit
(
self
.
gru_unit
=
GRUUnit
(
hidden_size
*
3
,
hidden_size
*
3
,
...
@@ -1067,6 +1588,7 @@ class GRUCell(RNNCell):
...
@@ -1067,6 +1588,7 @@ class GRUCell(RNNCell):
return
[
self
.
hidden_size
]
return
[
self
.
hidden_size
]
#TODO: we should merge GRUCell with BasicGRUCell
class
GRUEncoderCell
(
RNNCell
):
class
GRUEncoderCell
(
RNNCell
):
def
__init__
(
self
,
def
__init__
(
self
,
num_layers
,
num_layers
,
...
@@ -1086,7 +1608,8 @@ class GRUEncoderCell(RNNCell):
...
@@ -1086,7 +1608,8 @@ class GRUEncoderCell(RNNCell):
GRUCell
(
GRUCell
(
input_size
=
input_size
if
i
==
0
else
hidden_size
,
input_size
=
input_size
if
i
==
0
else
hidden_size
,
hidden_size
=
hidden_size
,
hidden_size
=
hidden_size
,
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
UniformInitializer
(
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
UniformInitializer
(
low
=-
init_scale
,
high
=
init_scale
)))))
low
=-
init_scale
,
high
=
init_scale
)))))
def
forward
(
self
,
step_input
,
states
):
def
forward
(
self
,
step_input
,
states
):
...
@@ -1109,17 +1632,16 @@ class GRUEncoderCell(RNNCell):
...
@@ -1109,17 +1632,16 @@ class GRUEncoderCell(RNNCell):
class
BiGRU
(
fluid
.
dygraph
.
Layer
):
class
BiGRU
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
input_dim
,
grnn_hidden_dim
,
init_bound
,
h_0
=
None
):
def
__init__
(
self
,
input_dim
,
grnn_hidden_dim
,
init_bound
,
h_0
=
None
):
super
(
BiGRU
,
self
).
__init__
()
super
(
BiGRU
,
self
).
__init__
()
self
.
gru
=
RNN
(
GRUEncoderCell
(
1
,
input_dim
,
self
.
gru
=
RNN
(
GRUEncoderCell
(
1
,
input_dim
,
grnn_hidden_dim
,
0.0
,
grnn_hidden_dim
,
0.0
,
init_bound
),
init_bound
),
is_reverse
=
False
,
is_reverse
=
False
,
time_major
=
False
)
time_major
=
False
)
self
.
gru_r
=
RNN
(
GRUEncoderCell
(
1
,
input_dim
,
self
.
gru_r
=
RNN
(
GRUEncoderCell
(
1
,
input_dim
,
grnn_hidden_dim
,
0.0
,
grnn_hidden_dim
,
0.0
,
init_bound
),
init_bound
),
is_reverse
=
True
,
is_reverse
=
True
,
time_major
=
False
)
time_major
=
False
)
def
forward
(
self
,
input_feature
):
def
forward
(
self
,
input_feature
):
pre_gru
,
pre_state
=
self
.
gru
(
input_feature
)
pre_gru
,
pre_state
=
self
.
gru
(
input_feature
)
gru_r
,
r_state
=
self
.
gru_r
(
input_feature
)
gru_r
,
r_state
=
self
.
gru_r
(
input_feature
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录