未验证 提交 2ed20815 编写于 作者: W whs 提交者: GitHub

Add option for CE of pruning (#942)

上级 25f89072
...@@ -7,6 +7,7 @@ import paddle ...@@ -7,6 +7,7 @@ import paddle
import argparse import argparse
import functools import functools
import math import math
import random
import time import time
import numpy as np import numpy as np
sys.path.append( sys.path.append(
...@@ -43,6 +44,7 @@ add_arg('pruned_ratio', float, None, "The ratios to be pruned.") ...@@ -43,6 +44,7 @@ add_arg('pruned_ratio', float, None, "The ratios to be pruned.")
add_arg('criterion', str, "l1_norm", "The prune criterion to be used, support l1_norm and batch_norm_scale.") add_arg('criterion', str, "l1_norm", "The prune criterion to be used, support l1_norm and batch_norm_scale.")
add_arg('use_gpu', bool, True, "Whether to GPUs.") add_arg('use_gpu', bool, True, "Whether to GPUs.")
add_arg('checkpoint', str, None, "The path of checkpoint which is used for resume training.") add_arg('checkpoint', str, None, "The path of checkpoint which is used for resume training.")
add_arg('ce_test', bool, False, "Whether to CE test.")
# yapf: enable # yapf: enable
model_list = models.__all__ model_list = models.__all__
...@@ -109,6 +111,16 @@ def create_optimizer(args, parameters, steps_per_epoch): ...@@ -109,6 +111,16 @@ def create_optimizer(args, parameters, steps_per_epoch):
def compress(args): def compress(args):
num_workers = 4
shuffle = True
if args.ce_test:
# set seed
seed = 111
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
num_workers = 0
shuffle = False
paddle.set_device('gpu' if args.use_gpu else 'cpu') paddle.set_device('gpu' if args.use_gpu else 'cpu')
train_reader = None train_reader = None
...@@ -187,7 +199,8 @@ def compress(args): ...@@ -187,7 +199,8 @@ def compress(args):
batch_size=args.batch_size // ParallelEnv().nranks, batch_size=args.batch_size // ParallelEnv().nranks,
verbose=1, verbose=1,
save_dir=args.model_path, save_dir=args.model_path,
num_workers=8) num_workers=num_workers,
shuffle=shuffle)
def main(): def main():
......
...@@ -5,6 +5,7 @@ import paddle ...@@ -5,6 +5,7 @@ import paddle
import argparse import argparse
import functools import functools
import math import math
import random
import time import time
import numpy as np import numpy as np
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir) sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
...@@ -31,13 +32,14 @@ add_arg('momentum_rate', float, 0.9, "The value of momentum_ra ...@@ -31,13 +32,14 @@ add_arg('momentum_rate', float, 0.9, "The value of momentum_ra
add_arg('num_epochs', int, 120, "The number of total epochs.") add_arg('num_epochs', int, 120, "The number of total epochs.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('config_file', str, None, "The config file for compression with yaml format.") add_arg('config_file', str, None, "The config file for compression with yaml format.")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'") add_arg('data', str, "cifar10", "Which data to use. 'cifar10' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.") add_arg('log_period', int, 10, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.") add_arg('test_period', int, 10, "Test period in epoches.")
add_arg('model_path', str, "./models", "The path to save model.") add_arg('model_path', str, "./models", "The path to save model.")
add_arg('pruned_ratio', float, None, "The ratios to be pruned.") add_arg('pruned_ratio', float, None, "The ratios to be pruned.")
add_arg('criterion', str, "l1_norm", "The prune criterion to be used, support l1_norm and batch_norm_scale.") add_arg('criterion', str, "l1_norm", "The prune criterion to be used, support l1_norm and batch_norm_scale.")
add_arg('save_inference', bool, False, "Whether to save inference model.") add_arg('save_inference', bool, False, "Whether to save inference model.")
add_arg('ce_test', bool, False, "Whether to CE test.")
# yapf: enable # yapf: enable
model_list = models.__all__ model_list = models.__all__
...@@ -94,16 +96,31 @@ def create_optimizer(args, step_per_epoch): ...@@ -94,16 +96,31 @@ def create_optimizer(args, step_per_epoch):
def compress(args): def compress(args):
num_workers = 4
shuffle = True
if args.ce_test:
# set seed
seed = 111
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
num_workers = 0
shuffle = False
train_reader = None train_reader = None
test_reader = None test_reader = None
if args.data == "mnist":
need_pretrain = True
if args.data == "cifar10":
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.MNIST( train_dataset = paddle.vision.datasets.Cifar10(
mode='train', backend="cv2", transform=transform) mode="train", backend="cv2", transform=transform)
val_dataset = paddle.vision.datasets.MNIST( val_dataset = paddle.vision.datasets.Cifar10(
mode='test', backend="cv2", transform=transform) mode="test", backend="cv2", transform=transform)
class_dim = 10 class_dim = 10
image_shape = "1,28,28" image_shape = "3, 32, 32"
need_pretrain = False
elif args.data == "imagenet": elif args.data == "imagenet":
import imagenet_reader as reader import imagenet_reader as reader
train_dataset = reader.ImageNetDataset(mode='train') train_dataset = reader.ImageNetDataset(mode='train')
...@@ -129,10 +146,10 @@ def compress(args): ...@@ -129,10 +146,10 @@ def compress(args):
feed_list=[image, label], feed_list=[image, label],
drop_last=True, drop_last=True,
batch_size=batch_size_per_card, batch_size=batch_size_per_card,
shuffle=True, shuffle=shuffle,
return_list=False, return_list=False,
use_shared_memory=True, use_shared_memory=True,
num_workers=16) num_workers=num_workers)
valid_loader = paddle.io.DataLoader( valid_loader = paddle.io.DataLoader(
val_dataset, val_dataset,
places=place, places=place,
...@@ -147,6 +164,7 @@ def compress(args): ...@@ -147,6 +164,7 @@ def compress(args):
# model definition # model definition
model = models.__dict__[args.model]() model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim) out = model.net(input=image, class_dim=class_dim)
label = paddle.reshape(label, [-1, 1])
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label) cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost) avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
...@@ -157,7 +175,7 @@ def compress(args): ...@@ -157,7 +175,7 @@ def compress(args):
exe.run(paddle.static.default_startup_program()) exe.run(paddle.static.default_startup_program())
if args.pretrained_model: if need_pretrain and args.pretrained_model:
def if_exist(var): def if_exist(var):
return os.path.exists(os.path.join(args.pretrained_model, var.name)) return os.path.exists(os.path.join(args.pretrained_model, var.name))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册