Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
82b8a3c5
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
82b8a3c5
编写于
9月 17, 2018
作者:
Y
yuyang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move trainer to contrib
上级
3fbfcd9c
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
1374 addition
and
1351 deletion
+1374
-1351
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+0
-9
python/paddle/fluid/contrib/inferencer.py
python/paddle/fluid/contrib/inferencer.py
+112
-0
python/paddle/fluid/contrib/trainer.py
python/paddle/fluid/contrib/trainer.py
+1258
-0
python/paddle/fluid/inferencer.py
python/paddle/fluid/inferencer.py
+2
-98
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+2
-1244
未找到文件。
python/paddle/fluid/__init__.py
浏览文件 @
82b8a3c5
...
@@ -19,17 +19,8 @@ from .framework import *
...
@@ -19,17 +19,8 @@ from .framework import *
# import all class inside executor into fluid module
# import all class inside executor into fluid module
from
.
import
executor
from
.
import
executor
from
.executor
import
*
from
.executor
import
*
from
.
import
trainer
from
.
import
trainer
from
.trainer
import
Trainer
from
.trainer
import
BeginEpochEvent
from
.trainer
import
EndEpochEvent
from
.trainer
import
BeginStepEvent
from
.trainer
import
EndStepEvent
from
.trainer
import
CheckpointConfig
from
.
import
inferencer
from
.
import
inferencer
from
.inferencer
import
Inferencer
from
.
import
io
from
.
import
io
from
.
import
evaluator
from
.
import
evaluator
...
...
python/paddle/fluid/contrib/inferencer.py
0 → 100644
浏览文件 @
82b8a3c5
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
contextlib
from
..
import
core
from
..
import
executor
from
..
import
framework
from
..
import
io
from
..
import
parallel_executor
from
..
import
unique_name
from
.trainer
import
check_and_get_place
__all__
=
[
'Inferencer'
,
]
class
Inferencer
(
object
):
"""
Inferencer High Level API.
Args:
infer_func (Python func): Infer function that will return predict Variable
param_path (str): The path where the inference model is saved by fluid.io.save_params
place (Place): place to do the inference
parallel (bool): use parallel_executor to run the inference, it will use multi CPU/GPU.
Examples:
.. code-block:: python
def inference_program():
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
return y_predict
place = fluid.CPUPlace()
inferencer = fluid.Inferencer(
infer_func=inference_program, param_path="/tmp/model", place=place)
"""
def
__init__
(
self
,
infer_func
,
param_path
,
place
=
None
,
parallel
=
False
):
self
.
param_path
=
param_path
self
.
scope
=
core
.
Scope
()
self
.
parallel
=
parallel
self
.
place
=
check_and_get_place
(
place
)
self
.
inference_program
=
framework
.
Program
()
with
framework
.
program_guard
(
self
.
inference_program
):
with
unique_name
.
guard
():
self
.
predict_var
=
infer_func
()
with
self
.
_prog_and_scope_guard
():
# load params from param_path into scope
io
.
load_params
(
executor
.
Executor
(
self
.
place
),
param_path
)
if
parallel
:
with
self
.
_prog_and_scope_guard
():
self
.
exe
=
parallel_executor
.
ParallelExecutor
(
use_cuda
=
isinstance
(
self
.
place
,
core
.
CUDAPlace
),
loss_name
=
self
.
predict_var
.
name
)
else
:
self
.
exe
=
executor
.
Executor
(
self
.
place
)
self
.
inference_program
=
self
.
inference_program
.
clone
(
for_test
=
True
)
def
infer
(
self
,
inputs
,
return_numpy
=
True
):
"""
Do Inference for Inputs
Args:
inputs (map): a map of {"input_name": input_var} that will be feed into the inference program
return_numpy (bool): transform return value into numpy or not
Returns:
Tensor or Numpy: the predict value of the inference model for the inputs
Examples:
.. code-block:: python
tensor_x = numpy.random.uniform(0, 10, [batch_size, 13]).astype("float32")
results = inferencer.infer({'x': tensor_x})
"""
if
not
isinstance
(
inputs
,
dict
):
raise
ValueError
(
"inputs should be a map of {'input_name': input_var}"
)
with
self
.
_prog_and_scope_guard
():
results
=
self
.
exe
.
run
(
feed
=
inputs
,
fetch_list
=
[
self
.
predict_var
.
name
],
return_numpy
=
return_numpy
)
return
results
@
contextlib
.
contextmanager
def
_prog_and_scope_guard
(
self
):
with
framework
.
program_guard
(
main_program
=
self
.
inference_program
):
with
executor
.
scope_guard
(
self
.
scope
):
yield
python/paddle/fluid/contrib/trainer.py
0 → 100644
浏览文件 @
82b8a3c5
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
contextlib
import
os
import
errno
import
shutil
import
six
import
time
from
..
import
core
from
..
import
data_feeder
from
..
import
executor
from
..
import
framework
from
..
import
io
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
from
..
import
optimizer
as
opt_module
from
..
import
parallel_executor
from
..transpiler
import
distribute_transpiler
__all__
=
[
'Trainer'
,
'BeginEpochEvent'
,
'EndEpochEvent'
,
'BeginStepEvent'
,
'EndStepEvent'
,
'CheckpointConfig'
]
class
BeginEpochEvent
(
object
):
"""
The begin of a training epoch.
Args:
epoch_id(int): The current epoch ID.
"""
def
__init__
(
self
,
epoch_id
):
self
.
epoch
=
epoch_id
class
EndEpochEvent
(
object
):
"""
The end of a training epoch.
Args:
epoch_id(int): The current epoch ID.
"""
def
__init__
(
self
,
epoch_id
):
self
.
epoch
=
epoch_id
class
BeginStepEvent
(
object
):
"""
The begin of a training epoch.
Args:
epoch_id(int): The current epoch ID.
step_id(int): The current step ID.
"""
def
__init__
(
self
,
epoch_id
,
step_id
):
self
.
epoch
=
epoch_id
self
.
step
=
step_id
self
.
fetch_metrics
=
True
"""
If fetch_metrics is true, the metrics will be fetched at the
EndStepEvent. Default is True.
"""
class
EndStepEvent
(
object
):
"""
The end of a training step.
Args:
epoch_id(int): The current epoch ID.
step_id(int): The current step ID.
metrics(list): A list of fetched tensor. The order of this list is same
as the :code:`train_func` returns.
"""
def
__init__
(
self
,
epoch_id
,
step_id
,
metrics
):
self
.
epoch
=
epoch_id
self
.
step
=
step_id
self
.
metrics
=
metrics
class
CheckpointConfig
(
object
):
"""
Parameter object for :code:`save_checkpoint` and
:code:`fluid.Trainer`. Used to configuration how to save checkpoint.
Args:
checkpoint_dir(str): Directory path to save check point. Default is the
current directory.
max_num_checkpoints(int): The max number of local check points.
epoch_interval(int): Every number of epoch to save check point.
step_interval(int): Every number of step to save check point.
Examples:
>>> config = fluid.CheckpointConfig("./checkpoints")
>>> trainer = fluid.Trainer(train_func=train_program,
>>> place=place,
>>> optimizer_func=optimizer_func,
>>> checkpoint_config=config)
>>> trainer.train(...)
"""
def
__init__
(
self
,
checkpoint_dir
=
None
,
max_num_checkpoints
=
3
,
epoch_interval
=
1
,
step_interval
=
10
):
assert
epoch_interval
>=
1
assert
step_interval
>=
1
self
.
checkpoint_dir
=
checkpoint_dir
\
if
checkpoint_dir
is
not
None
else
os
.
getcwd
()
self
.
max_num_checkpoints
=
max_num_checkpoints
self
.
epoch_interval
=
epoch_interval
self
.
step_interval
=
step_interval
self
.
epoch_id
=
0
self
.
step_id
=
0
self
.
load_serial
=
None
self
.
pserver_id
=
None
self
.
lookup_table_name
=
None
def
check_and_get_place
(
place
):
"""
Check the type of place or get the default place
Args:
place(None|core.CUDAPlace|core.CPUPlace): the place that trainer will be executed on.
Raises:
TypeError if the type mismatched.
Returns:
the original place if it is not None.
if fluid is compiled with CUDA, returns CUDAPlace(0) by default.
Otherwise returns CPUPlace by default.
"""
if
place
is
None
:
if
core
.
is_compiled_with_cuda
():
return
core
.
CUDAPlace
(
0
)
else
:
return
core
.
CPUPlace
()
else
:
if
not
isinstance
(
place
,
core
.
CUDAPlace
)
and
not
isinstance
(
place
,
core
.
CPUPlace
):
raise
TypeError
(
"Place should be either CUDAPlace or CPUPlace"
)
return
place
class
Trainer
(
object
):
"""
A trainer wraps MultiGPU/MultiNode training loops and can be used to train a
simple neural network easily.
This API takes a :code:`train_func`. A :code:`train_func` is a function that
return loss as it first return value. The reset value can be fetched by
EndStepEvent.metrics
This API also takes a :code:`optimizer_func` that will return an optimizer
instance.
For example, to train a MLP for MNIST dataset, the sample program is
>>> import paddle.fluid as fluid
>>>
>>> def mlp(image, layer_sizes=[200, 100], activation="relu", num_classes=10):
>>> hidden = image
>>> for layer_size in layer_sizes:
>>> hidden = fluid.layers.fc(input=hidden, size=layer_size, act=activation)
>>> return fluid.layers.fc(input=hidden, size=num_classes, act="softmax")
>>>
>>> def train_mnist_mlp():
>>> img = fluid.layers.data(name='image', shape=[784])
>>> label = fluid.layers.data(name='label', shape=[1], dtype='int64')
>>> prediction = mlp(img)
>>> return fluid.layers.mean(fluid.layers.cross_entropy(prediction, label))
>>>
>>> def optimizer():
>>> return fluid.optimizer.Adam()
>>>
>>> trainer = Trainer(train_func=train_mnist_mlp,
>>> optimizer_func=optimizer,
>>> place=fluid.CUDAPlace(0),
>>> parallel=True)
>>>
>>> def train_callback(event):
>>> if isinstance(event, fluid.EndStepEvent):
>>> print "Epoch ID", event.epoch, "Step ID",
\
>>> event.step, "AvgLoss", event.metrics[0]
>>> elif isinstance(event, fluid.EndEpochEvent):
>>> trainer.save_params("./model_{0}".format(event.epoch))
>>>
>>> trainer.train(num_epochs=100, event_handler=train_callback)
For more example, please see :ref:`api_guide_high_level_api`.
Args:
train_func(callable): A function which will return loss. The loss must be
a scalar tensor.
optimizer_func(callable): A function that returns an Optimizer object.
place(CUDAPlace|CPUPlace): The device place of this trainer. If
:code:`parallel=True,` all CUDA Places will be used if :code:`place`
is a :code:`CUDAPlace`.
parallel(bool): True if use multiple devices.
checkpoint_config(CheckpointConfig): Configuration about how to save
checkpoints.
"""
def
__init__
(
self
,
train_func
,
optimizer_func
,
param_path
=
None
,
place
=
None
,
parallel
=
False
,
checkpoint_config
=
None
):
self
.
__stop
=
False
self
.
parallel
=
parallel
# config for checkpoint
# only chief worker will save variables
self
.
trainer_id
=
0
self
.
checkpoint_cfg
=
checkpoint_config
if
self
.
checkpoint_cfg
:
assert
isinstance
(
self
.
checkpoint_cfg
,
CheckpointConfig
)
serial
=
_get_latest_checkpoint_serial
(
self
.
checkpoint_cfg
.
checkpoint_dir
)
self
.
checkpoint_cfg
.
load_serial
=
serial
if
serial
>=
0
else
None
self
.
scope
=
core
.
Scope
()
# 1. we need to generate a framework.Program by calling
# program_func. Reference: fluid.program_guard in
# test_word2vec.py
self
.
startup_program
=
framework
.
Program
()
self
.
train_program
=
framework
.
Program
()
with
framework
.
program_guard
(
self
.
train_program
,
self
.
startup_program
):
program_func_outs
=
train_func
()
self
.
train_func_outputs
=
program_func_outs
if
isinstance
(
program_func_outs
,
list
)
else
[
program_func_outs
]
self
.
test_program
=
self
.
train_program
.
clone
(
for_test
=
True
)
# The first element of program_func_outs is loss.
loss
=
self
.
train_func_outputs
[
0
]
optimizer
=
optimizer_func
()
if
not
isinstance
(
optimizer
,
opt_module
.
Optimizer
):
raise
TypeError
(
"The optimizer should be an instance of Optimizer"
)
optimize_ops
,
params_grads
=
optimizer
.
minimize
(
loss
)
self
.
place
=
check_and_get_place
(
place
)
self
.
_dist_transpile_if_necessary
(
optimize_ops
,
params_grads
)
# 2. move the default_main_program to self.program and run the
# default_startup program on an empty core.Scope()
# Run startup program
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
place
)
exe
.
run
(
self
.
startup_program
)
if
self
.
checkpoint_cfg
and
self
.
checkpoint_cfg
.
load_serial
is
not
None
:
self
.
_load_checkpoint
()
if
param_path
and
os
.
path
.
isdir
(
param_path
):
with
self
.
_prog_and_scope_guard
():
# load params from param_path into scope
io
.
load_persistables
(
executor
=
exe
,
dirname
=
param_path
,
main_program
=
self
.
startup_program
)
def
_transpile_nccl2_dist
(
self
):
# PADDLE_TRAINER_IPS
if
"PADDLE_TRAINER_IPS"
not
in
os
.
environ
:
self
.
nccl_id_var
=
None
else
:
self
.
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
))
port
=
os
.
getenv
(
"PADDLE_PSERVER_PORT"
)
worker_ips
=
os
.
getenv
(
"PADDLE_TRAINER_IPS"
)
worker_endpoints
=
[]
for
ip
in
worker_ips
.
split
(
","
):
worker_endpoints
.
append
(
':'
.
join
([
ip
,
port
]))
self
.
num_trainers
=
len
(
worker_endpoints
)
current_endpoint
=
os
.
getenv
(
"PADDLE_CURRENT_IP"
)
+
":"
+
port
worker_endpoints
.
remove
(
current_endpoint
)
# TODO(wuyi): use self.nccl_id_var, self.num_trainers and self.trainer_id
# in ParallelExecutor to start
# distributed training using NCCL2
self
.
nccl_id_var
=
self
.
startup_program
.
global_block
().
create_var
(
name
=
"NCCLID"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
self
.
startup_program
.
global_block
().
append_op
(
type
=
"gen_nccl_id"
,
inputs
=
{},
outputs
=
{
"NCCLID"
:
self
.
nccl_id_var
},
attrs
=
{
"endpoint"
:
current_endpoint
,
"endpoint_list"
:
worker_endpoints
,
"trainer_id"
:
self
.
trainer_id
})
def
_dist_transpile_if_necessary
(
self
,
optimize_ops
,
params_grads
):
self
.
_transpile_nccl2_dist
()
if
self
.
nccl_id_var
!=
None
:
return
if
"PADDLE_TRAINING_ROLE"
not
in
os
.
environ
:
return
# the port of all pservers, needed by both trainer and pserver
port
=
os
.
getenv
(
"PADDLE_PSERVER_PORT"
,
"6174"
)
# comma separated ips of all pservers, needed by trainer and
# pserver
pserver_ips
=
os
.
getenv
(
"PADDLE_PSERVER_IPS"
,
""
)
eplist
=
[]
for
ip
in
pserver_ips
.
split
(
","
):
eplist
.
append
(
':'
.
join
([
ip
,
port
]))
pserver_endpoints
=
","
.
join
(
eplist
)
# total number of workers/trainers in the job, needed by
# trainer and pserver
trainers
=
int
(
os
.
getenv
(
"PADDLE_TRAINERS"
))
# the IP of the local machine, needed by pserver only
current_endpoint
=
os
.
getenv
(
"PADDLE_CURRENT_IP"
,
""
)
+
":"
+
port
# the unique trainer id, starting from 0, needed by trainer
# only
self
.
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
"0"
))
# the role, should be either PSERVER or TRAINER
training_role
=
os
.
getenv
(
"PADDLE_TRAINING_ROLE"
)
with
self
.
_prog_and_scope_guard
():
t
=
distribute_transpiler
.
DistributeTranspiler
()
t
.
transpile
(
self
.
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
if
self
.
checkpoint_cfg
:
pserver_id
=
eplist
.
index
(
current_endpoint
)
self
.
checkpoint_cfg
.
pserver_id
=
pserver_id
if
t
.
has_distributed_lookup_table
:
self
.
checkpoint_cfg
.
lookup_table_name
=
t
.
table_name
self
.
train_program
=
t
.
get_pserver_program
(
current_endpoint
)
self
.
startup_program
=
t
.
get_startup_program
(
current_endpoint
,
self
.
train_program
)
elif
training_role
==
"TRAINER"
:
self
.
train_program
=
t
.
get_trainer_program
()
else
:
raise
ValueError
(
'TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
def
stop
(
self
):
"""
stop training
"""
self
.
__stop
=
True
def
train
(
self
,
num_epochs
,
event_handler
,
reader
=
None
,
feed_order
=
None
):
"""
Start the train loop to train the model.
Args:
num_epochs(int): The number of epoch. An epoch will process all data in reader
event_handler(callable): The event handler. A function with type (ev:Event)->void
reader(callable): A reader creator object. See also
:ref:`api_guide_python_reader` .
feed_order(list): Feeding order of reader. None will following the defining
order in program
Returns:
None
"""
training_role
=
os
.
getenv
(
"PADDLE_TRAINING_ROLE"
,
""
)
if
training_role
==
"PSERVER"
:
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
self
.
place
)
exe
.
run
()
return
if
self
.
parallel
:
self
.
_train_by_parallel_executor
(
num_epochs
,
event_handler
,
reader
,
feed_order
)
else
:
self
.
_train_by_executor
(
num_epochs
,
event_handler
,
reader
,
feed_order
)
def
test
(
self
,
reader
,
feed_order
):
"""
Test the model on given test data
Args:
reader(callable): The reader that yields test data.
feed_order(list): Feeding order of reader. None will following the
defining order in program
"""
return
self
.
_test_by_executor
(
reader
,
feed_order
,
self
.
train_func_outputs
)
def
save_params
(
self
,
param_path
):
"""
Save all parameters into :code:`param_path`.
Args:
param_path(str): The path to save parameters.
Returns:
None
"""
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
self
.
place
)
io
.
save_persistables
(
exe
,
dirname
=
param_path
)
def
save_inference_model
(
self
,
param_path
,
feeded_var_names
,
target_var_indexes
):
"""
Save model for cpp inference into :code:`param_path`.
Args:
param_path(str): The path to save parameters.
feeded_var_names(list(str)): The name of the vars that you
need to feed in before run program.
target_var_indexes(list(int)): the index of target var that
you need to return in trainer.train_func.
Returns:
None
"""
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
self
.
place
)
target_vars
=
[
self
.
train_func_outputs
[
index
]
for
index
in
target_var_indexes
]
io
.
save_inference_model
(
param_path
,
feeded_var_names
,
target_vars
,
exe
)
@
contextlib
.
contextmanager
def
_prog_and_scope_guard
(
self
):
with
framework
.
program_guard
(
main_program
=
self
.
train_program
,
startup_program
=
self
.
startup_program
):
with
executor
.
scope_guard
(
self
.
scope
):
yield
def
_train_by_executor
(
self
,
num_epochs
,
event_handler
,
reader
,
feed_order
):
"""
Train by Executor and single device.
Args:
num_epochs:
event_handler:
reader:
feed_order:
Returns:
"""
with
self
.
_prog_and_scope_guard
():
feed_var_list
=
build_feed_var_list
(
self
.
train_program
,
feed_order
)
feeder
=
data_feeder
.
DataFeeder
(
feed_list
=
feed_var_list
,
place
=
self
.
place
)
exe
=
executor
.
Executor
(
self
.
place
)
reader
=
feeder
.
decorate_reader
(
reader
,
multi_devices
=
False
)
self
.
_train_by_any_executor
(
event_handler
,
exe
,
num_epochs
,
reader
)
def
_train_by_any_executor
(
self
,
event_handler
,
exe
,
num_epochs
,
reader
):
if
self
.
checkpoint_cfg
:
epochs
=
[
epoch_id
for
epoch_id
in
range
(
num_epochs
)
if
epoch_id
>=
self
.
checkpoint_cfg
.
epoch_id
]
else
:
epochs
=
[
epoch_id
for
epoch_id
in
range
(
num_epochs
)]
for
epoch_id
in
epochs
:
event_handler
(
BeginEpochEvent
(
epoch_id
))
for
step_id
,
data
in
enumerate
(
reader
()):
if
self
.
__stop
:
if
self
.
checkpoint_cfg
:
self
.
_clean_checkpoint
()
return
if
self
.
checkpoint_cfg
and
self
.
checkpoint_cfg
.
load_serial
\
and
self
.
checkpoint_cfg
.
step_id
>=
step_id
and
self
.
checkpoint_cfg
.
epoch_id
==
epoch_id
:
continue
begin_event
=
BeginStepEvent
(
epoch_id
,
step_id
)
event_handler
(
begin_event
)
if
begin_event
.
fetch_metrics
:
metrics
=
exe
.
run
(
feed
=
data
,
fetch_list
=
[
var
.
name
for
var
in
self
.
train_func_outputs
])
else
:
metrics
=
exe
.
run
(
feed
=
data
,
fetch_list
=
[])
if
self
.
checkpoint_cfg
:
self
.
_save_checkpoint
(
epoch_id
,
step_id
)
event_handler
(
EndStepEvent
(
epoch_id
,
step_id
,
metrics
))
event_handler
(
EndEpochEvent
(
epoch_id
))
if
self
.
checkpoint_cfg
:
self
.
_clean_checkpoint
()
def
_test_by_executor
(
self
,
reader
,
feed_order
,
fetch_list
):
with
executor
.
scope_guard
(
self
.
scope
):
feed_var_list
=
build_feed_var_list
(
self
.
test_program
,
feed_order
)
feeder
=
data_feeder
.
DataFeeder
(
feed_list
=
feed_var_list
,
place
=
self
.
place
)
exe
=
executor
.
Executor
(
self
.
place
)
accumulated
=
len
(
fetch_list
)
*
[
0
]
count
=
0
for
data
in
reader
():
outs
=
exe
.
run
(
program
=
self
.
test_program
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
fetch_list
)
accumulated
=
[
x
[
0
]
+
x
[
1
][
0
]
for
x
in
zip
(
accumulated
,
outs
)]
count
+=
1
return
[
x
/
count
for
x
in
accumulated
]
def
_train_by_parallel_executor
(
self
,
num_epochs
,
event_handler
,
reader
,
feed_order
):
with
self
.
_prog_and_scope_guard
():
pe
=
self
.
_get_or_create_parallel_executor
()
feed_var_list
=
build_feed_var_list
(
self
.
train_program
,
feed_order
)
feeder
=
data_feeder
.
DataFeeder
(
feed_list
=
feed_var_list
,
place
=
self
.
place
)
reader
=
feeder
.
decorate_reader
(
reader
,
multi_devices
=
True
)
self
.
_train_by_any_executor
(
event_handler
,
pe
,
num_epochs
,
reader
)
def
_get_parallel_executor
(
self
):
return
getattr
(
self
,
'parallel_executor'
,
None
)
def
_get_or_create_parallel_executor
(
self
):
if
self
.
_get_parallel_executor
()
is
None
:
self
.
parallel_executor
=
parallel_executor
.
ParallelExecutor
(
use_cuda
=
isinstance
(
self
.
place
,
core
.
CUDAPlace
),
loss_name
=
self
.
train_func_outputs
[
0
].
name
)
return
self
.
_get_parallel_executor
()
def
_clean_checkpoint
(
self
):
assert
self
.
checkpoint_cfg
clean_checkpoint
(
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
)
def
_get_checkpoint_load_args
(
self
):
"""
epoch_id and step_id are runtime arguments, they are not variables, will load them independently.
"""
return
[
"epoch_id"
,
"step_id"
]
def
_get_checkpoint_save_args
(
self
,
epoch_id
,
step_id
):
"""
epoch_id and step_id are runtime arguments, they are not variables, will save them independently.
"""
trainer_args
=
{}
trainer_args
[
"epoch_id"
]
=
epoch_id
trainer_args
[
"step_id"
]
=
step_id
return
trainer_args
def
_save_checkpoint
(
self
,
epoch_id
,
step_id
):
assert
self
.
checkpoint_cfg
if
epoch_id
%
self
.
checkpoint_cfg
.
epoch_interval
==
0
\
and
step_id
%
self
.
checkpoint_cfg
.
step_interval
==
0
:
exe
=
executor
.
Executor
(
self
.
place
)
save_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
trainer_id
=
self
.
trainer_id
,
trainer_args
=
self
.
_get_checkpoint_save_args
(
epoch_id
,
step_id
),
main_program
=
self
.
train_program
,
max_num_checkpoints
=
self
.
checkpoint_cfg
.
max_num_checkpoints
)
def
_load_checkpoint
(
self
):
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
self
.
place
)
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
main_program
=
self
.
startup_program
)
if
not
self
.
checkpoint_cfg
.
pserver_id
:
load_trainer_args
=
self
.
_get_checkpoint_load_args
()
trainer_args
=
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
trainer_id
,
is_trainer
=
True
,
load_trainer_args
=
load_trainer_args
)
if
len
(
trainer_args
)
!=
2
:
raise
ValueError
(
"the return trainer_args length do not equal _get_checkpoint_load_args"
)
self
.
checkpoint_cfg
.
epoch_id
=
int
(
trainer_args
[
0
])
self
.
checkpoint_cfg
.
step_id
=
int
(
trainer_args
[
1
])
else
:
if
self
.
checkpoint_cfg
.
lookup_table_name
:
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
checkpoint_cfg
.
pserver_id
,
is_trainer
=
False
,
load_trainer_args
=
None
,
load_lookup_table
=
self
.
checkpoint_cfg
.
lookup_table_name
)
def
build_feed_var_list
(
program
,
feed_order
):
if
not
isinstance
(
program
,
framework
.
Program
):
raise
TypeError
(
"The 'program' should be an object of Program"
)
if
isinstance
(
feed_order
,
list
):
feed_var_list
=
[
program
.
global_block
().
var
(
var_name
)
for
var_name
in
feed_order
]
else
:
if
not
isinstance
(
feed_order
,
dict
):
raise
TypeError
(
"The 'feed_order' should be either None, list or dict."
)
if
not
sorted
(
feed_order
.
values
())
==
list
(
range
(
len
(
feed_order
))):
raise
ValueError
(
"The values of 'feed_order' should be a permutation of [0, len(feed_order))"
)
sorted_pair_list
=
sorted
(
six
.
iteritems
(
feed_order
),
key
=
lambda
item
:
item
[
1
])
feed_var_list
=
[
program
.
global_block
().
var
(
pair
[
0
])
for
pair
in
sorted_pair_list
]
return
feed_var_list
# move Checkpoint APIs from io.py to trainer.py, make all of them are private.
SUCCESS_MARK_FILENAME
=
"_SUCCESS"
CHECKPOINT_PREFIX
=
"checkpoint"
MODEL_DIR
=
"__model__"
LOOKUP_TABLE_DIR
=
"__lookup_table__"
TRAINER_PREFIX
=
"trainer"
CHECKPOINT_SEPARATOR
=
"_"
def
save_checkpoint
(
executor
,
checkpoint_dir
,
trainer_id
,
main_program
,
trainer_args
=
None
,
max_num_checkpoints
=
3
,
lookup_table
=
None
,
pserver_endpoints
=
None
):
"""
This function filters out all checkpoint variables from the give
main_program and then saves these variables to the `checkpoint_dir`
directory.
In the training precess, we generally save a checkpoint in each
iteration. So there might be a lot of checkpoints in the
`checkpoint_dir`. To avoid them taking too much disk space, the
`max_num_checkpoints` are introduced to limit the total number of
checkpoints. If the number of existing checkpints is greater than
the `max_num_checkpoints`, oldest ones will be scroll deleted.
A variable is a checkpoint variable and will be saved if it meets
all following conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for save checkpoint.
checkpoint_dir(str): The folder where to save checkpoints.
trainer_id(int): currect trainer id, if id is equal to 0, the trainer
is chief.
trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
and 'step_id'.
Defaut: None
main_program(Program): The program whose checkpoint variables will
be saved.
max_num_checkpoints(int): The max number of total number of existing
checkpoints.
Default: 3
lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
pserver_endpoints(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get pserver_endpoints by
distribute arguments.
Returns:
None
Raises:
ValueError: If `checkpoint_dir` is None.
AssertionError: If `trainer_args` is not a dict.
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
path = "./checkpoints"
prog = fluid.default_main_program()
trainer_args = {"epoch_id": 200,
"step_id": 20} # just an example
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
save_checkpoint(executor=exe,
checkpoint_dir=path,
trainer_id=0,
trainer_args=trainer_args,
main_program=prog,
max_num_checkpoints=3,
lookup_table=table_name,
pserver_endpoints = ps_endpoints)
"""
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
if
main_program
is
None
:
raise
ValueError
(
'main_program should not be None.'
)
if
trainer_args
:
assert
isinstance
(
trainer_args
,
dict
)
is_chief
=
trainer_id
==
0
_make_chekcpoint_dirs
(
checkpoint_dir
)
serial
=
_get_latest_checkpoint_serial
(
checkpoint_dir
)
+
1
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
_save_trainer_args
(
cur_dir
,
trainer_id
,
trainer_args
)
if
is_chief
:
_save_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
if
is_chief
and
lookup_table
and
pserver_endpoints
:
_save_pserver_vars_by_notify
(
executor
,
cur_dir
,
lookup_table
,
pserver_endpoints
)
_scroll_delete
(
checkpoint_dir
,
max_num_checkpoints
)
def
load_checkpoint
(
executor
,
checkpoint_dir
,
main_program
,
role_id
=
0
,
is_trainer
=
True
,
load_trainer_args
=
None
,
load_lookup_table
=
None
):
"""
This function filters out all checkpoint variables from the give
main_program and then try to load these variables from the
`checkpoint_dir` directory.
In the training precess, we generally save a checkpoint in each
iteration. So there are more than one checkpoint in the
`checkpoint_dir` (each checkpoint has its own sub folder), use
`serial` to specify which serial of checkpoint you would like to
load.
A variable is a checkpoint variable and will be loaded if it meets
all following conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for loading checkpoint.
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
main_program(Program): The program whose checkpoint variables will
be loaded.
role_id(int): the trainer id or the parameter server id.
is_trainer(bool): trainer is True and parameter server is False.
load_trainer_args(list|None): list about load trainer args.
load_lookup_table(str|None): the lookup table name
Returns:
None
Raises:
ValueError: If `checkpoint_dir` is None.
ValueError: If `main_program` is None.
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
path = "./checkpoints"
prog = fluid.default_main_program()
load_checkpoint(executor=exe, checkpoint_dir=path,
serial=9, main_program=prog)
# In this example, `load_checkpoint` function
# will first filters out all checkpoint variables in the default
# main program, and then try to load these variables form the
# folder "./checkpoints/checkpoint_9/__model__".
"""
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
serial
=
_get_latest_checkpoint_serial
(
checkpoint_dir
)
# there are nothing need to be loaded
if
serial
is
None
or
serial
<
0
:
return
if
main_program
is
None
:
raise
ValueError
(
'main_program should not be None.'
)
if
is_trainer
and
load_trainer_args
is
None
:
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
_load_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
,
True
)
return
if
is_trainer
and
load_trainer_args
:
return
_load_trainer_args
(
checkpoint_dir
,
serial
,
role_id
,
load_trainer_args
)
if
not
is_trainer
and
load_lookup_table
:
_load_lookup_table_vars
(
executor
,
checkpoint_dir
,
main_program
,
role_id
,
load_lookup_table
)
def
clean_checkpoint
(
checkpoint_dir
,
delete_dir
=
False
):
"""
clean the checkpoint dir, when the train exits normally,
the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
: param checkpoint_dir
: param delete_dir
"""
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
_scroll_delete
(
checkpoint_dir
,
max_num_checkpoints
=
0
)
if
delete_dir
and
not
os
.
listdir
(
checkpoint_dir
):
os
.
rmdir
(
checkpoint_dir
)
def
_load_persist_vars_without_grad
(
executor
,
dirname
,
program
,
has_model_dir
=
False
):
"""
This function filters out all checkpoint variables from the give
program and then trys to load these variables from the given directory.
A variable is a checkpoint variable if it meets all following
conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for loading variables.
dirname(str): The directory path.
program(Program): The program whose checkpoint variables will
be loaded.
has_model_dir(bool): if True, the function loads variables
from a sub directory named '__model__'.
Default: False
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
_load_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog, has_model_dir=True)
# In this example, `_load_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default
# main program, and then trys to load these variables form the
# folder "./my_paddle_model/__model__".
"""
if
has_model_dir
:
dirname
=
_get_model_dir
(
dirname
)
io
.
load_vars
(
executor
,
dirname
=
dirname
,
main_program
=
program
,
predicate
=
_is_checkpoint_var
,
filename
=
None
)
def
_load_lookup_table_vars
(
executor
,
dirname
,
program
,
pserver_id
,
table_name
):
"""
The parameter server will load lookup table's local file in
selectedrows variable.
Args:
executor(Executor): The executor to run for loading persistable variables
dirname(str): The directory path
main_program(Program): Find the variable named table_name in main_program
pserver_id(int): the serial number in pserver_endpoints list
table_name(str): lookup table name
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
dirname = "./checkpoints/checkpoint_9/"
prog = fluid.default_main_program()
pserver_id = 1
table_name = "share_w"
_load_lookup_table_vars(executor=exe,
dirname=dirname, program=prog, pserver_id=pserver_id,
table_name=table_name)
"""
for
var
in
program
.
list_vars
():
if
var
.
name
==
table_name
:
lookup_table_var
=
var
break
assert
lookup_table_var
is
not
None
lookup_table_dir
=
os
.
path
.
join
(
dirname
,
LOOKUP_TABLE_DIR
)
table_file
=
table_name
+
CHECKPOINT_SEPARATOR
+
str
(
pserver_id
)
load_prog
=
framework
.
Program
()
load_block
=
load_prog
.
global_block
()
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
lookup_table_var
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
lookup_table_dir
,
table_file
)})
executor
.
run
(
load_prog
)
def
_save_persist_vars_without_grad
(
executor
,
dirname
,
program
):
"""
This function filters out all checkpoint variables from the give
program and then save these variables to a sub-folder '__model__' of
the given directory.
A variable is a checkpoint variable if it meets all following
conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for saving variables.
dirname(str): The directory path.
program(Program): The program whose checkpoint variables will
be saved.
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
_save_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog)
# In this example, `_save_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default
# main program, and then saves these variables to the folder
# "./my_paddle_model/__model__".
"""
cur_dir
=
_get_model_dir
(
dirname
)
io
.
save_vars
(
executor
,
dirname
=
cur_dir
,
main_program
=
program
,
vars
=
None
,
predicate
=
_is_checkpoint_var
,
filename
=
None
)
_write_success
(
cur_dir
)
def
_save_pserver_vars_by_notify
(
executor
,
dirname
,
lookup_table
,
ps_endpoint_list
):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save checkpoints.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
_save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
ps_endpoint_list=ps_endpoints)
"""
cur_dir
=
_get_lookuptable_dir
(
dirname
)
checkpoint_notify_program
=
framework
.
Program
()
checkpoint_notify_block
=
checkpoint_notify_program
.
global_block
()
attrs
=
{}
attrs
[
'epmap'
]
=
ps_endpoint_list
attrs
[
'dir'
]
=
cur_dir
attrs
[
'lookup_table'
]
=
lookup_table
checkpoint_notify_block
.
append_op
(
type
=
'checkpoint_notify'
,
inputs
=
{},
outputs
=
{},
attrs
=
attrs
)
executor
.
run
(
checkpoint_notify_program
)
def
_save_trainer_args
(
dirname
,
trainer_id
,
trainer_args
):
assert
isinstance
(
trainer_args
,
dict
)
cur_dir
=
_get_trainer_dir
(
dirname
,
trainer_id
)
for
name
,
value
in
six
.
iteritems
(
trainer_args
):
args_file
=
os
.
path
.
join
(
cur_dir
,
name
)
with
open
(
args_file
,
'w'
)
as
f
:
f
.
write
(
str
(
value
))
_write_success
(
cur_dir
)
def
_load_trainer_args
(
checkpoint_dir
,
serial
,
trainer_id
,
trainer_args
):
"""
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
Args:
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
trainer_id(int): current trainer id.
trainer_args(list): list about load trainer args
Return:
None
Examples:
.. code-block:: python
param_path = "./checkpoint/"
serial = 7
trainer_id = 2
trainer_args = ["epoch_id", "step_id"]
_load_trainer_args(checkpoint_dir=param_path, serial=serial,
trainer_id=trainer_id, trainer_args=trainer_args)
"""
assert
isinstance
(
trainer_args
,
list
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
cur_dir
=
_get_trainer_dir
(
cur_dir
,
trainer_id
)
ret_values
=
[]
for
arg
in
trainer_args
:
cur_file
=
os
.
path
.
join
(
cur_dir
,
arg
)
with
open
(
cur_file
,
'r'
)
as
f
:
contents
=
f
.
read
()
ret_values
.
append
(
contents
.
strip
())
return
ret_values
def
_is_checkpoint_var
(
var
):
"""
the checkpoint will not save or load all the variables.
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
: param var(Variable)
"""
if
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FEED_MINIBATCH
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FETCH_LIST
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
RAW
:
return
False
# @GRAD are named for gradient variables, checkpoint will not save it.
if
"@GRAD"
in
var
.
name
:
return
False
# .trainer_ are named for distribute train variables, checkpoint will not save it.
if
".trainer_"
in
var
.
name
:
return
False
# .block is named for distribute train variables, checkpoint will not save it.
if
".block"
in
var
.
name
:
return
False
return
var
.
persistable
def
_make_chekcpoint_dirs
(
dirs
):
"""
_make_chekcpoint_dirs will makdir local directory directly, when the directory is exist, it will igore it.
"""
assert
dirs
is
not
None
if
os
.
path
.
isfile
(
dirs
):
raise
OSError
(
errno
.
ENOTDIR
,
"dirs path shoule be a Directory."
,
dirs
)
if
not
os
.
path
.
isdir
(
dirs
):
try
:
os
.
makedirs
(
dirs
)
except
OSError
as
err
:
if
err
.
errno
!=
errno
.
EEXIST
:
raise
err
def
_get_dir_serial
(
dirname
):
_
,
serial
=
dirname
.
split
(
CHECKPOINT_SEPARATOR
)
try
:
serial_num
=
int
(
serial
)
except
ValueError
:
serial_num
=
-
1
return
serial_num
def
_get_serial_dir
(
dirname
,
serial
):
serial_folder
=
CHECKPOINT_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
serial
)
serial_dir
=
os
.
path
.
join
(
dirname
,
serial_folder
)
_make_chekcpoint_dirs
(
serial_dir
)
return
serial_dir
def
_get_model_dir
(
dirname
):
model_dir
=
os
.
path
.
join
(
dirname
,
MODEL_DIR
)
_make_chekcpoint_dirs
(
model_dir
)
return
model_dir
def
_get_lookuptable_dir
(
dirname
):
lookuptable_dir
=
os
.
path
.
join
(
dirname
,
LOOKUP_TABLE_DIR
)
_make_chekcpoint_dirs
(
lookuptable_dir
)
return
lookuptable_dir
def
_get_trainer_dir
(
dirname
,
trainer_id
):
trainer_folder
=
TRAINER_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
trainer_id
)
trainer_dir
=
os
.
path
.
join
(
dirname
,
trainer_folder
)
_make_chekcpoint_dirs
(
trainer_dir
)
return
trainer_dir
def
_scroll_delete
(
dirname
,
max_num_checkpoints
=
3
):
dirs
=
os
.
listdir
(
dirname
)
serial_map
=
{}
for
serial
in
dirs
:
serial_num
=
_get_dir_serial
(
serial
)
serial_map
[
serial_num
]
=
serial
if
len
(
list
(
serial_map
.
keys
()))
<=
max_num_checkpoints
:
return
serials
=
list
(
serial_map
.
keys
())
serials
.
sort
(
reverse
=
True
)
serials
=
serials
[
max_num_checkpoints
:]
for
serial
in
serials
:
cur_dir
=
_get_serial_dir
(
dirname
,
serial
)
try
:
shutil
.
rmtree
(
cur_dir
)
except
OSError
as
err
:
if
err
.
errno
!=
errno
.
ENOENT
:
raise
err
def
_write_success
(
dirname
):
"""
write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct.
: param dirname
"""
success_file
=
os
.
path
.
join
(
dirname
,
SUCCESS_MARK_FILENAME
)
with
open
(
success_file
,
'a'
)
as
f
:
now
=
time
.
ctime
()
f
.
write
(
now
)
def
_get_latest_checkpoint_serial
(
checkpoint_dir
):
"""
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
: param checkpoint_dir
"""
if
not
checkpoint_dir
:
return
-
1
def
has_success
(
checkpoint_dir
,
cur_dir
):
"""
is _SUCCESS in this dir
"""
serial
=
_get_dir_serial
(
cur_dir
)
if
serial
==
-
1
or
not
os
.
path
.
isdir
(
os
.
path
.
join
(
checkpoint_dir
,
cur_dir
)):
return
-
1
success_path
=
os
.
path
.
join
(
_get_serial_dir
(
checkpoint_dir
,
serial
),
MODEL_DIR
,
SUCCESS_MARK_FILENAME
)
if
os
.
path
.
isfile
(
success_path
):
return
serial
if
not
os
.
path
.
isdir
(
checkpoint_dir
):
return
-
1
current_dir
=
-
1
dirs
=
os
.
listdir
(
checkpoint_dir
)
for
cur_dir
in
dirs
:
success_num
=
has_success
(
checkpoint_dir
,
cur_dir
)
if
success_num
>
current_dir
:
current_dir
=
success_num
return
current_dir
python/paddle/fluid/inferencer.py
浏览文件 @
82b8a3c5
...
@@ -12,101 +12,5 @@
...
@@ -12,101 +12,5 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
__future__
import
print_function
# NOTE: inferencer is moved into fluid.contrib.inferencer.
__all__
=
[]
import
contextlib
from
.
import
core
from
.
import
executor
from
.
import
framework
from
.
import
io
from
.
import
parallel_executor
from
.
import
unique_name
from
.trainer
import
check_and_get_place
__all__
=
[
'Inferencer'
,
]
class
Inferencer
(
object
):
"""
Inferencer High Level API.
Args:
infer_func (Python func): Infer function that will return predict Variable
param_path (str): The path where the inference model is saved by fluid.io.save_params
place (Place): place to do the inference
parallel (bool): use parallel_executor to run the inference, it will use multi CPU/GPU.
Examples:
.. code-block:: python
def inference_program():
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
return y_predict
place = fluid.CPUPlace()
inferencer = fluid.Inferencer(
infer_func=inference_program, param_path="/tmp/model", place=place)
"""
def
__init__
(
self
,
infer_func
,
param_path
,
place
=
None
,
parallel
=
False
):
self
.
param_path
=
param_path
self
.
scope
=
core
.
Scope
()
self
.
parallel
=
parallel
self
.
place
=
check_and_get_place
(
place
)
self
.
inference_program
=
framework
.
Program
()
with
framework
.
program_guard
(
self
.
inference_program
):
with
unique_name
.
guard
():
self
.
predict_var
=
infer_func
()
with
self
.
_prog_and_scope_guard
():
# load params from param_path into scope
io
.
load_params
(
executor
.
Executor
(
self
.
place
),
param_path
)
if
parallel
:
with
self
.
_prog_and_scope_guard
():
self
.
exe
=
parallel_executor
.
ParallelExecutor
(
use_cuda
=
isinstance
(
self
.
place
,
core
.
CUDAPlace
),
loss_name
=
self
.
predict_var
.
name
)
else
:
self
.
exe
=
executor
.
Executor
(
self
.
place
)
self
.
inference_program
=
self
.
inference_program
.
clone
(
for_test
=
True
)
def
infer
(
self
,
inputs
,
return_numpy
=
True
):
"""
Do Inference for Inputs
Args:
inputs (map): a map of {"input_name": input_var} that will be feed into the inference program
return_numpy (bool): transform return value into numpy or not
Returns:
Tensor or Numpy: the predict value of the inference model for the inputs
Examples:
.. code-block:: python
tensor_x = numpy.random.uniform(0, 10, [batch_size, 13]).astype("float32")
results = inferencer.infer({'x': tensor_x})
"""
if
not
isinstance
(
inputs
,
dict
):
raise
ValueError
(
"inputs should be a map of {'input_name': input_var}"
)
with
self
.
_prog_and_scope_guard
():
results
=
self
.
exe
.
run
(
feed
=
inputs
,
fetch_list
=
[
self
.
predict_var
.
name
],
return_numpy
=
return_numpy
)
return
results
@
contextlib
.
contextmanager
def
_prog_and_scope_guard
(
self
):
with
framework
.
program_guard
(
main_program
=
self
.
inference_program
):
with
executor
.
scope_guard
(
self
.
scope
):
yield
python/paddle/fluid/trainer.py
浏览文件 @
82b8a3c5
...
@@ -12,1247 +12,5 @@
...
@@ -12,1247 +12,5 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
__future__
import
print_function
# NOTE: Trainer is moved into fluid.contrib.trainer.
__all__
=
[]
import
contextlib
import
os
import
errno
import
shutil
import
six
import
time
from
.
import
core
from
.
import
data_feeder
from
.
import
executor
from
.
import
framework
from
.
import
io
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
from
.
import
optimizer
as
opt_module
from
.
import
parallel_executor
from
.transpiler
import
distribute_transpiler
__all__
=
[
'Trainer'
,
'BeginEpochEvent'
,
'EndEpochEvent'
,
'BeginStepEvent'
,
'EndStepEvent'
,
'CheckpointConfig'
]
class
BeginEpochEvent
(
object
):
"""
The begin of a training epoch.
Args:
epoch_id(int): The current epoch ID.
"""
def
__init__
(
self
,
epoch_id
):
self
.
epoch
=
epoch_id
class
EndEpochEvent
(
object
):
"""
The end of a training epoch.
Args:
epoch_id(int): The current epoch ID.
"""
def
__init__
(
self
,
epoch_id
):
self
.
epoch
=
epoch_id
class
BeginStepEvent
(
object
):
"""
The begin of a training epoch.
Args:
epoch_id(int): The current epoch ID.
step_id(int): The current step ID.
"""
def
__init__
(
self
,
epoch_id
,
step_id
):
self
.
epoch
=
epoch_id
self
.
step
=
step_id
self
.
fetch_metrics
=
True
"""
If fetch_metrics is true, the metrics will be fetched at the
EndStepEvent. Default is True.
"""
class
EndStepEvent
(
object
):
"""
The end of a training step.
Args:
epoch_id(int): The current epoch ID.
step_id(int): The current step ID.
metrics(list): A list of fetched tensor. The order of this list is same
as the :code:`train_func` returns.
"""
def
__init__
(
self
,
epoch_id
,
step_id
,
metrics
):
self
.
epoch
=
epoch_id
self
.
step
=
step_id
self
.
metrics
=
metrics
class
CheckpointConfig
(
object
):
"""
Parameter object for :code:`save_checkpoint` and
:code:`fluid.Trainer`. Used to configuration how to save checkpoint.
Args:
checkpoint_dir(str): Directory path to save check point. Default is the
current directory.
max_num_checkpoints(int): The max number of local check points.
epoch_interval(int): Every number of epoch to save check point.
step_interval(int): Every number of step to save check point.
Examples:
>>> config = fluid.CheckpointConfig("./checkpoints")
>>> trainer = fluid.Trainer(train_func=train_program,
>>> place=place,
>>> optimizer_func=optimizer_func,
>>> checkpoint_config=config)
>>> trainer.train(...)
"""
def
__init__
(
self
,
checkpoint_dir
=
None
,
max_num_checkpoints
=
3
,
epoch_interval
=
1
,
step_interval
=
10
):
assert
epoch_interval
>=
1
assert
step_interval
>=
1
self
.
checkpoint_dir
=
checkpoint_dir
\
if
checkpoint_dir
is
not
None
else
os
.
getcwd
()
self
.
max_num_checkpoints
=
max_num_checkpoints
self
.
epoch_interval
=
epoch_interval
self
.
step_interval
=
step_interval
self
.
epoch_id
=
0
self
.
step_id
=
0
self
.
load_serial
=
None
self
.
pserver_id
=
None
self
.
lookup_table_name
=
None
def
check_and_get_place
(
place
):
"""
Check the type of place or get the default place
Args:
place(None|core.CUDAPlace|core.CPUPlace): the place that trainer will be executed on.
Raises:
TypeError if the type mismatched.
Returns:
the original place if it is not None.
if fluid is compiled with CUDA, returns CUDAPlace(0) by default.
Otherwise returns CPUPlace by default.
"""
if
place
is
None
:
if
core
.
is_compiled_with_cuda
():
return
core
.
CUDAPlace
(
0
)
else
:
return
core
.
CPUPlace
()
else
:
if
not
isinstance
(
place
,
core
.
CUDAPlace
)
and
not
isinstance
(
place
,
core
.
CPUPlace
):
raise
TypeError
(
"Place should be either CUDAPlace or CPUPlace"
)
return
place
class
Trainer
(
object
):
"""
A trainer wraps MultiGPU/MultiNode training loops and can be used to train a
simple neural network easily.
This API takes a :code:`train_func`. A :code:`train_func` is a function that
return loss as it first return value. The reset value can be fetched by
EndStepEvent.metrics
This API also takes a :code:`optimizer_func` that will return an optimizer
instance.
For example, to train a MLP for MNIST dataset, the sample program is
>>> import paddle.fluid as fluid
>>>
>>> def mlp(image, layer_sizes=[200, 100], activation="relu", num_classes=10):
>>> hidden = image
>>> for layer_size in layer_sizes:
>>> hidden = fluid.layers.fc(input=hidden, size=layer_size, act=activation)
>>> return fluid.layers.fc(input=hidden, size=num_classes, act="softmax")
>>>
>>> def train_mnist_mlp():
>>> img = fluid.layers.data(name='image', shape=[784])
>>> label = fluid.layers.data(name='label', shape=[1], dtype='int64')
>>> prediction = mlp(img)
>>> return fluid.layers.mean(fluid.layers.cross_entropy(prediction, label))
>>>
>>> def optimizer():
>>> return fluid.optimizer.Adam()
>>>
>>> trainer = Trainer(train_func=train_mnist_mlp,
>>> optimizer_func=optimizer,
>>> place=fluid.CUDAPlace(0),
>>> parallel=True)
>>>
>>> def train_callback(event):
>>> if isinstance(event, fluid.EndStepEvent):
>>> print "Epoch ID", event.epoch, "Step ID",
\
>>> event.step, "AvgLoss", event.metrics[0]
>>> elif isinstance(event, fluid.EndEpochEvent):
>>> trainer.save_params("./model_{0}".format(event.epoch))
>>>
>>> trainer.train(num_epochs=100, event_handler=train_callback)
For more example, please see :ref:`api_guide_high_level_api`.
Args:
train_func(callable): A function which will return loss. The loss must be
a scalar tensor.
optimizer_func(callable): A function that returns an Optimizer object.
place(CUDAPlace|CPUPlace): The device place of this trainer. If
:code:`parallel=True,` all CUDA Places will be used if :code:`place`
is a :code:`CUDAPlace`.
parallel(bool): True if use multiple devices.
checkpoint_config(CheckpointConfig): Configuration about how to save
checkpoints.
"""
def
__init__
(
self
,
train_func
,
optimizer_func
,
param_path
=
None
,
place
=
None
,
parallel
=
False
,
checkpoint_config
=
None
):
self
.
__stop
=
False
self
.
parallel
=
parallel
# config for checkpoint
# only chief worker will save variables
self
.
trainer_id
=
0
self
.
checkpoint_cfg
=
checkpoint_config
if
self
.
checkpoint_cfg
:
assert
isinstance
(
self
.
checkpoint_cfg
,
CheckpointConfig
)
serial
=
_get_latest_checkpoint_serial
(
self
.
checkpoint_cfg
.
checkpoint_dir
)
self
.
checkpoint_cfg
.
load_serial
=
serial
if
serial
>=
0
else
None
self
.
scope
=
core
.
Scope
()
# 1. we need to generate a framework.Program by calling
# program_func. Reference: fluid.program_guard in
# test_word2vec.py
self
.
startup_program
=
framework
.
Program
()
self
.
train_program
=
framework
.
Program
()
with
framework
.
program_guard
(
self
.
train_program
,
self
.
startup_program
):
program_func_outs
=
train_func
()
self
.
train_func_outputs
=
program_func_outs
if
isinstance
(
program_func_outs
,
list
)
else
[
program_func_outs
]
self
.
test_program
=
self
.
train_program
.
clone
(
for_test
=
True
)
# The first element of program_func_outs is loss.
loss
=
self
.
train_func_outputs
[
0
]
optimizer
=
optimizer_func
()
if
not
isinstance
(
optimizer
,
opt_module
.
Optimizer
):
raise
TypeError
(
"The optimizer should be an instance of Optimizer"
)
optimize_ops
,
params_grads
=
optimizer
.
minimize
(
loss
)
self
.
place
=
check_and_get_place
(
place
)
self
.
_dist_transpile_if_necessary
(
optimize_ops
,
params_grads
)
# 2. move the default_main_program to self.program and run the
# default_startup program on an empty core.Scope()
# Run startup program
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
place
)
exe
.
run
(
self
.
startup_program
)
if
self
.
checkpoint_cfg
and
self
.
checkpoint_cfg
.
load_serial
is
not
None
:
self
.
_load_checkpoint
()
if
param_path
and
os
.
path
.
isdir
(
param_path
):
with
self
.
_prog_and_scope_guard
():
# load params from param_path into scope
io
.
load_persistables
(
executor
=
exe
,
dirname
=
param_path
,
main_program
=
self
.
startup_program
)
def
_transpile_nccl2_dist
(
self
):
# PADDLE_TRAINER_IPS
if
"PADDLE_TRAINER_IPS"
not
in
os
.
environ
:
self
.
nccl_id_var
=
None
else
:
self
.
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
))
port
=
os
.
getenv
(
"PADDLE_PSERVER_PORT"
)
worker_ips
=
os
.
getenv
(
"PADDLE_TRAINER_IPS"
)
worker_endpoints
=
[]
for
ip
in
worker_ips
.
split
(
","
):
worker_endpoints
.
append
(
':'
.
join
([
ip
,
port
]))
self
.
num_trainers
=
len
(
worker_endpoints
)
current_endpoint
=
os
.
getenv
(
"PADDLE_CURRENT_IP"
)
+
":"
+
port
worker_endpoints
.
remove
(
current_endpoint
)
# TODO(wuyi): use self.nccl_id_var, self.num_trainers and self.trainer_id
# in ParallelExecutor to start
# distributed training using NCCL2
self
.
nccl_id_var
=
self
.
startup_program
.
global_block
().
create_var
(
name
=
"NCCLID"
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
self
.
startup_program
.
global_block
().
append_op
(
type
=
"gen_nccl_id"
,
inputs
=
{},
outputs
=
{
"NCCLID"
:
self
.
nccl_id_var
},
attrs
=
{
"endpoint"
:
current_endpoint
,
"endpoint_list"
:
worker_endpoints
,
"trainer_id"
:
self
.
trainer_id
})
def
_dist_transpile_if_necessary
(
self
,
optimize_ops
,
params_grads
):
self
.
_transpile_nccl2_dist
()
if
self
.
nccl_id_var
!=
None
:
return
if
"PADDLE_TRAINING_ROLE"
not
in
os
.
environ
:
return
# the port of all pservers, needed by both trainer and pserver
port
=
os
.
getenv
(
"PADDLE_PSERVER_PORT"
,
"6174"
)
# comma separated ips of all pservers, needed by trainer and
# pserver
pserver_ips
=
os
.
getenv
(
"PADDLE_PSERVER_IPS"
,
""
)
eplist
=
[]
for
ip
in
pserver_ips
.
split
(
","
):
eplist
.
append
(
':'
.
join
([
ip
,
port
]))
pserver_endpoints
=
","
.
join
(
eplist
)
# total number of workers/trainers in the job, needed by
# trainer and pserver
trainers
=
int
(
os
.
getenv
(
"PADDLE_TRAINERS"
))
# the IP of the local machine, needed by pserver only
current_endpoint
=
os
.
getenv
(
"PADDLE_CURRENT_IP"
,
""
)
+
":"
+
port
# the unique trainer id, starting from 0, needed by trainer
# only
self
.
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
"0"
))
# the role, should be either PSERVER or TRAINER
training_role
=
os
.
getenv
(
"PADDLE_TRAINING_ROLE"
)
with
self
.
_prog_and_scope_guard
():
t
=
distribute_transpiler
.
DistributeTranspiler
()
t
.
transpile
(
self
.
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
if
self
.
checkpoint_cfg
:
pserver_id
=
eplist
.
index
(
current_endpoint
)
self
.
checkpoint_cfg
.
pserver_id
=
pserver_id
if
t
.
has_distributed_lookup_table
:
self
.
checkpoint_cfg
.
lookup_table_name
=
t
.
table_name
self
.
train_program
=
t
.
get_pserver_program
(
current_endpoint
)
self
.
startup_program
=
t
.
get_startup_program
(
current_endpoint
,
self
.
train_program
)
elif
training_role
==
"TRAINER"
:
self
.
train_program
=
t
.
get_trainer_program
()
else
:
raise
ValueError
(
'TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
def
stop
(
self
):
"""
stop training
"""
self
.
__stop
=
True
def
train
(
self
,
num_epochs
,
event_handler
,
reader
=
None
,
feed_order
=
None
):
"""
Start the train loop to train the model.
Args:
num_epochs(int): The number of epoch. An epoch will process all data in reader
event_handler(callable): The event handler. A function with type (ev:Event)->void
reader(callable): A reader creator object. See also
:ref:`api_guide_python_reader` .
feed_order(list): Feeding order of reader. None will following the defining
order in program
Returns:
None
"""
training_role
=
os
.
getenv
(
"PADDLE_TRAINING_ROLE"
,
""
)
if
training_role
==
"PSERVER"
:
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
self
.
place
)
exe
.
run
()
return
if
self
.
parallel
:
self
.
_train_by_parallel_executor
(
num_epochs
,
event_handler
,
reader
,
feed_order
)
else
:
self
.
_train_by_executor
(
num_epochs
,
event_handler
,
reader
,
feed_order
)
def
test
(
self
,
reader
,
feed_order
):
"""
Test the model on given test data
Args:
reader(callable): The reader that yields test data.
feed_order(list): Feeding order of reader. None will following the
defining order in program
"""
return
self
.
_test_by_executor
(
reader
,
feed_order
,
self
.
train_func_outputs
)
def
save_params
(
self
,
param_path
):
"""
Save all parameters into :code:`param_path`.
Args:
param_path(str): The path to save parameters.
Returns:
None
"""
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
self
.
place
)
io
.
save_persistables
(
exe
,
dirname
=
param_path
)
def
save_inference_model
(
self
,
param_path
,
feeded_var_names
,
target_var_indexes
):
"""
Save model for cpp inference into :code:`param_path`.
Args:
param_path(str): The path to save parameters.
feeded_var_names(list(str)): The name of the vars that you
need to feed in before run program.
target_var_indexes(list(int)): the index of target var that
you need to return in trainer.train_func.
Returns:
None
"""
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
self
.
place
)
target_vars
=
[
self
.
train_func_outputs
[
index
]
for
index
in
target_var_indexes
]
io
.
save_inference_model
(
param_path
,
feeded_var_names
,
target_vars
,
exe
)
@
contextlib
.
contextmanager
def
_prog_and_scope_guard
(
self
):
with
framework
.
program_guard
(
main_program
=
self
.
train_program
,
startup_program
=
self
.
startup_program
):
with
executor
.
scope_guard
(
self
.
scope
):
yield
def
_train_by_executor
(
self
,
num_epochs
,
event_handler
,
reader
,
feed_order
):
"""
Train by Executor and single device.
Args:
num_epochs:
event_handler:
reader:
feed_order:
Returns:
"""
with
self
.
_prog_and_scope_guard
():
feed_var_list
=
build_feed_var_list
(
self
.
train_program
,
feed_order
)
feeder
=
data_feeder
.
DataFeeder
(
feed_list
=
feed_var_list
,
place
=
self
.
place
)
exe
=
executor
.
Executor
(
self
.
place
)
reader
=
feeder
.
decorate_reader
(
reader
,
multi_devices
=
False
)
self
.
_train_by_any_executor
(
event_handler
,
exe
,
num_epochs
,
reader
)
def
_train_by_any_executor
(
self
,
event_handler
,
exe
,
num_epochs
,
reader
):
if
self
.
checkpoint_cfg
:
epochs
=
[
epoch_id
for
epoch_id
in
range
(
num_epochs
)
if
epoch_id
>=
self
.
checkpoint_cfg
.
epoch_id
]
else
:
epochs
=
[
epoch_id
for
epoch_id
in
range
(
num_epochs
)]
for
epoch_id
in
epochs
:
event_handler
(
BeginEpochEvent
(
epoch_id
))
for
step_id
,
data
in
enumerate
(
reader
()):
if
self
.
__stop
:
if
self
.
checkpoint_cfg
:
self
.
_clean_checkpoint
()
return
if
self
.
checkpoint_cfg
and
self
.
checkpoint_cfg
.
load_serial
\
and
self
.
checkpoint_cfg
.
step_id
>=
step_id
and
self
.
checkpoint_cfg
.
epoch_id
==
epoch_id
:
continue
begin_event
=
BeginStepEvent
(
epoch_id
,
step_id
)
event_handler
(
begin_event
)
if
begin_event
.
fetch_metrics
:
metrics
=
exe
.
run
(
feed
=
data
,
fetch_list
=
[
var
.
name
for
var
in
self
.
train_func_outputs
])
else
:
metrics
=
exe
.
run
(
feed
=
data
,
fetch_list
=
[])
if
self
.
checkpoint_cfg
:
self
.
_save_checkpoint
(
epoch_id
,
step_id
)
event_handler
(
EndStepEvent
(
epoch_id
,
step_id
,
metrics
))
event_handler
(
EndEpochEvent
(
epoch_id
))
if
self
.
checkpoint_cfg
:
self
.
_clean_checkpoint
()
def
_test_by_executor
(
self
,
reader
,
feed_order
,
fetch_list
):
with
executor
.
scope_guard
(
self
.
scope
):
feed_var_list
=
build_feed_var_list
(
self
.
test_program
,
feed_order
)
feeder
=
data_feeder
.
DataFeeder
(
feed_list
=
feed_var_list
,
place
=
self
.
place
)
exe
=
executor
.
Executor
(
self
.
place
)
accumulated
=
len
(
fetch_list
)
*
[
0
]
count
=
0
for
data
in
reader
():
outs
=
exe
.
run
(
program
=
self
.
test_program
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
fetch_list
)
accumulated
=
[
x
[
0
]
+
x
[
1
][
0
]
for
x
in
zip
(
accumulated
,
outs
)]
count
+=
1
return
[
x
/
count
for
x
in
accumulated
]
def
_train_by_parallel_executor
(
self
,
num_epochs
,
event_handler
,
reader
,
feed_order
):
with
self
.
_prog_and_scope_guard
():
pe
=
self
.
_get_or_create_parallel_executor
()
feed_var_list
=
build_feed_var_list
(
self
.
train_program
,
feed_order
)
feeder
=
data_feeder
.
DataFeeder
(
feed_list
=
feed_var_list
,
place
=
self
.
place
)
reader
=
feeder
.
decorate_reader
(
reader
,
multi_devices
=
True
)
self
.
_train_by_any_executor
(
event_handler
,
pe
,
num_epochs
,
reader
)
def
_get_parallel_executor
(
self
):
return
getattr
(
self
,
'parallel_executor'
,
None
)
def
_get_or_create_parallel_executor
(
self
):
if
self
.
_get_parallel_executor
()
is
None
:
self
.
parallel_executor
=
parallel_executor
.
ParallelExecutor
(
use_cuda
=
isinstance
(
self
.
place
,
core
.
CUDAPlace
),
loss_name
=
self
.
train_func_outputs
[
0
].
name
)
return
self
.
_get_parallel_executor
()
def
_clean_checkpoint
(
self
):
assert
self
.
checkpoint_cfg
clean_checkpoint
(
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
)
def
_get_checkpoint_load_args
(
self
):
"""
epoch_id and step_id are runtime arguments, they are not variables, will load them independently.
"""
return
[
"epoch_id"
,
"step_id"
]
def
_get_checkpoint_save_args
(
self
,
epoch_id
,
step_id
):
"""
epoch_id and step_id are runtime arguments, they are not variables, will save them independently.
"""
trainer_args
=
{}
trainer_args
[
"epoch_id"
]
=
epoch_id
trainer_args
[
"step_id"
]
=
step_id
return
trainer_args
def
_save_checkpoint
(
self
,
epoch_id
,
step_id
):
assert
self
.
checkpoint_cfg
if
epoch_id
%
self
.
checkpoint_cfg
.
epoch_interval
==
0
\
and
step_id
%
self
.
checkpoint_cfg
.
step_interval
==
0
:
exe
=
executor
.
Executor
(
self
.
place
)
save_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
trainer_id
=
self
.
trainer_id
,
trainer_args
=
self
.
_get_checkpoint_save_args
(
epoch_id
,
step_id
),
main_program
=
self
.
train_program
,
max_num_checkpoints
=
self
.
checkpoint_cfg
.
max_num_checkpoints
)
def
_load_checkpoint
(
self
):
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
self
.
place
)
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
main_program
=
self
.
startup_program
)
if
not
self
.
checkpoint_cfg
.
pserver_id
:
load_trainer_args
=
self
.
_get_checkpoint_load_args
()
trainer_args
=
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
trainer_id
,
is_trainer
=
True
,
load_trainer_args
=
load_trainer_args
)
if
len
(
trainer_args
)
!=
2
:
raise
ValueError
(
"the return trainer_args length do not equal _get_checkpoint_load_args"
)
self
.
checkpoint_cfg
.
epoch_id
=
int
(
trainer_args
[
0
])
self
.
checkpoint_cfg
.
step_id
=
int
(
trainer_args
[
1
])
else
:
if
self
.
checkpoint_cfg
.
lookup_table_name
:
load_checkpoint
(
executor
=
exe
,
checkpoint_dir
=
self
.
checkpoint_cfg
.
checkpoint_dir
,
main_program
=
self
.
startup_program
,
role_id
=
self
.
checkpoint_cfg
.
pserver_id
,
is_trainer
=
False
,
load_trainer_args
=
None
,
load_lookup_table
=
self
.
checkpoint_cfg
.
lookup_table_name
)
def
build_feed_var_list
(
program
,
feed_order
):
if
not
isinstance
(
program
,
framework
.
Program
):
raise
TypeError
(
"The 'program' should be an object of Program"
)
if
isinstance
(
feed_order
,
list
):
feed_var_list
=
[
program
.
global_block
().
var
(
var_name
)
for
var_name
in
feed_order
]
else
:
if
not
isinstance
(
feed_order
,
dict
):
raise
TypeError
(
"The 'feed_order' should be either None, list or dict."
)
if
not
sorted
(
feed_order
.
values
())
==
list
(
range
(
len
(
feed_order
))):
raise
ValueError
(
"The values of 'feed_order' should be a permutation of [0, len(feed_order))"
)
sorted_pair_list
=
sorted
(
six
.
iteritems
(
feed_order
),
key
=
lambda
item
:
item
[
1
])
feed_var_list
=
[
program
.
global_block
().
var
(
pair
[
0
])
for
pair
in
sorted_pair_list
]
return
feed_var_list
# move Checkpoint APIs from io.py to trainer.py, make all of them are private.
SUCCESS_MARK_FILENAME
=
"_SUCCESS"
CHECKPOINT_PREFIX
=
"checkpoint"
MODEL_DIR
=
"__model__"
LOOKUP_TABLE_DIR
=
"__lookup_table__"
TRAINER_PREFIX
=
"trainer"
CHECKPOINT_SEPARATOR
=
"_"
def
save_checkpoint
(
executor
,
checkpoint_dir
,
trainer_id
,
main_program
,
trainer_args
=
None
,
max_num_checkpoints
=
3
,
lookup_table
=
None
,
pserver_endpoints
=
None
):
"""
This function filters out all checkpoint variables from the give
main_program and then saves these variables to the `checkpoint_dir`
directory.
In the training precess, we generally save a checkpoint in each
iteration. So there might be a lot of checkpoints in the
`checkpoint_dir`. To avoid them taking too much disk space, the
`max_num_checkpoints` are introduced to limit the total number of
checkpoints. If the number of existing checkpints is greater than
the `max_num_checkpoints`, oldest ones will be scroll deleted.
A variable is a checkpoint variable and will be saved if it meets
all following conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for save checkpoint.
checkpoint_dir(str): The folder where to save checkpoints.
trainer_id(int): currect trainer id, if id is equal to 0, the trainer
is chief.
trainer_args(dict|None): Current training arguments. Such as 'epoch_id'
and 'step_id'.
Defaut: None
main_program(Program): The program whose checkpoint variables will
be saved.
max_num_checkpoints(int): The max number of total number of existing
checkpoints.
Default: 3
lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
pserver_endpoints(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get pserver_endpoints by
distribute arguments.
Returns:
None
Raises:
ValueError: If `checkpoint_dir` is None.
AssertionError: If `trainer_args` is not a dict.
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
path = "./checkpoints"
prog = fluid.default_main_program()
trainer_args = {"epoch_id": 200,
"step_id": 20} # just an example
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
save_checkpoint(executor=exe,
checkpoint_dir=path,
trainer_id=0,
trainer_args=trainer_args,
main_program=prog,
max_num_checkpoints=3,
lookup_table=table_name,
pserver_endpoints = ps_endpoints)
"""
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
if
main_program
is
None
:
raise
ValueError
(
'main_program should not be None.'
)
if
trainer_args
:
assert
isinstance
(
trainer_args
,
dict
)
is_chief
=
trainer_id
==
0
_make_chekcpoint_dirs
(
checkpoint_dir
)
serial
=
_get_latest_checkpoint_serial
(
checkpoint_dir
)
+
1
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
_save_trainer_args
(
cur_dir
,
trainer_id
,
trainer_args
)
if
is_chief
:
_save_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
)
if
is_chief
and
lookup_table
and
pserver_endpoints
:
_save_pserver_vars_by_notify
(
executor
,
cur_dir
,
lookup_table
,
pserver_endpoints
)
_scroll_delete
(
checkpoint_dir
,
max_num_checkpoints
)
def
load_checkpoint
(
executor
,
checkpoint_dir
,
main_program
,
role_id
=
0
,
is_trainer
=
True
,
load_trainer_args
=
None
,
load_lookup_table
=
None
):
"""
This function filters out all checkpoint variables from the give
main_program and then try to load these variables from the
`checkpoint_dir` directory.
In the training precess, we generally save a checkpoint in each
iteration. So there are more than one checkpoint in the
`checkpoint_dir` (each checkpoint has its own sub folder), use
`serial` to specify which serial of checkpoint you would like to
load.
A variable is a checkpoint variable and will be loaded if it meets
all following conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for loading checkpoint.
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
main_program(Program): The program whose checkpoint variables will
be loaded.
role_id(int): the trainer id or the parameter server id.
is_trainer(bool): trainer is True and parameter server is False.
load_trainer_args(list|None): list about load trainer args.
load_lookup_table(str|None): the lookup table name
Returns:
None
Raises:
ValueError: If `checkpoint_dir` is None.
ValueError: If `main_program` is None.
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
path = "./checkpoints"
prog = fluid.default_main_program()
load_checkpoint(executor=exe, checkpoint_dir=path,
serial=9, main_program=prog)
# In this example, `load_checkpoint` function
# will first filters out all checkpoint variables in the default
# main program, and then try to load these variables form the
# folder "./checkpoints/checkpoint_9/__model__".
"""
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
serial
=
_get_latest_checkpoint_serial
(
checkpoint_dir
)
# there are nothing need to be loaded
if
serial
is
None
or
serial
<
0
:
return
if
main_program
is
None
:
raise
ValueError
(
'main_program should not be None.'
)
if
is_trainer
and
load_trainer_args
is
None
:
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
_load_persist_vars_without_grad
(
executor
,
cur_dir
,
main_program
,
True
)
return
if
is_trainer
and
load_trainer_args
:
return
_load_trainer_args
(
checkpoint_dir
,
serial
,
role_id
,
load_trainer_args
)
if
not
is_trainer
and
load_lookup_table
:
_load_lookup_table_vars
(
executor
,
checkpoint_dir
,
main_program
,
role_id
,
load_lookup_table
)
def
clean_checkpoint
(
checkpoint_dir
,
delete_dir
=
False
):
"""
clean the checkpoint dir, when the train exits normally,
the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised.
: param checkpoint_dir
: param delete_dir
"""
if
checkpoint_dir
is
None
:
raise
ValueError
(
"'checkpoint_dir' should not be None"
)
_scroll_delete
(
checkpoint_dir
,
max_num_checkpoints
=
0
)
if
delete_dir
and
not
os
.
listdir
(
checkpoint_dir
):
os
.
rmdir
(
checkpoint_dir
)
def
_load_persist_vars_without_grad
(
executor
,
dirname
,
program
,
has_model_dir
=
False
):
"""
This function filters out all checkpoint variables from the give
program and then trys to load these variables from the given directory.
A variable is a checkpoint variable if it meets all following
conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for loading variables.
dirname(str): The directory path.
program(Program): The program whose checkpoint variables will
be loaded.
has_model_dir(bool): if True, the function loads variables
from a sub directory named '__model__'.
Default: False
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
_load_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog, has_model_dir=True)
# In this example, `_load_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default
# main program, and then trys to load these variables form the
# folder "./my_paddle_model/__model__".
"""
if
has_model_dir
:
dirname
=
_get_model_dir
(
dirname
)
io
.
load_vars
(
executor
,
dirname
=
dirname
,
main_program
=
program
,
predicate
=
_is_checkpoint_var
,
filename
=
None
)
def
_load_lookup_table_vars
(
executor
,
dirname
,
program
,
pserver_id
,
table_name
):
"""
The parameter server will load lookup table's local file in
selectedrows variable.
Args:
executor(Executor): The executor to run for loading persistable variables
dirname(str): The directory path
main_program(Program): Find the variable named table_name in main_program
pserver_id(int): the serial number in pserver_endpoints list
table_name(str): lookup table name
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
dirname = "./checkpoints/checkpoint_9/"
prog = fluid.default_main_program()
pserver_id = 1
table_name = "share_w"
_load_lookup_table_vars(executor=exe,
dirname=dirname, program=prog, pserver_id=pserver_id,
table_name=table_name)
"""
for
var
in
program
.
list_vars
():
if
var
.
name
==
table_name
:
lookup_table_var
=
var
break
assert
lookup_table_var
is
not
None
lookup_table_dir
=
os
.
path
.
join
(
dirname
,
LOOKUP_TABLE_DIR
)
table_file
=
table_name
+
CHECKPOINT_SEPARATOR
+
str
(
pserver_id
)
load_prog
=
framework
.
Program
()
load_block
=
load_prog
.
global_block
()
load_block
.
append_op
(
type
=
'load'
,
inputs
=
{},
outputs
=
{
'Out'
:
[
lookup_table_var
]},
attrs
=
{
'file_path'
:
os
.
path
.
join
(
lookup_table_dir
,
table_file
)})
executor
.
run
(
load_prog
)
def
_save_persist_vars_without_grad
(
executor
,
dirname
,
program
):
"""
This function filters out all checkpoint variables from the give
program and then save these variables to a sub-folder '__model__' of
the given directory.
A variable is a checkpoint variable if it meets all following
conditions:
1. It's persistable.
2. It's type is not FEED_MINIBATCH nor FETCH_LIST nor RAW.
3. It's name contains no "@GRAD" nor ".trainer_" nor ".block".
Args:
executor(Executor): The executor to run for saving variables.
dirname(str): The directory path.
program(Program): The program whose checkpoint variables will
be saved.
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
_save_persist_vars_without_grad(executor=exe,
dirname=param_path, program=prog)
# In this example, `_save_persist_vars_without_grad` function
# will first filters out all checkpoint variables in the default
# main program, and then saves these variables to the folder
# "./my_paddle_model/__model__".
"""
cur_dir
=
_get_model_dir
(
dirname
)
io
.
save_vars
(
executor
,
dirname
=
cur_dir
,
main_program
=
program
,
vars
=
None
,
predicate
=
_is_checkpoint_var
,
filename
=
None
)
_write_success
(
cur_dir
)
def
_save_pserver_vars_by_notify
(
executor
,
dirname
,
lookup_table
,
ps_endpoint_list
):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save checkpoints.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
_save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
ps_endpoint_list=ps_endpoints)
"""
cur_dir
=
_get_lookuptable_dir
(
dirname
)
checkpoint_notify_program
=
framework
.
Program
()
checkpoint_notify_block
=
checkpoint_notify_program
.
global_block
()
attrs
=
{}
attrs
[
'epmap'
]
=
ps_endpoint_list
attrs
[
'dir'
]
=
cur_dir
attrs
[
'lookup_table'
]
=
lookup_table
checkpoint_notify_block
.
append_op
(
type
=
'checkpoint_notify'
,
inputs
=
{},
outputs
=
{},
attrs
=
attrs
)
executor
.
run
(
checkpoint_notify_program
)
def
_save_trainer_args
(
dirname
,
trainer_id
,
trainer_args
):
assert
isinstance
(
trainer_args
,
dict
)
cur_dir
=
_get_trainer_dir
(
dirname
,
trainer_id
)
for
name
,
value
in
six
.
iteritems
(
trainer_args
):
args_file
=
os
.
path
.
join
(
cur_dir
,
name
)
with
open
(
args_file
,
'w'
)
as
f
:
f
.
write
(
str
(
value
))
_write_success
(
cur_dir
)
def
_load_trainer_args
(
checkpoint_dir
,
serial
,
trainer_id
,
trainer_args
):
"""
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
Args:
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
trainer_id(int): current trainer id.
trainer_args(list): list about load trainer args
Return:
None
Examples:
.. code-block:: python
param_path = "./checkpoint/"
serial = 7
trainer_id = 2
trainer_args = ["epoch_id", "step_id"]
_load_trainer_args(checkpoint_dir=param_path, serial=serial,
trainer_id=trainer_id, trainer_args=trainer_args)
"""
assert
isinstance
(
trainer_args
,
list
)
cur_dir
=
_get_serial_dir
(
checkpoint_dir
,
serial
)
cur_dir
=
_get_trainer_dir
(
cur_dir
,
trainer_id
)
ret_values
=
[]
for
arg
in
trainer_args
:
cur_file
=
os
.
path
.
join
(
cur_dir
,
arg
)
with
open
(
cur_file
,
'r'
)
as
f
:
contents
=
f
.
read
()
ret_values
.
append
(
contents
.
strip
())
return
ret_values
def
_is_checkpoint_var
(
var
):
"""
the checkpoint will not save or load all the variables.
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
: param var(Variable)
"""
if
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FEED_MINIBATCH
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
FETCH_LIST
or
\
var
.
desc
.
type
()
==
core
.
VarDesc
.
VarType
.
RAW
:
return
False
# @GRAD are named for gradient variables, checkpoint will not save it.
if
"@GRAD"
in
var
.
name
:
return
False
# .trainer_ are named for distribute train variables, checkpoint will not save it.
if
".trainer_"
in
var
.
name
:
return
False
# .block is named for distribute train variables, checkpoint will not save it.
if
".block"
in
var
.
name
:
return
False
return
var
.
persistable
def
_make_chekcpoint_dirs
(
dirs
):
"""
_make_chekcpoint_dirs will makdir local directory directly, when the directory is exist, it will igore it.
"""
assert
dirs
is
not
None
if
os
.
path
.
isfile
(
dirs
):
raise
OSError
(
errno
.
ENOTDIR
,
"dirs path shoule be a Directory."
,
dirs
)
if
not
os
.
path
.
isdir
(
dirs
):
try
:
os
.
makedirs
(
dirs
)
except
OSError
as
err
:
if
err
.
errno
!=
errno
.
EEXIST
:
raise
err
def
_get_dir_serial
(
dirname
):
_
,
serial
=
dirname
.
split
(
CHECKPOINT_SEPARATOR
)
try
:
serial_num
=
int
(
serial
)
except
ValueError
:
serial_num
=
-
1
return
serial_num
def
_get_serial_dir
(
dirname
,
serial
):
serial_folder
=
CHECKPOINT_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
serial
)
serial_dir
=
os
.
path
.
join
(
dirname
,
serial_folder
)
_make_chekcpoint_dirs
(
serial_dir
)
return
serial_dir
def
_get_model_dir
(
dirname
):
model_dir
=
os
.
path
.
join
(
dirname
,
MODEL_DIR
)
_make_chekcpoint_dirs
(
model_dir
)
return
model_dir
def
_get_lookuptable_dir
(
dirname
):
lookuptable_dir
=
os
.
path
.
join
(
dirname
,
LOOKUP_TABLE_DIR
)
_make_chekcpoint_dirs
(
lookuptable_dir
)
return
lookuptable_dir
def
_get_trainer_dir
(
dirname
,
trainer_id
):
trainer_folder
=
TRAINER_PREFIX
+
CHECKPOINT_SEPARATOR
+
str
(
trainer_id
)
trainer_dir
=
os
.
path
.
join
(
dirname
,
trainer_folder
)
_make_chekcpoint_dirs
(
trainer_dir
)
return
trainer_dir
def
_scroll_delete
(
dirname
,
max_num_checkpoints
=
3
):
dirs
=
os
.
listdir
(
dirname
)
serial_map
=
{}
for
serial
in
dirs
:
serial_num
=
_get_dir_serial
(
serial
)
serial_map
[
serial_num
]
=
serial
if
len
(
list
(
serial_map
.
keys
()))
<=
max_num_checkpoints
:
return
serials
=
list
(
serial_map
.
keys
())
serials
.
sort
(
reverse
=
True
)
serials
=
serials
[
max_num_checkpoints
:]
for
serial
in
serials
:
cur_dir
=
_get_serial_dir
(
dirname
,
serial
)
try
:
shutil
.
rmtree
(
cur_dir
)
except
OSError
as
err
:
if
err
.
errno
!=
errno
.
ENOENT
:
raise
err
def
_write_success
(
dirname
):
"""
write an empty file named "_SUCCESS" in checkpoint dir, indicate this checkpoint is correct.
: param dirname
"""
success_file
=
os
.
path
.
join
(
dirname
,
SUCCESS_MARK_FILENAME
)
with
open
(
success_file
,
'a'
)
as
f
:
now
=
time
.
ctime
()
f
.
write
(
now
)
def
_get_latest_checkpoint_serial
(
checkpoint_dir
):
"""
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
: param checkpoint_dir
"""
if
not
checkpoint_dir
:
return
-
1
def
has_success
(
checkpoint_dir
,
cur_dir
):
"""
is _SUCCESS in this dir
"""
serial
=
_get_dir_serial
(
cur_dir
)
if
serial
==
-
1
or
not
os
.
path
.
isdir
(
os
.
path
.
join
(
checkpoint_dir
,
cur_dir
)):
return
-
1
success_path
=
os
.
path
.
join
(
_get_serial_dir
(
checkpoint_dir
,
serial
),
MODEL_DIR
,
SUCCESS_MARK_FILENAME
)
if
os
.
path
.
isfile
(
success_path
):
return
serial
if
not
os
.
path
.
isdir
(
checkpoint_dir
):
return
-
1
current_dir
=
-
1
dirs
=
os
.
listdir
(
checkpoint_dir
)
for
cur_dir
in
dirs
:
success_num
=
has_success
(
checkpoint_dir
,
cur_dir
)
if
success_num
>
current_dir
:
current_dir
=
success_num
return
current_dir
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录