未验证 提交 1c6c326f 编写于 作者: M minghaoBD 提交者: GitHub

[unstructured_prune]support reproducibility for ce (#957)

上级 9b358cab
...@@ -2,6 +2,7 @@ import paddle ...@@ -2,6 +2,7 @@ import paddle
import os import os
import sys import sys
import argparse import argparse
import random
import numpy as np import numpy as np
from paddleslim import UnstructuredPruner, GMPUnstructuredPruner from paddleslim import UnstructuredPruner, GMPUnstructuredPruner
sys.path.append( sys.path.append(
...@@ -51,6 +52,7 @@ add_arg('initial_ratio', float, 0.15, "The initial pruning ratio used ...@@ -51,6 +52,7 @@ add_arg('initial_ratio', float, 0.15, "The initial pruning ratio used
add_arg('pruning_strategy', str, 'base', "Which training strategy to use in pruning, we only support base and gmp for now. Default: base") add_arg('pruning_strategy', str, 'base', "Which training strategy to use in pruning, we only support base and gmp for now. Default: base")
add_arg('prune_params_type', str, None, "Which kind of params should be pruned, we only support None (all but norms) and conv1x1_only for now. Default: None") add_arg('prune_params_type', str, None, "Which kind of params should be pruned, we only support None (all but norms) and conv1x1_only for now. Default: None")
add_arg('local_sparsity', bool, False, "Whether to prune all the parameter matrix at the same ratio or not. Default: False") add_arg('local_sparsity', bool, False, "Whether to prune all the parameter matrix at the same ratio or not. Default: False")
add_arg('ce_test', bool, False, "Whether to CE test. Default: False")
# yapf: enable # yapf: enable
...@@ -109,6 +111,16 @@ def create_unstructured_pruner(model, args, configs=None): ...@@ -109,6 +111,16 @@ def create_unstructured_pruner(model, args, configs=None):
def compress(args): def compress(args):
shuffle = True
if args.ce_test:
# set seed
seed = 111
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
args.num_workers = 0
shuffle = False
if args.use_gpu: if args.use_gpu:
place = paddle.set_device('gpu') place = paddle.set_device('gpu')
else: else:
...@@ -139,7 +151,10 @@ def compress(args): ...@@ -139,7 +151,10 @@ def compress(args):
raise ValueError("{} is not supported.".format(args.data)) raise ValueError("{} is not supported.".format(args.data))
batch_sampler = paddle.io.DistributedBatchSampler( batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) train_dataset,
batch_size=args.batch_size,
shuffle=shuffle,
drop_last=True)
train_loader = paddle.io.DataLoader( train_loader = paddle.io.DataLoader(
train_dataset, train_dataset,
......
...@@ -5,6 +5,7 @@ import paddle ...@@ -5,6 +5,7 @@ import paddle
import argparse import argparse
import functools import functools
import time import time
import random
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddleslim.prune.unstructured_pruner import UnstructuredPruner, GMPUnstructuredPruner from paddleslim.prune.unstructured_pruner import UnstructuredPruner, GMPUnstructuredPruner
...@@ -34,10 +35,11 @@ add_arg('momentum_rate', float, 0.9, "The value of momentum_ra ...@@ -34,10 +35,11 @@ add_arg('momentum_rate', float, 0.9, "The value of momentum_ra
add_arg('pruning_strategy', str, 'base', "The pruning strategy, currently we support base and gmp. Default: base") add_arg('pruning_strategy', str, 'base', "The pruning strategy, currently we support base and gmp. Default: base")
add_arg('threshold', float, 0.01, "The threshold to set zeros, the abs(weights) lower than which will be zeros. Default: 0.01") add_arg('threshold', float, 0.01, "The threshold to set zeros, the abs(weights) lower than which will be zeros. Default: 0.01")
add_arg('pruning_mode', str, 'ratio', "the pruning mode: whether by ratio or by threshold. Default: ratio") add_arg('pruning_mode', str, 'ratio', "the pruning mode: whether by ratio or by threshold. Default: ratio")
add_arg('ratio', float, 0.55, "The ratio to set zeros, the smaller portion will be zeros. Default: 0.55") add_arg('ratio', float, 0.55, "The ratio to set zeros, the smaller portion will be zeros. Default: 0.55")
add_arg('num_epochs', int, 120, "The number of total epochs. Default: 120") add_arg('num_epochs', int, 120, "The number of total epochs. Default: 120")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('data', str, "imagenet", "Which data to use. 'mnist' or 'imagenet'. Default: imagenet") add_arg('data', str, "imagenet", "Which data to use. 'mnist', 'cifar10' or 'imagenet'. Default: imagenet")
add_arg('log_period', int, 100, "Log period in batches. Default: 100") add_arg('log_period', int, 100, "Log period in batches. Default: 100")
add_arg('test_period', int, 5, "Test period in epoches. Default: 5") add_arg('test_period', int, 5, "Test period in epoches. Default: 5")
add_arg('model_path', str, "./models", "The path to save model. Default: ./models") add_arg('model_path', str, "./models", "The path to save model. Default: ./models")
...@@ -50,6 +52,8 @@ add_arg('pruning_steps', int, 120, "How many times you want to increas ...@@ -50,6 +52,8 @@ add_arg('pruning_steps', int, 120, "How many times you want to increas
add_arg('initial_ratio', float, 0.15, "The initial pruning ratio used at the start of pruning stage. Default: 0.15") add_arg('initial_ratio', float, 0.15, "The initial pruning ratio used at the start of pruning stage. Default: 0.15")
add_arg('prune_params_type', str, None, "Which kind of params should be pruned, we only support None (all but norms) and conv1x1_only for now. Default: None") add_arg('prune_params_type', str, None, "Which kind of params should be pruned, we only support None (all but norms) and conv1x1_only for now. Default: None")
add_arg('local_sparsity', bool, False, "Whether to prune all the parameter matrix at the same ratio or not. Default: False") add_arg('local_sparsity', bool, False, "Whether to prune all the parameter matrix at the same ratio or not. Default: False")
add_arg('ce_test', bool, False, "Whether to CE test. Default: False")
add_arg('num_workers', int, 32, "number of workers when loading dataset. Default: 32")
# yapf: enable # yapf: enable
model_list = models.__all__ model_list = models.__all__
...@@ -110,6 +114,16 @@ def create_unstructured_pruner(train_program, args, place, configs): ...@@ -110,6 +114,16 @@ def create_unstructured_pruner(train_program, args, place, configs):
def compress(args): def compress(args):
shuffle = True
if args.ce_test:
# set seed
seed = 111
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
args.num_workers = 0
shuffle = False
env = os.environ env = os.environ
num_trainers = int(env.get('PADDLE_TRAINERS_NUM', 1)) num_trainers = int(env.get('PADDLE_TRAINERS_NUM', 1))
use_data_parallel = num_trainers > 1 use_data_parallel = num_trainers > 1
...@@ -130,6 +144,15 @@ def compress(args): ...@@ -130,6 +144,15 @@ def compress(args):
class_dim = 10 class_dim = 10
image_shape = "1,28,28" image_shape = "1,28,28"
args.pretrained_model = False args.pretrained_model = False
elif args.data == "cifar10":
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.Cifar10(
mode="train", backend="cv2", transform=transform)
val_dataset = paddle.vision.datasets.Cifar10(
mode="test", backend="cv2", transform=transform)
class_dim = 10
image_shape = "3, 32, 32"
args.pretrained_model = False
elif args.data == "imagenet": elif args.data == "imagenet":
import imagenet_reader as reader import imagenet_reader as reader
train_dataset = reader.ImageNetDataset(mode='train') train_dataset = reader.ImageNetDataset(mode='train')
...@@ -156,7 +179,7 @@ def compress(args): ...@@ -156,7 +179,7 @@ def compress(args):
batch_sampler = paddle.io.DistributedBatchSampler( batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, train_dataset,
batch_size=batch_size_per_card, batch_size=batch_size_per_card,
shuffle=True, shuffle=shuffle,
drop_last=True) drop_last=True)
train_loader = paddle.io.DataLoader( train_loader = paddle.io.DataLoader(
...@@ -166,7 +189,7 @@ def compress(args): ...@@ -166,7 +189,7 @@ def compress(args):
feed_list=[image, label], feed_list=[image, label],
return_list=False, return_list=False,
use_shared_memory=True, use_shared_memory=True,
num_workers=32) num_workers=args.num_workers)
valid_loader = paddle.io.DataLoader( valid_loader = paddle.io.DataLoader(
val_dataset, val_dataset,
...@@ -184,6 +207,8 @@ def compress(args): ...@@ -184,6 +207,8 @@ def compress(args):
# model definition # model definition
model = models.__dict__[args.model]() model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim) out = model.net(input=image, class_dim=class_dim)
if args.data == 'cifar10':
label = paddle.reshape(label, [-1, 1])
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label) cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost) avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册