Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
ee442428
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ee442428
编写于
4月 08, 2020
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix distribute BatchSampler
上级
6431daed
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
172 addition
and
146 deletion
+172
-146
transformer/reader.py
transformer/reader.py
+162
-94
transformer/train.py
transformer/train.py
+10
-52
未找到文件。
transformer/reader.py
浏览文件 @
ee442428
# Copyright (c) 20
19
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 20
20
PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -15,8 +15,9 @@
...
@@ -15,8 +15,9 @@
import
glob
import
glob
import
six
import
six
import
os
import
os
import
tarfile
import
io
import
itertools
import
itertools
from
functools
import
partial
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
...
@@ -24,16 +25,67 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
...
@@ -24,16 +25,67 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from
paddle.fluid.io
import
BatchSampler
,
DataLoader
,
Dataset
from
paddle.fluid.io
import
BatchSampler
,
DataLoader
,
Dataset
def
prepare_train_input
(
insts
,
src_pad_idx
,
trg_pad_idx
,
n_head
):
def
create_data_loader
(
args
,
device
):
data_loaders
=
[
None
,
None
]
data_files
=
[
args
.
training_file
,
args
.
validation_file
]
if
args
.
validation_file
else
[
args
.
training_file
]
for
i
,
data_file
in
enumerate
(
data_files
):
dataset
=
Seq2SeqDataset
(
fpattern
=
data_file
,
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
token_delimiter
=
args
.
token_delimiter
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
],
byte_data
=
True
)
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
args
.
unk_idx
=
dataset
.
get_vocab_summary
()
batch_sampler
=
Seq2SeqBatchSampler
(
dataset
=
dataset
,
use_token_batch
=
args
.
use_token_batch
,
batch_size
=
args
.
batch_size
,
pool_size
=
args
.
pool_size
,
sort_type
=
args
.
sort_type
,
shuffle
=
args
.
shuffle
,
shuffle_batch
=
args
.
shuffle_batch
,
max_length
=
args
.
max_length
,
distribute_mode
=
True
if
i
==
0
else
False
)
# every device eval all data
data_loader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
places
=
device
,
collate_fn
=
partial
(
prepare_train_input
,
bos_idx
=
args
.
bos_idx
,
eos_idx
=
args
.
eos_idx
,
src_pad_idx
=
args
.
eos_idx
,
trg_pad_idx
=
args
.
eos_idx
,
n_head
=
args
.
n_head
),
num_workers
=
0
,
# TODO: use multi-process
return_list
=
True
)
data_loaders
[
i
]
=
data_loader
return
data_loaders
def
prepare_train_input
(
insts
,
bos_idx
,
eos_idx
,
src_pad_idx
,
trg_pad_idx
,
n_head
):
"""
"""
Put all padded data needed by training into a list.
Put all padded data needed by training into a list.
"""
"""
src_word
,
src_pos
,
src_slf_attn_bias
,
src_max_len
=
pad_batch_data
(
src_word
,
src_pos
,
src_slf_attn_bias
,
src_max_len
=
pad_batch_data
(
[
inst
[
0
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
[
inst
[
0
]
+
[
eos_idx
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
)
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
)
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_max_len
=
pad_batch_data
(
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_max_len
=
pad_batch_data
(
[
inst
[
1
]
for
inst
in
insts
],
trg_pad_idx
,
n_head
,
is_target
=
True
)
[[
bos_idx
]
+
inst
[
1
]
for
inst
in
insts
],
trg_pad_idx
,
n_head
,
is_target
=
True
)
trg_word
=
trg_word
.
reshape
(
-
1
,
trg_max_len
)
trg_word
=
trg_word
.
reshape
(
-
1
,
trg_max_len
)
trg_pos
=
trg_pos
.
reshape
(
-
1
,
trg_max_len
)
trg_pos
=
trg_pos
.
reshape
(
-
1
,
trg_max_len
)
...
@@ -41,7 +93,7 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
...
@@ -41,7 +93,7 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
[
1
,
1
,
trg_max_len
,
1
]).
astype
(
"float32"
)
[
1
,
1
,
trg_max_len
,
1
]).
astype
(
"float32"
)
lbl_word
,
lbl_weight
,
num_token
=
pad_batch_data
(
lbl_word
,
lbl_weight
,
num_token
=
pad_batch_data
(
[
inst
[
2
]
for
inst
in
insts
],
[
inst
[
1
]
+
[
eos_idx
]
for
inst
in
insts
],
trg_pad_idx
,
trg_pad_idx
,
n_head
,
n_head
,
is_target
=
False
,
is_target
=
False
,
...
@@ -71,9 +123,7 @@ def prepare_infer_input(insts, src_pad_idx, n_head):
...
@@ -71,9 +123,7 @@ def prepare_infer_input(insts, src_pad_idx, n_head):
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
)
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
)
data_inputs
=
[
data_inputs
=
[
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_src_attn_bias
]
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_src_attn_bias
]
return
data_inputs
return
data_inputs
...
@@ -142,29 +192,30 @@ class SortType(object):
...
@@ -142,29 +192,30 @@ class SortType(object):
class
Converter
(
object
):
class
Converter
(
object
):
def
__init__
(
self
,
vocab
,
beg
,
end
,
unk
,
delimiter
,
add_beg
):
def
__init__
(
self
,
vocab
,
beg
,
end
,
unk
,
delimiter
,
add_beg
,
add_end
):
self
.
_vocab
=
vocab
self
.
_vocab
=
vocab
self
.
_beg
=
beg
self
.
_beg
=
beg
self
.
_end
=
end
self
.
_end
=
end
self
.
_unk
=
unk
self
.
_unk
=
unk
self
.
_delimiter
=
delimiter
self
.
_delimiter
=
delimiter
self
.
_add_beg
=
add_beg
self
.
_add_beg
=
add_beg
self
.
_add_end
=
add_end
def
__call__
(
self
,
sentence
):
def
__call__
(
self
,
sentence
):
return
([
self
.
_beg
]
if
self
.
_add_beg
else
[])
+
[
return
([
self
.
_beg
]
if
self
.
_add_beg
else
[])
+
[
self
.
_vocab
.
get
(
w
,
self
.
_unk
)
self
.
_vocab
.
get
(
w
,
self
.
_unk
)
for
w
in
sentence
.
split
(
self
.
_delimiter
)
for
w
in
sentence
.
split
(
self
.
_delimiter
)
]
+
[
self
.
_end
]
]
+
([
self
.
_end
]
if
self
.
_add_end
else
[])
class
ComposedConverter
(
object
):
class
ComposedConverter
(
object
):
def
__init__
(
self
,
converters
):
def
__init__
(
self
,
converters
):
self
.
_converters
=
converters
self
.
_converters
=
converters
def
__call__
(
self
,
parallel_sentence
):
def
__call__
(
self
,
fields
):
return
[
return
[
self
.
_converters
[
i
](
parallel_sentence
[
i
]
)
converter
(
field
)
for
i
in
range
(
len
(
self
.
_converters
)
)
for
field
,
converter
in
zip
(
fields
,
self
.
_converters
)
]
]
...
@@ -201,10 +252,11 @@ class TokenBatchCreator(object):
...
@@ -201,10 +252,11 @@ class TokenBatchCreator(object):
class
SampleInfo
(
object
):
class
SampleInfo
(
object
):
def
__init__
(
self
,
i
,
max_len
,
min_len
):
def
__init__
(
self
,
i
,
lens
):
self
.
i
=
i
self
.
i
=
i
self
.
min_len
=
min_len
# take bos and eos into account
self
.
max_len
=
max_len
self
.
min_len
=
min
(
lens
[
0
]
+
1
,
lens
[
1
]
+
2
)
self
.
max_len
=
max
(
lens
[
0
]
+
1
,
lens
[
1
]
+
2
)
class
MinMaxFilter
(
object
):
class
MinMaxFilter
(
object
):
...
@@ -229,98 +281,109 @@ class Seq2SeqDataset(Dataset):
...
@@ -229,98 +281,109 @@ class Seq2SeqDataset(Dataset):
src_vocab_fpath
,
src_vocab_fpath
,
trg_vocab_fpath
,
trg_vocab_fpath
,
fpattern
,
fpattern
,
tar_fname
=
None
,
field_delimiter
=
"
\t
"
,
field_delimiter
=
"
\t
"
,
token_delimiter
=
" "
,
token_delimiter
=
" "
,
start_mark
=
"<s>"
,
start_mark
=
"<s>"
,
end_mark
=
"<e>"
,
end_mark
=
"<e>"
,
unk_mark
=
"<unk>"
,
unk_mark
=
"<unk>"
,
only_src
=
False
):
only_src
=
False
,
# convert str to bytes, and use byte data
trg_fpattern
=
None
,
byte_data
=
False
):
if
byte_data
:
# The WMT16 bpe data used here seems including bytes can not be
# decoded by utf8. Thus convert str to bytes, and use byte data
field_delimiter
=
field_delimiter
.
encode
(
"utf8"
)
field_delimiter
=
field_delimiter
.
encode
(
"utf8"
)
token_delimiter
=
token_delimiter
.
encode
(
"utf8"
)
token_delimiter
=
token_delimiter
.
encode
(
"utf8"
)
start_mark
=
start_mark
.
encode
(
"utf8"
)
start_mark
=
start_mark
.
encode
(
"utf8"
)
end_mark
=
end_mark
.
encode
(
"utf8"
)
end_mark
=
end_mark
.
encode
(
"utf8"
)
unk_mark
=
unk_mark
.
encode
(
"utf8"
)
unk_mark
=
unk_mark
.
encode
(
"utf8"
)
self
.
_src_vocab
=
self
.
load_dict
(
src_vocab_fpath
)
self
.
_byte_data
=
byte_data
self
.
_trg_vocab
=
self
.
load_dict
(
trg_vocab_fpath
)
self
.
_src_vocab
=
self
.
load_dict
(
src_vocab_fpath
,
byte_data
=
byte_data
)
self
.
_trg_vocab
=
self
.
load_dict
(
trg_vocab_fpath
,
byte_data
=
byte_data
)
self
.
_bos_idx
=
self
.
_src_vocab
[
start_mark
]
self
.
_bos_idx
=
self
.
_src_vocab
[
start_mark
]
self
.
_eos_idx
=
self
.
_src_vocab
[
end_mark
]
self
.
_eos_idx
=
self
.
_src_vocab
[
end_mark
]
self
.
_unk_idx
=
self
.
_src_vocab
[
unk_mark
]
self
.
_unk_idx
=
self
.
_src_vocab
[
unk_mark
]
self
.
_only_src
=
only_src
self
.
_field_delimiter
=
field_delimiter
self
.
_field_delimiter
=
field_delimiter
self
.
_token_delimiter
=
token_delimiter
self
.
_token_delimiter
=
token_delimiter
self
.
load_src_trg_ids
(
fpattern
,
t
ar_fname
)
self
.
load_src_trg_ids
(
fpattern
,
t
rg_fpattern
)
def
load_src_trg_ids
(
self
,
fpattern
,
t
ar_fnam
e
):
def
load_src_trg_ids
(
self
,
fpattern
,
t
rg_fpattern
=
Non
e
):
converters
=
[
src_converter
=
Converter
(
Converter
(
vocab
=
self
.
_src_vocab
,
vocab
=
self
.
_src_vocab
,
beg
=
self
.
_bos_idx
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
False
)
add_beg
=
False
,
]
add_end
=
False
)
if
not
self
.
_only_src
:
converters
.
append
(
trg_converter
=
Converter
(
Converter
(
vocab
=
self
.
_trg_vocab
,
vocab
=
self
.
_trg_vocab
,
beg
=
self
.
_bos_idx
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
True
))
add_beg
=
False
,
add_end
=
False
)
converters
=
ComposedConverter
(
converters
)
converters
=
ComposedConverter
(
[
src_converter
,
trg_converter
]
)
self
.
_src_seq_ids
=
[]
self
.
_src_seq_ids
=
[]
self
.
_trg_seq_ids
=
None
if
self
.
_only_src
else
[]
self
.
_trg_seq_ids
=
[]
self
.
_sample_infos
=
[]
self
.
_sample_infos
=
[]
for
i
,
line
in
enumerate
(
self
.
_load_lines
(
fpattern
,
tar_fname
)):
slots
=
[
self
.
_src_seq_ids
,
self
.
_trg_seq_ids
]
src_trg_ids
=
converters
(
line
)
for
i
,
line
in
enumerate
(
self
.
_load_lines
(
fpattern
,
trg_fpattern
)):
self
.
_src_seq_ids
.
append
(
src_trg_ids
[
0
])
lens
=
[]
lens
=
[
len
(
src_trg_ids
[
0
])]
for
field
,
slot
in
zip
(
converters
(
line
),
slots
):
if
not
self
.
_only_src
:
slot
.
append
(
field
)
self
.
_trg_seq_ids
.
append
(
src_trg_ids
[
1
])
lens
.
append
(
len
(
field
))
lens
.
append
(
len
(
src_trg_ids
[
1
]))
self
.
_sample_infos
.
append
(
SampleInfo
(
i
,
lens
))
self
.
_sample_infos
.
append
(
SampleInfo
(
i
,
max
(
lens
),
min
(
lens
)))
def
_load_lines
(
self
,
fpattern
,
t
ar_fnam
e
):
def
_load_lines
(
self
,
fpattern
,
t
rg_fpattern
=
Non
e
):
fpaths
=
glob
.
glob
(
fpattern
)
fpaths
=
glob
.
glob
(
fpattern
)
fpaths
=
sorted
(
fpaths
)
# TODO: Add custum sort
assert
len
(
fpaths
)
>
0
,
"no matching file to the provided data path"
assert
len
(
fpaths
)
>
0
,
"no matching file to the provided data path"
if
len
(
fpaths
)
==
1
and
tarfile
.
is_tarfile
(
fpaths
[
0
]):
(
f_mode
,
f_encoding
,
if
tar_fname
is
None
:
endl
)
=
(
"rb"
,
None
,
b
"
\n
"
)
if
self
.
_byte_data
else
(
"r"
,
"utf8"
,
raise
Exception
(
"If tar file provided, please set tar_fname."
)
"
\n
"
)
if
trg_fpattern
is
None
:
f
=
tarfile
.
open
(
fpaths
[
0
],
"rb"
)
for
line
in
f
.
extractfile
(
tar_fname
):
fields
=
line
.
strip
(
b
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
yield
fields
else
:
for
fpath
in
fpaths
:
for
fpath
in
fpaths
:
if
not
os
.
path
.
isfile
(
fpath
):
with
io
.
open
(
fpath
,
f_mode
,
encoding
=
f_encoding
)
as
f
:
raise
IOError
(
"Invalid file: %s"
%
fpath
)
with
open
(
fpath
,
"rb"
)
as
f
:
for
line
in
f
:
for
line
in
f
:
fields
=
line
.
strip
(
b
"
\n
"
).
split
(
self
.
_field_delimiter
)
fields
=
line
.
strip
(
endl
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
yield
fields
self
.
_only_src
and
len
(
fields
)
==
1
):
else
:
# separated source and target language data files
# assume we can get aligned data by sort the two language files
# TODO: Need more rigorous check
trg_fpaths
=
glob
.
glob
(
trg_fpattern
)
trg_fpaths
=
sorted
(
trg_fpaths
)
assert
len
(
fpaths
)
==
len
(
trg_fpaths
),
"the number of source language data files must equal
\
with that of source language"
for
fpath
,
trg_fpath
in
zip
(
fpaths
,
trg_fpaths
):
with
io
.
open
(
fpath
,
f_mode
,
encoding
=
f_encoding
)
as
f
:
with
io
.
open
(
trg_fpath
,
f_mode
,
encoding
=
f_encoding
)
as
trg_f
:
for
line
in
zip
(
f
,
trg_f
):
fields
=
[
field
.
strip
(
endl
)
for
field
in
line
]
yield
fields
yield
fields
@
staticmethod
@
staticmethod
def
load_dict
(
dict_path
,
reverse
=
False
):
def
load_dict
(
dict_path
,
reverse
=
False
,
byte_data
=
False
):
word_dict
=
{}
word_dict
=
{}
with
open
(
dict_path
,
"rb"
)
as
fdict
:
(
f_mode
,
f_encoding
,
endl
)
=
(
"rb"
,
None
,
b
"
\n
"
)
if
byte_data
else
(
"r"
,
"utf8"
,
"
\n
"
)
with
io
.
open
(
dict_path
,
f_mode
,
encoding
=
f_encoding
)
as
fdict
:
for
idx
,
line
in
enumerate
(
fdict
):
for
idx
,
line
in
enumerate
(
fdict
):
if
reverse
:
if
reverse
:
word_dict
[
idx
]
=
line
.
strip
(
b
"
\n
"
)
word_dict
[
idx
]
=
line
.
strip
(
endl
)
else
:
else
:
word_dict
[
line
.
strip
(
b
"
\n
"
)]
=
idx
word_dict
[
line
.
strip
(
endl
)]
=
idx
return
word_dict
return
word_dict
def
get_vocab_summary
(
self
):
def
get_vocab_summary
(
self
):
...
@@ -328,9 +391,8 @@ class Seq2SeqDataset(Dataset):
...
@@ -328,9 +391,8 @@ class Seq2SeqDataset(Dataset):
self
.
_trg_vocab
),
self
.
_bos_idx
,
self
.
_eos_idx
,
self
.
_unk_idx
self
.
_trg_vocab
),
self
.
_bos_idx
,
self
.
_eos_idx
,
self
.
_unk_idx
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
return
(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
][:
-
1
],
return
(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
]
self
.
_trg_seq_ids
[
idx
][
1
:]
)
if
self
.
_trg_seq_ids
else
self
.
_src_seq_ids
[
idx
]
)
if
not
self
.
_only_src
else
self
.
_src_seq_ids
[
idx
]
def
__len__
(
self
):
def
__len__
(
self
):
return
len
(
self
.
_sample_infos
)
return
len
(
self
.
_sample_infos
)
...
@@ -348,6 +410,7 @@ class Seq2SeqBatchSampler(BatchSampler):
...
@@ -348,6 +410,7 @@ class Seq2SeqBatchSampler(BatchSampler):
shuffle_batch
=
False
,
shuffle_batch
=
False
,
use_token_batch
=
False
,
use_token_batch
=
False
,
clip_last_batch
=
False
,
clip_last_batch
=
False
,
distribute_mode
=
True
,
seed
=
0
):
seed
=
0
):
for
arg
,
value
in
locals
().
items
():
for
arg
,
value
in
locals
().
items
():
if
arg
!=
"self"
:
if
arg
!=
"self"
:
...
@@ -355,6 +418,7 @@ class Seq2SeqBatchSampler(BatchSampler):
...
@@ -355,6 +418,7 @@ class Seq2SeqBatchSampler(BatchSampler):
self
.
_random
=
np
.
random
self
.
_random
=
np
.
random
self
.
_random
.
seed
(
seed
)
self
.
_random
.
seed
(
seed
)
# for multi-devices
# for multi-devices
self
.
_distribute_mode
=
distribute_mode
self
.
_nranks
=
ParallelEnv
().
nranks
self
.
_nranks
=
ParallelEnv
().
nranks
self
.
_local_rank
=
ParallelEnv
().
local_rank
self
.
_local_rank
=
ParallelEnv
().
local_rank
self
.
_device_id
=
ParallelEnv
().
dev_id
self
.
_device_id
=
ParallelEnv
().
dev_id
...
@@ -362,8 +426,8 @@ class Seq2SeqBatchSampler(BatchSampler):
...
@@ -362,8 +426,8 @@ class Seq2SeqBatchSampler(BatchSampler):
def
__iter__
(
self
):
def
__iter__
(
self
):
# global sort or global shuffle
# global sort or global shuffle
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
infos
=
sorted
(
self
.
_dataset
.
_sample_infos
,
infos
=
sorted
(
key
=
lambda
x
:
x
.
max_len
)
self
.
_dataset
.
_sample_infos
,
key
=
lambda
x
:
x
.
max_len
)
else
:
else
:
if
self
.
_shuffle
:
if
self
.
_shuffle
:
infos
=
self
.
_dataset
.
_sample_infos
infos
=
self
.
_dataset
.
_sample_infos
...
@@ -383,9 +447,9 @@ class Seq2SeqBatchSampler(BatchSampler):
...
@@ -383,9 +447,9 @@ class Seq2SeqBatchSampler(BatchSampler):
batches
=
[]
batches
=
[]
batch_creator
=
TokenBatchCreator
(
batch_creator
=
TokenBatchCreator
(
self
.
_batch_size
self
.
)
if
self
.
_use_token_batch
else
SentenceBatchCreator
(
self
.
_batch_size
*
_batch_size
)
if
self
.
_use_token_batch
else
SentenceBatchCreator
(
self
.
_nranks
)
self
.
_batch_size
*
self
.
_nranks
)
batch_creator
=
MinMaxFilter
(
self
.
_max_length
,
self
.
_min_length
,
batch_creator
=
MinMaxFilter
(
self
.
_max_length
,
self
.
_min_length
,
batch_creator
)
batch_creator
)
...
@@ -413,11 +477,15 @@ class Seq2SeqBatchSampler(BatchSampler):
...
@@ -413,11 +477,15 @@ class Seq2SeqBatchSampler(BatchSampler):
# for multi-device
# for multi-device
for
batch_id
,
batch
in
enumerate
(
batches
):
for
batch_id
,
batch
in
enumerate
(
batches
):
if
batch_id
%
self
.
_nranks
==
self
.
_local_rank
:
if
not
self
.
_distribute_mode
or
(
batch_id
%
self
.
_nranks
==
self
.
_local_rank
):
batch_indices
=
[
info
.
i
for
info
in
batch
]
batch_indices
=
[
info
.
i
for
info
in
batch
]
yield
batch_indices
yield
batch_indices
if
self
.
_local_rank
>
len
(
batches
)
%
self
.
_nranks
:
if
self
.
_distribute_mode
and
len
(
batches
)
%
self
.
_nranks
!=
0
:
if
self
.
_local_rank
>=
len
(
batches
)
%
self
.
_nranks
:
# use previous data to pad
yield
batch_indices
yield
batch_indices
def
__len__
(
self
):
def
__len__
(
self
):
return
100
# TODO(guosheng): fix the uncertain length
return
0
transformer/train.py
浏览文件 @
ee442428
...
@@ -17,7 +17,6 @@ import os
...
@@ -17,7 +17,6 @@ import os
import
six
import
six
import
sys
import
sys
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))))
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))))
from
functools
import
partial
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
...
@@ -29,14 +28,18 @@ from utils.check import check_gpu, check_version
...
@@ -29,14 +28,18 @@ from utils.check import check_gpu, check_version
from
model
import
Input
,
set_device
from
model
import
Input
,
set_device
from
callbacks
import
ProgBarLogger
from
callbacks
import
ProgBarLogger
from
reader
import
prepare_train_input
,
Seq2SeqDataset
,
Seq2SeqBatchSampl
er
from
reader
import
create_data_load
er
from
transformer
import
Transformer
,
CrossEntropyCriterion
from
transformer
import
Transformer
,
CrossEntropyCriterion
class
TrainCallback
(
ProgBarLogger
):
class
TrainCallback
(
ProgBarLogger
):
def
__init__
(
self
,
log_freq
=
1
,
verbose
=
2
,
loss_normalizer
=
0.
):
def
__init__
(
self
,
args
,
verbose
=
2
):
super
(
TrainCallback
,
self
).
__init__
(
log_freq
,
verbose
)
super
(
TrainCallback
,
self
).
__init__
(
args
.
print_step
,
verbose
)
# TODO: wrap these override function to simplify
# the best cross-entropy value with label smoothing
loss_normalizer
=
-
(
(
1.
-
args
.
label_smooth_eps
)
*
np
.
log
(
(
1.
-
args
.
label_smooth_eps
))
+
args
.
label_smooth_eps
*
np
.
log
(
args
.
label_smooth_eps
/
(
args
.
trg_vocab_size
-
1
)
+
1e-20
))
self
.
loss_normalizer
=
loss_normalizer
self
.
loss_normalizer
=
loss_normalizer
def
on_train_begin
(
self
,
logs
=
None
):
def
on_train_begin
(
self
,
logs
=
None
):
...
@@ -100,42 +103,7 @@ def do_train(args):
...
@@ -100,42 +103,7 @@ def do_train(args):
]
]
# def dataloader
# def dataloader
data_loaders
=
[
None
,
None
]
train_loader
,
eval_loader
=
create_data_loader
(
args
,
device
)
data_files
=
[
args
.
training_file
,
args
.
validation_file
]
if
args
.
validation_file
else
[
args
.
training_file
]
for
i
,
data_file
in
enumerate
(
data_files
):
dataset
=
Seq2SeqDataset
(
fpattern
=
data_file
,
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
token_delimiter
=
args
.
token_delimiter
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
])
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
args
.
unk_idx
=
dataset
.
get_vocab_summary
()
batch_sampler
=
Seq2SeqBatchSampler
(
dataset
=
dataset
,
use_token_batch
=
args
.
use_token_batch
,
batch_size
=
args
.
batch_size
,
pool_size
=
args
.
pool_size
,
sort_type
=
args
.
sort_type
,
shuffle
=
args
.
shuffle
,
shuffle_batch
=
args
.
shuffle_batch
,
max_length
=
args
.
max_length
)
data_loader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
places
=
device
,
collate_fn
=
partial
(
prepare_train_input
,
src_pad_idx
=
args
.
eos_idx
,
trg_pad_idx
=
args
.
eos_idx
,
n_head
=
args
.
n_head
),
num_workers
=
0
,
# TODO: use multi-process
return_list
=
True
)
data_loaders
[
i
]
=
data_loader
train_loader
,
eval_loader
=
data_loaders
# define model
# define model
transformer
=
Transformer
(
transformer
=
Transformer
(
...
@@ -166,12 +134,6 @@ def do_train(args):
...
@@ -166,12 +134,6 @@ def do_train(args):
if
args
.
init_from_pretrain_model
:
if
args
.
init_from_pretrain_model
:
transformer
.
load
(
args
.
init_from_pretrain_model
,
reset_optimizer
=
True
)
transformer
.
load
(
args
.
init_from_pretrain_model
,
reset_optimizer
=
True
)
# the best cross-entropy value with label smoothing
loss_normalizer
=
-
(
(
1.
-
args
.
label_smooth_eps
)
*
np
.
log
(
(
1.
-
args
.
label_smooth_eps
))
+
args
.
label_smooth_eps
*
np
.
log
(
args
.
label_smooth_eps
/
(
args
.
trg_vocab_size
-
1
)
+
1e-20
))
# model train
# model train
transformer
.
fit
(
train_data
=
train_loader
,
transformer
.
fit
(
train_data
=
train_loader
,
eval_data
=
eval_loader
,
eval_data
=
eval_loader
,
...
@@ -180,11 +142,7 @@ def do_train(args):
...
@@ -180,11 +142,7 @@ def do_train(args):
save_freq
=
1
,
save_freq
=
1
,
save_dir
=
args
.
save_model
,
save_dir
=
args
.
save_model
,
verbose
=
2
,
verbose
=
2
,
callbacks
=
[
callbacks
=
[
TrainCallback
(
args
)])
TrainCallback
(
log_freq
=
args
.
print_step
,
loss_normalizer
=
loss_normalizer
)
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录