Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
2bb216b7
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2bb216b7
编写于
4月 04, 2020
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update seq2seq
上级
f91528a9
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
379 addition
and
234 deletion
+379
-234
seq2seq/reader.py
seq2seq/reader.py
+320
-190
seq2seq/seq2seq_attn.py
seq2seq/seq2seq_attn.py
+8
-5
seq2seq/seq2seq_base.py
seq2seq/seq2seq_base.py
+2
-2
seq2seq/train.py
seq2seq/train.py
+49
-37
未找到文件。
seq2seq/reader.py
浏览文件 @
2bb216b7
...
...
@@ -16,203 +16,333 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
os
import
glob
import
io
import
sys
import
numpy
as
np
Py3
=
sys
.
version_info
[
0
]
==
3
UNK_ID
=
0
def
_read_words
(
filename
):
data
=
[]
with
io
.
open
(
filename
,
"r"
,
encoding
=
'utf-8'
)
as
f
:
if
Py3
:
return
f
.
read
().
replace
(
"
\n
"
,
"<eos>"
).
split
()
import
itertools
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.io
import
BatchSampler
,
DataLoader
,
Dataset
def
prepare_train_input
(
insts
,
bos_id
,
eos_id
,
pad_id
):
src
,
src_length
=
pad_batch_data
(
[
inst
[
0
]
for
inst
in
insts
],
pad_id
=
pad_id
)
trg
,
trg_length
=
pad_batch_data
(
[[
bos_id
]
+
inst
[
1
]
+
[
eos_id
]
for
inst
in
insts
],
pad_id
=
pad_id
)
trg_length
=
trg_length
-
1
return
src
,
src_length
,
trg
[:,
:
-
1
],
trg_length
,
trg
[:,
1
:,
np
.
newaxis
]
def
pad_batch_data
(
insts
,
pad_id
):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
inst_lens
=
np
.
array
([
len
(
inst
)
for
inst
in
insts
],
dtype
=
"int64"
)
max_len
=
np
.
max
(
inst_lens
)
inst_data
=
np
.
array
(
[
inst
+
[
pad_id
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
],
dtype
=
"int64"
)
return
inst_data
,
inst_lens
class
SortType
(
object
):
GLOBAL
=
'global'
POOL
=
'pool'
NONE
=
"none"
class
Converter
(
object
):
def
__init__
(
self
,
vocab
,
beg
,
end
,
unk
,
delimiter
,
add_beg
,
add_end
):
self
.
_vocab
=
vocab
self
.
_beg
=
beg
self
.
_end
=
end
self
.
_unk
=
unk
self
.
_delimiter
=
delimiter
self
.
_add_beg
=
add_beg
self
.
_add_end
=
add_end
def
__call__
(
self
,
sentence
):
return
([
self
.
_beg
]
if
self
.
_add_beg
else
[])
+
[
self
.
_vocab
.
get
(
w
,
self
.
_unk
)
for
w
in
sentence
.
split
(
self
.
_delimiter
)
]
+
([
self
.
_end
]
if
self
.
_add_end
else
[])
class
ComposedConverter
(
object
):
def
__init__
(
self
,
converters
):
self
.
_converters
=
converters
def
__call__
(
self
,
fields
):
return
[
converter
(
field
)
for
field
,
converter
in
zip
(
fields
,
self
.
_converters
)
]
class
SentenceBatchCreator
(
object
):
def
__init__
(
self
,
batch_size
):
self
.
batch
=
[]
self
.
_batch_size
=
batch_size
def
append
(
self
,
info
):
self
.
batch
.
append
(
info
)
if
len
(
self
.
batch
)
==
self
.
_batch_size
:
tmp
=
self
.
batch
self
.
batch
=
[]
return
tmp
class
TokenBatchCreator
(
object
):
def
__init__
(
self
,
batch_size
):
self
.
batch
=
[]
self
.
max_len
=
-
1
self
.
_batch_size
=
batch_size
def
append
(
self
,
info
):
cur_len
=
info
.
max_len
max_len
=
max
(
self
.
max_len
,
cur_len
)
if
max_len
*
(
len
(
self
.
batch
)
+
1
)
>
self
.
_batch_size
:
result
=
self
.
batch
self
.
batch
=
[
info
]
self
.
max_len
=
cur_len
return
result
else
:
return
f
.
read
().
decode
(
"utf-8"
).
replace
(
u
"
\n
"
,
u
"<eos>"
).
split
()
def
read_all_line
(
filenam
):
data
=
[]
with
io
.
open
(
filename
,
"r"
,
encoding
=
'utf-8'
)
as
f
:
for
line
in
f
.
readlines
():
data
.
append
(
line
.
strip
())
def
_build_vocab
(
filename
):
vocab_dict
=
{}
ids
=
0
with
io
.
open
(
filename
,
"r"
,
encoding
=
'utf-8'
)
as
f
:
for
line
in
f
.
readlines
():
vocab_dict
[
line
.
strip
()]
=
ids
ids
+=
1
print
(
"vocab word num"
,
ids
)
return
vocab_dict
def
_para_file_to_ids
(
src_file
,
tar_file
,
src_vocab
,
tar_vocab
):
src_data
=
[]
with
io
.
open
(
src_file
,
"r"
,
encoding
=
'utf-8'
)
as
f_src
:
for
line
in
f_src
.
readlines
():
arra
=
line
.
strip
().
split
()
ids
=
[
src_vocab
[
w
]
if
w
in
src_vocab
else
UNK_ID
for
w
in
arra
]
ids
=
ids
src_data
.
append
(
ids
)
tar_data
=
[]
with
io
.
open
(
tar_file
,
"r"
,
encoding
=
'utf-8'
)
as
f_tar
:
for
line
in
f_tar
.
readlines
():
arra
=
line
.
strip
().
split
()
ids
=
[
tar_vocab
[
w
]
if
w
in
tar_vocab
else
UNK_ID
for
w
in
arra
]
ids
=
[
1
]
+
ids
+
[
2
]
tar_data
.
append
(
ids
)
return
src_data
,
tar_data
def
filter_len
(
src
,
tar
,
max_sequence_len
=
50
):
new_src
=
[]
new_tar
=
[]
for
id1
,
id2
in
zip
(
src
,
tar
):
if
len
(
id1
)
>
max_sequence_len
:
id1
=
id1
[:
max_sequence_len
]
if
len
(
id2
)
>
max_sequence_len
+
2
:
id2
=
id2
[:
max_sequence_len
+
2
]
self
.
max_len
=
max_len
self
.
batch
.
append
(
info
)
new_src
.
append
(
id1
)
new_tar
.
append
(
id2
)
return
new_src
,
new_tar
class
SampleInfo
(
object
):
def
__init__
(
self
,
i
,
max_len
,
min_len
):
self
.
i
=
i
self
.
min_len
=
min_len
self
.
max_len
=
max_len
def
raw_data
(
src_lang
,
tar_lang
,
vocab_prefix
,
train_prefix
,
eval_prefix
,
test_prefix
,
max_sequence_len
=
50
):
class
MinMaxFilter
(
object
):
def
__init__
(
self
,
max_len
,
min_len
,
underlying_creator
):
self
.
_min_len
=
min_len
self
.
_max_len
=
max_len
self
.
_creator
=
underlying_creator
src_vocab_file
=
vocab_prefix
+
"."
+
src_lang
tar_vocab_file
=
vocab_prefix
+
"."
+
tar_lang
src_train_file
=
train_prefix
+
"."
+
src_lang
tar_train_file
=
train_prefix
+
"."
+
tar_lang
src_eval_file
=
eval_prefix
+
"."
+
src_lang
tar_eval_file
=
eval_prefix
+
"."
+
tar_lang
src_test_file
=
test_prefix
+
"."
+
src_lang
tar_test_file
=
test_prefix
+
"."
+
tar_lang
src_vocab
=
_build_vocab
(
src_vocab_file
)
tar_vocab
=
_build_vocab
(
tar_vocab_file
)
train_src
,
train_tar
=
_para_file_to_ids
(
src_train_file
,
tar_train_file
,
\
src_vocab
,
tar_vocab
)
train_src
,
train_tar
=
filter_len
(
train_src
,
train_tar
,
max_sequence_len
=
max_sequence_len
)
eval_src
,
eval_tar
=
_para_file_to_ids
(
src_eval_file
,
tar_eval_file
,
\
src_vocab
,
tar_vocab
)
test_src
,
test_tar
=
_para_file_to_ids
(
src_test_file
,
tar_test_file
,
\
src_vocab
,
tar_vocab
)
return
(
train_src
,
train_tar
),
(
eval_src
,
eval_tar
),
(
test_src
,
test_tar
),
\
(
src_vocab
,
tar_vocab
)
def
raw_mono_data
(
vocab_file
,
file_path
):
def
append
(
self
,
info
):
if
info
.
max_len
>
self
.
_max_len
or
info
.
min_len
<
self
.
_min_len
:
return
else
:
return
self
.
_creator
.
append
(
info
)
@
property
def
batch
(
self
):
return
self
.
_creator
.
batch
class
Seq2SeqDataset
(
Dataset
):
def
__init__
(
self
,
src_vocab_fpath
,
trg_vocab_fpath
,
fpattern
,
field_delimiter
=
"
\t
"
,
token_delimiter
=
" "
,
start_mark
=
"<s>"
,
end_mark
=
"<e>"
,
unk_mark
=
"<unk>"
,
only_src
=
False
,
trg_fpattern
=
None
):
# convert str to bytes, and use byte data
# field_delimiter = field_delimiter.encode("utf8")
# token_delimiter = token_delimiter.encode("utf8")
# start_mark = start_mark.encode("utf8")
# end_mark = end_mark.encode("utf8")
# unk_mark = unk_mark.encode("utf8")
self
.
_src_vocab
=
self
.
load_dict
(
src_vocab_fpath
)
self
.
_trg_vocab
=
self
.
load_dict
(
trg_vocab_fpath
)
self
.
_bos_idx
=
self
.
_src_vocab
[
start_mark
]
self
.
_eos_idx
=
self
.
_src_vocab
[
end_mark
]
self
.
_unk_idx
=
self
.
_src_vocab
[
unk_mark
]
self
.
_only_src
=
only_src
self
.
_field_delimiter
=
field_delimiter
self
.
_token_delimiter
=
token_delimiter
self
.
load_src_trg_ids
(
fpattern
,
trg_fpattern
)
def
load_src_trg_ids
(
self
,
fpattern
,
trg_fpattern
=
None
):
src_converter
=
Converter
(
vocab
=
self
.
_src_vocab
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
False
,
add_end
=
False
)
trg_converter
=
Converter
(
vocab
=
self
.
_trg_vocab
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
False
,
add_end
=
False
)
converters
=
ComposedConverter
([
src_converter
,
trg_converter
])
self
.
_src_seq_ids
=
[]
self
.
_trg_seq_ids
=
[]
self
.
_sample_infos
=
[]
slots
=
[
self
.
_src_seq_ids
,
self
.
_trg_seq_ids
]
lens
=
[]
for
i
,
line
in
enumerate
(
self
.
_load_lines
(
fpattern
,
trg_fpattern
)):
lens
=
[]
for
field
,
slot
in
zip
(
converters
(
line
),
slots
):
slot
.
append
(
field
)
lens
.
append
(
len
(
field
))
# self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
self
.
_sample_infos
.
append
(
SampleInfo
(
i
,
lens
[
0
],
lens
[
0
]))
def
_load_lines
(
self
,
fpattern
,
trg_fpattern
=
None
):
fpaths
=
glob
.
glob
(
fpattern
)
fpaths
=
sorted
(
fpaths
)
# TODO: Add custum sort
assert
len
(
fpaths
)
>
0
,
"no matching file to the provided data path"
if
trg_fpattern
is
None
:
for
fpath
in
fpaths
:
# with io.open(fpath, "rb") as f:
with
io
.
open
(
fpath
,
"r"
,
encoding
=
"utf8"
)
as
f
:
for
line
in
f
:
fields
=
line
.
strip
(
"
\n
"
).
split
(
self
.
_field_delimiter
)
yield
fields
else
:
# separated source and target language data files
# assume we can get aligned data by sort the two language files
# TODO: Need more rigorous check
trg_fpaths
=
glob
.
glob
(
trg_fpattern
)
trg_fpaths
=
sorted
(
trg_fpaths
)
assert
len
(
fpaths
)
==
len
(
trg_fpaths
),
"the number of source language data files must equal
\
with that of source language"
for
fpath
,
trg_fpath
in
zip
(
fpaths
,
trg_fpaths
):
# with io.open(fpath, "rb") as f:
# with io.open(trg_fpath, "rb") as trg_f:
with
io
.
open
(
fpath
,
"r"
,
encoding
=
"utf8"
)
as
f
:
with
io
.
open
(
trg_fpath
,
"r"
,
encoding
=
"utf8"
)
as
trg_f
:
for
line
in
zip
(
f
,
trg_f
):
fields
=
[
field
.
strip
(
"
\n
"
)
for
field
in
line
]
yield
fields
@
staticmethod
def
load_dict
(
dict_path
,
reverse
=
False
):
word_dict
=
{}
# with io.open(dict_path, "rb") as fdict:
with
io
.
open
(
dict_path
,
"r"
,
encoding
=
"utf8"
)
as
fdict
:
for
idx
,
line
in
enumerate
(
fdict
):
if
reverse
:
word_dict
[
idx
]
=
line
.
strip
(
"
\n
"
)
else
:
word_dict
[
line
.
strip
(
"
\n
"
)]
=
idx
return
word_dict
src_vocab
=
_build_vocab
(
vocab_file
)
def
get_vocab_summary
(
self
):
return
len
(
self
.
_src_vocab
),
len
(
self
.
_trg_vocab
),
self
.
_bos_idx
,
self
.
_eos_idx
,
self
.
_unk_idx
test_src
,
test_tar
=
_para_file_to_ids
(
file_path
,
file_path
,
\
src_vocab
,
src_vocab
)
def
__getitem__
(
self
,
idx
):
return
(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
]
)
if
self
.
_trg_seq_ids
else
self
.
_src_seq_ids
[
idx
]
return
(
test_src
,
test_tar
)
def
__len__
(
self
):
return
len
(
self
.
_sample_infos
)
def
get_data_iter
(
raw_data
,
class
Seq2SeqBatchSampler
(
BatchSampler
):
def
__init__
(
self
,
dataset
,
batch_size
,
mode
=
'train'
,
enable_ce
=
False
,
cache_num
=
20
):
src_data
,
tar_data
=
raw_data
data_len
=
len
(
src_data
)
index
=
np
.
arange
(
data_len
)
if
mode
==
"train"
and
not
enable_ce
:
np
.
random
.
shuffle
(
index
)
def
to_pad_np
(
data
,
source
=
False
):
max_len
=
0
bs
=
min
(
batch_size
,
len
(
data
))
for
ele
in
data
:
if
len
(
ele
)
>
max_len
:
max_len
=
len
(
ele
)
ids
=
np
.
ones
((
bs
,
max_len
),
dtype
=
'int64'
)
*
2
mask
=
np
.
zeros
((
bs
),
dtype
=
'int32'
)
for
i
,
ele
in
enumerate
(
data
):
ids
[
i
,
:
len
(
ele
)]
=
ele
if
not
source
:
mask
[
i
]
=
len
(
ele
)
-
1
pool_size
=
10000
,
sort_type
=
SortType
.
NONE
,
min_length
=
0
,
max_length
=
100
,
shuffle
=
False
,
shuffle_batch
=
False
,
use_token_batch
=
False
,
clip_last_batch
=
False
,
seed
=
None
):
for
arg
,
value
in
locals
().
items
():
if
arg
!=
"self"
:
setattr
(
self
,
"_"
+
arg
,
value
)
self
.
_random
=
np
.
random
self
.
_random
.
seed
(
seed
)
# for multi-devices
self
.
_nranks
=
ParallelEnv
().
nranks
self
.
_local_rank
=
ParallelEnv
().
local_rank
self
.
_device_id
=
ParallelEnv
().
dev_id
def
__iter__
(
self
):
# global sort or global shuffle
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
infos
=
sorted
(
self
.
_dataset
.
_sample_infos
,
key
=
lambda
x
:
x
.
max_len
)
else
:
mask
[
i
]
=
len
(
ele
)
return
ids
,
mask
b_src
=
[]
if
mode
!=
"train"
:
cache_num
=
1
for
j
in
range
(
data_len
):
if
len
(
b_src
)
==
batch_size
*
cache_num
:
# build batch size
# sort
if
mode
==
'infer'
:
new_cache
=
b_src
if
self
.
_shuffle
:
infos
=
self
.
_dataset
.
_sample_infos
self
.
_random
.
shuffle
(
infos
)
else
:
new_cache
=
sorted
(
b_src
,
key
=
lambda
k
:
len
(
k
[
0
]))
for
i
in
range
(
cache_num
):
batch_data
=
new_cache
[
i
*
batch_size
:(
i
+
1
)
*
batch_size
]
src_cache
=
[
w
[
0
]
for
w
in
batch_data
]
tar_cache
=
[
w
[
1
]
for
w
in
batch_data
]
src_ids
,
src_mask
=
to_pad_np
(
src_cache
,
source
=
True
)
tar_ids
,
tar_mask
=
to_pad_np
(
tar_cache
)
yield
(
src_ids
,
src_mask
,
tar_ids
,
tar_mask
)
b_src
=
[]
b_src
.
append
((
src_data
[
index
[
j
]],
tar_data
[
index
[
j
]]))
if
len
(
b_src
)
==
batch_size
*
cache_num
or
mode
==
'infer'
:
if
mode
==
'infer'
:
new_cache
=
b_src
infos
=
self
.
_dataset
.
_sample_infos
if
self
.
_sort_type
==
SortType
.
POOL
:
reverse
=
True
for
i
in
range
(
0
,
len
(
infos
),
self
.
_pool_size
):
# to avoid placing short next to long sentences
reverse
=
not
reverse
infos
[
i
:
i
+
self
.
_pool_size
]
=
sorted
(
infos
[
i
:
i
+
self
.
_pool_size
],
key
=
lambda
x
:
x
.
max_len
,
reverse
=
reverse
)
batches
=
[]
batch_creator
=
TokenBatchCreator
(
self
.
_batch_size
)
if
self
.
_use_token_batch
else
SentenceBatchCreator
(
self
.
_batch_size
*
self
.
_nranks
)
batch_creator
=
MinMaxFilter
(
self
.
_max_length
,
self
.
_min_length
,
batch_creator
)
for
info
in
infos
:
batch
=
batch_creator
.
append
(
info
)
if
batch
is
not
None
:
batches
.
append
(
batch
)
if
not
self
.
_clip_last_batch
and
len
(
batch_creator
.
batch
)
!=
0
:
batches
.
append
(
batch_creator
.
batch
)
if
self
.
_shuffle_batch
:
self
.
_random
.
shuffle
(
batches
)
if
not
self
.
_use_token_batch
:
# when producing batches according to sequence number, to confirm
# neighbor batches which would be feed and run parallel have similar
# length (thus similar computational cost) after shuffle, we as take
# them as a whole when shuffling and split here
batches
=
[[
batch
[
self
.
_batch_size
*
i
:
self
.
_batch_size
*
(
i
+
1
)]
for
i
in
range
(
self
.
_nranks
)
]
for
batch
in
batches
]
batches
=
list
(
itertools
.
chain
.
from_iterable
(
batches
))
# for multi-device
for
batch_id
,
batch
in
enumerate
(
batches
):
if
batch_id
%
self
.
_nranks
==
self
.
_local_rank
:
batch_indices
=
[
info
.
i
for
info
in
batch
]
yield
batch_indices
if
self
.
_local_rank
>
len
(
batches
)
%
self
.
_nranks
:
yield
batch_indices
def
__len__
(
self
):
if
not
self
.
_use_token_batch
:
batch_number
=
(
len
(
self
.
_dataset
)
+
self
.
_batch_size
*
self
.
_nranks
-
1
)
//
(
self
.
_batch_size
*
self
.
_nranks
)
else
:
new_cache
=
sorted
(
b_src
,
key
=
lambda
k
:
len
(
k
[
0
]))
for
i
in
range
(
cache_num
):
batch_end
=
min
(
len
(
new_cache
),
(
i
+
1
)
*
batch_size
)
batch_data
=
new_cache
[
i
*
batch_size
:
batch_end
]
src_cache
=
[
w
[
0
]
for
w
in
batch_data
]
tar_cache
=
[
w
[
1
]
for
w
in
batch_data
]
src_ids
,
src_mask
=
to_pad_np
(
src_cache
,
source
=
True
)
tar_ids
,
tar_mask
=
to_pad_np
(
tar_cache
)
yield
(
src_ids
,
src_mask
,
tar_ids
,
tar_mask
)
batch_number
=
100
return
batch_number
seq2seq/seq2seq_attn.py
浏览文件 @
2bb216b7
...
...
@@ -41,9 +41,10 @@ class AttentionLayer(Layer):
bias_attr
=
bias
)
def
forward
(
self
,
hidden
,
encoder_output
,
encoder_padding_mask
):
query
=
self
.
input_proj
(
hidden
)
# query = self.input_proj(hidden)
encoder_output
=
self
.
input_proj
(
encoder_output
)
attn_scores
=
layers
.
matmul
(
layers
.
unsqueeze
(
query
,
[
1
]),
encoder_output
,
transpose_y
=
True
)
layers
.
unsqueeze
(
hidden
,
[
1
]),
encoder_output
,
transpose_y
=
True
)
if
encoder_padding_mask
is
not
None
:
attn_scores
=
layers
.
elementwise_add
(
attn_scores
,
encoder_padding_mask
)
...
...
@@ -73,7 +74,9 @@ class DecoderCell(RNNCell):
BasicLSTMCell
(
input_size
=
input_size
+
hidden_size
if
i
==
0
else
hidden_size
,
hidden_size
=
hidden_size
)))
hidden_size
=
hidden_size
,
param_attr
=
ParamAttr
(
initializer
=
UniformInitializer
(
low
=-
init_scale
,
high
=
init_scale
)))))
self
.
attention_layer
=
AttentionLayer
(
hidden_size
)
def
forward
(
self
,
...
...
@@ -107,8 +110,8 @@ class Decoder(Layer):
size
=
[
vocab_size
,
embed_dim
],
param_attr
=
ParamAttr
(
initializer
=
UniformInitializer
(
low
=-
init_scale
,
high
=
init_scale
)))
self
.
lstm_attention
=
RNN
(
DecoderCell
(
num_layers
,
embed_dim
,
hidden_size
,
init_scale
),
self
.
lstm_attention
=
RNN
(
DecoderCell
(
num_layers
,
embed_dim
,
hidden_size
,
dropout_prob
,
init_scale
),
is_reverse
=
False
,
time_major
=
False
)
self
.
output_layer
=
Linear
(
...
...
seq2seq/seq2seq_base.py
浏览文件 @
2bb216b7
...
...
@@ -86,7 +86,7 @@ class Encoder(Layer):
param_attr
=
ParamAttr
(
initializer
=
UniformInitializer
(
low
=-
init_scale
,
high
=
init_scale
)))
self
.
stack_lstm
=
RNN
(
EncoderCell
(
num_layers
,
embed_dim
,
hidden_size
,
init_scale
),
dropout_prob
,
init_scale
),
is_reverse
=
False
,
time_major
=
False
)
...
...
@@ -114,7 +114,7 @@ class Decoder(Layer):
param_attr
=
ParamAttr
(
initializer
=
UniformInitializer
(
low
=-
init_scale
,
high
=
init_scale
)))
self
.
stack_lstm
=
RNN
(
DecoderCell
(
num_layers
,
embed_dim
,
hidden_size
,
init_scale
),
dropout_prob
,
init_scale
),
is_reverse
=
False
,
time_major
=
False
)
self
.
output_layer
=
Linear
(
...
...
seq2seq/train.py
浏览文件 @
2bb216b7
...
...
@@ -17,8 +17,7 @@ import os
import
six
import
sys
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))))
import
time
import
contextlib
import
random
from
functools
import
partial
import
numpy
as
np
...
...
@@ -34,16 +33,17 @@ from seq2seq_base import BaseModel, CrossEntropyCriterion
from
seq2seq_attn
import
AttentionModel
from
model
import
Input
,
set_device
from
callbacks
import
ProgBarLogger
from
metrics
import
Metric
class
PPL
(
Metric
):
pass
from
reader
import
Seq2SeqDataset
,
Seq2SeqBatchSampler
,
SortType
,
prepare_train_input
def
do_train
(
args
):
device
=
set_device
(
"gpu"
if
args
.
use_gpu
else
"cpu"
)
fluid
.
enable_dygraph
(
device
)
#if args.eager_run else None
fluid
.
enable_dygraph
(
device
)
if
args
.
eager_run
else
None
if
args
.
enable_ce
:
fluid
.
default_main_program
().
random_seed
=
102
fluid
.
default_startup_program
().
random_seed
=
102
args
.
shuffle
=
False
# define model
inputs
=
[
...
...
@@ -58,6 +58,45 @@ def do_train(args):
]
labels
=
[
Input
([
None
,
None
,
1
],
"int64"
,
name
=
"label"
),
]
# def dataloader
data_loaders
=
[
None
,
None
]
data_prefixes
=
[
args
.
train_data_prefix
,
args
.
eval_data_prefix
]
if
args
.
eval_data_prefix
else
[
args
.
train_data_prefix
]
for
i
,
data_prefix
in
enumerate
(
data_prefixes
):
dataset
=
Seq2SeqDataset
(
fpattern
=
data_prefix
+
"."
+
args
.
src_lang
,
trg_fpattern
=
data_prefix
+
"."
+
args
.
tar_lang
,
src_vocab_fpath
=
args
.
vocab_prefix
+
"."
+
args
.
src_lang
,
trg_vocab_fpath
=
args
.
vocab_prefix
+
"."
+
args
.
tar_lang
,
token_delimiter
=
None
,
start_mark
=
"<s>"
,
end_mark
=
"</s>"
,
unk_mark
=
"<unk>"
)
(
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
bos_id
,
eos_id
,
unk_id
)
=
dataset
.
get_vocab_summary
()
batch_sampler
=
Seq2SeqBatchSampler
(
dataset
=
dataset
,
use_token_batch
=
False
,
batch_size
=
args
.
batch_size
,
pool_size
=
args
.
batch_size
*
20
,
sort_type
=
SortType
.
POOL
,
shuffle
=
args
.
shuffle
)
data_loader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
places
=
device
,
feed_list
=
None
if
fluid
.
in_dygraph_mode
()
else
[
x
.
forward
()
for
x
in
inputs
+
labels
],
collate_fn
=
partial
(
prepare_train_input
,
bos_id
=
bos_id
,
eos_id
=
eos_id
,
pad_id
=
eos_id
),
num_workers
=
0
,
return_list
=
True
)
data_loaders
[
i
]
=
data_loader
train_loader
,
eval_loader
=
data_loaders
model
=
AttentionModel
(
args
.
src_vocab_size
,
args
.
tar_vocab_size
,
args
.
hidden_size
,
args
.
hidden_size
,
args
.
num_layers
,
args
.
dropout
)
...
...
@@ -69,39 +108,12 @@ def do_train(args):
CrossEntropyCriterion
(),
inputs
=
inputs
,
labels
=
labels
)
batch_size
=
32
src_seq_len
=
10
trg_seq_len
=
12
iter_num
=
10
def
random_generator
():
for
i
in
range
(
iter_num
):
src
=
np
.
random
.
randint
(
2
,
args
.
src_vocab_size
,
(
batch_size
,
src_seq_len
)).
astype
(
"int64"
)
src_length
=
np
.
random
.
randint
(
1
,
src_seq_len
,
(
batch_size
,
)).
astype
(
"int64"
)
trg
=
np
.
random
.
randint
(
2
,
args
.
tar_vocab_size
,
(
batch_size
,
trg_seq_len
)).
astype
(
"int64"
)
trg_length
=
np
.
random
.
randint
(
1
,
trg_seq_len
,
(
batch_size
,
)).
astype
(
"int64"
)
label
=
np
.
random
.
randint
(
1
,
trg_seq_len
,
(
batch_size
,
trg_seq_len
,
1
)).
astype
(
"int64"
)
yield
src
,
src_length
,
trg
,
trg_length
,
label
model
.
fit
(
train_data
=
random_generator
,
log_freq
=
1
)
exit
(
0
)
data_loaders
=
[
None
,
None
]
data_files
=
[
args
.
training_file
,
args
.
validation_file
]
if
args
.
validation_file
else
[
args
.
training_file
]
train_loader
,
eval_loader
=
data_loaders
model
.
fit
(
train_data
=
train_loader
,
eval_data
=
None
,
eval_data
=
eval_loader
,
epochs
=
1
,
eval_freq
=
1
,
save_freq
=
1
,
log_freq
=
1
,
verbose
=
2
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录