未验证 提交 1af11949 编写于 作者: Z Zhen Wang 提交者: GitHub

Merge pull request #1965 from wzzju/add_checkoutpoint

Add checkpoint function for pass.
......@@ -20,6 +20,7 @@ sys.path.append('..')
import reader
import models
from utility import add_arguments, print_arguments
from utility import save_persistable_nodes, load_persistable_nodes
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
......@@ -31,7 +32,8 @@ add_arg('num_epochs', int, 120, "number of epochs.")
add_arg('class_dim', int, 1000, "Class number.")
add_arg('image_shape', str, "3,224,224", "input image size")
add_arg('model_save_dir', str, "output", "model save directory")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('pretrained_fp32_model', str, None, "Whether to use the pretrained float32 model to initialize the weights.")
add_arg('checkpoint', str, None, "Whether to resume the training process from the checkpoint.")
add_arg('lr', float, 0.1, "set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.")
......@@ -180,7 +182,8 @@ def build_program(is_train, main_prog, startup_prog, args):
def train(args):
# parameters from arguments
model_name = args.model
pretrained_model = args.pretrained_model
pretrained_fp32_model = args.pretrained_fp32_model
checkpoint = args.checkpoint
model_save_dir = args.model_save_dir
data_dir = args.data_dir
activation_quant_type = args.act_quant_type
......@@ -210,11 +213,11 @@ def train(args):
main_graph = IrGraph(core.Graph(train_prog.desc), for_test=False)
test_graph = IrGraph(core.Graph(test_prog.desc), for_test=True)
if pretrained_model:
if pretrained_fp32_model:
def if_exist(var):
return os.path.exists(os.path.join(pretrained_model, var.name))
return os.path.exists(os.path.join(pretrained_fp32_model, var.name))
fluid.io.load_vars(
exe, pretrained_model, main_program=train_prog, predicate=if_exist)
exe, pretrained_fp32_model, main_program=train_prog, predicate=if_exist)
if args.use_gpu:
visible_device = os.getenv('CUDA_VISIBLE_DEVICES')
......@@ -248,6 +251,9 @@ def train(args):
transform_pass.apply(main_graph)
transform_pass.apply(test_graph)
if checkpoint:
load_persistable_nodes(exe, checkpoint, main_graph)
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
......@@ -327,6 +333,11 @@ def train(args):
test_acc1, test_acc5))
sys.stdout.flush()
save_checkpoint_path = os.path.join(model_save_dir, model_name, str(pass_id))
if not os.path.isdir(save_checkpoint_path):
os.makedirs(save_checkpoint_path)
save_persistable_nodes(exe, save_checkpoint_path, main_graph)
model_path = os.path.join(model_save_dir, model_name, args.act_quant_type)
float_path = os.path.join(model_path, 'float')
int8_path = os.path.join(model_path, 'int8')
......
#!/usr/bin/env bash
export CUDA_VISIBLE_DEVICES=0
export CUDA_VISIBLE_DEVICES=0,1,2,3
#MobileNet v1:
python quant.py \
--model=MobileNet \
--pretrained_model=../data/pretrain/MobileNetV1_pretrained \
--pretrained_fp32_model=../data/pretrain/MobileNetV1_pretrained \
--use_gpu=True \
--data_dir=../data/ILSVRC2012 \
--batch_size=64 \
--batch_size=256 \
--total_images=1281167 \
--class_dim=1000 \
--image_shape=3,224,224 \
--model_save_dir=output/ \
--lr_strategy=piecewise_decay \
--num_epochs=10 \
--num_epochs=20 \
--lr=0.0001 \
--act_quant_type=abs_max \
--wt_quant_type=abs_max
......@@ -23,16 +23,16 @@ python quant.py \
#ResNet50:
#python quant.py \
# --model=ResNet50 \
# --pretrained_model=../data/pretrain/ResNet50_pretrained \
# --pretrained_fp32_model=../data/pretrain/ResNet50_pretrained \
# --use_gpu=True \
# --data_dir=../data/ILSVRC2012 \
# --batch_size=32 \
# --batch_size=128 \
# --total_images=1281167 \
# --class_dim=1000 \
# --image_shape=3,224,224 \
# --model_save_dir=output/ \
# --lr_strategy=piecewise_decay \
# --num_epochs=10 \
# --num_epochs=20 \
# --lr=0.0001 \
# --act_quant_type=abs_max \
# --wt_quant_type=abs_max
......
......@@ -17,9 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import distutils.util
import os
import numpy as np
import six
import paddle.fluid as fluid
import paddle.compat as cpt
from paddle.fluid import core
from paddle.fluid.framework import Program
def print_arguments(args):
......@@ -61,3 +65,78 @@ def add_arguments(argname, type, default, help, argparser, **kwargs):
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
def save_persistable_nodes(executor, dirname, graph):
"""
Save persistable nodes to the given directory by the executor.
Args:
executor(Executor): The executor to run for saving node values.
dirname(str): The directory path.
graph(IrGraph): All the required persistable nodes in the graph will be saved.
"""
persistable_node_names = set()
persistable_nodes = []
all_persistable_nodes = graph.all_persistable_nodes()
for node in all_persistable_nodes:
name = cpt.to_text(node.name())
if name not in persistable_node_names:
persistable_node_names.add(name)
persistable_nodes.append(node)
program = Program()
var_list = []
for node in persistable_nodes:
var_desc = node.var()
if var_desc.type() == core.VarDesc.VarType.RAW or \
var_desc.type() == core.VarDesc.VarType.READER:
continue
var = program.global_block().create_var(
name=var_desc.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
var_list.append(var)
fluid.io.save_vars(executor=executor, dirname=dirname, vars=var_list)
def load_persistable_nodes(executor, dirname, graph):
"""
Load persistable node values from the given directory by the executor.
Args:
executor(Executor): The executor to run for loading node values.
dirname(str): The directory path.
graph(IrGraph): All the required persistable nodes in the graph will be loaded.
"""
persistable_node_names = set()
persistable_nodes = []
all_persistable_nodes = graph.all_persistable_nodes()
for node in all_persistable_nodes:
name = cpt.to_text(node.name())
if name not in persistable_node_names:
persistable_node_names.add(name)
persistable_nodes.append(node)
program = Program()
var_list = []
def _exist(var):
return os.path.exists(os.path.join(dirname, var.name))
for node in persistable_nodes:
var_desc = node.var()
if var_desc.type() == core.VarDesc.VarType.RAW or \
var_desc.type() == core.VarDesc.VarType.READER:
continue
var = program.global_block().create_var(
name=var_desc.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
if _exist(var):
var_list.append(var)
fluid.io.load_vars(executor=executor, dirname=dirname, vars=var_list)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册