Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
0e47f4c4
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0e47f4c4
编写于
3月 26, 2020
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'master' of
https://github.com/PaddlePaddle/hapi
into hapi-transformer
上级
68dfe864
6d9e77b9
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
304 addition
and
673 deletion
+304
-673
callbacks.py
callbacks.py
+8
-8
distributed.py
distributed.py
+27
-96
mnist.py
mnist.py
+12
-3
model.py
model.py
+119
-102
tests/test_model.py
tests/test_model.py
+4
-3
transformer/reader.py
transformer/reader.py
+89
-418
transformer/train.py
transformer/train.py
+45
-43
未找到文件。
callbacks.py
浏览文件 @
0e47f4c4
...
...
@@ -16,7 +16,7 @@ import six
import
copy
from
progressbar
import
ProgressBar
from
distributed
import
get_local_rank
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
def
config_callbacks
(
callbacks
=
None
,
...
...
@@ -195,7 +195,7 @@ class ProgBarLogger(Callback):
self
.
steps
=
self
.
params
[
'steps'
]
self
.
epoch
=
epoch
self
.
train_step
=
0
if
self
.
verbose
and
self
.
epochs
and
get_local_rank
()
==
0
:
if
self
.
verbose
and
self
.
epochs
and
ParallelEnv
().
local_rank
==
0
:
print
(
'Epoch %d/%d'
%
(
epoch
+
1
,
self
.
epochs
))
self
.
train_progbar
=
ProgressBar
(
num
=
self
.
steps
,
verbose
=
self
.
verbose
)
...
...
@@ -213,8 +213,8 @@ class ProgBarLogger(Callback):
logs
=
logs
or
{}
self
.
train_step
+=
1
if
self
.
train_step
%
self
.
log_freq
==
0
and
self
.
verbose
and
get_local_rank
(
)
==
0
:
if
self
.
train_step
%
self
.
log_freq
==
0
and
self
.
verbose
and
ParallelEnv
(
)
.
local_rank
==
0
:
# if steps is not None, last step will update in on_epoch_end
if
self
.
steps
and
self
.
train_step
<
self
.
steps
:
self
.
_updates
(
logs
,
'train'
)
...
...
@@ -223,7 +223,7 @@ class ProgBarLogger(Callback):
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
logs
=
logs
or
{}
if
self
.
verbose
and
get_local_rank
()
==
0
:
if
self
.
verbose
and
ParallelEnv
().
local_rank
==
0
:
self
.
_updates
(
logs
,
'train'
)
def
on_eval_begin
(
self
,
logs
=
None
):
...
...
@@ -233,7 +233,7 @@ class ProgBarLogger(Callback):
self
.
evaled_samples
=
0
self
.
eval_progbar
=
ProgressBar
(
num
=
self
.
eval_steps
,
verbose
=
self
.
verbose
)
if
get_local_rank
()
==
0
:
if
ParallelEnv
().
local_rank
==
0
:
print
(
'Eval begin...'
)
def
on_eval_batch_end
(
self
,
step
,
logs
=
None
):
...
...
@@ -244,7 +244,7 @@ class ProgBarLogger(Callback):
def
on_eval_end
(
self
,
logs
=
None
):
logs
=
logs
or
{}
if
self
.
verbose
and
get_local_rank
()
==
0
:
if
self
.
verbose
and
ParallelEnv
().
local_rank
==
0
:
self
.
_updates
(
logs
,
'eval'
)
print
(
'Eval samples: %d'
%
(
self
.
evaled_samples
))
...
...
@@ -258,7 +258,7 @@ class ModelCheckpoint(Callback):
self
.
epoch
=
epoch
def
_is_save
(
self
):
return
self
.
model
and
self
.
save_dir
and
get_local_rank
()
==
0
return
self
.
model
and
self
.
save_dir
and
ParallelEnv
().
local_rank
==
0
def
on_epoch_end
(
self
,
epoch
,
logs
=
None
):
if
self
.
_is_save
()
and
self
.
epoch
%
self
.
save_freq
==
0
:
...
...
distributed.py
浏览文件 @
0e47f4c4
...
...
@@ -13,30 +13,21 @@
# limitations under the License.
import
os
import
sys
import
six
import
time
import
math
import
socket
import
contextlib
from
contextlib
import
closing
from
six
import
string_types
import
numpy
as
np
from
collections
import
OrderedDict
from
paddle
import
fluid
import
paddle.fluid.unique_name
as
nameGen
from
paddle.fluid
import
core
from
paddle
.fluid
import
framework
from
paddle
import
fluid
from
paddle.fluid.layers
import
collective
from
paddle.fluid.dygraph
import
to_variable
,
no_grad
,
layers
from
paddle.fluid.
framework
import
Variable
from
paddle.fluid.executor
import
global_scope
from
paddle.fluid.dygraph
.parallel
import
ParallelEnv
,
ParallelStrategy
from
paddle.fluid.
io
import
BatchSampler
from
paddle.fluid.dygraph.parallel
import
Env
,
DataParallel
,
ParallelStrategy
from
paddle.fluid.layers.collective
import
_c_allreduce
,
_c_allgather
,
_c_broadcast
,
\
_c_sync_comm_stream
,
_c_sync_calc_stream
from
paddle.fluid.io
import
BatchSampler
,
DataLoader
_parallel_context_initialized
=
False
__parallel_context_init
=
False
class
DistributedBatchSampler
(
BatchSampler
):
"""Sampler that restricts data loading to a subset of the dataset.
...
...
@@ -71,9 +62,10 @@ class DistributedBatchSampler(BatchSampler):
self
.
shuffle
=
shuffle
assert
isinstance
(
drop_last
,
bool
),
\
"drop_last should be a boolean number"
self
.
drop_last
=
drop_last
self
.
nranks
=
get_nranks
()
self
.
local_rank
=
get_local_rank
()
self
.
nranks
=
ParallelEnv
().
nranks
self
.
local_rank
=
ParallelEnv
().
local_rank
self
.
epoch
=
0
self
.
num_samples
=
int
(
math
.
ceil
(
len
(
self
.
dataset
)
*
1.0
/
self
.
nranks
))
self
.
total_size
=
self
.
num_samples
*
self
.
nranks
...
...
@@ -106,27 +98,22 @@ class DistributedBatchSampler(BatchSampler):
num_samples
+=
int
(
not
self
.
drop_last
)
*
(
self
.
batch_size
-
1
)
return
num_samples
//
self
.
batch_size
def
set_epoch
(
self
,
epoch
):
self
.
epoch
=
epoch
def
_all_gather
(
x
,
nranks
,
ring_id
=
0
,
use_calc_stream
=
True
):
return
_c_allgather
(
x
,
nranks
,
ring_id
=
ring_id
,
use_calc_stream
=
use_calc_stream
)
def
get_local_rank
():
return
Env
().
local_rank
def
get_nranks
():
return
Env
().
nranks
def
_all_gather
(
x
,
nranks
,
ring_id
=
0
,
use_calc_stream
=
True
):
return
collective
.
_c_allgather
(
x
,
nranks
,
ring_id
=
ring_id
,
use_calc_stream
=
use_calc_stream
)
def
wait_server_ready
(
endpoints
):
assert
not
isinstance
(
endpoints
,
string_types
)
assert
not
isinstance
(
endpoints
,
s
ix
.
s
tring_types
)
while
True
:
all_ok
=
True
not_ready_endpoints
=
[]
for
ep
in
endpoints
:
ip_port
=
ep
.
split
(
":"
)
with
closing
(
with
c
ontextlib
.
c
losing
(
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
))
as
sock
:
sock
.
settimeout
(
2
)
...
...
@@ -135,10 +122,6 @@ def wait_server_ready(endpoints):
all_ok
=
False
not_ready_endpoints
.
append
(
ep
)
if
not
all_ok
:
sys
.
stderr
.
write
(
"server not ready, wait 3 sec to retry...
\n
"
)
sys
.
stderr
.
write
(
"not ready endpoints:"
+
str
(
not_ready_endpoints
)
+
"
\n
"
)
sys
.
stderr
.
flush
()
time
.
sleep
(
3
)
else
:
break
...
...
@@ -154,9 +137,9 @@ def init_communicator(program, rank, nranks, wait_port,
wait_server_ready
(
other_endpoints
)
block
=
program
.
global_block
()
nccl_id_var
=
block
.
create_var
(
name
=
nameGen
.
generate
(
'nccl_id'
),
name
=
fluid
.
unique_name
.
generate
(
'nccl_id'
),
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
type
=
fluid
.
core
.
VarDesc
.
VarType
.
RAW
)
block
.
append_op
(
type
=
'c_gen_nccl_id'
,
...
...
@@ -181,23 +164,23 @@ def init_communicator(program, rank, nranks, wait_port,
def
prepare_distributed_context
(
place
=
None
):
if
place
is
None
:
place
=
fluid
.
CUDAPlace
(
Env
().
dev_id
)
if
Env
().
nranks
>
1
\
place
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
if
Parallel
Env
().
nranks
>
1
\
else
fluid
.
CUDAPlace
(
0
)
strategy
=
ParallelStrategy
()
strategy
.
nranks
=
Env
().
nranks
strategy
.
local_rank
=
Env
().
local_rank
strategy
.
trainer_endpoints
=
Env
().
trainer_endpoints
strategy
.
current_endpoint
=
Env
().
current_endpoint
strategy
.
nranks
=
Parallel
Env
().
nranks
strategy
.
local_rank
=
Parallel
Env
().
local_rank
strategy
.
trainer_endpoints
=
Parallel
Env
().
trainer_endpoints
strategy
.
current_endpoint
=
Parallel
Env
().
current_endpoint
if
strategy
.
nranks
<
2
:
return
global
_
_parallel_context_init
global
_
parallel_context_initialized
if
not
_
_parallel_context_init
and
isinstance
(
place
,
core
.
CUDAPlace
):
if
not
_
parallel_context_initialized
and
isinstance
(
place
,
fluid
.
CUDAPlace
):
def
_init_context
():
communicator_prog
=
f
ramework
.
Program
()
communicator_prog
=
f
luid
.
Program
()
init_communicator
(
communicator_prog
,
strategy
.
local_rank
,
strategy
.
nranks
,
True
,
strategy
.
current_endpoint
,
strategy
.
trainer_endpoints
)
exe
=
fluid
.
Executor
(
place
)
...
...
@@ -213,57 +196,5 @@ def prepare_distributed_context(place=None):
else
:
assert
(
"Only support CUDAPlace for now."
)
__parallel_context_init
=
True
return
strategy
class
DistributedDataParallel
(
DataParallel
):
def
__init__
(
self
,
layers
,
strategy
=
None
):
if
strategy
is
None
:
strategy
=
ParallelStrategy
()
strategy
.
nranks
=
Env
().
nranks
strategy
.
local_rank
=
Env
().
local_rank
strategy
.
trainer_endpoints
=
Env
().
trainer_endpoints
strategy
.
current_endpoint
=
Env
().
current_endpoint
super
(
DistributedDataParallel
,
self
).
__init__
(
layers
,
strategy
)
@
no_grad
def
apply_collective_grads
(
self
):
"""
AllReduce the Parameters' gradient.
"""
if
not
self
.
_is_data_parallel_mode
():
return
grad_var_set
=
set
()
grad_vars
=
[]
for
param
in
self
.
_layers
.
parameters
():
# NOTE(zcd): The grad_ivar maybe no generated.
if
param
.
trainable
and
param
.
_grad_ivar
():
g_var
=
param
.
_grad_ivar
()
grad_vars
.
append
(
g_var
)
assert
g_var
not
in
grad_var_set
grad_var_set
.
add
(
g_var
)
mega_bytes
=
128
*
1024
*
1024
group_idx
=
0
memory_counter
=
0
grad_var_groups
=
OrderedDict
()
dtype
=
grad_vars
[
0
].
dtype
for
g_var
in
grad_vars
:
# Note: the dtype of the same group should be the same.
bytes
=
np
.
prod
(
g_var
.
shape
)
*
core
.
size_of_dtype
(
g_var
.
dtype
)
if
memory_counter
<
mega_bytes
and
dtype
==
g_var
.
dtype
:
memory_counter
+=
bytes
else
:
memory_counter
=
bytes
group_idx
+=
1
grad_var_groups
.
setdefault
(
group_idx
,
[]).
append
(
g_var
)
coalesced_grads_and_vars
=
self
.
_coalesce_tensors
(
grad_var_groups
)
for
coalesced_grad
,
_
,
_
in
coalesced_grads_and_vars
:
collective
.
_c_allreduce
(
coalesced_grad
,
coalesced_grad
,
use_calc_stream
=
True
)
self
.
_split_tensors
(
coalesced_grads_and_vars
)
_parallel_context_initialized
=
True
return
strategy
\ No newline at end of file
mnist.py
浏览文件 @
0e47f4c4
...
...
@@ -26,7 +26,7 @@ from paddle.fluid.optimizer import Momentum
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
paddle.fluid.io
import
MNIST
as
MnistDataset
from
model
import
Model
,
CrossEntropy
,
Input
,
init_context
from
model
import
Model
,
CrossEntropy
,
Input
,
set_device
from
metrics
import
Accuracy
...
...
@@ -106,7 +106,8 @@ class MNIST(Model):
def
main
():
init_context
(
'dynamic'
if
FLAGS
.
dynamic
else
'static'
)
device
=
set_device
(
FLAGS
.
device
)
fluid
.
enable_dygraph
(
device
)
if
FLAGS
.
dynamic
else
None
train_dataset
=
MnistDataset
(
mode
=
'train'
)
val_dataset
=
MnistDataset
(
mode
=
'test'
)
...
...
@@ -118,7 +119,13 @@ def main():
optim
=
Momentum
(
learning_rate
=
FLAGS
.
lr
,
momentum
=
.
9
,
parameter_list
=
model
.
parameters
())
model
.
prepare
(
optim
,
CrossEntropy
(),
Accuracy
(
topk
=
(
1
,
2
)),
inputs
,
labels
)
model
.
prepare
(
optim
,
CrossEntropy
(),
Accuracy
(
topk
=
(
1
,
2
)),
inputs
,
labels
,
device
=
FLAGS
.
device
)
if
FLAGS
.
resume
is
not
None
:
model
.
load
(
FLAGS
.
resume
)
...
...
@@ -131,6 +138,8 @@ def main():
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
"CNN training on MNIST"
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
'gpu'
,
help
=
"device to use, gpu or cpu"
)
parser
.
add_argument
(
"-d"
,
"--dynamic"
,
action
=
'store_true'
,
help
=
"enable dygraph mode"
)
parser
.
add_argument
(
...
...
model.py
浏览文件 @
0e47f4c4
...
...
@@ -20,25 +20,34 @@ import pickle
import
numpy
as
np
import
six
import
warnings
from
collections
import
Iterable
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
collections
import
Iterable
from
paddle
import
fluid
from
paddle.fluid.framework
import
in_dygraph_mode
,
Variable
from
paddle.fluid.executor
import
global_scope
from
paddle.fluid.io
import
is_belong_to_optimizer
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
from
paddle.fluid.incubate.fleet.collective
import
fleet
,
DistributedStrategy
import
paddle.fluid.incubate.fleet.base.role_maker
as
role_maker
import
distributed
from
distributed
import
DistributedBatchSampler
from
paddle.fluid.io
import
DataLoader
from
paddle.fluid.incubate.fleet.base
import
role_maker
from
paddle.fluid.io
import
DataLoader
,
Dataset
from
distributed
import
DistributedBatchSampler
,
_all_gather
,
prepare_distributed_context
,
_parallel_context_initialized
from
metrics
import
Metric
from
callbacks
import
config_callbacks
__all__
=
[
'Model'
,
'Loss'
,
'CrossEntropy'
,
'Input'
]
__all__
=
[
'Model'
,
'Loss'
,
'CrossEntropy'
,
'Input'
,
'set_device'
]
def
set_device
(
device
):
assert
isinstance
(
device
,
six
.
string_types
)
and
device
.
lower
()
in
[
'cpu'
,
'gpu'
],
\
"Expected device in ['cpu', 'gpu'], but got {}"
.
format
(
device
)
place
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
if
device
.
lower
()
==
'gpu'
and
fluid
.
is_compiled_with_cuda
()
\
else
fluid
.
CPUPlace
()
return
place
def
to_list
(
value
):
...
...
@@ -84,18 +93,6 @@ def extract_args(func):
return
inspect
.
getargspec
(
func
)[
0
]
def
init_context
(
backend
):
assert
isinstance
(
backend
,
str
)
and
backend
.
lower
()
in
[
'dynamic'
,
'static'
],
\
"Expected backend in ['dynamic', 'static'], but got {}"
.
format
(
backend
)
place
=
fluid
.
CUDAPlace
(
distributed
.
Env
().
dev_id
)
if
\
distributed
.
Env
().
nranks
>
1
else
fluid
.
CUDAPlace
(
0
)
distributed
.
prepare_distributed_context
(
place
)
backend
=
backend
.
lower
()
if
backend
==
'dynamic'
:
fluid
.
enable_dygraph
(
place
)
class
Input
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
shape
=
None
,
dtype
=
None
,
name
=
None
):
super
(
Input
,
self
).
__init__
()
...
...
@@ -161,8 +158,8 @@ class StaticGraphAdapter(object):
'test_batch'
:
0
}
self
.
_nranks
=
distributed
.
Env
().
nranks
self
.
_local_rank
=
distributed
.
Env
().
local_rank
self
.
_nranks
=
Parallel
Env
().
nranks
self
.
_local_rank
=
Parallel
Env
().
local_rank
@
property
def
mode
(
self
):
...
...
@@ -267,7 +264,8 @@ class StaticGraphAdapter(object):
# When using static learning rate, static-graph would make it
# a persistable var named 'unique_name.generate("learning_rate")',
# However, dygraph wouldn't save it.
if
var
.
name
not
in
state
:
continue
if
var
.
name
not
in
state
:
continue
else
:
# moment and other accumulators
if
var
.
name
not
in
converted_state
:
...
...
@@ -366,8 +364,8 @@ class StaticGraphAdapter(object):
for
metric
,
state
in
zip
(
self
.
model
.
_metrics
,
metric_states
):
# cut off padding size
if
self
.
mode
!=
'train'
and
self
.
model
.
_test_dataloader
is
not
None
\
and
isinstance
(
self
.
model
.
_test_dataloader
,
DataLoader
)
\
and
self
.
_nranks
>
1
:
and
isinstance
(
self
.
model
.
_test_dataloader
,
DataLoader
)
\
and
self
.
_nranks
>
1
:
total_size
=
len
(
self
.
model
.
_test_dataloader
.
dataset
)
# TODO: fixme if have better way to get batch size
samples
=
state
[
0
].
shape
[
0
]
...
...
@@ -407,7 +405,7 @@ class StaticGraphAdapter(object):
for
op
in
list
(
prog
.
global_block
().
ops
):
prog
.
global_block
().
_remove_op
(
0
)
if
mode
==
'train'
and
self
.
model
.
_optimizer
\
and
self
.
model
.
_optimizer
.
_learning_rate_map
:
and
self
.
model
.
_optimizer
.
_learning_rate_map
:
# HACK workaround learning rate map issue
lr_var
=
self
.
model
.
_optimizer
.
_learning_rate_map
[
self
.
_orig_prog
]
self
.
model
.
_optimizer
.
_learning_rate_map
[
prog
]
=
lr_var
...
...
@@ -416,8 +414,10 @@ class StaticGraphAdapter(object):
metrics
=
[]
with
fluid
.
program_guard
(
prog
,
self
.
_startup_prog
):
if
isinstance
(
self
.
model
.
_inputs
,
dict
):
ins
=
[
self
.
model
.
_inputs
[
n
]
\
for
n
in
extract_args
(
self
.
model
.
forward
)
if
n
!=
'self'
]
ins
=
[
self
.
model
.
_inputs
[
n
]
for
n
in
extract_args
(
self
.
model
.
forward
)
if
n
!=
'self'
]
else
:
ins
=
self
.
model
.
_inputs
lbls
=
self
.
model
.
_labels
if
self
.
model
.
_labels
else
[]
...
...
@@ -430,14 +430,9 @@ class StaticGraphAdapter(object):
losses
=
self
.
model
.
_loss_function
(
outputs
,
labels
)
if
self
.
_nranks
>
1
and
mode
!=
'train'
:
outputs
=
[
distributed
.
_all_gather
(
o
,
self
.
_nranks
)
for
o
in
outputs
]
outputs
=
[
_all_gather
(
o
,
self
.
_nranks
)
for
o
in
outputs
]
if
mode
!=
'test'
:
labels
=
[
distributed
.
_all_gather
(
l
,
self
.
_nranks
)
for
l
in
labels
]
labels
=
[
_all_gather
(
l
,
self
.
_nranks
)
for
l
in
labels
]
if
mode
!=
'test'
:
for
metric
in
self
.
model
.
_metrics
:
...
...
@@ -474,31 +469,22 @@ class StaticGraphAdapter(object):
if
compiled_prog
is
not
None
:
return
compiled_prog
device
=
self
.
model
.
_device
device_ids
=
self
.
model
.
_device_ids
assert
self
.
model
.
_place
is
not
None
,
\
"device is not set, please call `model.prepare()` first"
if
device
.
lower
()
==
'gpu'
:
places
=
fluid
.
cuda_places
(
device_ids
)
else
:
places
=
fluid
.
cpu_places
(
len
(
device_ids
)
if
device_ids
else
None
)
place
=
self
.
model
.
_place
# XXX *ALL WEIGHTS* should be initialized upon model construction
# even if `forward()` may run different code path for different mode
# therefore startup program only needs to run once
if
self
.
_executor
is
None
:
if
self
.
_nranks
>
1
and
device
.
lower
()
==
'gpu'
:
gpu_id
=
int
(
distributed
.
Env
().
dev_id
)
place
=
fluid
.
CUDAPlace
(
gpu_id
)
if
device
.
lower
(
)
==
'gpu'
else
fluid
.
CPUPlace
()
else
:
place
=
places
[
0
]
self
.
_executor
=
fluid
.
Executor
(
place
)
# XXX incremental initialization
uninitialized
=
[]
for
var_py
in
self
.
_startup_prog
.
list_vars
():
var
=
fluid
.
global_scope
().
find_var
(
var_py
.
name
)
if
not
var_py
.
name
.
startswith
(
'nccl_id'
)
and
var
and
\
var
.
get_tensor
().
_is_initialized
():
var
.
get_tensor
().
_is_initialized
():
continue
uninitialized
.
append
(
var_py
)
...
...
@@ -509,14 +495,8 @@ class StaticGraphAdapter(object):
if
self
.
_nranks
<
2
:
compiled_prog
=
fluid
.
CompiledProgram
(
prog
)
else
:
compiled_prog
=
prog
#fleet.main_program
if
len
(
places
)
>
1
:
loss_name
=
None
if
mode
==
'train'
and
self
.
_loss_endpoint
is
not
None
:
loss_name
=
self
.
_loss_endpoint
.
name
compiled_prog
=
compiled_prog
.
with_data_parallel
(
loss_name
=
loss_name
,
places
=
places
)
compiled_prog
=
prog
self
.
_compiled_progs
[
mode
]
=
compiled_prog
...
...
@@ -524,8 +504,8 @@ class DynamicGraphAdapter(object):
def
__init__
(
self
,
model
):
super
(
DynamicGraphAdapter
,
self
).
__init__
()
self
.
model
=
model
self
.
_nranks
=
distributed
.
Env
().
nranks
self
.
_local_rank
=
distributed
.
Env
().
local_rank
self
.
_nranks
=
Parallel
Env
().
nranks
self
.
_local_rank
=
Parallel
Env
().
local_rank
self
.
_merge_count
=
{
'eval_total'
:
0
,
'test_total'
:
0
,
...
...
@@ -534,7 +514,13 @@ class DynamicGraphAdapter(object):
}
if
self
.
_nranks
>
1
:
self
.
ddp_model
=
distributed
.
DistributedDataParallel
(
self
.
model
)
stradegy
=
fluid
.
dygraph
.
parallel
.
ParallelStrategy
()
stradegy
.
nranks
=
ParallelEnv
().
nranks
stradegy
.
local_rank
=
ParallelEnv
().
local_rank
stradegy
.
trainer_endpoints
=
ParallelEnv
().
trainer_endpoints
stradegy
.
current_endpoint
=
ParallelEnv
().
current_endpoint
self
.
ddp_model
=
fluid
.
dygraph
.
parallel
.
DataParallel
(
self
.
model
,
stradegy
)
@
property
def
mode
(
self
):
...
...
@@ -576,7 +562,7 @@ class DynamicGraphAdapter(object):
metrics
.
append
(
m
)
return
([
to_numpy
(
l
)
for
l
in
losses
],
metrics
)
\
if
len
(
metrics
)
>
0
else
[
to_numpy
(
l
)
for
l
in
losses
]
if
len
(
metrics
)
>
0
else
[
to_numpy
(
l
)
for
l
in
losses
]
def
eval
(
self
,
inputs
,
labels
=
None
):
super
(
Model
,
self
.
model
).
eval
()
...
...
@@ -590,11 +576,8 @@ class DynamicGraphAdapter(object):
else
:
losses
=
[]
if
self
.
_nranks
>
1
:
outputs
=
[
distributed
.
_all_gather
(
o
,
self
.
_nranks
)
for
o
in
to_list
(
outputs
)
]
labels
=
[
distributed
.
_all_gather
(
l
,
self
.
_nranks
)
for
l
in
labels
]
outputs
=
[
_all_gather
(
o
,
self
.
_nranks
)
for
o
in
to_list
(
outputs
)]
labels
=
[
_all_gather
(
l
,
self
.
_nranks
)
for
l
in
labels
]
metrics
=
[]
for
metric
in
self
.
model
.
_metrics
:
# cut off padding value.
...
...
@@ -622,7 +605,7 @@ class DynamicGraphAdapter(object):
# To be consistent with static graph
# return empty loss if loss_function is None
return
([
to_numpy
(
l
)
for
l
in
losses
],
metrics
)
\
if
len
(
metrics
)
>
0
else
[
to_numpy
(
l
)
for
l
in
losses
]
if
len
(
metrics
)
>
0
else
[
to_numpy
(
l
)
for
l
in
losses
]
def
test
(
self
,
inputs
):
super
(
Model
,
self
.
model
).
eval
()
...
...
@@ -630,10 +613,7 @@ class DynamicGraphAdapter(object):
inputs
=
[
to_variable
(
x
)
for
x
in
to_list
(
inputs
)]
outputs
=
self
.
model
.
forward
(
*
inputs
)
if
self
.
_nranks
>
2
:
outputs
=
[
distributed
.
_all_gather
(
o
,
self
.
_nranks
)
for
o
in
to_list
(
outputs
)
]
outputs
=
[
_all_gather
(
o
,
self
.
_nranks
)
for
o
in
to_list
(
outputs
)]
return
[
to_numpy
(
o
)
for
o
in
to_list
(
outputs
)]
def
parameters
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -720,10 +700,6 @@ class Model(fluid.dygraph.Layer):
self
.
_optimizer
=
None
self
.
_test_dataloader
=
None
# init multiple gpus context
self
.
_place
=
fluid
.
CUDAPlace
(
distributed
.
Env
().
dev_id
)
\
if
distributed
.
Env
().
nranks
>
1
else
fluid
.
CUDAPlace
(
0
)
# init backend
if
fluid
.
in_dygraph_mode
():
self
.
_adapter
=
DynamicGraphAdapter
(
self
)
...
...
@@ -740,7 +716,7 @@ class Model(fluid.dygraph.Layer):
return
self
.
_adapter
.
test
(
*
args
,
**
kwargs
)
def
save
(
self
,
*
args
,
**
kwargs
):
if
distributed
.
get_local_rank
()
==
0
:
if
ParallelEnv
().
local_rank
==
0
:
return
self
.
_adapter
.
save
(
*
args
,
**
kwargs
)
def
load
(
self
,
path
,
skip_mismatch
=
False
,
reset_optimizer
=
False
):
...
...
@@ -855,6 +831,35 @@ class Model(fluid.dygraph.Layer):
The default is None.
"""
if
isinstance
(
device
,
fluid
.
CUDAPlace
)
or
\
(
isinstance
(
device
,
six
.
string_types
)
and
device
.
lower
()
==
'gpu'
)
\
or
(
device
is
None
and
fluid
.
is_compiled_with_cuda
()):
if
isinstance
(
device
,
fluid
.
CUDAPlace
):
self
.
_place
=
device
else
:
self
.
_place
=
fluid
.
CUDAPlace
(
ParallelEnv
().
dev_id
)
\
if
ParallelEnv
().
nranks
>
1
else
fluid
.
CUDAPlace
(
0
)
global
_parallel_context_initialized
if
ParallelEnv
().
nranks
>
1
and
not
_parallel_context_initialized
:
if
fluid
.
in_dygraph_mode
():
fluid
.
disable_dygraph
()
fluid
.
enable_dygraph
(
self
.
_place
)
fluid
.
dygraph
.
parallel
.
prepare_context
()
else
:
prepare_distributed_context
(
self
.
_place
)
_parallel_context_initialized
=
True
elif
isinstance
(
device
,
fluid
.
CPUPlace
):
self
.
_place
=
device
elif
(
isinstance
(
device
,
six
.
string_types
)
and
device
.
lower
()
==
'cpu'
)
\
or
(
device
is
None
):
self
.
_place
=
fluid
.
CPUPlace
()
else
:
raise
ValueError
(
"Expected device in ('gpu', 'cpu', fluid.CUDAPlace, fluid.CPUPlace, None),
\
but got {}"
.
format
(
device
))
self
.
_optimizer
=
optimizer
if
loss_function
:
if
not
isinstance
(
loss_function
,
Loss
):
...
...
@@ -871,25 +876,20 @@ class Model(fluid.dygraph.Layer):
metrics
=
metrics
or
[]
for
metric
in
to_list
(
metrics
):
assert
isinstance
(
metric
,
Metric
),
\
"{} is not sub class of Metric"
.
format
(
metric
.
__class__
.
__name__
)
"{} is not sub class of Metric"
.
format
(
metric
.
__class__
.
__name__
)
self
.
_metrics
=
to_list
(
metrics
)
self
.
_inputs
=
inputs
self
.
_labels
=
labels
self
.
_device
=
device
if
device
is
None
:
self
.
_device
=
'GPU'
if
fluid
.
is_compiled_with_cuda
()
else
'CPU'
self
.
_device_ids
=
device_ids
if
not
in_dygraph_mode
():
self
.
_adapter
.
prepare
()
def
fit
(
self
,
train_dataset
=
None
,
eval_dataset
=
None
,
train_loader
=
None
,
eval_loader
=
None
,
train_data
=
None
,
eval_data
=
None
,
batch_size
=
1
,
epochs
=
1
,
eval_freq
=
1
,
...
...
@@ -904,9 +904,16 @@ class Model(fluid.dygraph.Layer):
"""
FIXME: add more comments and usage
Args:
train_loader (DataLoader): An iterable data loader is used for train.
eval_loader (DataLoader): An iterable data loader is used for
evaluation at the end of epoch. If None, will not do evaluation.
train_data (Dataset|DataLoader): An iterable data loader is used for
train. An instance of paddle.fluid.io.Dataset or
paddle.fluid.io.Dataloader is recomended.
eval_data (Dataset|DataLoader): An iterable data loader is used for
evaluation at the end of epoch. If None, will not do evaluation.
An instance of paddle.fluid.io.Dataset or paddle.fluid.io.Dataloader
is recomended.
batch_size (int): Integer number. The batch size of train_data and eval_data.
When train_data and eval_data are both the instance of Dataloader, this
parameter will be ignored.
epochs (int): Integer number. The number of epochs to train the model.
eval_freq (int): The frequency, in number of epochs, an evalutation
is performed.
...
...
@@ -917,47 +924,57 @@ class Model(fluid.dygraph.Layer):
save_freq (int): The frequency, in number of epochs, to save checkpoint.
verbose (int): The verbosity mode, should be 0, 1, or 2.
0 = silent, 1 = progress bar, 2 = one line per epoch.
drop_last (bool): whether drop the last incomplete batch of train_data
when dataset size is not divisible by the batch size. When train_data
is an instance of Dataloader, this parameter will be ignored.
shuffle (bool): whther to shuffle train_data. When train_data is an instance
of Dataloader, this parameter will be ignored.
num_workers (int): the number of subprocess to load data, 0 for no subprocess
used and loading data in main process. When train_data and eval_data are
both the instance of Dataloader, this parameter will be ignored.
callbacks (Callback|None): A list of `Callback` instances to apply
during training. If None, `ProgBarLogger` and `ModelCheckpoint`
are automatically inserted.
"""
assert
train_dataset
is
not
None
or
train_loader
is
not
None
,
\
"train_dataset or train_loader must be given"
assert
(
train_loader
is
not
None
and
train_dataset
is
None
)
or
\
(
train_loader
is
None
and
train_dataset
is
not
None
),
\
"train_dataset should not be set when train_loader is given"
assert
train_data
is
not
None
,
\
"train_data must be given!"
if
fluid
.
in_dygraph_mode
():
feed_list
=
None
else
:
feed_list
=
[
x
.
forward
()
for
x
in
self
.
_inputs
+
self
.
_labels
]
if
train_loader
is
None
:
if
isinstance
(
train_data
,
Dataset
)
:
train_sampler
=
DistributedBatchSampler
(
train_data
set
,
train_data
,
batch_size
=
batch_size
,
shuffle
=
shuffle
,
drop_last
=
drop_last
)
train_loader
=
DataLoader
(
train_data
set
,
train_data
,
batch_sampler
=
train_sampler
,
places
=
self
.
_place
,
feed_list
=
feed_list
,
num_workers
=
num_workers
,
return_list
=
True
)
else
:
train_loader
=
train_data
if
eval_
loader
is
None
and
eval_dataset
is
not
None
:
if
eval_
data
is
not
None
and
isinstance
(
eval_data
,
Dataset
)
:
eval_sampler
=
DistributedBatchSampler
(
eval_data
set
,
batch_size
=
batch_size
)
eval_data
,
batch_size
=
batch_size
)
eval_loader
=
DataLoader
(
eval_data
set
,
eval_data
,
batch_sampler
=
eval_sampler
,
places
=
self
.
_place
,
feed_list
=
feed_list
,
num_workers
=
num_workers
,
return_list
=
True
)
elif
eval_data
is
not
None
:
eval_loader
=
eval_data
else
:
eval_loader
=
None
do_eval
=
eval_loader
is
not
None
self
.
_test_dataloader
=
eval_loader
...
...
@@ -1010,7 +1027,7 @@ class Model(fluid.dygraph.Layer):
logs
[
'step'
]
=
step
if
mode
==
'train'
or
self
.
_adapter
.
_merge_count
.
get
(
mode
+
'_batch'
,
0
)
<=
0
:
logs
[
'batch_size'
]
=
batch_size
*
distributed
.
Env
().
nranks
logs
[
'batch_size'
]
=
batch_size
*
Parallel
Env
().
nranks
else
:
logs
[
'batch_size'
]
=
self
.
_adapter
.
_merge_count
[
mode
+
'_batch'
]
...
...
@@ -1035,7 +1052,7 @@ class Model(fluid.dygraph.Layer):
loader
=
eval_loader
if
not
isinstance
(
eval_loader
,
Iterable
):
loader
=
eval_loader
()
logs
=
_run_one_epoch
(
eval_loader
()
,
cbks
,
'eval'
)
logs
=
_run_one_epoch
(
loader
,
cbks
,
'eval'
)
cbks
.
on_end
(
'eval'
,
logs
)
cbks
.
on_end
(
'train'
,
logs
)
...
...
tests/test_model.py
浏览文件 @
0e47f4c4
...
...
@@ -28,7 +28,7 @@ import contextlib
import
paddle
from
paddle
import
fluid
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
Linear
from
model
import
Model
,
CrossEntropy
,
Input
,
Loss
,
init_context
from
model
import
Model
,
CrossEntropy
,
Input
,
Loss
,
set_device
from
metrics
import
Accuracy
from
callbacks
import
ProgBarLogger
from
paddle.fluid.io
import
BatchSampler
,
DataLoader
...
...
@@ -141,7 +141,8 @@ class MyCrossEntropy(Loss):
class
TestModel
(
unittest
.
TestCase
):
def
fit
(
self
,
dynamic
,
is_mlp
=
False
):
init_context
(
'dynamic'
if
dynamic
else
'static'
)
device
=
set_device
(
'gpu'
)
fluid
.
enable_dygraph
(
device
)
if
dynamic
else
None
im_shape
=
(
-
1
,
784
)
batch_size
=
128
...
...
@@ -156,7 +157,7 @@ class TestModel(unittest.TestCase):
optim
=
fluid
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
.
9
,
parameter_list
=
model
.
parameters
())
loss
=
CrossEntropy
()
if
not
is_mlp
else
MyCrossEntropy
()
model
.
prepare
(
optim
,
loss
,
Accuracy
(),
inputs
,
labels
)
model
.
prepare
(
optim
,
loss
,
Accuracy
(),
inputs
,
labels
,
device
=
device
)
cbk
=
ProgBarLogger
(
50
)
model
.
fit
(
train_dataset
,
val_dataset
,
...
...
transformer/reader.py
浏览文件 @
0e47f4c4
...
...
@@ -138,6 +138,95 @@ def pad_batch_data(insts,
return
return_list
if
len
(
return_list
)
>
1
else
return_list
[
0
]
class
SortType
(
object
):
GLOBAL
=
'global'
POOL
=
'pool'
NONE
=
"none"
class
Converter
(
object
):
def
__init__
(
self
,
vocab
,
beg
,
end
,
unk
,
delimiter
,
add_beg
):
self
.
_vocab
=
vocab
self
.
_beg
=
beg
self
.
_end
=
end
self
.
_unk
=
unk
self
.
_delimiter
=
delimiter
self
.
_add_beg
=
add_beg
def
__call__
(
self
,
sentence
):
return
([
self
.
_beg
]
if
self
.
_add_beg
else
[])
+
[
self
.
_vocab
.
get
(
w
,
self
.
_unk
)
for
w
in
sentence
.
split
(
self
.
_delimiter
)
]
+
[
self
.
_end
]
class
ComposedConverter
(
object
):
def
__init__
(
self
,
converters
):
self
.
_converters
=
converters
def
__call__
(
self
,
parallel_sentence
):
return
[
self
.
_converters
[
i
](
parallel_sentence
[
i
])
for
i
in
range
(
len
(
self
.
_converters
))
]
class
SentenceBatchCreator
(
object
):
def
__init__
(
self
,
batch_size
):
self
.
batch
=
[]
self
.
_batch_size
=
batch_size
def
append
(
self
,
info
):
self
.
batch
.
append
(
info
)
if
len
(
self
.
batch
)
==
self
.
_batch_size
:
tmp
=
self
.
batch
self
.
batch
=
[]
return
tmp
class
TokenBatchCreator
(
object
):
def
__init__
(
self
,
batch_size
):
self
.
batch
=
[]
self
.
max_len
=
-
1
self
.
_batch_size
=
batch_size
def
append
(
self
,
info
):
cur_len
=
info
.
max_len
max_len
=
max
(
self
.
max_len
,
cur_len
)
if
max_len
*
(
len
(
self
.
batch
)
+
1
)
>
self
.
_batch_size
:
result
=
self
.
batch
self
.
batch
=
[
info
]
self
.
max_len
=
cur_len
return
result
else
:
self
.
max_len
=
max_len
self
.
batch
.
append
(
info
)
class
SampleInfo
(
object
):
def
__init__
(
self
,
i
,
max_len
,
min_len
):
self
.
i
=
i
self
.
min_len
=
min_len
self
.
max_len
=
max_len
class
MinMaxFilter
(
object
):
def
__init__
(
self
,
max_len
,
min_len
,
underlying_creator
):
self
.
_min_len
=
min_len
self
.
_max_len
=
max_len
self
.
_creator
=
underlying_creator
def
append
(
self
,
info
):
if
info
.
max_len
>
self
.
_max_len
or
info
.
min_len
<
self
.
_min_len
:
return
else
:
return
self
.
_creator
.
append
(
info
)
@
property
def
batch
(
self
):
return
self
.
_creator
.
batch
class
Seq2SeqDataset
(
Dataset
):
def
__init__
(
self
,
src_vocab_fpath
,
...
...
@@ -338,421 +427,3 @@ class Seq2SeqBatchSampler(BatchSampler):
@
property
def
dev_id
(
self
):
return
self
.
_dev_id
class
SortType
(
object
):
GLOBAL
=
'global'
POOL
=
'pool'
NONE
=
"none"
class
Converter
(
object
):
def
__init__
(
self
,
vocab
,
beg
,
end
,
unk
,
delimiter
,
add_beg
):
self
.
_vocab
=
vocab
self
.
_beg
=
beg
self
.
_end
=
end
self
.
_unk
=
unk
self
.
_delimiter
=
delimiter
self
.
_add_beg
=
add_beg
def
__call__
(
self
,
sentence
):
return
([
self
.
_beg
]
if
self
.
_add_beg
else
[])
+
[
self
.
_vocab
.
get
(
w
,
self
.
_unk
)
for
w
in
sentence
.
split
(
self
.
_delimiter
)
]
+
[
self
.
_end
]
class
ComposedConverter
(
object
):
def
__init__
(
self
,
converters
):
self
.
_converters
=
converters
def
__call__
(
self
,
parallel_sentence
):
return
[
self
.
_converters
[
i
](
parallel_sentence
[
i
])
for
i
in
range
(
len
(
self
.
_converters
))
]
class
SentenceBatchCreator
(
object
):
def
__init__
(
self
,
batch_size
):
self
.
batch
=
[]
self
.
_batch_size
=
batch_size
def
append
(
self
,
info
):
self
.
batch
.
append
(
info
)
if
len
(
self
.
batch
)
==
self
.
_batch_size
:
tmp
=
self
.
batch
self
.
batch
=
[]
return
tmp
class
TokenBatchCreator
(
object
):
def
__init__
(
self
,
batch_size
):
self
.
batch
=
[]
self
.
max_len
=
-
1
self
.
_batch_size
=
batch_size
def
append
(
self
,
info
):
cur_len
=
info
.
max_len
max_len
=
max
(
self
.
max_len
,
cur_len
)
if
max_len
*
(
len
(
self
.
batch
)
+
1
)
>
self
.
_batch_size
:
result
=
self
.
batch
self
.
batch
=
[
info
]
self
.
max_len
=
cur_len
return
result
else
:
self
.
max_len
=
max_len
self
.
batch
.
append
(
info
)
class
SampleInfo
(
object
):
def
__init__
(
self
,
i
,
max_len
,
min_len
):
self
.
i
=
i
self
.
min_len
=
min_len
self
.
max_len
=
max_len
class
MinMaxFilter
(
object
):
def
__init__
(
self
,
max_len
,
min_len
,
underlying_creator
):
self
.
_min_len
=
min_len
self
.
_max_len
=
max_len
self
.
_creator
=
underlying_creator
def
append
(
self
,
info
):
if
info
.
max_len
>
self
.
_max_len
or
info
.
min_len
<
self
.
_min_len
:
return
else
:
return
self
.
_creator
.
append
(
info
)
@
property
def
batch
(
self
):
return
self
.
_creator
.
batch
class
DataProcessor
(
object
):
"""
The data reader loads all data from files and produces batches of data
in the way corresponding to settings.
An example of returning a generator producing data batches whose data
is shuffled in each pass and sorted in each pool:
```
train_data = DataProcessor(
src_vocab_fpath='data/src_vocab_file',
trg_vocab_fpath='data/trg_vocab_file',
fpattern='data/part-*',
use_token_batch=True,
batch_size=2000,
device_count=8,
n_head=8,
pool_size=10000,
sort_type=SortType.POOL,
shuffle=True,
shuffle_batch=True,
start_mark='<s>',
end_mark='<e>',
unk_mark='<unk>',
clip_last_batch=False).data_generator(phase='train')
```
:param src_vocab_fpath: The path of vocabulary file of source language.
:type src_vocab_fpath: basestring
:param trg_vocab_fpath: The path of vocabulary file of target language.
:type trg_vocab_fpath: basestring
:param fpattern: The pattern to match data files.
:type fpattern: basestring
:param batch_size: The number of sequences contained in a mini-batch.
or the maximum number of tokens (include paddings) contained in a
mini-batch.
:type batch_size: int
:param pool_size: The size of pool buffer.
:type device_count: int
:param device_count: The number of devices. The actual batch size is
determined by both batch_size and device_count.
:type n_head: int
:param n_head: The number of head used in multi-head attention. Actually,
this is not a reader related argument, but is used for input data.
:type pool_size: int
:param sort_type: The grain to sort by length: 'global' for all
instances; 'pool' for instances in pool; 'none' for no sort.
:type sort_type: basestring
:param clip_last_batch: Whether to clip the last uncompleted batch.
:type clip_last_batch: bool
:param tar_fname: The data file in tar if fpattern matches a tar file.
:type tar_fname: basestring
:param min_length: The minimum length used to filt sequences.
:type min_length: int
:param max_length: The maximum length used to filt sequences.
:type max_length: int
:param shuffle: Whether to shuffle all instances.
:type shuffle: bool
:param shuffle_batch: Whether to shuffle the generated batches.
:type shuffle_batch: bool
:param use_token_batch: Whether to produce batch data according to
token number.
:type use_token_batch: bool
:param field_delimiter: The delimiter used to split source and target in
each line of data file.
:type field_delimiter: basestring
:param token_delimiter: The delimiter used to split tokens in source or
target sentences.
:type token_delimiter: basestring
:param start_mark: The token representing for the beginning of
sentences in dictionary.
:type start_mark: basestring
:param end_mark: The token representing for the end of sentences
in dictionary.
:type end_mark: basestring
:param unk_mark: The token representing for unknown word in dictionary.
:type unk_mark: basestring
:param only_src: Whether each line is a source and target sentence
pair or only has the source sentence.
:type only_src: bool
:param seed: The seed for random.
:type seed: int
"""
def
__init__
(
self
,
src_vocab_fpath
,
trg_vocab_fpath
,
fpattern
,
batch_size
,
device_count
,
n_head
,
pool_size
,
sort_type
=
SortType
.
GLOBAL
,
clip_last_batch
=
False
,
tar_fname
=
None
,
min_length
=
0
,
max_length
=
100
,
shuffle
=
True
,
shuffle_batch
=
False
,
use_token_batch
=
False
,
field_delimiter
=
"
\t
"
,
token_delimiter
=
" "
,
start_mark
=
"<s>"
,
end_mark
=
"<e>"
,
unk_mark
=
"<unk>"
,
only_src
=
False
,
seed
=
0
):
# convert str to bytes, and use byte data
field_delimiter
=
field_delimiter
.
encode
(
"utf8"
)
token_delimiter
=
token_delimiter
.
encode
(
"utf8"
)
start_mark
=
start_mark
.
encode
(
"utf8"
)
end_mark
=
end_mark
.
encode
(
"utf8"
)
unk_mark
=
unk_mark
.
encode
(
"utf8"
)
self
.
_src_vocab
=
self
.
load_dict
(
src_vocab_fpath
)
self
.
_trg_vocab
=
self
.
load_dict
(
trg_vocab_fpath
)
self
.
_bos_idx
=
self
.
_src_vocab
[
start_mark
]
self
.
_eos_idx
=
self
.
_src_vocab
[
end_mark
]
self
.
_unk_idx
=
self
.
_src_vocab
[
unk_mark
]
self
.
_only_src
=
only_src
self
.
_pool_size
=
pool_size
self
.
_batch_size
=
batch_size
self
.
_device_count
=
device_count
self
.
_n_head
=
n_head
self
.
_use_token_batch
=
use_token_batch
self
.
_sort_type
=
sort_type
self
.
_clip_last_batch
=
clip_last_batch
self
.
_shuffle
=
shuffle
self
.
_shuffle_batch
=
shuffle_batch
self
.
_min_length
=
min_length
self
.
_max_length
=
max_length
self
.
_field_delimiter
=
field_delimiter
self
.
_token_delimiter
=
token_delimiter
self
.
load_src_trg_ids
(
fpattern
,
tar_fname
)
self
.
_random
=
np
.
random
self
.
_random
.
seed
(
seed
)
def
load_src_trg_ids
(
self
,
fpattern
,
tar_fname
):
converters
=
[
Converter
(
vocab
=
self
.
_src_vocab
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
False
)
]
if
not
self
.
_only_src
:
converters
.
append
(
Converter
(
vocab
=
self
.
_trg_vocab
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
True
))
converters
=
ComposedConverter
(
converters
)
self
.
_src_seq_ids
=
[]
self
.
_trg_seq_ids
=
None
if
self
.
_only_src
else
[]
self
.
_sample_infos
=
[]
for
i
,
line
in
enumerate
(
self
.
_load_lines
(
fpattern
,
tar_fname
)):
src_trg_ids
=
converters
(
line
)
self
.
_src_seq_ids
.
append
(
src_trg_ids
[
0
])
lens
=
[
len
(
src_trg_ids
[
0
])]
if
not
self
.
_only_src
:
self
.
_trg_seq_ids
.
append
(
src_trg_ids
[
1
])
lens
.
append
(
len
(
src_trg_ids
[
1
]))
self
.
_sample_infos
.
append
(
SampleInfo
(
i
,
max
(
lens
),
min
(
lens
)))
def
_load_lines
(
self
,
fpattern
,
tar_fname
):
fpaths
=
glob
.
glob
(
fpattern
)
assert
len
(
fpaths
)
>
0
,
"no matching file to the provided data path"
if
len
(
fpaths
)
==
1
and
tarfile
.
is_tarfile
(
fpaths
[
0
]):
if
tar_fname
is
None
:
raise
Exception
(
"If tar file provided, please set tar_fname."
)
f
=
tarfile
.
open
(
fpaths
[
0
],
"rb"
)
for
line
in
f
.
extractfile
(
tar_fname
):
fields
=
line
.
strip
(
b
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
yield
fields
else
:
for
fpath
in
fpaths
:
if
not
os
.
path
.
isfile
(
fpath
):
raise
IOError
(
"Invalid file: %s"
%
fpath
)
with
open
(
fpath
,
"rb"
)
as
f
:
for
line
in
f
:
fields
=
line
.
strip
(
b
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
yield
fields
@
staticmethod
def
load_dict
(
dict_path
,
reverse
=
False
):
word_dict
=
{}
with
open
(
dict_path
,
"rb"
)
as
fdict
:
for
idx
,
line
in
enumerate
(
fdict
):
if
reverse
:
word_dict
[
idx
]
=
line
.
strip
(
b
"
\n
"
)
else
:
word_dict
[
line
.
strip
(
b
"
\n
"
)]
=
idx
return
word_dict
def
batch_generator
(
self
,
batch_size
,
use_token_batch
):
def
__impl__
():
# global sort or global shuffle
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
infos
=
sorted
(
self
.
_sample_infos
,
key
=
lambda
x
:
x
.
max_len
)
else
:
if
self
.
_shuffle
:
infos
=
self
.
_sample_infos
self
.
_random
.
shuffle
(
infos
)
else
:
infos
=
self
.
_sample_infos
if
self
.
_sort_type
==
SortType
.
POOL
:
reverse
=
True
for
i
in
range
(
0
,
len
(
infos
),
self
.
_pool_size
):
# to avoid placing short next to long sentences
reverse
=
not
reverse
infos
[
i
:
i
+
self
.
_pool_size
]
=
sorted
(
infos
[
i
:
i
+
self
.
_pool_size
],
key
=
lambda
x
:
x
.
max_len
,
reverse
=
reverse
)
# concat batch
batches
=
[]
batch_creator
=
TokenBatchCreator
(
batch_size
)
if
use_token_batch
else
SentenceBatchCreator
(
batch_size
)
batch_creator
=
MinMaxFilter
(
self
.
_max_length
,
self
.
_min_length
,
batch_creator
)
for
info
in
infos
:
batch
=
batch_creator
.
append
(
info
)
if
batch
is
not
None
:
batches
.
append
(
batch
)
if
not
self
.
_clip_last_batch
and
len
(
batch_creator
.
batch
)
!=
0
:
batches
.
append
(
batch_creator
.
batch
)
if
self
.
_shuffle_batch
:
self
.
_random
.
shuffle
(
batches
)
for
batch
in
batches
:
batch_ids
=
[
info
.
i
for
info
in
batch
]
if
self
.
_only_src
:
yield
[[
self
.
_src_seq_ids
[
idx
]]
for
idx
in
batch_ids
]
else
:
yield
[(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
][:
-
1
],
self
.
_trg_seq_ids
[
idx
][
1
:])
for
idx
in
batch_ids
]
return
__impl__
@
staticmethod
def
stack
(
data_reader
,
count
,
clip_last
=
True
):
def
__impl__
():
res
=
[]
for
item
in
data_reader
():
res
.
append
(
item
)
if
len
(
res
)
==
count
:
yield
res
res
=
[]
if
len
(
res
)
==
count
:
yield
res
elif
not
clip_last
:
data
=
[]
for
item
in
res
:
data
+=
item
if
len
(
data
)
>
count
:
inst_num_per_part
=
len
(
data
)
//
count
yield
[
data
[
inst_num_per_part
*
i
:
inst_num_per_part
*
(
i
+
1
)]
for
i
in
range
(
count
)
]
return
__impl__
@
staticmethod
def
split
(
data_reader
,
count
):
def
__impl__
():
for
item
in
data_reader
():
inst_num_per_part
=
len
(
item
)
//
count
for
i
in
range
(
count
):
yield
item
[
inst_num_per_part
*
i
:
inst_num_per_part
*
(
i
+
1
)]
return
__impl__
def
data_generator
(
self
,
phase
):
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
src_pad_idx
=
trg_pad_idx
=
self
.
_eos_idx
bos_idx
=
self
.
_bos_idx
n_head
=
self
.
_n_head
data_reader
=
self
.
batch_generator
(
self
.
_batch_size
*
(
1
if
self
.
_use_token_batch
else
self
.
_device_count
),
self
.
_use_token_batch
)
if
not
self
.
_use_token_batch
:
# to make data on each device have similar token number
data_reader
=
self
.
split
(
data_reader
,
self
.
_device_count
)
def
__for_train__
():
for
data
in
data_reader
():
data_inputs
=
prepare_train_input
(
data
,
src_pad_idx
,
trg_pad_idx
,
n_head
)
yield
data_inputs
[:
-
2
],
data_inputs
[
-
2
:]
def
__for_predict__
():
for
data
in
data_reader
():
data_inputs
=
prepare_infer_input
(
data
,
src_pad_idx
,
bos_idx
,
n_head
)
yield
data_inputs
return
__for_train__
if
phase
==
"train"
else
__for_predict__
def
get_vocab_summary
(
self
):
return
len
(
self
.
_src_vocab
),
len
(
self
.
_trg_vocab
),
self
.
_bos_idx
,
self
.
_eos_idx
,
self
.
_unk_idx
transformer/train.py
浏览文件 @
0e47f4c4
...
...
@@ -19,6 +19,7 @@ import sys
sys
.
path
.
append
(
os
.
path
.
dirname
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))))
import
time
import
contextlib
from
functools
import
partial
import
numpy
as
np
import
paddle
...
...
@@ -30,9 +31,9 @@ from utils.configure import PDConfig
from
utils.check
import
check_gpu
,
check_version
# include task-specific libs
import
read
er
from
reader
import
prepare_train_input
,
Seq2SeqDataset
,
Seq2SeqBatchSampl
er
from
transformer
import
Transformer
,
CrossEntropyCriterion
,
NoamDecay
from
model
import
Input
from
model
import
Input
,
set_device
from
callbacks
import
ProgBarLogger
...
...
@@ -62,7 +63,8 @@ class LoggerCallback(ProgBarLogger):
def
do_train
(
args
):
# init_context('dynamic' if FLAGS.dynamic else 'static')
device
=
set_device
(
"gpu"
if
args
.
use_cuda
else
"cpu"
)
fluid
.
enable_dygraph
(
device
)
if
args
.
eager_run
else
None
# set seed for CE
random_seed
=
eval
(
str
(
args
.
random_seed
))
...
...
@@ -72,20 +74,19 @@ def do_train(args):
# define model
inputs
=
[
Input
(
[
None
,
None
],
"int64"
,
name
=
"src_word"
),
Input
(
[
None
,
None
],
"int64"
,
name
=
"src_pos"
),
Input
(
[
None
,
args
.
n_head
,
None
,
None
],
"float32"
,
name
=
"src_slf_attn_bias"
),
Input
(
[
None
,
None
],
"int64"
,
name
=
"trg_word"
),
Input
(
[
None
,
None
],
"int64"
,
name
=
"trg_pos"
),
Input
(
[
None
,
args
.
n_head
,
None
,
None
],
"float32"
,
name
=
"trg_slf_attn_bias"
),
Input
(
[
None
,
args
.
n_head
,
None
,
None
],
"float32"
,
name
=
"trg_src_attn_bias"
)
Input
([
None
,
None
],
"int64"
,
name
=
"src_word"
),
Input
([
None
,
None
],
"int64"
,
name
=
"src_pos"
),
Input
([
None
,
args
.
n_head
,
None
,
None
],
"float32"
,
name
=
"src_slf_attn_bias"
),
Input
([
None
,
None
],
"int64"
,
name
=
"trg_word"
),
Input
([
None
,
None
],
"int64"
,
name
=
"trg_pos"
),
Input
([
None
,
args
.
n_head
,
None
,
None
],
"float32"
,
name
=
"trg_slf_attn_bias"
),
Input
([
None
,
args
.
n_head
,
None
,
None
],
"float32"
,
name
=
"trg_src_attn_bias"
),
]
labels
=
[
Input
(
...
...
@@ -94,32 +95,33 @@ def do_train(args):
[
None
,
1
],
"float32"
,
name
=
"weight"
),
]
dataset
=
reader
.
Seq2SeqDataset
(
fpattern
=
args
.
training_file
,
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
token_delimiter
=
args
.
token_delimiter
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
])
dataset
=
Seq2SeqDataset
(
fpattern
=
args
.
training_file
,
src_vocab_fpath
=
args
.
src_vocab_fpath
,
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
token_delimiter
=
args
.
token_delimiter
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
])
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
args
.
unk_idx
=
dataset
.
get_vocab_summary
()
batch_sampler
=
reader
.
Seq2SeqBatchSampler
(
dataset
=
dataset
,
use_token_batch
=
args
.
use_token_batch
,
batch_size
=
args
.
batch_size
,
pool_size
=
args
.
pool_size
,
sort_type
=
args
.
sort_type
,
shuffle
=
args
.
shuffle
,
shuffle_batch
=
args
.
shuffle_batch
,
max_length
=
args
.
max_length
)
train_loader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
places
=
None
,
feed_list
=
[
x
.
forward
()
for
x
in
inputs
+
labels
],
num_workers
=
0
,
return_list
=
True
)
batch_sampler
=
Seq2SeqBatchSampler
(
dataset
=
dataset
,
use_token_batch
=
args
.
use_token_batch
,
batch_size
=
args
.
batch_size
,
pool_size
=
args
.
pool_size
,
sort_type
=
args
.
sort_type
,
shuffle
=
args
.
shuffle
,
shuffle_batch
=
args
.
shuffle_batch
,
max_length
=
args
.
max_length
)
train_loader
=
DataLoader
(
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
places
=
device
,
feed_list
=
[
x
.
forward
()
for
x
in
inputs
+
labels
],
collate_fn
=
partial
(
prepare_train_input
,
src_pad_idx
=
args
.
eos_idx
,
trg_pad_idx
=
args
.
eos_idx
,
n_head
=
args
.
n_head
),
num_workers
=
0
,
return_list
=
True
)
transformer
=
Transformer
(
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
max_length
+
1
,
...
...
@@ -156,8 +158,8 @@ def do_train(args):
(
1.
-
args
.
label_smooth_eps
))
+
args
.
label_smooth_eps
*
np
.
log
(
args
.
label_smooth_eps
/
(
args
.
trg_vocab_size
-
1
)
+
1e-20
))
transformer
.
fit
(
train_
loader
=
train_loader
,
eval_
loader
=
None
,
transformer
.
fit
(
train_
data
=
train_loader
,
eval_
data
=
None
,
epochs
=
1
,
eval_freq
=
1
,
save_freq
=
1
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录