Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
68dfe864
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
68dfe864
编写于
3月 26, 2020
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update Transformer
上级
64a7e1a0
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
346 addition
and
185 deletion
+346
-185
transformer/reader.py
transformer/reader.py
+249
-53
transformer/train.py
transformer/train.py
+97
-132
未找到文件。
transformer/reader.py
浏览文件 @
68dfe864
...
...
@@ -16,18 +16,68 @@ import glob
import
six
import
os
import
tarfile
import
itertools
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid.io
import
BatchSampler
,
DataLoader
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.io
import
BatchSampler
,
DataLoader
,
Dataset
class
TokenBatchSampler
(
BatchSampler
):
def
__init__
(
self
):
pass
def
prepare_train_input
(
insts
,
src_pad_idx
,
trg_pad_idx
,
n_head
):
"""
Put all padded data needed by training into a list.
"""
src_word
,
src_pos
,
src_slf_attn_bias
,
src_max_len
=
pad_batch_data
(
[
inst
[
0
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
)
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_max_len
=
pad_batch_data
(
[
inst
[
1
]
for
inst
in
insts
],
trg_pad_idx
,
n_head
,
is_target
=
True
)
trg_word
=
trg_word
.
reshape
(
-
1
,
trg_max_len
)
trg_pos
=
trg_pos
.
reshape
(
-
1
,
trg_max_len
)
def
__iter
(
self
):
pass
trg_src_attn_bias
=
np
.
tile
(
src_slf_attn_bias
[:,
:,
::
src_max_len
,
:],
[
1
,
1
,
trg_max_len
,
1
]).
astype
(
"float32"
)
lbl_word
,
lbl_weight
,
num_token
=
pad_batch_data
(
[
inst
[
2
]
for
inst
in
insts
],
trg_pad_idx
,
n_head
,
is_target
=
False
,
is_label
=
True
,
return_attn_bias
=
False
,
return_max_len
=
False
,
return_num_token
=
True
)
lbl_word
=
lbl_word
.
reshape
(
-
1
,
1
)
lbl_weight
=
lbl_weight
.
reshape
(
-
1
,
1
)
data_inputs
=
[
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
lbl_word
,
lbl_weight
]
return
data_inputs
def
prepare_infer_input
(
insts
,
src_pad_idx
,
bos_idx
,
n_head
):
"""
Put all padded data needed by beam search decoder into a list.
"""
src_word
,
src_pos
,
src_slf_attn_bias
,
src_max_len
=
pad_batch_data
(
[
inst
[
0
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
# start tokens
trg_word
=
np
.
asarray
([[
bos_idx
]]
*
len
(
insts
),
dtype
=
"int64"
)
trg_src_attn_bias
=
np
.
tile
(
src_slf_attn_bias
[:,
:,
::
src_max_len
,
:],
[
1
,
1
,
1
,
1
]).
astype
(
"float32"
)
trg_word
=
trg_word
.
reshape
(
-
1
,
1
)
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
)
data_inputs
=
[
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_src_attn_bias
]
return
data_inputs
def
pad_batch_data
(
insts
,
...
...
@@ -88,60 +138,206 @@ def pad_batch_data(insts,
return
return_list
if
len
(
return_list
)
>
1
else
return_list
[
0
]
def
prepare_train_input
(
insts
,
src_pad_idx
,
trg_pad_idx
,
n_head
):
"""
Put all padded data needed by training into a list.
"""
src_word
,
src_pos
,
src_slf_attn_bias
,
src_max_len
=
pad_batch_data
(
[
inst
[
0
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
)
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_max_len
=
pad_batch_data
(
[
inst
[
1
]
for
inst
in
insts
],
trg_pad_idx
,
n_head
,
is_target
=
True
)
trg_word
=
trg_word
.
reshape
(
-
1
,
trg_max_len
)
trg_pos
=
trg_pos
.
reshape
(
-
1
,
trg_max_len
)
class
Seq2SeqDataset
(
Dataset
):
def
__init__
(
self
,
src_vocab_fpath
,
trg_vocab_fpath
,
fpattern
,
tar_fname
=
None
,
field_delimiter
=
"
\t
"
,
token_delimiter
=
" "
,
start_mark
=
"<s>"
,
end_mark
=
"<e>"
,
unk_mark
=
"<unk>"
,
only_src
=
False
):
# convert str to bytes, and use byte data
field_delimiter
=
field_delimiter
.
encode
(
"utf8"
)
token_delimiter
=
token_delimiter
.
encode
(
"utf8"
)
start_mark
=
start_mark
.
encode
(
"utf8"
)
end_mark
=
end_mark
.
encode
(
"utf8"
)
unk_mark
=
unk_mark
.
encode
(
"utf8"
)
self
.
_src_vocab
=
self
.
load_dict
(
src_vocab_fpath
)
self
.
_trg_vocab
=
self
.
load_dict
(
trg_vocab_fpath
)
self
.
_bos_idx
=
self
.
_src_vocab
[
start_mark
]
self
.
_eos_idx
=
self
.
_src_vocab
[
end_mark
]
self
.
_unk_idx
=
self
.
_src_vocab
[
unk_mark
]
self
.
_only_src
=
only_src
self
.
_field_delimiter
=
field_delimiter
self
.
_token_delimiter
=
token_delimiter
self
.
load_src_trg_ids
(
fpattern
,
tar_fname
)
trg_src_attn_bias
=
np
.
tile
(
src_slf_attn_bias
[:,
:,
::
src_max_len
,
:],
[
1
,
1
,
trg_max_len
,
1
]).
astype
(
"float32"
)
def
load_src_trg_ids
(
self
,
fpattern
,
tar_fname
):
converters
=
[
Converter
(
vocab
=
self
.
_src_vocab
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
False
)
]
if
not
self
.
_only_src
:
converters
.
append
(
Converter
(
vocab
=
self
.
_trg_vocab
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
True
))
lbl_word
,
lbl_weight
,
num_token
=
pad_batch_data
(
[
inst
[
2
]
for
inst
in
insts
],
trg_pad_idx
,
n_head
,
is_target
=
False
,
is_label
=
True
,
return_attn_bias
=
False
,
return_max_len
=
False
,
return_num_token
=
True
)
lbl_word
=
lbl_word
.
reshape
(
-
1
,
1
)
lbl_weight
=
lbl_weight
.
reshape
(
-
1
,
1
)
converters
=
ComposedConverter
(
converters
)
data_inputs
=
[
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
lbl_word
,
lbl_weight
]
self
.
_src_seq_ids
=
[]
self
.
_trg_seq_ids
=
None
if
self
.
_only_src
else
[]
self
.
_sample_infos
=
[]
return
data_inputs
for
i
,
line
in
enumerate
(
self
.
_load_lines
(
fpattern
,
tar_fname
)):
src_trg_ids
=
converters
(
line
)
self
.
_src_seq_ids
.
append
(
src_trg_ids
[
0
])
lens
=
[
len
(
src_trg_ids
[
0
])]
if
not
self
.
_only_src
:
self
.
_trg_seq_ids
.
append
(
src_trg_ids
[
1
])
lens
.
append
(
len
(
src_trg_ids
[
1
]))
self
.
_sample_infos
.
append
(
SampleInfo
(
i
,
max
(
lens
),
min
(
lens
)))
def
_load_lines
(
self
,
fpattern
,
tar_fname
):
fpaths
=
glob
.
glob
(
fpattern
)
assert
len
(
fpaths
)
>
0
,
"no matching file to the provided data path"
def
prepare_infer_input
(
insts
,
src_pad_idx
,
bos_idx
,
n_head
):
"""
Put all padded data needed by beam search decoder into a list.
"""
src_word
,
src_pos
,
src_slf_attn_bias
,
src_max_len
=
pad_batch_data
(
[
inst
[
0
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
# start tokens
trg_word
=
np
.
asarray
([[
bos_idx
]]
*
len
(
insts
),
dtype
=
"int64"
)
trg_src_attn_bias
=
np
.
tile
(
src_slf_attn_bias
[:,
:,
::
src_max_len
,
:],
[
1
,
1
,
1
,
1
]).
astype
(
"float32"
)
trg_word
=
trg_word
.
reshape
(
-
1
,
1
)
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
)
if
len
(
fpaths
)
==
1
and
tarfile
.
is_tarfile
(
fpaths
[
0
]):
if
tar_fname
is
None
:
raise
Exception
(
"If tar file provided, please set tar_fname."
)
data_inputs
=
[
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_src_attn_bias
]
return
data_inputs
f
=
tarfile
.
open
(
fpaths
[
0
],
"rb"
)
for
line
in
f
.
extractfile
(
tar_fname
):
fields
=
line
.
strip
(
b
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
yield
fields
else
:
for
fpath
in
fpaths
:
if
not
os
.
path
.
isfile
(
fpath
):
raise
IOError
(
"Invalid file: %s"
%
fpath
)
with
open
(
fpath
,
"rb"
)
as
f
:
for
line
in
f
:
fields
=
line
.
strip
(
b
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
yield
fields
@
staticmethod
def
load_dict
(
dict_path
,
reverse
=
False
):
word_dict
=
{}
with
open
(
dict_path
,
"rb"
)
as
fdict
:
for
idx
,
line
in
enumerate
(
fdict
):
if
reverse
:
word_dict
[
idx
]
=
line
.
strip
(
b
"
\n
"
)
else
:
word_dict
[
line
.
strip
(
b
"
\n
"
)]
=
idx
return
word_dict
def
get_vocab_summary
(
self
):
return
len
(
self
.
_src_vocab
),
len
(
self
.
_trg_vocab
),
self
.
_bos_idx
,
self
.
_eos_idx
,
self
.
_unk_idx
def
__getitem__
(
self
,
idx
):
return
(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
]
)
if
not
self
.
_only_src
else
self
.
_src_seq_ids
[
idx
]
def
__len__
(
self
):
return
len
(
self
.
_sample_infos
)
class
Seq2SeqBatchSampler
(
BatchSampler
):
def
__init__
(
self
,
dataset
,
batch_size
,
pool_size
,
sort_type
=
SortType
.
GLOBAL
,
min_length
=
0
,
max_length
=
100
,
shuffle
=
True
,
shuffle_batch
=
False
,
use_token_batch
=
False
,
clip_last_batch
=
False
,
seed
=
0
):
for
arg
,
value
in
locals
().
items
():
if
arg
!=
"self"
:
setattr
(
self
,
"_"
+
arg
,
value
)
self
.
_random
=
np
.
random
self
.
_random
.
seed
(
seed
)
# for multi-devices
self
.
_nranks
=
ParallelEnv
().
nranks
self
.
_local_rank
=
ParallelEnv
().
local_rank
self
.
_device_id
=
ParallelEnv
().
dev_id
def
__iter__
(
self
):
# global sort or global shuffle
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
infos
=
sorted
(
self
.
dataset
.
_sample_infos
,
key
=
lambda
x
:
x
.
max_len
)
else
:
if
self
.
_shuffle
:
infos
=
self
.
dataset
.
_sample_infos
self
.
_random
.
shuffle
(
infos
)
else
:
infos
=
self
.
dataset
.
_sample_infos
if
self
.
_sort_type
==
SortType
.
POOL
:
reverse
=
True
for
i
in
range
(
0
,
len
(
infos
),
self
.
_pool_size
):
# to avoid placing short next to long sentences
reverse
=
not
reverse
infos
[
i
:
i
+
self
.
_pool_size
]
=
sorted
(
infos
[
i
:
i
+
self
.
_pool_size
],
key
=
lambda
x
:
x
.
max_len
,
reverse
=
reverse
)
batches
=
[]
batch_creator
=
TokenBatchCreator
(
self
.
_batch_size
)
if
self
.
_use_token_batch
else
SentenceBatchCreator
(
self
.
_batch_size
*
self
.
_nranks
)
batch_creator
=
MinMaxFilter
(
self
.
_max_length
,
self
.
_min_length
,
batch_creator
)
for
info
in
infos
:
batch
=
batch_creator
.
append
(
info
)
if
batch
is
not
None
:
batches
.
append
(
batch
)
if
not
self
.
_clip_last_batch
and
len
(
batch_creator
.
batch
)
!=
0
:
batches
.
append
(
batch_creator
.
batch
)
if
self
.
_shuffle_batch
:
self
.
_random
.
shuffle
(
batches
)
if
not
self
.
_use_token_batch
:
# when producing batches according to sequence number, to confirm
# neighbor batches which would be feed and run parallel have similar
# length (thus similar computational cost) after shuffle, we as take
# them as a whole when shuffling and split here
batches
=
[[
batch
[
self
.
_batch_size
*
i
:
self
.
_batch_size
*
(
i
+
1
)]
for
i
in
range
(
self
.
_nranks
)
]
for
batch
in
batches
]
batches
=
itertools
.
chain
.
from_iterable
(
batches
)
# for multi-device
for
batch_id
,
batch
in
enumerate
(
batches
):
if
batch_id
%
self
.
_nranks
==
self
.
_local_rank
:
batch_indices
=
[
info
.
i
for
info
in
batch
]
yield
batch_indices
if
self
.
_local_rank
>
len
(
batches
)
%
self
.
_nranks
:
yield
batch_indices
def
__len__
(
self
):
pass
@
property
def
dev_id
(
self
):
return
self
.
_dev_id
class
SortType
(
object
):
...
...
transformer/train.py
浏览文件 @
68dfe864
...
...
@@ -24,6 +24,7 @@ import numpy as np
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph
import
to_variable
from
paddle.fluid.io
import
DataLoader
from
utils.configure
import
PDConfig
from
utils.check
import
check_gpu
,
check_version
...
...
@@ -38,6 +39,7 @@ from callbacks import ProgBarLogger
class
LoggerCallback
(
ProgBarLogger
):
def
__init__
(
self
,
log_freq
=
1
,
verbose
=
2
,
loss_normalizer
=
0.
):
super
(
LoggerCallback
,
self
).
__init__
(
log_freq
,
verbose
)
# TODO: wrap these override function to simplify
self
.
loss_normalizer
=
loss_normalizer
def
on_train_begin
(
self
,
logs
=
None
):
...
...
@@ -60,148 +62,111 @@ class LoggerCallback(ProgBarLogger):
def
do_train
(
args
):
init_context
(
'dynamic'
if
FLAGS
.
dynamic
else
'static'
)
trainer_count
=
1
#get_nranks()
@
contextlib
.
contextmanager
def
null_guard
():
yield
guard
=
fluid
.
dygraph
.
guard
()
if
args
.
eager_run
else
null_guard
()
# define the data generator
processor
=
reader
.
DataProcessor
(
# init_context('dynamic' if FLAGS.dynamic else 'static')
# set seed for CE
random_seed
=
eval
(
str
(
args
.
random_seed
))
if
random_seed
is
not
None
:
fluid
.
default_main_program
().
random_seed
=
random_seed
fluid
.
default_startup_program
().
random_seed
=
random_seed
# define model
inputs
=
[
Input
(
[
None
,
None
],
"int64"
,
name
=
"src_word"
),
Input
(
[
None
,
None
],
"int64"
,
name
=
"src_pos"
),
Input
(
[
None
,
args
.
n_head
,
None
,
None
],
"float32"
,
name
=
"src_slf_attn_bias"
),
Input
(
[
None
,
None
],
"int64"
,
name
=
"trg_word"
),
Input
(
[
None
,
None
],
"int64"
,
name
=
"trg_pos"
),
Input
(
[
None
,
args
.
n_head
,
None
,
None
],
"float32"
,
name
=
"trg_slf_attn_bias"
),
Input
(
[
None
,
args
.
n_head
,
None
,
None
],
"float32"
,
name
=
"trg_src_attn_bias"
)
]
labels
=
[
Input
(
[
None
,
1
],
"int64"
,
name
=
"label"
),
Input
(
[
None
,
1
],
"float32"
,
name
=
"weight"
),
]
dataset
=
reader
.
Seq2SeqDataset
(
fpattern
=
args
.
training_file
,
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
token_delimiter
=
args
.
token_delimiter
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
])
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
args
.
unk_idx
=
dataset
.
get_vocab_summary
()
batch_sampler
=
reader
.
Seq2SeqBatchSampler
(
dataset
=
dataset
,
use_token_batch
=
args
.
use_token_batch
,
batch_size
=
args
.
batch_size
,
device_count
=
trainer_count
,
pool_size
=
args
.
pool_size
,
sort_type
=
args
.
sort_type
,
shuffle
=
args
.
shuffle
,
shuffle_batch
=
args
.
shuffle_batch
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
],
max_length
=
args
.
max_length
,
n_head
=
args
.
n_head
)
batch_generator
=
processor
.
data_generator
(
phase
=
"train"
)
if
trainer_count
>
1
:
# for multi-process gpu training
batch_generator
=
fluid
.
contrib
.
reader
.
distributed_batch_reader
(
batch_generator
)
if
args
.
validation_file
:
val_processor
=
reader
.
DataProcessor
(
fpattern
=
args
.
validation_file
,
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
token_delimiter
=
args
.
token_delimiter
,
use_token_batch
=
args
.
use_token_batch
,
batch_size
=
args
.
batch_size
,
device_count
=
trainer_count
,
pool_size
=
args
.
pool_size
,
sort_type
=
args
.
sort_type
,
shuffle
=
False
,
shuffle_batch
=
False
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
],
max_length
=
args
.
max_length
,
n_head
=
args
.
n_head
)
val_batch_generator
=
val_processor
.
data_generator
(
phase
=
"train"
)
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
args
.
unk_idx
=
processor
.
get_vocab_summary
()
with
guard
:
# set seed for CE
random_seed
=
eval
(
str
(
args
.
random_seed
))
if
random_seed
is
not
None
:
fluid
.
default_main_program
().
random_seed
=
random_seed
fluid
.
default_startup_program
().
random_seed
=
random_seed
# define data loader
train_loader
=
batch_generator
if
args
.
validation_file
:
val_loader
=
val_batch_generator
# define model
inputs
=
[
Input
(
[
None
,
None
],
"int64"
,
name
=
"src_word"
),
Input
(
[
None
,
None
],
"int64"
,
name
=
"src_pos"
),
Input
(
[
None
,
args
.
n_head
,
None
,
None
],
"float32"
,
name
=
"src_slf_attn_bias"
),
Input
(
[
None
,
None
],
"int64"
,
name
=
"trg_word"
),
Input
(
[
None
,
None
],
"int64"
,
name
=
"trg_pos"
),
Input
(
[
None
,
args
.
n_head
,
None
,
None
],
"float32"
,
name
=
"trg_slf_attn_bias"
),
Input
(
[
None
,
args
.
n_head
,
None
,
None
],
"float32"
,
name
=
"trg_src_attn_bias"
),
]
labels
=
[
Input
(
[
None
,
1
],
"int64"
,
name
=
"label"
),
Input
(
[
None
,
1
],
"float32"
,
name
=
"weight"
),
]
transformer
=
Transformer
(
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
max_length
+
1
,
args
.
n_layer
,
args
.
n_head
,
args
.
d_key
,
args
.
d_value
,
args
.
d_model
,
args
.
d_inner_hid
,
args
.
prepostprocess_dropout
,
args
.
attention_dropout
,
args
.
relu_dropout
,
args
.
preprocess_cmd
,
args
.
postprocess_cmd
,
args
.
weight_sharing
,
args
.
bos_idx
,
args
.
eos_idx
)
transformer
.
prepare
(
fluid
.
optimizer
.
Adam
(
learning_rate
=
fluid
.
layers
.
noam_decay
(
args
.
d_model
,
args
.
warmup_steps
),
# args.learning_rate),
beta1
=
args
.
beta1
,
beta2
=
args
.
beta2
,
epsilon
=
float
(
args
.
eps
),
parameter_list
=
transformer
.
parameters
()),
CrossEntropyCriterion
(
args
.
label_smooth_eps
),
inputs
=
inputs
,
labels
=
labels
)
## init from some checkpoint, to resume the previous training
if
args
.
init_from_checkpoint
:
transformer
.
load
(
os
.
path
.
join
(
args
.
init_from_checkpoint
,
"transformer"
))
## init from some pretrain models, to better solve the current task
if
args
.
init_from_pretrain_model
:
transformer
.
load
(
os
.
path
.
join
(
args
.
init_from_pretrain_model
,
"transformer"
),
reset_optimizer
=
True
)
# the best cross-entropy value with label smoothing
loss_normalizer
=
-
(
(
1.
-
args
.
label_smooth_eps
)
*
np
.
log
(
(
1.
-
args
.
label_smooth_eps
))
+
args
.
label_smooth_eps
*
np
.
log
(
args
.
label_smooth_eps
/
(
args
.
trg_vocab_size
-
1
)
+
1e-20
))
transformer
.
fit
(
train_loader
=
train_loader
,
eval_loader
=
val_loader
,
epochs
=
1
,
eval_freq
=
1
,
save_freq
=
1
,
verbose
=
2
,
callbacks
=
[
LoggerCallback
(
log_freq
=
args
.
print_step
,
loss_normalizer
=
loss_normalizer
)
])
max_length
=
args
.
max_length
)
train_loader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
places
=
None
,
feed_list
=
[
x
.
forward
()
for
x
in
inputs
+
labels
],
num_workers
=
0
,
return_list
=
True
)
transformer
=
Transformer
(
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
max_length
+
1
,
args
.
n_layer
,
args
.
n_head
,
args
.
d_key
,
args
.
d_value
,
args
.
d_model
,
args
.
d_inner_hid
,
args
.
prepostprocess_dropout
,
args
.
attention_dropout
,
args
.
relu_dropout
,
args
.
preprocess_cmd
,
args
.
postprocess_cmd
,
args
.
weight_sharing
,
args
.
bos_idx
,
args
.
eos_idx
)
transformer
.
prepare
(
fluid
.
optimizer
.
Adam
(
learning_rate
=
fluid
.
layers
.
noam_decay
(
args
.
d_model
,
args
.
warmup_steps
),
# args.learning_rate),
beta1
=
args
.
beta1
,
beta2
=
args
.
beta2
,
epsilon
=
float
(
args
.
eps
),
parameter_list
=
transformer
.
parameters
()),
CrossEntropyCriterion
(
args
.
label_smooth_eps
),
inputs
=
inputs
,
labels
=
labels
)
## init from some checkpoint, to resume the previous training
if
args
.
init_from_checkpoint
:
transformer
.
load
(
os
.
path
.
join
(
args
.
init_from_checkpoint
,
"transformer"
))
## init from some pretrain models, to better solve the current task
if
args
.
init_from_pretrain_model
:
transformer
.
load
(
os
.
path
.
join
(
args
.
init_from_pretrain_model
,
"transformer"
),
reset_optimizer
=
True
)
# the best cross-entropy value with label smoothing
loss_normalizer
=
-
(
(
1.
-
args
.
label_smooth_eps
)
*
np
.
log
(
(
1.
-
args
.
label_smooth_eps
))
+
args
.
label_smooth_eps
*
np
.
log
(
args
.
label_smooth_eps
/
(
args
.
trg_vocab_size
-
1
)
+
1e-20
))
transformer
.
fit
(
train_loader
=
train_loader
,
eval_loader
=
None
,
epochs
=
1
,
eval_freq
=
1
,
save_freq
=
1
,
verbose
=
2
,
callbacks
=
[
LoggerCallback
(
log_freq
=
args
.
print_step
,
loss_normalizer
=
loss_normalizer
)
])
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录