Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
dcdb6a00
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dcdb6a00
编写于
7月 20, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Speed up NTM reader
* Load Data --> 64 sec * Shuffle/Batch --> 14 sec
上级
0b48d785
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
156 addition
and
212 deletion
+156
-212
fluid/neural_machine_translation/transformer/infer.py
fluid/neural_machine_translation/transformer/infer.py
+0
-1
fluid/neural_machine_translation/transformer/model.py
fluid/neural_machine_translation/transformer/model.py
+1
-0
fluid/neural_machine_translation/transformer/reader.py
fluid/neural_machine_translation/transformer/reader.py
+150
-205
fluid/neural_machine_translation/transformer/train.py
fluid/neural_machine_translation/transformer/train.py
+5
-6
未找到文件。
fluid/neural_machine_translation/transformer/infer.py
浏览文件 @
dcdb6a00
...
@@ -529,7 +529,6 @@ def fast_infer(test_data, trg_idx2word):
...
@@ -529,7 +529,6 @@ def fast_infer(test_data, trg_idx2word):
def
infer
(
args
,
inferencer
=
fast_infer
):
def
infer
(
args
,
inferencer
=
fast_infer
):
place
=
fluid
.
CUDAPlace
(
0
)
if
InferTaskConfig
.
use_gpu
else
fluid
.
CPUPlace
()
place
=
fluid
.
CUDAPlace
(
0
)
if
InferTaskConfig
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
test_data
=
reader
.
DataReader
(
test_data
=
reader
.
DataReader
(
src_vocab_fpath
=
args
.
src_vocab_fpath
,
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
...
...
fluid/neural_machine_translation/transformer/model.py
浏览文件 @
dcdb6a00
...
@@ -466,6 +466,7 @@ def transformer(
...
@@ -466,6 +466,7 @@ def transformer(
sum_cost
=
layers
.
reduce_sum
(
weighted_cost
)
sum_cost
=
layers
.
reduce_sum
(
weighted_cost
)
token_num
=
layers
.
reduce_sum
(
weights
)
token_num
=
layers
.
reduce_sum
(
weights
)
avg_cost
=
sum_cost
/
token_num
avg_cost
=
sum_cost
/
token_num
avg_cost
.
stop_gradient
=
True
return
sum_cost
,
avg_cost
,
predict
,
token_num
return
sum_cost
,
avg_cost
,
predict
,
token_num
...
...
fluid/neural_machine_translation/transformer/reader.py
浏览文件 @
dcdb6a00
import
os
import
tarfile
import
glob
import
glob
import
os
import
random
import
random
import
tarfile
import
time
class
SortType
(
object
):
class
SortType
(
object
):
...
@@ -11,54 +11,84 @@ class SortType(object):
...
@@ -11,54 +11,84 @@ class SortType(object):
NONE
=
"none"
NONE
=
"none"
class
EndEpoch
():
class
Converter
(
object
):
pass
def
__init__
(
self
,
vocab
,
beg
,
end
,
unk
):
self
.
_vocab
=
vocab
self
.
_beg
=
beg
self
.
_end
=
end
self
.
_unk
=
unk
def
__call__
(
self
,
sentence
):
return
[
self
.
_beg
]
+
[
self
.
_vocab
.
get
(
w
,
self
.
_unk
)
for
w
in
sentence
.
split
()
]
+
[
self
.
_end
]
class
Pool
(
object
):
def
__init__
(
self
,
sample_generator
,
pool_size
,
sort
):
class
ComposedConverter
(
object
):
self
.
_pool_size
=
pool_size
def
__init__
(
self
,
converters
):
self
.
_pool
=
[]
self
.
_converters
=
converters
self
.
_sample_generator
=
sample_generator
()
self
.
_end
=
False
def
__call__
(
self
,
parallel_sentence
):
self
.
_sort
=
sort
return
[
self
.
_converters
[
i
](
parallel_sentence
[
i
])
def
_fill
(
self
):
for
i
in
range
(
len
(
self
.
_converters
))
while
len
(
self
.
_pool
)
<
self
.
_pool_size
and
not
self
.
_end
:
]
try
:
sample
=
self
.
_sample_generator
.
next
()
self
.
_pool
.
append
(
sample
)
class
SentenceBatchCreator
(
object
):
except
StopIteration
as
e
:
def
__init__
(
self
,
batch_size
):
self
.
_end
=
True
self
.
batch
=
[]
break
self
.
_batch_size
=
batch_size
if
self
.
_sort
:
def
append
(
self
,
info
):
self
.
_pool
.
sort
(
self
.
batch
.
append
(
info
)
key
=
lambda
sample
:
max
(
len
(
sample
[
0
]),
len
(
sample
[
1
]))
\
if
len
(
self
.
batch
)
==
self
.
_batch_size
:
if
len
(
sample
)
>
1
else
len
(
sample
[
0
])
tmp
=
self
.
batch
)
self
.
batch
=
[]
return
tmp
if
self
.
_end
and
len
(
self
.
_pool
)
<
self
.
_pool_size
:
self
.
_pool
.
append
(
EndEpoch
())
class
TokenBatchCreator
(
object
):
def
push_back
(
self
,
samples
):
def
__init__
(
self
,
batch_size
):
if
len
(
self
.
_pool
)
!=
0
:
self
.
batch
=
[]
raise
Exception
(
"Pool should be empty."
)
self
.
max_len
=
-
1
self
.
_batch_size
=
batch_size
if
len
(
samples
)
>=
self
.
_pool_size
:
raise
Exception
(
"Capacity of pool should be greater than a batch. "
def
append
(
self
,
info
):
"Please enlarge `pool_size`."
)
cur_len
=
info
.
max_len
max_len
=
max
(
self
.
max_len
,
cur_len
)
for
sample
in
samples
:
if
max_len
*
(
len
(
self
.
batch
)
+
1
)
>
self
.
_batch_size
:
self
.
_pool
.
append
(
sample
)
result
=
self
.
batch
self
.
batch
=
[
info
]
self
.
_fill
()
self
.
max_len
=
cur_len
return
result
def
next
(
self
,
look
=
False
):
else
:
if
len
(
self
.
_pool
)
==
0
:
self
.
max_len
=
max_len
return
None
self
.
batch
.
append
(
info
)
class
SampleInfo
(
object
):
def
__init__
(
self
,
i
,
max_len
,
min_len
):
self
.
i
=
i
self
.
min_len
=
min_len
self
.
max_len
=
max_len
class
MinMaxFilter
(
object
):
def
__init__
(
self
,
max_len
,
min_len
,
underlying_creator
):
self
.
_min_len
=
min_len
self
.
_max_len
=
max_len
self
.
_creator
=
underlying_creator
def
append
(
self
,
info
):
if
info
.
max_len
>
self
.
_max_len
or
info
.
min_len
<
self
.
_min_len
:
return
else
:
else
:
return
self
.
_pool
[
0
]
if
look
else
self
.
_pool
.
pop
(
0
)
return
self
.
_creator
.
append
(
info
)
@
property
def
batch
(
self
):
return
self
.
_creator
.
batch
class
DataReader
(
object
):
class
DataReader
(
object
):
...
@@ -137,7 +167,7 @@ class DataReader(object):
...
@@ -137,7 +167,7 @@ class DataReader(object):
fpattern
,
fpattern
,
batch_size
,
batch_size
,
pool_size
,
pool_size
,
sort_type
=
SortType
.
NONE
,
sort_type
=
SortType
.
GLOBAL
,
clip_last_batch
=
True
,
clip_last_batch
=
True
,
tar_fname
=
None
,
tar_fname
=
None
,
min_length
=
0
,
min_length
=
0
,
...
@@ -165,92 +195,61 @@ class DataReader(object):
...
@@ -165,92 +195,61 @@ class DataReader(object):
self
.
_min_length
=
min_length
self
.
_min_length
=
min_length
self
.
_max_length
=
max_length
self
.
_max_length
=
max_length
self
.
_delimiter
=
delimiter
self
.
_delimiter
=
delimiter
self
.
_epoch_batches
=
[]
self
.
load_src_trg_ids
(
end_mark
,
fpattern
,
start_mark
,
tar_fname
,
unk_mark
)
src_seq_words
,
trg_seq_words
=
self
.
_load_data
(
fpattern
,
tar_fname
)
self
.
_random
=
random
.
Random
(
x
=
seed
)
self
.
_src_seq_ids
=
[[
self
.
_src_vocab
.
get
(
word
,
self
.
_src_vocab
.
get
(
unk_mark
))
def
load_src_trg_ids
(
self
,
end_mark
,
fpattern
,
start_mark
,
tar_fname
,
for
word
in
([
start_mark
]
+
src_seq
+
[
end_mark
])
unk_mark
):
]
for
src_seq
in
src_seq_words
]
converters
=
[
Converter
(
self
.
_sample_count
=
len
(
self
.
_src_seq_ids
)
vocab
=
self
.
_src_vocab
,
beg
=
self
.
_src_vocab
[
start_mark
],
end
=
self
.
_src_vocab
[
end_mark
],
unk
=
self
.
_src_vocab
[
unk_mark
])
]
if
not
self
.
_only_src
:
if
not
self
.
_only_src
:
self
.
_trg_seq_ids
=
[[
converters
.
append
(
self
.
_trg_vocab
.
get
(
word
,
self
.
_trg_vocab
.
get
(
unk_mark
))
Converter
(
for
word
in
([
start_mark
]
+
trg_seq
+
[
end_mark
])
vocab
=
self
.
_trg_vocab
,
]
for
trg_seq
in
trg_seq_words
]
beg
=
self
.
_trg_vocab
[
start_mark
],
if
len
(
self
.
_trg_seq_ids
)
!=
self
.
_sample_count
:
end
=
self
.
_trg_vocab
[
end_mark
],
raise
Exception
(
"Inconsistent sample count between "
unk
=
self
.
_trg_vocab
[
unk_mark
]))
"source sequences and target sequences."
)
else
:
converters
=
ComposedConverter
(
converters
)
self
.
_trg_seq_ids
=
None
self
.
_src_seq_ids
=
[]
self
.
_sample_idxs
=
[
i
for
i
in
xrange
(
self
.
_sample_count
)]
self
.
_trg_seq_ids
=
None
if
self
.
_only_src
else
[]
self
.
_sorted
=
False
self
.
_sample_infos
=
[]
random
.
seed
(
seed
)
for
i
,
line
in
enumerate
(
self
.
_load_lines
(
fpattern
,
tar_fname
)):
src_trg_ids
=
converters
(
line
)
def
_parse_file
(
self
,
f_obj
):
self
.
_src_seq_ids
.
append
(
src_trg_ids
[
0
])
src_seq_words
=
[]
lens
=
[
len
(
src_trg_ids
[
0
])]
trg_seq_words
=
[]
for
line
in
f_obj
:
fields
=
line
.
strip
().
split
(
self
.
_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
!=
2
)
or
(
self
.
_only_src
and
len
(
fields
)
!=
1
):
continue
sample_words
=
[]
is_valid_sample
=
True
max_len
=
-
1
for
i
,
seq
in
enumerate
(
fields
):
seq_words
=
seq
.
split
()
max_len
=
max
(
max_len
,
len
(
seq_words
))
if
len
(
seq_words
)
==
0
or
\
len
(
seq_words
)
<
self
.
_min_length
or
\
len
(
seq_words
)
>
self
.
_max_length
or
\
(
self
.
_use_token_batch
and
max_len
>
self
.
_batch_size
):
is_valid_sample
=
False
break
sample_words
.
append
(
seq_words
)
if
not
is_valid_sample
:
continue
src_seq_words
.
append
(
sample_words
[
0
])
if
not
self
.
_only_src
:
if
not
self
.
_only_src
:
trg_seq_words
.
append
(
sample_words
[
1
])
self
.
_trg_seq_ids
.
append
(
src_trg_ids
[
1
])
lens
.
append
(
len
(
src_trg_ids
[
1
]))
self
.
_sample_infos
.
append
(
SampleInfo
(
i
,
max
(
lens
),
min
(
lens
)))
return
(
src_seq_words
,
trg_seq_words
)
@
staticmethod
def
_load_lines
(
fpattern
,
tar_fname
):
def
_load_data
(
self
,
fpattern
,
tar_fname
):
fpaths
=
glob
.
glob
(
fpattern
)
fpaths
=
glob
.
glob
(
fpattern
)
src_seq_words
=
[]
trg_seq_words
=
[]
if
len
(
fpaths
)
==
1
and
tarfile
.
is_tarfile
(
fpaths
[
0
]):
if
len
(
fpaths
)
==
1
and
tarfile
.
is_tarfile
(
fpaths
[
0
]):
if
tar_fname
is
None
:
if
tar_fname
is
None
:
raise
Exception
(
"If tar file provided, please set tar_fname."
)
raise
Exception
(
"If tar file provided, please set tar_fname."
)
f
=
tarfile
.
open
(
fpaths
[
0
],
'r'
)
f
=
tarfile
.
open
(
fpaths
[
0
],
'r'
)
part_file_data
=
self
.
_parse_file
(
f
.
extractfile
(
tar_fname
))
for
line
in
f
.
extractfile
(
tar_fname
):
src_seq_words
=
part_file_data
[
0
]
yield
line
.
split
()
trg_seq_words
=
part_file_data
[
1
]
else
:
else
:
for
fpath
in
fpaths
:
for
fpath
in
fpaths
:
if
not
os
.
path
.
isfile
(
fpath
):
if
not
os
.
path
.
isfile
(
fpath
):
raise
IOError
(
"Invalid file: %s"
%
fpath
)
raise
IOError
(
"Invalid file: %s"
%
fpath
)
part_file_data
=
self
.
_parse_file
(
open
(
fpath
,
'r'
))
with
open
(
fpath
,
'r'
)
as
f
:
src_seq_words
.
extend
(
part_file_data
[
0
])
for
line
in
f
:
trg_seq_words
.
extend
(
part_file_data
[
1
])
yield
line
.
split
()
return
src_seq_words
,
trg_seq_words
@
staticmethod
@
staticmethod
def
load_dict
(
dict_path
,
reverse
=
False
):
def
load_dict
(
dict_path
,
reverse
=
False
):
...
@@ -263,95 +262,41 @@ class DataReader(object):
...
@@ -263,95 +262,41 @@ class DataReader(object):
word_dict
[
line
.
strip
()]
=
idx
word_dict
[
line
.
strip
()]
=
idx
return
word_dict
return
word_dict
def
_sample_generator
(
self
):
def
batch_generator
(
self
):
# global sort or global shuffle
beg
=
time
.
time
()
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
if
not
self
.
_sorted
:
infos
=
sorted
(
self
.
_sample_idxs
.
sort
(
self
.
_sample_infos
,
key
=
lambda
idx
:
max
(
len
(
self
.
_src_seq_ids
[
idx
]),
key
=
lambda
x
:
max
(
x
[
1
],
x
[
2
])
if
not
self
.
_only_src
else
x
[
1
])
len
(
self
.
_trg_seq_ids
[
idx
]
if
not
self
.
_only_src
else
0
))
)
self
.
_sorted
=
True
elif
self
.
_shuffle
:
elif
self
.
_shuffle
:
random
.
shuffle
(
self
.
_sample_idxs
)
infos
=
self
.
_sample_infos
self
.
_random
.
shuffle
(
infos
)
for
sample_idx
in
self
.
_sample_idxs
:
if
self
.
_only_src
:
yield
(
self
.
_src_seq_ids
[
sample_idx
],
)
else
:
else
:
yield
(
self
.
_src_seq_ids
[
sample_idx
],
infos
=
self
.
_sample_infos
self
.
_trg_seq_ids
[
sample_idx
][:
-
1
],
self
.
_trg_seq_ids
[
sample_idx
][
1
:])
def
batch_generator
(
self
):
pool
=
Pool
(
self
.
_sample_generator
,
self
.
_pool_size
,
True
if
self
.
_sort_type
==
SortType
.
POOL
else
False
)
def
next_batch
():
batch_data
=
[]
max_len
=
-
1
batch_max_seq_len
=
-
1
while
True
:
sample
=
pool
.
next
(
look
=
True
)
if
sample
is
None
:
pool
.
push_back
(
batch_data
)
batch_data
=
[]
continue
if
isinstance
(
sample
,
EndEpoch
):
# concat batch
return
batch_data
,
batch_max_seq_len
,
True
batches
=
[]
batch_creator
=
TokenBatchCreator
(
self
.
_batch_size
)
if
self
.
_use_token_batch
else
SentenceBatchCreator
(
self
.
_batch_size
)
batch_creator
=
MinMaxFilter
(
self
.
_max_length
,
self
.
_min_length
,
batch_creator
)
max_len
=
max
(
max_len
,
len
(
sample
[
0
]))
for
info
in
infos
:
batch
=
batch_creator
.
append
(
info
)
if
batch
is
not
None
:
batches
.
append
([
info
.
i
for
info
in
batch
])
if
not
self
.
_only_src
:
if
not
self
.
_clip_last_batch
and
len
(
batch_creator
.
batch
)
!=
0
:
max_len
=
max
(
max_len
,
len
(
sample
[
1
])
)
batches
.
append
([
info
.
i
for
info
in
batch_creator
.
batch
]
)
if
self
.
_use_token_batch
:
if
self
.
_shuffle
:
if
max_len
*
(
len
(
batch_data
)
+
1
)
<
self
.
_batch_size
:
self
.
_random
.
shuffle
(
batches
)
batch_max_seq_len
=
max_len
batch_data
.
append
(
pool
.
next
())
else
:
return
batch_data
,
batch_max_seq_len
,
False
else
:
if
len
(
batch_data
)
<
self
.
_batch_size
:
batch_max_seq_len
=
max_len
batch_data
.
append
(
pool
.
next
())
else
:
return
batch_data
,
batch_max_seq_len
,
False
if
not
self
.
_shuffle_batch
:
for
batch_ids
in
batches
:
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
if
self
.
_only_src
:
while
not
last_batch
:
yield
[[
self
.
_src_seq_ids
[
idx
]]
for
idx
in
batch_ids
]
yield
batch_data
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
batch_size
=
len
(
batch_data
)
if
self
.
_use_token_batch
:
batch_size
*=
batch_max_seq_len
if
(
not
self
.
_clip_last_batch
and
len
(
batch_data
)
>
0
)
\
or
(
batch_size
==
self
.
_batch_size
):
yield
batch_data
else
:
else
:
# should re-generate batches
yield
[(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
][:
-
1
],
if
self
.
_sort_type
==
SortType
.
POOL
\
self
.
_trg_seq_ids
[
idx
][
1
:])
for
idx
in
batch_ids
]
or
len
(
self
.
_epoch_batches
)
==
0
:
self
.
_epoch_batches
=
[]
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
while
not
last_batch
:
self
.
_epoch_batches
.
append
(
batch_data
)
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
batch_size
=
len
(
batch_data
)
if
self
.
_use_token_batch
:
batch_size
*=
batch_max_seq_len
if
(
not
self
.
_clip_last_batch
and
len
(
batch_data
)
>
0
)
\
or
(
batch_size
==
self
.
_batch_size
):
self
.
_epoch_batches
.
append
(
batch_data
)
random
.
shuffle
(
self
.
_epoch_batches
)
for
batch_data
in
self
.
_epoch_batches
:
yield
batch_data
fluid/neural_machine_translation/transformer/train.py
浏览文件 @
dcdb6a00
import
os
import
time
import
argparse
import
argparse
import
ast
import
ast
import
numpy
as
np
import
multiprocessing
import
multiprocessing
import
os
import
time
import
paddle
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
reader
from
config
import
*
from
model
import
transformer
,
position_encoding_init
from
model
import
transformer
,
position_encoding_init
from
optim
import
LearningRateScheduler
from
optim
import
LearningRateScheduler
from
config
import
*
import
reader
def
parse_args
():
def
parse_args
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录