未验证 提交 43f43c6d 编写于 作者: G Guanghua Yu 提交者: GitHub

support ce with quant (#932)

上级 203ae656
......@@ -24,6 +24,7 @@ import argparse
import functools
import math
import time
import random
import numpy as np
from paddle.distributed import ParallelEnv
from paddle.static import load_program_state
......@@ -53,6 +54,7 @@ add_arg('lr_strategy', str, "piecewise_decay",
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('ls_epsilon', float, 0.0, "Label smooth epsilon.")
add_arg('use_pact', bool, False, "Whether to use PACT method.")
add_arg('ce_test', bool, False, "Whether to CE test.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 1, "The number of total epochs.")
add_arg('total_images', int, 1281167, "The number of total training images.")
......@@ -88,6 +90,17 @@ def load_dygraph_pretrain(model, path=None, load_static_weights=False):
def compress(args):
num_workers = 4
shuffle = True
if args.ce_test:
# set seed
seed = 111
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
num_workers = 0
shuffle = False
if args.data == "cifar10":
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.Cifar10(
......@@ -172,13 +185,16 @@ def compress(args):
net = paddle.DataParallel(net)
train_batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
train_dataset,
batch_size=args.batch_size,
shuffle=shuffle,
drop_last=True)
train_loader = paddle.io.DataLoader(
train_dataset,
batch_sampler=train_batch_sampler,
places=place,
return_list=True,
num_workers=4)
num_workers=num_workers)
valid_loader = paddle.io.DataLoader(
val_dataset,
......@@ -187,7 +203,7 @@ def compress(args):
shuffle=False,
drop_last=False,
return_list=True,
num_workers=4)
num_workers=num_workers)
@paddle.no_grad()
def test(epoch, net):
......
......@@ -7,6 +7,7 @@ import functools
import math
import time
import numpy as np
import random
from collections import defaultdict
sys.path.append(os.path.dirname("__file__"))
......@@ -64,6 +65,7 @@ add_arg('use_pact', bool, True,
"Whether to use PACT or not.")
add_arg('analysis', bool, False,
"Whether analysis variables distribution.")
add_arg('ce_test', bool, False, "Whether to CE test.")
# yapf: enable
......@@ -108,6 +110,16 @@ def create_optimizer(args):
def compress(args):
num_workers = 4
shuffle = True
if args.ce_test:
# set seed
seed = 111
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
num_workers = 0
shuffle = False
if args.data == "mnist":
train_dataset = paddle.vision.datasets.MNIST(mode='train')
......@@ -160,8 +172,8 @@ def compress(args):
return_list=False,
batch_size=args.batch_size,
use_shared_memory=True,
shuffle=True,
num_workers=4)
shuffle=shuffle,
num_workers=num_workers)
valid_loader = paddle.io.DataLoader(
val_dataset,
......
......@@ -6,6 +6,7 @@ import argparse
import functools
import math
import time
import random
import numpy as np
import paddle.fluid as fluid
sys.path[0] = os.path.join(
......@@ -13,6 +14,7 @@ sys.path[0] = os.path.join(
from paddleslim.common import get_logger
from paddleslim.analysis import flops
from paddleslim.quant import quant_aware, convert
import paddle.vision.transforms as T
import models
from utility import add_arguments, print_arguments
......@@ -35,9 +37,10 @@ add_arg('num_epochs', int, 1, "The number of total epochs."
add_arg('total_images', int, 1281167, "The number of total training images.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('config_file', str, None, "The config file for compression with yaml format.")
add_arg('data', str, "imagenet", "Which data to use. 'mnist' or 'imagenet'")
add_arg('data', str, "imagenet", "Which data to use. 'mnist', 'cifar10' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.")
add_arg('checkpoint_dir', str, "output", "checkpoint save dir")
add_arg('ce_test', bool, False, "Whether to CE test.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
......@@ -81,6 +84,17 @@ def create_optimizer(args):
def compress(args):
num_workers = 4
shuffle = True
if args.ce_test:
# set seed
seed = 111
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
num_workers = 0
shuffle = False
############################################################################################################
# 1. quantization configs
############################################################################################################
......@@ -105,11 +119,21 @@ def compress(args):
'moving_rate': 0.9,
}
pretrain = True
if args.data == "mnist":
train_dataset = paddle.vision.datasets.MNIST(mode='train')
val_dataset = paddle.vision.datasets.MNIST(mode='test')
class_dim = 10
image_shape = "1,28,28"
elif args.data == "cifar10":
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.Cifar10(
mode="train", backend="cv2", transform=transform)
val_dataset = paddle.vision.datasets.Cifar10(
mode="test", backend="cv2", transform=transform)
class_dim = 10
image_shape = "3, 32, 32"
pretrain = False
elif args.data == "imagenet":
import imagenet_reader as reader
train_dataset = reader.ImageNetDataset(mode='train')
......@@ -153,11 +177,12 @@ def compress(args):
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
assert os.path.exists(
args.pretrained_model), "pretrained_model doesn't exist"
if pretrain:
assert os.path.exists(
args.pretrained_model), "pretrained_model doesn't exist"
if args.pretrained_model:
paddle.static.load(train_prog, args.pretrained_model, exe)
if args.pretrained_model:
paddle.static.load(train_prog, args.pretrained_model, exe)
places = paddle.static.cuda_places(
) if args.use_gpu else paddle.static.cpu_places()
......@@ -170,8 +195,8 @@ def compress(args):
batch_size=args.batch_size,
return_list=False,
use_shared_memory=True,
shuffle=True,
num_workers=4)
shuffle=shuffle,
num_workers=num_workers)
valid_loader = paddle.io.DataLoader(
val_dataset,
places=place,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册