A new design of our Python save / load interfaces
Created by: JiayiFeng
Related issue: #7163 (closed)
Issues
Currently, there are a few obvious issues in our program saving and loading interfaces:
-
The
save_params()
andload_params()
is useless and misleading. Some variables required by making checkpoints are notParameter
, which makes a model is unable to continue its training or make inference aftersave_params()
andload_params()
. The correct way of making a checkpoint is usingsave_persistables()
andload_persistables()
. -
save_var
andload_var
takes anexecutor
, builds a temporary program and then executes the program immediately. This makes variables saving and loading are triggered by Python code. We can't make checkpoints or save parameters in an environment without Python.
Proposed Solution
Base functions
To fix existing issues, we redesign our Python io module. The proposed new io module mainly consists of following base functions:
# serializes the given program and save it in dir
def save_program(program, dir):
...
# loads and deserializes program from the given dir
def load_program(dir):
...
return res_prog
# appends save_ops to the given program
# appended save_ops save variables given in var_list to dir
def save(var_list, dir, program):
...
# appends load_ops to the given program
def load(var_name_list, dir, program):
...
save()
and load()
can be considered as the layers of save_op
and load_op
. They don't execute immediatly like current load_vars()
and save_vars()
. They just append the save_op
or load_op
to the given program and leave the execution to the runtime.
By using these base functions, we can save our program
at any stage of model configuration, or save and load any specific variable values at any phrases of program execution.
Checkpoints
To make it more user-friendly, we can add some high-level wrappers for checkpoint related functions:
"""
checkpoint loader is a special startup program.
A regular startup holds only initializer ops,
while a checkpoint loader may hold some load_op.
In other words, a load_program initialize persistable variables from existing files.
"""
def build_checkpoint_loader(load_var_list, startup_program):
loader = Program()
load_var_name_set = set([var.name for var in load_var_list])
load(var_name_list=load_var_name_set,
program = loader)
if startup_program:
for op in startup_program:
if not (op.out in load_var_name_set):
loader.append(deep_copy(op))
return loader
def make_checkpoint(dir, predicate, startup_program, main_program):
persistable_var_list = filter(main_program.all_vars(), is_persistable())
if predicate:
persistable_var_list = filter(persistable_var_list, predicate)
save(persistable_var_list, dir, main_program)
loader = build_checkpoint_loader(persistable_var_list, startup_program)
save_program(loader, "./loader")
def get_checkpoint_loader(...):
loader = load_program(loader_dir)
return loader
A checkpoint consists of two parts: variables and a loader. A loader is a program. It acts like a startup program and the only difference between a loader and a regular startup program is that in a loader some variables may be initialized by existing file instead of initializer ops.
We can use the checkpoint as follows:
"""
Save checkpoints:
"""
x = layers.data(...)
var1 = layers.fc(...)
# some other model configurations
make_checkpoint(dir="./",
predicate=None,
startup_program=default_startup_program(),
main_program=default_main_program())
exe = Executor()
exe.run(default_startup_program())
while(...):
exe.run(default_main_program())
"""
Load a checkpoint and continue training:
"""
x = layers.data(...)
var1 = layers.fc(...)
# the same model configurations as above
make_checkpoint(dir="./",
predicate=None,
startup_program=default_startup_program(),
main_program=default_main_program())
loader = get_checkpoint_loader("./")
exe = Executor()
exe.run(loader)
while(...):
exe.run(default_main_program())
Inference Model Saving and Loading
Currently, we use Program.prune
to cut main program to get inference model. However, prune algorithm is complex and easy to be buggy. In recent discussions, we tend to leave the building of inference model to users:
"""
Saves the inference model:
"""
x = layers.data(...)
var1 = layers.fc(...)
# some other model configurations
cost = layers.mean(...)
save_program(default_main_program(), "./main_prog")
make_checkpoint(dir="./",
predicate=None,
startup_program=default_startup_program(),
main_program=default_main_program())
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost)
exe = Executor()
exe.run(default_startup_program())
while(...):
exe.run(default_main_program())
"""
Loads and uses the inference model:
"""
main_program = load_program("./main_prog")
loader = get_checkpoint_loader("./")
exe.run(loader)
while(...):
exe.run(main_program)
The key to getting inference model is saving the main program and making checkpoints precisely before optimizers.