Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
9fe93972
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9fe93972
编写于
7月 09, 2018
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change code arrangment
上级
336a73ba
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
218 addition
and
208 deletion
+218
-208
fluid/neural_machine_translation/transformer/train.py
fluid/neural_machine_translation/transformer/train.py
+218
-208
未找到文件。
fluid/neural_machine_translation/transformer/train.py
浏览文件 @
9fe93972
...
@@ -220,64 +220,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
...
@@ -220,64 +220,7 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
[
num_token
],
dtype
=
"float32"
)
[
num_token
],
dtype
=
"float32"
)
def
train
(
args
):
def
read_multiple
(
reader
,
count
,
clip_last
=
True
):
# priority: ENV > args > config
is_local
=
os
.
getenv
(
"PADDLE_IS_LOCAL"
,
"1"
)
if
is_local
==
'0'
:
args
.
local
=
False
print
args
if
args
.
device
==
'CPU'
:
TrainTaskConfig
.
use_gpu
=
False
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
if
training_role
==
"PSERVER"
or
(
not
TrainTaskConfig
.
use_gpu
):
place
=
fluid
.
CPUPlace
()
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
else
:
place
=
fluid
.
CUDAPlace
(
0
)
dev_count
=
fluid
.
core
.
get_cuda_device_count
()
exe
=
fluid
.
Executor
(
place
)
sum_cost
,
avg_cost
,
predict
,
token_num
=
transformer
(
ModelHyperParams
.
src_vocab_size
,
ModelHyperParams
.
trg_vocab_size
,
ModelHyperParams
.
max_length
+
1
,
ModelHyperParams
.
n_layer
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_key
,
ModelHyperParams
.
d_value
,
ModelHyperParams
.
d_model
,
ModelHyperParams
.
d_inner_hid
,
ModelHyperParams
.
dropout
,
ModelHyperParams
.
weight_sharing
,
TrainTaskConfig
.
label_smooth_eps
)
if
args
.
local
:
lr_scheduler
=
LearningRateScheduler
(
ModelHyperParams
.
d_model
,
TrainTaskConfig
.
warmup_steps
,
TrainTaskConfig
.
learning_rate
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
lr_scheduler
.
learning_rate
,
beta1
=
TrainTaskConfig
.
beta1
,
beta2
=
TrainTaskConfig
.
beta2
,
epsilon
=
TrainTaskConfig
.
eps
)
optimizer
.
minimize
(
sum_cost
)
elif
args
.
sync
==
False
:
optimizer
=
fluid
.
optimizer
.
SGD
(
0.003
)
optimizer
.
minimize
(
sum_cost
)
else
:
lr_decay
=
fluid
.
layers
\
.
learning_rate_scheduler
\
.
noam_decay
(
ModelHyperParams
.
d_model
,
TrainTaskConfig
.
warmup_steps
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
lr_decay
,
beta1
=
TrainTaskConfig
.
beta1
,
beta2
=
TrainTaskConfig
.
beta2
,
epsilon
=
TrainTaskConfig
.
eps
)
optimizer
.
minimize
(
sum_cost
)
def
train_loop
(
exe
,
train_progm
):
def
read_multiple
(
reader
,
count
=
dev_count
if
args
.
use_token_batch
else
1
,
clip_last
=
True
):
"""
"""
Stack data from reader for multi-devices.
Stack data from reader for multi-devices.
"""
"""
...
@@ -298,14 +241,14 @@ def train(args):
...
@@ -298,14 +241,14 @@ def train(args):
if
len
(
data
)
>
count
:
if
len
(
data
)
>
count
:
inst_num_per_part
=
len
(
data
)
//
count
inst_num_per_part
=
len
(
data
)
//
count
yield
[
yield
[
data
[
inst_num_per_part
*
i
:
inst_num_per_part
*
(
i
+
data
[
inst_num_per_part
*
i
:
inst_num_per_part
*
(
i
+
1
)]
1
)]
for
i
in
range
(
count
)
for
i
in
range
(
count
)
]
]
return
__impl__
return
__impl__
def
split_data
(
data
,
num_part
=
dev_count
):
def
split_data
(
data
,
num_part
):
"""
"""
Split data for each device.
Split data for each device.
"""
"""
...
@@ -318,47 +261,9 @@ def train(args):
...
@@ -318,47 +261,9 @@ def train(args):
for
i
in
range
(
num_part
)
for
i
in
range
(
num_part
)
]
]
# Initialize the parameters.
if
TrainTaskConfig
.
ckpt_path
:
fluid
.
io
.
load_persistables
(
exe
,
TrainTaskConfig
.
ckpt_path
)
#lr_scheduler.current_steps = TrainTaskConfig.start_step
else
:
print
"init fluid.framework.default_startup_program"
exe
.
run
(
fluid
.
framework
.
default_startup_program
())
train_data
=
reader
.
DataReader
(
def
test_context
(
train_progm
,
avg_cost
,
train_exe
,
dev_count
,
data_input_names
,
src_vocab_fpath
=
args
.
src_vocab_fpath
,
util_input_names
,
sum_cost
,
token_num
):
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
fpattern
=
args
.
train_file_pattern
,
use_token_batch
=
args
.
use_token_batch
,
batch_size
=
args
.
batch_size
*
(
1
if
args
.
use_token_batch
else
dev_count
),
pool_size
=
args
.
pool_size
,
sort_type
=
args
.
sort_type
,
shuffle
=
args
.
shuffle
,
shuffle_batch
=
args
.
shuffle_batch
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
],
# count start and end tokens out
max_length
=
ModelHyperParams
.
max_length
-
2
,
clip_last_batch
=
False
)
train_data
=
read_multiple
(
reader
=
train_data
.
batch_generator
,
count
=
dev_count
if
args
.
use_token_batch
else
1
)
build_strategy
=
fluid
.
BuildStrategy
()
# Since the token number differs among devices, customize gradient scale to
# use token average cost among multi-devices. and the gradient scale is
# `1 / token_number` for average cost.
build_strategy
.
gradient_scale_strategy
=
fluid
.
BuildStrategy
.
GradientScaleStrategy
.
Customized
train_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
TrainTaskConfig
.
use_gpu
,
loss_name
=
sum_cost
.
name
,
main_program
=
train_progm
,
build_strategy
=
build_strategy
)
def
test_context
():
# Context to do validation.
# Context to do validation.
test_program
=
train_progm
.
clone
()
test_program
=
train_progm
.
clone
()
with
fluid
.
program_guard
(
test_program
):
with
fluid
.
program_guard
(
test_program
):
...
@@ -369,8 +274,7 @@ def train(args):
...
@@ -369,8 +274,7 @@ def train(args):
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
fpattern
=
args
.
val_file_pattern
,
fpattern
=
args
.
val_file_pattern
,
use_token_batch
=
args
.
use_token_batch
,
use_token_batch
=
args
.
use_token_batch
,
batch_size
=
args
.
batch_size
*
batch_size
=
args
.
batch_size
*
(
1
if
args
.
use_token_batch
else
dev_count
),
(
1
if
args
.
use_token_batch
else
dev_count
),
pool_size
=
args
.
pool_size
,
pool_size
=
args
.
pool_size
,
sort_type
=
args
.
sort_type
,
sort_type
=
args
.
sort_type
,
start_mark
=
args
.
special_token
[
0
],
start_mark
=
args
.
special_token
[
0
],
...
@@ -390,22 +294,24 @@ def train(args):
...
@@ -390,22 +294,24 @@ def train(args):
def
test
(
exe
=
test_exe
):
def
test
(
exe
=
test_exe
):
test_total_cost
=
0
test_total_cost
=
0
test_total_token
=
0
test_total_token
=
0
test_data
=
read_multiple
(
reader
=
val_data
.
batch_generator
)
test_data
=
read_multiple
(
reader
=
val_data
.
batch_generator
,
count
=
dev_count
if
args
.
use_token_batch
else
1
)
for
batch_id
,
data
in
enumerate
(
test_data
()):
for
batch_id
,
data
in
enumerate
(
test_data
()):
feed_list
=
[]
feed_list
=
[]
for
place_id
,
data_buffer
in
enumerate
(
split_data
(
data
)):
for
place_id
,
data_buffer
in
enumerate
(
split_data
(
data
,
num_part
=
dev_count
)):
data_input_dict
,
util_input_dict
,
_
=
prepare_batch_input
(
data_input_dict
,
util_input_dict
,
_
=
prepare_batch_input
(
data_buffer
,
data_input_names
,
util_input_names
,
data_buffer
,
data_input_names
,
util_input_names
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_model
)
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_model
)
feed_list
.
append
(
feed_list
.
append
(
dict
(
data_input_dict
.
items
()
+
dict
(
data_input_dict
.
items
()
+
util_input_dict
.
items
()))
util_input_dict
.
items
()))
outs
=
exe
.
run
(
feed
=
feed_list
,
outs
=
exe
.
run
(
feed
=
feed_list
,
fetch_list
=
[
sum_cost
.
name
,
token_num
.
name
])
fetch_list
=
[
sum_cost
.
name
,
token_num
.
name
])
sum_cost_val
,
token_num_val
=
np
.
array
(
outs
[
0
]),
np
.
array
(
sum_cost_val
,
token_num_val
=
np
.
array
(
outs
[
0
]),
np
.
array
(
outs
[
1
])
outs
[
1
])
test_total_cost
+=
sum_cost_val
.
sum
()
test_total_cost
+=
sum_cost_val
.
sum
()
test_total_token
+=
token_num_val
.
sum
()
test_total_token
+=
token_num_val
.
sum
()
test_avg_cost
=
test_total_cost
/
test_total_token
test_avg_cost
=
test_total_cost
/
test_total_token
...
@@ -414,26 +320,73 @@ def train(args):
...
@@ -414,26 +320,73 @@ def train(args):
return
test
return
test
if
args
.
val_file_pattern
is
not
None
:
test
=
test_context
()
def
train_loop
(
exe
,
train_progm
,
dev_count
,
sum_cost
,
avg_cost
,
lr_scheduler
,
token_num
,
predict
):
# Initialize the parameters.
if
TrainTaskConfig
.
ckpt_path
:
fluid
.
io
.
load_persistables
(
exe
,
TrainTaskConfig
.
ckpt_path
)
lr_scheduler
.
current_steps
=
TrainTaskConfig
.
start_step
else
:
print
"init fluid.framework.default_startup_program"
exe
.
run
(
fluid
.
framework
.
default_startup_program
())
train_data
=
reader
.
DataReader
(
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
fpattern
=
args
.
train_file_pattern
,
use_token_batch
=
args
.
use_token_batch
,
batch_size
=
args
.
batch_size
*
(
1
if
args
.
use_token_batch
else
dev_count
),
pool_size
=
args
.
pool_size
,
sort_type
=
args
.
sort_type
,
shuffle
=
args
.
shuffle
,
shuffle_batch
=
args
.
shuffle_batch
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
],
# count start and end tokens out
max_length
=
ModelHyperParams
.
max_length
-
2
,
clip_last_batch
=
False
)
train_data
=
read_multiple
(
reader
=
train_data
.
batch_generator
,
count
=
dev_count
if
args
.
use_token_batch
else
1
)
build_strategy
=
fluid
.
BuildStrategy
()
# Since the token number differs among devices, customize gradient scale to
# use token average cost among multi-devices. and the gradient scale is
# `1 / token_number` for average cost.
build_strategy
.
gradient_scale_strategy
=
fluid
.
BuildStrategy
.
GradientScaleStrategy
.
Customized
train_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
TrainTaskConfig
.
use_gpu
,
loss_name
=
sum_cost
.
name
,
main_program
=
train_progm
,
build_strategy
=
build_strategy
)
data_input_names
=
encoder_data_input_fields
+
decoder_data_input_fields
[:
data_input_names
=
encoder_data_input_fields
+
decoder_data_input_fields
[:
-
1
]
+
label_data_input_fields
-
1
]
+
label_data_input_fields
util_input_names
=
encoder_util_input_fields
+
decoder_util_input_fields
util_input_names
=
encoder_util_input_fields
+
decoder_util_input_fields
if
args
.
val_file_pattern
is
not
None
:
test
=
test_context
(
train_progm
,
avg_cost
,
train_exe
,
dev_count
,
data_input_names
,
util_input_names
,
sum_cost
,
token_num
)
init
=
False
init
=
False
for
pass_id
in
xrange
(
TrainTaskConfig
.
pass_num
):
for
pass_id
in
xrange
(
TrainTaskConfig
.
pass_num
):
pass_start_time
=
time
.
time
()
pass_start_time
=
time
.
time
()
for
batch_id
,
data
in
enumerate
(
train_data
()):
for
batch_id
,
data
in
enumerate
(
train_data
()):
feed_list
=
[]
feed_list
=
[]
total_num_token
=
0
total_num_token
=
0
for
place_id
,
data_buffer
in
enumerate
(
split_data
(
data
)):
for
place_id
,
data_buffer
in
enumerate
(
split_data
(
data
,
num_part
=
dev_count
)):
data_input_dict
,
util_input_dict
,
num_token
=
prepare_batch_input
(
data_input_dict
,
util_input_dict
,
num_token
=
prepare_batch_input
(
data_buffer
,
data_input_names
,
util_input_names
,
data_buffer
,
data_input_names
,
util_input_names
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_model
)
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_model
)
total_num_token
+=
num_token
total_num_token
+=
num_token
feed_kv_pairs
=
data
_input_dict
.
items
(
feed_kv_pairs
=
data_input_dict
.
items
()
+
util
_input_dict
.
items
(
)
+
util_input_dict
.
items
(
)
)
if
args
.
local
:
if
args
.
local
:
lr_rate
=
lr_scheduler
.
update_learning_rate
()
lr_rate
=
lr_scheduler
.
update_learning_rate
()
feed_kv_pairs
+=
{
feed_kv_pairs
+=
{
...
@@ -449,23 +402,21 @@ def train(args):
...
@@ -449,23 +402,21 @@ def train(args):
feed_list
[
place_id
][
pos_enc_param_name
]
=
pos_enc
feed_list
[
place_id
][
pos_enc_param_name
]
=
pos_enc
for
feed_dict
in
feed_list
:
for
feed_dict
in
feed_list
:
feed_dict
[
sum_cost
.
name
+
"@GRAD"
]
=
1.
/
total_num_token
feed_dict
[
sum_cost
.
name
+
"@GRAD"
]
=
1.
/
total_num_token
outs
=
train_exe
.
run
(
outs
=
train_exe
.
run
(
fetch_list
=
[
sum_cost
.
name
,
token_num
.
name
],
fetch_list
=
[
sum_cost
.
name
,
token_num
.
name
],
feed
=
feed_list
)
feed
=
feed_list
)
train_exe
.
bcast_params
()
train_exe
.
bcast_params
()
sum_cost_val
,
token_num_val
=
np
.
array
(
outs
[
0
]),
np
.
array
(
outs
[
sum_cost_val
,
token_num_val
=
np
.
array
(
outs
[
0
]),
np
.
array
(
outs
[
1
])
1
])
total_sum_cost
=
sum_cost_val
.
sum
(
total_sum_cost
=
sum_cost_val
.
sum
(
)
# sum the cost from multi-devices
)
# sum the cost from multi-devices
total_token_num
=
token_num_val
.
sum
()
total_token_num
=
token_num_val
.
sum
()
total_avg_cost
=
total_sum_cost
/
total_token_num
total_avg_cost
=
total_sum_cost
/
total_token_num
print
(
print
(
"epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f"
%
"epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f"
(
pass_id
,
batch_id
,
total_sum_cost
,
total_avg_cost
,
%
(
pass_id
,
batch_id
,
total_sum_cost
,
total_avg_cost
,
np
.
exp
([
min
(
total_avg_cost
,
100
)])))
np
.
exp
([
min
(
total_avg_cost
,
100
)])))
init
=
True
init
=
True
# Validate and save the model for inference.
# Validate and save the model for inference.
print
(
"epoch: %d, "
%
pass_id
+
(
print
(
"epoch: %d, "
%
pass_id
+
"val avg loss: %f, val ppl: %f, "
%
test
()
(
"val avg loss: %f, val ppl: %f, "
%
test
()
if
args
.
val_file_pattern
is
not
None
else
""
)
+
"consumed %fs"
%
if
args
.
val_file_pattern
is
not
None
else
""
)
+
"consumed %fs"
%
(
time
.
time
()
-
pass_start_time
))
(
time
.
time
()
-
pass_start_time
))
fluid
.
io
.
save_persistables
(
fluid
.
io
.
save_persistables
(
...
@@ -477,9 +428,67 @@ def train(args):
...
@@ -477,9 +428,67 @@ def train(args):
"pass_"
+
str
(
pass_id
)
+
".infer.model"
),
"pass_"
+
str
(
pass_id
)
+
".infer.model"
),
data_input_names
[:
-
2
]
+
util_input_names
,
[
predict
],
exe
)
data_input_names
[:
-
2
]
+
util_input_names
,
[
predict
],
exe
)
def
train
(
args
):
# priority: ENV > args > config
is_local
=
os
.
getenv
(
"PADDLE_IS_LOCAL"
,
"1"
)
if
is_local
==
'0'
:
args
.
local
=
False
print
args
if
args
.
device
==
'CPU'
:
TrainTaskConfig
.
use_gpu
=
False
training_role
=
os
.
getenv
(
"TRAINING_ROLE"
,
"TRAINER"
)
if
training_role
==
"PSERVER"
or
(
not
TrainTaskConfig
.
use_gpu
):
place
=
fluid
.
CPUPlace
()
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
else
:
place
=
fluid
.
CUDAPlace
(
0
)
dev_count
=
fluid
.
core
.
get_cuda_device_count
()
exe
=
fluid
.
Executor
(
place
)
sum_cost
,
avg_cost
,
predict
,
token_num
=
transformer
(
ModelHyperParams
.
src_vocab_size
,
ModelHyperParams
.
trg_vocab_size
,
ModelHyperParams
.
max_length
+
1
,
ModelHyperParams
.
n_layer
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_key
,
ModelHyperParams
.
d_value
,
ModelHyperParams
.
d_model
,
ModelHyperParams
.
d_inner_hid
,
ModelHyperParams
.
dropout
,
ModelHyperParams
.
weight_sharing
,
TrainTaskConfig
.
label_smooth_eps
)
lr_scheduler
=
LearningRateScheduler
(
ModelHyperParams
.
d_model
,
TrainTaskConfig
.
warmup_steps
,
TrainTaskConfig
.
learning_rate
)
if
args
.
local
:
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
lr_scheduler
.
learning_rate
,
beta1
=
TrainTaskConfig
.
beta1
,
beta2
=
TrainTaskConfig
.
beta2
,
epsilon
=
TrainTaskConfig
.
eps
)
optimizer
.
minimize
(
sum_cost
)
elif
args
.
sync
==
False
:
optimizer
=
fluid
.
optimizer
.
SGD
(
0.003
)
optimizer
.
minimize
(
sum_cost
)
else
:
lr_decay
=
fluid
.
layers
\
.
learning_rate_scheduler
\
.
noam_decay
(
ModelHyperParams
.
d_model
,
TrainTaskConfig
.
warmup_steps
)
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
lr_decay
,
beta1
=
TrainTaskConfig
.
beta1
,
beta2
=
TrainTaskConfig
.
beta2
,
epsilon
=
TrainTaskConfig
.
eps
)
optimizer
.
minimize
(
sum_cost
)
if
args
.
local
:
if
args
.
local
:
print
(
"local start_up:"
)
print
(
"local start_up:"
)
train_loop
(
exe
,
fluid
.
default_main_program
())
train_loop
(
exe
,
fluid
.
default_main_program
(),
dev_count
,
sum_cost
,
avg_cost
,
lr_scheduler
,
token_num
,
predict
)
else
:
else
:
port
=
os
.
getenv
(
"PADDLE_PORT"
,
"6174"
)
port
=
os
.
getenv
(
"PADDLE_PORT"
,
"6174"
)
pserver_ips
=
os
.
getenv
(
"PADDLE_PSERVERS"
)
# ip,ip...
pserver_ips
=
os
.
getenv
(
"PADDLE_PSERVERS"
)
# ip,ip...
...
@@ -515,7 +524,8 @@ def train(args):
...
@@ -515,7 +524,8 @@ def train(args):
trainer_prog
=
t
.
get_trainer_program
()
trainer_prog
=
t
.
get_trainer_program
()
with
open
(
'trainer_prog.desc'
,
'w'
)
as
f
:
with
open
(
'trainer_prog.desc'
,
'w'
)
as
f
:
f
.
write
(
str
(
trainer_prog
))
f
.
write
(
str
(
trainer_prog
))
train_loop
(
exe
,
trainer_prog
)
train_loop
(
exe
,
trainer_prog
,
dev_count
,
sum_cost
,
avg_cost
,
lr_scheduler
,
token_num
,
predict
)
else
:
else
:
print
(
"environment var TRAINER_ROLE should be TRAINER os PSERVER"
)
print
(
"environment var TRAINER_ROLE should be TRAINER os PSERVER"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录