Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
e6436070
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e6436070
编写于
1月 17, 2020
作者:
G
Guo Sheng
提交者:
pkpk
1月 17, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update dygraph Transformer to be consistent with its graph counterpart. (#4206)
上级
801edd0d
变更
12
展开全部
隐藏空白更改
内联
并排
Showing
12 changed file
with
2081 addition
and
1002 deletion
+2081
-1002
dygraph/transformer/README.md
dygraph/transformer/README.md
+214
-122
dygraph/transformer/data_util.py
dygraph/transformer/data_util.py
+0
-75
dygraph/transformer/images/multi_head_attention.png
dygraph/transformer/images/multi_head_attention.png
+0
-0
dygraph/transformer/images/transformer_network.png
dygraph/transformer/images/transformer_network.png
+0
-0
dygraph/transformer/model.py
dygraph/transformer/model.py
+525
-530
dygraph/transformer/predict.py
dygraph/transformer/predict.py
+96
-108
dygraph/transformer/reader.py
dygraph/transformer/reader.py
+550
-0
dygraph/transformer/train.py
dygraph/transformer/train.py
+177
-167
dygraph/transformer/transformer.yaml
dygraph/transformer/transformer.yaml
+108
-0
dygraph/transformer/utils/__init__.py
dygraph/transformer/utils/__init__.py
+0
-0
dygraph/transformer/utils/check.py
dygraph/transformer/utils/check.py
+61
-0
dygraph/transformer/utils/configure.py
dygraph/transformer/utils/configure.py
+350
-0
未找到文件。
dygraph/transformer/README.md
浏览文件 @
e6436070
此差异已折叠。
点击以展开。
dygraph/transformer/data_util.py
已删除
100644 → 0
浏览文件 @
801edd0d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
numpy
as
np
from
paddle.fluid.dygraph
import
to_variable
def
pad_batch_data
(
insts
,
pad_idx
,
n_head
,
is_target
=
False
,
is_label
=
False
,
return_attn_bias
=
True
,
return_max_len
=
True
,
return_num_token
=
False
):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list
=
[]
max_len
=
max
(
len
(
inst
)
for
inst
in
insts
)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data
=
np
.
array
(
[
inst
+
[
pad_idx
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
])
return_list
+=
[
inst_data
.
astype
(
"int64"
).
reshape
([
-
1
,
1
])]
if
is_label
:
# label weight
inst_weight
=
np
.
array
([[
1.
]
*
len
(
inst
)
+
[
0.
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
])
return_list
+=
[
inst_weight
.
astype
(
"float32"
).
reshape
([
-
1
,
1
])]
else
:
# position data
inst_pos
=
np
.
array
([
list
(
range
(
0
,
len
(
inst
)))
+
[
0
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
])
return_list
+=
[
inst_pos
.
astype
(
"int64"
).
reshape
([
-
1
,
1
])]
if
return_attn_bias
:
if
is_target
:
# This is used to avoid attention on paddings and subsequent
# words.
slf_attn_bias_data
=
np
.
ones
((
inst_data
.
shape
[
0
],
max_len
,
max_len
))
slf_attn_bias_data
=
np
.
triu
(
slf_attn_bias_data
,
1
).
reshape
([
-
1
,
1
,
max_len
,
max_len
])
slf_attn_bias_data
=
np
.
tile
(
slf_attn_bias_data
,
[
1
,
n_head
,
1
,
1
])
*
[
-
1e9
]
else
:
# This is used to avoid attention on paddings.
slf_attn_bias_data
=
np
.
array
([[
0
]
*
len
(
inst
)
+
[
-
1e9
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
])
slf_attn_bias_data
=
np
.
tile
(
slf_attn_bias_data
.
reshape
([
-
1
,
1
,
1
,
max_len
]),
[
1
,
n_head
,
max_len
,
1
])
return_list
+=
[
slf_attn_bias_data
.
astype
(
"float32"
)]
if
return_max_len
:
return_list
+=
[
max_len
]
if
return_num_token
:
num_token
=
0
for
inst
in
insts
:
num_token
+=
len
(
inst
)
return_list
+=
[
num_token
]
return
return_list
if
len
(
return_list
)
>
1
else
return_list
[
0
]
dygraph/transformer/images/multi_head_attention.png
0 → 100644
浏览文件 @
e6436070
104.5 KB
dygraph/transformer/images/transformer_network.png
0 → 100644
浏览文件 @
e6436070
259.1 KB
dygraph/transformer/model.py
浏览文件 @
e6436070
此差异已折叠。
点击以展开。
dygraph/transformer/predict.py
浏览文件 @
e6436070
...
@@ -12,72 +12,22 @@
...
@@ -12,72 +12,22 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
__future__
import
print_function
import
logging
import
argparse
import
os
import
ast
import
six
import
sys
import
time
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
paddle.dataset.wmt16
as
wmt16
from
model
import
TransFormer
from
config
import
*
from
data_util
import
*
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
"Arguments for Inference"
)
parser
.
add_argument
(
"--use_data_parallel"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"The flag indicating whether to shuffle instances in each pass."
)
parser
.
add_argument
(
"--model_file"
,
type
=
str
,
default
=
"transformer_params"
,
help
=
"Load model from the file named `model_file.pdparams`."
)
parser
.
add_argument
(
"--output_file"
,
type
=
str
,
default
=
"predict.txt"
,
help
=
"The file to output the translation results of predict_file to."
)
parser
.
add_argument
(
'opts'
,
help
=
'See config.py for all options'
,
default
=
None
,
nargs
=
argparse
.
REMAINDER
)
args
=
parser
.
parse_args
()
merge_cfg_from_list
(
args
.
opts
,
[
InferTaskConfig
,
ModelHyperParams
])
return
args
def
prepare_infer_input
(
insts
,
src_pad_idx
,
bos_idx
,
n_head
):
"""
inputs for inferencs
"""
src_word
,
src_pos
,
src_slf_attn_bias
,
src_max_len
=
pad_batch_data
(
[
inst
[
0
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
# start tokens
trg_word
=
np
.
asarray
([[
bos_idx
]]
*
len
(
insts
),
dtype
=
"int64"
)
trg_src_attn_bias
=
np
.
tile
(
src_slf_attn_bias
[:,
:,
::
src_max_len
,
:],
[
1
,
1
,
1
,
1
]).
astype
(
"float32"
)
trg_word
=
trg_word
.
reshape
(
-
1
,
1
,
1
)
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
,
1
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
,
1
)
data_inputs
=
[
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_src_attn_bias
]
var_inputs
=
[]
from
utils.configure
import
PDConfig
for
i
,
field
in
enumerate
(
encoder_data_input_fields
+
from
utils.check
import
check_gpu
,
check_version
fast_decoder_data_input_fields
):
var_inputs
.
append
(
to_variable
(
data_inputs
[
i
],
name
=
field
))
enc_inputs
=
var_inputs
[
0
:
len
(
encoder_data_input_fields
)]
# include task-specific libs
dec_inputs
=
var_inputs
[
len
(
encoder_data_input_fields
):]
import
reader
return
enc_inputs
,
dec_inputs
from
model
import
Transformer
,
position_encoding_init
def
post_process_seq
(
seq
,
bos_idx
,
eos_idx
,
output_bos
=
False
,
output_eos
=
False
):
def
post_process_seq
(
seq
,
bos_idx
,
eos_idx
,
output_bos
=
False
,
output_eos
=
False
):
...
@@ -96,60 +46,98 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
...
@@ -96,60 +46,98 @@ def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
return
seq
return
seq
def
infer
(
args
):
def
do_predict
(
args
):
place
=
fluid
.
CUDAPlace
(
fluid
.
dygraph
.
parallel
.
Env
().
dev_id
)
\
if
args
.
use_cuda
:
if
args
.
use_data_parallel
else
fluid
.
CUDAPlace
(
0
)
place
=
fluid
.
CUDAPlace
(
0
)
else
:
place
=
fluid
.
CPUPlace
()
# define the data generator
processor
=
reader
.
DataProcessor
(
fpattern
=
args
.
predict_file
,
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
token_delimiter
=
args
.
token_delimiter
,
use_token_batch
=
False
,
batch_size
=
args
.
batch_size
,
device_count
=
1
,
pool_size
=
args
.
pool_size
,
sort_type
=
reader
.
SortType
.
NONE
,
shuffle
=
False
,
shuffle_batch
=
False
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
],
max_length
=
args
.
max_length
,
n_head
=
args
.
n_head
)
batch_generator
=
processor
.
data_generator
(
phase
=
"predict"
,
place
=
place
)
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
args
.
unk_idx
=
processor
.
get_vocab_summary
()
trg_idx2word
=
reader
.
DataProcessor
.
load_dict
(
dict_path
=
args
.
trg_vocab_fpath
,
reverse
=
True
)
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
args
.
unk_idx
=
processor
.
get_vocab_summary
()
with
fluid
.
dygraph
.
guard
(
place
):
with
fluid
.
dygraph
.
guard
(
place
):
transformer
=
TransFormer
(
# define data loader
ModelHyperParams
.
src_vocab_size
,
test_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
10
)
ModelHyperParams
.
trg_vocab_size
,
ModelHyperParams
.
max_length
+
1
,
test_loader
.
set_batch_generator
(
batch_generator
,
places
=
place
)
ModelHyperParams
.
n_layer
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_key
,
ModelHyperParams
.
d_value
,
# define model
ModelHyperParams
.
d_model
,
ModelHyperParams
.
d_inner_hid
,
transformer
=
Transformer
(
ModelHyperParams
.
prepostprocess_dropout
,
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
max_length
+
1
,
ModelHyperParams
.
attention_dropout
,
ModelHyperParams
.
relu_dropout
,
args
.
n_layer
,
args
.
n_head
,
args
.
d_key
,
args
.
d_value
,
args
.
d_model
,
ModelHyperParams
.
preprocess_cmd
,
ModelHyperParams
.
postprocess_cmd
,
args
.
d_inner_hid
,
args
.
prepostprocess_dropout
,
ModelHyperParams
.
weight_sharing
)
args
.
attention_dropout
,
args
.
relu_dropout
,
args
.
preprocess_cmd
,
# load checkpoint
args
.
postprocess_cmd
,
args
.
weight_sharing
,
args
.
bos_idx
,
model_dict
,
_
=
fluid
.
load_dygraph
(
args
.
model_file
)
args
.
eos_idx
)
# load the trained model
assert
args
.
init_from_params
,
(
"Please set init_from_params to load the infer model."
)
model_dict
,
_
=
fluid
.
load_dygraph
(
os
.
path
.
join
(
args
.
init_from_params
,
"transformer"
))
# to avoid a longer length than training, reset the size of position
# encoding to max_length
model_dict
[
"encoder.pos_encoder.weight"
]
=
position_encoding_init
(
args
.
max_length
+
1
,
args
.
d_model
)
model_dict
[
"decoder.pos_encoder.weight"
]
=
position_encoding_init
(
args
.
max_length
+
1
,
args
.
d_model
)
transformer
.
load_dict
(
model_dict
)
transformer
.
load_dict
(
model_dict
)
print
(
"checkpoint loaded"
)
# start evaluate mode
transformer
.
eval
()
reader
=
paddle
.
batch
(
wmt16
.
test
(
ModelHyperParams
.
src_vocab_size
,
# set evaluate mode
ModelHyperParams
.
trg_vocab_size
),
transformer
.
eval
()
batch_size
=
InferTaskConfig
.
batch_size
)
id2word
=
wmt16
.
get_dict
(
"de"
,
ModelHyperParams
.
trg_vocab_size
,
reverse
=
True
)
f
=
open
(
args
.
output_file
,
"wb"
)
f
=
open
(
args
.
output_file
,
"wb"
)
for
batch
in
reader
():
for
input_data
in
test_loader
():
enc_inputs
,
dec_inputs
=
prepare_infer_input
(
(
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
batch
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
bos_idx
,
trg_src_attn_bias
)
=
input_data
ModelHyperParams
.
n_head
)
print
(
"enc inputs"
,
enc_inputs
[
0
].
shape
)
finished_seq
,
finished_scores
=
transformer
.
beam_search
(
finished_seq
,
finished_scores
=
transformer
.
beam_search
(
enc_inputs
,
src_word
,
dec_inputs
,
src_pos
,
bos_id
=
ModelHyperParams
.
bos_idx
,
src_slf_attn_bias
,
eos_id
=
ModelHyperParams
.
eos_idx
,
trg_word
,
max_len
=
InferTaskConfig
.
max_out_len
,
trg_src_attn_bias
,
alpha
=
InferTaskConfig
.
alpha
)
bos_id
=
args
.
bos_idx
,
eos_id
=
args
.
eos_idx
,
beam_size
=
args
.
beam_size
,
max_len
=
args
.
max_out_len
)
finished_seq
=
finished_seq
.
numpy
()
finished_seq
=
finished_seq
.
numpy
()
finished_scores
=
finished_scores
.
numpy
()
finished_scores
=
finished_scores
.
numpy
()
for
ins
in
finished_seq
:
for
ins
in
finished_seq
:
for
beam
in
ins
:
for
beam_idx
,
beam
in
enumerate
(
ins
):
id_list
=
post_process_seq
(
beam
,
ModelHyperParams
.
bos_idx
,
if
beam_idx
>=
args
.
n_best
:
break
ModelHyperParams
.
eos_idx
)
id_list
=
post_process_seq
(
beam
,
args
.
bos_idx
,
args
.
eos_idx
)
word_list
=
[
id2word
[
id
]
for
id
in
id_list
]
word_list
=
[
trg_idx2word
[
id
]
for
id
in
id_list
]
sequence
=
" "
.
join
(
word_list
)
+
"
\n
"
sequence
=
b
" "
.
join
(
word_list
)
+
b
"
\n
"
f
.
write
(
sequence
.
encode
(
"utf8"
))
f
.
write
(
sequence
)
break
# only print the best
if
__name__
==
"__main__"
:
if
__name__
==
'__main__'
:
args
=
PDConfig
(
yaml_file
=
"./transformer.yaml"
)
args
=
parse_args
()
args
.
build
()
infer
(
args
)
args
.
Print
()
check_gpu
(
args
.
use_cuda
)
check_version
()
do_predict
(
args
)
dygraph/transformer/reader.py
0 → 100644
浏览文件 @
e6436070
此差异已折叠。
点击以展开。
dygraph/transformer/train.py
浏览文件 @
e6436070
...
@@ -12,185 +12,195 @@
...
@@ -12,185 +12,195 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
__future__
import
print_function
import
logging
import
argparse
import
os
import
ast
import
six
import
sys
import
time
import
numpy
as
np
import
numpy
as
np
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
paddle.dataset.wmt16
as
wmt16
from
utils.configure
import
PDConfig
from
model
import
TransFormer
,
NoamDecay
from
utils.check
import
check_gpu
,
check_version
from
config
import
*
from
data_util
import
*
# include task-specific libs
import
reader
from
model
import
Transformer
,
CrossEntropyCriterion
,
NoamDecay
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
"Arguments for Training"
)
parser
.
add_argument
(
def
do_train
(
args
):
"--use_data_parallel"
,
if
args
.
use_cuda
:
type
=
ast
.
literal_eval
,
trainer_count
=
fluid
.
dygraph
.
parallel
.
Env
().
nranks
default
=
False
,
place
=
fluid
.
CUDAPlace
(
fluid
.
dygraph
.
parallel
.
Env
().
dev_id
help
=
"The flag indicating whether to use multi-GPU."
)
)
if
trainer_count
>
1
else
fluid
.
CUDAPlace
(
0
)
parser
.
add_argument
(
else
:
"--model_file"
,
trainer_count
=
1
type
=
str
,
place
=
fluid
.
CPUPlace
()
default
=
"transformer_params"
,
help
=
"Save the model as a file named `model_file.pdparams`."
)
# define the data generator
parser
.
add_argument
(
processor
=
reader
.
DataProcessor
(
fpattern
=
args
.
training_file
,
'opts'
,
src_vocab_fpath
=
args
.
src_vocab_fpath
,
help
=
'See config.py for all options'
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
default
=
None
,
token_delimiter
=
args
.
token_delimiter
,
nargs
=
argparse
.
REMAINDER
)
use_token_batch
=
args
.
use_token_batch
,
args
=
parser
.
parse_args
()
batch_size
=
args
.
batch_size
,
merge_cfg_from_list
(
args
.
opts
,
[
TrainTaskConfig
,
ModelHyperParams
])
device_count
=
trainer_count
,
return
args
pool_size
=
args
.
pool_size
,
sort_type
=
args
.
sort_type
,
shuffle
=
args
.
shuffle
,
def
prepare_train_input
(
insts
,
src_pad_idx
,
trg_pad_idx
,
n_head
):
shuffle_batch
=
args
.
shuffle_batch
,
"""
start_mark
=
args
.
special_token
[
0
],
inputs for training
end_mark
=
args
.
special_token
[
1
],
"""
unk_mark
=
args
.
special_token
[
2
],
src_word
,
src_pos
,
src_slf_attn_bias
,
src_max_len
=
pad_batch_data
(
max_length
=
args
.
max_length
,
[
inst
[
0
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
n_head
=
args
.
n_head
)
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
)
batch_generator
=
processor
.
data_generator
(
phase
=
"train"
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
)
if
trainer_count
>
1
:
# for multi-process gpu training
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_max_len
=
pad_batch_data
(
batch_generator
=
fluid
.
contrib
.
reader
.
distributed_batch_reader
(
[
inst
[
1
]
for
inst
in
insts
],
trg_pad_idx
,
n_head
,
is_target
=
True
)
batch_generator
)
trg_word
=
trg_word
.
reshape
(
-
1
,
trg_max_len
)
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
trg_pos
=
trg_pos
.
reshape
(
-
1
,
trg_max_len
)
args
.
unk_idx
=
processor
.
get_vocab_summary
()
trg_src_attn_bias
=
np
.
tile
(
src_slf_attn_bias
[:,
:,
::
src_max_len
,
:],
[
1
,
1
,
trg_max_len
,
1
]).
astype
(
"float32"
)
lbl_word
,
lbl_weight
,
num_token
=
pad_batch_data
(
[
inst
[
2
]
for
inst
in
insts
],
trg_pad_idx
,
n_head
,
is_target
=
False
,
is_label
=
True
,
return_attn_bias
=
False
,
return_max_len
=
False
,
return_num_token
=
True
)
data_inputs
=
[
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
lbl_word
,
lbl_weight
]
var_inputs
=
[]
for
i
,
field
in
enumerate
(
encoder_data_input_fields
+
decoder_data_input_fields
[:
-
1
]
+
label_data_input_fields
):
var_inputs
.
append
(
to_variable
(
data_inputs
[
i
],
name
=
field
))
enc_inputs
=
var_inputs
[
0
:
len
(
encoder_data_input_fields
)]
dec_inputs
=
var_inputs
[
len
(
encoder_data_input_fields
):
len
(
encoder_data_input_fields
)
+
len
(
decoder_data_input_fields
[:
-
1
])]
label
=
var_inputs
[
-
2
]
weights
=
var_inputs
[
-
1
]
return
enc_inputs
,
dec_inputs
,
label
,
weights
def
train
(
args
):
"""
train models
:return:
"""
trainer_count
=
fluid
.
dygraph
.
parallel
.
Env
().
nranks
place
=
fluid
.
CUDAPlace
(
fluid
.
dygraph
.
parallel
.
Env
().
dev_id
)
\
if
args
.
use_data_parallel
else
fluid
.
CUDAPlace
(
0
)
with
fluid
.
dygraph
.
guard
(
place
):
with
fluid
.
dygraph
.
guard
(
place
):
if
args
.
use_data_parallel
:
# set seed for CE
strategy
=
fluid
.
dygraph
.
parallel
.
prepare_context
()
random_seed
=
eval
(
str
(
args
.
random_seed
))
if
random_seed
is
not
None
:
fluid
.
default_main_program
().
random_seed
=
random_seed
fluid
.
default_startup_program
().
random_seed
=
random_seed
# define data loader
train_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
capacity
=
10
)
train_loader
.
set_batch_generator
(
batch_generator
,
places
=
place
)
# define model
# define model
transformer
=
TransFormer
(
transformer
=
Transformer
(
ModelHyperParams
.
src_vocab_size
,
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
max_length
+
1
,
ModelHyperParams
.
trg_vocab_size
,
ModelHyperParams
.
max_length
+
1
,
args
.
n_layer
,
args
.
n_head
,
args
.
d_key
,
args
.
d_value
,
args
.
d_model
,
ModelHyperParams
.
n_layer
,
ModelHyperParams
.
n_head
,
args
.
d_inner_hid
,
args
.
prepostprocess_dropout
,
ModelHyperParams
.
d_key
,
ModelHyperParams
.
d_value
,
args
.
attention_dropout
,
args
.
relu_dropout
,
args
.
preprocess_cmd
,
ModelHyperParams
.
d_model
,
ModelHyperParams
.
d_inner_hid
,
args
.
postprocess_cmd
,
args
.
weight_sharing
,
args
.
bos_idx
,
ModelHyperParams
.
prepostprocess_dropout
,
args
.
eos_idx
)
ModelHyperParams
.
attention_dropout
,
ModelHyperParams
.
relu_dropout
,
ModelHyperParams
.
preprocess_cmd
,
ModelHyperParams
.
postprocess_cmd
,
# define loss
ModelHyperParams
.
weight_sharing
,
TrainTaskConfig
.
label_smooth_eps
)
criterion
=
CrossEntropyCriterion
(
args
.
label_smooth_eps
)
# define optimizer
# define optimizer
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
NoamDecay
(
optimizer
=
fluid
.
optimizer
.
Adam
(
ModelHyperParams
.
d_model
,
TrainTaskConfig
.
warmup_steps
,
learning_rate
=
NoamDecay
(
args
.
d_model
,
args
.
warmup_steps
,
TrainTaskConfig
.
learning_rate
),
args
.
learning_rate
),
parameter_list
=
transformer
.
parameters
(),
beta1
=
args
.
beta1
,
beta1
=
TrainTaskConfig
.
beta1
,
beta2
=
args
.
beta2
,
beta2
=
TrainTaskConfig
.
beta2
,
epsilon
=
float
(
args
.
eps
),
epsilon
=
TrainTaskConfig
.
eps
)
parameter_list
=
transformer
.
parameters
())
#
if
args
.
use_data_parallel
:
## init from some checkpoint, to resume the previous training
if
args
.
init_from_checkpoint
:
model_dict
,
opt_dict
=
fluid
.
load_dygraph
(
os
.
path
.
join
(
args
.
init_from_checkpoint
,
"transformer"
))
transformer
.
load_dict
(
model_dict
)
optimizer
.
set_dict
(
opt_dict
)
## init from some pretrain models, to better solve the current task
if
args
.
init_from_pretrain_model
:
model_dict
,
_
=
fluid
.
load_dygraph
(
os
.
path
.
join
(
args
.
init_from_pretrain_model
,
"transformer"
))
transformer
.
load_dict
(
model_dict
)
if
trainer_count
>
1
:
strategy
=
fluid
.
dygraph
.
parallel
.
prepare_context
()
transformer
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
transformer
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
transformer
,
strategy
)
transformer
,
strategy
)
# define data generator for training and validation
# the best cross-entropy value with label smoothing
train_reader
=
paddle
.
batch
(
wmt16
.
train
(
loss_normalizer
=
-
(
ModelHyperParams
.
src_vocab_size
,
ModelHyperParams
.
trg_vocab_size
),
(
1.
-
args
.
label_smooth_eps
)
*
np
.
log
(
batch_size
=
TrainTaskConfig
.
batch_size
)
(
1.
-
args
.
label_smooth_eps
))
+
if
args
.
use_data_parallel
:
args
.
label_smooth_eps
*
np
.
log
(
args
.
label_smooth_eps
/
train_reader
=
fluid
.
contrib
.
reader
.
distributed_batch_reader
(
(
args
.
trg_vocab_size
-
1
)
+
1e-20
))
train_reader
)
val_reader
=
paddle
.
batch
(
wmt16
.
test
(
ModelHyperParams
.
src_vocab_size
,
step_idx
=
0
ModelHyperParams
.
trg_vocab_size
),
# train loop
batch_size
=
TrainTaskConfig
.
batch_size
)
for
pass_id
in
range
(
args
.
epoch
):
pass_start_time
=
time
.
time
()
# loop for training iterations
batch_id
=
0
for
i
in
range
(
TrainTaskConfig
.
pass_num
):
for
input_data
in
train_loader
():
dy_step
=
0
(
src_word
,
src_pos
,
src_slf_attn_bias
,
trg_word
,
trg_pos
,
sum_cost
=
0
trg_slf_attn_bias
,
trg_src_attn_bias
,
lbl_word
,
transformer
.
train
()
lbl_weight
)
=
input_data
for
batch
in
train_reader
():
logits
=
transformer
(
src_word
,
src_pos
,
src_slf_attn_bias
,
enc_inputs
,
dec_inputs
,
label
,
weights
=
prepare_train_input
(
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
batch
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
eos_idx
,
trg_src_attn_bias
)
ModelHyperParams
.
n_head
)
sum_cost
,
avg_cost
,
token_num
=
criterion
(
dy_sum_cost
,
dy_avg_cost
,
dy_predict
,
dy_token_num
=
transformer
(
logits
,
lbl_word
,
lbl_weight
)
enc_inputs
,
dec_inputs
,
label
,
weights
)
if
trainer_count
>
1
:
if
args
.
use_data_parallel
:
avg_cost
=
transformer
.
scale_loss
(
avg_cost
)
dy_avg_cost
=
transformer
.
scale_loss
(
dy_avg_cost
)
avg_cost
.
backward
()
dy_avg_cost
.
backward
()
transformer
.
apply_collective_grads
()
transformer
.
apply_collective_grads
()
else
:
else
:
dy_avg_cost
.
backward
()
avg_cost
.
backward
()
optimizer
.
minimize
(
dy_avg_cost
)
optimizer
.
minimize
(
avg_cost
)
transformer
.
clear_gradients
()
transformer
.
clear_gradients
()
dy_step
=
dy_step
+
1
if
step_idx
%
args
.
print_step
==
0
:
if
dy_step
%
10
==
0
:
total_avg_cost
=
avg_cost
.
numpy
()
*
trainer_count
print
(
"pass num : {}, batch_id: {}, dy_graph avg loss: {}"
.
format
(
i
,
dy_step
,
if
step_idx
==
0
:
dy_avg_cost
.
numpy
()
*
trainer_count
))
logging
.
info
(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
# switch to evaluation mode
"normalized loss: %f, ppl: %f"
%
transformer
.
eval
()
(
step_idx
,
pass_id
,
batch_id
,
total_avg_cost
,
sum_cost
=
0
total_avg_cost
-
loss_normalizer
,
token_num
=
0
np
.
exp
([
min
(
total_avg_cost
,
100
)])))
for
batch
in
val_reader
():
avg_batch_time
=
time
.
time
()
enc_inputs
,
dec_inputs
,
label
,
weights
=
prepare_train_input
(
else
:
batch
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
eos_idx
,
logging
.
info
(
ModelHyperParams
.
n_head
)
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s"
%
dy_sum_cost
,
dy_avg_cost
,
dy_predict
,
dy_token_num
=
transformer
(
(
step_idx
,
pass_id
,
batch_id
,
total_avg_cost
,
enc_inputs
,
dec_inputs
,
label
,
weights
)
total_avg_cost
-
loss_normalizer
,
sum_cost
+=
dy_sum_cost
.
numpy
()
np
.
exp
([
min
(
total_avg_cost
,
100
)]),
token_num
+=
dy_token_num
.
numpy
()
args
.
print_step
/
(
time
.
time
()
-
avg_batch_time
)))
print
(
"pass : {} finished, validation avg loss: {}"
.
format
(
avg_batch_time
=
time
.
time
()
i
,
sum_cost
/
token_num
))
if
step_idx
%
args
.
save_step
==
0
and
step_idx
!=
0
and
(
if
fluid
.
dygraph
.
parallel
.
Env
().
dev_id
==
0
:
trainer_count
==
1
fluid
.
save_dygraph
(
transformer
.
state_dict
(),
args
.
model_file
)
or
fluid
.
dygraph
.
parallel
.
Env
().
dev_id
==
0
):
if
args
.
save_model
:
model_dir
=
os
.
path
.
join
(
args
.
save_model
,
if
__name__
==
'__main__'
:
"step_"
+
str
(
step_idx
))
args
=
parse_args
()
if
not
os
.
path
.
exists
(
model_dir
):
train
(
args
)
os
.
makedirs
(
model_dir
)
fluid
.
save_dygraph
(
transformer
.
state_dict
(),
os
.
path
.
join
(
model_dir
,
"transformer"
))
fluid
.
save_dygraph
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
model_dir
,
"transformer"
))
batch_id
+=
1
step_idx
+=
1
time_consumed
=
time
.
time
()
-
pass_start_time
if
args
.
save_model
:
model_dir
=
os
.
path
.
join
(
args
.
save_model
,
"step_final"
)
if
not
os
.
path
.
exists
(
model_dir
):
os
.
makedirs
(
model_dir
)
fluid
.
save_dygraph
(
transformer
.
state_dict
(),
os
.
path
.
join
(
model_dir
,
"transformer"
))
fluid
.
save_dygraph
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
model_dir
,
"transformer"
))
if
__name__
==
"__main__"
:
args
=
PDConfig
(
yaml_file
=
"./transformer.yaml"
)
args
.
build
()
args
.
Print
()
check_gpu
(
args
.
use_cuda
)
check_version
()
do_train
(
args
)
dygraph/transformer/transformer.yaml
0 → 100644
浏览文件 @
e6436070
# used for continuous evaluation
enable_ce
:
False
# The frequency to save trained models when training.
save_step
:
10000
# The frequency to fetch and print output when training.
print_step
:
100
# path of the checkpoint, to resume the previous training
init_from_checkpoint
:
"
"
# path of the pretrain model, to better solve the current task
init_from_pretrain_model
:
"
"
# path of trained parameter, to make prediction
init_from_params
:
"
trained_params/step_100000/"
# the directory for saving model
save_model
:
"
trained_models"
# the directory for saving inference model.
inference_model_dir
:
"
infer_model"
# Set seed for CE or debug
random_seed
:
None
# The pattern to match training data files.
training_file
:
"
wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de"
# The pattern to match test data files.
predict_file
:
"
wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de"
# The file to output the translation results of predict_file to.
output_file
:
"
predict.txt"
# The path of vocabulary file of source language.
src_vocab_fpath
:
"
wmt16_ende_data_bpe/vocab_all.bpe.32000"
# The path of vocabulary file of target language.
trg_vocab_fpath
:
"
wmt16_ende_data_bpe/vocab_all.bpe.32000"
# The <bos>, <eos> and <unk> tokens in the dictionary.
special_token
:
[
"
<s>"
,
"
<e>"
,
"
<unk>"
]
# whether to use cuda
use_cuda
:
True
# args for reader, see reader.py for details
token_delimiter
:
"
"
use_token_batch
:
True
pool_size
:
200000
sort_type
:
"
pool"
shuffle
:
True
shuffle_batch
:
True
batch_size
:
4096
# Hyparams for training:
# the number of epoches for training
epoch
:
30
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate
:
2.0
beta1
:
0.9
beta2
:
0.997
eps
:
1e-9
# the parameters for learning rate scheduling.
warmup_steps
:
8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps
:
0.1
# Hyparams for generation:
# the parameters for beam search.
beam_size
:
5
max_out_len
:
256
# the number of decoded sentences to output.
n_best
:
1
# Hyparams for model:
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size
:
10000
# size of target word dictionay
trg_vocab_size
:
10000
# index for <bos> token
bos_idx
:
0
# index for <eos> token
eos_idx
:
1
# index for <unk> token
unk_idx
:
2
# max length of sequences deciding the size of position encoding table.
max_length
:
256
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model
:
512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid
:
2048
# the dimension that keys are projected to for dot-product attention.
d_key
:
64
# the dimension that values are projected to for dot-product attention.
d_value
:
64
# number of head used in multi-head attention.
n_head
:
8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer
:
6
# dropout rates of different modules.
prepostprocess_dropout
:
0.1
attention_dropout
:
0.1
relu_dropout
:
0.1
# to process before each sub-layer
preprocess_cmd
:
"
n"
# layer normalization
# to process after each sub-layer
postprocess_cmd
:
"
da"
# dropout + residual connection
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing
:
True
dygraph/transformer/utils/__init__.py
0 → 100644
浏览文件 @
e6436070
dygraph/transformer/utils/check.py
0 → 100644
浏览文件 @
e6436070
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
sys
import
paddle.fluid
as
fluid
import
logging
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'check_gpu'
,
'check_version'
]
def
check_gpu
(
use_gpu
):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err
=
"Config use_gpu cannot be set as true while you are "
\
"using paddlepaddle cpu version !
\n
Please try:
\n
"
\
"
\t
1. Install paddlepaddle-gpu to run model on GPU
\n
"
\
"
\t
2. Set use_gpu as false in config file to run "
\
"model on CPU"
try
:
if
use_gpu
and
not
fluid
.
is_compiled_with_cuda
():
logger
.
error
(
err
)
sys
.
exit
(
1
)
except
Exception
as
e
:
pass
def
check_version
():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err
=
"PaddlePaddle version 1.6 or higher is required, "
\
"or a suitable develop version is satisfied as well.
\n
"
\
"Please make sure the version is good with your code."
\
try
:
fluid
.
require_version
(
'1.6.0'
)
except
Exception
as
e
:
logger
.
error
(
err
)
sys
.
exit
(
1
)
dygraph/transformer/utils/configure.py
0 → 100644
浏览文件 @
e6436070
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
import
argparse
import
json
import
yaml
import
six
import
logging
logging_only_message
=
"%(message)s"
logging_details
=
"%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s"
class
JsonConfig
(
object
):
"""
A high-level api for handling json configure file.
"""
def
__init__
(
self
,
config_path
):
self
.
_config_dict
=
self
.
_parse
(
config_path
)
def
_parse
(
self
,
config_path
):
try
:
with
open
(
config_path
)
as
json_file
:
config_dict
=
json
.
load
(
json_file
)
except
:
raise
IOError
(
"Error in parsing bert model config file '%s'"
%
config_path
)
else
:
return
config_dict
def
__getitem__
(
self
,
key
):
return
self
.
_config_dict
[
key
]
def
print_config
(
self
):
for
arg
,
value
in
sorted
(
six
.
iteritems
(
self
.
_config_dict
)):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------------'
)
class
ArgumentGroup
(
object
):
def
__init__
(
self
,
parser
,
title
,
des
):
self
.
_group
=
parser
.
add_argument_group
(
title
=
title
,
description
=
des
)
def
add_arg
(
self
,
name
,
type
,
default
,
help
,
**
kwargs
):
type
=
str2bool
if
type
==
bool
else
type
self
.
_group
.
add_argument
(
"--"
+
name
,
default
=
default
,
type
=
type
,
help
=
help
+
' Default: %(default)s.'
,
**
kwargs
)
class
ArgConfig
(
object
):
"""
A high-level api for handling argument configs.
"""
def
__init__
(
self
):
parser
=
argparse
.
ArgumentParser
()
train_g
=
ArgumentGroup
(
parser
,
"training"
,
"training options."
)
train_g
.
add_arg
(
"epoch"
,
int
,
3
,
"Number of epoches for fine-tuning."
)
train_g
.
add_arg
(
"learning_rate"
,
float
,
5e-5
,
"Learning rate used to train with warmup."
)
train_g
.
add_arg
(
"lr_scheduler"
,
str
,
"linear_warmup_decay"
,
"scheduler of learning rate."
,
choices
=
[
'linear_warmup_decay'
,
'noam_decay'
])
train_g
.
add_arg
(
"weight_decay"
,
float
,
0.01
,
"Weight decay rate for L2 regularizer."
)
train_g
.
add_arg
(
"warmup_proportion"
,
float
,
0.1
,
"Proportion of training steps to perform linear learning rate warmup for."
)
train_g
.
add_arg
(
"save_steps"
,
int
,
1000
,
"The steps interval to save checkpoints."
)
train_g
.
add_arg
(
"use_fp16"
,
bool
,
False
,
"Whether to use fp16 mixed precision training."
)
train_g
.
add_arg
(
"loss_scaling"
,
float
,
1.0
,
"Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled."
)
train_g
.
add_arg
(
"pred_dir"
,
str
,
None
,
"Path to save the prediction results"
)
log_g
=
ArgumentGroup
(
parser
,
"logging"
,
"logging related."
)
log_g
.
add_arg
(
"skip_steps"
,
int
,
10
,
"The steps interval to print loss."
)
log_g
.
add_arg
(
"verbose"
,
bool
,
False
,
"Whether to output verbose log."
)
run_type_g
=
ArgumentGroup
(
parser
,
"run_type"
,
"running type options."
)
run_type_g
.
add_arg
(
"use_cuda"
,
bool
,
True
,
"If set, use GPU for training."
)
run_type_g
.
add_arg
(
"use_fast_executor"
,
bool
,
False
,
"If set, use fast parallel executor (in experiment)."
)
run_type_g
.
add_arg
(
"num_iteration_per_drop_scope"
,
int
,
1
,
"Ihe iteration intervals to clean up temporary variables."
)
run_type_g
.
add_arg
(
"do_train"
,
bool
,
True
,
"Whether to perform training."
)
run_type_g
.
add_arg
(
"do_predict"
,
bool
,
True
,
"Whether to perform prediction."
)
custom_g
=
ArgumentGroup
(
parser
,
"customize"
,
"customized options."
)
self
.
custom_g
=
custom_g
self
.
parser
=
parser
def
add_arg
(
self
,
name
,
dtype
,
default
,
descrip
):
self
.
custom_g
.
add_arg
(
name
,
dtype
,
default
,
descrip
)
def
build_conf
(
self
):
return
self
.
parser
.
parse_args
()
def
str2bool
(
v
):
# because argparse does not support to parse "true, False" as python
# boolean directly
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
def
print_arguments
(
args
,
log
=
None
):
if
not
log
:
print
(
'----------- Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
six
.
iteritems
(
vars
(
args
))):
print
(
'%s: %s'
%
(
arg
,
value
))
print
(
'------------------------------------------------'
)
else
:
log
.
info
(
'----------- Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
six
.
iteritems
(
vars
(
args
))):
log
.
info
(
'%s: %s'
%
(
arg
,
value
))
log
.
info
(
'------------------------------------------------'
)
class
PDConfig
(
object
):
"""
A high-level API for managing configuration files in PaddlePaddle.
Can jointly work with command-line-arugment, json files and yaml files.
"""
def
__init__
(
self
,
json_file
=
""
,
yaml_file
=
""
,
fuse_args
=
True
):
"""
Init funciton for PDConfig.
json_file: the path to the json configure file.
yaml_file: the path to the yaml configure file.
fuse_args: if fuse the json/yaml configs with argparse.
"""
assert
isinstance
(
json_file
,
str
)
assert
isinstance
(
yaml_file
,
str
)
if
json_file
!=
""
and
yaml_file
!=
""
:
raise
Warning
(
"json_file and yaml_file can not co-exist for now. please only use one configure file type."
)
return
self
.
args
=
None
self
.
arg_config
=
{}
self
.
json_config
=
{}
self
.
yaml_config
=
{}
parser
=
argparse
.
ArgumentParser
()
self
.
default_g
=
ArgumentGroup
(
parser
,
"default"
,
"default options."
)
self
.
yaml_g
=
ArgumentGroup
(
parser
,
"yaml"
,
"options from yaml."
)
self
.
json_g
=
ArgumentGroup
(
parser
,
"json"
,
"options from json."
)
self
.
com_g
=
ArgumentGroup
(
parser
,
"custom"
,
"customized options."
)
self
.
default_g
.
add_arg
(
"do_train"
,
bool
,
False
,
"Whether to perform training."
)
self
.
default_g
.
add_arg
(
"do_predict"
,
bool
,
False
,
"Whether to perform predicting."
)
self
.
default_g
.
add_arg
(
"do_eval"
,
bool
,
False
,
"Whether to perform evaluating."
)
self
.
default_g
.
add_arg
(
"do_save_inference_model"
,
bool
,
False
,
"Whether to perform model saving for inference."
)
# NOTE: args for profiler
self
.
default_g
.
add_arg
(
"is_profiler"
,
int
,
0
,
"the switch of profiler tools. (used for benchmark)"
)
self
.
default_g
.
add_arg
(
"profiler_path"
,
str
,
'./'
,
"the profiler output file path. (used for benchmark)"
)
self
.
default_g
.
add_arg
(
"max_iter"
,
int
,
0
,
"the max train batch num.(used for benchmark)"
)
self
.
parser
=
parser
if
json_file
!=
""
:
self
.
load_json
(
json_file
,
fuse_args
=
fuse_args
)
if
yaml_file
:
self
.
load_yaml
(
yaml_file
,
fuse_args
=
fuse_args
)
def
load_json
(
self
,
file_path
,
fuse_args
=
True
):
if
not
os
.
path
.
exists
(
file_path
):
raise
Warning
(
"the json file %s does not exist."
%
file_path
)
return
with
open
(
file_path
,
"r"
)
as
fin
:
self
.
json_config
=
json
.
loads
(
fin
.
read
())
fin
.
close
()
if
fuse_args
:
for
name
in
self
.
json_config
:
if
isinstance
(
self
.
json_config
[
name
],
list
):
self
.
json_g
.
add_arg
(
name
,
type
(
self
.
json_config
[
name
][
0
]),
self
.
json_config
[
name
],
"This is from %s"
%
file_path
,
nargs
=
len
(
self
.
json_config
[
name
]))
continue
if
not
isinstance
(
self
.
json_config
[
name
],
int
)
\
and
not
isinstance
(
self
.
json_config
[
name
],
float
)
\
and
not
isinstance
(
self
.
json_config
[
name
],
str
)
\
and
not
isinstance
(
self
.
json_config
[
name
],
bool
):
continue
self
.
json_g
.
add_arg
(
name
,
type
(
self
.
json_config
[
name
]),
self
.
json_config
[
name
],
"This is from %s"
%
file_path
)
def
load_yaml
(
self
,
file_path
,
fuse_args
=
True
):
if
not
os
.
path
.
exists
(
file_path
):
raise
Warning
(
"the yaml file %s does not exist."
%
file_path
)
return
with
open
(
file_path
,
"r"
)
as
fin
:
self
.
yaml_config
=
yaml
.
load
(
fin
,
Loader
=
yaml
.
SafeLoader
)
fin
.
close
()
if
fuse_args
:
for
name
in
self
.
yaml_config
:
if
isinstance
(
self
.
yaml_config
[
name
],
list
):
self
.
yaml_g
.
add_arg
(
name
,
type
(
self
.
yaml_config
[
name
][
0
]),
self
.
yaml_config
[
name
],
"This is from %s"
%
file_path
,
nargs
=
len
(
self
.
yaml_config
[
name
]))
continue
if
not
isinstance
(
self
.
yaml_config
[
name
],
int
)
\
and
not
isinstance
(
self
.
yaml_config
[
name
],
float
)
\
and
not
isinstance
(
self
.
yaml_config
[
name
],
str
)
\
and
not
isinstance
(
self
.
yaml_config
[
name
],
bool
):
continue
self
.
yaml_g
.
add_arg
(
name
,
type
(
self
.
yaml_config
[
name
]),
self
.
yaml_config
[
name
],
"This is from %s"
%
file_path
)
def
build
(
self
):
self
.
args
=
self
.
parser
.
parse_args
()
self
.
arg_config
=
vars
(
self
.
args
)
def
__add__
(
self
,
new_arg
):
assert
isinstance
(
new_arg
,
list
)
or
isinstance
(
new_arg
,
tuple
)
assert
len
(
new_arg
)
>=
3
assert
self
.
args
is
None
name
=
new_arg
[
0
]
dtype
=
new_arg
[
1
]
dvalue
=
new_arg
[
2
]
desc
=
new_arg
[
3
]
if
len
(
new_arg
)
==
4
else
"Description is not provided."
self
.
com_g
.
add_arg
(
name
,
dtype
,
dvalue
,
desc
)
return
self
def
__getattr__
(
self
,
name
):
if
name
in
self
.
arg_config
:
return
self
.
arg_config
[
name
]
if
name
in
self
.
json_config
:
return
self
.
json_config
[
name
]
if
name
in
self
.
yaml_config
:
return
self
.
yaml_config
[
name
]
raise
Warning
(
"The argument %s is not defined."
%
name
)
def
Print
(
self
):
print
(
"-"
*
70
)
for
name
in
self
.
arg_config
:
print
(
"%s:
\t\t\t\t
%s"
%
(
str
(
name
),
str
(
self
.
arg_config
[
name
])))
for
name
in
self
.
json_config
:
if
name
not
in
self
.
arg_config
:
print
(
"%s:
\t\t\t\t
%s"
%
(
str
(
name
),
str
(
self
.
json_config
[
name
])))
for
name
in
self
.
yaml_config
:
if
name
not
in
self
.
arg_config
:
print
(
"%s:
\t\t\t\t
%s"
%
(
str
(
name
),
str
(
self
.
yaml_config
[
name
])))
print
(
"-"
*
70
)
if
__name__
==
"__main__"
:
"""
pd_config = PDConfig(json_file = "./test/bert_config.json")
pd_config.build()
print(pd_config.do_train)
print(pd_config.hidden_size)
pd_config = PDConfig(yaml_file = "./test/bert_config.yaml")
pd_config.build()
print(pd_config.do_train)
print(pd_config.hidden_size)
"""
pd_config
=
PDConfig
(
yaml_file
=
"./test/bert_config.yaml"
)
pd_config
+=
(
"my_age"
,
int
,
18
,
"I am forever 18."
)
pd_config
.
build
()
print
(
pd_config
.
do_train
)
print
(
pd_config
.
hidden_size
)
print
(
pd_config
.
my_age
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录