提交 8be15753 编写于 作者: C ceci3

mv file

...@@ -3,3 +3,8 @@ build/ ...@@ -3,3 +3,8 @@ build/
./dist/ ./dist/
*.pyc *.pyc
dist/ dist/
*.data
*.log
*.tar
*.tar.gz
*.zip
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
import paddle.fluid as fluid
from paddleslim.prune import AutoPruner
from paddleslim.common import get_logger
from paddleslim.analysis import flops
sys.path.append(sys.path[0] + "/../")
import models
from utility import add_arguments, print_arguments
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64 * 4, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretained", "Whether to use pretrained model.")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 120, "The number of total epochs.")
add_arg('total_images', int, 1281167, "The number of total training images.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('config_file', str, None, "The config file for compression with yaml format.")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def piecewise_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
bd = [step * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def cosine_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
learning_rate = fluid.layers.cosine_decay(
learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def create_optimizer(args):
if args.lr_strategy == "piecewise_decay":
return piecewise_decay(args)
elif args.lr_strategy == "cosine_decay":
return cosine_decay(args)
def compress(args):
train_reader = None
test_reader = None
if args.data == "mnist":
import paddle.dataset.mnist as reader
train_reader = reader.train()
val_reader = reader.test()
class_dim = 10
image_shape = "1,28,28"
elif args.data == "imagenet":
import imagenet_reader as reader
train_reader = reader.train()
val_reader = reader.val()
class_dim = 1000
image_shape = "3,224,224"
else:
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(
args.model, model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
val_program = fluid.default_main_program().clone(for_test=True)
opt = create_optimizer(args)
opt.minimize(avg_cost)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if args.pretrained_model:
def if_exist(var):
return os.path.exists(
os.path.join(args.pretrained_model, var.name))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
val_reader = paddle.batch(val_reader, batch_size=args.batch_size)
train_reader = paddle.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
train_feeder = feeder = fluid.DataFeeder([image, label], place)
val_feeder = feeder = fluid.DataFeeder(
[image, label], place, program=val_program)
def test(epoch, program):
batch_id = 0
acc_top1_ns = []
acc_top5_ns = []
for data in val_reader():
start_time = time.time()
acc_top1_n, acc_top5_n = exe.run(
program,
feed=train_feeder.feed(data),
fetch_list=[acc_top1.name, acc_top5.name])
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id,
np.mean(acc_top1_n),
np.mean(acc_top5_n), end_time - start_time))
acc_top1_ns.append(np.mean(acc_top1_n))
acc_top5_ns.append(np.mean(acc_top5_n))
batch_id += 1
_logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".
format(epoch,
np.mean(np.array(acc_top1_ns)),
np.mean(np.array(acc_top5_ns))))
return np.mean(np.array(acc_top1_ns))
def train(epoch, program):
build_strategy = fluid.BuildStrategy()
exec_strategy = fluid.ExecutionStrategy()
train_program = fluid.compiler.CompiledProgram(
program).with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
batch_id = 0
for data in train_reader():
start_time = time.time()
loss_n, acc_top1_n, acc_top5_n = exe.run(
train_program,
feed=train_feeder.feed(data),
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
end_time = time.time()
loss_n = np.mean(loss_n)
acc_top1_n = np.mean(acc_top1_n)
acc_top5_n = np.mean(acc_top5_n)
if batch_id % args.log_period == 0:
_logger.info(
"epoch[{}]-batch[{}] - loss: {}; acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n,
end_time - start_time))
batch_id += 1
params = []
for param in fluid.default_main_program().global_block().all_parameters():
if "_sep_weights" in param.name:
params.append(param.name)
pruner = AutoPruner(
val_program,
fluid.global_scope(),
place,
params=params,
init_ratios=[0.33] * len(params),
pruned_flops=0.5,
pruned_latency=None,
server_addr=("", 0),
init_temperature=100,
reduce_rate=0.85,
max_try_times=300,
max_client_num=10,
search_steps=100,
max_ratios=0.9,
min_ratios=0.,
is_server=True,
key="auto_pruner")
while True:
pruned_program, pruned_val_program = pruner.prune(
fluid.default_main_program(), val_program)
for i in range(1):
train(i, pruned_program)
score = test(0, pruned_val_program)
pruner.reward(score)
def main():
args = parser.parse_args()
print_arguments(args)
compress(args)
if __name__ == '__main__':
main()
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import math
import logging
import paddle
import argparse
import functools
import numpy as np
import paddle.fluid as fluid
sys.path.append(sys.path[0] + "/../")
import models
import imagenet_reader as reader
from utility import add_arguments, print_arguments
from paddleslim.dist import merge, l2_loss, soft_label_loss, fsp_loss
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64*4, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('total_images', int, 1281167, "Training image number.")
add_arg('image_shape', str, "3,224,224", "Input image size")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 120, "The number of total epochs.")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period', int, 20, "Log period in batches.")
add_arg('model', str, "MobileNet", "Set the network to use.")
add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
add_arg('teacher_model', str, "ResNet50", "Set the teacher network to use.")
add_arg('teacher_pretrained_model', str, "../pretrain/ResNet50_pretrained", "Whether to use pretrained model.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def piecewise_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
bd = [step * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def cosine_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
learning_rate = fluid.layers.cosine_decay(
learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def create_optimizer(args):
if args.lr_strategy == "piecewise_decay":
return piecewise_decay(args)
elif args.lr_strategy == "cosine_decay":
return cosine_decay(args)
def compress(args):
if args.data == "mnist":
import paddle.dataset.mnist as reader
train_reader = reader.train()
val_reader = reader.test()
class_dim = 10
image_shape = "1,28,28"
elif args.data == "imagenet":
import imagenet_reader as reader
train_reader = reader.train()
val_reader = reader.val()
class_dim = 1000
image_shape = "3,224,224"
else:
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(
args.model, model_list)
student_program = fluid.Program()
s_startup = fluid.Program()
with fluid.program_guard(student_program, s_startup):
with fluid.unique_name.guard():
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
train_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=64,
use_double_buffer=True,
iterable=True)
valid_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=64,
use_double_buffer=True,
iterable=True)
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
#print("="*50+"student_model_params"+"="*50)
#for v in student_program.list_vars():
# print(v.name, v.shape)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
train_reader = paddle.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
val_reader = paddle.batch(
val_reader, batch_size=args.batch_size, drop_last=True)
val_program = student_program.clone(for_test=True)
places = fluid.cuda_places()
train_loader.set_sample_list_generator(train_reader, places)
valid_loader.set_sample_list_generator(val_reader, place)
teacher_model = models.__dict__[args.teacher_model]()
# define teacher program
teacher_program = fluid.Program()
t_startup = fluid.Program()
teacher_scope = fluid.Scope()
with fluid.scope_guard(teacher_scope):
with fluid.program_guard(teacher_program, t_startup):
with fluid.unique_name.guard():
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
predict = teacher_model.net(image, class_dim=class_dim)
#print("="*50+"teacher_model_params"+"="*50)
#for v in teacher_program.list_vars():
# print(v.name, v.shape)
exe.run(t_startup)
assert args.teacher_pretrained_model and os.path.exists(
args.teacher_pretrained_model
), "teacher_pretrained_model should be set when teacher_model is not None."
def if_exist(var):
return os.path.exists(
os.path.join(args.teacher_pretrained_model, var.name)
) and var.name != 'conv1_weights' and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0'
fluid.io.load_vars(
exe,
args.teacher_pretrained_model,
main_program=teacher_program,
predicate=if_exist)
data_name_map = {'image': 'image'}
main = merge(
teacher_program,
student_program,
data_name_map,
place,
teacher_scope=teacher_scope)
#print("="*50+"teacher_vars"+"="*50)
#for v in teacher_program.list_vars():
# if '_generated_var' not in v.name and 'fetch' not in v.name and 'feed' not in v.name:
# print(v.name, v.shape)
#return
with fluid.program_guard(main, s_startup):
l2_loss_v = l2_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", main)
fsp_loss_v = fsp_loss("teacher_res2a_branch2a.conv2d.output.1.tmp_0",
"teacher_res3a_branch2a.conv2d.output.1.tmp_0",
"depthwise_conv2d_1.tmp_0", "conv2d_3.tmp_0",
main)
loss = avg_cost + l2_loss_v + fsp_loss_v
opt = create_optimizer(args)
opt.minimize(loss)
exe.run(s_startup)
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_all_reduce_ops = False
parallel_main = fluid.CompiledProgram(main).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
for epoch_id in range(args.num_epochs):
for step_id, data in enumerate(train_loader):
loss_1, loss_2, loss_3, loss_4 = exe.run(
parallel_main,
feed=data,
fetch_list=[
loss.name, avg_cost.name, l2_loss_v.name, fsp_loss_v.name
])
if step_id % args.log_period == 0:
_logger.info(
"train_epoch {} step {} loss {:.6f}, class loss {:.6f}, l2 loss {:.6f}, fsp loss {:.6f}".
format(epoch_id, step_id, loss_1[0], loss_2[0], loss_3[0],
loss_4[0]))
val_acc1s = []
val_acc5s = []
for step_id, data in enumerate(valid_loader):
val_loss, val_acc1, val_acc5 = exe.run(
val_program,
data,
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
val_acc1s.append(val_acc1)
val_acc5s.append(val_acc5)
if step_id % args.log_period == 0:
_logger.info(
"valid_epoch {} step {} loss {:.6f}, top1 {:.6f}, top5 {:.6f}".
format(epoch_id, step_id, val_loss[0], val_acc1[0],
val_acc5[0]))
_logger.info("epoch {} top1 {:.6f}, top5 {:.6f}".format(
epoch_id, np.mean(val_acc1s), np.mean(val_acc5s)))
def main():
args = parser.parse_args()
print_arguments(args)
compress(args)
if __name__ == '__main__':
main()
import os
import math
import random
import functools
import numpy as np
import paddle
from PIL import Image, ImageEnhance
random.seed(0)
np.random.seed(0)
DATA_DIM = 224
THREAD = 16
BUF_SIZE = 10240
#DATA_DIR = './data/ILSVRC2012/'
DATA_DIR = './data/'
DATA_DIR = os.path.join(os.path.split(os.path.realpath(__file__))[0], DATA_DIR)
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.LANCZOS)
return img
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]):
aspect_ratio = math.sqrt(np.random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
bound = min((float(img.size[0]) / img.size[1]) / (w**2),
(float(img.size[1]) / img.size[0]) / (h**2))
scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound)
target_area = img.size[0] * img.size[1] * np.random.uniform(scale_min,
scale_max)
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = np.random.randint(0, img.size[0] - w + 1)
j = np.random.randint(0, img.size[1] - h + 1)
img = img.crop((i, j, i + w, j + h))
img = img.resize((size, size), Image.LANCZOS)
return img
def rotate_image(img):
angle = np.random.randint(-10, 11)
img = img.rotate(angle)
return img
def distort_color(img):
def random_brightness(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
def random_contrast(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
def random_color(img, lower=0.5, upper=1.5):
e = np.random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
ops = [random_brightness, random_contrast, random_color]
np.random.shuffle(ops)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
return img
def process_image(sample, mode, color_jitter, rotate):
img_path = sample[0]
img = Image.open(img_path)
if mode == 'train':
if rotate: img = rotate_image(img)
img = random_crop(img, DATA_DIM)
else:
img = resize_short(img, target_size=256)
img = crop_image(img, target_size=DATA_DIM, center=True)
if mode == 'train':
if color_jitter:
img = distort_color(img)
if np.random.randint(0, 2) == 1:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255
img -= img_mean
img /= img_std
if mode == 'train' or mode == 'val':
return img, sample[1]
elif mode == 'test':
return [img]
def _reader_creator(file_list,
mode,
shuffle=False,
color_jitter=False,
rotate=False,
data_dir=DATA_DIR,
batch_size=1):
def reader():
try:
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if shuffle:
np.random.shuffle(full_lines)
if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS", "1"))
per_node_lines = len(full_lines) // trainer_count
lines = full_lines[trainer_id * per_node_lines:(
trainer_id + 1) * per_node_lines]
print(
"read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines,
len(lines), len(full_lines)))
else:
lines = full_lines
for line in lines:
if mode == 'train' or mode == 'val':
img_path, label = line.split()
img_path = os.path.join(data_dir + "/" + mode,
img_path)
yield img_path, int(label)
elif mode == 'test':
img_path = os.path.join(data_dir, line)
yield [img_path]
except Exception as e:
print("Reader failed!\n{}".format(str(e)))
os._exit(1)
mapper = functools.partial(
process_image, mode=mode, color_jitter=color_jitter, rotate=rotate)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def train(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'train_list.txt')
return _reader_creator(
file_list,
'train',
shuffle=True,
color_jitter=False,
rotate=False,
data_dir=data_dir)
def val(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir)
def test(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'test_list.txt')
return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir)
from .mobilenet import MobileNet
from .resnet import ResNet34, ResNet50
from .mobilenet_v2 import MobileNetV2
__all__ = ['MobileNet', 'ResNet34', 'ResNet50', 'MobileNetV2']
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
__all__ = ['MobileNet']
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [10, 16, 30],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class MobileNet():
def __init__(self):
self.params = train_parameters
def net(self, input, class_dim=1000, scale=1.0):
# conv1: 112x112
input = self.conv_bn_layer(
input,
filter_size=3,
channels=3,
num_filters=int(32 * scale),
stride=2,
padding=1,
name="conv1")
# 56x56
input = self.depthwise_separable(
input,
num_filters1=32,
num_filters2=64,
num_groups=32,
stride=1,
scale=scale,
name="conv2_1")
input = self.depthwise_separable(
input,
num_filters1=64,
num_filters2=128,
num_groups=64,
stride=2,
scale=scale,
name="conv2_2")
# 28x28
input = self.depthwise_separable(
input,
num_filters1=128,
num_filters2=128,
num_groups=128,
stride=1,
scale=scale,
name="conv3_1")
input = self.depthwise_separable(
input,
num_filters1=128,
num_filters2=256,
num_groups=128,
stride=2,
scale=scale,
name="conv3_2")
# 14x14
input = self.depthwise_separable(
input,
num_filters1=256,
num_filters2=256,
num_groups=256,
stride=1,
scale=scale,
name="conv4_1")
input = self.depthwise_separable(
input,
num_filters1=256,
num_filters2=512,
num_groups=256,
stride=2,
scale=scale,
name="conv4_2")
# 14x14
for i in range(5):
input = self.depthwise_separable(
input,
num_filters1=512,
num_filters2=512,
num_groups=512,
stride=1,
scale=scale,
name="conv5" + "_" + str(i + 1))
# 7x7
input = self.depthwise_separable(
input,
num_filters1=512,
num_filters2=1024,
num_groups=512,
stride=2,
scale=scale,
name="conv5_6")
input = self.depthwise_separable(
input,
num_filters1=1024,
num_filters2=1024,
num_groups=1024,
stride=1,
scale=scale,
name="conv6")
input = fluid.layers.pool2d(
input=input,
pool_size=0,
pool_stride=1,
pool_type='avg',
global_pooling=True)
output = fluid.layers.fc(input=input,
size=class_dim,
act='softmax',
param_attr=ParamAttr(
initializer=MSRA(), name="fc7_weights"),
bias_attr=ParamAttr(name="fc7_offset"))
return output
def conv_bn_layer(self,
input,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
act='relu',
use_cudnn=True,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(
initializer=MSRA(), name=name + "_weights"),
bias_attr=False)
bn_name = name + "_bn"
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def depthwise_separable(self,
input,
num_filters1,
num_filters2,
num_groups,
stride,
scale,
name=None):
depthwise_conv = self.conv_bn_layer(
input=input,
filter_size=3,
num_filters=int(num_filters1 * scale),
stride=stride,
padding=1,
num_groups=int(num_groups * scale),
use_cudnn=False,
name=name + "_dw")
pointwise_conv = self.conv_bn_layer(
input=depthwise_conv,
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
padding=0,
name=name + "_sep")
return pointwise_conv
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
__all__ = [
'MobileNetV2', 'MobileNetV2_x0_25, '
'MobileNetV2_x0_5', 'MobileNetV2_x1_0', 'MobileNetV2_x1_5',
'MobileNetV2_x2_0', 'MobileNetV2_scale'
]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [30, 60, 90],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class MobileNetV2():
def __init__(self, scale=1.0, change_depth=False):
self.params = train_parameters
self.scale = scale
self.change_depth = change_depth
def net(self, input, class_dim=1000):
scale = self.scale
change_depth = self.change_depth
#if change_depth is True, the new depth is 1.4 times as deep as before.
bottleneck_params_list = [
(1, 16, 1, 1),
(6, 24, 2, 2),
(6, 32, 3, 2),
(6, 64, 4, 2),
(6, 96, 3, 1),
(6, 160, 3, 2),
(6, 320, 1, 1),
] if change_depth == False else [
(1, 16, 1, 1),
(6, 24, 2, 2),
(6, 32, 5, 2),
(6, 64, 7, 2),
(6, 96, 5, 1),
(6, 160, 3, 2),
(6, 320, 1, 1),
]
#conv1
input = self.conv_bn_layer(
input,
num_filters=int(32 * scale),
filter_size=3,
stride=2,
padding=1,
if_act=True,
name='conv1_1')
# bottleneck sequences
i = 1
in_c = int(32 * scale)
for layer_setting in bottleneck_params_list:
t, c, n, s = layer_setting
i += 1
input = self.invresi_blocks(
input=input,
in_c=in_c,
t=t,
c=int(c * scale),
n=n,
s=s,
name='conv' + str(i))
in_c = int(c * scale)
#last_conv
input = self.conv_bn_layer(
input=input,
num_filters=int(1280 * scale) if scale > 1.0 else 1280,
filter_size=1,
stride=1,
padding=0,
if_act=True,
name='conv9')
input = fluid.layers.pool2d(
input=input,
pool_size=7,
pool_stride=1,
pool_type='avg',
global_pooling=True)
output = fluid.layers.fc(input=input,
size=class_dim,
act='softmax',
param_attr=ParamAttr(name='fc10_weights'),
bias_attr=ParamAttr(name='fc10_offset'))
return output
def conv_bn_layer(self,
input,
filter_size,
num_filters,
stride,
padding,
channels=None,
num_groups=1,
if_act=True,
name=None,
use_cudnn=True):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(name=bn_name + "_scale"),
bias_attr=ParamAttr(name=bn_name + "_offset"),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
return fluid.layers.relu6(bn)
else:
return bn
def shortcut(self, input, data_residual):
return fluid.layers.elementwise_add(input, data_residual)
def inverted_residual_unit(self,
input,
num_in_filter,
num_filters,
ifshortcut,
stride,
filter_size,
padding,
expansion_factor,
name=None):
num_expfilter = int(round(num_in_filter * expansion_factor))
channel_expand = self.conv_bn_layer(
input=input,
num_filters=num_expfilter,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
name=name + '_expand')
bottleneck_conv = self.conv_bn_layer(
input=channel_expand,
num_filters=num_expfilter,
filter_size=filter_size,
stride=stride,
padding=padding,
num_groups=num_expfilter,
if_act=True,
name=name + '_dwise',
use_cudnn=False)
linear_out = self.conv_bn_layer(
input=bottleneck_conv,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=False,
name=name + '_linear')
if ifshortcut:
out = self.shortcut(input=input, data_residual=linear_out)
return out
else:
return linear_out
def invresi_blocks(self, input, in_c, t, c, n, s, name=None):
first_block = self.inverted_residual_unit(
input=input,
num_in_filter=in_c,
num_filters=c,
ifshortcut=False,
stride=s,
filter_size=3,
padding=1,
expansion_factor=t,
name=name + '_1')
last_residual_block = first_block
last_c = c
for i in range(1, n):
last_residual_block = self.inverted_residual_unit(
input=last_residual_block,
num_in_filter=last_c,
num_filters=c,
ifshortcut=True,
stride=1,
filter_size=3,
padding=1,
expansion_factor=t,
name=name + '_' + str(i + 1))
return last_residual_block
def MobileNetV2_x0_25():
model = MobileNetV2(scale=0.25)
return model
def MobileNetV2_x0_5():
model = MobileNetV2(scale=0.5)
return model
def MobileNetV2_x1_0():
model = MobileNetV2(scale=1.0)
return model
def MobileNetV2_x1_5():
model = MobileNetV2(scale=1.5)
return model
def MobileNetV2_x2_0():
model = MobileNetV2(scale=2.0)
return model
def MobileNetV2_scale():
model = MobileNetV2(scale=1.2, change_depth=True)
return model
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import math
from paddle.fluid.param_attr import ParamAttr
__all__ = ["ResNet", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]
train_parameters = {
"input_size": [3, 224, 224],
"input_mean": [0.485, 0.456, 0.406],
"input_std": [0.229, 0.224, 0.225],
"learning_strategy": {
"name": "piecewise_decay",
"batch_size": 256,
"epochs": [10, 16, 30],
"steps": [0.1, 0.01, 0.001, 0.0001]
}
}
class ResNet():
def __init__(self, layers=50, prefix_name=''):
self.params = train_parameters
self.layers = layers
self.prefix_name = prefix_name
def net(self, input, class_dim=1000, conv1_name='conv1', fc_name=None):
layers = self.layers
prefix_name = self.prefix_name if self.prefix_name is '' else self.prefix_name + '_'
supported_layers = [34, 50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]
# TODO(wanghaoshuang@baidu.com):
# fix name("conv1") conflict between student and teacher in distillation.
conv = self.conv_bn_layer(
input=input,
num_filters=64,
filter_size=7,
stride=2,
act='relu',
name=prefix_name + conv1_name)
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
if layers >= 50:
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
conv_name = prefix_name + conv_name
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
name=conv_name)
pool = fluid.layers.pool2d(
input=conv, pool_size=7, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
fc_name = fc_name if fc_name is None else prefix_name + fc_name
out = fluid.layers.fc(input=pool,
size=class_dim,
act='softmax',
name=fc_name,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(
-stdv, stdv)))
else:
for block in range(len(depth)):
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
conv_name = prefix_name + conv_name
conv = self.basic_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
is_first=block == i == 0,
name=conv_name)
pool = fluid.layers.pool2d(
input=conv, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
fc_name = fc_name if fc_name is None else prefix_name + fc_name
out = fluid.layers.fc(
input=pool,
size=class_dim,
act='softmax',
name=fc_name,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
return out
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False,
name=name + '.conv2d.output.1')
if self.prefix_name == '':
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
else:
if name.split("_")[1] == "conv1":
bn_name = name.split("_", 1)[0] + "_bn_" + name.split("_",
1)[1]
else:
bn_name = name.split("_", 1)[0] + "_bn" + name.split("_",
1)[1][3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '.output.1',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance', )
def shortcut(self, input, ch_out, stride, is_first, name):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1 or is_first == True:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, name):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c")
short = self.shortcut(
input,
num_filters * 4,
stride,
is_first=False,
name=name + "_branch1")
return fluid.layers.elementwise_add(
x=short, y=conv2, act='relu', name=name + ".add.output.5")
def basic_block(self, input, num_filters, stride, is_first, name):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=3,
act='relu',
stride=stride,
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
short = self.shortcut(
input, num_filters, stride, is_first, name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
def ResNet34(prefix_name=''):
model = ResNet(layers=34, prefix_name=prefix_name)
return model
def ResNet50(prefix_name=''):
model = ResNet(layers=50, prefix_name=prefix_name)
return model
def ResNet101():
model = ResNet(layers=101)
return model
def ResNet152():
model = ResNet(layers=152)
return model
import sys
sys.path.append('..')
import numpy as np
import argparse
import ast
import time
import argparse
import ast
import logging
import paddle
import paddle.fluid as fluid
from paddleslim.nas.search_space.search_space_factory import SearchSpaceFactory
from paddleslim.analysis import flops
from paddleslim.nas import SANAS
from paddleslim.common import get_logger
from optimizer import create_optimizer
import imagenet_reader
_logger = get_logger(__name__, level=logging.INFO)
def create_data_loader(image_shape):
data_shape = [-1] + image_shape
data = fluid.data(name='data', shape=data_shape, dtype='float32')
label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
data_loader = fluid.io.DataLoader.from_generator(
feed_list=[data, label],
capacity=1024,
use_double_buffer=True,
iterable=True)
return data_loader, data, label
def build_program(main_program,
startup_program,
image_shape,
archs,
args,
is_test=False):
with fluid.program_guard(main_program, startup_program):
data_loader, data, label = create_data_loader(image_shape)
output = archs(data)
softmax_out = fluid.layers.softmax(input=output, use_cudnn=False)
cost = fluid.layers.cross_entropy(input=softmax_out, label=label)
avg_cost = fluid.layers.mean(cost)
acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=softmax_out, label=label, k=5)
if is_test == False:
optimizer = create_optimizer(args)
optimizer.minimize(avg_cost)
return data_loader, avg_cost, acc_top1, acc_top5
def search_mobilenetv2(config, args, image_size, is_server=True):
factory = SearchSpaceFactory()
space = factory.get_search_space(config)
if is_server:
### start a server and a client
sa_nas = SANAS(
config,
server_addr=("", 8883),
init_temperature=args.init_temperature,
reduce_rate=args.reduce_rate,
search_steps=args.search_steps,
is_server=True)
else:
### start a client
sa_nas = SANAS(
config,
server_addr=("10.255.125.38", 8883),
init_temperature=args.init_temperature,
reduce_rate=args.reduce_rate,
search_steps=args.search_steps,
is_server=False)
image_shape = [3, image_size, image_size]
for step in range(args.search_steps):
archs = sa_nas.next_archs()[0]
train_program = fluid.Program()
test_program = fluid.Program()
startup_program = fluid.Program()
train_loader, avg_cost, acc_top1, acc_top5 = build_program(
train_program, startup_program, image_shape, archs, args)
current_flops = flops(train_program)
print('step: {}, current_flops: {}'.format(step, current_flops))
if current_flops > args.max_flops:
continue
test_loader, test_avg_cost, test_acc_top1, test_acc_top5 = build_program(
test_program,
startup_program,
image_shape,
archs,
args,
is_test=True)
test_program = test_program.clone(for_test=True)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
if args.data == 'cifar10':
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10(cycle=False), buf_size=1024),
batch_size=args.batch_size,
drop_last=True)
test_reader = paddle.batch(
paddle.dataset.cifar.test10(cycle=False),
batch_size=args.batch_size,
drop_last=False)
elif args.data == 'imagenet':
train_reader = paddle.batch(
imagenet_reader.train(),
batch_size=args.batch_size,
drop_last=True)
test_reader = paddle.batch(
imagenet_reader.val(),
batch_size=args.batch_size,
drop_last=False)
#test_loader, _, _ = create_data_loader(image_shape)
train_loader.set_sample_list_generator(
train_reader,
places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
test_loader.set_sample_list_generator(test_reader, places=place)
build_strategy = fluid.BuildStrategy()
train_compiled_program = fluid.CompiledProgram(
train_program).with_data_parallel(
loss_name=avg_cost.name, build_strategy=build_strategy)
for epoch_id in range(args.retain_epoch):
for batch_id, data in enumerate(train_loader()):
fetches = [avg_cost.name]
s_time = time.time()
outs = exe.run(train_compiled_program,
feed=data,
fetch_list=fetches)[0]
batch_time = time.time() - s_time
if batch_id % 10 == 0:
_logger.info(
'TRAIN: steps: {}, epoch: {}, batch: {}, cost: {}, batch_time: {}ms'.
format(step, epoch_id, batch_id, outs[0], batch_time))
reward = []
for batch_id, data in enumerate(test_loader()):
test_fetches = [
test_avg_cost.name, test_acc_top1.name, test_acc_top5.name
]
batch_reward = exe.run(test_program,
feed=data,
fetch_list=test_fetches)
reward_avg = np.mean(np.array(batch_reward), axis=1)
reward.append(reward_avg)
_logger.info(
'TEST: step: {}, batch: {}, avg_cost: {}, acc_top1: {}, acc_top5: {}'.
format(step, batch_id, batch_reward[0], batch_reward[1],
batch_reward[2]))
finally_reward = np.mean(np.array(reward), axis=0)
_logger.info(
'FINAL TEST: avg_cost: {}, acc_top1: {}, acc_top5: {}'.format(
finally_reward[0], finally_reward[1], finally_reward[2]))
sa_nas.reward(float(finally_reward[1]))
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='SA NAS MobileNetV2 cifar10 argparase')
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='Whether to use GPU in train/test model.')
parser.add_argument(
'--batch_size', type=int, default=256, help='batch size.')
parser.add_argument(
'--data',
type=str,
default='cifar10',
choices=['cifar10', 'imagenet'],
help='server address.')
# controller
parser.add_argument(
'--reduce_rate', type=float, default=0.85, help='reduce rate.')
parser.add_argument(
'--init_temperature',
type=float,
default=10.24,
help='init temperature.')
parser.add_argument(
'--is_server',
type=ast.literal_eval,
default=True,
help='Whether to start a server.')
# nas args
parser.add_argument(
'--max_flops', type=int, default=592948064, help='reduce rate.')
parser.add_argument(
'--retain_epoch', type=int, default=5, help='train epoch before val.')
parser.add_argument(
'--end_epoch', type=int, default=500, help='end epoch present client.')
parser.add_argument(
'--search_steps',
type=int,
default=100,
help='controller server number.')
parser.add_argument(
'--server_address', type=str, default=None, help='server address.')
# optimizer args
parser.add_argument(
'--lr_strategy',
type=str,
default='piecewise_decay',
help='learning rate decay strategy.')
parser.add_argument('--lr', type=float, default=0.1, help='learning rate.')
parser.add_argument(
'--l2_decay', type=float, default=1e-4, help='learning rate decay.')
parser.add_argument(
'--step_epochs',
nargs='+',
type=int,
default=[30, 60, 90],
help="piecewise decay step")
parser.add_argument(
'--momentum_rate',
type=float,
default=0.9,
help='learning rate decay.')
parser.add_argument(
'--warm_up_epochs',
type=float,
default=5.0,
help='learning rate decay.')
parser.add_argument(
'--num_epochs', type=int, default=120, help='learning rate decay.')
parser.add_argument(
'--decay_epochs', type=float, default=2.4, help='learning rate decay.')
parser.add_argument(
'--decay_rate', type=float, default=0.97, help='learning rate decay.')
parser.add_argument(
'--total_images',
type=int,
default=1281167,
help='learning rate decay.')
args = parser.parse_args()
print(args)
if args.data == 'cifar10':
image_size = 32
block_num = 3
elif args.data == 'imagenet':
image_size = 224
block_num = 6
else:
raise NotImplemented(
'data must in [cifar10, imagenet], but received: {}'.format(
args.data))
config_info = {
'input_size': image_size,
'output_size': 1,
'block_num': block_num,
'block_mask': None
}
config = [('MobileNetV2Space', config_info)]
search_mobilenetv2(config, args, image_size, is_server=args.is_server)
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle.fluid as fluid
import paddle.fluid.layers.ops as ops
from paddle.fluid.initializer import init_on_cpu
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
def cosine_decay(learning_rate, step_each_epoch, epochs=120):
"""Applies cosine decay to the learning rate.
lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1)
"""
global_step = _decay_step_counter()
with init_on_cpu():
epoch = ops.floor(global_step / step_each_epoch)
decayed_lr = learning_rate * \
(ops.cos(epoch * (math.pi / epochs)) + 1)/2
return decayed_lr
def cosine_decay_with_warmup(learning_rate, step_each_epoch, epochs=120):
"""Applies cosine decay to the learning rate.
lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1)
decrease lr for every mini-batch and start with warmup.
"""
global_step = _decay_step_counter()
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate")
warmup_epoch = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(5), force_cpu=True)
with init_on_cpu():
epoch = ops.floor(global_step / step_each_epoch)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(epoch < warmup_epoch):
decayed_lr = learning_rate * (global_step /
(step_each_epoch * warmup_epoch))
fluid.layers.tensor.assign(input=decayed_lr, output=lr)
with switch.default():
decayed_lr = learning_rate * \
(ops.cos((global_step - warmup_epoch * step_each_epoch) * (math.pi / (epochs * step_each_epoch))) + 1)/2
fluid.layers.tensor.assign(input=decayed_lr, output=lr)
return lr
def exponential_decay_with_warmup(learning_rate,
step_each_epoch,
decay_epochs,
decay_rate=0.97,
warm_up_epoch=5.0):
"""Applies exponential decay to the learning rate.
"""
global_step = _decay_step_counter()
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate")
warmup_epoch = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=float(warm_up_epoch), force_cpu=True)
with init_on_cpu():
epoch = ops.floor(global_step / step_each_epoch)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(epoch < warmup_epoch):
decayed_lr = learning_rate * (global_step /
(step_each_epoch * warmup_epoch))
fluid.layers.assign(input=decayed_lr, output=lr)
with switch.default():
div_res = (global_step - warmup_epoch * step_each_epoch
) / decay_epochs
div_res = ops.floor(div_res)
decayed_lr = learning_rate * (decay_rate**div_res)
fluid.layers.assign(input=decayed_lr, output=lr)
return lr
def lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
""" Applies linear learning rate warmup for distributed training
Argument learning_rate can be float or a Variable
lr = lr + (warmup_rate * step / warmup_steps)
"""
assert (isinstance(end_lr, float))
assert (isinstance(start_lr, float))
linear_step = end_lr - start_lr
with fluid.default_main_program()._lr_schedule_guard():
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="learning_rate_warmup")
global_step = fluid.layers.learning_rate_scheduler._decay_step_counter(
)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(global_step < warmup_steps):
decayed_lr = start_lr + linear_step * (global_step /
warmup_steps)
fluid.layers.tensor.assign(decayed_lr, lr)
with switch.default():
fluid.layers.tensor.assign(learning_rate, lr)
return lr
class Optimizer(object):
"""A class used to represent several optimizer methods
Attributes:
batch_size: batch size on all devices.
lr: learning rate.
lr_strategy: learning rate decay strategy.
l2_decay: l2_decay parameter.
momentum_rate: momentum rate when using Momentum optimizer.
step_epochs: piecewise decay steps.
num_epochs: number of total epochs.
total_images: total images.
step: total steps in the an epoch.
"""
def __init__(self, args):
self.batch_size = args.batch_size
self.lr = args.lr
self.lr_strategy = args.lr_strategy
self.l2_decay = args.l2_decay
self.momentum_rate = args.momentum_rate
self.step_epochs = args.step_epochs
self.num_epochs = args.num_epochs
self.warm_up_epochs = args.warm_up_epochs
self.decay_epochs = args.decay_epochs
self.decay_rate = args.decay_rate
self.total_images = args.total_images
self.step = int(math.ceil(float(self.total_images) / self.batch_size))
def piecewise_decay(self):
"""piecewise decay with Momentum optimizer
Returns:
a piecewise_decay optimizer
"""
bd = [self.step * e for e in self.step_epochs]
lr = [self.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=self.momentum_rate,
regularization=fluid.regularizer.L2Decay(self.l2_decay))
return optimizer
def cosine_decay(self):
"""cosine decay with Momentum optimizer
Returns:
a cosine_decay optimizer
"""
learning_rate = fluid.layers.cosine_decay(
learning_rate=self.lr,
step_each_epoch=self.step,
epochs=self.num_epochs)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=self.momentum_rate,
regularization=fluid.regularizer.L2Decay(self.l2_decay))
return optimizer
def cosine_decay_warmup(self):
"""cosine decay with warmup
Returns:
a cosine_decay_with_warmup optimizer
"""
learning_rate = cosine_decay_with_warmup(
learning_rate=self.lr,
step_each_epoch=self.step,
epochs=self.num_epochs)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=self.momentum_rate,
regularization=fluid.regularizer.L2Decay(self.l2_decay))
return optimizer
def exponential_decay_warmup(self):
"""exponential decay with warmup
Returns:
a exponential_decay_with_warmup optimizer
"""
learning_rate = exponential_decay_with_warmup(
learning_rate=self.lr,
step_each_epoch=self.step,
decay_epochs=self.step * self.decay_epochs,
decay_rate=self.decay_rate,
warm_up_epoch=self.warm_up_epochs)
optimizer = fluid.optimizer.RMSProp(
learning_rate=learning_rate,
regularization=fluid.regularizer.L2Decay(self.l2_decay),
momentum=self.momentum_rate,
rho=0.9,
epsilon=0.001)
return optimizer
def linear_decay(self):
"""linear decay with Momentum optimizer
Returns:
a linear_decay optimizer
"""
end_lr = 0
learning_rate = fluid.layers.polynomial_decay(
self.lr, self.step, end_lr, power=1)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=self.momentum_rate,
regularization=fluid.regularizer.L2Decay(self.l2_decay))
return optimizer
def adam_decay(self):
"""Adam optimizer
Returns:
an adam_decay optimizer
"""
return fluid.optimizer.Adam(learning_rate=self.lr)
def cosine_decay_RMSProp(self):
"""cosine decay with RMSProp optimizer
Returns:
an cosine_decay_RMSProp optimizer
"""
learning_rate = fluid.layers.cosine_decay(
learning_rate=self.lr,
step_each_epoch=self.step,
epochs=self.num_epochs)
optimizer = fluid.optimizer.RMSProp(
learning_rate=learning_rate,
momentum=self.momentum_rate,
regularization=fluid.regularizer.L2Decay(self.l2_decay),
# Apply epsilon=1 on ImageNet dataset.
epsilon=1)
return optimizer
def default_decay(self):
"""default decay
Returns:
default decay optimizer
"""
optimizer = fluid.optimizer.Momentum(
learning_rate=self.lr,
momentum=self.momentum_rate,
regularization=fluid.regularizer.L2Decay(self.l2_decay))
return optimizer
def create_optimizer(args):
Opt = Optimizer(args)
optimizer = getattr(Opt, args.lr_strategy)()
return optimizer
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
import paddle.fluid as fluid
from paddleslim.prune import Pruner
from paddleslim.common import get_logger
from paddleslim.analysis import flops
sys.path.append(sys.path[0] + "/../")
import models
from utility import add_arguments, print_arguments
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64 * 4, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretained", "Whether to use pretrained model.")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 120, "The number of total epochs.")
add_arg('total_images', int, 1281167, "The number of total training images.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('config_file', str, None, "The config file for compression with yaml format.")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def piecewise_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
bd = [step * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def cosine_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
learning_rate = fluid.layers.cosine_decay(
learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def create_optimizer(args):
if args.lr_strategy == "piecewise_decay":
return piecewise_decay(args)
elif args.lr_strategy == "cosine_decay":
return cosine_decay(args)
def compress(args):
train_reader = None
test_reader = None
if args.data == "mnist":
import paddle.dataset.mnist as reader
train_reader = reader.train()
val_reader = reader.test()
class_dim = 10
image_shape = "1,28,28"
elif args.data == "imagenet":
import imagenet_reader as reader
train_reader = reader.train()
val_reader = reader.val()
class_dim = 1000
image_shape = "3,224,224"
else:
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(
args.model, model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
val_program = fluid.default_main_program().clone(for_test=True)
opt = create_optimizer(args)
opt.minimize(avg_cost)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if args.pretrained_model:
def if_exist(var):
return os.path.exists(
os.path.join(args.pretrained_model, var.name))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
val_reader = paddle.batch(val_reader, batch_size=args.batch_size)
train_reader = paddle.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
train_feeder = feeder = fluid.DataFeeder([image, label], place)
val_feeder = feeder = fluid.DataFeeder(
[image, label], place, program=val_program)
def test(epoch, program):
batch_id = 0
acc_top1_ns = []
acc_top5_ns = []
for data in val_reader():
start_time = time.time()
acc_top1_n, acc_top5_n = exe.run(
program,
feed=train_feeder.feed(data),
fetch_list=[acc_top1.name, acc_top5.name])
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id,
np.mean(acc_top1_n),
np.mean(acc_top5_n), end_time - start_time))
acc_top1_ns.append(np.mean(acc_top1_n))
acc_top5_ns.append(np.mean(acc_top5_n))
batch_id += 1
_logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".
format(epoch,
np.mean(np.array(acc_top1_ns)),
np.mean(np.array(acc_top5_ns))))
def train(epoch, program):
build_strategy = fluid.BuildStrategy()
exec_strategy = fluid.ExecutionStrategy()
train_program = fluid.compiler.CompiledProgram(
program).with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
batch_id = 0
for data in train_reader():
start_time = time.time()
loss_n, acc_top1_n, acc_top5_n = exe.run(
train_program,
feed=train_feeder.feed(data),
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
end_time = time.time()
loss_n = np.mean(loss_n)
acc_top1_n = np.mean(acc_top1_n)
acc_top5_n = np.mean(acc_top5_n)
if batch_id % args.log_period == 0:
_logger.info(
"epoch[{}]-batch[{}] - loss: {}; acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n,
end_time - start_time))
batch_id += 1
params = []
for param in fluid.default_main_program().global_block().all_parameters():
if "_sep_weights" in param.name:
params.append(param.name)
_logger.info("fops before pruning: {}".format(
flops(fluid.default_main_program())))
pruner = Pruner()
pruned_val_program = pruner.prune(
val_program,
fluid.global_scope(),
params=params,
ratios=[0.33] * len(params),
place=place,
only_graph=True)
pruned_program = pruner.prune(
fluid.default_main_program(),
fluid.global_scope(),
params=params,
ratios=[0.33] * len(params),
place=place)
_logger.info("fops after pruning: {}".format(flops(pruned_program)))
for i in range(args.num_epochs):
train(i, pruned_program)
if i % args.test_period == 0:
test(i, pruned_val_program)
def main():
args = parser.parse_args()
print_arguments(args)
compress(args)
if __name__ == '__main__':
main()
# Embedding量化示例
本示例介绍如何使用Embedding量化的接口 [paddleslim.quant.quant_embedding]()``quant_embedding``接口将网络中的Embedding参数从``float32``类型量化到 ``8-bit``整数类型,在几乎不损失模型精度的情况下减少模型的存储空间和显存占用。
接口如下:
```
quant_embedding(program, place, config, scope=None)
```
参数介绍:
- program(fluid.Program) : 需要量化的program
- scope(fluid.Scope, optional) : 用来获取和写入``Variable``, 如果设置为``None``,则使用``fluid.global_scope()``.
- place(fluid.CPUPlace or fluid.CUDAPlace): 运行program的设备
- config(dict) : 定义量化的配置。可以配置的参数有:
- ``'params_name'`` (str, required): 需要进行量化的参数名称,此参数必须设置。
- ``'quantize_type'`` (str, optional): 量化的类型,目前支持的类型是``'abs_max'``, 待支持的类型有 ``'log', 'product_quantization'``。 默认值是``'abs_max'``.
- ``'quantize_bits'``(int, optional): 量化的``bit``数,目前支持的``bit``数为8。默认值是8.
- ``'dtype'``(str, optional): 量化之后的数据类型, 目前支持的是``'int8'``. 默认值是``int8``
- ``'threshold'``(float, optional): 量化之前将根据此阈值对需要量化的参数值进行``clip``. 如果不设置,则跳过``clip``过程直接量化。
该接口对program的修改:
量化前:
<p align="center">
<img src="./image/before.png" height=200 width=100 hspace='10'/> <br />
<strong>图1:量化前的模型结构</strong>
</p>
量化后:
<p align="center">
<img src="./image/after.png" height=300 width=300 hspace='10'/> <br />
<strong>图2: 量化后的模型结构</strong>
</p>
以下将以 ``基于skip-gram的word2vector模型`` 为例来说明如何使用``quant_embedding``接口。首先介绍 ``基于skip-gram的word2vector模型`` 的正常训练和测试流程。
## 基于skip-gram的word2vector模型
以下是本例的简要目录结构及说明:
```text
.
├── cluster_train.py # 分布式训练函数
├── cluster_train.sh # 本地模拟多机脚本
├── train.py # 训练函数
├── infer.py # 预测脚本
├── net.py # 网络结构
├── preprocess.py # 预处理脚本,包括构建词典和预处理文本
├── reader.py # 训练阶段的文本读写
├── train.py # 训练函数
└── utils.py # 通用函数
```
### 介绍
本例实现了skip-gram模式的word2vector模型。
同时推荐用户参考[ IPython Notebook demo](https://aistudio.baidu.com/aistudio/projectDetail/124377)
### 数据下载
全量数据集使用的是来自1 Billion Word Language Model Benchmark的(http://www.statmt.org/lm-benchmark) 的数据集.
```bash
mkdir data
wget http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz
tar xzvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz
mv 1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/ data/
```
备用数据地址下载命令如下
```bash
mkdir data
wget https://paddlerec.bj.bcebos.com/word2vec/1-billion-word-language-modeling-benchmark-r13output.tar
tar xvf 1-billion-word-language-modeling-benchmark-r13output.tar
mv 1-billion-word-language-modeling-benchmark-r13output/training-monolingual.tokenized.shuffled/ data/
```
为了方便快速验证,我们也提供了经典的text8样例数据集,包含1700w个词。 下载命令如下
```bash
mkdir data
wget https://paddlerec.bj.bcebos.com/word2vec/text.tar
tar xvf text.tar
mv text data/
```
### 数据预处理
以样例数据集为例进行预处理。全量数据集注意解压后以training-monolingual.tokenized.shuffled 目录为预处理目录,和样例数据集的text目录并列。
词典格式: 词<空格>词频。注意低频词用'UNK'表示
可以按格式自建词典,如果自建词典跳过第一步。
```
the 1061396
of 593677
and 416629
one 411764
in 372201
a 325873
<UNK> 324608
to 316376
zero 264975
nine 250430
```
第一步根据英文语料生成词典,中文语料可以通过修改text_strip方法自定义处理方法。
```bash
python preprocess.py --build_dict --build_dict_corpus_dir data/text/ --dict_path data/test_build_dict
```
第二步根据词典将文本转成id, 同时进行downsample,按照概率过滤常见词, 同时生成word和id映射的文件,文件名为词典+"_word_to_id_"。
```bash
python preprocess.py --filter_corpus --dict_path data/test_build_dict --input_corpus_dir data/text --output_corpus_dir data/convert_text8 --min_count 5 --downsample 0.001
```
### 训练
具体的参数配置可运行
```bash
python train.py -h
```
单机多线程训练
```bash
OPENBLAS_NUM_THREADS=1 CPU_NUM=5 python train.py --train_data_dir data/convert_text8 --dict_path data/test_build_dict --num_passes 10 --batch_size 100 --model_output_dir v1_cpu5_b100_lr1dir --base_lr 1.0 --print_batch 1000 --with_speed --is_sparse
```
本地单机模拟多机训练
```bash
sh cluster_train.sh
```
本示例中按照单机多线程训练的命令进行训练,训练完毕后,可看到在当前文件夹下保存模型的路径为: ``v1_cpu5_b100_lr1dir``, 运行 ``ls v1_cpu5_b100_lr1dir``可看到该文件夹下保存了训练的10个epoch的模型文件。
```
pass-0 pass-1 pass-2 pass-3 pass-4 pass-5 pass-6 pass-7 pass-8 pass-9
```
### 预测
测试集下载命令如下
```bash
#全量数据集测试集
wget https://paddlerec.bj.bcebos.com/word2vec/test_dir.tar
#样本数据集测试集
wget https://paddlerec.bj.bcebos.com/word2vec/test_mid_dir.tar
```
预测命令,注意词典名称需要加后缀"_word_to_id_", 此文件是预处理阶段生成的。
```bash
python infer.py --infer_epoch --test_dir data/test_mid_dir --dict_path data/test_build_dict_word_to_id_ --batch_size 20000 --model_dir v1_cpu5_b100_lr1dir/ --start_index 0 --last_index 9
```
运行该预测命令, 可看到如下输出
```
('start index: ', 0, ' last_index:', 9)
('vocab_size:', 63642)
step:1 249
epoch:0 acc:0.014
step:1 590
epoch:1 acc:0.033
step:1 982
epoch:2 acc:0.055
step:1 1338
epoch:3 acc:0.075
step:1 1653
epoch:4 acc:0.093
step:1 1914
epoch:5 acc:0.107
step:1 2204
epoch:6 acc:0.124
step:1 2416
epoch:7 acc:0.136
step:1 2606
epoch:8 acc:0.146
step:1 2722
epoch:9 acc:0.153
```
## 量化``基于skip-gram的word2vector模型``
量化配置为:
```
config = {
'params_name': 'emb',
'quantize_type': 'abs_max'
}
```
运行命令为:
```bash
python infer.py --infer_epoch --test_dir data/test_mid_dir --dict_path data/test_build_dict_word_to_id_ --batch_size 20000 --model_dir v1_cpu5_b100_lr1dir/ --start_index 0 --last_index 9 --emb_quant True
```
运行输出为:
```
('start index: ', 0, ' last_index:', 9)
('vocab_size:', 63642)
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 253
epoch:0 acc:0.014
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 586
epoch:1 acc:0.033
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 970
epoch:2 acc:0.054
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 1364
epoch:3 acc:0.077
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 1642
epoch:4 acc:0.092
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 1936
epoch:5 acc:0.109
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 2216
epoch:6 acc:0.124
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 2419
epoch:7 acc:0.136
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 2603
epoch:8 acc:0.146
quant_embedding config {'quantize_type': 'abs_max', 'params_name': 'emb', 'quantize_bits': 8, 'dtype': 'int8'}
step:1 2719
epoch:9 acc:0.153
```
量化后的模型保存在``./output_quant``中,可看到量化后的参数``'emb.int8'``的大小为3.9M, 在``./v1_cpu5_b100_lr1dir``中可看到量化前的参数``'emb'``的大小为16M。
from __future__ import print_function
import argparse
import logging
import os
import time
import math
import random
import numpy as np
import paddle
import paddle.fluid as fluid
import six
import reader
from net import skip_gram_word2vec
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(
description="PaddlePaddle Word2vec example")
parser.add_argument(
'--train_data_dir',
type=str,
default='./data/text',
help="The path of taining dataset")
parser.add_argument(
'--base_lr',
type=float,
default=0.01,
help="The number of learing rate (default: 0.01)")
parser.add_argument(
'--save_step',
type=int,
default=500000,
help="The number of step to save (default: 500000)")
parser.add_argument(
'--print_batch',
type=int,
default=100,
help="The number of print_batch (default: 10)")
parser.add_argument(
'--dict_path',
type=str,
default='./data/1-billion_dict',
help="The path of data dict")
parser.add_argument(
'--batch_size',
type=int,
default=500,
help="The size of mini-batch (default:500)")
parser.add_argument(
'--num_passes',
type=int,
default=10,
help="The number of passes to train (default: 10)")
parser.add_argument(
'--model_output_dir',
type=str,
default='models',
help='The path for model to store (default: models)')
parser.add_argument('--nce_num', type=int, default=5, help='nce_num')
parser.add_argument(
'--embedding_size',
type=int,
default=64,
help='sparse feature hashing space for index processing')
parser.add_argument(
'--is_sparse',
action='store_true',
required=False,
default=False,
help='embedding and nce will use sparse or not, (default: False)')
parser.add_argument(
'--with_speed',
action='store_true',
required=False,
default=False,
help='print speed or not , (default: False)')
parser.add_argument(
'--role', type=str, default='pserver', help='trainer or pserver')
parser.add_argument(
'--endpoints',
type=str,
default='127.0.0.1:6000',
help='The pserver endpoints, like: 127.0.0.1:6000, 127.0.0.1:6001')
parser.add_argument(
'--current_endpoint',
type=str,
default='127.0.0.1:6000',
help='The current_endpoint')
parser.add_argument(
'--trainer_id',
type=int,
default=0,
help='trainer id ,only trainer_id=0 save model')
parser.add_argument(
'--trainers',
type=int,
default=1,
help='The num of trianers, (default: 1)')
return parser.parse_args()
def convert_python_to_tensor(weight, batch_size, sample_reader):
def __reader__():
cs = np.array(weight).cumsum()
result = [[], []]
for sample in sample_reader():
for i, fea in enumerate(sample):
result[i].append(fea)
if len(result[0]) == batch_size:
tensor_result = []
for tensor in result:
t = fluid.Tensor()
dat = np.array(tensor, dtype='int64')
if len(dat.shape) > 2:
dat = dat.reshape((dat.shape[0], dat.shape[2]))
elif len(dat.shape) == 1:
dat = dat.reshape((-1, 1))
t.set(dat, fluid.CPUPlace())
tensor_result.append(t)
tt = fluid.Tensor()
neg_array = cs.searchsorted(np.random.sample(args.nce_num))
neg_array = np.tile(neg_array, batch_size)
tt.set(
neg_array.reshape((batch_size, args.nce_num)),
fluid.CPUPlace())
tensor_result.append(tt)
yield tensor_result
result = [[], []]
return __reader__
def train_loop(args, train_program, reader, py_reader, loss, trainer_id,
weight):
py_reader.decorate_tensor_provider(
convert_python_to_tensor(weight, args.batch_size, reader.train()))
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
print("CPU_NUM:" + str(os.getenv("CPU_NUM")))
train_exe = exe
for pass_id in range(args.num_passes):
py_reader.start()
time.sleep(10)
epoch_start = time.time()
batch_id = 0
start = time.time()
try:
while True:
loss_val = train_exe.run(fetch_list=[loss.name])
loss_val = np.mean(loss_val)
if batch_id % args.print_batch == 0:
logger.info(
"TRAIN --> pass: {} batch: {} loss: {} reader queue:{}".
format(pass_id, batch_id,
loss_val.mean(), py_reader.queue.size()))
if args.with_speed:
if batch_id % 500 == 0 and batch_id != 0:
elapsed = (time.time() - start)
start = time.time()
samples = 1001 * args.batch_size * int(
os.getenv("CPU_NUM"))
logger.info("Time used: {}, Samples/Sec: {}".format(
elapsed, samples / elapsed))
if batch_id % args.save_step == 0 and batch_id != 0:
model_dir = args.model_output_dir + '/pass-' + str(
pass_id) + ('/batch-' + str(batch_id))
if trainer_id == 0:
fluid.io.save_params(executor=exe, dirname=model_dir)
print("model saved in %s" % model_dir)
batch_id += 1
except fluid.core.EOFException:
py_reader.reset()
epoch_end = time.time()
logger.info("Epoch: {0}, Train total expend: {1} ".format(
pass_id, epoch_end - epoch_start))
model_dir = args.model_output_dir + '/pass-' + str(pass_id)
if trainer_id == 0:
fluid.io.save_params(executor=exe, dirname=model_dir)
print("model saved in %s" % model_dir)
def GetFileList(data_path):
return os.listdir(data_path)
def train(args):
if not os.path.isdir(args.model_output_dir) and args.trainer_id == 0:
os.mkdir(args.model_output_dir)
filelist = GetFileList(args.train_data_dir)
word2vec_reader = reader.Word2VecReader(
args.dict_path, args.train_data_dir, filelist, 0, 1)
logger.info("dict_size: {}".format(word2vec_reader.dict_size))
np_power = np.power(np.array(word2vec_reader.id_frequencys), 0.75)
id_frequencys_pow = np_power / np_power.sum()
loss, py_reader = skip_gram_word2vec(
word2vec_reader.dict_size,
args.embedding_size,
is_sparse=args.is_sparse,
neg_num=args.nce_num)
optimizer = fluid.optimizer.SGD(
learning_rate=fluid.layers.exponential_decay(
learning_rate=args.base_lr,
decay_steps=100000,
decay_rate=0.999,
staircase=True))
optimizer.minimize(loss)
logger.info("run dist training")
t = fluid.DistributeTranspiler()
t.transpile(
args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
if args.role == "pserver":
print("run psever")
pserver_prog = t.get_pserver_program(args.current_endpoint)
pserver_startup = t.get_startup_program(args.current_endpoint,
pserver_prog)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(pserver_startup)
exe.run(pserver_prog)
elif args.role == "trainer":
print("run trainer")
train_loop(args,
t.get_trainer_program(), word2vec_reader, py_reader, loss,
args.trainer_id, id_frequencys_pow)
if __name__ == '__main__':
args = parse_args()
train(args)
#!/bin/bash
#export GLOG_v=30
#export GLOG_logtostderr=1
# start pserver0
export CPU_NUM=5
export FLAGS_rpc_deadline=3000000
python cluster_train.py \
--train_data_dir data/convert_text8 \
--dict_path data/test_build_dict \
--batch_size 100 \
--model_output_dir dis_model \
--base_lr 1.0 \
--print_batch 1 \
--is_sparse \
--with_speed \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6000 \
--trainers 2 \
> pserver0.log 2>&1 &
python cluster_train.py \
--train_data_dir data/convert_text8 \
--dict_path data/test_build_dict \
--batch_size 100 \
--model_output_dir dis_model \
--base_lr 1.0 \
--print_batch 1 \
--is_sparse \
--with_speed \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6001 \
--trainers 2 \
> pserver1.log 2>&1 &
# start trainer0
python cluster_train.py \
--train_data_dir data/convert_text8 \
--dict_path data/test_build_dict \
--batch_size 100 \
--model_output_dir dis_model \
--base_lr 1.0 \
--print_batch 1000 \
--is_sparse \
--with_speed \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 0 \
> trainer0.log 2>&1 &
# start trainer1
python cluster_train.py \
--train_data_dir data/convert_text8 \
--dict_path data/test_build_dict \
--batch_size 100 \
--model_output_dir dis_model \
--base_lr 1.0 \
--print_batch 1000 \
--is_sparse \
--with_speed \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 1 \
> trainer1.log 2>&1 &
import argparse
import sys
import time
import math
import unittest
import contextlib
import numpy as np
import six
import paddle.fluid as fluid
import paddle
import net
import utils
sys.path.append(sys.path[0] + "/../../../")
from paddleslim.quant import quant_embedding
def parse_args():
parser = argparse.ArgumentParser("PaddlePaddle Word2vec infer example")
parser.add_argument(
'--dict_path',
type=str,
default='./data/data_c/1-billion_dict_word_to_id_',
help="The path of dic")
parser.add_argument(
'--infer_epoch',
action='store_true',
required=False,
default=False,
help='infer by epoch')
parser.add_argument(
'--infer_step',
action='store_true',
required=False,
default=False,
help='infer by step')
parser.add_argument(
'--test_dir', type=str, default='test_data', help='test file address')
parser.add_argument(
'--print_step', type=int, default='500000', help='print step')
parser.add_argument(
'--start_index', type=int, default='0', help='start index')
parser.add_argument(
'--start_batch', type=int, default='1', help='start index')
parser.add_argument(
'--end_batch', type=int, default='13', help='start index')
parser.add_argument(
'--last_index', type=int, default='100', help='last index')
parser.add_argument(
'--model_dir', type=str, default='model', help='model dir')
parser.add_argument(
'--use_cuda', type=int, default='0', help='whether use cuda')
parser.add_argument(
'--batch_size', type=int, default='5', help='batch_size')
parser.add_argument(
'--emb_size', type=int, default='64', help='batch_size')
parser.add_argument(
'--emb_quant',
type=bool,
default=False,
help='whether to quant embedding parameter')
args = parser.parse_args()
return args
def infer_epoch(args, vocab_size, test_reader, use_cuda, i2w):
""" inference function """
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
emb_size = args.emb_size
batch_size = args.batch_size
with fluid.scope_guard(fluid.Scope()):
main_program = fluid.Program()
with fluid.program_guard(main_program):
values, pred = net.infer_network(vocab_size, emb_size)
for epoch in range(start_index, last_index + 1):
copy_program = main_program.clone()
model_path = model_dir + "/pass-" + str(epoch)
fluid.io.load_params(
executor=exe,
dirname=model_path,
main_program=copy_program)
if args.emb_quant:
config = {'params_name': 'emb', 'quantize_type': 'abs_max'}
copy_program = quant_embedding(copy_program, place, config)
fluid.io.save_persistables(
exe,
'./output_quant/pass-' + str(epoch),
main_program=copy_program)
accum_num = 0
accum_num_sum = 0.0
t0 = time.time()
step_id = 0
for data in test_reader():
step_id += 1
b_size = len([dat[0] for dat in data])
wa = np.array(
[dat[0] for dat in data]).astype("int64").reshape(
b_size, 1)
wb = np.array(
[dat[1] for dat in data]).astype("int64").reshape(
b_size, 1)
wc = np.array(
[dat[2] for dat in data]).astype("int64").reshape(
b_size, 1)
label = [dat[3] for dat in data]
input_word = [dat[4] for dat in data]
para = exe.run(copy_program,
feed={
"analogy_a": wa,
"analogy_b": wb,
"analogy_c": wc,
"all_label":
np.arange(vocab_size).reshape(
vocab_size, 1).astype("int64"),
},
fetch_list=[pred.name, values],
return_numpy=False)
pre = np.array(para[0])
val = np.array(para[1])
for ii in range(len(label)):
top4 = pre[ii]
accum_num_sum += 1
for idx in top4:
if int(idx) in input_word[ii]:
continue
if int(idx) == int(label[ii][0]):
accum_num += 1
break
if step_id % 1 == 0:
print("step:%d %d " % (step_id, accum_num))
print("epoch:%d \t acc:%.3f " %
(epoch, 1.0 * accum_num / accum_num_sum))
def infer_step(args, vocab_size, test_reader, use_cuda, i2w):
""" inference function """
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
emb_size = args.emb_size
batch_size = args.batch_size
with fluid.scope_guard(fluid.Scope()):
main_program = fluid.Program()
with fluid.program_guard(main_program):
values, pred = net.infer_network(vocab_size, emb_size)
for epoch in range(start_index, last_index + 1):
for batchid in range(args.start_batch, args.end_batch):
copy_program = main_program.clone()
model_path = model_dir + "/pass-" + str(epoch) + (
'/batch-' + str(batchid * args.print_step))
fluid.io.load_params(
executor=exe,
dirname=model_path,
main_program=copy_program)
accum_num = 0
accum_num_sum = 0.0
t0 = time.time()
step_id = 0
for data in test_reader():
step_id += 1
b_size = len([dat[0] for dat in data])
wa = np.array(
[dat[0] for dat in data]).astype("int64").reshape(
b_size, 1)
wb = np.array(
[dat[1] for dat in data]).astype("int64").reshape(
b_size, 1)
wc = np.array(
[dat[2] for dat in data]).astype("int64").reshape(
b_size, 1)
label = [dat[3] for dat in data]
input_word = [dat[4] for dat in data]
para = exe.run(
copy_program,
feed={
"analogy_a": wa,
"analogy_b": wb,
"analogy_c": wc,
"all_label":
np.arange(vocab_size).reshape(vocab_size, 1),
},
fetch_list=[pred.name, values],
return_numpy=False)
pre = np.array(para[0])
val = np.array(para[1])
for ii in range(len(label)):
top4 = pre[ii]
accum_num_sum += 1
for idx in top4:
if int(idx) in input_word[ii]:
continue
if int(idx) == int(label[ii][0]):
accum_num += 1
break
if step_id % 1 == 0:
print("step:%d %d " % (step_id, accum_num))
print("epoch:%d \t acc:%.3f " %
(epoch, 1.0 * accum_num / accum_num_sum))
t1 = time.time()
if __name__ == "__main__":
args = parse_args()
start_index = args.start_index
last_index = args.last_index
test_dir = args.test_dir
model_dir = args.model_dir
batch_size = args.batch_size
dict_path = args.dict_path
use_cuda = True if args.use_cuda else False
print("start index: ", start_index, " last_index:", last_index)
vocab_size, test_reader, id2word = utils.prepare_data(
test_dir, dict_path, batch_size=batch_size)
print("vocab_size:", vocab_size)
if args.infer_step:
infer_step(
args,
vocab_size,
test_reader=test_reader,
use_cuda=use_cuda,
i2w=id2word)
else:
infer_epoch(
args,
vocab_size,
test_reader=test_reader,
use_cuda=use_cuda,
i2w=id2word)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
neural network for word2vec
"""
from __future__ import print_function
import math
import numpy as np
import paddle.fluid as fluid
def skip_gram_word2vec(dict_size, embedding_size, is_sparse=False, neg_num=5):
datas = []
input_word = fluid.layers.data(name="input_word", shape=[1], dtype='int64')
true_word = fluid.layers.data(name='true_label', shape=[1], dtype='int64')
neg_word = fluid.layers.data(
name="neg_label", shape=[neg_num], dtype='int64')
datas.append(input_word)
datas.append(true_word)
datas.append(neg_word)
py_reader = fluid.layers.create_py_reader_by_data(
capacity=64, feed_list=datas, name='py_reader', use_double_buffer=True)
words = fluid.layers.read_file(py_reader)
init_width = 0.5 / embedding_size
input_emb = fluid.layers.embedding(
input=words[0],
is_sparse=is_sparse,
size=[dict_size, embedding_size],
param_attr=fluid.ParamAttr(
name='emb',
initializer=fluid.initializer.Uniform(-init_width, init_width)))
true_emb_w = fluid.layers.embedding(
input=words[1],
is_sparse=is_sparse,
size=[dict_size, embedding_size],
param_attr=fluid.ParamAttr(
name='emb_w', initializer=fluid.initializer.Constant(value=0.0)))
true_emb_b = fluid.layers.embedding(
input=words[1],
is_sparse=is_sparse,
size=[dict_size, 1],
param_attr=fluid.ParamAttr(
name='emb_b', initializer=fluid.initializer.Constant(value=0.0)))
neg_word_reshape = fluid.layers.reshape(words[2], shape=[-1, 1])
neg_word_reshape.stop_gradient = True
neg_emb_w = fluid.layers.embedding(
input=neg_word_reshape,
is_sparse=is_sparse,
size=[dict_size, embedding_size],
param_attr=fluid.ParamAttr(
name='emb_w', learning_rate=1.0))
neg_emb_w_re = fluid.layers.reshape(
neg_emb_w, shape=[-1, neg_num, embedding_size])
neg_emb_b = fluid.layers.embedding(
input=neg_word_reshape,
is_sparse=is_sparse,
size=[dict_size, 1],
param_attr=fluid.ParamAttr(
name='emb_b', learning_rate=1.0))
neg_emb_b_vec = fluid.layers.reshape(neg_emb_b, shape=[-1, neg_num])
true_logits = fluid.layers.elementwise_add(
fluid.layers.reduce_sum(
fluid.layers.elementwise_mul(input_emb, true_emb_w),
dim=1,
keep_dim=True),
true_emb_b)
input_emb_re = fluid.layers.reshape(
input_emb, shape=[-1, 1, embedding_size])
neg_matmul = fluid.layers.matmul(
input_emb_re, neg_emb_w_re, transpose_y=True)
neg_matmul_re = fluid.layers.reshape(neg_matmul, shape=[-1, neg_num])
neg_logits = fluid.layers.elementwise_add(neg_matmul_re, neg_emb_b_vec)
#nce loss
label_ones = fluid.layers.fill_constant_batch_size_like(
true_logits, shape=[-1, 1], value=1.0, dtype='float32')
label_zeros = fluid.layers.fill_constant_batch_size_like(
true_logits, shape=[-1, neg_num], value=0.0, dtype='float32')
true_xent = fluid.layers.sigmoid_cross_entropy_with_logits(true_logits,
label_ones)
neg_xent = fluid.layers.sigmoid_cross_entropy_with_logits(neg_logits,
label_zeros)
cost = fluid.layers.elementwise_add(
fluid.layers.reduce_sum(
true_xent, dim=1),
fluid.layers.reduce_sum(
neg_xent, dim=1))
avg_cost = fluid.layers.reduce_mean(cost)
return avg_cost, py_reader
def infer_network(vocab_size, emb_size):
analogy_a = fluid.layers.data(name="analogy_a", shape=[1], dtype='int64')
analogy_b = fluid.layers.data(name="analogy_b", shape=[1], dtype='int64')
analogy_c = fluid.layers.data(name="analogy_c", shape=[1], dtype='int64')
all_label = fluid.layers.data(
name="all_label",
shape=[vocab_size, 1],
dtype='int64',
append_batch_size=False)
emb_all_label = fluid.layers.embedding(
input=all_label, size=[vocab_size, emb_size], param_attr="emb")
emb_a = fluid.layers.embedding(
input=analogy_a, size=[vocab_size, emb_size], param_attr="emb")
emb_b = fluid.layers.embedding(
input=analogy_b, size=[vocab_size, emb_size], param_attr="emb")
emb_c = fluid.layers.embedding(
input=analogy_c, size=[vocab_size, emb_size], param_attr="emb")
target = fluid.layers.elementwise_add(
fluid.layers.elementwise_sub(emb_b, emb_a), emb_c)
emb_all_label_l2 = fluid.layers.l2_normalize(x=emb_all_label, axis=1)
dist = fluid.layers.matmul(x=target, y=emb_all_label_l2, transpose_y=True)
values, pred_idx = fluid.layers.topk(input=dist, k=4)
return values, pred_idx
# -*- coding: utf-8 -*
import os
import random
import re
import six
import argparse
import io
import math
prog = re.compile("[^a-z ]", flags=0)
def parse_args():
parser = argparse.ArgumentParser(
description="Paddle Fluid word2 vector preprocess")
parser.add_argument(
'--build_dict_corpus_dir', type=str, help="The dir of corpus")
parser.add_argument(
'--input_corpus_dir', type=str, help="The dir of input corpus")
parser.add_argument(
'--output_corpus_dir', type=str, help="The dir of output corpus")
parser.add_argument(
'--dict_path',
type=str,
default='./dict',
help="The path of dictionary ")
parser.add_argument(
'--min_count',
type=int,
default=5,
help="If the word count is less then min_count, it will be removed from dict"
)
parser.add_argument(
'--downsample',
type=float,
default=0.001,
help="filter word by downsample")
parser.add_argument(
'--filter_corpus',
action='store_true',
default=False,
help='Filter corpus')
parser.add_argument(
'--build_dict',
action='store_true',
default=False,
help='Build dict from corpus')
return parser.parse_args()
def text_strip(text):
#English Preprocess Rule
return prog.sub("", text.lower())
# Shameless copy from Tensorflow https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/data_generators/text_encoder.py
# Unicode utility functions that work with Python 2 and 3
def native_to_unicode(s):
if _is_unicode(s):
return s
try:
return _to_unicode(s)
except UnicodeDecodeError:
res = _to_unicode(s, ignore_errors=True)
return res
def _is_unicode(s):
if six.PY2:
if isinstance(s, unicode):
return True
else:
if isinstance(s, str):
return True
return False
def _to_unicode(s, ignore_errors=False):
if _is_unicode(s):
return s
error_mode = "ignore" if ignore_errors else "strict"
return s.decode("utf-8", errors=error_mode)
def filter_corpus(args):
"""
filter corpus and convert id.
"""
word_count = dict()
word_to_id_ = dict()
word_all_count = 0
id_counts = []
word_id = 0
#read dict
with io.open(args.dict_path, 'r', encoding='utf-8') as f:
for line in f:
word, count = line.split()[0], int(line.split()[1])
word_count[word] = count
word_to_id_[word] = word_id
word_id += 1
id_counts.append(count)
word_all_count += count
#write word2id file
print("write word2id file to : " + args.dict_path + "_word_to_id_")
with io.open(
args.dict_path + "_word_to_id_", 'w+', encoding='utf-8') as fid:
for k, v in word_to_id_.items():
fid.write(k + " " + str(v) + '\n')
#filter corpus and convert id
if not os.path.exists(args.output_corpus_dir):
os.makedirs(args.output_corpus_dir)
for file in os.listdir(args.input_corpus_dir):
with io.open(args.output_corpus_dir + '/convert_' + file, "w") as wf:
with io.open(
args.input_corpus_dir + '/' + file,
encoding='utf-8') as rf:
print(args.input_corpus_dir + '/' + file)
for line in rf:
signal = False
line = text_strip(line)
words = line.split()
for item in words:
if item in word_count:
idx = word_to_id_[item]
else:
idx = word_to_id_[native_to_unicode('<UNK>')]
count_w = id_counts[idx]
corpus_size = word_all_count
keep_prob = (
math.sqrt(count_w /
(args.downsample * corpus_size)) + 1
) * (args.downsample * corpus_size) / count_w
r_value = random.random()
if r_value > keep_prob:
continue
wf.write(_to_unicode(str(idx) + " "))
signal = True
if signal:
wf.write(_to_unicode("\n"))
def build_dict(args):
"""
proprocess the data, generate dictionary and save into dict_path.
:param corpus_dir: the input data dir.
:param dict_path: the generated dict path. the data in dict is "word count"
:param min_count:
:return:
"""
# word to count
word_count = dict()
for file in os.listdir(args.build_dict_corpus_dir):
with io.open(
args.build_dict_corpus_dir + "/" + file,
encoding='utf-8') as f:
print("build dict : ", args.build_dict_corpus_dir + "/" + file)
for line in f:
line = text_strip(line)
words = line.split()
for item in words:
if item in word_count:
word_count[item] = word_count[item] + 1
else:
word_count[item] = 1
item_to_remove = []
for item in word_count:
if word_count[item] <= args.min_count:
item_to_remove.append(item)
unk_sum = 0
for item in item_to_remove:
unk_sum += word_count[item]
del word_count[item]
#sort by count
word_count[native_to_unicode('<UNK>')] = unk_sum
word_count = sorted(
word_count.items(), key=lambda word_count: -word_count[1])
with io.open(args.dict_path, 'w+', encoding='utf-8') as f:
for k, v in word_count:
f.write(k + " " + str(v) + '\n')
if __name__ == "__main__":
args = parse_args()
if args.build_dict:
build_dict(args)
elif args.filter_corpus:
filter_corpus(args)
else:
print(
"error command line, please choose --build_dict or --filter_corpus")
# -*- coding: utf-8 -*
import numpy as np
import preprocess
import logging
import math
import random
import io
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
class NumpyRandomInt(object):
def __init__(self, a, b, buf_size=1000):
self.idx = 0
self.buffer = np.random.random_integers(a, b, buf_size)
self.a = a
self.b = b
def __call__(self):
if self.idx == len(self.buffer):
self.buffer = np.random.random_integers(self.a, self.b,
len(self.buffer))
self.idx = 0
result = self.buffer[self.idx]
self.idx += 1
return result
class Word2VecReader(object):
def __init__(self,
dict_path,
data_path,
filelist,
trainer_id,
trainer_num,
window_size=5):
self.window_size_ = window_size
self.data_path_ = data_path
self.filelist = filelist
self.trainer_id = trainer_id
self.trainer_num = trainer_num
word_all_count = 0
id_counts = []
word_id = 0
with io.open(dict_path, 'r', encoding='utf-8') as f:
for line in f:
word, count = line.split()[0], int(line.split()[1])
word_id += 1
id_counts.append(count)
word_all_count += count
self.word_all_count = word_all_count
self.corpus_size_ = word_all_count
self.dict_size = len(id_counts)
self.id_counts_ = id_counts
print("corpus_size:", self.corpus_size_)
self.id_frequencys = [
float(count) / word_all_count for count in self.id_counts_
]
print("dict_size = " + str(self.dict_size) + " word_all_count = " +
str(word_all_count))
self.random_generator = NumpyRandomInt(1, self.window_size_ + 1)
def get_context_words(self, words, idx):
"""
Get the context word list of target word.
words: the words of the current line
idx: input word index
window_size: window size
"""
target_window = self.random_generator()
start_point = idx - target_window # if (idx - target_window) > 0 else 0
if start_point < 0:
start_point = 0
end_point = idx + target_window
targets = words[start_point:idx] + words[idx + 1:end_point + 1]
return targets
def train(self):
def nce_reader():
for file in self.filelist:
with io.open(
self.data_path_ + "/" + file, 'r',
encoding='utf-8') as f:
logger.info("running data in {}".format(self.data_path_ +
"/" + file))
count = 1
for line in f:
if self.trainer_id == count % self.trainer_num:
word_ids = [int(w) for w in line.split()]
for idx, target_id in enumerate(word_ids):
context_word_ids = self.get_context_words(
word_ids, idx)
for context_id in context_word_ids:
yield [target_id], [context_id]
count += 1
return nce_reader
from __future__ import print_function
import argparse
import logging
import os
import time
import math
import random
import numpy as np
import paddle
import paddle.fluid as fluid
import six
import reader
from net import skip_gram_word2vec
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(
description="PaddlePaddle Word2vec example")
parser.add_argument(
'--train_data_dir',
type=str,
default='./data/text',
help="The path of taining dataset")
parser.add_argument(
'--base_lr',
type=float,
default=0.01,
help="The number of learing rate (default: 0.01)")
parser.add_argument(
'--save_step',
type=int,
default=500000,
help="The number of step to save (default: 500000)")
parser.add_argument(
'--print_batch',
type=int,
default=10,
help="The number of print_batch (default: 10)")
parser.add_argument(
'--dict_path',
type=str,
default='./data/1-billion_dict',
help="The path of data dict")
parser.add_argument(
'--batch_size',
type=int,
default=500,
help="The size of mini-batch (default:500)")
parser.add_argument(
'--num_passes',
type=int,
default=10,
help="The number of passes to train (default: 10)")
parser.add_argument(
'--model_output_dir',
type=str,
default='models',
help='The path for model to store (default: models)')
parser.add_argument('--nce_num', type=int, default=5, help='nce_num')
parser.add_argument(
'--embedding_size',
type=int,
default=64,
help='sparse feature hashing space for index processing')
parser.add_argument(
'--is_sparse',
action='store_true',
required=False,
default=False,
help='embedding and nce will use sparse or not, (default: False)')
parser.add_argument(
'--with_speed',
action='store_true',
required=False,
default=False,
help='print speed or not , (default: False)')
return parser.parse_args()
def convert_python_to_tensor(weight, batch_size, sample_reader):
def __reader__():
cs = np.array(weight).cumsum()
result = [[], []]
for sample in sample_reader():
for i, fea in enumerate(sample):
result[i].append(fea)
if len(result[0]) == batch_size:
tensor_result = []
for tensor in result:
t = fluid.Tensor()
dat = np.array(tensor, dtype='int64')
if len(dat.shape) > 2:
dat = dat.reshape((dat.shape[0], dat.shape[2]))
elif len(dat.shape) == 1:
dat = dat.reshape((-1, 1))
t.set(dat, fluid.CPUPlace())
tensor_result.append(t)
tt = fluid.Tensor()
neg_array = cs.searchsorted(np.random.sample(args.nce_num))
neg_array = np.tile(neg_array, batch_size)
tt.set(
neg_array.reshape((batch_size, args.nce_num)),
fluid.CPUPlace())
tensor_result.append(tt)
yield tensor_result
result = [[], []]
return __reader__
def train_loop(args, train_program, reader, py_reader, loss, trainer_id,
weight):
py_reader.decorate_tensor_provider(
convert_python_to_tensor(weight, args.batch_size, reader.train()))
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True
print("CPU_NUM:" + str(os.getenv("CPU_NUM")))
exec_strategy.num_threads = int(os.getenv("CPU_NUM"))
build_strategy = fluid.BuildStrategy()
if int(os.getenv("CPU_NUM")) > 1:
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
train_exe = fluid.ParallelExecutor(
use_cuda=False,
loss_name=loss.name,
main_program=train_program,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
for pass_id in range(args.num_passes):
py_reader.start()
time.sleep(10)
epoch_start = time.time()
batch_id = 0
start = time.time()
try:
while True:
loss_val = train_exe.run(fetch_list=[loss.name])
loss_val = np.mean(loss_val)
if batch_id % args.print_batch == 0:
logger.info(
"TRAIN --> pass: {} batch: {} loss: {} reader queue:{}".
format(pass_id, batch_id,
loss_val.mean(), py_reader.queue.size()))
if args.with_speed:
if batch_id % 500 == 0 and batch_id != 0:
elapsed = (time.time() - start)
start = time.time()
samples = 1001 * args.batch_size * int(
os.getenv("CPU_NUM"))
logger.info("Time used: {}, Samples/Sec: {}".format(
elapsed, samples / elapsed))
if batch_id % args.save_step == 0 and batch_id != 0:
model_dir = args.model_output_dir + '/pass-' + str(
pass_id) + ('/batch-' + str(batch_id))
if trainer_id == 0:
fluid.io.save_params(executor=exe, dirname=model_dir)
print("model saved in %s" % model_dir)
batch_id += 1
except fluid.core.EOFException:
py_reader.reset()
epoch_end = time.time()
logger.info("Epoch: {0}, Train total expend: {1} ".format(
pass_id, epoch_end - epoch_start))
model_dir = args.model_output_dir + '/pass-' + str(pass_id)
if trainer_id == 0:
fluid.io.save_params(executor=exe, dirname=model_dir)
print("model saved in %s" % model_dir)
def GetFileList(data_path):
return os.listdir(data_path)
def train(args):
if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
filelist = GetFileList(args.train_data_dir)
word2vec_reader = reader.Word2VecReader(
args.dict_path, args.train_data_dir, filelist, 0, 1)
logger.info("dict_size: {}".format(word2vec_reader.dict_size))
np_power = np.power(np.array(word2vec_reader.id_frequencys), 0.75)
id_frequencys_pow = np_power / np_power.sum()
loss, py_reader = skip_gram_word2vec(
word2vec_reader.dict_size,
args.embedding_size,
is_sparse=args.is_sparse,
neg_num=args.nce_num)
optimizer = fluid.optimizer.SGD(
learning_rate=fluid.layers.exponential_decay(
learning_rate=args.base_lr,
decay_steps=100000,
decay_rate=0.999,
staircase=True))
optimizer.minimize(loss)
# do local training
logger.info("run local training")
main_program = fluid.default_main_program()
train_loop(args, main_program, word2vec_reader, py_reader, loss, 0,
id_frequencys_pow)
if __name__ == '__main__':
args = parse_args()
train(args)
import sys
import collections
import six
import time
import numpy as np
import paddle.fluid as fluid
import paddle
import os
import preprocess
def BuildWord_IdMap(dict_path):
word_to_id = dict()
id_to_word = dict()
with open(dict_path, 'r') as f:
for line in f:
word_to_id[line.split(' ')[0]] = int(line.split(' ')[1])
id_to_word[int(line.split(' ')[1])] = line.split(' ')[0]
return word_to_id, id_to_word
def prepare_data(file_dir, dict_path, batch_size):
w2i, i2w = BuildWord_IdMap(dict_path)
vocab_size = len(i2w)
reader = paddle.batch(test(file_dir, w2i), batch_size)
return vocab_size, reader, i2w
def native_to_unicode(s):
if _is_unicode(s):
return s
try:
return _to_unicode(s)
except UnicodeDecodeError:
res = _to_unicode(s, ignore_errors=True)
return res
def _is_unicode(s):
if six.PY2:
if isinstance(s, unicode):
return True
else:
if isinstance(s, str):
return True
return False
def _to_unicode(s, ignore_errors=False):
if _is_unicode(s):
return s
error_mode = "ignore" if ignore_errors else "strict"
return s.decode("utf-8", errors=error_mode)
def strip_lines(line, vocab):
return _replace_oov(vocab, native_to_unicode(line))
def _replace_oov(original_vocab, line):
"""Replace out-of-vocab words with "<UNK>".
This maintains compatibility with published results.
Args:
original_vocab: a set of strings (The standard vocabulary for the dataset)
line: a unicode string - a space-delimited sequence of words.
Returns:
a unicode string - a space-delimited sequence of words.
"""
return u" ".join([
word if word in original_vocab else u"<UNK>" for word in line.split()
])
def reader_creator(file_dir, word_to_id):
def reader():
files = os.listdir(file_dir)
for fi in files:
with open(file_dir + '/' + fi, "r") as f:
for line in f:
if ':' in line:
pass
else:
line = strip_lines(line.lower(), word_to_id)
line = line.split()
yield [word_to_id[line[0]]], [word_to_id[line[1]]], [
word_to_id[line[2]]
], [word_to_id[line[3]]], [
word_to_id[line[0]], word_to_id[line[1]],
word_to_id[line[2]]
]
return reader
def test(test_dir, w2i):
return reader_creator(test_dir, w2i)
# 离线量化示例
本示例介绍如何使用离线量化接口``paddleslim.quant.quant_post``来对训练好的分类模型进行离线量化, 该接口无需对模型进行训练就可得到量化模型,减少模型的存储空间和显存占用。
## 接口介绍
```
quant_post(executor,
model_dir,
quantize_model_path,
sample_generator,
model_filename=None,
params_filename=None,
batch_size=16,
batch_nums=None,
scope=None,
algo='KL',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"])
```
参数介绍:
- executor (fluid.Executor): 执行模型的executor,可以在cpu或者gpu上执行。
- model_dir(str): 需要量化的模型所在的文件夹。
- quantize_model_path(str): 保存量化后的模型的路径
- sample_generator(python generator): 读取数据样本,每次返回一个样本。
- model_filename(str, optional): 模型文件名,如果需要量化的模型的参数存在一个文件中,则需要设置``model_filename``为模型文件的名称,否则设置为``None``即可。默认值是``None``
- params_filename(str): 参数文件名,如果需要量化的模型的参数存在一个文件中,则需要设置``params_filename``为参数文件的名称,否则设置为``None``即可。默认值是``None``
- batch_size(int): 每个batch的图片数量。默认值为16 。
- batch_nums(int, optional): 迭代次数。如果设置为``None``,则会一直运行到``sample_generator`` 迭代结束, 否则,迭代次数为``batch_nums``, 也就是说参与对``Scale``进行校正的样本个数为 ``'batch_nums' * 'batch_size' ``.
- scope(fluid.Scope, optional): 用来获取和写入``Variable``, 如果设置为``None``,则使用``fluid.global_scope()``. 默认值是``None``.
- algo(str): 量化时使用的算法名称,可为``'KL'``或者``'direct'``。该参数仅针对激活值的量化,因为参数值的量化使用的方式为``'channel_wise_abs_max'``. 当``algo`` 设置为``'direct'``时,使用``'abs_max'``计算``Scale``值,当设置为``'KL'``时,则使用``KL``散度的方法来计算``Scale``值。默认值为``'KL'``
- quantizable_op_type(list[str]): 需要量化的``op``类型列表。默认值为``["conv2d", "depthwise_conv2d", "mul"]``
## 分类模型的离线量化流程
### 准备数据
在当前文件夹下创建``data``文件夹,将``imagenet``数据集解压在``data``文件夹下,解压后``data``文件夹下应包含以下文件:
- ``'train'``文件夹,训练图片
- ``'train_list.txt'``文件
- ``'val'``文件夹,验证图片
- ``'val_list.txt'``文件
### 准备需要量化的模型
因为离线量化接口只支持加载通过``fluid.io.save_inference_model``接口保存的模型,因此如果您的模型是通过其他接口保存的,那需要先将模型进行转化。本示例将以分类模型为例进行说明。
首先在[imagenet分类模型](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification#%E5%B7%B2%E5%8F%91%E5%B8%83%E6%A8%A1%E5%9E%8B%E5%8F%8A%E5%85%B6%E6%80%A7%E8%83%BD)中下载训练好的``mobilenetv1``模型。
在当前文件夹下创建``'pretrain'``文件夹,将``mobilenetv1``模型在该文件夹下解压,解压后的目录为``pretrain/MobileNetV1_pretrained``
### 导出模型
通过运行以下命令可将模型转化为离线量化接口可用的模型:
```
python export_model.py --model "MobileNet" --pretrained_model ./pretrain/MobileNetV1_pretrained --data imagenet
```
转化之后的模型存储在``inference_model/MobileNet/``文件夹下,可看到该文件夹下有``'model'``, ``'weights'``两个文件。
### 离线量化
接下来对导出的模型文件进行离线量化,离线量化的脚本为[quant_post.py](./quant_post.py),脚本中使用接口``paddleslim.quant.quant_post``对模型进行离线量化。运行命令为:
```
python quant_post.py --model_path ./inference_model/MobileNet --save_path ./quant_model_train/MobileNet --model_filename model --params_filename weights
```
- ``model_path``: 需要量化的模型坐在的文件夹
- ``save_path``: 量化后的模型保存的路径
- ``model_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的模型文件名称,如果参数文件保存在多个文件中,则不需要设置。
- ``params_filename``: 如果需要量化的模型的参数文件保存在一个文件中,则设置为该模型的参数文件名称,如果参数文件保存在多个文件中,则不需要设置。
运行以上命令后,可在``${save_path}``下看到量化后的模型文件和参数文件。
> 使用的量化算法为``'KL'``, 使用训练集中的160张图片进行量化参数的校正。
### 测试精度
使用[eval.py](./eval.py)脚本对量化前后的模型进行测试,得到模型的分类精度进行对比。
首先测试量化前的模型的精度,运行以下命令:
```
python eval.py --model_path ./inference_model/MobileNet --model_name model --params_name weights
```
精度输出为:
```
top1_acc/top5_acc= [0.70913923 0.89548034]
```
使用以下命令测试离线量化后的模型的精度:
```
python eval.py --model_path ./quant_model_train/MobileNet
```
精度输出为
```
top1_acc/top5_acc= [0.70141864 0.89086477]
```
从以上精度对比可以看出,对``mobilenet````imagenet``上的分类模型进行离线量化后 ``top1``精度损失为``0.77%````top5``精度损失为``0.46%``.
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import sys
import numpy as np
import argparse
import functools
import paddle
import paddle.fluid as fluid
sys.path.append('../../')
import imagenet_reader as reader
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
# yapf: disable
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model_path', str, "./pruning/checkpoints/resnet50/2/eval_model/", "Whether to use pretrained model.")
add_arg('model_name', str, None, "model filename for inference model")
add_arg('params_name', str, None, "params filename for inference model")
# yapf: enable
def eval(args):
# parameters from arguments
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
val_program, feed_target_names, fetch_targets = fluid.io.load_inference_model(
args.model_path,
exe,
model_filename=args.model_name,
params_filename=args.params_name)
val_reader = paddle.batch(reader.val(), batch_size=128)
feeder = fluid.DataFeeder(
place=place, feed_list=feed_target_names, program=val_program)
results = []
for batch_id, data in enumerate(val_reader()):
# top1_acc, top5_acc
if len(feed_target_names) == 1:
# eval "infer model", which input is image, output is classification probability
image = [[d[0]] for d in data]
label = [[d[1]] for d in data]
feed_data = feeder.feed(image)
pred = exe.run(val_program,
feed=feed_data,
fetch_list=fetch_targets)
pred = np.array(pred[0])
label = np.array(label)
sort_array = pred.argsort(axis=1)
top_1_pred = sort_array[:, -1:][:, ::-1]
top_1 = np.mean(label == top_1_pred)
top_5_pred = sort_array[:, -5:][:, ::-1]
acc_num = 0
for i in range(len(label)):
if label[i][0] in top_5_pred[i]:
acc_num += 1
top_5 = float(acc_num) / len(label)
results.append([top_1, top_5])
else:
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
result = exe.run(val_program,
feed=feeder.feed(data),
fetch_list=fetch_targets)
result = [np.mean(r) for r in result]
results.append(result)
result = np.mean(np.array(results), axis=0)
print("top1_acc/top5_acc= {}".format(result))
sys.stdout.flush()
def main():
args = parser.parse_args()
print_arguments(args)
eval(args)
if __name__ == '__main__':
main()
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
import paddle.fluid as fluid
sys.path.append(sys.path[0] + "/../../../")
from paddleslim.common import get_logger
sys.path.append(sys.path[0] + "/../../")
import models
from utility import add_arguments, print_arguments
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretained", "Whether to use pretrained model.")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'")
add_arg('test_period', int, 10, "Test period in epoches.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def export_model(args):
if args.data == "mnist":
import paddle.dataset.mnist as reader
train_reader = reader.train()
val_reader = reader.test()
class_dim = 10
image_shape = "1,28,28"
elif args.data == "imagenet":
import imagenet_reader as reader
train_reader = reader.train()
val_reader = reader.val()
class_dim = 1000
image_shape = "3,224,224"
else:
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
image = fluid.data(
name='image', shape=[None] + image_shape, dtype='float32')
assert args.model in model_list, "{} is not in lists: {}".format(
args.model, model_list)
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
val_program = fluid.default_main_program().clone(for_test=True)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if args.pretrained_model:
def if_exist(var):
return os.path.exists(
os.path.join(args.pretrained_model, var.name))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
else:
assert False, "args.pretrained_model must set"
fluid.io.save_inference_model(
'./inference_model/' + args.model,
feeded_var_names=[image.name],
target_vars=[out],
executor=exe,
main_program=val_program,
model_filename='model',
params_filename='weights')
def main():
args = parser.parse_args()
print_arguments(args)
export_model(args)
if __name__ == '__main__':
main()
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
import paddle.fluid as fluid
import reader
sys.path.append(sys.path[0] + "/../../../")
from paddleslim.common import get_logger
from paddleslim.quant import quant_post
sys.path.append(sys.path[0] + "/../../")
from utility import add_arguments, print_arguments
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 16, "Minibatch size.")
add_arg('batch_num', int, 10, "Batch number")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model_path', str, "./inference_model/MobileNet/", "model dir")
add_arg('save_path', str, "./quant_model/MobileNet/", "model dir to save quanted model")
add_arg('model_filename', str, None, "model file name")
add_arg('params_filename', str, None, "params file name")
# yapf: enable
def quantize(args):
val_reader = reader.train()
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
assert os.path.exists(args.model_path), "args.model_path doesn't exist"
assert os.path.isdir(args.model_path), "args.model_path must be a dir"
exe = fluid.Executor(place)
quant_post(
executor=exe,
model_dir=args.model_path,
quantize_model_path=args.save_path,
sample_generator=val_reader,
model_filename=args.model_filename,
params_filename=args.params_filename,
batch_size=args.batch_size,
batch_nums=args.batch_num)
def main():
args = parser.parse_args()
print_arguments(args)
quantize(args)
if __name__ == '__main__':
main()
import sys
sys.path.append('..')
import numpy as np
import argparse
import ast
import paddle
import paddle.fluid as fluid
from paddleslim.nas.search_space.search_space_factory import SearchSpaceFactory
from paddleslim.analysis import flops
from paddleslim.nas import SANAS
def create_data_loader():
data = fluid.data(name='data', shape=[-1, 3, 32, 32], dtype='float32')
label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
data_loader = fluid.io.DataLoader.from_generator(
feed_list=[data, label],
capacity=1024,
use_double_buffer=True,
iterable=True)
return data_loader, data, label
def init_sa_nas(config):
factory = SearchSpaceFactory()
space = factory.get_search_space(config)
model_arch = space.token2arch()[0]
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
data_loader, data, label = create_data_loader()
output = model_arch(data)
cost = fluid.layers.mean(
fluid.layers.softmax_with_cross_entropy(
logits=output, label=label))
base_flops = flops(main_program)
search_steps = 10000000
### start a server and a client
sa_nas = SANAS(config, max_flops=base_flops, search_steps=search_steps)
### start a client, server_addr is server address
#sa_nas = SANAS(config, max_flops = base_flops, server_addr=("10.255.125.38", 18607), search_steps = search_steps, is_server=False)
return sa_nas, search_steps
def search_mobilenetv2_cifar10(config, args):
sa_nas, search_steps = init_sa_nas(config)
for i in range(search_steps):
print('search step: ', i)
archs = sa_nas.next_archs()[0]
train_program = fluid.Program()
test_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
train_loader, data, label = create_data_loader()
output = archs(data)
cost = fluid.layers.mean(
fluid.layers.softmax_with_cross_entropy(
logits=output, label=label))[0]
test_program = train_program.clone(for_test=True)
optimizer = fluid.optimizer.Momentum(
learning_rate=0.1,
momentum=0.9,
regularization=fluid.regularizer.L2Decay(1e-4))
optimizer.minimize(cost)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
train_reader = paddle.reader.shuffle(
paddle.dataset.cifar.train10(cycle=False), buf_size=1024)
train_loader.set_sample_generator(
train_reader,
batch_size=512,
places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
test_loader, _, _ = create_data_loader()
test_reader = paddle.dataset.cifar.test10(cycle=False)
test_loader.set_sample_generator(
test_reader,
batch_size=256,
drop_last=False,
places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places())
for epoch_id in range(10):
for batch_id, data in enumerate(train_loader()):
loss = exe.run(train_program,
feed=data,
fetch_list=[cost.name])[0]
if batch_id % 5 == 0:
print('epoch: {}, batch: {}, loss: {}'.format(
epoch_id, batch_id, loss[0]))
for data in test_loader():
reward = exe.run(test_program, feed=data,
fetch_list=[cost.name])[0]
print('reward:', reward)
sa_nas.reward(float(reward))
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='SA NAS MobileNetV2 cifar10 argparase')
parser.add_argument(
'--use_gpu',
type=ast.literal_eval,
default=True,
help='Whether to use GPU in train/test model.')
args = parser.parse_args()
print(args)
config_info = {
'input_size': 32,
'output_size': 1,
'block_num': 5,
'block_mask': None
}
config = [('MobileNetV2Space', config_info)]
search_mobilenetv2_cifar10(config, args)
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
import paddle.fluid as fluid
from paddleslim.prune import SensitivePruner
from paddleslim.common import get_logger
from paddleslim.analysis import flops
sys.path.append(sys.path[0] + "/../")
import models
from utility import add_arguments, print_arguments
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64 * 4, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretained", "Whether to use pretrained model.")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 120, "The number of total epochs.")
add_arg('total_images', int, 1281167, "The number of total training images.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('config_file', str, None, "The config file for compression with yaml format.")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.")
add_arg('checkpoints', str, "./checkpoints", "Checkpoints path.")
add_arg('prune_steps', int, 1000, "prune steps.")
add_arg('retrain_epoch', int, 5, "Retrain epoch.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def piecewise_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
bd = [step * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def cosine_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
learning_rate = fluid.layers.cosine_decay(
learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def create_optimizer(args):
if args.lr_strategy == "piecewise_decay":
return piecewise_decay(args)
elif args.lr_strategy == "cosine_decay":
return cosine_decay(args)
def compress(args):
train_reader = None
test_reader = None
if args.data == "mnist":
import paddle.dataset.mnist as reader
train_reader = reader.train()
val_reader = reader.test()
class_dim = 10
image_shape = "1,28,28"
elif args.data == "imagenet":
import imagenet_reader as reader
train_reader = reader.train()
val_reader = reader.val()
class_dim = 1000
image_shape = "3,224,224"
else:
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(
args.model, model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
val_program = fluid.default_main_program().clone(for_test=True)
opt = create_optimizer(args)
opt.minimize(avg_cost)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if args.pretrained_model:
def if_exist(var):
return os.path.exists(
os.path.join(args.pretrained_model, var.name))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
val_reader = paddle.batch(val_reader, batch_size=args.batch_size)
train_reader = paddle.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
train_feeder = feeder = fluid.DataFeeder([image, label], place)
val_feeder = feeder = fluid.DataFeeder(
[image, label], place, program=val_program)
def test(epoch, program):
batch_id = 0
acc_top1_ns = []
acc_top5_ns = []
for data in val_reader():
start_time = time.time()
acc_top1_n, acc_top5_n = exe.run(
program,
feed=train_feeder.feed(data),
fetch_list=[acc_top1.name, acc_top5.name])
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"Eval epoch[{}] batch[{}] - acc_top1: {:.3f}; acc_top5: {:.3f}; time: {:.3f}".
format(epoch, batch_id,
np.mean(acc_top1_n),
np.mean(acc_top5_n), end_time - start_time))
acc_top1_ns.append(np.mean(acc_top1_n))
acc_top5_ns.append(np.mean(acc_top5_n))
batch_id += 1
_logger.info(
"Final eval epoch[{}] - acc_top1: {:.3f}; acc_top5: {:.3f}".format(
epoch,
np.mean(np.array(acc_top1_ns)), np.mean(
np.array(acc_top5_ns))))
return np.mean(np.array(acc_top1_ns))
def train(epoch, program):
build_strategy = fluid.BuildStrategy()
exec_strategy = fluid.ExecutionStrategy()
train_program = fluid.compiler.CompiledProgram(
program).with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
batch_id = 0
for data in train_reader():
start_time = time.time()
loss_n, acc_top1_n, acc_top5_n = exe.run(
train_program,
feed=train_feeder.feed(data),
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
end_time = time.time()
loss_n = np.mean(loss_n)
acc_top1_n = np.mean(acc_top1_n)
acc_top5_n = np.mean(acc_top5_n)
if batch_id % args.log_period == 0:
_logger.info(
"epoch[{}]-batch[{}] - loss: {:.3f}; acc_top1: {:.3f}; acc_top5: {:.3f}; time: {:.3f}".
format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n,
end_time - start_time))
batch_id += 1
params = []
for param in fluid.default_main_program().global_block().all_parameters():
if "_sep_weights" in param.name:
params.append(param.name)
def eval_func(program):
return test(0, program)
if args.data == "mnist":
train(0, fluid.default_main_program())
pruner = SensitivePruner(place, eval_func, checkpoints=args.checkpoints)
pruned_program, pruned_val_program, iter = pruner.restore()
if pruned_program is None:
pruned_program = fluid.default_main_program()
if pruned_val_program is None:
pruned_val_program = val_program
base_flops = flops(val_program)
start = iter
end = args.prune_steps
for iter in range(start, end):
pruned_program, pruned_val_program = pruner.greedy_prune(
pruned_program, pruned_val_program, params, 0.03, topk=1)
current_flops = flops(pruned_val_program)
print("iter:{}; pruned FLOPS: {}".format(
iter, float(base_flops - current_flops) / base_flops))
acc = None
for i in range(args.retrain_epoch):
train(i, pruned_program)
acc = test(i, pruned_val_program)
print("iter:{}; pruned FLOPS: {}; acc: {}".format(
iter, float(base_flops - current_flops) / base_flops, acc))
pruner.save_checkpoint(pruned_program, pruned_val_program)
def main():
args = parser.parse_args()
print_arguments(args)
compress(args)
if __name__ == '__main__':
main()
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import numpy as np
import paddle.fluid as fluid
from paddleslim.prune import SensitivePruner
from paddleslim.common import get_logger
from paddleslim.analysis import flops
sys.path.append(sys.path[0] + "/../")
import models
from utility import add_arguments, print_arguments
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64 * 4, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretained", "Whether to use pretrained model.")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 120, "The number of total epochs.")
add_arg('total_images', int, 1281167, "The number of total training images.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
add_arg('config_file', str, None, "The config file for compression with yaml format.")
add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.")
add_arg('test_period', int, 10, "Test period in epoches.")
add_arg('checkpoints', str, "./checkpoints", "Checkpoints path.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]
def piecewise_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
bd = [step * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def cosine_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
learning_rate = fluid.layers.cosine_decay(
learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
def create_optimizer(args):
if args.lr_strategy == "piecewise_decay":
return piecewise_decay(args)
elif args.lr_strategy == "cosine_decay":
return cosine_decay(args)
def compress(args):
train_reader = None
test_reader = None
if args.data == "mnist":
import paddle.dataset.mnist as reader
train_reader = reader.train()
val_reader = reader.test()
class_dim = 10
image_shape = "1,28,28"
elif args.data == "imagenet":
import imagenet_reader as reader
train_reader = reader.train()
val_reader = reader.val()
class_dim = 1000
image_shape = "3,224,224"
else:
raise ValueError("{} is not supported.".format(args.data))
image_shape = [int(m) for m in image_shape.split(",")]
assert args.model in model_list, "{} is not in lists: {}".format(
args.model, model_list)
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
# model definition
model = models.__dict__[args.model]()
out = model.net(input=image, class_dim=class_dim)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
val_program = fluid.default_main_program().clone(for_test=True)
opt = create_optimizer(args)
opt.minimize(avg_cost)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
if args.pretrained_model:
def if_exist(var):
return os.path.exists(
os.path.join(args.pretrained_model, var.name))
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
val_reader = paddle.batch(val_reader, batch_size=args.batch_size)
train_reader = paddle.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
train_feeder = feeder = fluid.DataFeeder([image, label], place)
val_feeder = feeder = fluid.DataFeeder(
[image, label], place, program=val_program)
def test(epoch, program):
batch_id = 0
acc_top1_ns = []
acc_top5_ns = []
for data in val_reader():
start_time = time.time()
acc_top1_n, acc_top5_n = exe.run(
program,
feed=train_feeder.feed(data),
fetch_list=[acc_top1.name, acc_top5.name])
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"Eval epoch[{}] batch[{}] - acc_top1: {:.3f}; acc_top5: {:.3f}; time: {:.3f}".
format(epoch, batch_id,
np.mean(acc_top1_n),
np.mean(acc_top5_n), end_time - start_time))
acc_top1_ns.append(np.mean(acc_top1_n))
acc_top5_ns.append(np.mean(acc_top5_n))
batch_id += 1
_logger.info(
"Final eval epoch[{}] - acc_top1: {:.3f}; acc_top5: {:.3f}".format(
epoch,
np.mean(np.array(acc_top1_ns)), np.mean(
np.array(acc_top5_ns))))
return np.mean(np.array(acc_top1_ns))
def train(epoch, program):
build_strategy = fluid.BuildStrategy()
exec_strategy = fluid.ExecutionStrategy()
train_program = fluid.compiler.CompiledProgram(
program).with_data_parallel(
loss_name=avg_cost.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
batch_id = 0
for data in train_reader():
start_time = time.time()
loss_n, acc_top1_n, acc_top5_n = exe.run(
train_program,
feed=train_feeder.feed(data),
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
end_time = time.time()
loss_n = np.mean(loss_n)
acc_top1_n = np.mean(acc_top1_n)
acc_top5_n = np.mean(acc_top5_n)
if batch_id % args.log_period == 0:
_logger.info(
"epoch[{}]-batch[{}] - loss: {:.3f}; acc_top1: {:.3f}; acc_top5: {:.3f}; time: {:.3f}".
format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n,
end_time - start_time))
batch_id += 1
params = []
for param in fluid.default_main_program().global_block().all_parameters():
if "_sep_weights" in param.name:
params.append(param.name)
def eval_func(program):
return test(0, program)
if args.data == "mnist":
train(0, fluid.default_main_program())
pruner = SensitivePruner(place, eval_func, checkpoints=args.checkpoints)
pruned_program, pruned_val_program, iter = pruner.restore()
if pruned_program is None:
pruned_program = fluid.default_main_program()
if pruned_val_program is None:
pruned_val_program = val_program
start = iter
end = 6
for iter in range(start, end):
pruned_program, pruned_val_program = pruner.prune(
pruned_program, pruned_val_program, params, 0.1)
train(iter, pruned_program)
test(iter, pruned_val_program)
pruner.save_checkpoint(pruned_program, pruned_val_program)
print("before flops: {}".format(flops(fluid.default_main_program())))
print("after flops: {}".format(flops(pruned_val_program)))
def main():
args = parser.parse_args()
print_arguments(args)
compress(args)
if __name__ == '__main__':
main()
"""Contains common utility functions."""
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import distutils.util
import os
import numpy as np
import six
import logging
import paddle.fluid as fluid
import paddle.compat as cpt
from paddle.fluid import core
from paddle.fluid.framework import Program
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
def print_arguments(args):
"""Print argparse's arguments.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
parser.add_argument("name", default="Jonh", type=str, help="User name.")
args = parser.parse_args()
print_arguments(args)
:param args: Input argparse.Namespace for printing.
:type args: argparse.Namespace
"""
print("----------- Configuration Arguments -----------")
for arg, value in sorted(six.iteritems(vars(args))):
print("%s: %s" % (arg, value))
print("------------------------------------------------")
def add_arguments(argname, type, default, help, argparser, **kwargs):
"""Add argparse's argument.
Usage:
.. code-block:: python
parser = argparse.ArgumentParser()
add_argument("name", str, "Jonh", "User name.", parser)
args = parser.parse_args()
"""
type = distutils.util.strtobool if type == bool else type
argparser.add_argument(
"--" + argname,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
def save_persistable_nodes(executor, dirname, graph):
"""
Save persistable nodes to the given directory by the executor.
Args:
executor(Executor): The executor to run for saving node values.
dirname(str): The directory path.
graph(IrGraph): All the required persistable nodes in the graph will be saved.
"""
persistable_node_names = set()
persistable_nodes = []
all_persistable_nodes = graph.all_persistable_nodes()
for node in all_persistable_nodes:
name = cpt.to_text(node.name())
if name not in persistable_node_names:
persistable_node_names.add(name)
persistable_nodes.append(node)
program = Program()
var_list = []
for node in persistable_nodes:
var_desc = node.var()
if var_desc.type() == core.VarDesc.VarType.RAW or \
var_desc.type() == core.VarDesc.VarType.READER:
continue
var = program.global_block().create_var(
name=var_desc.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
var_list.append(var)
fluid.io.save_vars(executor=executor, dirname=dirname, vars=var_list)
def load_persistable_nodes(executor, dirname, graph):
"""
Load persistable node values from the given directory by the executor.
Args:
executor(Executor): The executor to run for loading node values.
dirname(str): The directory path.
graph(IrGraph): All the required persistable nodes in the graph will be loaded.
"""
persistable_node_names = set()
persistable_nodes = []
all_persistable_nodes = graph.all_persistable_nodes()
for node in all_persistable_nodes:
name = cpt.to_text(node.name())
if name not in persistable_node_names:
persistable_node_names.add(name)
persistable_nodes.append(node)
program = Program()
var_list = []
def _exist(var):
return os.path.exists(os.path.join(dirname, var.name))
def _load_var(name, scope):
return np.array(scope.find_var(name).get_tensor())
def _store_var(name, array, scope, place):
tensor = scope.find_var(name).get_tensor()
tensor.set(array, place)
for node in persistable_nodes:
var_desc = node.var()
if var_desc.type() == core.VarDesc.VarType.RAW or \
var_desc.type() == core.VarDesc.VarType.READER:
continue
var = program.global_block().create_var(
name=var_desc.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
if _exist(var):
var_list.append(var)
else:
_logger.info("Cannot find the var %s!!!" % (node.name()))
fluid.io.load_vars(executor=executor, dirname=dirname, vars=var_list)
...@@ -23,6 +23,8 @@ import controller_client ...@@ -23,6 +23,8 @@ import controller_client
from controller_client import * from controller_client import *
import lock_utils import lock_utils
from lock_utils import * from lock_utils import *
import cached_reader as cached_reader_module
from cached_reader import *
__all__ = [] __all__ = []
__all__ += controller.__all__ __all__ += controller.__all__
...@@ -30,3 +32,4 @@ __all__ += sa_controller.__all__ ...@@ -30,3 +32,4 @@ __all__ += sa_controller.__all__
__all__ += controller_server.__all__ __all__ += controller_server.__all__
__all__ += controller_client.__all__ __all__ += controller_client.__all__
__all__ += lock_utils.__all__ __all__ += lock_utils.__all__
__all__ += cached_reader_module.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import numpy as np
from .log_helper import get_logger
__all__ = ['cached_reader']
_logger = get_logger(__name__, level=logging.INFO)
def cached_reader(reader, sampled_rate, cache_path, cached_id):
"""
Sample partial data from reader and cache them into local file system.
Args:
reader: Iterative data source.
sampled_rate(float): The sampled rate used to sample partial data for evaluation. None means using all data in eval_reader. default: None.
cache_path(str): The path to cache the sampled data.
cached_id(int): The id of dataset sampled. Evaluations with same cached_id use the same sampled dataset. default: 0.
"""
np.random.seed(cached_id)
cache_path = os.path.join(cache_path, str(cached_id))
_logger.debug('read data from: {}'.format(cache_path))
def s_reader():
if os.path.isdir(cache_path):
for file_name in open(os.path.join(cache_path, "list")):
yield np.load(
os.path.join(cache_path, file_name.strip()),
allow_pickle=True)
else:
os.makedirs(cache_path)
list_file = open(os.path.join(cache_path, "list"), 'w')
batch = 0
dtype = None
for data in reader():
if batch == 0 or (np.random.uniform() < sampled_rate):
np.save(
os.path.join(cache_path, 'batch' + str(batch)), data)
list_file.write('batch' + str(batch) + '.npy\n')
batch += 1
yield data
return s_reader
...@@ -38,7 +38,7 @@ class ControllerClient(object): ...@@ -38,7 +38,7 @@ class ControllerClient(object):
self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._key = key self._key = key
def update(self, tokens, reward): def update(self, tokens, reward, iter):
""" """
Update the controller according to latest tokens and reward. Update the controller according to latest tokens and reward.
Args: Args:
...@@ -48,11 +48,13 @@ class ControllerClient(object): ...@@ -48,11 +48,13 @@ class ControllerClient(object):
socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_client.connect((self.server_ip, self.server_port)) socket_client.connect((self.server_ip, self.server_port))
tokens = ",".join([str(token) for token in tokens]) tokens = ",".join([str(token) for token in tokens])
socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward) socket_client.send("{}\t{}\t{}\t{}".format(self._key, tokens, reward,
.encode()) iter).encode())
tokens = socket_client.recv(1024).decode() response = socket_client.recv(1024).decode()
tokens = [int(token) for token in tokens.strip("\n").split(",")] if response.strip('\n').split("\t") == "ok":
return tokens return True
else:
return False
def next_tokens(self): def next_tokens(self):
""" """
......
...@@ -51,23 +51,8 @@ class ControllerServer(object): ...@@ -51,23 +51,8 @@ class ControllerServer(object):
self._port = address[1] self._port = address[1]
self._ip = address[0] self._ip = address[0]
self._key = key self._key = key
self._socket_file = "./controller_server.socket"
def start(self): def start(self):
open(self._socket_file, 'a').close()
socket_file = open(self._socket_file, 'r+')
lock(socket_file)
tid = socket_file.readline()
if tid == '':
_logger.info("start controller server...")
tid = self._start()
socket_file.write("tid: {}\nip: {}\nport: {}\n".format(
tid, self._ip, self._port))
_logger.info("started controller server...")
unlock(socket_file)
socket_file.close()
def _start(self):
self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket_server.bind(self._address) self._socket_server.bind(self._address)
self._socket_server.listen(self._max_client_num) self._socket_server.listen(self._max_client_num)
...@@ -82,7 +67,6 @@ class ControllerServer(object): ...@@ -82,7 +67,6 @@ class ControllerServer(object):
def close(self): def close(self):
"""Close the server.""" """Close the server."""
self._closed = True self._closed = True
os.remove(self._socket_file)
_logger.info("server closed!") _logger.info("server closed!")
def port(self): def port(self):
...@@ -109,20 +93,22 @@ class ControllerServer(object): ...@@ -109,20 +93,22 @@ class ControllerServer(object):
_logger.debug("recv message from {}: [{}]".format(addr, _logger.debug("recv message from {}: [{}]".format(addr,
message)) message))
messages = message.strip('\n').split("\t") messages = message.strip('\n').split("\t")
if (len(messages) < 3) or (messages[0] != self._key): if (len(messages) < 4) or (messages[0] != self._key):
_logger.debug("recv noise from {}: [{}]".format( _logger.debug("recv noise from {}: [{}]".format(
addr, message)) addr, message))
continue continue
tokens = messages[1] tokens = messages[1]
reward = messages[2] reward = messages[2]
iter = messages[3]
tokens = [int(token) for token in tokens.split(",")] tokens = [int(token) for token in tokens.split(",")]
self._controller.update(tokens, float(reward)) self._controller.update(tokens, float(reward), int(iter))
tokens = self._controller.next_tokens() response = "ok"
tokens = ",".join([str(token) for token in tokens]) conn.send(response.encode())
conn.send(tokens.encode())
_logger.debug("send message to {}: [{}]".format(addr, _logger.debug("send message to {}: [{}]".format(addr,
tokens)) tokens))
conn.close() conn.close()
except Exception, err:
_logger.error(err)
finally: finally:
self._socket_server.close() self._socket_server.close()
self.close() self.close()
...@@ -19,7 +19,7 @@ import logging ...@@ -19,7 +19,7 @@ import logging
__all__ = ['get_logger'] __all__ = ['get_logger']
def get_logger(name, level, fmt=None): def get_logger(name, level, fmt='%(asctime)s-%(levelname)s: %(message)s'):
""" """
Get logger from logging with given name, level and format without Get logger from logging with given name, level and format without
setting logging basicConfig. For setting basicConfig in paddle setting logging basicConfig. For setting basicConfig in paddle
...@@ -39,10 +39,10 @@ def get_logger(name, level, fmt=None): ...@@ -39,10 +39,10 @@ def get_logger(name, level, fmt=None):
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(level) logger.setLevel(level)
handler = logging.StreamHandler() handler = logging.StreamHandler()
if fmt: if fmt:
formatter = logging.Formatter(fmt=fmt) formatter = logging.Formatter(fmt=fmt)
handler.setFormatter(formatter) handler.setFormatter(formatter)
logger.addHandler(handler) logger.addHandler(handler)
logger.propagate = 0
return logger return logger
...@@ -32,7 +32,7 @@ class SAController(EvolutionaryController): ...@@ -32,7 +32,7 @@ class SAController(EvolutionaryController):
range_table=None, range_table=None,
reduce_rate=0.85, reduce_rate=0.85,
init_temperature=1024, init_temperature=1024,
max_iter_number=300, max_try_times=None,
init_tokens=None, init_tokens=None,
constrain_func=None): constrain_func=None):
"""Initialize. """Initialize.
...@@ -40,7 +40,7 @@ class SAController(EvolutionaryController): ...@@ -40,7 +40,7 @@ class SAController(EvolutionaryController):
range_table(list<int>): Range table. range_table(list<int>): Range table.
reduce_rate(float): The decay rate of temperature. reduce_rate(float): The decay rate of temperature.
init_temperature(float): Init temperature. init_temperature(float): Init temperature.
max_iter_number(int): max iteration number. max_try_times(int): max try times before get legal tokens.
init_tokens(list<int>): The initial tokens. init_tokens(list<int>): The initial tokens.
constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None. constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None.
""" """
...@@ -50,7 +50,7 @@ class SAController(EvolutionaryController): ...@@ -50,7 +50,7 @@ class SAController(EvolutionaryController):
len(self._range_table) == 2) len(self._range_table) == 2)
self._reduce_rate = reduce_rate self._reduce_rate = reduce_rate
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._max_iter_number = max_iter_number self._max_try_times = max_try_times
self._reward = -1 self._reward = -1
self._tokens = init_tokens self._tokens = init_tokens
self._constrain_func = constrain_func self._constrain_func = constrain_func
...@@ -65,14 +65,16 @@ class SAController(EvolutionaryController): ...@@ -65,14 +65,16 @@ class SAController(EvolutionaryController):
d[key] = self.__dict__[key] d[key] = self.__dict__[key]
return d return d
def update(self, tokens, reward): def update(self, tokens, reward, iter):
""" """
Update the controller according to latest tokens and reward. Update the controller according to latest tokens and reward.
Args: Args:
tokens(list<int>): The tokens generated in last step. tokens(list<int>): The tokens generated in last step.
reward(float): The reward of tokens. reward(float): The reward of tokens.
""" """
self._iter += 1 iter = int(iter)
if iter > self._iter:
self._iter = iter
temperature = self._init_temperature * self._reduce_rate**self._iter temperature = self._init_temperature * self._reduce_rate**self._iter
if (reward > self._reward) or (np.random.random() <= math.exp( if (reward > self._reward) or (np.random.random() <= math.exp(
(reward - self._reward) / temperature)): (reward - self._reward) / temperature)):
...@@ -96,12 +98,12 @@ class SAController(EvolutionaryController): ...@@ -96,12 +98,12 @@ class SAController(EvolutionaryController):
new_tokens = tokens[:] new_tokens = tokens[:]
index = int(len(self._range_table[0]) * np.random.random()) index = int(len(self._range_table[0]) * np.random.random())
new_tokens[index] = np.random.randint(self._range_table[0][index], new_tokens[index] = np.random.randint(self._range_table[0][index],
self._range_table[1][index] + 1) self._range_table[1][index])
_logger.debug("change index[{}] from {} to {}".format(index, tokens[ _logger.debug("change index[{}] from {} to {}".format(index, tokens[
index], new_tokens[index])) index], new_tokens[index]))
if self._constrain_func is None: if self._constrain_func is None or self._max_try_times is None:
return new_tokens return new_tokens
for _ in range(self._max_iter_number): for _ in range(self._max_try_times):
if not self._constrain_func(new_tokens): if not self._constrain_func(new_tokens):
index = int(len(self._range_table[0]) * np.random.random()) index = int(len(self._range_table[0]) * np.random.random())
new_tokens = tokens[:] new_tokens = tokens[:]
......
...@@ -54,6 +54,9 @@ class VarWrapper(object): ...@@ -54,6 +54,9 @@ class VarWrapper(object):
""" """
return self._var.name return self._var.name
def __repr__(self):
return self._var.name
def shape(self): def shape(self):
""" """
Get the shape of the varibale. Get the shape of the varibale.
...@@ -131,6 +134,11 @@ class OpWrapper(object): ...@@ -131,6 +134,11 @@ class OpWrapper(object):
""" """
return self._op.type return self._op.type
def __repr__(self):
return "op[id: {}, type: {}; inputs: {}]".format(self.idx(),
self.type(),
self.all_inputs())
def is_bwd_op(self): def is_bwd_op(self):
""" """
Whether this operator is backward op. Whether this operator is backward op.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle.fluid as fluid
def merge(teacher_program,
student_program,
data_name_map,
place,
teacher_scope=fluid.global_scope(),
student_scope=fluid.global_scope(),
name_prefix='teacher_'):
"""
Merge teacher program into student program and add a uniform prefix to the
names of all vars in teacher program
Args:
teacher_program(Program): The input teacher model paddle program
student_program(Program): The input student model paddle program
data_map_map(dict): Describe the mapping between the teacher var name
and the student var name
place(fluid.CPUPlace()|fluid.CUDAPlace(N)): This parameter represents
paddle run on which device.
student_scope(Scope): The input student scope
teacher_scope(Scope): The input teacher scope
name_prefix(str): Name prefix added for all vars of the teacher program.
Return(Program): Merged program.
"""
teacher_program = teacher_program.clone(for_test=True)
for teacher_var in teacher_program.list_vars():
skip_rename = False
if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
if teacher_var.name in data_name_map.keys():
new_name = data_name_map[teacher_var.name]
if new_name == teacher_var.name:
skip_rename = True
else:
new_name = name_prefix + teacher_var.name
if not skip_rename:
# scope var rename
scope_var = teacher_scope.var(teacher_var.name).get_tensor()
renamed_scope_var = teacher_scope.var(new_name).get_tensor()
renamed_scope_var.set(np.array(scope_var), place)
# program var rename
renamed_var = teacher_program.global_block()._rename_var(
teacher_var.name, new_name)
for teacher_var in teacher_program.list_vars():
if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
# student scope add var
student_scope_var = student_scope.var(teacher_var.name).get_tensor()
teacher_scope_var = teacher_scope.var(teacher_var.name).get_tensor()
student_scope_var.set(np.array(teacher_scope_var), place)
# student program add var
new_var = student_program.global_block()._clone_variable(
teacher_var, force_persistable=False)
new_var.stop_gradient = True
for block in teacher_program.blocks:
for op in block.ops:
if op.type != 'feed' and op.type != 'fetch':
inputs = {}
outputs = {}
attrs = {}
for input_name in op.input_names:
inputs[input_name] = [
block.var(in_var_name)
for in_var_name in op.input(input_name)
]
for output_name in op.output_names:
outputs[output_name] = [
block.var(out_var_name)
for out_var_name in op.output(output_name)
]
for attr_name in op.attr_names:
attrs[attr_name] = op.attr(attr_name)
student_program.global_block().append_op(
type=op.type, inputs=inputs, outputs=outputs, attrs=attrs)
return student_program
def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name,
student_var2_name, program=fluid.default_main_program()):
"""
Combine variables from student model and teacher model by fsp-loss.
Args:
teacher_var1_name(str): The name of teacher_var1.
teacher_var2_name(str): The name of teacher_var2. Except for the
second dimension, all other dimensions should
be consistent with teacher_var1.
student_var1_name(str): The name of student_var1.
student_var2_name(str): The name of student_var2. Except for the
second dimension, all other dimensions should
be consistent with student_var1.
program(Program): The input distiller program.
default: fluid.default_main_program()
Return(Variable): fsp distiller loss.
"""
teacher_var1 = program.global_block().var(teacher_var1_name)
teacher_var2 = program.global_block().var(teacher_var2_name)
student_var1 = program.global_block().var(student_var1_name)
student_var2 = program.global_block().var(student_var2_name)
teacher_fsp_matrix = fluid.layers.fsp_matrix(teacher_var1, teacher_var2)
student_fsp_matrix = fluid.layers.fsp_matrix(student_var1, student_var2)
fsp_loss = fluid.layers.reduce_mean(
fluid.layers.square(student_fsp_matrix - teacher_fsp_matrix))
return fsp_loss
def l2_loss(teacher_var_name, student_var_name,
program=fluid.default_main_program()):
"""
Combine variables from student model and teacher model by l2-loss.
Args:
teacher_var_name(str): The name of teacher_var.
student_var_name(str): The name of student_var.
program(Program): The input distiller program.
default: fluid.default_main_program()
Return(Variable): l2 distiller loss.
"""
student_var = program.global_block().var(student_var_name)
teacher_var = program.global_block().var(teacher_var_name)
l2_loss = fluid.layers.reduce_mean(
fluid.layers.square(student_var - teacher_var))
return l2_loss
def soft_label_loss(teacher_var_name,
student_var_name,
program=fluid.default_main_program(),
teacher_temperature=1.,
student_temperature=1.):
"""
Combine variables from student model and teacher model by soft-label-loss.
Args:
teacher_var_name(str): The name of teacher_var.
student_var_name(str): The name of student_var.
program(Program): The input distiller program.
default: fluid.default_main_program()
teacher_temperature(float): Temperature used to divide
teacher_feature_map before softmax. default: 1.0
student_temperature(float): Temperature used to divide
student_feature_map before softmax. default: 1.0
Return(Variable): l2 distiller loss.
"""
student_var = program.global_block().var(student_var_name)
teacher_var = program.global_block().var(teacher_var_name)
student_var = fluid.layers.softmax(student_var / student_temperature)
teacher_var = fluid.layers.softmax(teacher_var / teacher_temperature)
teacher_var.stop_gradient = True
soft_label_loss = fluid.layers.reduce_mean(
fluid.layers.cross_entropy(
student_var, teacher_var, soft_label=True))
return soft_label_loss
def loss(loss_func, program=fluid.default_main_program(), **kwargs):
"""
Combine variables from student model and teacher model by self defined loss.
Args:
program(Program): The input distiller program.
default: fluid.default_main_program()
loss_func(function): The user self defined loss function.
Return(Variable): self defined distiller loss.
"""
func_parameters = {}
for item in kwargs.items():
if isinstance(item[1], str):
func_parameters.setdefault(item[0],
program.global_block().var(item[1]))
else:
func_parameters.setdefault(item[0], item[1])
loss = loss_func(**func_parameters)
return loss
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import socket import socket
import logging import logging
import numpy as np import numpy as np
import hashlib
import paddle.fluid as fluid import paddle.fluid as fluid
from ..core import VarWrapper, OpWrapper, GraphWrapper from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import SAController from ..common import SAController
...@@ -33,97 +34,75 @@ _logger = get_logger(__name__, level=logging.INFO) ...@@ -33,97 +34,75 @@ _logger = get_logger(__name__, level=logging.INFO)
class SANAS(object): class SANAS(object):
def __init__(self, def __init__(self,
configs, configs,
max_flops=None, server_addr=("", 8881),
max_latency=None,
server_addr=("", 0),
init_temperature=100, init_temperature=100,
reduce_rate=0.85, reduce_rate=0.85,
max_try_number=300,
max_client_num=10,
search_steps=300, search_steps=300,
key="sa_nas", key="sa_nas",
is_server=True): is_server=False):
""" """
Search a group of ratios used to prune program. Search a group of ratios used to prune program.
Args: Args:
configs(list<tuple>): A list of search space configuration with format (key, input_size, output_size, block_num). configs(list<tuple>): A list of search space configuration with format (key, input_size, output_size, block_num).
`key` is the name of search space with data type str. `input_size` and `output_size` are `key` is the name of search space with data type str. `input_size` and `output_size` are
input size and output size of searched sub-network. `block_num` is the number of blocks in searched network. input size and output size of searched sub-network. `block_num` is the number of blocks in searched network.
max_flops(int): The max flops of searched network. None means no constrains. Default: None.
max_latency(float): The max latency of searched network. None means no constrains. Default: None.
server_addr(tuple): A tuple of server ip and server port for controller server. server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy. init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy. reduce_rate(float): The decay rate used in simulated annealing search strategy.
max_try_number(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching. search_steps(int): The steps of searching.
key(str): Identity used in communication between controller server and clients. key(str): Identity used in communication between controller server and clients.
is_server(bool): Whether current host is controller server. Default: True. is_server(bool): Whether current host is controller server. Default: True.
""" """
if not is_server:
assert server_addr[
0] != "", "You should set the IP and port of server when is_server is False."
self._reduce_rate = reduce_rate self._reduce_rate = reduce_rate
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._max_try_number = max_try_number
self._is_server = is_server self._is_server = is_server
self._max_flops = max_flops
self._max_latency = max_latency
self._configs = configs self._configs = configs
self._key = hashlib.md5(str(self._configs)).hexdigest()
factory = SearchSpaceFactory()
self._search_space = factory.get_search_space(configs)
init_tokens = self._search_space.init_tokens()
range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table)
print range_table
controller = SAController(range_table, self._reduce_rate,
self._init_temperature, self._max_try_number,
init_tokens, self._constrain_func)
server_ip, server_port = server_addr server_ip, server_port = server_addr
if server_ip == None or server_ip == "": if server_ip == None or server_ip == "":
server_ip = self._get_host_ip() server_ip = self._get_host_ip()
self._controller_server = ControllerServer( factory = SearchSpaceFactory()
controller=controller, self._search_space = factory.get_search_space(configs)
address=(server_ip, server_port),
max_client_num=max_client_num,
search_steps=search_steps,
key=key)
# create controller server # create controller server
if self._is_server: if self._is_server:
init_tokens = self._search_space.init_tokens()
range_table = self._search_space.range_table()
range_table = (len(range_table) * [0], range_table)
_logger.info("range table: {}".format(range_table))
controller = SAController(
range_table,
self._reduce_rate,
self._init_temperature,
max_try_times=None,
init_tokens=init_tokens,
constrain_func=None)
max_client_num = 100
self._controller_server = ControllerServer(
controller=controller,
address=(server_ip, server_port),
max_client_num=max_client_num,
search_steps=search_steps,
key=self._key)
self._controller_server.start() self._controller_server.start()
server_port = self._controller_server.port()
self._controller_client = ControllerClient( self._controller_client = ControllerClient(
self._controller_server.ip(), server_ip, server_port, key=self._key)
self._controller_server.port(),
key=key)
self._iter = 0 self._iter = 0
def _get_host_ip(self): def _get_host_ip(self):
return socket.gethostbyname(socket.gethostname()) return socket.gethostbyname(socket.gethostname())
def _constrain_func(self, tokens): def tokens2arch(self, tokens):
if (self._max_flops is None) and (self._max_latency is None): return self._search_space.token2arch(self.tokens)
return True
archs = self._search_space.token2arch(tokens)
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
i = 0
for config, arch in zip(self._configs, archs):
input_size = config[1]["input_size"]
input = fluid.data(
name="data_{}".format(i),
shape=[None, 3, input_size, input_size],
dtype="float32")
output = arch(input)
i += 1
return flops(main_program) < self._max_flops
def next_archs(self): def next_archs(self):
""" """
...@@ -140,6 +119,9 @@ class SANAS(object): ...@@ -140,6 +119,9 @@ class SANAS(object):
Return reward of current searched network. Return reward of current searched network.
Args: Args:
score(float): The score of current searched network. score(float): The score of current searched network.
Returns:
bool: True means updating successfully while false means failure.
""" """
self._controller_client.update(self._current_tokens, score)
self._iter += 1 self._iter += 1
return self._controller_client.update(self._current_tokens, score,
self._iter)
...@@ -39,6 +39,7 @@ class CombineSearchSpace(object): ...@@ -39,6 +39,7 @@ class CombineSearchSpace(object):
for config_list in config_lists: for config_list in config_lists:
key, config = config_list key, config = config_list
self.spaces.append(self._get_single_search_space(key, config)) self.spaces.append(self._get_single_search_space(key, config))
self.init_tokens()
def _get_single_search_space(self, key, config): def _get_single_search_space(self, key, config):
""" """
...@@ -51,9 +52,11 @@ class CombineSearchSpace(object): ...@@ -51,9 +52,11 @@ class CombineSearchSpace(object):
model space(class) model space(class)
""" """
cls = SEARCHSPACE.get(key) cls = SEARCHSPACE.get(key)
space = cls(config['input_size'], config['output_size'], block_mask = config['block_mask'] if 'block_mask' in config else None
config['block_num'], config['block_mask']) space = cls(config['input_size'],
config['output_size'],
config['block_num'],
block_mask=block_mask)
return space return space
def init_tokens(self): def init_tokens(self):
......
...@@ -32,10 +32,12 @@ class MobileNetV1Space(SearchSpaceBase): ...@@ -32,10 +32,12 @@ class MobileNetV1Space(SearchSpaceBase):
input_size, input_size,
output_size, output_size,
block_num, block_num,
block_mask,
scale=1.0, scale=1.0,
class_dim=1000): class_dim=1000):
super(MobileNetV1Space, self).__init__(input_size, output_size, super(MobileNetV1Space, self).__init__(input_size, output_size,
block_num) block_num, block_mask)
assert self.block_mask == None, 'MobileNetV1Space will use origin MobileNetV1 as seach space, so use input_size, output_size and block_num to search'
self.scale = scale self.scale = scale
self.class_dim = class_dim self.class_dim = class_dim
# self.head_num means the channel of first convolution # self.head_num means the channel of first convolution
......
...@@ -113,40 +113,69 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -113,40 +113,69 @@ class MobileNetV2Space(SearchSpaceBase):
if tokens is None: if tokens is None:
tokens = self.init_tokens() tokens = self.init_tokens()
print(tokens)
bottleneck_params_list = [] self.bottleneck_params_list = []
if self.block_num >= 1: if self.block_num >= 1:
bottleneck_params_list.append( self.bottleneck_params_list.append(
(1, self.head_num[tokens[0]], 1, 1, 3)) (1, self.head_num[tokens[0]], 1, 1, 3))
if self.block_num >= 2: if self.block_num >= 2:
bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.multiply[tokens[1]], self.filter_num1[tokens[2]], (self.multiply[tokens[1]], self.filter_num1[tokens[2]],
self.repeat[tokens[3]], 2, self.k_size[tokens[4]])) self.repeat[tokens[3]], 2, self.k_size[tokens[4]]))
if self.block_num >= 3: if self.block_num >= 3:
bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.multiply[tokens[5]], self.filter_num1[tokens[6]], (self.multiply[tokens[5]], self.filter_num1[tokens[6]],
self.repeat[tokens[7]], 2, self.k_size[tokens[8]])) self.repeat[tokens[7]], 2, self.k_size[tokens[8]]))
if self.block_num >= 4: if self.block_num >= 4:
bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.multiply[tokens[9]], self.filter_num2[tokens[10]], (self.multiply[tokens[9]], self.filter_num2[tokens[10]],
self.repeat[tokens[11]], 2, self.k_size[tokens[12]])) self.repeat[tokens[11]], 2, self.k_size[tokens[12]]))
if self.block_num >= 5: if self.block_num >= 5:
bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.multiply[tokens[13]], self.filter_num3[tokens[14]], (self.multiply[tokens[13]], self.filter_num3[tokens[14]],
self.repeat[tokens[15]], 2, self.k_size[tokens[16]])) self.repeat[tokens[15]], 2, self.k_size[tokens[16]]))
bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.multiply[tokens[17]], self.filter_num4[tokens[18]], (self.multiply[tokens[17]], self.filter_num4[tokens[18]],
self.repeat[tokens[19]], 1, self.k_size[tokens[20]])) self.repeat[tokens[19]], 1, self.k_size[tokens[20]]))
if self.block_num >= 6: if self.block_num >= 6:
bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.multiply[tokens[21]], self.filter_num5[tokens[22]], (self.multiply[tokens[21]], self.filter_num5[tokens[22]],
self.repeat[tokens[23]], 2, self.k_size[tokens[24]])) self.repeat[tokens[23]], 2, self.k_size[tokens[24]]))
bottleneck_params_list.append( self.bottleneck_params_list.append(
(self.multiply[tokens[25]], self.filter_num6[tokens[26]], (self.multiply[tokens[25]], self.filter_num6[tokens[26]],
self.repeat[tokens[27]], 1, self.k_size[tokens[28]])) self.repeat[tokens[27]], 1, self.k_size[tokens[28]]))
def net_arch(input): def _modify_bottle_params(output_stride=None):
if output_stride is not None and output_stride % 2 != 0:
raise Exception("output stride must to be even number")
if output_stride is None:
return
else:
stride = 2
for i, layer_setting in enumerate(self.bottleneck_params_list):
t, c, n, s, ks = layer_setting
stride = stride * s
if stride > output_stride:
s = 1
self.bottleneck_params_list[i] = (t, c, n, s, ks)
def net_arch(input,
end_points=None,
decode_points=None,
output_stride=None):
_modify_bottle_params(output_stride)
decode_ends = dict()
def check_points(count, points):
if points is None:
return False
else:
if isinstance(points, list):
return (True if count in points else False)
else:
return (True if count == points else False)
#conv1 #conv1
# all padding is 'SAME' in the conv2d, can compute the actual padding automatic. # all padding is 'SAME' in the conv2d, can compute the actual padding automatic.
input = conv_bn_layer( input = conv_bn_layer(
...@@ -157,14 +186,21 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -157,14 +186,21 @@ class MobileNetV2Space(SearchSpaceBase):
padding='SAME', padding='SAME',
act='relu6', act='relu6',
name='mobilenetv2_conv1_1') name='mobilenetv2_conv1_1')
layer_count = 1
if check_points(layer_count, decode_points):
decode_ends[layer_count] = input
if check_points(layer_count, end_points):
return input, decode_ends
# bottleneck sequences # bottleneck sequences
i = 1 i = 1
in_c = int(32 * self.scale) in_c = int(32 * self.scale)
for layer_setting in bottleneck_params_list: for layer_setting in self.bottleneck_params_list:
t, c, n, s, k = layer_setting t, c, n, s, k = layer_setting
i += 1 i += 1
input = self._invresi_blocks( #print(input)
input, depthwise_output = self._invresi_blocks(
input=input, input=input,
in_c=in_c, in_c=in_c,
t=t, t=t,
...@@ -174,6 +210,33 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -174,6 +210,33 @@ class MobileNetV2Space(SearchSpaceBase):
k=k, k=k,
name='mobilenetv2_conv' + str(i)) name='mobilenetv2_conv' + str(i))
in_c = int(c * self.scale) in_c = int(c * self.scale)
layer_count += 1
### decode_points and end_points means block num
if check_points(layer_count, decode_points):
decode_ends[layer_count] = depthwise_output
if check_points(layer_count, end_points):
return input, decode_ends
# last conv
input = conv_bn_layer(
input=input,
num_filters=int(1280 * self.scale)
if self.scale > 1.0 else 1280,
filter_size=1,
stride=1,
padding='SAME',
act='relu6',
name='mobilenetv2_conv' + str(i + 1))
input = fluid.layers.pool2d(
input=input,
pool_size=7,
pool_stride=1,
pool_type='avg',
global_pooling=True,
name='mobilenetv2_last_pool')
# if output_size is 1, add fc layer in the end # if output_size is 1, add fc layer in the end
if self.output_size == 1: if self.output_size == 1:
...@@ -248,6 +311,8 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -248,6 +311,8 @@ class MobileNetV2Space(SearchSpaceBase):
name=name + '_dwise', name=name + '_dwise',
use_cudnn=False) use_cudnn=False)
depthwise_output = bottleneck_conv
linear_out = conv_bn_layer( linear_out = conv_bn_layer(
input=bottleneck_conv, input=bottleneck_conv,
num_filters=num_filters, num_filters=num_filters,
...@@ -260,7 +325,7 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -260,7 +325,7 @@ class MobileNetV2Space(SearchSpaceBase):
out = linear_out out = linear_out
if ifshortcut: if ifshortcut:
out = self._shortcut(input=input, data_residual=out) out = self._shortcut(input=input, data_residual=out)
return out return out, depthwise_output
def _invresi_blocks(self, input, in_c, t, c, n, s, k, name=None): def _invresi_blocks(self, input, in_c, t, c, n, s, k, name=None):
"""Build inverted residual blocks. """Build inverted residual blocks.
...@@ -276,7 +341,7 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -276,7 +341,7 @@ class MobileNetV2Space(SearchSpaceBase):
Returns: Returns:
Variable, layers output. Variable, layers output.
""" """
first_block = self._inverted_residual_unit( first_block, depthwise_output = self._inverted_residual_unit(
input=input, input=input,
num_in_filter=in_c, num_in_filter=in_c,
num_filters=c, num_filters=c,
...@@ -290,7 +355,7 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -290,7 +355,7 @@ class MobileNetV2Space(SearchSpaceBase):
last_c = c last_c = c
for i in range(1, n): for i in range(1, n):
last_residual_block = self._inverted_residual_unit( last_residual_block, depthwise_output = self._inverted_residual_unit(
input=last_residual_block, input=last_residual_block,
num_in_filter=last_c, num_in_filter=last_c,
num_filters=c, num_filters=c,
...@@ -299,4 +364,4 @@ class MobileNetV2Space(SearchSpaceBase): ...@@ -299,4 +364,4 @@ class MobileNetV2Space(SearchSpaceBase):
filter_size=k, filter_size=k,
expansion_factor=t, expansion_factor=t,
name=name + '_' + str(i + 1)) name=name + '_' + str(i + 1))
return last_residual_block return last_residual_block, depthwise_output
...@@ -19,7 +19,9 @@ class SearchSpaceBase(object): ...@@ -19,7 +19,9 @@ class SearchSpaceBase(object):
"""Controller for Neural Architecture Search. """Controller for Neural Architecture Search.
""" """
def __init__(self, input_size, output_size, block_num, block_mask, *argss): def __init__(self, input_size, output_size, block_num, block_mask, *args):
"""init model config
"""
self.input_size = input_size self.input_size = input_size
self.output_size = output_size self.output_size = output_size
self.block_num = block_num self.block_num = block_num
......
...@@ -19,9 +19,15 @@ import controller_server ...@@ -19,9 +19,15 @@ import controller_server
from controller_server import * from controller_server import *
import controller_client import controller_client
from controller_client import * from controller_client import *
import sensitive_pruner
from sensitive_pruner import *
import sensitive
from sensitive import *
__all__ = [] __all__ = []
__all__ += pruner.__all__ __all__ += pruner.__all__
__all__ += auto_pruner.__all__ __all__ += auto_pruner.__all__
__all__ += controller_server.__all__ __all__ += controller_server.__all__
__all__ += controller_client.__all__ __all__ += controller_client.__all__
__all__ += sensitive_pruner.__all__
__all__ += sensitive.__all__
...@@ -42,7 +42,7 @@ class AutoPruner(object): ...@@ -42,7 +42,7 @@ class AutoPruner(object):
server_addr=("", 0), server_addr=("", 0),
init_temperature=100, init_temperature=100,
reduce_rate=0.85, reduce_rate=0.85,
max_try_number=300, max_try_times=300,
max_client_num=10, max_client_num=10,
search_steps=300, search_steps=300,
max_ratios=[0.9], max_ratios=[0.9],
...@@ -66,7 +66,7 @@ class AutoPruner(object): ...@@ -66,7 +66,7 @@ class AutoPruner(object):
server_addr(tuple): A tuple of server ip and server port for controller server. server_addr(tuple): A tuple of server ip and server port for controller server.
init_temperature(float): The init temperature used in simulated annealing search strategy. init_temperature(float): The init temperature used in simulated annealing search strategy.
reduce_rate(float): The decay rate used in simulated annealing search strategy. reduce_rate(float): The decay rate used in simulated annealing search strategy.
max_try_number(int): The max number of trying to generate legal tokens. max_try_times(int): The max number of trying to generate legal tokens.
max_client_num(int): The max number of connections of controller server. max_client_num(int): The max number of connections of controller server.
search_steps(int): The steps of searching. search_steps(int): The steps of searching.
max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`. max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`.
...@@ -88,7 +88,7 @@ class AutoPruner(object): ...@@ -88,7 +88,7 @@ class AutoPruner(object):
self._pruned_latency = pruned_latency self._pruned_latency = pruned_latency
self._reduce_rate = reduce_rate self._reduce_rate = reduce_rate
self._init_temperature = init_temperature self._init_temperature = init_temperature
self._max_try_number = max_try_number self._max_try_times = max_try_times
self._is_server = is_server self._is_server = is_server
self._range_table = self._get_range_table(min_ratios, max_ratios) self._range_table = self._get_range_table(min_ratios, max_ratios)
...@@ -96,8 +96,10 @@ class AutoPruner(object): ...@@ -96,8 +96,10 @@ class AutoPruner(object):
self._pruner = Pruner() self._pruner = Pruner()
if self._pruned_flops: if self._pruned_flops:
self._base_flops = flops(program) self._base_flops = flops(program)
_logger.info("AutoPruner - base flops: {};".format( self._max_flops = self._base_flops * (1 - self._pruned_flops)
self._base_flops)) _logger.info(
"AutoPruner - base flops: {}; pruned_flops: {}; max_flops: {}".
format(self._base_flops, self._pruned_flops, self._max_flops))
if self._pruned_latency: if self._pruned_latency:
self._base_latency = latency(program) self._base_latency = latency(program)
...@@ -106,9 +108,9 @@ class AutoPruner(object): ...@@ -106,9 +108,9 @@ class AutoPruner(object):
self, _program, self._params, self._pruned_flops, self, _program, self._params, self._pruned_flops,
self._pruned_latency) self._pruned_latency)
init_tokens = self._ratios2tokens(self._init_ratios) init_tokens = self._ratios2tokens(self._init_ratios)
_logger.info("range table: {}".format(self._range_table))
controller = SAController(self._range_table, self._reduce_rate, controller = SAController(self._range_table, self._reduce_rate,
self._init_temperature, self._max_try_number, self._init_temperature, self._max_try_times,
init_tokens, self._constrain_func) init_tokens, self._constrain_func)
server_ip, server_port = server_addr server_ip, server_port = server_addr
...@@ -143,10 +145,10 @@ class AutoPruner(object): ...@@ -143,10 +145,10 @@ class AutoPruner(object):
def _get_range_table(self, min_ratios, max_ratios): def _get_range_table(self, min_ratios, max_ratios):
assert isinstance(min_ratios, list) or isinstance(min_ratios, float) assert isinstance(min_ratios, list) or isinstance(min_ratios, float)
assert isinstance(max_ratios, list) or isinstance(max_ratios, float) assert isinstance(max_ratios, list) or isinstance(max_ratios, float)
min_ratios = min_ratios if isinstance(min_ratios, min_ratios = min_ratios if isinstance(
list) else [min_ratios] min_ratios, list) else [min_ratios] * len(self._params)
max_ratios = max_ratios if isinstance(max_ratios, max_ratios = max_ratios if isinstance(
list) else [max_ratios] max_ratios, list) else [max_ratios] * len(self._params)
min_tokens = self._ratios2tokens(min_ratios) min_tokens = self._ratios2tokens(min_ratios)
max_tokens = self._ratios2tokens(max_ratios) max_tokens = self._ratios2tokens(max_ratios)
return (min_tokens, max_tokens) return (min_tokens, max_tokens)
...@@ -160,10 +162,17 @@ class AutoPruner(object): ...@@ -160,10 +162,17 @@ class AutoPruner(object):
ratios, ratios,
place=self._place, place=self._place,
only_graph=True) only_graph=True)
return flops(pruned_program) < self._base_flops * ( current_flops = flops(pruned_program)
1 - self._pruned_flops) result = current_flops < self._max_flops
if not result:
def prune(self, program): _logger.info("Failed try ratios: {}; flops: {}; max_flops: {}".
format(ratios, current_flops, self._max_flops))
else:
_logger.info("Success try ratios: {}; flops: {}; max_flops: {}".
format(ratios, current_flops, self._max_flops))
return result
def prune(self, program, eval_program=None):
""" """
Prune program with latest tokens generated by controller. Prune program with latest tokens generated by controller.
Args: Args:
...@@ -178,10 +187,21 @@ class AutoPruner(object): ...@@ -178,10 +187,21 @@ class AutoPruner(object):
self._params, self._params,
self._current_ratios, self._current_ratios,
place=self._place, place=self._place,
only_graph=False,
param_backup=self._param_backup) param_backup=self._param_backup)
pruned_val_program = None
if eval_program is not None:
pruned_val_program = self._pruner.prune(
program,
self._scope,
self._params,
self._current_ratios,
place=self._place,
only_graph=True)
_logger.info("AutoPruner - pruned ratios: {}".format( _logger.info("AutoPruner - pruned ratios: {}".format(
self._current_ratios)) self._current_ratios))
return pruned_program return pruned_program, pruned_val_program
def reward(self, score): def reward(self, score):
""" """
...@@ -192,7 +212,7 @@ class AutoPruner(object): ...@@ -192,7 +212,7 @@ class AutoPruner(object):
self._restore(self._scope) self._restore(self._scope)
self._param_backup = {} self._param_backup = {}
tokens = self._ratios2tokens(self._current_ratios) tokens = self._ratios2tokens(self._current_ratios)
self._controller_client.update(tokens, score) self._controller_client.update(tokens, score, self._iter)
self._iter += 1 self._iter += 1
def _restore(self, scope): def _restore(self, scope):
......
...@@ -12,13 +12,17 @@ ...@@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import copy import copy
from ..core import VarWrapper, OpWrapper, GraphWrapper from ..core import VarWrapper, OpWrapper, GraphWrapper
from ..common import get_logger
__all__ = ["Pruner"] __all__ = ["Pruner"]
_logger = get_logger(__name__, level=logging.INFO)
class Pruner(): class Pruner():
def __init__(self, criterion="l1_norm"): def __init__(self, criterion="l1_norm"):
...@@ -69,6 +73,10 @@ class Pruner(): ...@@ -69,6 +73,10 @@ class Pruner():
only_graph=only_graph, only_graph=only_graph,
param_backup=param_backup, param_backup=param_backup,
param_shape_backup=param_shape_backup) param_shape_backup=param_shape_backup)
for op in graph.ops():
if op.type() == 'depthwise_conv2d' or op.type(
) == 'depthwise_conv2d_grad':
op.set_attr('groups', op.inputs('Filter')[0].shape()[0])
return graph.program return graph.program
def _prune_filters_by_ratio(self, def _prune_filters_by_ratio(self,
...@@ -94,27 +102,49 @@ class Pruner(): ...@@ -94,27 +102,49 @@ class Pruner():
""" """
if params[0].name() in self.pruned_list[0]: if params[0].name() in self.pruned_list[0]:
return return
param_t = scope.find_var(params[0].name()).get_tensor()
pruned_idx = self._cal_pruned_idx( if only_graph:
params[0].name(), np.array(param_t), ratio, axis=0) pruned_num = int(round(params[0].shape()[0] * ratio))
for param in params: for param in params:
assert isinstance(param, VarWrapper) ori_shape = param.shape()
param_t = scope.find_var(param.name()).get_tensor() if param_backup is not None and (
if param_backup is not None and (param.name() not in param_backup): param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(np.array(param_t)) param_backup[param.name()] = copy.deepcopy(ori_shape)
pruned_param = self._prune_tensor( new_shape = list(ori_shape)
np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy) new_shape[0] -= pruned_num
if not only_graph: param.set_shape(new_shape)
_logger.debug("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[0].append(param.name())
return range(pruned_num)
else:
param_t = scope.find_var(params[0].name()).get_tensor()
pruned_idx = self._cal_pruned_idx(
params[0].name(), np.array(param_t), ratio, axis=0)
for param in params:
assert isinstance(param, VarWrapper)
param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(
np.array(param_t))
pruned_param = self._prune_tensor(
np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy)
param_t.set(pruned_param, place) param_t.set(pruned_param, place)
ori_shape = param.shape() ori_shape = param.shape()
if param_shape_backup is not None and ( if param_shape_backup is not None and (
param.name() not in param_shape_backup): param.name() not in param_shape_backup):
param_shape_backup[param.name()] = copy.deepcopy(param.shape()) param_shape_backup[param.name()] = copy.deepcopy(
new_shape = list(param.shape()) param.shape())
new_shape[0] = pruned_param.shape[0] new_shape = list(param.shape())
param.set_shape(new_shape) new_shape[0] = pruned_param.shape[0]
self.pruned_list[0].append(param.name()) param.set_shape(new_shape)
return pruned_idx _logger.debug("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[0].append(param.name())
return pruned_idx
def _prune_parameter_by_idx(self, def _prune_parameter_by_idx(self,
scope, scope,
...@@ -141,24 +171,44 @@ class Pruner(): ...@@ -141,24 +171,44 @@ class Pruner():
""" """
if params[0].name() in self.pruned_list[pruned_axis]: if params[0].name() in self.pruned_list[pruned_axis]:
return return
for param in params:
assert isinstance(param, VarWrapper) if only_graph:
param_t = scope.find_var(param.name()).get_tensor() pruned_num = len(pruned_idx)
if param_backup is not None and (param.name() not in param_backup): for param in params:
param_backup[param.name()] = copy.deepcopy(np.array(param_t)) ori_shape = param.shape()
pruned_param = self._prune_tensor( if param_backup is not None and (
np.array(param_t), pruned_idx, pruned_axis, lazy=lazy) param.name() not in param_backup):
if not only_graph: param_backup[param.name()] = copy.deepcopy(ori_shape)
new_shape = list(ori_shape)
new_shape[pruned_axis] -= pruned_num
param.set_shape(new_shape)
_logger.debug("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[pruned_axis].append(param.name())
else:
for param in params:
assert isinstance(param, VarWrapper)
param_t = scope.find_var(param.name()).get_tensor()
if param_backup is not None and (
param.name() not in param_backup):
param_backup[param.name()] = copy.deepcopy(
np.array(param_t))
pruned_param = self._prune_tensor(
np.array(param_t), pruned_idx, pruned_axis, lazy=lazy)
param_t.set(pruned_param, place) param_t.set(pruned_param, place)
ori_shape = param.shape() ori_shape = param.shape()
if param_shape_backup is not None and ( if param_shape_backup is not None and (
param.name() not in param_shape_backup): param.name() not in param_shape_backup):
param_shape_backup[param.name()] = copy.deepcopy(param.shape()) param_shape_backup[param.name()] = copy.deepcopy(
new_shape = list(param.shape()) param.shape())
new_shape[pruned_axis] = pruned_param.shape[pruned_axis] new_shape = list(param.shape())
param.set_shape(new_shape) new_shape[pruned_axis] = pruned_param.shape[pruned_axis]
self.pruned_list[pruned_axis].append(param.name()) param.set_shape(new_shape)
_logger.debug("prune [{}] from {} to {}".format(param.name(
), ori_shape, new_shape))
self.pruned_list[pruned_axis].append(param.name())
def _forward_search_related_op(self, graph, param): def _forward_search_related_op(self, graph, param):
""" """
...@@ -478,19 +528,24 @@ class Pruner(): ...@@ -478,19 +528,24 @@ class Pruner():
Returns: Returns:
list<VarWrapper>: A list of operators. list<VarWrapper>: A list of operators.
""" """
_logger.debug("######################search: {}######################".
format(op_node))
visited = [op_node.idx()] visited = [op_node.idx()]
stack = [] stack = []
brothers = [] brothers = []
for op in graph.next_ops(op_node): for op in graph.next_ops(op_node):
if (op.type() != 'conv2d') and (op.type() != 'fc') and ( if ("conv2d" not in op.type()) and (op.type() != 'fc') and (
not op.is_bwd_op()): not op.is_bwd_op()) and (not op.is_opt_op()):
stack.append(op) stack.append(op)
visited.append(op.idx()) visited.append(op.idx())
while len(stack) > 0: while len(stack) > 0:
top_op = stack.pop() top_op = stack.pop()
for parent in graph.pre_ops(top_op): for parent in graph.pre_ops(top_op):
if parent.idx() not in visited and (not parent.is_bwd_op()): if parent.idx() not in visited and (
if ((parent.type() == 'conv2d') or not parent.is_bwd_op()) and (not parent.is_opt_op()):
_logger.debug("----------go back from {} to {}----------".
format(top_op, parent))
if (('conv2d' in parent.type()) or
(parent.type() == 'fc')): (parent.type() == 'fc')):
brothers.append(parent) brothers.append(parent)
else: else:
...@@ -498,11 +553,16 @@ class Pruner(): ...@@ -498,11 +553,16 @@ class Pruner():
visited.append(parent.idx()) visited.append(parent.idx())
for child in graph.next_ops(top_op): for child in graph.next_ops(top_op):
if (child.type() != 'conv2d') and (child.type() != 'fc') and ( if ('conv2d' not in child.type()
) and (child.type() != 'fc') and (
child.idx() not in visited) and ( child.idx() not in visited) and (
not child.is_bwd_op()): not child.is_bwd_op()) and (not child.is_opt_op()):
stack.append(child) stack.append(child)
visited.append(child.idx()) visited.append(child.idx())
_logger.debug("brothers: {}".format(brothers))
_logger.debug(
"######################Finish search######################".format(
op_node))
return brothers return brothers
def _cal_pruned_idx(self, name, param, ratio, axis): def _cal_pruned_idx(self, name, param, ratio, axis):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import logging
import pickle
import numpy as np
import paddle.fluid as fluid
from ..core import GraphWrapper
from ..common import get_logger
from ..analysis import flops
from ..prune import Pruner
_logger = get_logger(__name__, level=logging.INFO)
__all__ = ["sensitivity", "flops_sensitivity"]
def sensitivity(program,
place,
param_names,
eval_func,
sensitivities_file=None,
step_size=0.2,
max_pruned_times=None):
scope = fluid.global_scope()
graph = GraphWrapper(program)
sensitivities = _load_sensitivities(sensitivities_file)
for name in param_names:
if name not in sensitivities:
size = graph.var(name).shape()[0]
sensitivities[name] = {
'pruned_percent': [],
'loss': [],
'size': size
}
baseline = None
for name in sensitivities:
ratio = step_size
pruned_times = 0
while ratio < 1:
if max_pruned_times is not None and pruned_times >= max_pruned_times:
break
ratio = round(ratio, 2)
if ratio in sensitivities[name]['pruned_percent']:
_logger.debug('{}, {} has computed.'.format(name, ratio))
ratio += step_size
pruned_times += 1
continue
if baseline is None:
baseline = eval_func(graph.program)
param_backup = {}
pruner = Pruner()
_logger.info("sensitive - param: {}; ratios: {}".format(name,
ratio))
pruned_program = pruner.prune(
program=graph.program,
scope=scope,
params=[name],
ratios=[ratio],
place=place,
lazy=True,
only_graph=False,
param_backup=param_backup)
pruned_metric = eval_func(pruned_program)
loss = (baseline - pruned_metric) / baseline
_logger.info("pruned param: {}; {}; loss={}".format(name, ratio,
loss))
sensitivities[name]['pruned_percent'].append(ratio)
sensitivities[name]['loss'].append(loss)
_save_sensitivities(sensitivities, sensitivities_file)
# restore pruned parameters
for param_name in param_backup.keys():
param_t = scope.find_var(param_name).get_tensor()
param_t.set(param_backup[param_name], place)
ratio += step_size
pruned_times += 1
return sensitivities
def flops_sensitivity(program,
place,
param_names,
eval_func,
sensitivities_file=None,
pruned_flops_rate=0.1):
assert (1.0 / len(param_names) > pruned_flops_rate)
scope = fluid.global_scope()
graph = GraphWrapper(program)
sensitivities = _load_sensitivities(sensitivities_file)
for name in param_names:
if name not in sensitivities:
size = graph.var(name).shape()[0]
sensitivities[name] = {
'pruned_percent': [],
'loss': [],
'size': size
}
base_flops = flops(program)
target_pruned_flops = base_flops * pruned_flops_rate
pruner = Pruner()
baseline = None
for name in sensitivities:
pruned_program = pruner.prune(
program=graph.program,
scope=None,
params=[name],
ratios=[0.5],
place=None,
lazy=False,
only_graph=True)
param_flops = (base_flops - flops(pruned_program)) * 2
channel_size = sensitivities[name]["size"]
pruned_ratio = target_pruned_flops / float(param_flops)
pruned_size = round(pruned_ratio * channel_size)
pruned_ratio = 1 if pruned_size >= channel_size else pruned_ratio
if len(sensitivities[name]["pruned_percent"]) > 0:
_logger.debug('{} exist; pruned ratio: {}; excepted ratio: {}'.
format(name, sensitivities[name]["pruned_percent"][
0], pruned_ratio))
continue
if baseline is None:
baseline = eval_func(graph.program)
param_backup = {}
pruner = Pruner()
_logger.info("sensitive - param: {}; ratios: {}".format(name,
pruned_ratio))
loss = 1
if pruned_ratio < 1:
pruned_program = pruner.prune(
program=graph.program,
scope=scope,
params=[name],
ratios=[pruned_ratio],
place=place,
lazy=True,
only_graph=False,
param_backup=param_backup)
pruned_metric = eval_func(pruned_program)
loss = (baseline - pruned_metric) / baseline
_logger.info("pruned param: {}; {}; loss={}".format(name, pruned_ratio,
loss))
sensitivities[name]['pruned_percent'].append(pruned_ratio)
sensitivities[name]['loss'].append(loss)
_save_sensitivities(sensitivities, sensitivities_file)
# restore pruned parameters
for param_name in param_backup.keys():
param_t = scope.find_var(param_name).get_tensor()
param_t.set(param_backup[param_name], place)
return sensitivities
def _load_sensitivities(sensitivities_file):
"""
Load sensitivities from file.
"""
sensitivities = {}
if sensitivities_file and os.path.exists(sensitivities_file):
with open(sensitivities_file, 'rb') as f:
if sys.version_info < (3, 0):
sensitivities = pickle.load(f)
else:
sensitivities = pickle.load(f, encoding='bytes')
for param in sensitivities:
sensitivities[param]['pruned_percent'] = [
round(p, 2) for p in sensitivities[param]['pruned_percent']
]
return sensitivities
def _save_sensitivities(sensitivities, sensitivities_file):
"""
Save sensitivities into file.
"""
with open(sensitivities_file, 'wb') as f:
pickle.dump(sensitivities, f)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import copy
from scipy.optimize import leastsq
import numpy as np
import paddle.fluid as fluid
from ..common import get_logger
from .sensitive import sensitivity
from .sensitive import flops_sensitivity
from ..analysis import flops
from .pruner import Pruner
__all__ = ["SensitivePruner"]
_logger = get_logger(__name__, level=logging.INFO)
class SensitivePruner(object):
def __init__(self, place, eval_func, scope=None, checkpoints=None):
"""
Pruner used to prune parameters iteratively according to sensitivities of parameters in each step.
Args:
place(fluid.CUDAPlace | fluid.CPUPlace): The device place where program execute.
eval_func(function): A callback function used to evaluate pruned program. The argument of this function is pruned program. And it return a score of given program.
scope(fluid.scope): The scope used to execute program.
"""
self._eval_func = eval_func
self._iter = 0
self._place = place
self._scope = fluid.global_scope() if scope is None else scope
self._pruner = Pruner()
self._checkpoints = checkpoints
def save_checkpoint(self, train_program, eval_program):
checkpoint = os.path.join(self._checkpoints, str(self._iter - 1))
exe = fluid.Executor(self._place)
fluid.io.save_persistables(
exe, checkpoint, main_program=train_program, filename="__params__")
with open(checkpoint + "/main_program", "wb") as f:
f.write(train_program.desc.serialize_to_string())
with open(checkpoint + "/eval_program", "wb") as f:
f.write(eval_program.desc.serialize_to_string())
def restore(self, checkpoints=None):
exe = fluid.Executor(self._place)
checkpoints = self._checkpoints if checkpoints is None else checkpoints
print("check points: {}".format(checkpoints))
main_program = None
eval_program = None
if checkpoints is not None:
cks = [dir for dir in os.listdir(checkpoints)]
if len(cks) > 0:
latest = max([int(ck) for ck in cks])
latest_ck_path = os.path.join(checkpoints, str(latest))
self._iter += 1
with open(latest_ck_path + "/main_program", "rb") as f:
program_desc_str = f.read()
main_program = fluid.Program.parse_from_string(
program_desc_str)
with open(latest_ck_path + "/eval_program", "rb") as f:
program_desc_str = f.read()
eval_program = fluid.Program.parse_from_string(
program_desc_str)
with fluid.scope_guard(self._scope):
fluid.io.load_persistables(exe, latest_ck_path,
main_program, "__params__")
print("load checkpoint from: {}".format(latest_ck_path))
print("flops of eval program: {}".format(flops(eval_program)))
return main_program, eval_program, self._iter
def greedy_prune(self,
train_program,
eval_program,
params,
pruned_flops_rate,
topk=1):
sensitivities_file = "greedy_sensitivities_iter{}.data".format(
self._iter)
with fluid.scope_guard(self._scope):
sensitivities = flops_sensitivity(
eval_program,
self._place,
params,
self._eval_func,
sensitivities_file=sensitivities_file,
pruned_flops_rate=pruned_flops_rate)
print sensitivities
params, ratios = self._greedy_ratio_by_sensitive(sensitivities, topk)
_logger.info("Pruning: {} by {}".format(params, ratios))
pruned_program = self._pruner.prune(
train_program,
self._scope,
params,
ratios,
place=self._place,
only_graph=False)
pruned_val_program = None
if eval_program is not None:
pruned_val_program = self._pruner.prune(
eval_program,
self._scope,
params,
ratios,
place=self._place,
only_graph=True)
self._iter += 1
return pruned_program, pruned_val_program
def prune(self, train_program, eval_program, params, pruned_flops):
"""
Pruning parameters of training and evaluation network by sensitivities in current step.
Args:
train_program(fluid.Program): The training program to be pruned.
eval_program(fluid.Program): The evaluation program to be pruned. And it is also used to calculate sensitivities of parameters.
params(list<str>): The parameters to be pruned.
pruned_flops(float): The ratio of FLOPS to be pruned in current step.
Return:
tuple: A tuple of pruned training program and pruned evaluation program.
"""
_logger.info("Pruning: {}".format(params))
sensitivities_file = "sensitivities_iter{}.data".format(self._iter)
with fluid.scope_guard(self._scope):
sensitivities = sensitivity(
eval_program,
self._place,
params,
self._eval_func,
sensitivities_file=sensitivities_file,
step_size=0.1)
print sensitivities
_, ratios = self._get_ratios_by_sensitive(sensitivities, pruned_flops,
eval_program)
pruned_program = self._pruner.prune(
train_program,
self._scope,
params,
ratios,
place=self._place,
only_graph=False)
pruned_val_program = None
if eval_program is not None:
pruned_val_program = self._pruner.prune(
eval_program,
self._scope,
params,
ratios,
place=self._place,
only_graph=True)
self._iter += 1
return pruned_program, pruned_val_program
def _greedy_ratio_by_sensitive(self, sensitivities, topk=1):
losses = {}
percents = {}
for param in sensitivities:
losses[param] = sensitivities[param]['loss'][0]
percents[param] = sensitivities[param]['pruned_percent'][0]
topk_parms = sorted(losses, key=losses.__getitem__)[:topk]
topk_percents = [percents[param] for param in topk_parms]
return topk_parms, topk_percents
def _get_ratios_by_sensitive(self, sensitivities, pruned_flops,
eval_program):
"""
Search a group of ratios for pruning target flops.
"""
def func(params, x):
a, b, c, d = params
return a * x * x * x + b * x * x + c * x + d
def error(params, x, y):
return func(params, x) - y
def slove_coefficient(x, y):
init_coefficient = [10, 10, 10, 10]
coefficient, loss = leastsq(error, init_coefficient, args=(x, y))
return coefficient
min_loss = 0.
max_loss = 0.
# step 1: fit curve by sensitivities
coefficients = {}
for param in sensitivities:
losses = np.array([0] * 5 + sensitivities[param]['loss'])
precents = np.array([0] * 5 + sensitivities[param][
'pruned_percent'])
coefficients[param] = slove_coefficient(precents, losses)
loss = np.max(losses)
max_loss = np.max([max_loss, loss])
# step 2: Find a group of ratios by binary searching.
base_flops = flops(eval_program)
ratios = []
max_times = 20
while min_loss < max_loss and max_times > 0:
loss = (max_loss + min_loss) / 2
_logger.info(
'-----------Try pruned ratios while acc loss={}-----------'.
format(loss))
ratios = []
# step 2.1: Get ratios according to current loss
for param in sensitivities:
coefficient = copy.deepcopy(coefficients[param])
coefficient[-1] = coefficient[-1] - loss
roots = np.roots(coefficient)
for root in roots:
min_root = 1
if np.isreal(root) and root > 0 and root < 1:
selected_root = min(root.real, min_root)
ratios.append(selected_root)
_logger.info('Pruned ratios={}'.format(
[round(ratio, 3) for ratio in ratios]))
# step 2.2: Pruning by current ratios
param_shape_backup = {}
pruned_program = self._pruner.prune(
eval_program,
None, # scope
sensitivities.keys(),
ratios,
None, # place
only_graph=True)
pruned_ratio = 1 - (float(flops(pruned_program)) / base_flops)
_logger.info('Pruned flops: {:.4f}'.format(pruned_ratio))
# step 2.3: Check whether current ratios is enough
if abs(pruned_ratio - pruned_flops) < 0.015:
break
if pruned_ratio > pruned_flops:
max_loss = loss
else:
min_loss = loss
max_times -= 1
return sensitivities.keys(), ratios
...@@ -20,11 +20,20 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass ...@@ -20,11 +20,20 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core from paddle.fluid import core
WEIGHT_QUANTIZATION_TYPES=['abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'] WEIGHT_QUANTIZATION_TYPES = [
ACTIVATION_QUANTIZATION_TYPES=['abs_max','range_abs_max', 'moving_average_abs_max'] 'abs_max', 'channel_wise_abs_max', 'range_abs_max',
'moving_average_abs_max'
]
ACTIVATION_QUANTIZATION_TYPES = [
'abs_max', 'range_abs_max', 'moving_average_abs_max'
]
VALID_DTYPES = ['int8'] VALID_DTYPES = ['int8']
TRANSFORM_PASS_OP_TYPES = ['conv2d', 'depthwise_conv2d', 'mul']
QUANT_DEQUANT_PASS_OP_TYPES = ['elementwise_add', 'pool2d']
_quant_config_default = { _quant_config_default = {
# weight quantize type, default is 'abs_max' # weight quantize type, default is 'abs_max'
...@@ -38,7 +47,8 @@ _quant_config_default = { ...@@ -38,7 +47,8 @@ _quant_config_default = {
# ops of name_scope in not_quant_pattern list, will not be quantized # ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'], 'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized # ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], 'quantize_op_types':
['conv2d', 'depthwise_conv2d', 'mul', 'elementwise_add', 'pool2d'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8', 'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000 # window size for 'range_abs_max' quantization. defaulf is 10000
...@@ -88,6 +98,12 @@ def _parse_configs(user_config): ...@@ -88,6 +98,12 @@ def _parse_configs(user_config):
assert isinstance(configs['quantize_op_types'], list), \ assert isinstance(configs['quantize_op_types'], list), \
"quantize_op_types must be a list" "quantize_op_types must be a list"
for op_type in configs['quantize_op_types']:
assert (op_type in QUANT_DEQUANT_PASS_OP_TYPES) or (
op_type in TRANSFORM_PASS_OP_TYPES), "{} is not support, \
now support op types are {}".format(
op_type, TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES)
assert isinstance(configs['dtype'], str), \ assert isinstance(configs['dtype'], str), \
"dtype must be a str." "dtype must be a str."
...@@ -132,19 +148,37 @@ def quant_aware(program, place, config, scope=None, for_test=False): ...@@ -132,19 +148,37 @@ def quant_aware(program, place, config, scope=None, for_test=False):
config = _parse_configs(config) config = _parse_configs(config)
main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) main_graph = IrGraph(core.Graph(program.desc), for_test=for_test)
transform_pass = QuantizationTransformPass( transform_pass_ops = []
scope=scope, quant_dequant_ops = []
place=place, for op_type in config['quantize_op_types']:
weight_bits=config['weight_bits'], if op_type in TRANSFORM_PASS_OP_TYPES:
activation_bits=config['activation_bits'], transform_pass_ops.append(op_type)
activation_quantize_type=config['activation_quantize_type'], elif op_type in QUANT_DEQUANT_PASS_OP_TYPES:
weight_quantize_type=config['weight_quantize_type'], quant_dequant_ops.append(op_type)
window_size=config['window_size'], if len(transform_pass_ops) > 0:
moving_rate=config['moving_rate'], transform_pass = QuantizationTransformPass(
quantizable_op_type=config['quantize_op_types'], scope=scope,
skip_pattern=config['not_quant_pattern']) place=place,
weight_bits=config['weight_bits'],
transform_pass.apply(main_graph) activation_bits=config['activation_bits'],
activation_quantize_type=config['activation_quantize_type'],
weight_quantize_type=config['weight_quantize_type'],
window_size=config['window_size'],
moving_rate=config['moving_rate'],
quantizable_op_type=transform_pass_ops,
skip_pattern=config['not_quant_pattern'])
transform_pass.apply(main_graph)
if len(quant_dequant_ops) > 0:
quant_dequant_pass = AddQuantDequantPass(
scope=scope,
place=place,
moving_rate=config['moving_rate'],
quant_bits=config['activation_bits'],
skip_pattern=config['not_quant_pattern'],
quantizable_op_type=quant_dequant_ops)
quant_dequant_pass.apply(main_graph)
if for_test: if for_test:
quant_program = main_graph.to_program() quant_program = main_graph.to_program()
...@@ -153,22 +187,71 @@ def quant_aware(program, place, config, scope=None, for_test=False): ...@@ -153,22 +187,71 @@ def quant_aware(program, place, config, scope=None, for_test=False):
return quant_program return quant_program
def quant_post(program, place, config, scope=None): def quant_post(executor,
model_dir,
quantize_model_path,
sample_generator,
model_filename=None,
params_filename=None,
batch_size=16,
batch_nums=None,
scope=None,
algo='KL',
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"]):
""" """
add quantization ops in program. the program returned is not trainable. The function utilizes post training quantization method to quantize the
fp32 model. It uses calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Args: Args:
program(fluid.Program): program executor(fluid.Executor): The executor to load, run and save the
scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope(). quantized model.
place(fluid.CPUPlace or fluid.CUDAPlace): place model_dir(str): The path of fp32 model that will be quantized, and
config(dict): configs for quantization, default values are in quant_config_default dict. the model and params that saved by fluid.io.save_inference_model
for_test: is for test program. are under the path.
Return: quantize_model_path(str): The path to save quantized model using api
fluid.Program: the quantization program is not trainable. fluid.io.save_inference_model.
sample_generator(Python Generator): The sample generator provides
calibrate data for DataLoader, and it only returns a sample every time.
model_filename(str, optional): The name of model file. If parameters
are saved in separate files, set it as 'None'. Default is 'None'.
params_filename(str, optional): The name of params file.
When all parameters are saved in a single file, set it
as filename. If parameters are saved in separate files,
set it as 'None'. Default is 'None'.
batch_size(int, optional): The batch size of DataLoader, default is 16.
batch_nums(int, optional): If batch_nums is not None, the number of calibrate
data is 'batch_size*batch_nums'. If batch_nums is None, use all data
generated by sample_generator as calibrate data.
scope(fluid.Scope, optional): The scope to run program, use it to load
and save variables. If scope is None, will use fluid.global_scope().
algo(str, optional): If algo=KL, use KL-divergenc method to
get the more precise scale factor. If algo='direct', use
abs_max method to get the scale factor. Default is 'KL'.
quantizable_op_type(list[str], optional): The list of op types
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"].
Returns:
None
""" """
pass post_training_quantization = PostTrainingQuantization(
executor=executor,
sample_generator=sample_generator,
model_dir=model_dir,
model_filename=model_filename,
params_filename=params_filename,
batch_size=batch_size,
batch_nums=batch_nums,
scope=scope,
algo=algo,
quantizable_op_type=quantizable_op_type,
is_full_quantize=False)
post_training_quantization.quantize()
post_training_quantization.save_quantized_model(quantize_model_path)
def convert(program, scope, place, config, save_int8=False): def convert(program, place, config, scope=None, save_int8=False):
""" """
add quantization ops in program. the program returned is not trainable. add quantization ops in program. the program returned is not trainable.
Args: Args:
...@@ -183,7 +266,7 @@ def convert(program, scope, place, config, save_int8=False): ...@@ -183,7 +266,7 @@ def convert(program, scope, place, config, save_int8=False):
fluid.Program: freezed int8 program which can be used for inference. fluid.Program: freezed int8 program which can be used for inference.
if save_int8 is False, this value is None. if save_int8 is False, this value is None.
""" """
scope = fluid.global_scope() if not scope else scope
test_graph = IrGraph(core.Graph(program.desc), for_test=True) test_graph = IrGraph(core.Graph(program.desc), for_test=True)
# Freeze the graph after training by adjusting the quantize # Freeze the graph after training by adjusting the quantize
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Controllers and controller server"""
...@@ -15,7 +15,7 @@ import sys ...@@ -15,7 +15,7 @@ import sys
sys.path.append("../") sys.path.append("../")
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from prune import Pruner from paddleslim.prune import Pruner
from layers import conv_bn_layer from layers import conv_bn_layer
......
...@@ -41,7 +41,10 @@ class TestSANAS(unittest.TestCase): ...@@ -41,7 +41,10 @@ class TestSANAS(unittest.TestCase):
search_steps = 3 search_steps = 3
sa_nas = SANAS( sa_nas = SANAS(
configs, max_flops=base_flops, search_steps=search_steps) configs,
search_steps=search_steps,
server_addr=("", 0),
is_server=True)
for i in range(search_steps): for i in range(search_steps):
archs = sa_nas.next_archs() archs = sa_nas.next_archs()
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import numpy
import paddle
import paddle.fluid as fluid
from paddleslim.analysis import sensitivity
from layers import conv_bn_layer
class TestSensitivity(unittest.TestCase):
def test_sensitivity(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
input = fluid.data(name="image", shape=[None, 1, 28, 28])
label = fluid.data(name="label", shape=[None, 1], dtype="int64")
conv1 = conv_bn_layer(input, 8, 3, "conv1")
conv2 = conv_bn_layer(conv1, 8, 3, "conv2")
sum1 = conv1 + conv2
conv3 = conv_bn_layer(sum1, 8, 3, "conv3")
conv4 = conv_bn_layer(conv3, 8, 3, "conv4")
sum2 = conv4 + sum1
conv5 = conv_bn_layer(sum2, 8, 3, "conv5")
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
out = fluid.layers.fc(conv6, size=10, act='softmax')
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
eval_program = main_program.clone(for_test=True)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup_program)
val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)
def eval_func(program, scope):
feeder = fluid.DataFeeder(
feed_list=['image', 'label'], place=place, program=program)
acc_set = []
for data in val_reader():
acc_np = exe.run(program=program,
scope=scope,
feed=feeder.feed(data),
fetch_list=[acc_top1])
acc_set.append(float(acc_np[0]))
acc_val_mean = numpy.array(acc_set).mean()
print("acc_val_mean: {}".format(acc_val_mean))
return acc_val_mean
sensitivity(eval_program,
fluid.global_scope(), place, ["conv4_weights"], eval_func,
"./sensitivities_file")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册