Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
c34bb5f1
M
models
项目概览
PaddlePaddle
/
models
1 年多 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c34bb5f1
编写于
8月 15, 2018
作者:
Y
Yu Yang
提交者:
GitHub
8月 15, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1060 from reyoung/speed_up_transformer_python_reader
Speed up NTM reader
上级
f36588dc
6fc33ac5
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
178 addition
and
213 deletion
+178
-213
.gitignore
.gitignore
+6
-0
fluid/neural_machine_translation/transformer/infer.py
fluid/neural_machine_translation/transformer/infer.py
+0
-1
fluid/neural_machine_translation/transformer/model.py
fluid/neural_machine_translation/transformer/model.py
+1
-0
fluid/neural_machine_translation/transformer/reader.py
fluid/neural_machine_translation/transformer/reader.py
+166
-206
fluid/neural_machine_translation/transformer/train.py
fluid/neural_machine_translation/transformer/train.py
+5
-6
未找到文件。
.gitignore
浏览文件 @
c34bb5f1
.DS_Store
*.pyc
.*~
fluid/neural_machine_translation/transformer/deps
fluid/neural_machine_translation/transformer/train.data
fluid/neural_machine_translation/transformer/train.pkl
fluid/neural_machine_translation/transformer/train.sh
fluid/neural_machine_translation/transformer/train.tok.clean.bpe.32000.en-de
fluid/neural_machine_translation/transformer/vocab.bpe.32000.refined
fluid/neural_machine_translation/transformer/infer.py
浏览文件 @
c34bb5f1
...
...
@@ -553,7 +553,6 @@ def fast_infer(test_data, trg_idx2word, use_wordpiece):
def
infer
(
args
,
inferencer
=
fast_infer
):
place
=
fluid
.
CUDAPlace
(
0
)
if
InferTaskConfig
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
test_data
=
reader
.
DataReader
(
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
...
...
fluid/neural_machine_translation/transformer/model.py
浏览文件 @
c34bb5f1
...
...
@@ -474,6 +474,7 @@ def transformer(
sum_cost
=
layers
.
reduce_sum
(
weighted_cost
)
token_num
=
layers
.
reduce_sum
(
weights
)
avg_cost
=
sum_cost
/
token_num
avg_cost
.
stop_gradient
=
True
return
sum_cost
,
avg_cost
,
predict
,
token_num
...
...
fluid/neural_machine_translation/transformer/reader.py
浏览文件 @
c34bb5f1
import
os
import
tarfile
import
glob
import
os
import
random
import
tarfile
import
cPickle
class
SortType
(
object
):
...
...
@@ -11,54 +11,86 @@ class SortType(object):
NONE
=
"none"
class
EndEpoch
():
pass
class
Converter
(
object
):
def
__init__
(
self
,
vocab
,
beg
,
end
,
unk
,
delimiter
):
self
.
_vocab
=
vocab
self
.
_beg
=
beg
self
.
_end
=
end
self
.
_unk
=
unk
self
.
_delimiter
=
delimiter
def
__call__
(
self
,
sentence
):
return
[
self
.
_beg
]
+
[
self
.
_vocab
.
get
(
w
,
self
.
_unk
)
for
w
in
sentence
.
split
(
self
.
_delimiter
)
]
+
[
self
.
_end
]
class
Pool
(
object
):
def
__init__
(
self
,
sample_generator
,
pool_size
,
sort
):
self
.
_pool_size
=
pool_size
self
.
_pool
=
[]
self
.
_sample_generator
=
sample_generator
()
self
.
_end
=
False
self
.
_sort
=
sort
def
_fill
(
self
):
while
len
(
self
.
_pool
)
<
self
.
_pool_size
and
not
self
.
_end
:
try
:
sample
=
self
.
_sample_generator
.
next
()
self
.
_pool
.
append
(
sample
)
except
StopIteration
as
e
:
self
.
_end
=
True
break
if
self
.
_sort
:
self
.
_pool
.
sort
(
key
=
lambda
sample
:
max
(
len
(
sample
[
0
]),
len
(
sample
[
1
]))
\
if
len
(
sample
)
>
1
else
len
(
sample
[
0
])
)
if
self
.
_end
and
len
(
self
.
_pool
)
<
self
.
_pool_size
:
self
.
_pool
.
append
(
EndEpoch
())
def
push_back
(
self
,
samples
):
if
len
(
self
.
_pool
)
!=
0
:
raise
Exception
(
"Pool should be empty."
)
if
len
(
samples
)
>=
self
.
_pool_size
:
raise
Exception
(
"Capacity of pool should be greater than a batch. "
"Please enlarge `pool_size`."
)
for
sample
in
samples
:
self
.
_pool
.
append
(
sample
)
self
.
_fill
()
def
next
(
self
,
look
=
False
):
if
len
(
self
.
_pool
)
==
0
:
return
None
class
ComposedConverter
(
object
):
def
__init__
(
self
,
converters
):
self
.
_converters
=
converters
def
__call__
(
self
,
parallel_sentence
):
return
[
self
.
_converters
[
i
](
parallel_sentence
[
i
])
for
i
in
range
(
len
(
self
.
_converters
))
]
class
SentenceBatchCreator
(
object
):
def
__init__
(
self
,
batch_size
):
self
.
batch
=
[]
self
.
_batch_size
=
batch_size
def
append
(
self
,
info
):
self
.
batch
.
append
(
info
)
if
len
(
self
.
batch
)
==
self
.
_batch_size
:
tmp
=
self
.
batch
self
.
batch
=
[]
return
tmp
class
TokenBatchCreator
(
object
):
def
__init__
(
self
,
batch_size
):
self
.
batch
=
[]
self
.
max_len
=
-
1
self
.
_batch_size
=
batch_size
def
append
(
self
,
info
):
cur_len
=
info
.
max_len
max_len
=
max
(
self
.
max_len
,
cur_len
)
if
max_len
*
(
len
(
self
.
batch
)
+
1
)
>
self
.
_batch_size
:
result
=
self
.
batch
self
.
batch
=
[
info
]
self
.
max_len
=
cur_len
return
result
else
:
self
.
max_len
=
max_len
self
.
batch
.
append
(
info
)
class
SampleInfo
(
object
):
def
__init__
(
self
,
i
,
max_len
,
min_len
):
self
.
i
=
i
self
.
min_len
=
min_len
self
.
max_len
=
max_len
class
MinMaxFilter
(
object
):
def
__init__
(
self
,
max_len
,
min_len
,
underlying_creator
):
self
.
_min_len
=
min_len
self
.
_max_len
=
max_len
self
.
_creator
=
underlying_creator
def
append
(
self
,
info
):
if
info
.
max_len
>
self
.
_max_len
or
info
.
min_len
<
self
.
_min_len
:
return
else
:
return
self
.
_pool
[
0
]
if
look
else
self
.
_pool
.
pop
(
0
)
return
self
.
_creator
.
append
(
info
)
@
property
def
batch
(
self
):
return
self
.
_creator
.
batch
class
DataReader
(
object
):
...
...
@@ -140,7 +172,7 @@ class DataReader(object):
fpattern
,
batch_size
,
pool_size
,
sort_type
=
SortType
.
NONE
,
sort_type
=
SortType
.
GLOBAL
,
clip_last_batch
=
True
,
tar_fname
=
None
,
min_length
=
0
,
...
...
@@ -170,92 +202,68 @@ class DataReader(object):
self
.
_max_length
=
max_length
self
.
_field_delimiter
=
field_delimiter
self
.
_token_delimiter
=
token_delimiter
self
.
_epoch_batches
=
[]
src_seq_words
,
trg_seq_words
=
self
.
_load_data
(
fpattern
,
tar_fname
)
self
.
_src_seq_ids
=
[[
self
.
_src_vocab
.
get
(
word
,
self
.
_src_vocab
.
get
(
unk_mark
))
for
word
in
([
start_mark
]
+
src_seq
+
[
end_mark
])
]
for
src_seq
in
src_seq_words
]
self
.
_sample_count
=
len
(
self
.
_src_seq_ids
)
self
.
load_src_trg_ids
(
end_mark
,
fpattern
,
start_mark
,
tar_fname
,
unk_mark
)
self
.
_random
=
random
.
Random
(
x
=
seed
)
def
load_src_trg_ids
(
self
,
end_mark
,
fpattern
,
start_mark
,
tar_fname
,
unk_mark
):
converters
=
[
Converter
(
vocab
=
self
.
_src_vocab
,
beg
=
self
.
_src_vocab
[
start_mark
],
end
=
self
.
_src_vocab
[
end_mark
],
unk
=
self
.
_src_vocab
[
unk_mark
],
delimiter
=
self
.
_token_delimiter
)
]
if
not
self
.
_only_src
:
self
.
_trg_seq_ids
=
[[
self
.
_trg_vocab
.
get
(
word
,
self
.
_trg_vocab
.
get
(
unk_mark
))
for
word
in
([
start_mark
]
+
trg_seq
+
[
end_mark
])
]
for
trg_seq
in
trg_seq_words
]
if
len
(
self
.
_trg_seq_ids
)
!=
self
.
_sample_count
:
raise
Exception
(
"Inconsistent sample count between "
"source sequences and target sequences."
)
else
:
self
.
_trg_seq_ids
=
None
self
.
_sample_idxs
=
[
i
for
i
in
xrange
(
self
.
_sample_count
)]
self
.
_sorted
=
False
random
.
seed
(
seed
)
def
_parse_file
(
self
,
f_obj
):
src_seq_words
=
[]
trg_seq_words
=
[]
for
line
in
f_obj
:
fields
=
line
.
strip
().
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
!=
2
)
or
(
self
.
_only_src
and
len
(
fields
)
!=
1
):
continue
sample_words
=
[]
is_valid_sample
=
True
max_len
=
-
1
for
i
,
seq
in
enumerate
(
fields
):
seq_words
=
seq
.
split
(
self
.
_token_delimiter
)
max_len
=
max
(
max_len
,
len
(
seq_words
))
if
len
(
seq_words
)
==
0
or
\
len
(
seq_words
)
<
self
.
_min_length
or
\
len
(
seq_words
)
>
self
.
_max_length
or
\
(
self
.
_use_token_batch
and
max_len
>
self
.
_batch_size
):
is_valid_sample
=
False
break
sample_words
.
append
(
seq_words
)
if
not
is_valid_sample
:
continue
src_seq_words
.
append
(
sample_words
[
0
])
converters
.
append
(
Converter
(
vocab
=
self
.
_trg_vocab
,
beg
=
self
.
_trg_vocab
[
start_mark
],
end
=
self
.
_trg_vocab
[
end_mark
],
unk
=
self
.
_trg_vocab
[
unk_mark
],
delimiter
=
self
.
_token_delimiter
))
converters
=
ComposedConverter
(
converters
)
self
.
_src_seq_ids
=
[]
self
.
_trg_seq_ids
=
None
if
self
.
_only_src
else
[]
self
.
_sample_infos
=
[]
for
i
,
line
in
enumerate
(
self
.
_load_lines
(
fpattern
,
tar_fname
)):
src_trg_ids
=
converters
(
line
)
self
.
_src_seq_ids
.
append
(
src_trg_ids
[
0
])
lens
=
[
len
(
src_trg_ids
[
0
])]
if
not
self
.
_only_src
:
trg_seq_words
.
append
(
sample_words
[
1
])
self
.
_trg_seq_ids
.
append
(
src_trg_ids
[
1
])
lens
.
append
(
len
(
src_trg_ids
[
1
]))
self
.
_sample_infos
.
append
(
SampleInfo
(
i
,
max
(
lens
),
min
(
lens
)))
return
(
src_seq_words
,
trg_seq_words
)
def
_load_data
(
self
,
fpattern
,
tar_fname
):
def
_load_lines
(
self
,
fpattern
,
tar_fname
):
fpaths
=
glob
.
glob
(
fpattern
)
src_seq_words
=
[]
trg_seq_words
=
[]
if
len
(
fpaths
)
==
1
and
tarfile
.
is_tarfile
(
fpaths
[
0
]):
if
tar_fname
is
None
:
raise
Exception
(
"If tar file provided, please set tar_fname."
)
f
=
tarfile
.
open
(
fpaths
[
0
],
'r'
)
part_file_data
=
self
.
_parse_file
(
f
.
extractfile
(
tar_fname
))
src_seq_words
=
part_file_data
[
0
]
trg_seq_words
=
part_file_data
[
1
]
f
=
tarfile
.
open
(
fpaths
[
0
],
"r"
)
for
line
in
f
.
extractfile
(
tar_fname
):
fields
=
line
.
strip
(
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
yield
fields
else
:
for
fpath
in
fpaths
:
if
not
os
.
path
.
isfile
(
fpath
):
raise
IOError
(
"Invalid file: %s"
%
fpath
)
part_file_data
=
self
.
_parse_file
(
open
(
fpath
,
'r'
))
src_seq_words
.
extend
(
part_file_data
[
0
])
trg_seq_words
.
extend
(
part_file_data
[
1
])
return
src_seq_words
,
trg_seq_words
with
open
(
fpath
,
"r"
)
as
f
:
for
line
in
f
:
fields
=
line
.
strip
(
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
yield
fields
@
staticmethod
def
load_dict
(
dict_path
,
reverse
=
False
):
...
...
@@ -263,100 +271,52 @@ class DataReader(object):
with
open
(
dict_path
,
"r"
)
as
fdict
:
for
idx
,
line
in
enumerate
(
fdict
):
if
reverse
:
word_dict
[
idx
]
=
line
.
strip
(
'
\n
'
)
word_dict
[
idx
]
=
line
.
strip
(
"
\n
"
)
else
:
word_dict
[
line
.
strip
(
'
\n
'
)]
=
idx
word_dict
[
line
.
strip
(
"
\n
"
)]
=
idx
return
word_dict
def
_sample_generator
(
self
):
def
batch_generator
(
self
):
# global sort or global shuffle
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
if
not
self
.
_sorted
:
self
.
_sample_idxs
.
sort
(
key
=
lambda
idx
:
max
(
len
(
self
.
_src_seq_ids
[
idx
]),
len
(
self
.
_trg_seq_ids
[
idx
]
if
not
self
.
_only_src
else
0
))
)
self
.
_sorted
=
True
elif
self
.
_shuffle
:
random
.
shuffle
(
self
.
_sample_idxs
)
for
sample_idx
in
self
.
_sample_idxs
:
if
self
.
_only_src
:
yield
(
self
.
_src_seq_ids
[
sample_idx
],
)
infos
=
sorted
(
self
.
_sample_infos
,
key
=
lambda
x
:
x
.
max_len
,
reverse
=
True
)
else
:
if
self
.
_shuffle
:
infos
=
self
.
_sample_infos
self
.
_random
.
shuffle
(
infos
)
else
:
yield
(
self
.
_src_seq_ids
[
sample_idx
],
self
.
_trg_seq_ids
[
sample_idx
][:
-
1
],
self
.
_trg_seq_ids
[
sample_idx
][
1
:])
infos
=
self
.
_sample_infos
def
batch_generator
(
self
):
pool
=
Pool
(
self
.
_sample_generator
,
self
.
_pool_size
,
True
if
self
.
_sort_type
==
SortType
.
POOL
else
False
)
def
next_batch
():
batch_data
=
[]
max_len
=
-
1
batch_max_seq_len
=
-
1
if
self
.
_sort_type
==
SortType
.
POOL
:
for
i
in
range
(
0
,
len
(
infos
),
self
.
_pool_size
):
infos
[
i
:
i
+
self
.
_pool_size
]
=
sorted
(
infos
[
i
:
i
+
self
.
_pool_size
],
key
=
lambda
x
:
x
.
max_len
)
while
True
:
sample
=
pool
.
next
(
look
=
True
)
# concat batch
batches
=
[]
batch_creator
=
TokenBatchCreator
(
self
.
_batch_size
)
if
self
.
_use_token_batch
else
SentenceBatchCreator
(
self
.
_batch_size
)
batch_creator
=
MinMaxFilter
(
self
.
_max_length
,
self
.
_min_length
,
batch_creator
)
if
sample
is
None
:
pool
.
push_back
(
batch_data
)
batch_data
=
[]
continue
for
info
in
infos
:
batch
=
batch_creator
.
append
(
info
)
if
batch
is
not
None
:
batches
.
append
(
batch
)
if
isinstance
(
sample
,
EndEpoch
)
:
return
batch_data
,
batch_max_seq_len
,
True
if
not
self
.
_clip_last_batch
and
len
(
batch_creator
.
batch
)
!=
0
:
batches
.
append
(
batch_creator
.
batch
)
max_len
=
max
(
max_len
,
len
(
sample
[
0
]))
if
self
.
_shuffle_batch
:
self
.
_random
.
shuffle
(
batches
)
if
not
self
.
_only_src
:
max_len
=
max
(
max_len
,
len
(
sample
[
1
]))
for
batch
in
batches
:
batch_ids
=
[
info
.
i
for
info
in
batch
]
if
self
.
_use_token_batch
:
if
max_len
*
(
len
(
batch_data
)
+
1
)
<
self
.
_batch_size
:
batch_max_seq_len
=
max_len
batch_data
.
append
(
pool
.
next
())
else
:
return
batch_data
,
batch_max_seq_len
,
False
else
:
if
len
(
batch_data
)
<
self
.
_batch_size
:
batch_max_seq_len
=
max_len
batch_data
.
append
(
pool
.
next
())
else
:
return
batch_data
,
batch_max_seq_len
,
False
if
not
self
.
_shuffle_batch
:
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
while
not
last_batch
:
yield
batch_data
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
batch_size
=
len
(
batch_data
)
if
self
.
_use_token_batch
:
batch_size
*=
batch_max_seq_len
if
(
not
self
.
_clip_last_batch
and
len
(
batch_data
)
>
0
)
\
or
(
batch_size
==
self
.
_batch_size
):
yield
batch_data
else
:
# should re-generate batches
if
self
.
_sort_type
==
SortType
.
POOL
\
or
len
(
self
.
_epoch_batches
)
==
0
:
self
.
_epoch_batches
=
[]
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
while
not
last_batch
:
self
.
_epoch_batches
.
append
(
batch_data
)
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
batch_size
=
len
(
batch_data
)
if
self
.
_use_token_batch
:
batch_size
*=
batch_max_seq_len
if
(
not
self
.
_clip_last_batch
and
len
(
batch_data
)
>
0
)
\
or
(
batch_size
==
self
.
_batch_size
):
self
.
_epoch_batches
.
append
(
batch_data
)
random
.
shuffle
(
self
.
_epoch_batches
)
for
batch_data
in
self
.
_epoch_batches
:
yield
batch_data
if
self
.
_only_src
:
yield
[[
self
.
_src_seq_ids
[
idx
]]
for
idx
in
batch_ids
]
else
:
yield
[(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
][:
-
1
],
self
.
_trg_seq_ids
[
idx
][
1
:])
for
idx
in
batch_ids
]
fluid/neural_machine_translation/transformer/train.py
浏览文件 @
c34bb5f1
import
os
import
time
import
argparse
import
ast
import
numpy
as
np
import
multiprocessing
import
os
import
time
from
functools
import
partial
import
paddle
import
numpy
as
np
import
paddle.fluid
as
fluid
import
reader
from
config
import
*
from
model
import
transformer
,
position_encoding_init
from
optim
import
LearningRateScheduler
from
config
import
*
import
reader
def
parse_args
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录