未验证 提交 2ed20815 编写于 作者: W whs 提交者: GitHub

Add option for CE of pruning (#942)

上级 25f89072
......@@ -7,6 +7,7 @@ import paddle
import argparse
import functools
import math
import random
import time
import numpy as np
sys.path.append(
......@@ -43,6 +44,7 @@ add_arg('pruned_ratio', float, None, "The ratios to be pruned.")
add_arg('criterion', str, "l1_norm", "The prune criterion to be used, support l1_norm and batch_norm_scale.")
add_arg('use_gpu', bool, True, "Whether to GPUs.")
add_arg('checkpoint', str, None, "The path of checkpoint which is used for resume training.")
add_arg('ce_test', bool, False, "Whether to CE test.")
# yapf: enable
model_list = models.__all__
......@@ -109,6 +111,16 @@ def create_optimizer(args, parameters, steps_per_epoch):
def compress(args):
num_workers = 4
shuffle = True
if args.ce_test:
# set seed
seed = 111
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
num_workers = 0
shuffle = False
paddle.set_device('gpu' if args.use_gpu else 'cpu')
train_reader = None
......@@ -187,7 +199,8 @@ def compress(args):
batch_size=args.batch_size // ParallelEnv().nranks,
verbose=1,
save_dir=args.model_path,
num_workers=8)
num_workers=num_workers,
shuffle=shuffle)
def main():
......
......@@ -5,6 +5,7 @@ import paddle
import argparse
import functools
import math
import random
import time
import numpy as np
sys.path[0] = os.path.join(os.path.dirname("__file__"), os.path.pardir)
......@@ -31,13 +32,14 @@ add_arg('momentum_rate', float, 0.9, "The value of momentum_ra
add_arg('num_epochs', int, 120, "The number of total epochs.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('config_file', str, None, "The config file for compression with yaml format.")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'")
add_arg('data', str, "cifar10", "Which data to use. 'cifar10' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.")
add_arg('model_path', str, "./models", "The path to save model.")
add_arg('pruned_ratio', float, None, "The ratios to be pruned.")
add_arg('criterion', str, "l1_norm", "The prune criterion to be used, support l1_norm and batch_norm_scale.")
add_arg('save_inference', bool, False, "Whether to save inference model.")
add_arg('ce_test', bool, False, "Whether to CE test.")
# yapf: enable
model_list = models.__all__
......@@ -94,16 +96,31 @@ def create_optimizer(args, step_per_epoch):
def compress(args):
num_workers = 4
shuffle = True
if args.ce_test:
# set seed
seed = 111
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
num_workers = 0
shuffle = False
train_reader = None
test_reader = None
if args.data == "mnist":
need_pretrain = True
if args.data == "cifar10":
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend="cv2", transform=transform)
val_dataset = paddle.vision.datasets.MNIST(
mode='test', backend="cv2", transform=transform)
train_dataset = paddle.vision.datasets.Cifar10(
mode="train", backend="cv2", transform=transform)
val_dataset = paddle.vision.datasets.Cifar10(
mode="test", backend="cv2", transform=transform)
class_dim = 10
image_shape = "1,28,28"
image_shape = "3, 32, 32"
need_pretrain = False
elif args.data == "imagenet":
import imagenet_reader as reader
train_dataset = reader.ImageNetDataset(mode='train')
......@@ -129,10 +146,10 @@ def compress(args):
feed_list=[image, label],
drop_last=True,
batch_size=batch_size_per_card,
shuffle=True,
shuffle=shuffle,
return_list=False,
use_shared_memory=True,
num_workers=16)
num_workers=num_workers)
valid_loader = paddle.io.DataLoader(
val_dataset,
places=place,
......@@ -147,6 +164,7 @@ def compress(args):
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
label = paddle.reshape(label, [-1, 1])
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
......@@ -157,7 +175,7 @@ def compress(args):
exe.run(paddle.static.default_startup_program())
if args.pretrained_model:
if need_pretrain and args.pretrained_model:
def if_exist(var):
return os.path.exists(os.path.join(args.pretrained_model, var.name))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册