未验证 提交 d896134f 编写于 作者: T tangwei12 提交者: GitHub

Merge pull request #10878 from seiriosPlus/new_api_about_cpkt

New api about checkpoint and models
...@@ -26,6 +26,7 @@ from trainer import BeginEpochEvent ...@@ -26,6 +26,7 @@ from trainer import BeginEpochEvent
from trainer import EndEpochEvent from trainer import EndEpochEvent
from trainer import BeginStepEvent from trainer import BeginStepEvent
from trainer import EndStepEvent from trainer import EndStepEvent
from trainer import CheckpointConfig
import inferencer import inferencer
from inferencer import Inferencer from inferencer import Inferencer
......
...@@ -24,7 +24,8 @@ __all__ = [ ...@@ -24,7 +24,8 @@ __all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params', 'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model', 'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'load_checkpoint', 'get_inference_program', 'save_checkpoint', 'load_checkpoint',
'clean_checkpoint' 'clean_checkpoint', 'load_persist_vars_without_grad',
'save_persist_vars_without_grad', 'get_latest_checkpoint_serial'
] ]
...@@ -457,95 +458,161 @@ def get_parameter_value_by_name(name, executor, program=None): ...@@ -457,95 +458,161 @@ def get_parameter_value_by_name(name, executor, program=None):
SUCCESS_MARK_FILENAME = "_SUCCESS" SUCCESS_MARK_FILENAME = "_SUCCESS"
CHECKPOINT_PREFIX = "checkpoint" CHECKPOINT_PREFIX = "checkpoint"
MODEL_DIR = "__model__"
TRAINER_PREFIX = "trainer"
CHECKPOINT_SEPARATOR = "_" CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor, def save_checkpoint(executor,
checkpoint_dir=None, checkpoint_dir,
max_num_checkpoints=3, trainer_id,
save_interval_secs=600, trainer_args=None,
main_program=None): main_program=None,
max_num_checkpoints=3):
""" """
Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory, Save Checkpoint will save persistable LodTensor variables from main_program in checkpoint directory,
the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy the directory named by serial number from 0 to (n -1), save_checkpoint use LRU strategy
to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most, to keep numbers of checkpoint directory, the numbers of checkpoint directory are max_num_checkpoints at most,
The interval between two saved checkpoints must greater than save_interval_secs. The interval between two saved checkpoints must greater than save_interval_secs.
:param executor :param executor executor for save the value
:param checkpoint_dir :param checkpoint_dir the checkpoint directory
:param max_num_checkpoints :param trainer_id currect trainer id, if id is equal to 0, the trainer is chief
:param save_interval_secs :param main_program will save all variables in program
:param main_program :param max_num_checkpoints will keep numbers of checkpoint serials not bigger than max_num_checkpoints
""" """
if checkpoint_dir is None: if checkpoint_dir is None:
checkpoint_dir = os.getcwd() raise ValueError("'checkpoint_dir' should not be None")
if trainer_args:
assert isinstance(trainer_args, dict)
if not os.path.isdir(checkpoint_dir): if not os.path.isdir(checkpoint_dir):
os.makedirs(checkpoint_dir) os.makedirs(checkpoint_dir)
serial = _get_lastest_checkpoint_dir(checkpoint_dir) serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
if serial >= 0 and not _interval_secs_exceed( cur_dir = _get_serial_dir(checkpoint_dir, serial)
_get_serial_dir(serial, checkpoint_dir), save_interval_secs):
return
serial += 1 save_trainer_args(cur_dir, trainer_id, trainer_args)
cur_dir = _get_serial_dir(serial, checkpoint_dir)
save_vars( if trainer_id == 0:
executor, save_persist_vars_without_grad(executor, cur_dir, main_program)
dirname=cur_dir,
main_program=main_program, _scroll_delete(checkpoint_dir, max_num_checkpoints)
vars=None,
predicate=_is_checkpoint_var,
filename=None)
_write_success(cur_dir)
_lru_delete(checkpoint_dir, max_num_checkpoints)
def load_checkpoint(executor, checkpoint_dir=None, main_program=None): def load_checkpoint(executor, checkpoint_dir, serial, main_program):
""" """
Load checkpoint from a directory by executor, Load checkpoint from a directory by executor,
it will find the most recent saved checkpoint file and load it auto. it will find the most recent saved checkpoint file and load it auto.
:param executor :param executor executor for load the value
:param checkpoint_dir :param checkpoint_dir the checkpoint directory
:param main_program :param serial the serial folder in checkpoint directory will be load
:param main_program will load all variables in program
""" """
if checkpoint_dir is None: if checkpoint_dir is None:
checkpoint_dir = os.getcwd() raise ValueError("'checkpoint_dir' should not be None")
serial = _get_lastest_checkpoint_dir(checkpoint_dir) if serial is None or serial < 0:
raise ValueError("'serial' should not be None or <0 ")
if serial < 0: if main_program is None:
return raise ValueError('main_program should not be None.')
cur_dir = _get_serial_dir(serial, checkpoint_dir) cur_dir = _get_serial_dir(checkpoint_dir, serial)
load_persist_vars_without_grad(executor, cur_dir, main_program, True)
load_vars(
executor,
dirname=cur_dir,
main_program=main_program,
predicate=_is_checkpoint_var,
filename=None)
def clean_checkpoint(checkpoint_dir, delete_dir=False): def clean_checkpoint(checkpoint_dir, delete_dir=False):
""" """
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before. clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised. delete_dir only works when the directory is empty, otherwise, OSError is raised.
:param checkpoint_dir
:param delete_dir
""" """
if checkpoint_dir is None: if checkpoint_dir is None:
checkpoint_dir = os.getcwd() raise ValueError("'checkpoint_dir' should not be None")
_lru_delete(checkpoint_dir, max_num_checkpoints=0) _scroll_delete(checkpoint_dir, max_num_checkpoints=0)
if delete_dir and not os.listdir(checkpoint_dir): if delete_dir and not os.listdir(checkpoint_dir):
os.rmdir(checkpoint_dir) os.rmdir(checkpoint_dir)
def _get_serial_dir(serial, checkpoint_dir): def load_persist_vars_without_grad(executor,
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) dirname,
return os.path.join(checkpoint_dir, serial_folder) program,
has_model_dir=False):
"""
load_persist_vars_without_grad will load variables from a directory by an executor,
the variable named end with "@GRAD" will not be loaded.
:param executor executor for load the value
:param dirname the checkpoint directory
:param program will load all variables in program
:param has_model_dir if has_model_dir is True, will load variables from sub directory named __model__
"""
if has_model_dir:
dirname = _get_model_dir(dirname)
load_vars(
executor,
dirname=dirname,
main_program=program,
predicate=_is_checkpoint_var,
filename=None)
def save_persist_vars_without_grad(executor, dirname, program):
"""
save_persist_vars_without_grad will save variables to a directory by an executor,
the variable named end with "@GRAD" will not be saved.
:param executor executor for load the value
:param dirname the checkpoint directory
:param program will load all variables in program
"""
cur_dir = _get_model_dir(dirname)
save_vars(
executor,
dirname=cur_dir,
main_program=program,
vars=None,
predicate=_is_checkpoint_var,
filename=None)
_write_success(cur_dir)
def save_trainer_args(dirname, trainer_id, trainer_args):
assert isinstance(trainer_args, dict)
cur_dir = _get_trainer_dir(dirname, trainer_id)
for name, value in trainer_args.iteritems():
args_file = os.path.join(cur_dir, name)
with open(args_file, 'w') as f:
f.write(str(value))
_write_success(cur_dir)
def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
assert isinstance(trainer_args, list)
cur_dir = _get_serial_dir(checkpoint_dir, serial)
cur_dir = _get_trainer_dir(cur_dir, trainer_id)
ret_values = []
for arg in trainer_args:
cur_file = os.path.join(cur_dir, arg)
with open(cur_file, 'r') as f:
contents = f.read()
ret_values.append(contents.strip())
return ret_values
def _is_checkpoint_var(var): def _is_checkpoint_var(var):
...@@ -559,36 +626,74 @@ def _is_checkpoint_var(var): ...@@ -559,36 +626,74 @@ def _is_checkpoint_var(var):
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW: var.desc.type() == core.VarDesc.VarType.RAW:
return False return False
# @GRAD are named for gradient variables, checkpoint will not save it.
if "@GRAD" in var.name:
return False
# .trainer_ are named for distribute train variables, checkpoint will not save it.
if ".trainer_" in var.name:
return False
if var.name.endswith("@GRAD"): # .block is named for distribute train variables, checkpoint will not save it.
if ".block" in var.name:
return False return False
return var.persistable return var.persistable
def _interval_secs_exceed(dirname, save_interval_secs): def _get_dir_serial(dirname):
dir_time = os.path.getmtime(dirname) _, serial = dirname.split(CHECKPOINT_SEPARATOR)
if save_interval_secs > (time.time() - dir_time):
return False try:
return True serial_num = int(serial)
except ValueError:
serial_num = -1
return serial_num
def _get_serial_dir(dirname, serial):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
serial_dir = os.path.join(dirname, serial_folder)
if not os.path.isdir(serial_dir):
os.makedirs(serial_dir)
return serial_dir
def _get_model_dir(dirname):
model_dir = os.path.join(dirname, MODEL_DIR)
def _lru_delete(dirname, max_num_checkpoints=3): if not os.path.isdir(model_dir):
os.makedirs(model_dir)
return model_dir
def _get_trainer_dir(dirname, trainer_id):
trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
trainer_dir = os.path.join(dirname, trainer_folder)
if not os.path.isdir(trainer_dir):
os.makedirs(trainer_dir)
return trainer_dir
def _scroll_delete(dirname, max_num_checkpoints=3):
dirs = os.listdir(dirname) dirs = os.listdir(dirname)
serials = [] serial_map = {}
for serial in dirs: for serial in dirs:
try: serial_num = _get_dir_serial(serial)
serials.append(int(serial)) serial_map[serial_num] = serial
except ValueError:
continue
if len(serials) <= max_num_checkpoints: if len(serial_map.keys()) <= max_num_checkpoints:
return return
serials = serial_map.keys()
serials.sort(reverse=True) serials.sort(reverse=True)
serials = serials[max_num_checkpoints:] serials = serials[max_num_checkpoints:]
for serial in serials: for serial in serials:
cur_dir = os.path.join(dirname, str(serial)) cur_dir = _get_serial_dir(dirname, serial)
shutil.rmtree(cur_dir) shutil.rmtree(cur_dir)
...@@ -604,33 +709,30 @@ def _write_success(dirname): ...@@ -604,33 +709,30 @@ def _write_success(dirname):
f.write(now) f.write(now)
def _get_lastest_checkpoint_dir(checkpoint_dir): def get_latest_checkpoint_serial(checkpoint_dir):
""" """
get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory get the latest file in checkpoint directory, the _SUCCESS file must exist in the directory
:param checkpoint_dir :param checkpoint_dir
""" """
if not checkpoint_dir.strip(): if not checkpoint_dir:
return -1 return -1
def has_success(checkpoint_dir, cur_dir): def has_success(checkpoint_dir, cur_dir):
""" """
is _SUCCESS in this dir is _SUCCESS in this dir
""" """
_, serial = cur_dir.split(CHECKPOINT_SEPARATOR)
try:
int(serial)
except ValueError:
return -1
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): serial = _get_dir_serial(cur_dir)
if serial == -1 or not os.path.isdir(
os.path.join(checkpoint_dir, cur_dir)):
return -1 return -1
success_path = os.path.join( success_path = os.path.join(
_get_serial_dir(serial, checkpoint_dir), SUCCESS_MARK_FILENAME) _get_serial_dir(checkpoint_dir, serial), MODEL_DIR,
SUCCESS_MARK_FILENAME)
if os.path.isfile(success_path): if os.path.isfile(success_path):
return int(serial) return serial
if not os.path.isdir(checkpoint_dir): if not os.path.isdir(checkpoint_dir):
return -1 return -1
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
import os
import tempfile
class TestCheckpoint(unittest.TestCase):
def setUp(self):
self.dirname = tempfile.mktemp()
self.max_num_checkpoints = 3
self.epoch_interval = 1
self.step_interval = 1
self.trainer_id = 0
self.chief = self.trainer_id == 0
self.place = fluid.CPUPlace()
self.epoch_id = 100
self.step_id = 20
def test_checkpoint(self):
self.save_checkpoint()
serial = fluid.io.get_latest_checkpoint_serial(self.dirname)
self.assertTrue(serial >= 0)
trainer_args = ["epoch_id", "step_id"]
epoch_id, step_id = fluid.io.load_trainer_args(
self.dirname, serial, self.trainer_id, trainer_args)
self.assertEqual(self.step_id, int(step_id))
self.assertEqual(self.epoch_id, int(epoch_id))
program = fluid.Program()
with fluid.program_guard(program):
exe = fluid.Executor(self.place)
fluid.io.load_checkpoint(exe, self.dirname, serial, program)
fluid.io.clean_checkpoint(self.dirname, delete_dir=True)
self.assertFalse(os.path.isdir(self.dirname))
def save_checkpoint(self):
config = fluid.CheckpointConfig(self.dirname, self.max_num_checkpoints,
self.epoch_interval, self.step_interval)
trainer_args = {}
trainer_args["epoch_id"] = self.epoch_id
trainer_args["step_id"] = self.step_id
program = fluid.Program()
with fluid.program_guard(program):
program.global_block().create_var(
name="scale_0",
psersistable=True,
dtype="float32",
shape=[32, 32])
exe = fluid.Executor(self.place)
for i in xrange(10):
fluid.io.save_checkpoint(exe, config.checkpoint_dir,
self.trainer_id, trainer_args, program,
config.max_num_checkpoints)
if __name__ == '__main__':
unittest.main()
...@@ -27,11 +27,8 @@ import parallel_executor ...@@ -27,11 +27,8 @@ import parallel_executor
from transpiler import distribute_transpiler from transpiler import distribute_transpiler
__all__ = [ __all__ = [
'Trainer', 'Trainer', 'BeginEpochEvent', 'EndEpochEvent', 'BeginStepEvent',
'BeginEpochEvent', 'EndStepEvent', 'CheckpointConfig'
'EndEpochEvent',
'BeginStepEvent',
'EndStepEvent',
] ]
...@@ -59,6 +56,35 @@ class EndStepEvent(object): ...@@ -59,6 +56,35 @@ class EndStepEvent(object):
self.metrics = metrics self.metrics = metrics
class CheckpointConfig(object):
def __init__(self,
checkpoint_dir=None,
max_num_checkpoints=3,
epoch_interval=1,
step_interval=10):
if checkpoint_dir is None:
self.checkpoint_dir = os.getcwd()
else:
self.checkpoint_dir = checkpoint_dir
self.max_num_checkpoints = max_num_checkpoints
if epoch_interval < 1:
self.epoch_interval = 1
else:
self.epoch_interval = epoch_interval
if step_interval < 1:
self.step_interval = 10
else:
self.step_interval = step_interval
self.epoch_id = 0
self.step_id = 0
self.load_serial = None
self.is_pserver = False
def check_and_get_place(place): def check_and_get_place(place):
""" """
Check the type of place or get the default place Check the type of place or get the default place
...@@ -99,13 +125,24 @@ class Trainer(object): ...@@ -99,13 +125,24 @@ class Trainer(object):
optimizer_func, optimizer_func,
param_path=None, param_path=None,
place=None, place=None,
parallel=False): parallel=False,
checkpoint_config=None):
self.__stop = False self.__stop = False
self.parallel = parallel self.parallel = parallel
# 1. we need to generate a framework.Program by calling # 1. we need to generate a framework.Program by calling
# program_func. Reference: fluid.program_guard in # program_func. Reference: fluid.program_guard in
# test_word2vec.py # test_word2vec.py
# config for checkpoint
# only chief worker will save variables
self.trainer_id = 0
self.checkpoint_cfg = checkpoint_config
if self.checkpoint_cfg:
assert isinstance(self.checkpoint_cfg, CheckpointConfig)
serial = io.get_latest_checkpoint_serial(
self.checkpoint_cfg.checkpoint_dir)
self.checkpoint_cfg.load_serial = serial if serial >= 0 else None
self.scope = core.Scope() self.scope = core.Scope()
self.startup_program = framework.Program() self.startup_program = framework.Program()
...@@ -137,9 +174,25 @@ class Trainer(object): ...@@ -137,9 +174,25 @@ class Trainer(object):
exe = executor.Executor(place) exe = executor.Executor(place)
exe.run(self.startup_program) exe.run(self.startup_program)
if param_path: if self.checkpoint_cfg and self.checkpoint_cfg.load_serial:
with self._prog_and_scope_guard():
exe = executor.Executor(place)
io.load_checkpoint(exe, self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial,
self.startup_program)
if not self.checkpoint_cfg.is_pserver:
epoch_id, step_id = io.load_trainer_args(
self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial, self.trainer_id,
self._get_checkpoint_load_args())
self.checkpoint_cfg.epoch_id = int(epoch_id)
self.checkpoint_cfg.step_id = int(step_id)
if param_path and os.path.isdir(param_path):
# load params from param_path into scope # load params from param_path into scope
io.load_persistables(exe, dirname=param_path) io.load_persist_vars_without_grad(
exe, dirname=param_path, program=self.startup_program)
def _transpile_nccl2_dist(self): def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS # PADDLE_TRAINER_IPS
...@@ -194,14 +247,18 @@ class Trainer(object): ...@@ -194,14 +247,18 @@ class Trainer(object):
current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
# the unique trainer id, starting from 0, needed by trainer # the unique trainer id, starting from 0, needed by trainer
# only # only
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
# the role, should be either PSERVER or TRAINER # the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE") training_role = os.getenv("PADDLE_TRAINING_ROLE")
with self._prog_and_scope_guard(): with self._prog_and_scope_guard():
t = distribute_transpiler.DistributeTranspiler() t = distribute_transpiler.DistributeTranspiler()
t.transpile( t.transpile(
trainer_id, pservers=pserver_endpoints, trainers=trainers) self.trainer_id, pservers=pserver_endpoints, trainers=trainers)
if training_role == "PSERVER": if training_role == "PSERVER":
if self.checkpoint_cfg:
self.is_pserver = True
self.train_program = t.get_pserver_program(current_endpoint) self.train_program = t.get_pserver_program(current_endpoint)
self.startup_program = t.get_startup_program(current_endpoint, self.startup_program = t.get_startup_program(current_endpoint,
self.train_program) self.train_program)
...@@ -294,11 +351,26 @@ class Trainer(object): ...@@ -294,11 +351,26 @@ class Trainer(object):
self._train_by_any_executor(event_handler, exe, num_epochs, reader) self._train_by_any_executor(event_handler, exe, num_epochs, reader)
def _train_by_any_executor(self, event_handler, exe, num_epochs, reader): def _train_by_any_executor(self, event_handler, exe, num_epochs, reader):
for epoch_id in range(num_epochs): if self.checkpoint_cfg:
epochs = [
epoch_id for epoch_id in range(num_epochs)
if epoch_id >= self.checkpoint_cfg.epoch_id
]
else:
epochs = [epoch_id for epoch_id in range(num_epochs)]
for epoch_id in epochs:
event_handler(BeginEpochEvent(epoch_id)) event_handler(BeginEpochEvent(epoch_id))
for step_id, data in enumerate(reader()): for step_id, data in enumerate(reader()):
if self.__stop: if self.__stop:
if self.checkpoint_cfg:
self._clean_checkpoint()
return return
if self.checkpoint_cfg and self.checkpoint_cfg.load_serial \
and self.checkpoint_cfg.step_id >= step_id and self.checkpoint_cfg.epoch_id == epoch_id:
continue
begin_event = BeginStepEvent(epoch_id, step_id) begin_event = BeginStepEvent(epoch_id, step_id)
event_handler(begin_event) event_handler(begin_event)
if begin_event.fetch_metrics: if begin_event.fetch_metrics:
...@@ -309,8 +381,13 @@ class Trainer(object): ...@@ -309,8 +381,13 @@ class Trainer(object):
]) ])
else: else:
metrics = exe.run(feed=data, fetch_list=[]) metrics = exe.run(feed=data, fetch_list=[])
if self.checkpoint_cfg:
self._save_checkpoint(epoch_id, step_id)
event_handler(EndStepEvent(epoch_id, step_id, metrics)) event_handler(EndStepEvent(epoch_id, step_id, metrics))
event_handler(EndEpochEvent(epoch_id)) event_handler(EndEpochEvent(epoch_id))
if self.checkpoint_cfg:
self._clean_checkpoint()
def _test_by_executor(self, reader, feed_order, fetch_list): def _test_by_executor(self, reader, feed_order, fetch_list):
with executor.scope_guard(self.scope): with executor.scope_guard(self.scope):
...@@ -349,6 +426,38 @@ class Trainer(object): ...@@ -349,6 +426,38 @@ class Trainer(object):
loss_name=self.train_func_outputs[0].name) loss_name=self.train_func_outputs[0].name)
return self._get_parallel_executor() return self._get_parallel_executor()
def _clean_checkpoint(self):
assert self.checkpoint_cfg
io.clean_checkpoint(checkpoint_dir=self.checkpoint_cfg.checkpoint_dir)
def _get_checkpoint_load_args(self):
"""
epoch_id and step_id are runtime arguments, they are not variables, will load them independently.
"""
return ["epoch_id", "step_id"]
def _get_checkpoint_save_args(self, epoch_id, step_id):
"""
epoch_id and step_id are runtime arguments, they are not variables, will save them independently.
"""
trainer_args = {}
trainer_args["epoch_id"] = epoch_id
trainer_args["step_id"] = step_id
return trainer_args
def _save_checkpoint(self, epoch_id, step_id):
assert self.checkpoint_cfg
if epoch_id % self.checkpoint_cfg.epoch_interval == 0 and step_id % self.checkpoint_cfg.step_interval == 0:
exe = executor.Executor(self.place)
io.save_checkpoint(
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
trainer_id=self.trainer_id,
trainer_args=self._get_checkpoint_save_args(epoch_id, step_id),
main_program=self.train_program,
max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints)
def build_feed_var_list(program, feed_order): def build_feed_var_list(program, feed_order):
if not isinstance(program, framework.Program): if not isinstance(program, framework.Program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册