Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
dcdb6a00
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dcdb6a00
编写于
7月 20, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Speed up NTM reader
* Load Data --> 64 sec * Shuffle/Batch --> 14 sec
上级
0b48d785
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
156 addition
and
212 deletion
+156
-212
fluid/neural_machine_translation/transformer/infer.py
fluid/neural_machine_translation/transformer/infer.py
+0
-1
fluid/neural_machine_translation/transformer/model.py
fluid/neural_machine_translation/transformer/model.py
+1
-0
fluid/neural_machine_translation/transformer/reader.py
fluid/neural_machine_translation/transformer/reader.py
+150
-205
fluid/neural_machine_translation/transformer/train.py
fluid/neural_machine_translation/transformer/train.py
+5
-6
未找到文件。
fluid/neural_machine_translation/transformer/infer.py
浏览文件 @
dcdb6a00
...
...
@@ -529,7 +529,6 @@ def fast_infer(test_data, trg_idx2word):
def
infer
(
args
,
inferencer
=
fast_infer
):
place
=
fluid
.
CUDAPlace
(
0
)
if
InferTaskConfig
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
test_data
=
reader
.
DataReader
(
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
...
...
fluid/neural_machine_translation/transformer/model.py
浏览文件 @
dcdb6a00
...
...
@@ -466,6 +466,7 @@ def transformer(
sum_cost
=
layers
.
reduce_sum
(
weighted_cost
)
token_num
=
layers
.
reduce_sum
(
weights
)
avg_cost
=
sum_cost
/
token_num
avg_cost
.
stop_gradient
=
True
return
sum_cost
,
avg_cost
,
predict
,
token_num
...
...
fluid/neural_machine_translation/transformer/reader.py
浏览文件 @
dcdb6a00
import
os
import
tarfile
import
glob
import
os
import
random
import
tarfile
import
time
class
SortType
(
object
):
...
...
@@ -11,54 +11,84 @@ class SortType(object):
NONE
=
"none"
class
EndEpoch
():
pass
class
Converter
(
object
):
def
__init__
(
self
,
vocab
,
beg
,
end
,
unk
):
self
.
_vocab
=
vocab
self
.
_beg
=
beg
self
.
_end
=
end
self
.
_unk
=
unk
def
__call__
(
self
,
sentence
):
return
[
self
.
_beg
]
+
[
self
.
_vocab
.
get
(
w
,
self
.
_unk
)
for
w
in
sentence
.
split
()
]
+
[
self
.
_end
]
class
Pool
(
object
):
def
__init__
(
self
,
sample_generator
,
pool_size
,
sort
):
self
.
_pool_size
=
pool_size
self
.
_pool
=
[]
self
.
_sample_generator
=
sample_generator
()
self
.
_end
=
False
self
.
_sort
=
sort
def
_fill
(
self
):
while
len
(
self
.
_pool
)
<
self
.
_pool_size
and
not
self
.
_end
:
try
:
sample
=
self
.
_sample_generator
.
next
()
self
.
_pool
.
append
(
sample
)
except
StopIteration
as
e
:
self
.
_end
=
True
break
if
self
.
_sort
:
self
.
_pool
.
sort
(
key
=
lambda
sample
:
max
(
len
(
sample
[
0
]),
len
(
sample
[
1
]))
\
if
len
(
sample
)
>
1
else
len
(
sample
[
0
])
)
if
self
.
_end
and
len
(
self
.
_pool
)
<
self
.
_pool_size
:
self
.
_pool
.
append
(
EndEpoch
())
def
push_back
(
self
,
samples
):
if
len
(
self
.
_pool
)
!=
0
:
raise
Exception
(
"Pool should be empty."
)
if
len
(
samples
)
>=
self
.
_pool_size
:
raise
Exception
(
"Capacity of pool should be greater than a batch. "
"Please enlarge `pool_size`."
)
for
sample
in
samples
:
self
.
_pool
.
append
(
sample
)
self
.
_fill
()
def
next
(
self
,
look
=
False
):
if
len
(
self
.
_pool
)
==
0
:
return
None
class
ComposedConverter
(
object
):
def
__init__
(
self
,
converters
):
self
.
_converters
=
converters
def
__call__
(
self
,
parallel_sentence
):
return
[
self
.
_converters
[
i
](
parallel_sentence
[
i
])
for
i
in
range
(
len
(
self
.
_converters
))
]
class
SentenceBatchCreator
(
object
):
def
__init__
(
self
,
batch_size
):
self
.
batch
=
[]
self
.
_batch_size
=
batch_size
def
append
(
self
,
info
):
self
.
batch
.
append
(
info
)
if
len
(
self
.
batch
)
==
self
.
_batch_size
:
tmp
=
self
.
batch
self
.
batch
=
[]
return
tmp
class
TokenBatchCreator
(
object
):
def
__init__
(
self
,
batch_size
):
self
.
batch
=
[]
self
.
max_len
=
-
1
self
.
_batch_size
=
batch_size
def
append
(
self
,
info
):
cur_len
=
info
.
max_len
max_len
=
max
(
self
.
max_len
,
cur_len
)
if
max_len
*
(
len
(
self
.
batch
)
+
1
)
>
self
.
_batch_size
:
result
=
self
.
batch
self
.
batch
=
[
info
]
self
.
max_len
=
cur_len
return
result
else
:
self
.
max_len
=
max_len
self
.
batch
.
append
(
info
)
class
SampleInfo
(
object
):
def
__init__
(
self
,
i
,
max_len
,
min_len
):
self
.
i
=
i
self
.
min_len
=
min_len
self
.
max_len
=
max_len
class
MinMaxFilter
(
object
):
def
__init__
(
self
,
max_len
,
min_len
,
underlying_creator
):
self
.
_min_len
=
min_len
self
.
_max_len
=
max_len
self
.
_creator
=
underlying_creator
def
append
(
self
,
info
):
if
info
.
max_len
>
self
.
_max_len
or
info
.
min_len
<
self
.
_min_len
:
return
else
:
return
self
.
_pool
[
0
]
if
look
else
self
.
_pool
.
pop
(
0
)
return
self
.
_creator
.
append
(
info
)
@
property
def
batch
(
self
):
return
self
.
_creator
.
batch
class
DataReader
(
object
):
...
...
@@ -137,7 +167,7 @@ class DataReader(object):
fpattern
,
batch_size
,
pool_size
,
sort_type
=
SortType
.
NONE
,
sort_type
=
SortType
.
GLOBAL
,
clip_last_batch
=
True
,
tar_fname
=
None
,
min_length
=
0
,
...
...
@@ -165,92 +195,61 @@ class DataReader(object):
self
.
_min_length
=
min_length
self
.
_max_length
=
max_length
self
.
_delimiter
=
delimiter
self
.
_epoch_batches
=
[]
src_seq_words
,
trg_seq_words
=
self
.
_load_data
(
fpattern
,
tar_fname
)
self
.
_src_seq_ids
=
[[
self
.
_src_vocab
.
get
(
word
,
self
.
_src_vocab
.
get
(
unk_mark
))
for
word
in
([
start_mark
]
+
src_seq
+
[
end_mark
])
]
for
src_seq
in
src_seq_words
]
self
.
_sample_count
=
len
(
self
.
_src_seq_ids
)
self
.
load_src_trg_ids
(
end_mark
,
fpattern
,
start_mark
,
tar_fname
,
unk_mark
)
self
.
_random
=
random
.
Random
(
x
=
seed
)
def
load_src_trg_ids
(
self
,
end_mark
,
fpattern
,
start_mark
,
tar_fname
,
unk_mark
):
converters
=
[
Converter
(
vocab
=
self
.
_src_vocab
,
beg
=
self
.
_src_vocab
[
start_mark
],
end
=
self
.
_src_vocab
[
end_mark
],
unk
=
self
.
_src_vocab
[
unk_mark
])
]
if
not
self
.
_only_src
:
self
.
_trg_seq_ids
=
[[
self
.
_trg_vocab
.
get
(
word
,
self
.
_trg_vocab
.
get
(
unk_mark
))
for
word
in
([
start_mark
]
+
trg_seq
+
[
end_mark
])
]
for
trg_seq
in
trg_seq_words
]
if
len
(
self
.
_trg_seq_ids
)
!=
self
.
_sample_count
:
raise
Exception
(
"Inconsistent sample count between "
"source sequences and target sequences."
)
else
:
self
.
_trg_seq_ids
=
None
self
.
_sample_idxs
=
[
i
for
i
in
xrange
(
self
.
_sample_count
)]
self
.
_sorted
=
False
random
.
seed
(
seed
)
def
_parse_file
(
self
,
f_obj
):
src_seq_words
=
[]
trg_seq_words
=
[]
for
line
in
f_obj
:
fields
=
line
.
strip
().
split
(
self
.
_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
!=
2
)
or
(
self
.
_only_src
and
len
(
fields
)
!=
1
):
continue
sample_words
=
[]
is_valid_sample
=
True
max_len
=
-
1
for
i
,
seq
in
enumerate
(
fields
):
seq_words
=
seq
.
split
()
max_len
=
max
(
max_len
,
len
(
seq_words
))
if
len
(
seq_words
)
==
0
or
\
len
(
seq_words
)
<
self
.
_min_length
or
\
len
(
seq_words
)
>
self
.
_max_length
or
\
(
self
.
_use_token_batch
and
max_len
>
self
.
_batch_size
):
is_valid_sample
=
False
break
sample_words
.
append
(
seq_words
)
if
not
is_valid_sample
:
continue
src_seq_words
.
append
(
sample_words
[
0
])
converters
.
append
(
Converter
(
vocab
=
self
.
_trg_vocab
,
beg
=
self
.
_trg_vocab
[
start_mark
],
end
=
self
.
_trg_vocab
[
end_mark
],
unk
=
self
.
_trg_vocab
[
unk_mark
]))
converters
=
ComposedConverter
(
converters
)
self
.
_src_seq_ids
=
[]
self
.
_trg_seq_ids
=
None
if
self
.
_only_src
else
[]
self
.
_sample_infos
=
[]
for
i
,
line
in
enumerate
(
self
.
_load_lines
(
fpattern
,
tar_fname
)):
src_trg_ids
=
converters
(
line
)
self
.
_src_seq_ids
.
append
(
src_trg_ids
[
0
])
lens
=
[
len
(
src_trg_ids
[
0
])]
if
not
self
.
_only_src
:
trg_seq_words
.
append
(
sample_wor
ds
[
1
])
return
(
src_seq_words
,
trg_seq_words
)
self
.
_trg_seq_ids
.
append
(
src_trg_i
ds
[
1
])
lens
.
append
(
len
(
src_trg_ids
[
1
]))
self
.
_sample_infos
.
append
(
SampleInfo
(
i
,
max
(
lens
),
min
(
lens
))
)
def
_load_data
(
self
,
fpattern
,
tar_fname
):
@
staticmethod
def
_load_lines
(
fpattern
,
tar_fname
):
fpaths
=
glob
.
glob
(
fpattern
)
src_seq_words
=
[]
trg_seq_words
=
[]
if
len
(
fpaths
)
==
1
and
tarfile
.
is_tarfile
(
fpaths
[
0
]):
if
tar_fname
is
None
:
raise
Exception
(
"If tar file provided, please set tar_fname."
)
f
=
tarfile
.
open
(
fpaths
[
0
],
'r'
)
part_file_data
=
self
.
_parse_file
(
f
.
extractfile
(
tar_fname
))
src_seq_words
=
part_file_data
[
0
]
trg_seq_words
=
part_file_data
[
1
]
for
line
in
f
.
extractfile
(
tar_fname
):
yield
line
.
split
()
else
:
for
fpath
in
fpaths
:
if
not
os
.
path
.
isfile
(
fpath
):
raise
IOError
(
"Invalid file: %s"
%
fpath
)
part_file_data
=
self
.
_parse_file
(
open
(
fpath
,
'r'
))
src_seq_words
.
extend
(
part_file_data
[
0
])
trg_seq_words
.
extend
(
part_file_data
[
1
])
return
src_seq_words
,
trg_seq_words
with
open
(
fpath
,
'r'
)
as
f
:
for
line
in
f
:
yield
line
.
split
()
@
staticmethod
def
load_dict
(
dict_path
,
reverse
=
False
):
...
...
@@ -263,95 +262,41 @@ class DataReader(object):
word_dict
[
line
.
strip
()]
=
idx
return
word_dict
def
_sample_generator
(
self
):
def
batch_generator
(
self
):
# global sort or global shuffle
beg
=
time
.
time
()
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
if
not
self
.
_sorted
:
self
.
_sample_idxs
.
sort
(
key
=
lambda
idx
:
max
(
len
(
self
.
_src_seq_ids
[
idx
]),
len
(
self
.
_trg_seq_ids
[
idx
]
if
not
self
.
_only_src
else
0
))
)
self
.
_sorted
=
True
infos
=
sorted
(
self
.
_sample_infos
,
key
=
lambda
x
:
max
(
x
[
1
],
x
[
2
])
if
not
self
.
_only_src
else
x
[
1
])
elif
self
.
_shuffle
:
random
.
shuffle
(
self
.
_sample_idxs
)
for
sample_idx
in
self
.
_sample_idxs
:
if
self
.
_only_src
:
yield
(
self
.
_src_seq_ids
[
sample_idx
],
)
else
:
yield
(
self
.
_src_seq_ids
[
sample_idx
],
self
.
_trg_seq_ids
[
sample_idx
][:
-
1
],
self
.
_trg_seq_ids
[
sample_idx
][
1
:])
def
batch_generator
(
self
):
pool
=
Pool
(
self
.
_sample_generator
,
self
.
_pool_size
,
True
if
self
.
_sort_type
==
SortType
.
POOL
else
False
)
def
next_batch
():
batch_data
=
[]
max_len
=
-
1
batch_max_seq_len
=
-
1
while
True
:
sample
=
pool
.
next
(
look
=
True
)
infos
=
self
.
_sample_infos
self
.
_random
.
shuffle
(
infos
)
else
:
infos
=
self
.
_sample_infos
if
sample
is
None
:
pool
.
push_back
(
batch_data
)
batch_data
=
[]
continue
# concat batch
batches
=
[]
batch_creator
=
TokenBatchCreator
(
self
.
_batch_size
)
if
self
.
_use_token_batch
else
SentenceBatchCreator
(
self
.
_batch_size
)
batch_creator
=
MinMaxFilter
(
self
.
_max_length
,
self
.
_min_length
,
batch_creator
)
if
isinstance
(
sample
,
EndEpoch
):
return
batch_data
,
batch_max_seq_len
,
True
for
info
in
infos
:
batch
=
batch_creator
.
append
(
info
)
if
batch
is
not
None
:
batches
.
append
([
info
.
i
for
info
in
batch
])
max_len
=
max
(
max_len
,
len
(
sample
[
0
]))
if
not
self
.
_clip_last_batch
and
len
(
batch_creator
.
batch
)
!=
0
:
batches
.
append
([
info
.
i
for
info
in
batch_creator
.
batch
])
if
not
self
.
_only_src
:
max_len
=
max
(
max_len
,
len
(
sample
[
1
])
)
if
self
.
_shuffle
:
self
.
_random
.
shuffle
(
batches
)
if
self
.
_use_token_batch
:
if
max_len
*
(
len
(
batch_data
)
+
1
)
<
self
.
_batch_size
:
batch_max_seq_len
=
max_len
batch_data
.
append
(
pool
.
next
())
else
:
return
batch_data
,
batch_max_seq_len
,
False
else
:
if
len
(
batch_data
)
<
self
.
_batch_size
:
batch_max_seq_len
=
max_len
batch_data
.
append
(
pool
.
next
())
else
:
return
batch_data
,
batch_max_seq_len
,
False
if
not
self
.
_shuffle_batch
:
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
while
not
last_batch
:
yield
batch_data
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
batch_size
=
len
(
batch_data
)
if
self
.
_use_token_batch
:
batch_size
*=
batch_max_seq_len
if
(
not
self
.
_clip_last_batch
and
len
(
batch_data
)
>
0
)
\
or
(
batch_size
==
self
.
_batch_size
):
yield
batch_data
else
:
# should re-generate batches
if
self
.
_sort_type
==
SortType
.
POOL
\
or
len
(
self
.
_epoch_batches
)
==
0
:
self
.
_epoch_batches
=
[]
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
while
not
last_batch
:
self
.
_epoch_batches
.
append
(
batch_data
)
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
batch_size
=
len
(
batch_data
)
if
self
.
_use_token_batch
:
batch_size
*=
batch_max_seq_len
if
(
not
self
.
_clip_last_batch
and
len
(
batch_data
)
>
0
)
\
or
(
batch_size
==
self
.
_batch_size
):
self
.
_epoch_batches
.
append
(
batch_data
)
random
.
shuffle
(
self
.
_epoch_batches
)
for
batch_data
in
self
.
_epoch_batches
:
yield
batch_data
for
batch_ids
in
batches
:
if
self
.
_only_src
:
yield
[[
self
.
_src_seq_ids
[
idx
]]
for
idx
in
batch_ids
]
else
:
yield
[(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
][:
-
1
],
self
.
_trg_seq_ids
[
idx
][
1
:])
for
idx
in
batch_ids
]
fluid/neural_machine_translation/transformer/train.py
浏览文件 @
dcdb6a00
import
os
import
time
import
argparse
import
ast
import
numpy
as
np
import
multiprocessing
import
os
import
time
import
paddle
import
numpy
as
np
import
paddle.fluid
as
fluid
import
reader
from
config
import
*
from
model
import
transformer
,
position_encoding_init
from
optim
import
LearningRateScheduler
from
config
import
*
import
reader
def
parse_args
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录