Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
c34bb5f1
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c34bb5f1
编写于
8月 15, 2018
作者:
Y
Yu Yang
提交者:
GitHub
8月 15, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1060 from reyoung/speed_up_transformer_python_reader
Speed up NTM reader
上级
f36588dc
6fc33ac5
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
178 addition
and
213 deletion
+178
-213
.gitignore
.gitignore
+6
-0
fluid/neural_machine_translation/transformer/infer.py
fluid/neural_machine_translation/transformer/infer.py
+0
-1
fluid/neural_machine_translation/transformer/model.py
fluid/neural_machine_translation/transformer/model.py
+1
-0
fluid/neural_machine_translation/transformer/reader.py
fluid/neural_machine_translation/transformer/reader.py
+166
-206
fluid/neural_machine_translation/transformer/train.py
fluid/neural_machine_translation/transformer/train.py
+5
-6
未找到文件。
.gitignore
浏览文件 @
c34bb5f1
.DS_Store
.DS_Store
*.pyc
*.pyc
.*~
.*~
fluid/neural_machine_translation/transformer/deps
fluid/neural_machine_translation/transformer/train.data
fluid/neural_machine_translation/transformer/train.pkl
fluid/neural_machine_translation/transformer/train.sh
fluid/neural_machine_translation/transformer/train.tok.clean.bpe.32000.en-de
fluid/neural_machine_translation/transformer/vocab.bpe.32000.refined
fluid/neural_machine_translation/transformer/infer.py
浏览文件 @
c34bb5f1
...
@@ -553,7 +553,6 @@ def fast_infer(test_data, trg_idx2word, use_wordpiece):
...
@@ -553,7 +553,6 @@ def fast_infer(test_data, trg_idx2word, use_wordpiece):
def
infer
(
args
,
inferencer
=
fast_infer
):
def
infer
(
args
,
inferencer
=
fast_infer
):
place
=
fluid
.
CUDAPlace
(
0
)
if
InferTaskConfig
.
use_gpu
else
fluid
.
CPUPlace
()
place
=
fluid
.
CUDAPlace
(
0
)
if
InferTaskConfig
.
use_gpu
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
test_data
=
reader
.
DataReader
(
test_data
=
reader
.
DataReader
(
src_vocab_fpath
=
args
.
src_vocab_fpath
,
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
...
...
fluid/neural_machine_translation/transformer/model.py
浏览文件 @
c34bb5f1
...
@@ -474,6 +474,7 @@ def transformer(
...
@@ -474,6 +474,7 @@ def transformer(
sum_cost
=
layers
.
reduce_sum
(
weighted_cost
)
sum_cost
=
layers
.
reduce_sum
(
weighted_cost
)
token_num
=
layers
.
reduce_sum
(
weights
)
token_num
=
layers
.
reduce_sum
(
weights
)
avg_cost
=
sum_cost
/
token_num
avg_cost
=
sum_cost
/
token_num
avg_cost
.
stop_gradient
=
True
return
sum_cost
,
avg_cost
,
predict
,
token_num
return
sum_cost
,
avg_cost
,
predict
,
token_num
...
...
fluid/neural_machine_translation/transformer/reader.py
浏览文件 @
c34bb5f1
import
os
import
tarfile
import
glob
import
glob
import
os
import
random
import
random
import
tarfile
import
cPickle
class
SortType
(
object
):
class
SortType
(
object
):
...
@@ -11,54 +11,86 @@ class SortType(object):
...
@@ -11,54 +11,86 @@ class SortType(object):
NONE
=
"none"
NONE
=
"none"
class
EndEpoch
():
class
Converter
(
object
):
pass
def
__init__
(
self
,
vocab
,
beg
,
end
,
unk
,
delimiter
):
self
.
_vocab
=
vocab
self
.
_beg
=
beg
self
.
_end
=
end
self
.
_unk
=
unk
self
.
_delimiter
=
delimiter
def
__call__
(
self
,
sentence
):
return
[
self
.
_beg
]
+
[
self
.
_vocab
.
get
(
w
,
self
.
_unk
)
for
w
in
sentence
.
split
(
self
.
_delimiter
)
]
+
[
self
.
_end
]
class
Pool
(
object
):
def
__init__
(
self
,
sample_generator
,
pool_size
,
sort
):
class
ComposedConverter
(
object
):
self
.
_pool_size
=
pool_size
def
__init__
(
self
,
converters
):
self
.
_pool
=
[]
self
.
_converters
=
converters
self
.
_sample_generator
=
sample_generator
()
self
.
_end
=
False
def
__call__
(
self
,
parallel_sentence
):
self
.
_sort
=
sort
return
[
self
.
_converters
[
i
](
parallel_sentence
[
i
])
def
_fill
(
self
):
for
i
in
range
(
len
(
self
.
_converters
))
while
len
(
self
.
_pool
)
<
self
.
_pool_size
and
not
self
.
_end
:
]
try
:
sample
=
self
.
_sample_generator
.
next
()
self
.
_pool
.
append
(
sample
)
class
SentenceBatchCreator
(
object
):
except
StopIteration
as
e
:
def
__init__
(
self
,
batch_size
):
self
.
_end
=
True
self
.
batch
=
[]
break
self
.
_batch_size
=
batch_size
if
self
.
_sort
:
def
append
(
self
,
info
):
self
.
_pool
.
sort
(
self
.
batch
.
append
(
info
)
key
=
lambda
sample
:
max
(
len
(
sample
[
0
]),
len
(
sample
[
1
]))
\
if
len
(
self
.
batch
)
==
self
.
_batch_size
:
if
len
(
sample
)
>
1
else
len
(
sample
[
0
])
tmp
=
self
.
batch
)
self
.
batch
=
[]
return
tmp
if
self
.
_end
and
len
(
self
.
_pool
)
<
self
.
_pool_size
:
self
.
_pool
.
append
(
EndEpoch
())
class
TokenBatchCreator
(
object
):
def
push_back
(
self
,
samples
):
def
__init__
(
self
,
batch_size
):
if
len
(
self
.
_pool
)
!=
0
:
self
.
batch
=
[]
raise
Exception
(
"Pool should be empty."
)
self
.
max_len
=
-
1
self
.
_batch_size
=
batch_size
if
len
(
samples
)
>=
self
.
_pool_size
:
raise
Exception
(
"Capacity of pool should be greater than a batch. "
def
append
(
self
,
info
):
"Please enlarge `pool_size`."
)
cur_len
=
info
.
max_len
max_len
=
max
(
self
.
max_len
,
cur_len
)
for
sample
in
samples
:
if
max_len
*
(
len
(
self
.
batch
)
+
1
)
>
self
.
_batch_size
:
self
.
_pool
.
append
(
sample
)
result
=
self
.
batch
self
.
batch
=
[
info
]
self
.
_fill
()
self
.
max_len
=
cur_len
return
result
def
next
(
self
,
look
=
False
):
if
len
(
self
.
_pool
)
==
0
:
return
None
else
:
else
:
return
self
.
_pool
[
0
]
if
look
else
self
.
_pool
.
pop
(
0
)
self
.
max_len
=
max_len
self
.
batch
.
append
(
info
)
class
SampleInfo
(
object
):
def
__init__
(
self
,
i
,
max_len
,
min_len
):
self
.
i
=
i
self
.
min_len
=
min_len
self
.
max_len
=
max_len
class
MinMaxFilter
(
object
):
def
__init__
(
self
,
max_len
,
min_len
,
underlying_creator
):
self
.
_min_len
=
min_len
self
.
_max_len
=
max_len
self
.
_creator
=
underlying_creator
def
append
(
self
,
info
):
if
info
.
max_len
>
self
.
_max_len
or
info
.
min_len
<
self
.
_min_len
:
return
else
:
return
self
.
_creator
.
append
(
info
)
@
property
def
batch
(
self
):
return
self
.
_creator
.
batch
class
DataReader
(
object
):
class
DataReader
(
object
):
...
@@ -140,7 +172,7 @@ class DataReader(object):
...
@@ -140,7 +172,7 @@ class DataReader(object):
fpattern
,
fpattern
,
batch_size
,
batch_size
,
pool_size
,
pool_size
,
sort_type
=
SortType
.
NONE
,
sort_type
=
SortType
.
GLOBAL
,
clip_last_batch
=
True
,
clip_last_batch
=
True
,
tar_fname
=
None
,
tar_fname
=
None
,
min_length
=
0
,
min_length
=
0
,
...
@@ -170,92 +202,68 @@ class DataReader(object):
...
@@ -170,92 +202,68 @@ class DataReader(object):
self
.
_max_length
=
max_length
self
.
_max_length
=
max_length
self
.
_field_delimiter
=
field_delimiter
self
.
_field_delimiter
=
field_delimiter
self
.
_token_delimiter
=
token_delimiter
self
.
_token_delimiter
=
token_delimiter
self
.
_epoch_batches
=
[]
self
.
load_src_trg_ids
(
end_mark
,
fpattern
,
start_mark
,
tar_fname
,
unk_mark
)
src_seq_words
,
trg_seq_words
=
self
.
_load_data
(
fpattern
,
tar_fname
)
self
.
_random
=
random
.
Random
(
x
=
seed
)
self
.
_src_seq_ids
=
[[
self
.
_src_vocab
.
get
(
word
,
self
.
_src_vocab
.
get
(
unk_mark
))
def
load_src_trg_ids
(
self
,
end_mark
,
fpattern
,
start_mark
,
tar_fname
,
for
word
in
([
start_mark
]
+
src_seq
+
[
end_mark
])
unk_mark
):
]
for
src_seq
in
src_seq_words
]
converters
=
[
Converter
(
self
.
_sample_count
=
len
(
self
.
_src_seq_ids
)
vocab
=
self
.
_src_vocab
,
beg
=
self
.
_src_vocab
[
start_mark
],
end
=
self
.
_src_vocab
[
end_mark
],
unk
=
self
.
_src_vocab
[
unk_mark
],
delimiter
=
self
.
_token_delimiter
)
]
if
not
self
.
_only_src
:
if
not
self
.
_only_src
:
self
.
_trg_seq_ids
=
[[
converters
.
append
(
self
.
_trg_vocab
.
get
(
word
,
self
.
_trg_vocab
.
get
(
unk_mark
))
Converter
(
for
word
in
([
start_mark
]
+
trg_seq
+
[
end_mark
])
vocab
=
self
.
_trg_vocab
,
]
for
trg_seq
in
trg_seq_words
]
beg
=
self
.
_trg_vocab
[
start_mark
],
if
len
(
self
.
_trg_seq_ids
)
!=
self
.
_sample_count
:
end
=
self
.
_trg_vocab
[
end_mark
],
raise
Exception
(
"Inconsistent sample count between "
unk
=
self
.
_trg_vocab
[
unk_mark
],
"source sequences and target sequences."
)
delimiter
=
self
.
_token_delimiter
))
else
:
self
.
_trg_seq_ids
=
None
converters
=
ComposedConverter
(
converters
)
self
.
_sample_idxs
=
[
i
for
i
in
xrange
(
self
.
_sample_count
)]
self
.
_src_seq_ids
=
[]
self
.
_sorted
=
False
self
.
_trg_seq_ids
=
None
if
self
.
_only_src
else
[]
self
.
_sample_infos
=
[]
random
.
seed
(
seed
)
for
i
,
line
in
enumerate
(
self
.
_load_lines
(
fpattern
,
tar_fname
)):
def
_parse_file
(
self
,
f_obj
):
src_trg_ids
=
converters
(
line
)
src_seq_words
=
[]
self
.
_src_seq_ids
.
append
(
src_trg_ids
[
0
])
trg_seq_words
=
[]
lens
=
[
len
(
src_trg_ids
[
0
])]
for
line
in
f_obj
:
fields
=
line
.
strip
().
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
!=
2
)
or
(
self
.
_only_src
and
len
(
fields
)
!=
1
):
continue
sample_words
=
[]
is_valid_sample
=
True
max_len
=
-
1
for
i
,
seq
in
enumerate
(
fields
):
seq_words
=
seq
.
split
(
self
.
_token_delimiter
)
max_len
=
max
(
max_len
,
len
(
seq_words
))
if
len
(
seq_words
)
==
0
or
\
len
(
seq_words
)
<
self
.
_min_length
or
\
len
(
seq_words
)
>
self
.
_max_length
or
\
(
self
.
_use_token_batch
and
max_len
>
self
.
_batch_size
):
is_valid_sample
=
False
break
sample_words
.
append
(
seq_words
)
if
not
is_valid_sample
:
continue
src_seq_words
.
append
(
sample_words
[
0
])
if
not
self
.
_only_src
:
if
not
self
.
_only_src
:
trg_seq_words
.
append
(
sample_words
[
1
])
self
.
_trg_seq_ids
.
append
(
src_trg_ids
[
1
])
lens
.
append
(
len
(
src_trg_ids
[
1
]))
self
.
_sample_infos
.
append
(
SampleInfo
(
i
,
max
(
lens
),
min
(
lens
)))
return
(
src_seq_words
,
trg_seq_words
)
def
_load_lines
(
self
,
fpattern
,
tar_fname
):
def
_load_data
(
self
,
fpattern
,
tar_fname
):
fpaths
=
glob
.
glob
(
fpattern
)
fpaths
=
glob
.
glob
(
fpattern
)
src_seq_words
=
[]
trg_seq_words
=
[]
if
len
(
fpaths
)
==
1
and
tarfile
.
is_tarfile
(
fpaths
[
0
]):
if
len
(
fpaths
)
==
1
and
tarfile
.
is_tarfile
(
fpaths
[
0
]):
if
tar_fname
is
None
:
if
tar_fname
is
None
:
raise
Exception
(
"If tar file provided, please set tar_fname."
)
raise
Exception
(
"If tar file provided, please set tar_fname."
)
f
=
tarfile
.
open
(
fpaths
[
0
],
'r'
)
f
=
tarfile
.
open
(
fpaths
[
0
],
"r"
)
part_file_data
=
self
.
_parse_file
(
f
.
extractfile
(
tar_fname
))
for
line
in
f
.
extractfile
(
tar_fname
):
src_seq_words
=
part_file_data
[
0
]
fields
=
line
.
strip
(
"
\n
"
).
split
(
self
.
_field_delimiter
)
trg_seq_words
=
part_file_data
[
1
]
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
yield
fields
else
:
else
:
for
fpath
in
fpaths
:
for
fpath
in
fpaths
:
if
not
os
.
path
.
isfile
(
fpath
):
if
not
os
.
path
.
isfile
(
fpath
):
raise
IOError
(
"Invalid file: %s"
%
fpath
)
raise
IOError
(
"Invalid file: %s"
%
fpath
)
part_file_data
=
self
.
_parse_file
(
open
(
fpath
,
'r'
))
with
open
(
fpath
,
"r"
)
as
f
:
src_seq_words
.
extend
(
part_file_data
[
0
])
for
line
in
f
:
trg_seq_words
.
extend
(
part_file_data
[
1
])
fields
=
line
.
strip
(
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
return
src_seq_words
,
trg_seq_words
self
.
_only_src
and
len
(
fields
)
==
1
):
yield
fields
@
staticmethod
@
staticmethod
def
load_dict
(
dict_path
,
reverse
=
False
):
def
load_dict
(
dict_path
,
reverse
=
False
):
...
@@ -263,100 +271,52 @@ class DataReader(object):
...
@@ -263,100 +271,52 @@ class DataReader(object):
with
open
(
dict_path
,
"r"
)
as
fdict
:
with
open
(
dict_path
,
"r"
)
as
fdict
:
for
idx
,
line
in
enumerate
(
fdict
):
for
idx
,
line
in
enumerate
(
fdict
):
if
reverse
:
if
reverse
:
word_dict
[
idx
]
=
line
.
strip
(
'
\n
'
)
word_dict
[
idx
]
=
line
.
strip
(
"
\n
"
)
else
:
else
:
word_dict
[
line
.
strip
(
'
\n
'
)]
=
idx
word_dict
[
line
.
strip
(
"
\n
"
)]
=
idx
return
word_dict
return
word_dict
def
_sample_generator
(
self
):
def
batch_generator
(
self
):
# global sort or global shuffle
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
if
not
self
.
_sorted
:
infos
=
sorted
(
self
.
_sample_idxs
.
sort
(
self
.
_sample_infos
,
key
=
lambda
x
:
x
.
max_len
,
reverse
=
True
)
key
=
lambda
idx
:
max
(
len
(
self
.
_src_seq_ids
[
idx
]),
len
(
self
.
_trg_seq_ids
[
idx
]
if
not
self
.
_only_src
else
0
))
)
self
.
_sorted
=
True
elif
self
.
_shuffle
:
random
.
shuffle
(
self
.
_sample_idxs
)
for
sample_idx
in
self
.
_sample_idxs
:
if
self
.
_only_src
:
yield
(
self
.
_src_seq_ids
[
sample_idx
],
)
else
:
else
:
yield
(
self
.
_src_seq_ids
[
sample_idx
],
if
self
.
_shuffle
:
self
.
_trg_seq_ids
[
sample_idx
][:
-
1
],
infos
=
self
.
_sample_infos
self
.
_trg_seq_ids
[
sample_idx
][
1
:])
self
.
_random
.
shuffle
(
infos
)
else
:
def
batch_generator
(
self
):
infos
=
self
.
_sample_infos
pool
=
Pool
(
self
.
_sample_generator
,
self
.
_pool_size
,
True
if
self
.
_sort_type
==
SortType
.
POOL
else
False
)
def
next_batch
():
batch_data
=
[]
max_len
=
-
1
batch_max_seq_len
=
-
1
while
True
:
sample
=
pool
.
next
(
look
=
True
)
if
sample
is
None
:
pool
.
push_back
(
batch_data
)
batch_data
=
[]
continue
if
isinstance
(
sample
,
EndEpoch
):
if
self
.
_sort_type
==
SortType
.
POOL
:
return
batch_data
,
batch_max_seq_len
,
True
for
i
in
range
(
0
,
len
(
infos
),
self
.
_pool_size
):
infos
[
i
:
i
+
self
.
_pool_size
]
=
sorted
(
infos
[
i
:
i
+
self
.
_pool_size
],
key
=
lambda
x
:
x
.
max_len
)
max_len
=
max
(
max_len
,
len
(
sample
[
0
]))
# concat batch
batches
=
[]
batch_creator
=
TokenBatchCreator
(
self
.
_batch_size
)
if
self
.
_use_token_batch
else
SentenceBatchCreator
(
self
.
_batch_size
)
batch_creator
=
MinMaxFilter
(
self
.
_max_length
,
self
.
_min_length
,
batch_creator
)
if
not
self
.
_only_src
:
for
info
in
infos
:
max_len
=
max
(
max_len
,
len
(
sample
[
1
]))
batch
=
batch_creator
.
append
(
info
)
if
batch
is
not
None
:
batches
.
append
(
batch
)
if
self
.
_use_token_batch
:
if
not
self
.
_clip_last_batch
and
len
(
batch_creator
.
batch
)
!=
0
:
if
max_len
*
(
len
(
batch_data
)
+
1
)
<
self
.
_batch_size
:
batches
.
append
(
batch_creator
.
batch
)
batch_max_seq_len
=
max_len
batch_data
.
append
(
pool
.
next
())
else
:
return
batch_data
,
batch_max_seq_len
,
False
else
:
if
len
(
batch_data
)
<
self
.
_batch_size
:
batch_max_seq_len
=
max_len
batch_data
.
append
(
pool
.
next
())
else
:
return
batch_data
,
batch_max_seq_len
,
False
if
not
self
.
_shuffle_batch
:
if
self
.
_shuffle_batch
:
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
self
.
_random
.
shuffle
(
batches
)
while
not
last_batch
:
yield
batch_data
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
batch_size
=
len
(
batch_data
)
for
batch
in
batches
:
if
self
.
_use_token_batch
:
batch_ids
=
[
info
.
i
for
info
in
batch
]
batch_size
*=
batch_max_seq_len
if
(
not
self
.
_clip_last_batch
and
len
(
batch_data
)
>
0
)
\
if
self
.
_only_src
:
or
(
batch_size
==
self
.
_batch_size
):
yield
[[
self
.
_src_seq_ids
[
idx
]]
for
idx
in
batch_ids
]
yield
batch_data
else
:
else
:
# should re-generate batches
yield
[(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
][:
-
1
],
if
self
.
_sort_type
==
SortType
.
POOL
\
self
.
_trg_seq_ids
[
idx
][
1
:])
for
idx
in
batch_ids
]
or
len
(
self
.
_epoch_batches
)
==
0
:
self
.
_epoch_batches
=
[]
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
while
not
last_batch
:
self
.
_epoch_batches
.
append
(
batch_data
)
batch_data
,
batch_max_seq_len
,
last_batch
=
next_batch
()
batch_size
=
len
(
batch_data
)
if
self
.
_use_token_batch
:
batch_size
*=
batch_max_seq_len
if
(
not
self
.
_clip_last_batch
and
len
(
batch_data
)
>
0
)
\
or
(
batch_size
==
self
.
_batch_size
):
self
.
_epoch_batches
.
append
(
batch_data
)
random
.
shuffle
(
self
.
_epoch_batches
)
for
batch_data
in
self
.
_epoch_batches
:
yield
batch_data
fluid/neural_machine_translation/transformer/train.py
浏览文件 @
c34bb5f1
import
os
import
time
import
argparse
import
argparse
import
ast
import
ast
import
numpy
as
np
import
multiprocessing
import
multiprocessing
import
os
import
time
from
functools
import
partial
from
functools
import
partial
import
paddle
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
reader
from
config
import
*
from
model
import
transformer
,
position_encoding_init
from
model
import
transformer
,
position_encoding_init
from
optim
import
LearningRateScheduler
from
optim
import
LearningRateScheduler
from
config
import
*
import
reader
def
parse_args
():
def
parse_args
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录