Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
60917f41
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
60917f41
编写于
4月 30, 2020
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add test for text.py
上级
61e218d0
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
525 addition
and
36 deletion
+525
-36
examples/transformer/transformer.py
examples/transformer/transformer.py
+3
-23
hapi/model.py
hapi/model.py
+11
-10
hapi/tests/test_text.py
hapi/tests/test_text.py
+508
-0
hapi/text/text.py
hapi/text/text.py
+3
-3
未找到文件。
examples/transformer/transformer.py
浏览文件 @
60917f41
...
@@ -18,10 +18,10 @@ import numpy as np
...
@@ -18,10 +18,10 @@ import numpy as np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
layers
import
paddle.fluid.layers
as
layers
from
paddle.fluid.dygraph
import
Embedding
,
LayerNorm
,
Linear
,
Layer
,
to_variable
from
paddle.fluid.dygraph
import
Embedding
,
LayerNorm
,
Linear
,
Layer
from
paddle.fluid.dygraph.learning_rate_scheduler
import
LearningRateDecay
from
paddle.fluid.dygraph.learning_rate_scheduler
import
LearningRateDecay
from
hapi.model
import
Model
,
CrossEntropy
,
Loss
from
hapi.model
import
Model
,
CrossEntropy
,
Loss
from
hapi.text
import
TransformerBeamSearchDecoder
,
DynamicDecode
from
hapi.text
import
Transformer
Cell
,
Transformer
BeamSearchDecoder
,
DynamicDecode
def
position_encoding_init
(
n_position
,
d_pos_vec
):
def
position_encoding_init
(
n_position
,
d_pos_vec
):
...
@@ -606,26 +606,6 @@ class Transformer(Model):
...
@@ -606,26 +606,6 @@ class Transformer(Model):
return
predict
return
predict
class
TransfomerCell
(
object
):
"""
Let inputs=(trg_word, trg_pos), states=cache to make Transformer can be
used as RNNCell
"""
def
__init__
(
self
,
decoder
):
self
.
decoder
=
decoder
def
__call__
(
self
,
inputs
,
states
,
trg_src_attn_bias
,
enc_output
,
static_caches
):
trg_word
,
trg_pos
=
inputs
for
cache
,
static_cache
in
zip
(
states
,
static_caches
):
cache
.
update
(
static_cache
)
logits
=
self
.
decoder
(
trg_word
,
trg_pos
,
None
,
trg_src_attn_bias
,
enc_output
,
states
)
new_states
=
[{
"k"
:
cache
[
"k"
],
"v"
:
cache
[
"v"
]}
for
cache
in
states
]
return
logits
,
new_states
class
InferTransformer
(
Transformer
):
class
InferTransformer
(
Transformer
):
"""
"""
model for prediction
model for prediction
...
@@ -657,7 +637,7 @@ class InferTransformer(Transformer):
...
@@ -657,7 +637,7 @@ class InferTransformer(Transformer):
self
.
beam_size
=
args
.
pop
(
"beam_size"
)
self
.
beam_size
=
args
.
pop
(
"beam_size"
)
self
.
max_out_len
=
args
.
pop
(
"max_out_len"
)
self
.
max_out_len
=
args
.
pop
(
"max_out_len"
)
super
(
InferTransformer
,
self
).
__init__
(
**
args
)
super
(
InferTransformer
,
self
).
__init__
(
**
args
)
cell
=
TransfomerCell
(
self
.
decoder
)
cell
=
Transfo
r
merCell
(
self
.
decoder
)
self
.
beam_search_decoder
=
DynamicDecode
(
self
.
beam_search_decoder
=
DynamicDecode
(
TransformerBeamSearchDecoder
(
TransformerBeamSearchDecoder
(
cell
,
bos_id
,
eos_id
,
beam_size
,
var_dim_in_state
=
2
),
cell
,
bos_id
,
eos_id
,
beam_size
,
var_dim_in_state
=
2
),
...
...
hapi/model.py
浏览文件 @
60917f41
...
@@ -38,7 +38,7 @@ from hapi.loss import Loss
...
@@ -38,7 +38,7 @@ from hapi.loss import Loss
from
hapi.distributed
import
DistributedBatchSampler
,
_all_gather
,
prepare_distributed_context
,
_parallel_context_initialized
from
hapi.distributed
import
DistributedBatchSampler
,
_all_gather
,
prepare_distributed_context
,
_parallel_context_initialized
from
hapi.metrics
import
Metric
from
hapi.metrics
import
Metric
from
hapi.callbacks
import
config_callbacks
from
hapi.callbacks
import
config_callbacks
from
hapi.utils
import
to_list
,
to_numpy
,
flatten_list
,
restore_flatten_list
from
hapi.utils
import
to_list
,
to_numpy
,
flatten_list
,
restore_flatten_list
,
extract_args
__all__
=
[
__all__
=
[
'Model'
,
'Model'
,
...
@@ -495,14 +495,15 @@ class DynamicGraphAdapter(object):
...
@@ -495,14 +495,15 @@ class DynamicGraphAdapter(object):
if
labels
is
not
None
:
if
labels
is
not
None
:
labels
=
[
to_variable
(
l
)
for
l
in
to_list
(
labels
)]
labels
=
[
to_variable
(
l
)
for
l
in
to_list
(
labels
)]
if
self
.
_nranks
>
1
:
if
self
.
_nranks
>
1
:
outputs
=
self
.
ddp_model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
])
outputs
=
self
.
ddp_model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
])
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
final_loss
=
fluid
.
layers
.
sum
(
losses
)
final_loss
=
fluid
.
layers
.
sum
(
losses
)
final_loss
=
self
.
ddp_model
.
scale_loss
(
final_loss
)
final_loss
=
self
.
ddp_model
.
scale_loss
(
final_loss
)
final_loss
.
backward
()
final_loss
.
backward
()
self
.
ddp_model
.
apply_collective_grads
()
self
.
ddp_model
.
apply_collective_grads
()
else
:
else
:
outputs
=
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
])
outputs
=
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
])
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
final_loss
=
fluid
.
layers
.
sum
(
losses
)
final_loss
=
fluid
.
layers
.
sum
(
losses
)
final_loss
.
backward
()
final_loss
.
backward
()
...
@@ -511,9 +512,9 @@ class DynamicGraphAdapter(object):
...
@@ -511,9 +512,9 @@ class DynamicGraphAdapter(object):
self
.
model
.
clear_gradients
()
self
.
model
.
clear_gradients
()
metrics
=
[]
metrics
=
[]
for
metric
in
self
.
model
.
_metrics
:
for
metric
in
self
.
model
.
_metrics
:
metric_outs
=
metric
.
add_metric_op
(
*
(
metric_outs
=
metric
.
add_metric_op
(
*
(
to_list
(
outputs
)
+
to_list
(
to_list
(
outputs
)
+
to_list
(
labels
)))
labels
)))
m
=
metric
.
update
(
*
[
to_numpy
(
m
)
for
m
in
to_list
(
metric_outs
)])
m
=
metric
.
update
(
*
[
to_numpy
(
m
)
for
m
in
to_list
(
metric_outs
)])
metrics
.
append
(
m
)
metrics
.
append
(
m
)
return
([
to_numpy
(
l
)
for
l
in
losses
],
metrics
)
\
return
([
to_numpy
(
l
)
for
l
in
losses
],
metrics
)
\
...
@@ -525,7 +526,7 @@ class DynamicGraphAdapter(object):
...
@@ -525,7 +526,7 @@ class DynamicGraphAdapter(object):
inputs
=
to_list
(
inputs
)
inputs
=
to_list
(
inputs
)
if
labels
is
not
None
:
if
labels
is
not
None
:
labels
=
[
to_variable
(
l
)
for
l
in
to_list
(
labels
)]
labels
=
[
to_variable
(
l
)
for
l
in
to_list
(
labels
)]
outputs
=
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
])
outputs
=
self
.
model
.
forward
(
*
[
to_variable
(
x
)
for
x
in
inputs
])
if
self
.
model
.
_loss_function
:
if
self
.
model
.
_loss_function
:
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
else
:
else
:
...
@@ -551,9 +552,9 @@ class DynamicGraphAdapter(object):
...
@@ -551,9 +552,9 @@ class DynamicGraphAdapter(object):
self
.
_merge_count
[
self
.
mode
+
'_total'
]
+=
samples
self
.
_merge_count
[
self
.
mode
+
'_total'
]
+=
samples
self
.
_merge_count
[
self
.
mode
+
'_batch'
]
=
samples
self
.
_merge_count
[
self
.
mode
+
'_batch'
]
=
samples
metric_outs
=
metric
.
add_metric_op
(
*
(
metric_outs
=
metric
.
add_metric_op
(
*
(
to_list
(
outputs
)
+
to_list
(
to_list
(
outputs
)
+
to_list
(
labels
)))
labels
)))
m
=
metric
.
update
(
*
[
to_numpy
(
m
)
for
m
in
to_list
(
metric_outs
)])
m
=
metric
.
update
(
*
[
to_numpy
(
m
)
for
m
in
to_list
(
metric_outs
)])
metrics
.
append
(
m
)
metrics
.
append
(
m
)
# To be consistent with static graph
# To be consistent with static graph
...
...
hapi/tests/test_text.py
0 → 100644
浏览文件 @
60917f41
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# when test, you should add hapi root path to the PYTHONPATH,
# export PYTHONPATH=PATH_TO_HAPI:$PYTHONPATH
import
unittest
import
time
import
random
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
Embedding
,
Linear
,
Layer
from
paddle.fluid.layers
import
BeamSearchDecoder
import
hapi.text
as
text
from
hapi.model
import
Model
,
Input
,
set_device
from
hapi.text
import
BasicLSTMCell
,
BasicGRUCell
,
RNN
,
DynamicDecode
,
MultiHeadAttention
,
TransformerEncoder
from
hapi.text
import
*
def
sigmoid
(
x
):
return
1.
/
(
1.
+
np
.
exp
(
-
x
))
def
tanh
(
x
):
return
2.
*
sigmoid
(
2.
*
x
)
-
1.
def
lstm_step
(
step_in
,
pre_hidden
,
pre_cell
,
gate_w
,
gate_b
,
forget_bias
=
1.0
):
concat_1
=
np
.
concatenate
([
step_in
,
pre_hidden
],
1
)
gate_input
=
np
.
matmul
(
concat_1
,
gate_w
)
gate_input
+=
gate_b
i
,
j
,
f
,
o
=
np
.
split
(
gate_input
,
indices_or_sections
=
4
,
axis
=
1
)
new_cell
=
pre_cell
*
sigmoid
(
f
+
forget_bias
)
+
sigmoid
(
i
)
*
tanh
(
j
)
new_hidden
=
tanh
(
new_cell
)
*
sigmoid
(
o
)
return
new_hidden
,
new_cell
def
gru_step
(
step_in
,
pre_hidden
,
gate_w
,
gate_b
,
candidate_w
,
candidate_b
):
concat_1
=
np
.
concatenate
([
step_in
,
pre_hidden
],
1
)
gate_input
=
np
.
matmul
(
concat_1
,
gate_w
)
gate_input
+=
gate_b
gate_input
=
sigmoid
(
gate_input
)
r
,
u
=
np
.
split
(
gate_input
,
indices_or_sections
=
2
,
axis
=
1
)
r_hidden
=
r
*
pre_hidden
candidate
=
np
.
matmul
(
np
.
concatenate
([
step_in
,
r_hidden
],
1
),
candidate_w
)
candidate
+=
candidate_b
c
=
tanh
(
candidate
)
new_hidden
=
u
*
pre_hidden
+
(
1
-
u
)
*
c
return
new_hidden
class
ModuleApiTest
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
_np_rand_state
=
np
.
random
.
get_state
()
cls
.
_py_rand_state
=
random
.
getstate
()
cls
.
_random_seed
=
123
np
.
random
.
seed
(
cls
.
_random_seed
)
random
.
seed
(
cls
.
_random_seed
)
cls
.
model_cls
=
type
(
cls
.
__name__
+
"Model"
,
(
Model
,
),
{
"__init__"
:
cls
.
model_init_wrapper
(
cls
.
model_init
),
"forward"
:
cls
.
model_forward
})
@
classmethod
def
tearDownClass
(
cls
):
np
.
random
.
set_state
(
cls
.
_np_rand_state
)
random
.
setstate
(
cls
.
_py_rand_state
)
@
staticmethod
def
model_init_wrapper
(
func
):
def
__impl__
(
self
,
*
args
,
**
kwargs
):
Model
.
__init__
(
self
)
func
(
self
,
*
args
,
**
kwargs
)
return
__impl__
@
staticmethod
def
model_init
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"model_init acts as `Model.__init__`, thus must implement it"
)
@
staticmethod
def
model_forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
module
(
*
args
,
**
kwargs
)
def
make_inputs
(
self
):
# TODO(guosheng): add default from `self.inputs`
raise
NotImplementedError
(
"model_inputs makes inputs for model, thus must implement it"
)
def
setUp
(
self
):
"""
For the model which wraps the module to be tested:
Set input data by `self.inputs` list
Set init argument values by `self.attrs` list/dict
Set model parameter values by `self.param_states` dict
Set expected output data by `self.outputs` list
We can create a model instance and run once with these.
"""
self
.
inputs
=
[]
self
.
attrs
=
{}
self
.
param_states
=
{}
self
.
outputs
=
[]
def
_calc_output
(
self
,
place
,
mode
=
"test"
,
dygraph
=
True
):
if
dygraph
:
fluid
.
enable_dygraph
(
place
)
else
:
fluid
.
disable_dygraph
()
fluid
.
default_main_program
().
random_seed
=
self
.
_random_seed
fluid
.
default_startup_program
().
random_seed
=
self
.
_random_seed
model
=
self
.
model_cls
(
**
self
.
attrs
)
if
isinstance
(
self
.
attrs
,
dict
)
else
self
.
model_cls
(
*
self
.
attrs
)
model
.
prepare
(
inputs
=
self
.
make_inputs
(),
device
=
place
)
if
self
.
param_states
:
model
.
load
(
self
.
param_states
,
optim_state
=
None
)
return
model
.
test_batch
(
self
.
inputs
)
def
check_output_with_place
(
self
,
place
,
mode
=
"test"
):
dygraph_output
=
self
.
_calc_output
(
place
,
mode
,
dygraph
=
True
)
stgraph_output
=
self
.
_calc_output
(
place
,
mode
,
dygraph
=
False
)
expect_output
=
getattr
(
self
,
"outputs"
,
None
)
for
actual_t
,
expect_t
in
zip
(
dygraph_output
,
stgraph_output
):
self
.
assertTrue
(
np
.
allclose
(
actual_t
,
expect_t
,
rtol
=
1e-5
,
atol
=
0
))
if
expect_output
:
for
actual_t
,
expect_t
in
zip
(
dygraph_output
,
expect_output
):
self
.
assertTrue
(
np
.
allclose
(
actual_t
,
expect_t
,
rtol
=
1e-5
,
atol
=
0
))
def
check_output
(
self
):
devices
=
[
"CPU"
,
"GPU"
]
if
fluid
.
is_compiled_with_cuda
()
else
[
"CPU"
]
for
device
in
devices
:
place
=
set_device
(
device
)
self
.
check_output_with_place
(
place
)
class
TestBasicLSTM
(
ModuleApiTest
):
def
setUp
(
self
):
# TODO(guosheng): Change to big size. Currentlys bigger hidden size for
# LSTM would fail, the second static graph run might get diff output
# with others.
shape
=
(
2
,
4
,
16
)
self
.
inputs
=
[
np
.
random
.
random
(
shape
).
astype
(
"float32"
)]
self
.
outputs
=
None
self
.
attrs
=
{
"input_size"
:
16
,
"hidden_size"
:
16
}
self
.
param_states
=
{}
@
staticmethod
def
model_init
(
self
,
input_size
,
hidden_size
):
self
.
lstm
=
RNN
(
BasicLSTMCell
(
input_size
,
hidden_size
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"lstm_weight"
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
"lstm_bias"
)))
@
staticmethod
def
model_forward
(
self
,
inputs
):
return
self
.
lstm
(
inputs
)[
0
]
def
make_inputs
(
self
):
inputs
=
[
Input
(
[
None
,
None
,
self
.
inputs
[
-
1
].
shape
[
-
1
]],
"float32"
,
name
=
"input"
)
]
return
inputs
def
test_check_output
(
self
):
self
.
check_output
()
class
TestBasicGRU
(
ModuleApiTest
):
def
setUp
(
self
):
shape
=
(
2
,
4
,
128
)
self
.
inputs
=
[
np
.
random
.
random
(
shape
).
astype
(
"float32"
)]
self
.
outputs
=
None
self
.
attrs
=
{
"input_size"
:
128
,
"hidden_size"
:
128
}
self
.
param_states
=
{}
@
staticmethod
def
model_init
(
self
,
input_size
,
hidden_size
):
self
.
gru
=
RNN
(
BasicGRUCell
(
input_size
,
hidden_size
))
@
staticmethod
def
model_forward
(
self
,
inputs
):
return
self
.
gru
(
inputs
)[
0
]
def
make_inputs
(
self
):
inputs
=
[
Input
(
[
None
,
None
,
self
.
inputs
[
-
1
].
shape
[
-
1
]],
"float32"
,
name
=
"input"
)
]
return
inputs
def
test_check_output
(
self
):
self
.
check_output
()
class
TestBeamSearch
(
ModuleApiTest
):
def
setUp
(
self
):
shape
=
(
8
,
32
)
self
.
inputs
=
[
np
.
random
.
random
(
shape
).
astype
(
"float32"
),
np
.
random
.
random
(
shape
).
astype
(
"float32"
)
]
self
.
outputs
=
None
self
.
attrs
=
{
"vocab_size"
:
100
,
"embed_dim"
:
32
,
"hidden_size"
:
32
,
}
self
.
param_states
=
{}
@
staticmethod
def
model_init
(
self
,
vocab_size
,
embed_dim
,
hidden_size
,
bos_id
=
0
,
eos_id
=
1
,
beam_size
=
4
,
max_step_num
=
20
):
embedder
=
Embedding
(
size
=
[
vocab_size
,
embed_dim
])
output_layer
=
Linear
(
hidden_size
,
vocab_size
)
cell
=
BasicLSTMCell
(
embed_dim
,
hidden_size
)
decoder
=
BeamSearchDecoder
(
cell
,
start_token
=
bos_id
,
end_token
=
eos_id
,
beam_size
=
beam_size
,
embedding_fn
=
embedder
,
output_fn
=
output_layer
)
self
.
beam_search_decoder
=
DynamicDecode
(
decoder
,
max_step_num
=
max_step_num
,
is_test
=
True
)
@
staticmethod
def
model_forward
(
self
,
init_hidden
,
init_cell
):
return
self
.
beam_search_decoder
([
init_hidden
,
init_cell
])[
0
]
def
make_inputs
(
self
):
inputs
=
[
Input
(
[
None
,
self
.
inputs
[
0
].
shape
[
-
1
]],
"float32"
,
name
=
"init_hidden"
),
Input
(
[
None
,
self
.
inputs
[
1
].
shape
[
-
1
]],
"float32"
,
name
=
"init_cell"
)
]
return
inputs
def
test_check_output
(
self
):
self
.
check_output
()
class
TestTransformerEncoder
(
ModuleApiTest
):
def
setUp
(
self
):
self
.
inputs
=
[
# encoder input: [batch_size, seq_len, hidden_size]
np
.
random
.
random
([
2
,
4
,
512
]).
astype
(
"float32"
),
# self attention bias: [batch_size, n_head, seq_len, seq_len]
np
.
random
.
randint
(
0
,
1
,
[
2
,
8
,
4
,
4
]).
astype
(
"float32"
)
*
-
1e9
]
self
.
outputs
=
None
self
.
attrs
=
{
"n_layer"
:
2
,
"n_head"
:
8
,
"d_key"
:
64
,
"d_value"
:
64
,
"d_model"
:
512
,
"d_inner_hid"
:
1024
}
self
.
param_states
=
{}
@
staticmethod
def
model_init
(
self
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
=
0.1
,
attention_dropout
=
0.1
,
relu_dropout
=
0.1
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
,
ffn_fc1_act
=
"relu"
):
self
.
encoder
=
TransformerEncoder
(
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
ffn_fc1_act
)
@
staticmethod
def
model_forward
(
self
,
enc_input
,
attn_bias
):
return
self
.
encoder
(
enc_input
,
attn_bias
)
def
make_inputs
(
self
):
inputs
=
[
Input
(
[
None
,
None
,
self
.
inputs
[
0
].
shape
[
-
1
]],
"float32"
,
name
=
"enc_input"
),
Input
(
[
None
,
self
.
inputs
[
1
].
shape
[
1
],
None
,
None
],
"float32"
,
name
=
"attn_bias"
)
]
return
inputs
def
test_check_output
(
self
):
self
.
check_output
()
class
TestTransformerDecoder
(
TestTransformerEncoder
):
def
setUp
(
self
):
self
.
inputs
=
[
# decoder input: [batch_size, seq_len, hidden_size]
np
.
random
.
random
([
2
,
4
,
512
]).
astype
(
"float32"
),
# encoder output: [batch_size, seq_len, hidden_size]
np
.
random
.
random
([
2
,
5
,
512
]).
astype
(
"float32"
),
# self attention bias: [batch_size, n_head, seq_len, seq_len]
np
.
random
.
randint
(
0
,
1
,
[
2
,
8
,
4
,
4
]).
astype
(
"float32"
)
*
-
1e9
,
# cross attention bias: [batch_size, n_head, seq_len, seq_len]
np
.
random
.
randint
(
0
,
1
,
[
2
,
8
,
4
,
5
]).
astype
(
"float32"
)
*
-
1e9
]
self
.
outputs
=
None
self
.
attrs
=
{
"n_layer"
:
2
,
"n_head"
:
8
,
"d_key"
:
64
,
"d_value"
:
64
,
"d_model"
:
512
,
"d_inner_hid"
:
1024
}
self
.
param_states
=
{}
@
staticmethod
def
model_init
(
self
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
=
0.1
,
attention_dropout
=
0.1
,
relu_dropout
=
0.1
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
):
self
.
decoder
=
TransformerDecoder
(
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
)
@
staticmethod
def
model_forward
(
self
,
dec_input
,
enc_output
,
self_attn_bias
,
cross_attn_bias
,
caches
=
None
):
return
self
.
decoder
(
dec_input
,
enc_output
,
self_attn_bias
,
cross_attn_bias
,
caches
)
def
make_inputs
(
self
):
inputs
=
[
Input
(
[
None
,
None
,
self
.
inputs
[
0
].
shape
[
-
1
]],
"float32"
,
name
=
"dec_input"
),
Input
(
[
None
,
None
,
self
.
inputs
[
0
].
shape
[
-
1
]],
"float32"
,
name
=
"enc_output"
),
Input
(
[
None
,
self
.
inputs
[
-
1
].
shape
[
1
],
None
,
None
],
"float32"
,
name
=
"self_attn_bias"
),
Input
(
[
None
,
self
.
inputs
[
-
1
].
shape
[
1
],
None
,
None
],
"float32"
,
name
=
"cross_attn_bias"
)
]
return
inputs
def
test_check_output
(
self
):
self
.
check_output
()
class
TestTransformerBeamSearchDecoder
(
ModuleApiTest
):
def
setUp
(
self
):
shape
=
(
8
,
32
)
self
.
inputs
=
[
np
.
random
.
random
(
shape
).
astype
(
"float32"
),
np
.
random
.
random
(
shape
).
astype
(
"float32"
)
]
self
.
outputs
=
None
self
.
attrs
=
{
"vocab_size"
:
100
,
"embed_dim"
:
32
,
"hidden_size"
:
32
,
}
self
.
param_states
=
{}
@
staticmethod
def
model_init
(
self
,
vocab_size
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
=
0.1
,
attention_dropout
=
0.1
,
relu_dropout
=
0.1
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
,
bos_id
=
0
,
eos_id
=
1
,
beam_size
=
4
,
max_step_num
=
20
):
embedder
=
Embedding
(
size
=
[
vocab_size
,
d_model
])
output_layer
=
Linear
(
d_model
,
vocab_size
)
decoder
=
TransformerDecoder
(
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
)
transformer_cell
=
TransformerCell
(
decoder
)
self
.
beam_search_decoder
=
DynamicDecode
(
TransformerBeamSearchDecoder
(
transformer_cell
,
bos_id
,
eos_id
,
beam_size
,
var_dim_in_state
=
2
),
max_step_num
,
is_test
=
True
)
@
staticmethod
def
model_forward
(
self
,
enc_output
,
trg_src_attn_bias
):
caches
=
[{
"k"
:
layers
.
fill_constant_batch_size_like
(
input
=
enc_output
,
shape
=
[
-
1
,
self
.
n_head
,
0
,
self
.
d_key
],
dtype
=
enc_output
.
dtype
,
value
=
0
),
"v"
:
layers
.
fill_constant_batch_size_like
(
input
=
enc_output
,
shape
=
[
-
1
,
self
.
n_head
,
0
,
self
.
d_value
],
dtype
=
enc_output
.
dtype
,
value
=
0
),
}
for
i
in
range
(
self
.
n_layer
)]
enc_output
=
TransformerBeamSearchDecoder
.
tile_beam_merge_with_batch
(
enc_output
,
self
.
beam_size
)
trg_src_attn_bias
=
TransformerBeamSearchDecoder
.
tile_beam_merge_with_batch
(
trg_src_attn_bias
,
self
.
beam_size
)
static_caches
=
self
.
decoder
.
decoder
.
prepare_static_cache
(
enc_output
)
rs
,
_
=
self
.
beam_search_decoder
(
inits
=
caches
,
enc_output
=
enc_output
,
trg_src_attn_bias
=
trg_src_attn_bias
,
static_caches
=
static_caches
)
return
rs
def
make_inputs
(
self
):
inputs
=
[
Input
(
[
None
,
self
.
inputs
[
0
].
shape
[
-
1
]],
"float32"
,
name
=
"init_hidden"
),
Input
(
[
None
,
self
.
inputs
[
1
].
shape
[
-
1
]],
"float32"
,
name
=
"init_cell"
)
]
return
inputs
def
test_check_output
(
self
):
self
.
check_output
()
if
__name__
==
'__main__'
:
unittest
.
main
()
hapi/text/text.py
浏览文件 @
60917f41
...
@@ -48,8 +48,8 @@ __all__ = [
...
@@ -48,8 +48,8 @@ __all__ = [
'RNNCell'
,
'BasicLSTMCell'
,
'BasicGRUCell'
,
'RNN'
,
'DynamicDecode'
,
'RNNCell'
,
'BasicLSTMCell'
,
'BasicGRUCell'
,
'RNN'
,
'DynamicDecode'
,
'BeamSearchDecoder'
,
'MultiHeadAttention'
,
'FFN'
,
'BeamSearchDecoder'
,
'MultiHeadAttention'
,
'FFN'
,
'TransformerEncoderLayer'
,
'TransformerEncoder'
,
'TransformerDecoderLayer'
,
'TransformerEncoderLayer'
,
'TransformerEncoder'
,
'TransformerDecoderLayer'
,
'TransformerDecoder'
,
'Transformer
BeamSearchDecoder'
,
'Linear_chain_crf
'
,
'TransformerDecoder'
,
'Transformer
Cell'
,
'TransformerBeamSearchDecoder
'
,
'Crf_decoding'
,
'SequenceTagging'
,
'GRUEncoderLayer'
'
Linear_chain_crf'
,
'
Crf_decoding'
,
'SequenceTagging'
,
'GRUEncoderLayer'
]
]
...
@@ -1002,7 +1002,7 @@ class DynamicDecode(Layer):
...
@@ -1002,7 +1002,7 @@ class DynamicDecode(Layer):
**
kwargs
)
**
kwargs
)
class
Transfo
merCell
(
object
):
class
Transfo
rmerCell
(
Layer
):
"""
"""
Let inputs=(trg_word, trg_pos), states=cache to make Transformer can be
Let inputs=(trg_word, trg_pos), states=cache to make Transformer can be
used as RNNCell
used as RNNCell
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录