提交 f96a344b 编写于 作者: Z Zhen Wang

add checkpoint function for pass.

上级 9726da54
...@@ -20,6 +20,7 @@ sys.path.append('..') ...@@ -20,6 +20,7 @@ sys.path.append('..')
import reader import reader
import models import models
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
from utility import save_persistable_nodes, load_persistable_nodes
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
...@@ -31,7 +32,8 @@ add_arg('num_epochs', int, 120, "number of epochs.") ...@@ -31,7 +32,8 @@ add_arg('num_epochs', int, 120, "number of epochs.")
add_arg('class_dim', int, 1000, "Class number.") add_arg('class_dim', int, 1000, "Class number.")
add_arg('image_shape', str, "3,224,224", "input image size") add_arg('image_shape', str, "3,224,224", "input image size")
add_arg('model_save_dir', str, "output", "model save directory") add_arg('model_save_dir', str, "output", "model save directory")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.") add_arg('pretrained_fp32_model', str, None, "Whether to use the pretrained float32 model to initialize the weights.")
add_arg('checkpoint', str, None, "Whether to resume the training process from the checkpoint.")
add_arg('lr', float, 0.1, "set learning rate.") add_arg('lr', float, 0.1, "set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.") add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.") add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.")
...@@ -180,7 +182,8 @@ def build_program(is_train, main_prog, startup_prog, args): ...@@ -180,7 +182,8 @@ def build_program(is_train, main_prog, startup_prog, args):
def train(args): def train(args):
# parameters from arguments # parameters from arguments
model_name = args.model model_name = args.model
pretrained_model = args.pretrained_model pretrained_fp32_model = args.pretrained_fp32_model
checkpoint = args.checkpoint
model_save_dir = args.model_save_dir model_save_dir = args.model_save_dir
data_dir = args.data_dir data_dir = args.data_dir
activation_quant_type = args.act_quant_type activation_quant_type = args.act_quant_type
...@@ -210,11 +213,11 @@ def train(args): ...@@ -210,11 +213,11 @@ def train(args):
main_graph = IrGraph(core.Graph(train_prog.desc), for_test=False) main_graph = IrGraph(core.Graph(train_prog.desc), for_test=False)
test_graph = IrGraph(core.Graph(test_prog.desc), for_test=True) test_graph = IrGraph(core.Graph(test_prog.desc), for_test=True)
if pretrained_model: if pretrained_fp32_model:
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name)) return os.path.exists(os.path.join(pretrained_fp32_model, var.name))
fluid.io.load_vars( fluid.io.load_vars(
exe, pretrained_model, main_program=train_prog, predicate=if_exist) exe, pretrained_fp32_model, main_program=train_prog, predicate=if_exist)
if args.use_gpu: if args.use_gpu:
visible_device = os.getenv('CUDA_VISIBLE_DEVICES') visible_device = os.getenv('CUDA_VISIBLE_DEVICES')
...@@ -248,6 +251,9 @@ def train(args): ...@@ -248,6 +251,9 @@ def train(args):
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
transform_pass.apply(test_graph) transform_pass.apply(test_graph)
if checkpoint:
load_persistable_nodes(exe, checkpoint, main_graph)
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False build_strategy.memory_optimize = False
build_strategy.enable_inplace = False build_strategy.enable_inplace = False
...@@ -327,6 +333,11 @@ def train(args): ...@@ -327,6 +333,11 @@ def train(args):
test_acc1, test_acc5)) test_acc1, test_acc5))
sys.stdout.flush() sys.stdout.flush()
save_checkpoint_path = os.path.join(model_save_dir, model_name, str(pass_id))
if not os.path.isdir(save_checkpoint_path):
os.makedirs(save_checkpoint_path)
save_persistable_nodes(exe, save_checkpoint_path, main_graph)
model_path = os.path.join(model_save_dir, model_name, args.act_quant_type) model_path = os.path.join(model_save_dir, model_name, args.act_quant_type)
float_path = os.path.join(model_path, 'float') float_path = os.path.join(model_path, 'float')
int8_path = os.path.join(model_path, 'int8') int8_path = os.path.join(model_path, 'int8')
......
#!/usr/bin/env bash #!/usr/bin/env bash
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0,1,2,3
#MobileNet v1: #MobileNet v1:
python quant.py \ python quant.py \
--model=MobileNet \ --model=MobileNet \
--pretrained_model=../data/pretrain/MobileNetV1_pretrained \ --pretrained_fp32_model=../data/pretrain/MobileNetV1_pretrained \
--use_gpu=True \ --use_gpu=True \
--data_dir=../data/ILSVRC2012 \ --data_dir=../data/ILSVRC2012 \
--batch_size=64 \ --batch_size=256 \
--total_images=1281167 \ --total_images=1281167 \
--class_dim=1000 \ --class_dim=1000 \
--image_shape=3,224,224 \ --image_shape=3,224,224 \
--model_save_dir=output/ \ --model_save_dir=output/ \
--lr_strategy=piecewise_decay \ --lr_strategy=piecewise_decay \
--num_epochs=10 \ --num_epochs=20 \
--lr=0.0001 \ --lr=0.0001 \
--act_quant_type=abs_max \ --act_quant_type=abs_max \
--wt_quant_type=abs_max --wt_quant_type=abs_max
...@@ -23,16 +23,16 @@ python quant.py \ ...@@ -23,16 +23,16 @@ python quant.py \
#ResNet50: #ResNet50:
#python quant.py \ #python quant.py \
# --model=ResNet50 \ # --model=ResNet50 \
# --pretrained_model=../data/pretrain/ResNet50_pretrained \ # --pretrained_fp32_model=../data/pretrain/ResNet50_pretrained \
# --use_gpu=True \ # --use_gpu=True \
# --data_dir=../data/ILSVRC2012 \ # --data_dir=../data/ILSVRC2012 \
# --batch_size=32 \ # --batch_size=128 \
# --total_images=1281167 \ # --total_images=1281167 \
# --class_dim=1000 \ # --class_dim=1000 \
# --image_shape=3,224,224 \ # --image_shape=3,224,224 \
# --model_save_dir=output/ \ # --model_save_dir=output/ \
# --lr_strategy=piecewise_decay \ # --lr_strategy=piecewise_decay \
# --num_epochs=10 \ # --num_epochs=20 \
# --lr=0.0001 \ # --lr=0.0001 \
# --act_quant_type=abs_max \ # --act_quant_type=abs_max \
# --wt_quant_type=abs_max # --wt_quant_type=abs_max
......
...@@ -17,9 +17,12 @@ from __future__ import absolute_import ...@@ -17,9 +17,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import distutils.util import distutils.util
import os
import numpy as np import numpy as np
import six import six
import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import Program
def print_arguments(args): def print_arguments(args):
...@@ -61,3 +64,78 @@ def add_arguments(argname, type, default, help, argparser, **kwargs): ...@@ -61,3 +64,78 @@ def add_arguments(argname, type, default, help, argparser, **kwargs):
type=type, type=type,
help=help + ' Default: %(default)s.', help=help + ' Default: %(default)s.',
**kwargs) **kwargs)
def save_persistable_nodes(executor, dirname, graph):
"""
Save persistable nodes to the given directory by the executor.
Args:
executor(Executor): The executor to run for saving node values.
dirname(str): The directory path.
graph(IrGraph): All the required persistable nodes in the graph will be saved.
"""
persistable_node_names = set()
persistable_nodes = []
all_persistable_nodes = graph.all_persistable_nodes()
for node in all_persistable_nodes:
name = node.name()
if name not in persistable_node_names:
persistable_node_names.add(name)
persistable_nodes.append(node)
program = Program()
var_list = []
for node in persistable_nodes:
var_desc = node.var()
if var_desc.type() == core.VarDesc.VarType.RAW or \
var_desc.type() == core.VarDesc.VarType.READER:
continue
var = program.global_block().create_var(
name=var_desc.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
var_list.append(var)
fluid.io.save_vars(executor=executor, dirname=dirname, vars=var_list)
def load_persistable_nodes(executor, dirname, graph):
"""
Load persistable node values from the given directory by the executor.
Args:
executor(Executor): The executor to run for loading node values.
dirname(str): The directory path.
graph(IrGraph): All the required persistable nodes in the graph will be loaded.
"""
persistable_node_names = set()
persistable_nodes = []
all_persistable_nodes = graph.all_persistable_nodes()
for node in all_persistable_nodes:
name = node.name()
if name not in persistable_node_names:
persistable_node_names.add(name)
persistable_nodes.append(node)
program = Program()
var_list = []
def _exist(var):
return os.path.exists(os.path.join(dirname, var.name))
for node in persistable_nodes:
var_desc = node.var()
if var_desc.type() == core.VarDesc.VarType.RAW or \
var_desc.type() == core.VarDesc.VarType.READER:
continue
var = program.global_block().create_var(
name=var_desc.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
if _exist(var):
var_list.append(var)
fluid.io.load_vars(executor=executor, dirname=dirname, vars=var_list)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册