diff --git a/.gitignore b/.gitignore index 2ea48a8b28c35f35d3880ccca1d54ea5b3947e0f..c59240cd16aeeb80cbe8501f65dba682edddc463 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ *.egg-info build/ ./dist/ +*.pyc +dist/ diff --git a/demo/auto_prune/train.py b/demo/auto_prune/train.py new file mode 100644 index 0000000000000000000000000000000000000000..70930774dc1c4306d12e63fbd1766a67ec2a5c3c --- /dev/null +++ b/demo/auto_prune/train.py @@ -0,0 +1,221 @@ +import os +import sys +import logging +import paddle +import argparse +import functools +import math +import time +import numpy as np +import paddle.fluid as fluid +from paddleslim.prune import AutoPruner +from paddleslim.common import get_logger +from paddleslim.analysis import flops +sys.path.append(sys.path[0] + "/../") +import models +from utility import add_arguments, print_arguments + +_logger = get_logger(__name__, level=logging.INFO) + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('batch_size', int, 64 * 4, "Minibatch size.") +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('model', str, "MobileNet", "The target model.") +add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretained", "Whether to use pretrained model.") +add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.") +add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.") +add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.") +add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") +add_arg('num_epochs', int, 120, "The number of total epochs.") +add_arg('total_images', int, 1281167, "The number of total training images.") +parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") +add_arg('config_file', str, None, "The config file for compression with yaml format.") +add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'") +add_arg('log_period', int, 10, "Log period in batches.") +add_arg('test_period', int, 10, "Test period in epoches.") +# yapf: enable + +model_list = [m for m in dir(models) if "__" not in m] + + +def piecewise_decay(args): + step = int(math.ceil(float(args.total_images) / args.batch_size)) + bd = [step * e for e in args.step_epochs] + lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)] + learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=args.momentum_rate, + regularization=fluid.regularizer.L2Decay(args.l2_decay)) + return optimizer + + +def cosine_decay(args): + step = int(math.ceil(float(args.total_images) / args.batch_size)) + learning_rate = fluid.layers.cosine_decay( + learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=args.momentum_rate, + regularization=fluid.regularizer.L2Decay(args.l2_decay)) + return optimizer + + +def create_optimizer(args): + if args.lr_strategy == "piecewise_decay": + return piecewise_decay(args) + elif args.lr_strategy == "cosine_decay": + return cosine_decay(args) + + +def compress(args): + + train_reader = None + test_reader = None + if args.data == "mnist": + import paddle.dataset.mnist as reader + train_reader = reader.train() + val_reader = reader.test() + class_dim = 10 + image_shape = "1,28,28" + elif args.data == "imagenet": + import imagenet_reader as reader + train_reader = reader.train() + val_reader = reader.val() + class_dim = 1000 + image_shape = "3,224,224" + else: + raise ValueError("{} is not supported.".format(args.data)) + + image_shape = [int(m) for m in image_shape.split(",")] + assert args.model in model_list, "{} is not in lists: {}".format( + args.model, model_list) + image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + # model definition + model = models.__dict__[args.model]() + out = model.net(input=image, class_dim=class_dim) + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + val_program = fluid.default_main_program().clone(for_test=True) + opt = create_optimizer(args) + opt.minimize(avg_cost) + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + if args.pretrained_model: + + def if_exist(var): + return os.path.exists( + os.path.join(args.pretrained_model, var.name)) + + fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist) + + val_reader = paddle.batch(val_reader, batch_size=args.batch_size) + train_reader = paddle.batch( + train_reader, batch_size=args.batch_size, drop_last=True) + + train_feeder = feeder = fluid.DataFeeder([image, label], place) + val_feeder = feeder = fluid.DataFeeder( + [image, label], place, program=val_program) + + def test(epoch, program): + batch_id = 0 + acc_top1_ns = [] + acc_top5_ns = [] + for data in val_reader(): + start_time = time.time() + acc_top1_n, acc_top5_n = exe.run( + program, + feed=train_feeder.feed(data), + fetch_list=[acc_top1.name, acc_top5.name]) + end_time = time.time() + if batch_id % args.log_period == 0: + _logger.info( + "Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}". + format(epoch, batch_id, + np.mean(acc_top1_n), + np.mean(acc_top5_n), end_time - start_time)) + acc_top1_ns.append(np.mean(acc_top1_n)) + acc_top5_ns.append(np.mean(acc_top5_n)) + batch_id += 1 + + _logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}". + format(epoch, + np.mean(np.array(acc_top1_ns)), + np.mean(np.array(acc_top5_ns)))) + return np.mean(np.array(acc_top1_ns)) + + def train(epoch, program): + + build_strategy = fluid.BuildStrategy() + exec_strategy = fluid.ExecutionStrategy() + train_program = fluid.compiler.CompiledProgram( + program).with_data_parallel( + loss_name=avg_cost.name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + + batch_id = 0 + for data in train_reader(): + start_time = time.time() + loss_n, acc_top1_n, acc_top5_n = exe.run( + train_program, + feed=train_feeder.feed(data), + fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name]) + end_time = time.time() + loss_n = np.mean(loss_n) + acc_top1_n = np.mean(acc_top1_n) + acc_top5_n = np.mean(acc_top5_n) + if batch_id % args.log_period == 0: + _logger.info( + "epoch[{}]-batch[{}] - loss: {}; acc_top1: {}; acc_top5: {}; time: {}". + format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n, + end_time - start_time)) + batch_id += 1 + + params = [] + for param in fluid.default_main_program().global_block().all_parameters(): + if "_sep_weights" in param.name: + params.append(param.name) + + pruner = AutoPruner( + val_program, + fluid.global_scope(), + place, + params=params, + init_ratios=[0.33] * len(params), + pruned_flops=0.5, + pruned_latency=None, + server_addr=("", 0), + init_temperature=100, + reduce_rate=0.85, + max_try_number=300, + max_client_num=10, + search_steps=100, + max_ratios=0.9, + min_ratios=0., + key="auto_pruner") + + while True: + pruned_program, pruned_val_program = pruner.prune( + fluid.default_main_program(), val_program) + for i in range(1): + train(i, pruned_program) + score = test(0, pruned_val_program) + pruner.reward(score) + + +def main(): + args = parser.parse_args() + print_arguments(args) + compress(args) + + +if __name__ == '__main__': + main() diff --git a/demo/imagenet_reader.py b/demo/imagenet_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..25bc756e93db829f3566754e079ba7711074e577 --- /dev/null +++ b/demo/imagenet_reader.py @@ -0,0 +1,194 @@ +import os +import math +import random +import functools +import numpy as np +import paddle +from PIL import Image, ImageEnhance + +random.seed(0) +np.random.seed(0) + +DATA_DIM = 224 + +THREAD = 16 +BUF_SIZE = 10240 + +#DATA_DIR = './data/ILSVRC2012/' +DATA_DIR = './data/' +DATA_DIR = os.path.join(os.path.split(os.path.realpath(__file__))[0], DATA_DIR) + +img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) +img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) + + +def resize_short(img, target_size): + percent = float(target_size) / min(img.size[0], img.size[1]) + resized_width = int(round(img.size[0] * percent)) + resized_height = int(round(img.size[1] * percent)) + img = img.resize((resized_width, resized_height), Image.LANCZOS) + return img + + +def crop_image(img, target_size, center): + width, height = img.size + size = target_size + if center == True: + w_start = (width - size) / 2 + h_start = (height - size) / 2 + else: + w_start = np.random.randint(0, width - size + 1) + h_start = np.random.randint(0, height - size + 1) + w_end = w_start + size + h_end = h_start + size + img = img.crop((w_start, h_start, w_end, h_end)) + return img + + +def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]): + aspect_ratio = math.sqrt(np.random.uniform(*ratio)) + w = 1. * aspect_ratio + h = 1. / aspect_ratio + + bound = min((float(img.size[0]) / img.size[1]) / (w**2), + (float(img.size[1]) / img.size[0]) / (h**2)) + scale_max = min(scale[1], bound) + scale_min = min(scale[0], bound) + + target_area = img.size[0] * img.size[1] * np.random.uniform(scale_min, + scale_max) + target_size = math.sqrt(target_area) + w = int(target_size * w) + h = int(target_size * h) + + i = np.random.randint(0, img.size[0] - w + 1) + j = np.random.randint(0, img.size[1] - h + 1) + + img = img.crop((i, j, i + w, j + h)) + img = img.resize((size, size), Image.LANCZOS) + return img + + +def rotate_image(img): + angle = np.random.randint(-10, 11) + img = img.rotate(angle) + return img + + +def distort_color(img): + def random_brightness(img, lower=0.5, upper=1.5): + e = np.random.uniform(lower, upper) + return ImageEnhance.Brightness(img).enhance(e) + + def random_contrast(img, lower=0.5, upper=1.5): + e = np.random.uniform(lower, upper) + return ImageEnhance.Contrast(img).enhance(e) + + def random_color(img, lower=0.5, upper=1.5): + e = np.random.uniform(lower, upper) + return ImageEnhance.Color(img).enhance(e) + + ops = [random_brightness, random_contrast, random_color] + np.random.shuffle(ops) + + img = ops[0](img) + img = ops[1](img) + img = ops[2](img) + + return img + + +def process_image(sample, mode, color_jitter, rotate): + img_path = sample[0] + + img = Image.open(img_path) + if mode == 'train': + if rotate: img = rotate_image(img) + img = random_crop(img, DATA_DIM) + else: + img = resize_short(img, target_size=256) + img = crop_image(img, target_size=DATA_DIM, center=True) + if mode == 'train': + if color_jitter: + img = distort_color(img) + if np.random.randint(0, 2) == 1: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + + if img.mode != 'RGB': + img = img.convert('RGB') + + img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255 + img -= img_mean + img /= img_std + + if mode == 'train' or mode == 'val': + return img, sample[1] + elif mode == 'test': + return [img] + + +def _reader_creator(file_list, + mode, + shuffle=False, + color_jitter=False, + rotate=False, + data_dir=DATA_DIR, + batch_size=1): + def reader(): + try: + with open(file_list) as flist: + full_lines = [line.strip() for line in flist] + if shuffle: + np.random.shuffle(full_lines) + if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'): + # distributed mode if the env var `PADDLE_TRAINING_ROLE` exits + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + trainer_count = int(os.getenv("PADDLE_TRAINERS", "1")) + per_node_lines = len(full_lines) // trainer_count + lines = full_lines[trainer_id * per_node_lines:( + trainer_id + 1) * per_node_lines] + print( + "read images from %d, length: %d, lines length: %d, total: %d" + % (trainer_id * per_node_lines, per_node_lines, + len(lines), len(full_lines))) + else: + lines = full_lines + + for line in lines: + if mode == 'train' or mode == 'val': + img_path, label = line.split() + img_path = os.path.join(data_dir + "/" + mode, + img_path) + yield img_path, int(label) + elif mode == 'test': + img_path = os.path.join(data_dir, line) + yield [img_path] + except Exception as e: + print("Reader failed!\n{}".format(str(e))) + os._exit(1) + + mapper = functools.partial( + process_image, mode=mode, color_jitter=color_jitter, rotate=rotate) + + return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) + + +def train(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'train_list.txt') + return _reader_creator( + file_list, + 'train', + shuffle=True, + color_jitter=False, + rotate=False, + data_dir=data_dir) + + +def val(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'val_list.txt') + return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir) + + +def test(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'test_list.txt') + return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir) diff --git a/demo/models/__init__.py b/demo/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e843697407850c049a5427d2b6533c417e59c228 --- /dev/null +++ b/demo/models/__init__.py @@ -0,0 +1,5 @@ +from .mobilenet import MobileNet +from .resnet import ResNet34, ResNet50 +from .mobilenet_v2 import MobileNetV2 + +__all__ = ['MobileNet', 'ResNet34', 'ResNet50', 'MobileNetV2'] diff --git a/demo/models/mobilenet.py b/demo/models/mobilenet.py new file mode 100644 index 0000000000000000000000000000000000000000..921d6226ca2a65d5c9b57e27bf6607c7376c51f6 --- /dev/null +++ b/demo/models/mobilenet.py @@ -0,0 +1,197 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle.fluid as fluid +from paddle.fluid.initializer import MSRA +from paddle.fluid.param_attr import ParamAttr + +__all__ = ['MobileNet'] + +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [10, 16, 30], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + + +class MobileNet(): + def __init__(self): + self.params = train_parameters + + def net(self, input, class_dim=1000, scale=1.0): + # conv1: 112x112 + input = self.conv_bn_layer( + input, + filter_size=3, + channels=3, + num_filters=int(32 * scale), + stride=2, + padding=1, + name="conv1") + + # 56x56 + input = self.depthwise_separable( + input, + num_filters1=32, + num_filters2=64, + num_groups=32, + stride=1, + scale=scale, + name="conv2_1") + + input = self.depthwise_separable( + input, + num_filters1=64, + num_filters2=128, + num_groups=64, + stride=2, + scale=scale, + name="conv2_2") + + # 28x28 + input = self.depthwise_separable( + input, + num_filters1=128, + num_filters2=128, + num_groups=128, + stride=1, + scale=scale, + name="conv3_1") + + input = self.depthwise_separable( + input, + num_filters1=128, + num_filters2=256, + num_groups=128, + stride=2, + scale=scale, + name="conv3_2") + + # 14x14 + input = self.depthwise_separable( + input, + num_filters1=256, + num_filters2=256, + num_groups=256, + stride=1, + scale=scale, + name="conv4_1") + + input = self.depthwise_separable( + input, + num_filters1=256, + num_filters2=512, + num_groups=256, + stride=2, + scale=scale, + name="conv4_2") + + # 14x14 + for i in range(5): + input = self.depthwise_separable( + input, + num_filters1=512, + num_filters2=512, + num_groups=512, + stride=1, + scale=scale, + name="conv5" + "_" + str(i + 1)) + # 7x7 + input = self.depthwise_separable( + input, + num_filters1=512, + num_filters2=1024, + num_groups=512, + stride=2, + scale=scale, + name="conv5_6") + + input = self.depthwise_separable( + input, + num_filters1=1024, + num_filters2=1024, + num_groups=1024, + stride=1, + scale=scale, + name="conv6") + + input = fluid.layers.pool2d( + input=input, + pool_size=0, + pool_stride=1, + pool_type='avg', + global_pooling=True) + + output = fluid.layers.fc(input=input, + size=class_dim, + act='softmax', + param_attr=ParamAttr( + initializer=MSRA(), name="fc7_weights"), + bias_attr=ParamAttr(name="fc7_offset")) + + return output + + def conv_bn_layer(self, + input, + filter_size, + num_filters, + stride, + padding, + channels=None, + num_groups=1, + act='relu', + use_cudnn=True, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=ParamAttr( + initializer=MSRA(), name=name + "_weights"), + bias_attr=False) + bn_name = name + "_bn" + return fluid.layers.batch_norm( + input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + "_scale"), + bias_attr=ParamAttr(name=bn_name + "_offset"), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def depthwise_separable(self, + input, + num_filters1, + num_filters2, + num_groups, + stride, + scale, + name=None): + depthwise_conv = self.conv_bn_layer( + input=input, + filter_size=3, + num_filters=int(num_filters1 * scale), + stride=stride, + padding=1, + num_groups=int(num_groups * scale), + use_cudnn=False, + name=name + "_dw") + + pointwise_conv = self.conv_bn_layer( + input=depthwise_conv, + filter_size=1, + num_filters=int(num_filters2 * scale), + stride=1, + padding=0, + name=name + "_sep") + return pointwise_conv diff --git a/demo/models/mobilenet_v2.py b/demo/models/mobilenet_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..ccfb250b79a5365d28470886624287fbc87be50c --- /dev/null +++ b/demo/models/mobilenet_v2.py @@ -0,0 +1,259 @@ +#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle.fluid as fluid +from paddle.fluid.initializer import MSRA +from paddle.fluid.param_attr import ParamAttr + +__all__ = [ + 'MobileNetV2', 'MobileNetV2_x0_25, ' + 'MobileNetV2_x0_5', 'MobileNetV2_x1_0', 'MobileNetV2_x1_5', + 'MobileNetV2_x2_0', 'MobileNetV2_scale' +] + +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + + +class MobileNetV2(): + def __init__(self, scale=1.0, change_depth=False): + self.params = train_parameters + self.scale = scale + self.change_depth = change_depth + + def net(self, input, class_dim=1000): + scale = self.scale + change_depth = self.change_depth + #if change_depth is True, the new depth is 1.4 times as deep as before. + bottleneck_params_list = [ + (1, 16, 1, 1), + (6, 24, 2, 2), + (6, 32, 3, 2), + (6, 64, 4, 2), + (6, 96, 3, 1), + (6, 160, 3, 2), + (6, 320, 1, 1), + ] if change_depth == False else [ + (1, 16, 1, 1), + (6, 24, 2, 2), + (6, 32, 5, 2), + (6, 64, 7, 2), + (6, 96, 5, 1), + (6, 160, 3, 2), + (6, 320, 1, 1), + ] + + #conv1 + input = self.conv_bn_layer( + input, + num_filters=int(32 * scale), + filter_size=3, + stride=2, + padding=1, + if_act=True, + name='conv1_1') + + # bottleneck sequences + i = 1 + in_c = int(32 * scale) + for layer_setting in bottleneck_params_list: + t, c, n, s = layer_setting + i += 1 + input = self.invresi_blocks( + input=input, + in_c=in_c, + t=t, + c=int(c * scale), + n=n, + s=s, + name='conv' + str(i)) + in_c = int(c * scale) + #last_conv + input = self.conv_bn_layer( + input=input, + num_filters=int(1280 * scale) if scale > 1.0 else 1280, + filter_size=1, + stride=1, + padding=0, + if_act=True, + name='conv9') + + input = fluid.layers.pool2d( + input=input, + pool_size=7, + pool_stride=1, + pool_type='avg', + global_pooling=True) + + output = fluid.layers.fc(input=input, + size=class_dim, + act='softmax', + param_attr=ParamAttr(name='fc10_weights'), + bias_attr=ParamAttr(name='fc10_offset')) + return output + + def conv_bn_layer(self, + input, + filter_size, + num_filters, + stride, + padding, + channels=None, + num_groups=1, + if_act=True, + name=None, + use_cudnn=True): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + bn_name = name + '_bn' + bn = fluid.layers.batch_norm( + input=conv, + param_attr=ParamAttr(name=bn_name + "_scale"), + bias_attr=ParamAttr(name=bn_name + "_offset"), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + if if_act: + return fluid.layers.relu6(bn) + else: + return bn + + def shortcut(self, input, data_residual): + return fluid.layers.elementwise_add(input, data_residual) + + def inverted_residual_unit(self, + input, + num_in_filter, + num_filters, + ifshortcut, + stride, + filter_size, + padding, + expansion_factor, + name=None): + num_expfilter = int(round(num_in_filter * expansion_factor)) + + channel_expand = self.conv_bn_layer( + input=input, + num_filters=num_expfilter, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=True, + name=name + '_expand') + + bottleneck_conv = self.conv_bn_layer( + input=channel_expand, + num_filters=num_expfilter, + filter_size=filter_size, + stride=stride, + padding=padding, + num_groups=num_expfilter, + if_act=True, + name=name + '_dwise', + use_cudnn=False) + + linear_out = self.conv_bn_layer( + input=bottleneck_conv, + num_filters=num_filters, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=False, + name=name + '_linear') + if ifshortcut: + out = self.shortcut(input=input, data_residual=linear_out) + return out + else: + return linear_out + + def invresi_blocks(self, input, in_c, t, c, n, s, name=None): + first_block = self.inverted_residual_unit( + input=input, + num_in_filter=in_c, + num_filters=c, + ifshortcut=False, + stride=s, + filter_size=3, + padding=1, + expansion_factor=t, + name=name + '_1') + + last_residual_block = first_block + last_c = c + + for i in range(1, n): + last_residual_block = self.inverted_residual_unit( + input=last_residual_block, + num_in_filter=last_c, + num_filters=c, + ifshortcut=True, + stride=1, + filter_size=3, + padding=1, + expansion_factor=t, + name=name + '_' + str(i + 1)) + return last_residual_block + + +def MobileNetV2_x0_25(): + model = MobileNetV2(scale=0.25) + return model + + +def MobileNetV2_x0_5(): + model = MobileNetV2(scale=0.5) + return model + + +def MobileNetV2_x1_0(): + model = MobileNetV2(scale=1.0) + return model + + +def MobileNetV2_x1_5(): + model = MobileNetV2(scale=1.5) + return model + + +def MobileNetV2_x2_0(): + model = MobileNetV2(scale=2.0) + return model + + +def MobileNetV2_scale(): + model = MobileNetV2(scale=1.2, change_depth=True) + return model diff --git a/demo/models/resnet.py b/demo/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4ceaef41ecc87d7388ae05d7fcb199de1841ebc2 --- /dev/null +++ b/demo/models/resnet.py @@ -0,0 +1,229 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle +import paddle.fluid as fluid +import math +from paddle.fluid.param_attr import ParamAttr + +__all__ = ["ResNet", "ResNet34", "ResNet50", "ResNet101", "ResNet152"] + +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [10, 16, 30], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + + +class ResNet(): + def __init__(self, layers=50, prefix_name=''): + self.params = train_parameters + self.layers = layers + self.prefix_name = prefix_name + + def net(self, input, class_dim=1000, conv1_name='conv1', fc_name=None): + layers = self.layers + prefix_name = self.prefix_name if self.prefix_name is '' else self.prefix_name + '_' + supported_layers = [34, 50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + + if layers == 34 or layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + num_filters = [64, 128, 256, 512] + + # TODO(wanghaoshuang@baidu.com): + # fix name("conv1") conflict between student and teacher in distillation. + conv = self.conv_bn_layer( + input=input, + num_filters=64, + filter_size=7, + stride=2, + act='relu', + name=prefix_name + conv1_name) + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + + if layers >= 50: + for block in range(len(depth)): + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + conv_name = prefix_name + conv_name + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + name=conv_name) + + pool = fluid.layers.pool2d( + input=conv, pool_size=7, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + fc_name = fc_name if fc_name is None else prefix_name + fc_name + out = fluid.layers.fc(input=pool, + size=class_dim, + act='softmax', + name=fc_name, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform( + -stdv, stdv))) + else: + for block in range(len(depth)): + for i in range(depth[block]): + conv_name = "res" + str(block + 2) + chr(97 + i) + conv_name = prefix_name + conv_name + conv = self.basic_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + is_first=block == i == 0, + name=conv_name) + + pool = fluid.layers.pool2d( + input=conv, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + fc_name = fc_name if fc_name is None else prefix_name + fc_name + out = fluid.layers.fc( + input=pool, + size=class_dim, + act='softmax', + name=fc_name, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv))) + + return out + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False, + name=name + '.conv2d.output.1') + if self.prefix_name == '': + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + else: + if name.split("_")[1] == "conv1": + bn_name = name.split("_", 1)[0] + "_bn_" + name.split("_", + 1)[1] + else: + bn_name = name.split("_", 1)[0] + "_bn" + name.split("_", + 1)[1][3:] + return fluid.layers.batch_norm( + input=conv, + act=act, + name=bn_name + '.output.1', + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', ) + + def shortcut(self, input, ch_out, stride, is_first, name): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1 or is_first == True: + return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + else: + return input + + def bottleneck_block(self, input, num_filters, stride, name): + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name + "_branch2a") + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu', + name=name + "_branch2b") + conv2 = self.conv_bn_layer( + input=conv1, + num_filters=num_filters * 4, + filter_size=1, + act=None, + name=name + "_branch2c") + + short = self.shortcut( + input, + num_filters * 4, + stride, + is_first=False, + name=name + "_branch1") + + return fluid.layers.elementwise_add( + x=short, y=conv2, act='relu', name=name + ".add.output.5") + + def basic_block(self, input, num_filters, stride, is_first, name): + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=3, + act='relu', + stride=stride, + name=name + "_branch2a") + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + act=None, + name=name + "_branch2b") + short = self.shortcut( + input, num_filters, stride, is_first, name=name + "_branch1") + return fluid.layers.elementwise_add(x=short, y=conv1, act='relu') + + +def ResNet34(prefix_name=''): + model = ResNet(layers=34, prefix_name=prefix_name) + return model + + +def ResNet50(prefix_name=''): + model = ResNet(layers=50, prefix_name=prefix_name) + return model + + +def ResNet101(): + model = ResNet(layers=101) + return model + + +def ResNet152(): + model = ResNet(layers=152) + return model diff --git a/demo/nas/sa_nas_mobilenetv2_cifar10.py b/demo/nas/sa_nas_mobilenetv2_cifar10.py new file mode 100644 index 0000000000000000000000000000000000000000..249d4c214788c0ffc5a0d741dc48b4942ea5808b --- /dev/null +++ b/demo/nas/sa_nas_mobilenetv2_cifar10.py @@ -0,0 +1,122 @@ +import sys +sys.path.append('..') +import numpy as np +import argparse +import ast +import paddle +import paddle.fluid as fluid +from paddleslim.nas.search_space.search_space_factory import SearchSpaceFactory +from paddleslim.analysis import flops +from paddleslim.nas import SANAS + + +def create_data_loader(): + data = fluid.data(name='data', shape=[-1, 3, 32, 32], dtype='float32') + label = fluid.data(name='label', shape=[-1, 1], dtype='int64') + data_loader = fluid.io.DataLoader.from_generator( + feed_list=[data, label], + capacity=1024, + use_double_buffer=True, + iterable=True) + return data_loader, data, label + + +def init_sa_nas(config): + factory = SearchSpaceFactory() + space = factory.get_search_space(config) + model_arch = space.token2arch()[0] + main_program = fluid.Program() + startup_program = fluid.Program() + + with fluid.program_guard(main_program, startup_program): + data_loader, data, label = create_data_loader() + output = model_arch(data) + cost = fluid.layers.mean( + fluid.layers.softmax_with_cross_entropy( + logits=output, label=label)) + + base_flops = flops(main_program) + search_steps = 10000000 + + ### start a server and a client + sa_nas = SANAS(config, search_steps=search_steps, is_server=True) + + ### start a client, server_addr is server address + #sa_nas = SANAS(config, max_flops = base_flops, server_addr=("10.255.125.38", 18607), search_steps = search_steps, is_server=False) + + return sa_nas, search_steps + + +def search_mobilenetv2_cifar10(config, args): + sa_nas, search_steps = init_sa_nas(config) + for i in range(search_steps): + print('search step: ', i) + archs = sa_nas.next_archs()[0] + train_program = fluid.Program() + test_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + train_loader, data, label = create_data_loader() + output = archs(data) + cost = fluid.layers.mean( + fluid.layers.softmax_with_cross_entropy( + logits=output, label=label))[0] + test_program = train_program.clone(for_test=True) + + optimizer = fluid.optimizer.Momentum( + learning_rate=0.1, + momentum=0.9, + regularization=fluid.regularizer.L2Decay(1e-4)) + optimizer.minimize(cost) + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup_program) + train_reader = paddle.reader.shuffle( + paddle.dataset.cifar.train10(cycle=False), buf_size=1024) + train_loader.set_sample_generator( + train_reader, + batch_size=512, + places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) + + test_loader, _, _ = create_data_loader() + test_reader = paddle.dataset.cifar.test10(cycle=False) + test_loader.set_sample_generator( + test_reader, + batch_size=256, + drop_last=False, + places=fluid.cuda_places() if args.use_gpu else fluid.cpu_places()) + + for epoch_id in range(10): + for batch_id, data in enumerate(train_loader()): + loss = exe.run(train_program, + feed=data, + fetch_list=[cost.name])[0] + if batch_id % 5 == 0: + print('epoch: {}, batch: {}, loss: {}'.format( + epoch_id, batch_id, loss[0])) + + for data in test_loader(): + reward = exe.run(test_program, feed=data, + fetch_list=[cost.name])[0] + + print('reward:', reward) + sa_nas.reward(float(reward)) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser( + description='SA NAS MobileNetV2 cifar10 argparase') + parser.add_argument( + '--use_gpu', + type=ast.literal_eval, + default=True, + help='Whether to use GPU in train/test model.') + args = parser.parse_args() + print(args) + + config_info = {'input_size': 32, 'output_size': 1, 'block_num': 5} + config = [('MobileNetV2Space', config_info)] + + search_mobilenetv2_cifar10(config, args) diff --git a/demo/prune/train.py b/demo/prune/train.py new file mode 100644 index 0000000000000000000000000000000000000000..f7d7f5cd854848e097c625b37d9c73f79d2aa662 --- /dev/null +++ b/demo/prune/train.py @@ -0,0 +1,216 @@ +import os +import sys +import logging +import paddle +import argparse +import functools +import math +import time +import numpy as np +import paddle.fluid as fluid +from paddleslim.prune import Pruner +from paddleslim.common import get_logger +from paddleslim.analysis import flops +sys.path.append(sys.path[0] + "/../") +import models +from utility import add_arguments, print_arguments + +_logger = get_logger(__name__, level=logging.INFO) + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('batch_size', int, 64 * 4, "Minibatch size.") +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('model', str, "MobileNet", "The target model.") +add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretained", "Whether to use pretrained model.") +add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.") +add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.") +add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.") +add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") +add_arg('num_epochs', int, 120, "The number of total epochs.") +add_arg('total_images', int, 1281167, "The number of total training images.") +parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") +add_arg('config_file', str, None, "The config file for compression with yaml format.") +add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'") +add_arg('log_period', int, 10, "Log period in batches.") +add_arg('test_period', int, 10, "Test period in epoches.") +# yapf: enable + +model_list = [m for m in dir(models) if "__" not in m] + + +def piecewise_decay(args): + step = int(math.ceil(float(args.total_images) / args.batch_size)) + bd = [step * e for e in args.step_epochs] + lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)] + learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=args.momentum_rate, + regularization=fluid.regularizer.L2Decay(args.l2_decay)) + return optimizer + + +def cosine_decay(args): + step = int(math.ceil(float(args.total_images) / args.batch_size)) + learning_rate = fluid.layers.cosine_decay( + learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=args.momentum_rate, + regularization=fluid.regularizer.L2Decay(args.l2_decay)) + return optimizer + + +def create_optimizer(args): + if args.lr_strategy == "piecewise_decay": + return piecewise_decay(args) + elif args.lr_strategy == "cosine_decay": + return cosine_decay(args) + + +def compress(args): + train_reader = None + test_reader = None + if args.data == "mnist": + import paddle.dataset.mnist as reader + train_reader = reader.train() + val_reader = reader.test() + class_dim = 10 + image_shape = "1,28,28" + elif args.data == "imagenet": + import imagenet_reader as reader + train_reader = reader.train() + val_reader = reader.val() + class_dim = 1000 + image_shape = "3,224,224" + else: + raise ValueError("{} is not supported.".format(args.data)) + image_shape = [int(m) for m in image_shape.split(",")] + assert args.model in model_list, "{} is not in lists: {}".format( + args.model, model_list) + image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + # model definition + model = models.__dict__[args.model]() + out = model.net(input=image, class_dim=class_dim) + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + val_program = fluid.default_main_program().clone(for_test=True) + opt = create_optimizer(args) + opt.minimize(avg_cost) + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + if args.pretrained_model: + + def if_exist(var): + return os.path.exists( + os.path.join(args.pretrained_model, var.name)) + + fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist) + + val_reader = paddle.batch(val_reader, batch_size=args.batch_size) + train_reader = paddle.batch( + train_reader, batch_size=args.batch_size, drop_last=True) + + train_feeder = feeder = fluid.DataFeeder([image, label], place) + val_feeder = feeder = fluid.DataFeeder( + [image, label], place, program=val_program) + + def test(epoch, program): + batch_id = 0 + acc_top1_ns = [] + acc_top5_ns = [] + for data in val_reader(): + start_time = time.time() + acc_top1_n, acc_top5_n = exe.run( + program, + feed=train_feeder.feed(data), + fetch_list=[acc_top1.name, acc_top5.name]) + end_time = time.time() + if batch_id % args.log_period == 0: + _logger.info( + "Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}". + format(epoch, batch_id, + np.mean(acc_top1_n), + np.mean(acc_top5_n), end_time - start_time)) + acc_top1_ns.append(np.mean(acc_top1_n)) + acc_top5_ns.append(np.mean(acc_top5_n)) + batch_id += 1 + + _logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}". + format(epoch, + np.mean(np.array(acc_top1_ns)), + np.mean(np.array(acc_top5_ns)))) + + def train(epoch, program): + + build_strategy = fluid.BuildStrategy() + exec_strategy = fluid.ExecutionStrategy() + train_program = fluid.compiler.CompiledProgram( + program).with_data_parallel( + loss_name=avg_cost.name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + + batch_id = 0 + for data in train_reader(): + start_time = time.time() + loss_n, acc_top1_n, acc_top5_n = exe.run( + train_program, + feed=train_feeder.feed(data), + fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name]) + end_time = time.time() + loss_n = np.mean(loss_n) + acc_top1_n = np.mean(acc_top1_n) + acc_top5_n = np.mean(acc_top5_n) + if batch_id % args.log_period == 0: + _logger.info( + "epoch[{}]-batch[{}] - loss: {}; acc_top1: {}; acc_top5: {}; time: {}". + format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n, + end_time - start_time)) + batch_id += 1 + + params = [] + for param in fluid.default_main_program().global_block().all_parameters(): + if "_sep_weights" in param.name: + params.append(param.name) + _logger.info("fops before pruning: {}".format( + flops(fluid.default_main_program()))) + pruner = Pruner() + pruned_val_program = pruner.prune( + val_program, + fluid.global_scope(), + params=params, + ratios=[0.33] * len(params), + place=place, + only_graph=True) + + pruned_program = pruner.prune( + fluid.default_main_program(), + fluid.global_scope(), + params=params, + ratios=[0.33] * len(params), + place=place) + + _logger.info("fops after pruning: {}".format(flops(pruned_program))) + + for i in range(args.num_epochs): + train(i, pruned_program) + if i % args.test_period == 0: + test(i, pruned_val_program) + + +def main(): + args = parser.parse_args() + print_arguments(args) + compress(args) + + +if __name__ == '__main__': + main() diff --git a/demo/utility.py b/demo/utility.py new file mode 100644 index 0000000000000000000000000000000000000000..dd52f69457c9f8d94920b85dc09b58ff8e605a64 --- /dev/null +++ b/demo/utility.py @@ -0,0 +1,156 @@ +"""Contains common utility functions.""" +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import distutils.util +import os +import numpy as np +import six +import logging +import paddle.fluid as fluid +import paddle.compat as cpt +from paddle.fluid import core +from paddle.fluid.framework import Program + +logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + + +def print_arguments(args): + """Print argparse's arguments. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + parser.add_argument("name", default="Jonh", type=str, help="User name.") + args = parser.parse_args() + print_arguments(args) + + :param args: Input argparse.Namespace for printing. + :type args: argparse.Namespace + """ + print("----------- Configuration Arguments -----------") + for arg, value in sorted(six.iteritems(vars(args))): + print("%s: %s" % (arg, value)) + print("------------------------------------------------") + + +def add_arguments(argname, type, default, help, argparser, **kwargs): + """Add argparse's argument. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + add_argument("name", str, "Jonh", "User name.", parser) + args = parser.parse_args() + """ + type = distutils.util.strtobool if type == bool else type + argparser.add_argument( + "--" + argname, + default=default, + type=type, + help=help + ' Default: %(default)s.', + **kwargs) + + +def save_persistable_nodes(executor, dirname, graph): + """ + Save persistable nodes to the given directory by the executor. + + Args: + executor(Executor): The executor to run for saving node values. + dirname(str): The directory path. + graph(IrGraph): All the required persistable nodes in the graph will be saved. + """ + persistable_node_names = set() + persistable_nodes = [] + all_persistable_nodes = graph.all_persistable_nodes() + for node in all_persistable_nodes: + name = cpt.to_text(node.name()) + if name not in persistable_node_names: + persistable_node_names.add(name) + persistable_nodes.append(node) + program = Program() + var_list = [] + for node in persistable_nodes: + var_desc = node.var() + if var_desc.type() == core.VarDesc.VarType.RAW or \ + var_desc.type() == core.VarDesc.VarType.READER: + continue + var = program.global_block().create_var( + name=var_desc.name(), + shape=var_desc.shape(), + dtype=var_desc.dtype(), + type=var_desc.type(), + lod_level=var_desc.lod_level(), + persistable=var_desc.persistable()) + var_list.append(var) + fluid.io.save_vars(executor=executor, dirname=dirname, vars=var_list) + + +def load_persistable_nodes(executor, dirname, graph): + """ + Load persistable node values from the given directory by the executor. + + Args: + executor(Executor): The executor to run for loading node values. + dirname(str): The directory path. + graph(IrGraph): All the required persistable nodes in the graph will be loaded. + """ + persistable_node_names = set() + persistable_nodes = [] + all_persistable_nodes = graph.all_persistable_nodes() + for node in all_persistable_nodes: + name = cpt.to_text(node.name()) + if name not in persistable_node_names: + persistable_node_names.add(name) + persistable_nodes.append(node) + program = Program() + var_list = [] + + def _exist(var): + return os.path.exists(os.path.join(dirname, var.name)) + + def _load_var(name, scope): + return np.array(scope.find_var(name).get_tensor()) + + def _store_var(name, array, scope, place): + tensor = scope.find_var(name).get_tensor() + tensor.set(array, place) + + for node in persistable_nodes: + var_desc = node.var() + if var_desc.type() == core.VarDesc.VarType.RAW or \ + var_desc.type() == core.VarDesc.VarType.READER: + continue + var = program.global_block().create_var( + name=var_desc.name(), + shape=var_desc.shape(), + dtype=var_desc.dtype(), + type=var_desc.type(), + lod_level=var_desc.lod_level(), + persistable=var_desc.persistable()) + if _exist(var): + var_list.append(var) + else: + _logger.info("Cannot find the var %s!!!" % (node.name())) + fluid.io.load_vars(executor=executor, dirname=dirname, vars=var_list) diff --git a/paddleslim/analysis/__init__.py b/paddleslim/analysis/__init__.py index 9d0531501ca43921438ee5b2fb58ac0ad2396d1b..76904c8d548208adb29188f28e9e0c6a0f11f30d 100644 --- a/paddleslim/analysis/__init__.py +++ b/paddleslim/analysis/__init__.py @@ -11,3 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import flops as flops_module +from flops import * +import model_size as model_size_module +from model_size import * +import sensitive +from sensitive import * +__all__ = [] +__all__ += flops_module.__all__ +__all__ += model_size_module.__all__ +__all__ += sensitive.__all__ diff --git a/paddleslim/analysis/flops.py b/paddleslim/analysis/flops.py new file mode 100644 index 0000000000000000000000000000000000000000..583c8e6ebf1a41f95c5ca8aeab0a1297cd798948 --- /dev/null +++ b/paddleslim/analysis/flops.py @@ -0,0 +1,68 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from ..core import GraphWrapper + +__all__ = ["flops"] + + +def flops(program): + """ + Get FLOPS of target graph. + Args: + program(Program): The program used to calculate FLOPS. + """ + graph = GraphWrapper(program) + return _graph_flops(graph) + + +def _graph_flops(graph, only_conv=False): + assert isinstance(graph, GraphWrapper) + flops = 0 + for op in graph.ops(): + if op.type() in ['conv2d', 'depthwise_conv2d']: + filter_shape = op.inputs("Filter")[0].shape() + input_shape = op.inputs("Input")[0].shape() + output_shape = op.outputs("Output")[0].shape() + c_out, c_in, k_h, k_w = filter_shape + _, _, h_out, w_out = output_shape + groups = op.attr("groups") + kernel_ops = k_h * k_w * (c_in / groups) + if len(op.inputs("Bias")) > 0: + with_bias = 1 + else: + with_bias = 0 + flops += 2 * h_out * w_out * c_out * (kernel_ops + with_bias) + elif op.type() == 'pool2d' and not only_conv: + input_shape = op.inputs("X")[0].shape() + output_shape = op.outputs("Out")[0].shape() + _, c_out, h_out, w_out = output_shape + k_size = op.attr("ksize") + flops += h_out * w_out * c_out * (k_size[0]**2) + + elif op.type() == 'mul' and not only_conv: + x_shape = list(op.inputs("X")[0].shape()) + y_shape = op.inputs("Y")[0].shape() + if x_shape[0] == -1: + x_shape[0] = 1 + flops += 2 * x_shape[0] * x_shape[1] * y_shape[1] + + elif op.type() in ['relu', 'sigmoid', 'batch_norm'] and not only_conv: + input_shape = list(op.inputs("X")[0].shape()) + if input_shape[0] == -1: + input_shape[0] = 1 + flops += np.product(input_shape) + + return flops diff --git a/paddleslim/analysis/model_size.py b/paddleslim/analysis/model_size.py new file mode 100644 index 0000000000000000000000000000000000000000..34574d5d53764810185112d7122aeb5b99d74682 --- /dev/null +++ b/paddleslim/analysis/model_size.py @@ -0,0 +1,31 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from ..core import GraphWrapper + +__all__ = ["model_size"] + + +def model_size(program): + """ + Get total value numbers of all parameters. + Args: + program(Program): The program used to calculate model size. + """ + size = 0 + for block in program.blocks: + for param in block.all_parameters(): + size += np.product(param.shape) + return size diff --git a/paddleslim/analysis/sensitive.py b/paddleslim/analysis/sensitive.py new file mode 100644 index 0000000000000000000000000000000000000000..09dd2a875ae21caf64034cf79421d7cc1661b817 --- /dev/null +++ b/paddleslim/analysis/sensitive.py @@ -0,0 +1,111 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import logging +import pickle +import numpy as np +from ..core import GraphWrapper +from ..common import get_logger +from ..prune import Pruner + +_logger = get_logger(__name__, level=logging.INFO) + +__all__ = ["sensitivity"] + + +def sensitivity(program, + scope, + place, + param_names, + eval_func, + sensitivities_file=None, + step_size=0.2): + + graph = GraphWrapper(program) + sensitivities = _load_sensitivities(sensitivities_file) + + for name in param_names: + if name not in sensitivities: + size = graph.var(name).shape()[0] + sensitivities[name] = { + 'pruned_percent': [], + 'loss': [], + 'size': size + } + baseline = None + for name in sensitivities: + ratio = step_size + while ratio < 1: + ratio = round(ratio, 2) + if ratio in sensitivities[name]['pruned_percent']: + _logger.debug('{}, {} has computed.'.format(name, ratio)) + ratio += step_size + continue + if baseline is None: + baseline = eval_func(graph.program, scope) + + param_backup = {} + pruner = Pruner() + pruned_program = pruner.prune( + program=graph.program, + scope=scope, + params=[name], + ratios=[ratio], + place=place, + lazy=True, + only_graph=False, + param_backup=param_backup) + pruned_metric = eval_func(pruned_program, scope) + loss = (baseline - pruned_metric) / baseline + _logger.info("pruned param: {}; {}; loss={}".format(name, ratio, + loss)) + sensitivities[name]['pruned_percent'].append(ratio) + sensitivities[name]['loss'].append(loss) + _save_sensitivities(sensitivities, sensitivities_file) + + # restore pruned parameters + for param_name in param_backup.keys(): + param_t = scope.find_var(param_name).get_tensor() + param_t.set(param_backup[param_name], place) + ratio += step_size + return sensitivities + + +def _load_sensitivities(sensitivities_file): + """ + Load sensitivities from file. + """ + sensitivities = {} + if sensitivities_file and os.path.exists(sensitivities_file): + with open(sensitivities_file, 'rb') as f: + if sys.version_info < (3, 0): + sensitivities = pickle.load(f) + else: + sensitivities = pickle.load(f, encoding='bytes') + + for param in sensitivities: + sensitivities[param]['pruned_percent'] = [ + round(p, 2) for p in sensitivities[param]['pruned_percent'] + ] + return sensitivities + + +def _save_sensitivities(sensitivities, sensitivities_file): + """ + Save sensitivities into file. + """ + with open(sensitivities_file, 'wb') as f: + pickle.dump(sensitivities, f) diff --git a/paddleslim/common/__init__.py b/paddleslim/common/__init__.py index 9d0531501ca43921438ee5b2fb58ac0ad2396d1b..98b314ab6d144924bff6b68e3fb176ce73583f5c 100644 --- a/paddleslim/common/__init__.py +++ b/paddleslim/common/__init__.py @@ -11,3 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import controller +from controller import * +import sa_controller +from sa_controller import * +import log_helper +from log_helper import * +import controller_server +from controller_server import * +import controller_client +from controller_client import * +import lock_utils +from lock_utils import * + +__all__ = [] +__all__ += controller.__all__ +__all__ += sa_controller.__all__ +__all__ += controller_server.__all__ +__all__ += controller_client.__all__ +__all__ += lock_utils.__all__ diff --git a/paddleslim/common/controller.py b/paddleslim/common/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..8c30f49c3aec27a326417554bac3163789342ff6 --- /dev/null +++ b/paddleslim/common/controller.py @@ -0,0 +1,51 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The controller used to search hyperparameters or neural architecture""" + +import copy +import math +import numpy as np + +__all__ = ['EvolutionaryController'] + + +class EvolutionaryController(object): + """Abstract controller for all evolutionary searching method. + """ + + def __init__(self, *args, **kwargs): + pass + + def update(self, tokens, reward): + """Update the status of controller according current tokens and reward. + Args: + tokens(list): A solution of searching task. + reward(list): The reward of tokens. + """ + raise NotImplementedError('Abstract method.') + + def reset(self, range_table, constrain_func=None): + """Reset the controller. + Args: + range_table(list): It is used to define the searching space of controller. + The tokens[i] generated by controller should be in [0, range_table[i]). + constrain_func(function): It is used to check whether tokens meet the constraint. + None means there is no constraint. Default: None. + """ + raise NotImplementedError('Abstract method.') + + def next_tokens(self): + """Generate new tokens. + """ + raise NotImplementedError('Abstract method.') diff --git a/paddleslim/common/controller_client.py b/paddleslim/common/controller_client.py new file mode 100644 index 0000000000000000000000000000000000000000..ad989dd16014fa8e6fa1495516e81048324fb826 --- /dev/null +++ b/paddleslim/common/controller_client.py @@ -0,0 +1,68 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import socket +from log_helper import get_logger + +__all__ = ['ControllerClient'] + +_logger = get_logger(__name__, level=logging.INFO) + + +class ControllerClient(object): + """ + Controller client. + """ + + def __init__(self, server_ip=None, server_port=None, key=None): + """ + Args: + server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None. + server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0. + key(str): The key used to identify legal agent for controller server. Default: "light-nas" + """ + self.server_ip = server_ip + self.server_port = server_port + self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._key = key + + def update(self, tokens, reward): + """ + Update the controller according to latest tokens and reward. + Args: + tokens(list): The tokens generated in last step. + reward(float): The reward of tokens. + """ + socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + socket_client.connect((self.server_ip, self.server_port)) + tokens = ",".join([str(token) for token in tokens]) + socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward) + .encode()) + response = socket_client.recv(1024).decode() + if response.strip('\n').split("\t") == "ok": + return True + else: + return False + + def next_tokens(self): + """ + Get next tokens. + """ + socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + socket_client.connect((self.server_ip, self.server_port)) + socket_client.send("next_tokens".encode()) + tokens = socket_client.recv(1024).decode() + tokens = [int(token) for token in tokens.strip("\n").split(",")] + return tokens diff --git a/paddleslim/common/controller_server.py b/paddleslim/common/controller_server.py new file mode 100644 index 0000000000000000000000000000000000000000..e4705a887727bf444b3ba285165d27df59a1ed57 --- /dev/null +++ b/paddleslim/common/controller_server.py @@ -0,0 +1,127 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import socket +from .log_helper import get_logger +from threading import Thread +from .lock_utils import lock, unlock + +__all__ = ['ControllerServer'] + +_logger = get_logger(__name__, level=logging.INFO) + + +class ControllerServer(object): + """ + The controller wrapper with a socket server to handle the request of search agent. + """ + + def __init__(self, + controller=None, + address=('', 0), + max_client_num=100, + search_steps=None, + key=None): + """ + Args: + controller(slim.searcher.Controller): The controller used to generate tokens. + address(tuple): The address of current server binding with format (ip, port). Default: ('', 0). + which means setting ip automatically + max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100. + search_steps(int): The total steps of searching. None means never stopping. Default: None + """ + self._controller = controller + self._address = address + self._max_client_num = max_client_num + self._search_steps = search_steps + self._closed = False + self._port = address[1] + self._ip = address[0] + self._key = key + self._socket_file = "./controller_server.socket" + + def start(self): + open(self._socket_file, 'a').close() + socket_file = open(self._socket_file, 'r+') + lock(socket_file) + tid = socket_file.readline() + if tid == '': + _logger.info("start controller server...") + tid = self._start() + socket_file.write("tid: {}\nip: {}\nport: {}\n".format( + tid, self._ip, self._port)) + _logger.info("started controller server...") + unlock(socket_file) + socket_file.close() + + def _start(self): + self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket_server.bind(self._address) + self._socket_server.listen(self._max_client_num) + self._port = self._socket_server.getsockname()[1] + self._ip = self._socket_server.getsockname()[0] + _logger.info("ControllerServer - listen on: [{}:{}]".format( + self._ip, self._port)) + thread = Thread(target=self.run) + thread.start() + return str(thread) + + def close(self): + """Close the server.""" + self._closed = True + os.remove(self._socket_file) + _logger.info("server closed!") + + def port(self): + """Get the port.""" + return self._port + + def ip(self): + """Get the ip.""" + return self._ip + + def run(self): + _logger.info("Controller Server run...") + try: + while ((self._search_steps is None) or + (self._controller._iter < + (self._search_steps))) and not self._closed: + conn, addr = self._socket_server.accept() + message = conn.recv(1024).decode() + if message.strip("\n") == "next_tokens": + tokens = self._controller.next_tokens() + tokens = ",".join([str(token) for token in tokens]) + conn.send(tokens.encode()) + else: + _logger.debug("recv message from {}: [{}]".format(addr, + message)) + messages = message.strip('\n').split("\t") + if (len(messages) < 3) or (messages[0] != self._key): + _logger.debug("recv noise from {}: [{}]".format( + addr, message)) + continue + tokens = messages[1] + reward = messages[2] + tokens = [int(token) for token in tokens.split(",")] + self._controller.update(tokens, float(reward)) + response = "ok" + conn.send(response.encode()) + _logger.debug("send message to {}: [{}]".format(addr, + tokens)) + conn.close() + finally: + self._socket_server.close() + self.close() diff --git a/paddleslim/common/lock_utils.py b/paddleslim/common/lock_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9daf4f3f6e842609a39fd286dfa49eb705c631a7 --- /dev/null +++ b/paddleslim/common/lock_utils.py @@ -0,0 +1,38 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +__all__ = ['lock', 'unlock'] + +if os.name == 'nt': + + def lock(file): + raise NotImplementedError('Windows is not supported.') + + def unlock(file): + raise NotImplementedError('Windows is not supported.') + +elif os.name == 'posix': + from fcntl import flock, LOCK_EX, LOCK_UN + + def lock(file): + """Lock the file in local file system.""" + flock(file.fileno(), LOCK_EX) + + def unlock(file): + """Unlock the file in local file system.""" + flock(file.fileno(), LOCK_UN) +else: + raise RuntimeError("File Locker only support NT and Posix platforms!") diff --git a/paddleslim/common/log_helper.py b/paddleslim/common/log_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..18000ce4ec6c472914de49a053e960c02cfd8e32 --- /dev/null +++ b/paddleslim/common/log_helper.py @@ -0,0 +1,48 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import logging + +__all__ = ['get_logger'] + + +def get_logger(name, level, fmt='%(asctime)s-%(levelname)s: %(message)s'): + """ + Get logger from logging with given name, level and format without + setting logging basicConfig. For setting basicConfig in paddle + will disable basicConfig setting after import paddle. + Args: + name (str): The logger name. + level (logging.LEVEL): The base level of the logger + fmt (str): Format of logger output + Returns: + logging.Logger: logging logger with given setttings + Examples: + .. code-block:: python + logger = log_helper.get_logger(__name__, logging.INFO, + fmt='%(asctime)s-%(levelname)s: %(message)s') + """ + + logger = logging.getLogger(name) + logger.setLevel(level) + handler = logging.StreamHandler() + if fmt: + formatter = logging.Formatter(fmt=fmt) + handler.setFormatter(formatter) + + logger.addHandler(handler) + logger.propagate = 0 + return logger diff --git a/paddleslim/common/sa_controller.py b/paddleslim/common/sa_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..b619b818a3208d740c1ddb6753cf5931f3d058f5 --- /dev/null +++ b/paddleslim/common/sa_controller.py @@ -0,0 +1,113 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The controller used to search hyperparameters or neural architecture""" + +import copy +import math +import logging +import numpy as np +from .controller import EvolutionaryController +from log_helper import get_logger + +__all__ = ["SAController"] + +_logger = get_logger(__name__, level=logging.INFO) + + +class SAController(EvolutionaryController): + """Simulated annealing controller.""" + + def __init__(self, + range_table=None, + reduce_rate=0.85, + init_temperature=1024, + max_iter_number=300, + init_tokens=None, + constrain_func=None): + """Initialize. + Args: + range_table(list): Range table. + reduce_rate(float): The decay rate of temperature. + init_temperature(float): Init temperature. + max_iter_number(int): max iteration number. + init_tokens(list): The initial tokens. + constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None. + """ + super(SAController, self).__init__() + self._range_table = range_table + assert isinstance(self._range_table, tuple) and ( + len(self._range_table) == 2) + self._reduce_rate = reduce_rate + self._init_temperature = init_temperature + self._max_iter_number = max_iter_number + self._reward = -1 + self._tokens = init_tokens + self._constrain_func = constrain_func + self._max_reward = -1 + self._best_tokens = None + self._iter = 0 + + def __getstate__(self): + d = {} + for key in self.__dict__: + if key != "_constrain_func": + d[key] = self.__dict__[key] + return d + + def update(self, tokens, reward): + """ + Update the controller according to latest tokens and reward. + Args: + tokens(list): The tokens generated in last step. + reward(float): The reward of tokens. + """ + self._iter += 1 + temperature = self._init_temperature * self._reduce_rate**self._iter + if (reward > self._reward) or (np.random.random() <= math.exp( + (reward - self._reward) / temperature)): + self._reward = reward + self._tokens = tokens + if reward > self._max_reward: + self._max_reward = reward + self._best_tokens = tokens + _logger.info( + "Controller - iter: {}; current_reward: {}; current tokens: {}". + format(self._iter, self._reward, self._tokens)) + + def next_tokens(self, control_token=None): + """ + Get next tokens. + """ + if control_token: + tokens = control_token[:] + else: + tokens = self._tokens + new_tokens = tokens[:] + index = int(len(self._range_table[0]) * np.random.random()) + new_tokens[index] = np.random.randint(self._range_table[0][index], + self._range_table[1][index] + 1) + _logger.debug("change index[{}] from {} to {}".format(index, tokens[ + index], new_tokens[index])) + if self._constrain_func is None: + return new_tokens + for _ in range(self._max_iter_number): + if not self._constrain_func(new_tokens): + index = int(len(self._range_table[0]) * np.random.random()) + new_tokens = tokens[:] + new_tokens[index] = np.random.randint( + self._range_table[0][index], + self._range_table[1][index] + 1) + else: + break + return new_tokens diff --git a/paddleslim/core/__init__.py b/paddleslim/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb4b2bf34f3aed2f74ce4fd5936527b17737181 --- /dev/null +++ b/paddleslim/core/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import graph_wrapper +from .graph_wrapper import * +from . import registry +from .registry import * + +__all__ = graph_wrapper.__all__ +__all__ += registry.__all__ diff --git a/paddleslim/core/graph_wrapper.py b/paddleslim/core/graph_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..72de894a2e4345c32e7a4eee2f35249b77c2f467 --- /dev/null +++ b/paddleslim/core/graph_wrapper.py @@ -0,0 +1,355 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import copy +import pickle +import numpy as np +from collections import OrderedDict +from collections import Iterable +from paddle.fluid.framework import Program, program_guard, Parameter, Variable + +__all__ = ['GraphWrapper', 'VarWrapper', 'OpWrapper'] + +OPTIMIZER_OPS = [ + 'momentum', + 'lars_momentum', + 'adagrad', + 'adam', + 'adamax', + 'dpsgd', + 'decayed_adagrad', + 'adadelta', + 'rmsprop', +] + + +class VarWrapper(object): + def __init__(self, var, graph): + assert isinstance(var, Variable) + assert isinstance(graph, GraphWrapper) + self._var = var + self._graph = graph + + def __eq__(self, v): + """ + Overwrite this function for ...in... syntax in python. + """ + return self._var.name == v._var.name + + def name(self): + """ + Get the name of the variable. + """ + return self._var.name + + def shape(self): + """ + Get the shape of the varibale. + """ + return self._var.shape + + def set_shape(self, shape): + """ + Set the shape of the variable. + """ + self._var.desc.set_shape(shape) + + def inputs(self): + """ + Get all the operators that use this variable as output. + Returns: + list: A list of operators. + """ + ops = [] + for op in self._graph.ops(): + if self in op.all_outputs(): + ops.append(op) + return ops + + def outputs(self): + """ + Get all the operators that use this variable as input. + Returns: + list: A list of operators. + """ + ops = [] + for op in self._graph.ops(): + if self in op.all_inputs(): + ops.append(op) + return ops + + +class OpWrapper(object): + def __init__(self, op, graph): + assert isinstance(graph, GraphWrapper) + self._op = op + self._graph = graph + + def __eq__(self, op): + """ + Overwrite this function for ...in... syntax in python. + """ + return self.idx() == op.idx() + + def all_inputs(self): + """ + Get all the input variables of this operator. + """ + return [ + self._graph.var(var_name) for var_name in self._op.input_arg_names + ] + + def all_outputs(self): + """ + Get all the output variables of this operator. + """ + return [ + self._graph.var(var_name) for var_name in self._op.output_arg_names + ] + + def idx(self): + """ + Get the id of this operator. + """ + return self._op.idx + + def type(self): + """ + Get the type of this operator. + """ + return self._op.type + + def is_bwd_op(self): + """ + Whether this operator is backward op. + """ + return self.type().endswith('_grad') + + def is_opt_op(self): + """ + Whether this operator is optimizer op. + """ + return self.type() in OPTIMIZER_OPS + + def inputs(self, name): + """ + Get all the varibales by the input name. + """ + return [self._graph.var(var_name) for var_name in self._op.input(name)] + + def outputs(self, name): + """ + Get all the varibales by the output name. + """ + return [ + self._graph.var(var_name) for var_name in self._op.output(name) + ] + + def set_attr(self, key, value): + """ + Set the value of attribute by attribute's name. + + Args: + key(str): the attribute name. + value(bool|int|str|float|list): the value of the attribute. + """ + self._op._set_attr(key, value) + + def attr(self, name): + """ + Get the attribute by name. + + Args: + name(str): the attribute name. + + Returns: + bool|int|str|float|list: The attribute value. The return value + can be any valid attribute type. + """ + return self._op.attr(name) + + +class GraphWrapper(object): + """ + It is a wrapper of paddle.fluid.framework.IrGraph with some special functions + for paddle slim framework. + """ + + def __init__(self, program=None, in_nodes=[], out_nodes=[]): + """ + Args: + program(framework.Program): A program with + in_nodes(dict): A dict to indicate the input nodes of the graph. + The key is user-defined and human-readable name. + The value is the name of Variable. + out_nodes(dict): A dict to indicate the input nodes of the graph. + The key is user-defined and human-readable name. + The value is the name of Variable. + """ + super(GraphWrapper, self).__init__() + self.program = Program() if program is None else program + self.persistables = {} + self.teacher_persistables = {} + for var in self.program.list_vars(): + if var.persistable: + self.persistables[var.name] = var + self.compiled_graph = None + in_nodes = [] if in_nodes is None else in_nodes + out_nodes = [] if out_nodes is None else out_nodes + self.in_nodes = OrderedDict(in_nodes) + self.out_nodes = OrderedDict(out_nodes) + self._attrs = OrderedDict() + + def all_parameters(self): + """ + Get all the parameters in this graph. + Returns: + list: A list of VarWrapper instances. + """ + params = [] + for block in self.program.blocks: + for param in block.all_parameters(): + params.append(VarWrapper(param, self)) + return params + + def is_parameter(self, var): + """ + Whether the given variable is parameter. + Args: + var(VarWrapper): The given varibale. + """ + return isinstance(var._var, Parameter) + + def is_persistable(self, var): + """ + Whether the given variable is persistable. + Args: + var(VarWrapper): The given varibale. + """ + return var._var.persistable + + def ops(self): + """ + Return all operator nodes included in the graph as a set. + """ + ops = [] + for block in self.program.blocks: + for op in block.ops: + ops.append(OpWrapper(op, self)) + return ops + + def vars(self): + """ + Get all the variables. + """ + return [VarWrapper(var, self) for var in self.program.list_vars()] + + def var(self, name): + """ + Get the variable by variable name. + """ + return VarWrapper(self.program.global_block().var(name), self) + + def clone(self, for_test=False): + """ + Clone a new graph from current graph. + Returns: + (GraphWrapper): The wrapper of a new graph. + """ + return GraphWrapper( + self.program.clone(for_test), + copy.deepcopy(self.in_nodes), copy.deepcopy(self.out_nodes)) + + def program(self): + """ + Get the program in current wrapper. + """ + return self.program + + def pre_ops(self, op): + """ + Get all the previous operators of target operator. + Args: + op(OpWrapper): Target operator.. + Returns: + list: A list of operators. + """ + ops = [] + for p in self.ops(): + for in_var in op.all_inputs(): + if in_var in p.all_outputs(): + ops.append(p) + return ops + + def next_ops(self, op): + """ + Get all the next operators of target operator. + Args: + op(OpWrapper): Target operator.. + Returns: + list: A list of operators. + """ + ops = [] + for p in self.ops(): + for out_var in op.all_outputs(): + if out_var in p.all_inputs(): + ops.append(p) + return ops + + def get_param_by_op(self, op): + """ + Get the parameters used by target operator. + """ + assert isinstance(op, OpWrapper) + params = [] + for var in op.all_inputs(): + if isinstance(var._var, Parameter): + params.append(var) + assert len(params) > 0 + return params + + def numel_params(self): + """ + Get the number of elements in all parameters. + """ + ret = 0 + for param in self.all_parameters(): + ret += np.product(param.shape()) + return ret + + def update_param_shape(self, scope): + """ + Update the shape of parameters in the graph according to tensors in scope. + It is used after loading pruned parameters from file. + """ + for param in self.all_parameters(): + tensor_shape = np.array( + scope.find_var(param.name()).get_tensor()).shape + param.set_shape(tensor_shape) + + def infer_shape(self): + """ + Update the groups of convolution layer according to current filters. + It is used after loading pruned parameters from file. + """ + for op in self.ops(): + if op.type() != 'conditional_block': + op._op.desc.infer_shape(op._op.block.desc) + + def update_groups_of_conv(self): + for op in self.ops(): + if op.type() == 'depthwise_conv2d' or op.type( + ) == 'depthwise_conv2d_grad': + op.set_attr('groups', op.inputs('Filter')[0].shape()[0]) diff --git a/paddleslim/core/registry.py b/paddleslim/core/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..208dceca1ff9958591b7e427d47124f3c57e4d5b --- /dev/null +++ b/paddleslim/core/registry.py @@ -0,0 +1,53 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +__all__ = ["Registry"] + + +class Registry(object): + def __init__(self, name): + self._name = name + self._module_dict = dict() + + def __repr__(self): + format_str = self.__class__.__name__ + '(name={}, items={})'.format( + self._name, list(self._module_dict.keys())) + return format_str + + @property + def name(self): + return self._name + + @property + def module_dict(self): + return self._module_dict + + def get(self, key): + return self._module_dict.get(key, None) + + def _register_module(self, module_class): + if not inspect.isclass(module_class): + raise TypeError('module must be a class, but receive {}.'.format( + type(module_class))) + module_name = module_class.__name__ + if module_name in self._module_dict: + raise KeyError('{} is already registered in {}.'.format( + module_name, self.name)) + self._module_dict[module_name] = module_class + + def register(self, cls): + self._register_module(cls) + return cls diff --git a/paddleslim/dist/mp_distiller.py b/paddleslim/dist/mp_distiller.py new file mode 100755 index 0000000000000000000000000000000000000000..ff15f5f17dd130edfd6fc5bfa1d8c358da2a5ae2 --- /dev/null +++ b/paddleslim/dist/mp_distiller.py @@ -0,0 +1,223 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import logging +import numpy as np +from six.moves.queue import Queue + +import paddle.fluid as fluid +from paddle.fluid.framework import Variable +from paddle.fluid.reader import DataLoaderBase +from paddle.fluid.core import EOFException +from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +__all__ = ['Knowledge'] + + +class Knowledge(object): + """ + The knowledge class describes how to extract and store the dark knowledge + of the teacher model, and how the student model learns these dark knowledge. + """ + + def __init__(self, + path, + items, + reduce_strategy={'type': 'sum', + 'key': 'image'}): + """Init a knowledge instance. + Args: + path(list, str, optional): Specifies the storage path of the knowledge, + supports AFS/HDFS, local file system, and memory. + items(list): Save the tensor of the specified name + reduce_strategy(dict, optional): The policy for performing the reduce + operation. If it is set to None, + the reduce operation is not performed. + reduce_strategy.type(str): Type of reduce operation. + reduce_strategy.key(str): The key of the reduce operation. + It is an element in the item. + """ + assert (isinstance(path, list) or isinstance(path, str) or + (path is None)), "path type should be list or str or None" + assert (isinstance(items, list)), "items should be a list" + assert (isinstance(reduce_strategy, + dict)), "reduce_strategy should be a dict" + self.path = path + if isinstance(self.path, list): + self.write_type = 'HDFS/AFS' + assert ( + len(self.path) == 4 and isinstance(self.path[0], str) and + isinstance(self.path[1], str) and + isinstance(self.path[2], str) and isinstance(self.path[3], str) + ), "path should contains four str, ['local hadoop home', 'fs.default.name', 'hadoop.job.ugi', 'FS path']" + + hadoop_home = self.path[0] + configs = { + "fs.default.name": self.path[1], + "hadoop.job.ugi": self.path[2] + } + self.client = HDFSClient(hadoop_home, configs) + assert ( + self.client.is_exist(self.path[3]) == True + ), "Plese make sure your hadoop confiuration is correct and FS path exists" + + self.hdfs_local_path = "./teacher_knowledge" + if not os.path.exists(self.hdfs_local_path): + os.mkdir(self.hdfs_local_path) + elif isinstance(self.path, str): + self.write_type = "LocalFS" + if not os.path.exists(path): + raise ValueError("The local path [%s] does not exist." % + (path)) + else: + self.write_type = "MEM" + self.knowledge_queue = Queue(64) + + self.items = items + self.reduce_strategy = reduce_strategy + + def _write(self, data): + if self.write_type == 'HDFS/AFS': + file_name = 'knowledge_' + str(self.file_cnt) + file_path = os.path.join(self.hdfs_local_path, file_name) + file_path += ".npy" + np.save(file_path, data) + self.file_cnt += 1 + self.client.upload(self.path[3], file_path) + logger.info('{}.npy pushed to HDFS/AFS: {}'.format(file_name, + self.path[3])) + + elif self.write_type == 'LocalFS': + file_name = 'knowledge_' + str(self.file_cnt) + file_path = os.path.join(self.path, file_name) + np.save(file_path, data) + logger.info('{}.npy saved'.format(file_name)) + self.file_cnt += 1 + + else: + self.knowledge_queue.put(data) + logger.info('{} pushed to Queue'.format(file_name)) + + def run(self, teacher_program, exe, place, scope, reader, inputs, outputs, + call_back): + """Start teacher model to do information. + Args: + teacher_program(Program): teacher program. + scope(Scope): The scope used to execute the teacher, + which contains the initialized variables. + reader(reader): The data reader used by the teacher. + inputs(list): The name of variables to feed the teacher program. + outputs(list): Need to write to the variable instance's names of + the Knowledge instance, which needs to correspond + to the Knowledge's items. + call_back(func, optional): The callback function that handles the + outputs of the teacher, which is none by default, + that is, the output of the teacher is concat directly. + Return: + (bool): Whether the teacher task was successfully registered and started + """ + assert (isinstance( + teacher_program, + fluid.Program)), "teacher_program should be a fluid.Program" + assert (isinstance(inputs, list)), "inputs should be a list" + assert (isinstance(outputs, list)), "outputs should be a list" + assert (len(self.items) == len(outputs) + ), "the length of outputs list should be equal with items list" + assert (callable(call_back) or (call_back is None) + ), "call_back should be a callable function or NoneType." + + for var in teacher_program.list_vars(): + var.stop_gradient = True + + compiled_teacher_program = fluid.compiler.CompiledProgram( + teacher_program) + self.file_cnt = 0 + if isinstance(reader, Variable) or ( + isinstance(reader, DataLoaderBase) and (not reader.iterable)): + reader.start() + try: + while True: + logits = exe.run(compiled_teacher_program, + scope=scope, + fetch_list=outputs, + feed=None) + knowledge = dict() + for index, array in enumerate(logits): + knowledge[self.items[index]] = array + self._write(knowledge) + except EOFException: + reader.reset() + + else: + if not isinstance(reader, DataLoaderBase): + feeder = fluid.DataFeeder( + feed_list=inputs, place=place, program=teacher_program) + for batch_id, data in enumerate(reader()): + if not isinstance(reader, DataLoaderBase): + data = feeder.feed(data) + logits = exe.run(compiled_teacher_program, + scope=scope, + fetch_list=outputs, + feed=data) + knowledge = dict() + for index, array in enumerate(logits): + knowledge[self.items[index]] = array + self._write(knowledge) + return True + + def dist(self, student_program, losses): + """Building the distillation network + Args: + student_program(Program): student program. + losses(list, optional): The losses need to add. If set to None + does not add any loss. + Return: + (Program): Program for distillation. + (startup_program): Program for initializing distillation network. + (reader): Data reader for distillation training. + (Variable): Loss of distillation training + """ + + def loss(self, loss_func, *variables): + """User-defined loss + Args: + loss_func(func): Function used to define loss. + *variables(list): Variable name list. + Return: + (Variable): Distillation loss. + """ + pass + + def fsp_loss(self): + """fsp loss + """ + pass + + def l2_loss(self): + """l2 loss + """ + pass + + def softlabel_loss(self): + """softlabel_loss + """ + pass diff --git a/paddleslim/nas/__init__.py b/paddleslim/nas/__init__.py index 9d0531501ca43921438ee5b2fb58ac0ad2396d1b..f11948f6bcbdd3d52e334bed3b06510e226825bc 100644 --- a/paddleslim/nas/__init__.py +++ b/paddleslim/nas/__init__.py @@ -11,3 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import search_space +from search_space import * +import sa_nas +from sa_nas import * + +__all__ = [] +__all__ += search_space.__all__ +__all__ += sa_nas.__all__ diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py new file mode 100644 index 0000000000000000000000000000000000000000..6d84df919881fceb8d2a26c0e03c3cbe8a0536aa --- /dev/null +++ b/paddleslim/nas/sa_nas.py @@ -0,0 +1,118 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket +import logging +import numpy as np +import paddle.fluid as fluid +from ..core import VarWrapper, OpWrapper, GraphWrapper +from ..common import SAController +from ..common import get_logger +from ..analysis import flops + +from ..common import ControllerServer +from ..common import ControllerClient +from .search_space import SearchSpaceFactory + +__all__ = ["SANAS"] + +_logger = get_logger(__name__, level=logging.INFO) + + +class SANAS(object): + def __init__(self, + configs, + server_addr=("", 8881), + init_temperature=100, + reduce_rate=0.85, + search_steps=300, + key="sa_nas", + is_server=False): + """ + Search a group of ratios used to prune program. + Args: + configs(list): A list of search space configuration with format (key, input_size, output_size, block_num). + `key` is the name of search space with data type str. `input_size` and `output_size` are + input size and output size of searched sub-network. `block_num` is the number of blocks in searched network. + server_addr(tuple): A tuple of server ip and server port for controller server. + init_temperature(float): The init temperature used in simulated annealing search strategy. + reduce_rate(float): The decay rate used in simulated annealing search strategy. + search_steps(int): The steps of searching. + key(str): Identity used in communication between controller server and clients. + is_server(bool): Whether current host is controller server. Default: True. + """ + if not is_server: + assert server_addr[ + 0] != "", "You should set the IP and port of server when is_server is False." + self._reduce_rate = reduce_rate + self._init_temperature = init_temperature + self._is_server = is_server + + self._configs = configs + + factory = SearchSpaceFactory() + self._search_space = factory.get_search_space(configs) + init_tokens = self._search_space.init_tokens() + range_table = self._search_space.range_table() + range_table = (len(range_table) * [0], range_table) + _logger.info("range table: {}".format(range_table)) + controller = SAController(range_table, self._reduce_rate, + self._init_temperature, self._max_try_number, + init_tokens, self._constrain_func) + + server_ip, server_port = server_addr + if server_ip == None or server_ip == "": + server_ip = self._get_host_ip() + max_client_num = 100 + self._controller_server = ControllerServer( + controller=controller, + address=(server_ip, server_port), + max_client_num=max_client_num, + search_steps=search_steps, + key=key) + + # create controller server + if self._is_server: + self._controller_server.start() + + self._controller_client = ControllerClient( + self._controller_server.ip(), + self._controller_server.port(), + key=key) + + self._iter = 0 + + def _get_host_ip(self): + return socket.gethostbyname(socket.gethostname()) + + def next_archs(self): + """ + Get next network architectures. + Returns: + list: A list of functions that define networks. + """ + self._current_tokens = self._controller_client.next_tokens() + archs = self._search_space.token2arch(self._current_tokens) + return archs + + def reward(self, score): + """ + Return reward of current searched network. + Args: + score(float): The score of current searched network. + Returns: + bool: True means updating successfully while false means failure. + """ + self._iter += 1 + return self._controller_client.update(self._current_tokens, score) diff --git a/paddleslim/nas/search_space/__init__.py b/paddleslim/nas/search_space/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..51b433d452b8cd8c3eb32582d9caa43634b700d0 --- /dev/null +++ b/paddleslim/nas/search_space/__init__.py @@ -0,0 +1,32 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mobilenetv2 +from .mobilenetv2 import * +import mobilenetv1 +from .mobilenetv1 import * +import resnet +from .resnet import * +import search_space_registry +from search_space_registry import * +import search_space_factory +from search_space_factory import * +import search_space_base +from search_space_base import * + +__all__ = [] +__all__ += mobilenetv2.__all__ +__all__ += search_space_registry.__all__ +__all__ += search_space_factory.__all__ +__all__ += search_space_base.__all__ diff --git a/paddleslim/nas/search_space/base_layer.py b/paddleslim/nas/search_space/base_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..b497c92a2ca57b4acab0c39c5dbd69d30083e295 --- /dev/null +++ b/paddleslim/nas/search_space/base_layer.py @@ -0,0 +1,60 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + + +def conv_bn_layer(input, + filter_size, + num_filters, + stride, + padding='SAME', + num_groups=1, + act=None, + name=None, + use_cudnn=True): + """Build convolution and batch normalization layers. + Args: + input(Variable): input. + filter_size(int): filter size. + num_filters(int): number of filters. + stride(int): stride. + padding(int|list|str): padding. + num_groups(int): number of groups. + act(str): activation type. + name(str): name. + use_cudnn(bool): whether use cudnn. + Returns: + Variable, layers output. + """ + conv = fluid.layers.conv2d( + input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=ParamAttr(name=name + '_weights'), + bias_attr=False) + bn_name = name + '_bn' + return fluid.layers.batch_norm( + input=conv, + act = act, + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(name=bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') diff --git a/paddleslim/nas/search_space/combine_search_space.py b/paddleslim/nas/search_space/combine_search_space.py new file mode 100644 index 0000000000000000000000000000000000000000..667720a9110aa92e096a4f8fa30bb3e4b3e3cecb --- /dev/null +++ b/paddleslim/nas/search_space/combine_search_space.py @@ -0,0 +1,99 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from .search_space_base import SearchSpaceBase +from .search_space_registry import SEARCHSPACE +from .base_layer import conv_bn_layer + +__all__ = ["CombineSearchSpace"] + + +class CombineSearchSpace(object): + """ + Combine Search Space. + Args: + configs(list): multi config. + """ + + def __init__(self, config_lists): + self.lens = len(config_lists) + self.spaces = [] + for config_list in config_lists: + key, config = config_list + self.spaces.append(self._get_single_search_space(key, config)) + + def _get_single_search_space(self, key, config): + """ + get specific model space based on key and config. + + Args: + key(str): model space name. + config(dict): basic config information. + return: + model space(class) + """ + cls = SEARCHSPACE.get(key) + space = cls(config['input_size'], config['output_size'], + config['block_num'], config['block_mask']) + + return space + + def init_tokens(self): + """ + Combine init tokens. + """ + tokens = [] + self.single_token_num = [] + for space in self.spaces: + tokens.extend(space.init_tokens()) + self.single_token_num.append(len(space.init_tokens())) + return tokens + + def range_table(self): + """ + Combine range table. + """ + range_tables = [] + for space in self.spaces: + range_tables.extend(space.range_table()) + return range_tables + + def token2arch(self, tokens=None): + """ + Combine model arch + """ + if tokens is None: + tokens = self.init_tokens() + + token_list = [] + start_idx = 0 + end_idx = 0 + + for i in range(len(self.single_token_num)): + end_idx += self.single_token_num[i] + token_list.append(tokens[start_idx:end_idx]) + start_idx = end_idx + + model_archs = [] + for space, token in zip(self.spaces, token_list): + model_archs.append(space.token2arch(token)) + + return model_archs diff --git a/paddleslim/nas/search_space/mobilenetv1.py b/paddleslim/nas/search_space/mobilenetv1.py new file mode 100644 index 0000000000000000000000000000000000000000..8b3277d2cb1b472ccd5e27407e3099b28e64f42b --- /dev/null +++ b/paddleslim/nas/search_space/mobilenetv1.py @@ -0,0 +1,224 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from .search_space_base import SearchSpaceBase +from .base_layer import conv_bn_layer +from .search_space_registry import SEARCHSPACE + +__all__ = ["MobileNetV1Space"] + + +@SEARCHSPACE.register +class MobileNetV1Space(SearchSpaceBase): + def __init__(self, + input_size, + output_size, + block_num, + scale=1.0, + class_dim=1000): + super(MobileNetV1Space, self).__init__(input_size, output_size, + block_num) + self.scale = scale + self.class_dim = class_dim + # self.head_num means the channel of first convolution + self.head_num = np.array([3, 4, 8, 12, 16, 24, 32]) # 7 + # self.filter_num1 ~ self.filtet_num9 means channel of the following convolution + self.filter_num1 = np.array([3, 4, 8, 12, 16, 24, 32, 48]) # 8 + self.filter_num2 = np.array([8, 12, 16, 24, 32, 48, 64, 80]) # 8 + self.filter_num3 = np.array( + [16, 24, 32, 48, 64, 80, 96, 128, 144, 160]) #10 + self.filter_num4 = np.array( + [24, 32, 48, 64, 80, 96, 128, 144, 160, 192]) #10 + self.filter_num5 = np.array( + [32, 48, 64, 80, 96, 128, 144, 160, 192, 224, 256, 320]) #12 + self.filter_num6 = np.array( + [64, 80, 96, 128, 144, 160, 192, 224, 256, 320, 384]) #11 + self.filter_num7 = np.array([ + 64, 80, 96, 128, 144, 160, 192, 224, 256, 320, 384, 512, 1024, 1048 + ]) #14 + self.filter_num8 = np.array( + [128, 144, 160, 192, 224, 256, 320, 384, 512, 576, 640, 704, + 768]) #13 + self.filter_num9 = np.array( + [160, 192, 224, 256, 320, 384, 512, 640, 768, 832, 1024, + 1048]) #12 + # self.k_size means kernel size + self.k_size = np.array([3, 5]) #2 + # self.repeat means repeat_num in forth downsample + self.repeat = np.array([1, 2, 3, 4, 5, 6]) #6 + + assert self.block_num < 6, 'MobileNetV1: block number must less than 6, but receive block number is {}'.format( + self.block_num) + + def init_tokens(self): + """ + The initial token. + The first one is the index of the first layers' channel in self.head_num, + each line in the following represent the index of the [filter_num1, filter_num2, kernel_size] + and depth means repeat times for forth downsample + """ + # yapf: disable + base_init_tokens = [6, # 32 + 6, 6, 0, # 32, 64, 3 + 6, 7, 0, # 64, 128, 3 + 7, 6, 0, # 128, 128, 3 + 6, 10, 0, # 128, 256, 3 + 10, 8, 0, # 256, 256, 3 + 8, 11, 0, # 256, 512, 3 + 4, # depth 5 + 11, 8, 0, # 512, 512, 3 + 8, 10, 0, # 512, 1024, 3 + 10, 10, 0] # 1024, 1024, 3 + # yapf: enable + if self.block_num < 5: + self.token_len = 1 + (self.block_num * 2 - 1) * 3 + else: + self.token_len = 2 + (self.block_num * 2 - 1) * 3 + return base_init_tokens[:self.token_len] + + def range_table(self): + """ + Get range table of current search space, constrains the range of tokens. + """ + # yapf: disable + base_range_table = [len(self.head_num), + len(self.filter_num1), len(self.filter_num2), len(self.k_size), + len(self.filter_num2), len(self.filter_num3), len(self.k_size), + len(self.filter_num3), len(self.filter_num4), len(self.k_size), + len(self.filter_num4), len(self.filter_num5), len(self.k_size), + len(self.filter_num5), len(self.filter_num6), len(self.k_size), + len(self.filter_num6), len(self.filter_num7), len(self.k_size), + len(self.repeat), + len(self.filter_num7), len(self.filter_num8), len(self.k_size), + len(self.filter_num8), len(self.filter_num9), len(self.k_size), + len(self.filter_num9), len(self.filter_num9), len(self.k_size)] + # yapf: enable + return base_range_table[:self.token_len] + + def token2arch(self, tokens=None): + + if tokens is None: + tokens = self.tokens() + + bottleneck_param_list = [] + + if self.block_num >= 1: + # tokens[0] = 32 + # 32, 64 + bottleneck_param_list.append( + (self.filter_num1[tokens[1]], self.filter_num2[tokens[2]], 1, + self.k_size[tokens[3]])) + if self.block_num >= 2: + # 64 128 128 128 + bottleneck_param_list.append( + (self.filter_num2[tokens[4]], self.filter_num3[tokens[5]], 2, + self.k_size[tokens[6]])) + bottleneck_param_list.append( + (self.filter_num3[tokens[7]], self.filter_num4[tokens[8]], 1, + self.k_size[tokens[9]])) + if self.block_num >= 3: + # 128 256 256 256 + bottleneck_param_list.append( + (self.filter_num4[tokens[10]], self.filter_num5[tokens[11]], 2, + self.k_size[tokens[12]])) + bottleneck_param_list.append( + (self.filter_num5[tokens[13]], self.filter_num6[tokens[14]], 1, + self.k_size[tokens[15]])) + if self.block_num >= 4: + # 256 512 (512 512) * 5 + bottleneck_param_list.append( + (self.filter_num6[tokens[16]], self.filter_num7[tokens[17]], 2, + self.k_size[tokens[18]])) + for i in range(self.repeat[tokens[19]]): + bottleneck_param_list.append( + (self.filter_num7[tokens[20]], + self.filter_num8[tokens[21]], 1, self.k_size[tokens[22]])) + if self.block_num >= 5: + # 512 1024 1024 1024 + bottleneck_param_list.append( + (self.filter_num8[tokens[23]], self.filter_num9[tokens[24]], 2, + self.k_size[tokens[25]])) + bottleneck_param_list.append( + (self.filter_num9[tokens[26]], self.filter_num9[tokens[27]], 1, + self.k_size[tokens[28]])) + + def net_arch(input): + input = conv_bn_layer( + input=input, + filter_size=3, + num_filters=self.head_num[tokens[0]], + stride=2, + name='mobilenetv1') + + for i, layer_setting in enumerate(bottleneck_param_list): + filter_num1, filter_num2, stride, kernel_size = layer_setting + input = self._depthwise_separable( + input=input, + num_filters1=filter_num1, + num_filters2=filter_num2, + num_groups=filter_num1, + stride=stride, + scale=self.scale, + kernel_size=kernel_size, + name='mobilenetv1_{}'.format(str(i + 1))) + + if self.output_size == 1: + print('NOTE: if output_size is 1, add fc layer in the end!!!') + input = fluid.layers.fc( + input=input, + size=self.class_dim, + param_attr=ParamAttr(name='mobilenetv2_fc_weights'), + bias_attr=ParamAttr(name='mobilenetv2_fc_offset')) + else: + assert self.output_size == input.shape[2], \ + ("output_size must EQUAL to input_size / (2^block_num)." + "But receive input_size={}, output_size={}, block_num={}".format( + self.input_size, self.output_size, self.block_num)) + + return input + + return net_arch + + def _depthwise_separable(self, + input, + num_filters1, + num_filters2, + num_groups, + stride, + scale, + kernel_size, + name=None): + depthwise_conv = conv_bn_layer( + input=input, + filter_size=kernel_size, + num_filters=int(num_filters1 * scale), + stride=stride, + num_groups=int(num_groups * scale), + use_cudnn=False, + name=name + '_dw') + pointwise_conv = conv_bn_layer( + input=depthwise_conv, + filter_size=1, + num_filters=int(num_filters2 * scale), + stride=1, + name=name + '_sep') + + return pointwise_conv diff --git a/paddleslim/nas/search_space/mobilenetv2.py b/paddleslim/nas/search_space/mobilenetv2.py new file mode 100644 index 0000000000000000000000000000000000000000..e974a676a70546e19aa4649679393031634e7822 --- /dev/null +++ b/paddleslim/nas/search_space/mobilenetv2.py @@ -0,0 +1,302 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from .search_space_base import SearchSpaceBase +from .base_layer import conv_bn_layer +from .search_space_registry import SEARCHSPACE + +__all__ = ["MobileNetV2Space"] + + +@SEARCHSPACE.register +class MobileNetV2Space(SearchSpaceBase): + def __init__(self, + input_size, + output_size, + block_num, + block_mask=None, + scale=1.0, + class_dim=1000): + super(MobileNetV2Space, self).__init__(input_size, output_size, + block_num, block_mask) + assert self.block_mask == None, 'MobileNetV2Space will use origin MobileNetV2 as seach space, so use input_size, output_size and block_num to search' + # self.head_num means the first convolution channel + self.head_num = np.array([3, 4, 8, 12, 16, 24, 32]) #7 + # self.filter_num1 ~ self.filter_num6 means following convlution channel + self.filter_num1 = np.array([3, 4, 8, 12, 16, 24, 32, 48]) #8 + self.filter_num2 = np.array([8, 12, 16, 24, 32, 48, 64, 80]) #8 + self.filter_num3 = np.array([16, 24, 32, 48, 64, 80, 96, 128]) #8 + self.filter_num4 = np.array( + [24, 32, 48, 64, 80, 96, 128, 144, 160, 192]) #10 + self.filter_num5 = np.array( + [32, 48, 64, 80, 96, 128, 144, 160, 192, 224]) #10 + self.filter_num6 = np.array( + [64, 80, 96, 128, 144, 160, 192, 224, 256, 320, 384, 512]) #12 + # self.k_size means kernel size + self.k_size = np.array([3, 5]) #2 + # self.multiply means expansion_factor of each _inverted_residual_unit + self.multiply = np.array([1, 2, 3, 4, 6]) #5 + # self.repeat means repeat_num _inverted_residual_unit in each _invresi_blocks + self.repeat = np.array([1, 2, 3, 4, 5, 6]) #6 + self.scale = scale + self.class_dim = class_dim + + assert self.block_num < 7, 'MobileNetV2: block number must less than 7, but receive block number is {}'.format( + self.block_num) + + def init_tokens(self): + """ + The initial token. + The first one is the index of the first layers' channel in self.head_num, + each line in the following represent the index of the [expansion_factor, filter_num, repeat_num, kernel_size] + """ + # original MobileNetV2 + # yapf: disable + init_token_base = [4, # 1, 16, 1 + 4, 5, 1, 0, # 6, 24, 1 + 4, 5, 1, 0, # 6, 24, 2 + 4, 4, 2, 0, # 6, 32, 3 + 4, 4, 3, 0, # 6, 64, 4 + 4, 5, 2, 0, # 6, 96, 3 + 4, 7, 2, 0, # 6, 160, 3 + 4, 9, 0, 0] # 6, 320, 1 + # yapf: enable + + if self.block_num < 5: + self.token_len = 1 + (self.block_num - 1) * 4 + else: + self.token_len = 1 + (self.block_num + 2 * + (self.block_num - 5)) * 4 + + return init_token_base[:self.token_len] + + def range_table(self): + """ + Get range table of current search space, constrains the range of tokens. + """ + # head_num + 7 * [multiple(expansion_factor), filter_num, repeat, kernel_size] + # yapf: disable + range_table_base = [len(self.head_num), + len(self.multiply), len(self.filter_num1), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num1), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num2), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num3), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num4), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num5), len(self.repeat), len(self.k_size), + len(self.multiply), len(self.filter_num6), len(self.repeat), len(self.k_size)] + range_table_base = list(np.array(range_table_base) - 1) + # yapf: enable + return range_table_base[:self.token_len] + + def token2arch(self, tokens=None): + """ + return net_arch function + """ + + if tokens is None: + tokens = self.init_tokens() + print(tokens) + + bottleneck_params_list = [] + if self.block_num >= 1: + bottleneck_params_list.append( + (1, self.head_num[tokens[0]], 1, 1, 3)) + if self.block_num >= 2: + bottleneck_params_list.append( + (self.multiply[tokens[1]], self.filter_num1[tokens[2]], + self.repeat[tokens[3]], 2, self.k_size[tokens[4]])) + if self.block_num >= 3: + bottleneck_params_list.append( + (self.multiply[tokens[5]], self.filter_num1[tokens[6]], + self.repeat[tokens[7]], 2, self.k_size[tokens[8]])) + if self.block_num >= 4: + bottleneck_params_list.append( + (self.multiply[tokens[9]], self.filter_num2[tokens[10]], + self.repeat[tokens[11]], 2, self.k_size[tokens[12]])) + if self.block_num >= 5: + bottleneck_params_list.append( + (self.multiply[tokens[13]], self.filter_num3[tokens[14]], + self.repeat[tokens[15]], 2, self.k_size[tokens[16]])) + bottleneck_params_list.append( + (self.multiply[tokens[17]], self.filter_num4[tokens[18]], + self.repeat[tokens[19]], 1, self.k_size[tokens[20]])) + if self.block_num >= 6: + bottleneck_params_list.append( + (self.multiply[tokens[21]], self.filter_num5[tokens[22]], + self.repeat[tokens[23]], 2, self.k_size[tokens[24]])) + bottleneck_params_list.append( + (self.multiply[tokens[25]], self.filter_num6[tokens[26]], + self.repeat[tokens[27]], 1, self.k_size[tokens[28]])) + + def net_arch(input): + #conv1 + # all padding is 'SAME' in the conv2d, can compute the actual padding automatic. + input = conv_bn_layer( + input, + num_filters=int(32 * self.scale), + filter_size=3, + stride=2, + padding='SAME', + act='relu6', + name='mobilenetv2_conv1_1') + + # bottleneck sequences + i = 1 + in_c = int(32 * self.scale) + for layer_setting in bottleneck_params_list: + t, c, n, s, k = layer_setting + i += 1 + input = self._invresi_blocks( + input=input, + in_c=in_c, + t=t, + c=int(c * self.scale), + n=n, + s=s, + k=k, + name='mobilenetv2_conv' + str(i)) + in_c = int(c * self.scale) + + # if output_size is 1, add fc layer in the end + if self.output_size == 1: + input = fluid.layers.fc( + input=input, + size=self.class_dim, + param_attr=ParamAttr(name='mobilenetv2_fc_weights'), + bias_attr=ParamAttr(name='mobilenetv2_fc_offset')) + else: + assert self.output_size == input.shape[2], \ + ("output_size must EQUAL to input_size / (2^block_num)." + "But receive input_size={}, output_size={}, block_num={}".format( + self.input_size, self.output_size, self.block_num)) + + return input + + return net_arch + + def _shortcut(self, input, data_residual): + """Build shortcut layer. + Args: + input(Variable): input. + data_residual(Variable): residual layer. + Returns: + Variable, layer output. + """ + return fluid.layers.elementwise_add(input, data_residual) + + def _inverted_residual_unit(self, + input, + num_in_filter, + num_filters, + ifshortcut, + stride, + filter_size, + expansion_factor, + reduction_ratio=4, + name=None): + """Build inverted residual unit. + Args: + input(Variable), input. + num_in_filter(int), number of in filters. + num_filters(int), number of filters. + ifshortcut(bool), whether using shortcut. + stride(int), stride. + filter_size(int), filter size. + padding(str|int|list), padding. + expansion_factor(float), expansion factor. + name(str), name. + Returns: + Variable, layers output. + """ + num_expfilter = int(round(num_in_filter * expansion_factor)) + channel_expand = conv_bn_layer( + input=input, + num_filters=num_expfilter, + filter_size=1, + stride=1, + padding='SAME', + num_groups=1, + act='relu6', + name=name + '_expand') + + bottleneck_conv = conv_bn_layer( + input=channel_expand, + num_filters=num_expfilter, + filter_size=filter_size, + stride=stride, + padding='SAME', + num_groups=num_expfilter, + act='relu6', + name=name + '_dwise', + use_cudnn=False) + + linear_out = conv_bn_layer( + input=bottleneck_conv, + num_filters=num_filters, + filter_size=1, + stride=1, + padding='SAME', + num_groups=1, + act=None, + name=name + '_linear') + out = linear_out + if ifshortcut: + out = self._shortcut(input=input, data_residual=out) + return out + + def _invresi_blocks(self, input, in_c, t, c, n, s, k, name=None): + """Build inverted residual blocks. + Args: + input: Variable, input. + in_c: int, number of in filters. + t: float, expansion factor. + c: int, number of filters. + n: int, number of layers. + s: int, stride. + k: int, filter size. + name: str, name. + Returns: + Variable, layers output. + """ + first_block = self._inverted_residual_unit( + input=input, + num_in_filter=in_c, + num_filters=c, + ifshortcut=False, + stride=s, + filter_size=k, + expansion_factor=t, + name=name + '_1') + + last_residual_block = first_block + last_c = c + + for i in range(1, n): + last_residual_block = self._inverted_residual_unit( + input=last_residual_block, + num_in_filter=last_c, + num_filters=c, + ifshortcut=True, + stride=1, + filter_size=k, + expansion_factor=t, + name=name + '_' + str(i + 1)) + return last_residual_block diff --git a/paddleslim/nas/search_space/resnet.py b/paddleslim/nas/search_space/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..fd761d417575988e8ba8bd99da25372613c5912f --- /dev/null +++ b/paddleslim/nas/search_space/resnet.py @@ -0,0 +1,175 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from .search_space_base import SearchSpaceBase +from .base_layer import conv_bn_layer +from .search_space_registry import SEARCHSPACE + +__all__ = ["ResNetSpace"] + + +@SEARCHSPACE.register +class ResNetSpace(SearchSpaceBase): + def __init__(self, + input_size, + output_size, + block_num, + block_mask=None, + extract_feature=False, + class_dim=1000): + super(ResNetSpace, self).__init__(input_size, output_size, block_num, + block_mask) + assert self.block_mask == None, 'ResNetSpace will use origin ResNet as seach space, so use input_size, output_size and block_num to search' + # self.filter_num1 ~ self.filter_num4 means convolution channel + self.filter_num1 = np.array([48, 64, 96, 128, 160, 192, 224]) #7 + self.filter_num2 = np.array([64, 96, 128, 160, 192, 256, 320]) #7 + self.filter_num3 = np.array([128, 160, 192, 256, 320, 384]) #6 + self.filter_num4 = np.array([192, 256, 384, 512, 640]) #5 + # self.repeat1 ~ self.repeat4 means depth of network + self.repeat1 = [2, 3, 4, 5, 6] #5 + self.repeat2 = [2, 3, 4, 5, 6, 7] #6 + self.repeat3 = [2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24] #13 + self.repeat4 = [2, 3, 4, 5, 6, 7] #6 + self.class_dim = class_dim + self.extract_feature = extract_feature + assert self.block_num < 5, 'ResNet: block number must less than 5, but receive block number is {}'.format( + self.block_num) + + def init_tokens(self): + """ + The initial token. + return 2 * self.block_num, 2 means depth and num_filter + """ + init_token_base = [0, 0, 0, 0, 0, 0, 0, 0] + self.token_len = self.block_num * 2 + return init_token_base[:self.token_len] + + def range_table(self): + """ + Get range table of current search space, constrains the range of tokens. + """ + #2 * self.block_num, 2 means depth and num_filter + range_table_base = [ + len(self.filter_num1), len(self.repeat1), len(self.filter_num2), + len(self.repeat2), len(self.filter_num3), len(self.repeat3), + len(self.filter_num4), len(self.repeat4) + ] + return range_table_base[:self.token_len] + + def token2arch(self, tokens=None): + """ + return net_arch function + """ + if tokens is None: + tokens = self.init_tokens() + + depth = [] + num_filters = [] + if self.block_num >= 1: + filter1 = self.filter_num1[tokens[0]] + repeat1 = self.repeat1[tokens[1]] + num_filters.append(filter1) + depth.append(repeat1) + if self.block_num >= 2: + filter2 = self.filter_num2[tokens[2]] + repeat2 = self.repeat2[tokens[3]] + num_filters.append(filter2) + depth.append(repeat2) + if self.block_num >= 3: + filter3 = self.filter_num3[tokens[4]] + repeat3 = self.repeat3[tokens[5]] + num_filters.append(filter3) + depth.append(repeat3) + if self.block_num >= 4: + filter4 = self.filter_num4[tokens[6]] + repeat4 = self.repeat4[tokens[7]] + num_filters.append(filter4) + depth.append(repeat4) + + def net_arch(input): + conv = conv_bn_layer( + input=input, + filter_size=5, + num_filters=filter1, + stride=2, + act='relu', + name='resnet_conv0') + for block in range(len(depth)): + for i in range(depth[block]): + conv = self._bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + name='resnet_depth{}_block{}'.format(i, block)) + + if self.output_size == 1: + conv = fluid.layers.fc( + input=conv, + size=self.class_dim, + act=None, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.NormalInitializer(0.0, + 0.01)), + bias_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.ConstantInitializer(0))) + + return conv + + return net_arch + + def _shortcut(self, input, ch_out, stride, name=None): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + return conv_bn_layer( + input=input, + filter_size=1, + num_filters=ch_out, + stride=stride, + name=name + '_conv') + else: + return input + + def _bottleneck_block(self, input, num_filters, stride, name=None): + conv0 = conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name + '_bottleneck_conv0') + conv1 = conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu', + name=name + '_bottleneck_conv1') + conv2 = conv_bn_layer( + input=conv1, + num_filters=num_filters * 4, + filter_size=1, + act=None, + name=name + '_bottleneck_conv2') + + short = self._shortcut( + input, num_filters * 4, stride, name=name + '_shortcut') + + return fluid.layers.elementwise_add( + x=short, y=conv2, act='relu', name=name + '_bottleneck_add') diff --git a/paddleslim/nas/search_space/search_space_base.py b/paddleslim/nas/search_space/search_space_base.py new file mode 100644 index 0000000000000000000000000000000000000000..6a83f86005a5fb2408f7f85f40dff8a9e5cba819 --- /dev/null +++ b/paddleslim/nas/search_space/search_space_base.py @@ -0,0 +1,45 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ['SearchSpaceBase'] + + +class SearchSpaceBase(object): + """Controller for Neural Architecture Search. + """ + + def __init__(self, input_size, output_size, block_num, block_mask, *argss): + self.input_size = input_size + self.output_size = output_size + self.block_num = block_num + self.block_mask = block_mask + + def init_tokens(self): + """Get init tokens in search space. + """ + raise NotImplementedError('Abstract method.') + + def range_table(self): + """Get range table of current search space. + """ + raise NotImplementedError('Abstract method.') + + def token2arch(self, tokens): + """Create networks for training and evaluation according to tokens. + Args: + tokens(list): The tokens which represent a network. + Return: + model arch + """ + raise NotImplementedError('Abstract method.') diff --git a/paddleslim/nas/search_space/search_space_factory.py b/paddleslim/nas/search_space/search_space_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc0be834445e13ddef5d6664d13a69fb6904aa6 --- /dev/null +++ b/paddleslim/nas/search_space/search_space_factory.py @@ -0,0 +1,31 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .combine_search_space import CombineSearchSpace + +__all__ = ["SearchSpaceFactory"] + + +class SearchSpaceFactory(object): + def __init__(self): + pass + + def get_search_space(self, config_lists): + """ + get model spaces based on list(key, config). + + """ + assert isinstance(config_lists, list), "configs must be a list" + + return CombineSearchSpace(config_lists) diff --git a/paddleslim/search/__init__.py b/paddleslim/nas/search_space/search_space_registry.py similarity index 86% rename from paddleslim/search/__init__.py rename to paddleslim/nas/search_space/search_space_registry.py index 4f3182c3058cb33e46777ab1424242b42406a603..2fea80fba4c908759e6123d3d898e94d7ef54c42 100644 --- a/paddleslim/search/__init__.py +++ b/paddleslim/nas/search_space/search_space_registry.py @@ -11,4 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Controllers and controller server""" + +from ...core import Registry + +__all__ = ["SEARCHSPACE"] + +SEARCHSPACE = Registry('searchspace') diff --git a/paddleslim/prune/__init__.py b/paddleslim/prune/__init__.py index 9d0531501ca43921438ee5b2fb58ac0ad2396d1b..bb615b9dfca03ed2b289f902f6d75c73543f6fb2 100644 --- a/paddleslim/prune/__init__.py +++ b/paddleslim/prune/__init__.py @@ -11,3 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pruner +from pruner import * +import auto_pruner +from auto_pruner import * +import controller_server +from controller_server import * +import controller_client +from controller_client import * + +__all__ = [] +__all__ += pruner.__all__ +__all__ += auto_pruner.__all__ +__all__ += controller_server.__all__ +__all__ += controller_client.__all__ diff --git a/paddleslim/prune/auto_pruner.py b/paddleslim/prune/auto_pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..fba8c11170f3fbf2eddbe15942dc642ad448658b --- /dev/null +++ b/paddleslim/prune/auto_pruner.py @@ -0,0 +1,235 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket +import logging +import numpy as np +import paddle.fluid as fluid +from .pruner import Pruner +from ..core import VarWrapper, OpWrapper, GraphWrapper +from ..common import SAController +from ..common import get_logger +from ..analysis import flops + +from ..common import ControllerServer +from ..common import ControllerClient + +__all__ = ["AutoPruner"] + +_logger = get_logger(__name__, level=logging.INFO) + + +class AutoPruner(object): + def __init__(self, + program, + scope, + place, + params=[], + init_ratios=None, + pruned_flops=0.5, + pruned_latency=None, + server_addr=("", 0), + init_temperature=100, + reduce_rate=0.85, + max_try_number=300, + max_client_num=10, + search_steps=300, + max_ratios=[0.9], + min_ratios=[0], + key="auto_pruner", + is_server=True): + """ + Search a group of ratios used to prune program. + Args: + program(Program): The program to be pruned. + scope(Scope): The scope to be pruned. + place(fluid.Place): The device place of parameters. + params(list): The names of parameters to be pruned. + init_ratios(list|float): Init ratios used to pruned parameters in `params`. + List means ratios used for pruning each parameter in `params`. + The length of `init_ratios` should be equal to length of params when `init_ratios` is a list. + If it is a scalar, all the parameters in `params` will be pruned by uniform ratio. + None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None. + pruned_flops(float): The percent of FLOPS to be pruned. Default: None. + pruned_latency(float): The percent of latency to be pruned. Default: None. + server_addr(tuple): A tuple of server ip and server port for controller server. + init_temperature(float): The init temperature used in simulated annealing search strategy. + reduce_rate(float): The decay rate used in simulated annealing search strategy. + max_try_number(int): The max number of trying to generate legal tokens. + max_client_num(int): The max number of connections of controller server. + search_steps(int): The steps of searching. + max_ratios(float|list): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`. + The length of `max_ratios` should be equal to length of params when `max_ratios` is a list. + If it is a scalar, it will used for all the parameters in `params`. + min_ratios(float|list): Min ratios used to pruned parameters in `params`. List means min ratios for each parameter in `params`. + The length of `min_ratios` should be equal to length of params when `min_ratios` is a list. + If it is a scalar, it will used for all the parameters in `params`. + key(str): Identity used in communication between controller server and clients. + is_server(bool): Whether current host is controller server. Default: True. + """ + + self._program = program + self._scope = scope + self._place = place + self._params = params + self._init_ratios = init_ratios + self._pruned_flops = pruned_flops + self._pruned_latency = pruned_latency + self._reduce_rate = reduce_rate + self._init_temperature = init_temperature + self._max_try_number = max_try_number + self._is_server = is_server + + self._range_table = self._get_range_table(min_ratios, max_ratios) + + self._pruner = Pruner() + if self._pruned_flops: + self._base_flops = flops(program) + self._max_flops = self._base_flops * (1 - self._pruned_flops) + _logger.info( + "AutoPruner - base flops: {}; pruned_flops: {}; max_flops: {}". + format(self._base_flops, self._pruned_flops, self._max_flops)) + if self._pruned_latency: + self._base_latency = latency(program) + + if self._init_ratios is None: + self._init_ratios = self._get_init_ratios( + self, _program, self._params, self._pruned_flops, + self._pruned_latency) + init_tokens = self._ratios2tokens(self._init_ratios) + _logger.info("range table: {}".format(self._range_table)) + controller = SAController(self._range_table, self._reduce_rate, + self._init_temperature, self._max_try_number, + init_tokens, self._constrain_func) + + server_ip, server_port = server_addr + if server_ip == None or server_ip == "": + server_ip = self._get_host_ip() + + self._controller_server = ControllerServer( + controller=controller, + address=(server_ip, server_port), + max_client_num=max_client_num, + search_steps=search_steps, + key=key) + + # create controller server + if self._is_server: + self._controller_server.start() + + self._controller_client = ControllerClient( + self._controller_server.ip(), + self._controller_server.port(), + key=key) + + self._iter = 0 + self._param_backup = {} + + def _get_host_ip(self): + return socket.gethostbyname(socket.gethostname()) + + def _get_init_ratios(self, program, params, pruned_flops, pruned_latency): + pass + + def _get_range_table(self, min_ratios, max_ratios): + assert isinstance(min_ratios, list) or isinstance(min_ratios, float) + assert isinstance(max_ratios, list) or isinstance(max_ratios, float) + min_ratios = min_ratios if isinstance( + min_ratios, list) else [min_ratios] * len(self._params) + max_ratios = max_ratios if isinstance( + max_ratios, list) else [max_ratios] * len(self._params) + min_tokens = self._ratios2tokens(min_ratios) + max_tokens = self._ratios2tokens(max_ratios) + return (min_tokens, max_tokens) + + def _constrain_func(self, tokens): + ratios = self._tokens2ratios(tokens) + pruned_program = self._pruner.prune( + self._program, + self._scope, + self._params, + ratios, + place=self._place, + only_graph=True) + current_flops = flops(pruned_program) + result = current_flops < self._max_flops + if not result: + _logger.info("Failed try ratios: {}; flops: {}; max_flops: {}". + format(ratios, current_flops, self._max_flops)) + else: + _logger.info("Success try ratios: {}; flops: {}; max_flops: {}". + format(ratios, current_flops, self._max_flops)) + return result + + def prune(self, program, eval_program=None): + """ + Prune program with latest tokens generated by controller. + Args: + program(fluid.Program): The program to be pruned. + Returns: + Program: The pruned program. + """ + self._current_ratios = self._next_ratios() + pruned_program = self._pruner.prune( + program, + self._scope, + self._params, + self._current_ratios, + place=self._place, + only_graph=False, + param_backup=self._param_backup) + pruned_val_program = None + if eval_program is not None: + pruned_val_program = self._pruner.prune( + program, + self._scope, + self._params, + self._current_ratios, + place=self._place, + only_graph=True) + + _logger.info("AutoPruner - pruned ratios: {}".format( + self._current_ratios)) + return pruned_program, pruned_val_program + + def reward(self, score): + """ + Return reward of current pruned program. + Args: + score(float): The score of pruned program. + """ + self._restore(self._scope) + self._param_backup = {} + tokens = self._ratios2tokens(self._current_ratios) + self._controller_client.update(tokens, score) + self._iter += 1 + + def _restore(self, scope): + for param_name in self._param_backup.keys(): + param_t = scope.find_var(param_name).get_tensor() + param_t.set(self._param_backup[param_name], self._place) + + def _next_ratios(self): + tokens = self._controller_client.next_tokens() + return self._tokens2ratios(tokens) + + def _ratios2tokens(self, ratios): + """Convert pruned ratios to tokens. + """ + return [int(ratio / 0.01) for ratio in ratios] + + def _tokens2ratios(self, tokens): + """Convert tokens to pruned ratios. + """ + return [token * 0.01 for token in tokens] diff --git a/paddleslim/prune/controller_client.py b/paddleslim/prune/controller_client.py new file mode 100644 index 0000000000000000000000000000000000000000..f133e8b28f823bba89024fe1473630feb509a616 --- /dev/null +++ b/paddleslim/prune/controller_client.py @@ -0,0 +1,66 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import socket +from ..common import get_logger + +__all__ = ['ControllerClient'] + +_logger = get_logger(__name__, level=logging.INFO) + + +class ControllerClient(object): + """ + Controller client. + """ + + def __init__(self, server_ip=None, server_port=None, key=None): + """ + Args: + server_ip(str): The ip that controller server listens on. None means getting the ip automatically. Default: None. + server_port(int): The port that controller server listens on. 0 means getting usable port automatically. Default: 0. + key(str): The key used to identify legal agent for controller server. Default: "light-nas" + """ + self.server_ip = server_ip + self.server_port = server_port + self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._key = key + + def update(self, tokens, reward): + """ + Update the controller according to latest tokens and reward. + Args: + tokens(list): The tokens generated in last step. + reward(float): The reward of tokens. + """ + socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + socket_client.connect((self.server_ip, self.server_port)) + tokens = ",".join([str(token) for token in tokens]) + socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward) + .encode()) + tokens = socket_client.recv(1024).decode() + tokens = [int(token) for token in tokens.strip("\n").split(",")] + return tokens + + def next_tokens(self): + """ + Get next tokens. + """ + socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + socket_client.connect((self.server_ip, self.server_port)) + socket_client.send("next_tokens".encode()) + tokens = socket_client.recv(1024).decode() + tokens = [int(token) for token in tokens.strip("\n").split(",")] + return tokens diff --git a/paddleslim/prune/controller_server.py b/paddleslim/prune/controller_server.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc978444656d2650904eedfd37453b6b5e22207 --- /dev/null +++ b/paddleslim/prune/controller_server.py @@ -0,0 +1,128 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import socket +from ..common import get_logger +from threading import Thread +from .lock import lock, unlock + +__all__ = ['ControllerServer'] + +_logger = get_logger(__name__, level=logging.INFO) + + +class ControllerServer(object): + """ + The controller wrapper with a socket server to handle the request of search agent. + """ + + def __init__(self, + controller=None, + address=('', 0), + max_client_num=100, + search_steps=None, + key=None): + """ + Args: + controller(slim.searcher.Controller): The controller used to generate tokens. + address(tuple): The address of current server binding with format (ip, port). Default: ('', 0). + which means setting ip automatically + max_client_num(int): The maximum number of clients connecting to current server simultaneously. Default: 100. + search_steps(int): The total steps of searching. None means never stopping. Default: None + """ + self._controller = controller + self._address = address + self._max_client_num = max_client_num + self._search_steps = search_steps + self._closed = False + self._port = address[1] + self._ip = address[0] + self._key = key + self._socket_file = "./controller_server.socket" + + def start(self): + open(self._socket_file, 'a').close() + socket_file = open(self._socket_file, 'r+') + lock(socket_file) + tid = socket_file.readline() + if tid == '': + _logger.info("start controller server...") + tid = self._start() + socket_file.write("tid: {}\nip: {}\nport: {}\n".format( + tid, self._ip, self._port)) + _logger.info("started controller server...") + unlock(socket_file) + socket_file.close() + + def _start(self): + self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket_server.bind(self._address) + self._socket_server.listen(self._max_client_num) + self._port = self._socket_server.getsockname()[1] + self._ip = self._socket_server.getsockname()[0] + _logger.info("ControllerServer - listen on: [{}:{}]".format( + self._ip, self._port)) + thread = Thread(target=self.run) + thread.start() + return str(thread) + + def close(self): + """Close the server.""" + self._closed = True + os.remove(self._socket_file) + _logger.info("server closed!") + + def port(self): + """Get the port.""" + return self._port + + def ip(self): + """Get the ip.""" + return self._ip + + def run(self): + _logger.info("Controller Server run...") + try: + while ((self._search_steps is None) or + (self._controller._iter < + (self._search_steps))) and not self._closed: + conn, addr = self._socket_server.accept() + message = conn.recv(1024).decode() + if message.strip("\n") == "next_tokens": + tokens = self._controller.next_tokens() + tokens = ",".join([str(token) for token in tokens]) + conn.send(tokens.encode()) + else: + _logger.debug("recv message from {}: [{}]".format(addr, + message)) + messages = message.strip('\n').split("\t") + if (len(messages) < 3) or (messages[0] != self._key): + _logger.debug("recv noise from {}: [{}]".format( + addr, message)) + continue + tokens = messages[1] + reward = messages[2] + tokens = [int(token) for token in tokens.split(",")] + self._controller.update(tokens, float(reward)) + tokens = self._controller.next_tokens() + tokens = ",".join([str(token) for token in tokens]) + conn.send(tokens.encode()) + _logger.debug("send message to {}: [{}]".format(addr, + tokens)) + conn.close() + finally: + self._socket_server.close() + self.close() diff --git a/paddleslim/prune/lock.py b/paddleslim/prune/lock.py new file mode 100644 index 0000000000000000000000000000000000000000..5edcd317304f941c2e7c15ad56e95525dea85398 --- /dev/null +++ b/paddleslim/prune/lock.py @@ -0,0 +1,36 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +__All__ = ['lock', 'unlock'] +if os.name == 'nt': + + def lock(file): + raise NotImplementedError('Windows is not supported.') + + def unlock(file): + raise NotImplementedError('Windows is not supported.') + +elif os.name == 'posix': + from fcntl import flock, LOCK_EX, LOCK_UN + + def lock(file): + """Lock the file in local file system.""" + flock(file.fileno(), LOCK_EX) + + def unlock(file): + """Unlock the file in local file system.""" + flock(file.fileno(), LOCK_UN) +else: + raise RuntimeError("File Locker only support NT and Posix platforms!") diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..0fdde525a793b90df63f3245ac5215365dd7ccf4 --- /dev/null +++ b/paddleslim/prune/pruner.py @@ -0,0 +1,606 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import numpy as np +import paddle.fluid as fluid +import copy +from ..core import VarWrapper, OpWrapper, GraphWrapper +from ..common import get_logger + +__all__ = ["Pruner"] + +_logger = get_logger(__name__, level=logging.INFO) + + +class Pruner(): + def __init__(self, criterion="l1_norm"): + """ + Args: + criterion(str): the criterion used to sort channels for pruning. + It only supports 'l1_norm' currently. + """ + self.criterion = criterion + + def prune(self, + program, + scope, + params, + ratios, + place=None, + lazy=False, + only_graph=False, + param_backup=None, + param_shape_backup=None): + """ + Pruning the given parameters. + Args: + program(fluid.Program): The program to be pruned. + scope(fluid.Scope): The scope storing paramaters to be pruned. + params(list): A list of parameter names to be pruned. + ratios(list): A list of ratios to be used to pruning parameters. + place(fluid.Place): The device place of filter parameters. Defalut: None. + lazy(bool): True means setting the pruned elements to zero. + False means cutting down the pruned elements. Default: False. + only_graph(bool): True means only modifying the graph. + False means modifying graph and variables in scope. Default: False. + param_backup(dict): A dict to backup the values of parameters. Default: None. + param_shape_backup(dict): A dict to backup the shapes of parameters. Default: None. + Returns: + Program: The pruned program. + """ + + self.pruned_list = [] + graph = GraphWrapper(program.clone()) + self._prune_parameters( + graph, + scope, + params, + ratios, + place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + for op in graph.ops(): + if op.type() == 'depthwise_conv2d' or op.type( + ) == 'depthwise_conv2d_grad': + op.set_attr('groups', op.inputs('Filter')[0].shape()[0]) + return graph.program + + def _prune_filters_by_ratio(self, + scope, + params, + ratio, + place, + lazy=False, + only_graph=False, + param_shape_backup=None, + param_backup=None): + """ + Pruning filters by given ratio. + Args: + scope(fluid.core.Scope): The scope used to pruning filters. + params(list): A list of filter parameters. + ratio(float): The ratio to be pruned. + place(fluid.Place): The device place of filter parameters. + lazy(bool): True means setting the pruned elements to zero. + False means cutting down the pruned elements. + only_graph(bool): True means only modifying the graph. + False means modifying graph and variables in scope. + """ + if params[0].name() in self.pruned_list[0]: + return + + if only_graph: + pruned_num = int(round(params[0].shape()[0] * ratio)) + for param in params: + ori_shape = param.shape() + if param_backup is not None and ( + param.name() not in param_backup): + param_backup[param.name()] = copy.deepcopy(ori_shape) + new_shape = list(ori_shape) + new_shape[0] -= pruned_num + param.set_shape(new_shape) + _logger.debug("prune [{}] from {} to {}".format(param.name( + ), ori_shape, new_shape)) + self.pruned_list[0].append(param.name()) + return range(pruned_num) + + else: + + param_t = scope.find_var(params[0].name()).get_tensor() + pruned_idx = self._cal_pruned_idx( + params[0].name(), np.array(param_t), ratio, axis=0) + for param in params: + assert isinstance(param, VarWrapper) + param_t = scope.find_var(param.name()).get_tensor() + if param_backup is not None and ( + param.name() not in param_backup): + param_backup[param.name()] = copy.deepcopy( + np.array(param_t)) + pruned_param = self._prune_tensor( + np.array(param_t), pruned_idx, pruned_axis=0, lazy=lazy) + param_t.set(pruned_param, place) + ori_shape = param.shape() + if param_shape_backup is not None and ( + param.name() not in param_shape_backup): + param_shape_backup[param.name()] = copy.deepcopy( + param.shape()) + new_shape = list(param.shape()) + new_shape[0] = pruned_param.shape[0] + param.set_shape(new_shape) + _logger.debug("prune [{}] from {} to {}".format(param.name( + ), ori_shape, new_shape)) + self.pruned_list[0].append(param.name()) + return pruned_idx + + def _prune_parameter_by_idx(self, + scope, + params, + pruned_idx, + pruned_axis, + place, + lazy=False, + only_graph=False, + param_shape_backup=None, + param_backup=None): + """ + Pruning parameters in given axis. + Args: + scope(fluid.core.Scope): The scope storing paramaters to be pruned. + params(VarWrapper): The parameter to be pruned. + pruned_idx(list): The index of elements to be pruned. + pruned_axis(int): The pruning axis. + place(fluid.Place): The device place of filter parameters. + lazy(bool): True means setting the pruned elements to zero. + False means cutting down the pruned elements. + only_graph(bool): True means only modifying the graph. + False means modifying graph and variables in scope. + """ + if params[0].name() in self.pruned_list[pruned_axis]: + return + + if only_graph: + pruned_num = len(pruned_idx) + for param in params: + ori_shape = param.shape() + if param_backup is not None and ( + param.name() not in param_backup): + param_backup[param.name()] = copy.deepcopy(ori_shape) + new_shape = list(ori_shape) + new_shape[pruned_axis] -= pruned_num + param.set_shape(new_shape) + _logger.debug("prune [{}] from {} to {}".format(param.name( + ), ori_shape, new_shape)) + self.pruned_list[pruned_axis].append(param.name()) + + else: + for param in params: + assert isinstance(param, VarWrapper) + param_t = scope.find_var(param.name()).get_tensor() + if param_backup is not None and ( + param.name() not in param_backup): + param_backup[param.name()] = copy.deepcopy( + np.array(param_t)) + pruned_param = self._prune_tensor( + np.array(param_t), pruned_idx, pruned_axis, lazy=lazy) + param_t.set(pruned_param, place) + ori_shape = param.shape() + + if param_shape_backup is not None and ( + param.name() not in param_shape_backup): + param_shape_backup[param.name()] = copy.deepcopy( + param.shape()) + new_shape = list(param.shape()) + new_shape[pruned_axis] = pruned_param.shape[pruned_axis] + param.set_shape(new_shape) + _logger.debug("prune [{}] from {} to {}".format(param.name( + ), ori_shape, new_shape)) + self.pruned_list[pruned_axis].append(param.name()) + + def _forward_search_related_op(self, graph, param): + """ + Forward search operators that will be affected by pruning of param. + Args: + graph(GraphWrapper): The graph to be searched. + param(VarWrapper): The current pruned parameter. + Returns: + list: A list of operators. + """ + assert isinstance(param, VarWrapper) + visited = {} + for op in graph.ops(): + visited[op.idx()] = False + stack = [] + for op in graph.ops(): + if (not op.is_bwd_op()) and (param in op.all_inputs()): + stack.append(op) + visit_path = [] + while len(stack) > 0: + top_op = stack[len(stack) - 1] + if visited[top_op.idx()] == False: + visit_path.append(top_op) + visited[top_op.idx()] = True + next_ops = None + if top_op.type() == "conv2d" and param not in top_op.all_inputs(): + next_ops = None + elif top_op.type() == "mul": + next_ops = None + else: + next_ops = self._get_next_unvisited_op(graph, visited, top_op) + if next_ops == None: + stack.pop() + else: + stack += next_ops + return visit_path + + def _get_next_unvisited_op(self, graph, visited, top_op): + """ + Get next unvisited adjacent operators of given operators. + Args: + graph(GraphWrapper): The graph used to search. + visited(list): The ids of operators that has been visited. + top_op: The given operator. + Returns: + list: A list of operators. + """ + assert isinstance(top_op, OpWrapper) + next_ops = [] + for op in graph.next_ops(top_op): + if (visited[op.idx()] == False) and (not op.is_bwd_op()): + next_ops.append(op) + return next_ops if len(next_ops) > 0 else None + + def _get_accumulator(self, graph, param): + """ + Get accumulators of given parameter. The accumulator was created by optimizer. + Args: + graph(GraphWrapper): The graph used to search. + param(VarWrapper): The given parameter. + Returns: + list: A list of accumulators which are variables. + """ + assert isinstance(param, VarWrapper) + params = [] + for op in param.outputs(): + if op.is_opt_op(): + for out_var in op.all_outputs(): + if graph.is_persistable(out_var) and out_var.name( + ) != param.name(): + params.append(out_var) + return params + + def _forward_pruning_ralated_params(self, + graph, + scope, + param, + place, + ratio=None, + pruned_idxs=None, + lazy=False, + only_graph=False, + param_backup=None, + param_shape_backup=None): + """ + Pruning all the parameters affected by the pruning of given parameter. + Args: + graph(GraphWrapper): The graph to be searched. + scope(fluid.core.Scope): The scope storing paramaters to be pruned. + param(VarWrapper): The given parameter. + place(fluid.Place): The device place of filter parameters. + ratio(float): The target ratio to be pruned. + pruned_idx(list): The index of elements to be pruned. + lazy(bool): True means setting the pruned elements to zero. + False means cutting down the pruned elements. + only_graph(bool): True means only modifying the graph. + False means modifying graph and variables in scope. + """ + assert isinstance( + graph, + GraphWrapper), "graph must be instance of slim.core.GraphWrapper" + assert isinstance( + param, + VarWrapper), "param must be instance of slim.core.VarWrapper" + + if param.name() in self.pruned_list[0]: + return + related_ops = self._forward_search_related_op(graph, param) + + if ratio is None: + assert pruned_idxs is not None + self._prune_parameter_by_idx( + scope, [param] + self._get_accumulator(graph, param), + pruned_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + + else: + pruned_idxs = self._prune_filters_by_ratio( + scope, [param] + self._get_accumulator(graph, param), + ratio, + place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + corrected_idxs = pruned_idxs[:] + + for idx, op in enumerate(related_ops): + if op.type() == "conv2d" and (param not in op.all_inputs()): + for in_var in op.all_inputs(): + if graph.is_parameter(in_var): + conv_param = in_var + self._prune_parameter_by_idx( + scope, [conv_param] + self._get_accumulator( + graph, conv_param), + corrected_idxs, + pruned_axis=1, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + if op.type() == "depthwise_conv2d": + for in_var in op.all_inputs(): + if graph.is_parameter(in_var): + conv_param = in_var + self._prune_parameter_by_idx( + scope, [conv_param] + self._get_accumulator( + graph, conv_param), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + elif op.type() == "elementwise_add": + # pruning bias + for in_var in op.all_inputs(): + if graph.is_parameter(in_var): + bias_param = in_var + self._prune_parameter_by_idx( + scope, [bias_param] + self._get_accumulator( + graph, bias_param), + pruned_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + elif op.type() == "mul": # pruning fc layer + fc_input = None + fc_param = None + for in_var in op.all_inputs(): + if graph.is_parameter(in_var): + fc_param = in_var + else: + fc_input = in_var + + idx = [] + feature_map_size = fc_input.shape()[2] * fc_input.shape()[3] + range_idx = np.array(range(feature_map_size)) + for i in corrected_idxs: + idx += list(range_idx + i * feature_map_size) + corrected_idxs = idx + self._prune_parameter_by_idx( + scope, [fc_param] + self._get_accumulator(graph, fc_param), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + + elif op.type() == "concat": + concat_inputs = op.all_inputs() + last_op = related_ops[idx - 1] + for out_var in last_op.all_outputs(): + if out_var in concat_inputs: + concat_idx = concat_inputs.index(out_var) + offset = 0 + for ci in range(concat_idx): + offset += concat_inputs[ci].shape()[1] + corrected_idxs = [x + offset for x in pruned_idxs] + elif op.type() == "batch_norm": + bn_inputs = op.all_inputs() + mean = bn_inputs[2] + variance = bn_inputs[3] + alpha = bn_inputs[0] + beta = bn_inputs[1] + self._prune_parameter_by_idx( + scope, [mean] + self._get_accumulator(graph, mean), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + self._prune_parameter_by_idx( + scope, [variance] + self._get_accumulator(graph, variance), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + self._prune_parameter_by_idx( + scope, [alpha] + self._get_accumulator(graph, alpha), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + self._prune_parameter_by_idx( + scope, [beta] + self._get_accumulator(graph, beta), + corrected_idxs, + pruned_axis=0, + place=place, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + + def _prune_parameters(self, + graph, + scope, + params, + ratios, + place, + lazy=False, + only_graph=False, + param_backup=None, + param_shape_backup=None): + """ + Pruning the given parameters. + Args: + graph(GraphWrapper): The graph to be searched. + scope(fluid.core.Scope): The scope storing paramaters to be pruned. + params(list): A list of parameter names to be pruned. + ratios(list): A list of ratios to be used to pruning parameters. + place(fluid.Place): The device place of filter parameters. + pruned_idx(list): The index of elements to be pruned. + lazy(bool): True means setting the pruned elements to zero. + False means cutting down the pruned elements. + only_graph(bool): True means only modifying the graph. + False means modifying graph and variables in scope. + """ + assert len(params) == len(ratios) + self.pruned_list = [[], []] + for param, ratio in zip(params, ratios): + assert isinstance(param, str) or isinstance(param, unicode) + param = graph.var(param) + self._forward_pruning_ralated_params( + graph, + scope, + param, + place, + ratio=ratio, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + ops = param.outputs() + for op in ops: + if op.type() == 'conv2d': + brother_ops = self._search_brother_ops(graph, op) + for broher in brother_ops: + for p in graph.get_param_by_op(broher): + self._forward_pruning_ralated_params( + graph, + scope, + p, + place, + ratio=ratio, + lazy=lazy, + only_graph=only_graph, + param_backup=param_backup, + param_shape_backup=param_shape_backup) + + def _search_brother_ops(self, graph, op_node): + """ + Search brother operators that was affected by pruning of given operator. + Args: + graph(GraphWrapper): The graph to be searched. + op_node(OpWrapper): The start node for searching. + Returns: + list: A list of operators. + """ + visited = [op_node.idx()] + stack = [] + brothers = [] + for op in graph.next_ops(op_node): + if (op.type() != 'conv2d') and (op.type() != 'fc') and ( + not op.is_bwd_op()): + stack.append(op) + visited.append(op.idx()) + while len(stack) > 0: + top_op = stack.pop() + if top_op.type().startswith("elementwise_"): + for parent in graph.pre_ops(top_op): + if parent.idx() not in visited and ( + not parent.is_bwd_op()): + if ((parent.type() == 'conv2d') or + (parent.type() == 'fc')): + brothers.append(parent) + else: + stack.append(parent) + visited.append(parent.idx()) + + for child in graph.next_ops(top_op): + if (child.type() != 'conv2d') and (child.type() != 'fc') and ( + child.idx() not in visited) and ( + not child.is_bwd_op()): + stack.append(child) + visited.append(child.idx()) + return brothers + + def _cal_pruned_idx(self, name, param, ratio, axis): + """ + Calculate the index to be pruned on axis by given pruning ratio. + Args: + name(str): The name of parameter to be pruned. + param(np.array): The data of parameter to be pruned. + ratio(float): The ratio to be pruned. + axis(int): The axis to be used for pruning given parameter. + If it is None, the value in self.pruning_axis will be used. + default: None. + Returns: + list: The indexes to be pruned on axis. + """ + prune_num = int(round(param.shape[axis] * ratio)) + reduce_dims = [i for i in range(len(param.shape)) if i != axis] + if self.criterion == 'l1_norm': + criterions = np.sum(np.abs(param), axis=tuple(reduce_dims)) + pruned_idx = criterions.argsort()[:prune_num] + return pruned_idx + + def _prune_tensor(self, tensor, pruned_idx, pruned_axis, lazy=False): + """ + Pruning a array by indexes on given axis. + Args: + tensor(numpy.array): The target array to be pruned. + pruned_idx(list): The indexes to be pruned. + pruned_axis(int): The axis of given array to be pruned on. + lazy(bool): True means setting the pruned elements to zero. + False means remove the pruned elements from memory. + default: False. + Returns: + numpy.array: The pruned array. + """ + mask = np.zeros(tensor.shape[pruned_axis], dtype=bool) + mask[pruned_idx] = True + + def func(data): + return data[~mask] + + def lazy_func(data): + data[mask] = 0 + return data + + if lazy: + return np.apply_along_axis(lazy_func, pruned_axis, tensor) + else: + return np.apply_along_axis(func, pruned_axis, tensor) diff --git a/paddleslim/quant/__init__.py b/paddleslim/quant/__init__.py index 9d0531501ca43921438ee5b2fb58ac0ad2396d1b..5f5f9a300630abac32a9c0301328e344da082c55 100644 --- a/paddleslim/quant/__init__.py +++ b/paddleslim/quant/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .quanter import quant_aware, quant_post, convert +from .quant_embedding import quant_embedding diff --git a/paddleslim/quant/quant_embedding.py b/paddleslim/quant/quant_embedding.py new file mode 100755 index 0000000000000000000000000000000000000000..46a81db65c55f91fdf5525bf0da25414598a0b71 --- /dev/null +++ b/paddleslim/quant/quant_embedding.py @@ -0,0 +1,259 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import copy +import numpy as np + +import paddle.fluid as fluid +from paddle.fluid.framework import IrGraph +from paddle.fluid import core + +#_logger = logging.basicConfig(level=logging.DEBUG) + +__all__ = ['quant_embedding'] + +default_config = { + "quantize_type": "abs_max", + "quantize_bits": 8, + "dtype": "int8" +} + +support_quantize_types = ['abs_max'] +support_quantize_bits = [8] +support_dtype = ['int8'] + + +def _merge_config(old_config, new_config): + """ + merge default config and user defined config + + Args: + old_config(dict): the copy of default_config + new_config(dict): the user defined config, 'params_name' must be set. + When 'threshold' is not set, quant embedding without clip . + """ + old_config.update(new_config) + keys = old_config.keys() + assert 'params_name' in keys, "params_name must be set" + + quantize_type = old_config['quantize_type'] + assert isinstance(quantize_type, str), "quantize_type must be \ + str" + + assert quantize_type in support_quantize_types, " \ + quantize_type {} is not supported, now supported quantize type \ + are {}.".format(quantize_type, support_quantize_types) + + quantize_bits = old_config['quantize_bits'] + assert isinstance(quantize_bits, int), "quantize_bits must be int" + assert quantize_bits in support_quantize_bits, " quantize_bits {} \ + is not supported, now supported quantize bits are \ + {}. ".format(quantize_bits, support_quantize_bits) + + dtype = old_config['dtype'] + assert isinstance(dtype, str), "dtype must be str" + assert dtype in support_dtype, " dtype {} is not \ + supported, now supported dtypes are {} \ + ".format(dtype, support_dtype) + if 'threshold' in keys: + assert isinstance(old_config['threshold'], (float, int)), "threshold \ + must be number." + + print("quant_embedding config {}".format(old_config)) + return old_config + + +def _get_var_tensor(scope, var_name): + """ + get tensor array by name. + Args: + scope(fluid.Scope): scope to get var + var_name(str): vatiable name + Return: + np.array + """ + return np.array(scope.find_var(var_name).get_tensor()) + + +def _clip_tensor(tensor_array, threshold): + """ + when 'threshold' is set, clip tensor by 'threshold' and '-threshold' + Args: + tensor_array(np.array): array to clip + config(dict): config dict + """ + tensor_array[tensor_array > threshold] = threshold + tensor_array[tensor_array < -threshold] = -threshold + return tensor_array + + +def _get_scale_var_name(var_name): + """ + get scale var name + """ + return var_name + '.scale' + + +def _get_quant_var_name(var_name): + """ + get quantized var name + """ + return var_name + '.int8' + + +def _get_dequant_var_name(var_name): + """ + get dequantized var name + """ + return var_name + '.dequantize' + + +def _restore_var(name, arr, scope, place): + """ + restore quantized array to quantized var + """ + tensor = scope.find_var(name).get_tensor() + tensor.set(arr, place) + + +def _clear_var(var_name, scope): + """ + free memory of var + """ + tensor = scope.find_var(var_name).get_tensor() + tensor._clear() + + +def _quant_embedding_abs_max(graph, scope, place, config): + """ + quantize embedding using abs_max + + Args: + graph(IrGraph): graph that includes lookup_table op + scope(fluid.Scope): scope + place(fluid.CPUPlace or flud.CUDAPlace): place + config(dict): config to quant + """ + + def _quant_abs_max(tensor_array, config): + """ + quant array using abs_max op + """ + bit_length = config['quantize_bits'] + scale = np.max(np.abs(tensor_array)).astype("float32") + quanted_tensor = np.round(tensor_array / scale * ( + (1 << (bit_length - 1)) - 1)) + return scale, quanted_tensor.astype(config['dtype']) + + def _insert_dequant_abs_max_op(graph, scope, var_node, scale_node, config): + """ + Insert dequantize_abs_max op in graph + """ + assert var_node.is_var(), "{} is not a var".format(var_node.name()) + + dequant_var_node = graph.create_var_node( + name=_get_dequant_var_name(var_node.name()), + var_type=var_node.type(), + shape=var_node.shape(), + var_dtype=core.VarDesc.VarType.FP32) + scope.var(dequant_var_node.name()) + + max_range = (1 << (config['quantize_bits'] - 1)) - 1 + output_ops = var_node.outputs + dequant_op = graph.create_op_node( + op_type='dequantize_abs_max', + attrs={ + 'max_range': float(max_range), + 'op_role': core.op_proto_and_checker_maker.OpRole.Forward + }, + inputs={'X': var_node, + 'Scale': scale_node}, + outputs={'Out': dequant_var_node}) + graph.link_to(var_node, dequant_op) + graph.link_to(scale_node, dequant_op) + graph.link_to(dequant_op, dequant_var_node) + for node in output_ops: + graph.update_input_link(var_node, dequant_var_node, node) + + all_var_nodes = graph.all_var_nodes() + var_name = config['params_name'] + # find embedding var node by 'params_name' + embedding_node = graph._find_node_by_name(all_var_nodes, var_name) + embedding_tensor = _get_var_tensor(scope, var_name) + if 'threshold' in config.keys(): + embedding_tensor = _clip_tensor(embedding_tensor, config['threshold']) + + # get scale and quanted tensor + scale, quanted_tensor = _quant_abs_max(embedding_tensor, config) + + #create params must to use create_persistable_node + scale_var = graph.create_persistable_node( + _get_scale_var_name(var_name), + var_type=embedding_node.type(), + shape=[1], + var_dtype=core.VarDesc.VarType.FP32) + quant_tensor_var = graph.create_persistable_node( + _get_quant_var_name(var_name), + var_type=embedding_node.type(), + shape=embedding_node.shape(), + var_dtype=core.VarDesc.VarType.INT8) + # create var in scope + scope.var(_get_quant_var_name(var_name)) + scope.var(_get_scale_var_name(var_name)) + #set var by tensor array or scale + _restore_var(_get_quant_var_name(var_name), quanted_tensor, scope, place) + _restore_var(_get_scale_var_name(var_name), np.array(scale), scope, place) + + # insert dequantize_abs_max op + for op_node in embedding_node.outputs: + if op_node.name() == 'lookup_table': + graph.update_input_link(embedding_node, quant_tensor_var, op_node) + var_node = op_node.outputs[0] + _insert_dequant_abs_max_op(graph, scope, var_node, scale_var, + config) + + # free float embedding params memory + _clear_var(embedding_node.name(), scope) + graph.safe_remove_nodes(embedding_node) + + +def quant_embedding(program, place, config, scope=None): + """ + quant lookup_table op parameters + Args: + program(fluid.Program): infer program + scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope() + place(fluid.CPUPlace or fluid.CUDAPlace): place + config(dict): config to quant. The keys are 'params_name', 'quantize_type', \ + 'quantize_bits', 'dtype', 'threshold'. \ + 'params_name': parameter name to quant, must be set. + 'quantize_type': quantize type, supported types are ['abs_max']. default is "abs_max". + 'quantize_bits': quantize bits, supported bits are [8]. default is 8. + 'dtype': quantize dtype, supported dtype are ['int8']. default is 'int8'. + 'threshold': threshold to clip tensor before quant. When threshold is not set, \ + tensor will not be clipped. + """ + assert isinstance(config, dict), "config must be dict" + config = _merge_config(copy.deepcopy(default_config), config) + scope = fluid.global_scope() if scope is None else scope + + graph = IrGraph(core.Graph(program.desc), for_test=True) + if config['quantize_type'] == 'abs_max': + _quant_embedding_abs_max(graph, scope, place, config) + + return graph.to_program() diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py new file mode 100755 index 0000000000000000000000000000000000000000..8ea9fbe32ee3f8617d9f00a1ce097b715957163e --- /dev/null +++ b/paddleslim/quant/quanter.py @@ -0,0 +1,238 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import paddle +import paddle.fluid as fluid +from paddle.fluid.framework import IrGraph +from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass +from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass +from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass +from paddle.fluid.contrib.slim.quantization import TransformForMobilePass +from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass +from paddle.fluid import core + +WEIGHT_QUANTIZATION_TYPES = [ + 'abs_max', 'channel_wise_abs_max', 'range_abs_max', + 'moving_average_abs_max' +] +ACTIVATION_QUANTIZATION_TYPES = [ + 'abs_max', 'range_abs_max', 'moving_average_abs_max' +] +VALID_DTYPES = ['int8'] +TRANSFORM_PASS_OP_TYPES = ['conv2d', 'depthwise_conv2d', 'mul'] +QUANT_DEQUANT_PASS_OP_TYPES = ['elementwise_add', 'pool2d'] + +_quant_config_default = { + # weight quantize type, default is 'abs_max' + 'weight_quantize_type': 'abs_max', + # activation quantize type, default is 'abs_max' + 'activation_quantize_type': 'abs_max', + # weight quantize bit num, default is 8 + 'weight_bits': 8, + # activation quantize bit num, default is 8 + 'activation_bits': 8, + # ops of name_scope in not_quant_pattern list, will not be quantized + 'not_quant_pattern': ['skip_quant'], + # ops of type in quantize_op_types, will be quantized + 'quantize_op_types': + ['conv2d', 'depthwise_conv2d', 'mul', 'elementwise_add', 'pool2d'], + # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' + 'dtype': 'int8', + # window size for 'range_abs_max' quantization. defaulf is 10000 + 'window_size': 10000, + # The decay coefficient of moving average, default is 0.9 + 'moving_rate': 0.9, + # if set quant_weight_only True, then only quantize parameters of layers which need to be quantized, + # and activations will not be quantized. + 'quant_weight_only': False +} + + +def _parse_configs(user_config): + """ + check user configs is valid, and set default value if user not config. + Args: + user_config(dict):the config of user. + Return: + configs(dict): final configs will be used. + """ + + configs = copy.deepcopy(_quant_config_default) + configs.update(user_config) + + # check configs is valid + assert configs['weight_quantize_type'] in WEIGHT_QUANTIZATION_TYPES, \ + "Unknown weight_quantize_type: '%s'. It can only be " + " ".join(WEIGHT_QUANTIZATION_TYPES) + + assert configs['activation_quantize_type'] in ACTIVATION_QUANTIZATION_TYPES, \ + "Unknown activation_quantize_type: '%s'. It can only be " + " ".join(ACTIVATION_QUANTIZATION_TYPES) + + assert isinstance(configs['weight_bits'], int), \ + "weight_bits must be int value." + + assert (configs['weight_bits'] >= 1 and configs['weight_bits'] <= 16), \ + "weight_bits should be between 1 and 16." + + assert isinstance(configs['activation_bits'], int), \ + "activation_bits must be int value." + + assert (configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \ + "activation_bits should be between 1 and 16." + + assert isinstance(configs['not_quant_pattern'], list), \ + "not_quant_pattern must be a list" + + assert isinstance(configs['quantize_op_types'], list), \ + "quantize_op_types must be a list" + + for op_type in configs['quantize_op_types']: + assert (op_type in QUANT_DEQUANT_PASS_OP_TYPES) or ( + op_type in TRANSFORM_PASS_OP_TYPES), "{} is not support, \ + now support op types are {}".format( + op_type, TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES) + + assert isinstance(configs['dtype'], str), \ + "dtype must be a str." + + assert (configs['dtype'] in VALID_DTYPES), \ + "dtype can only be " + " ".join(VALID_DTYPES) + + assert isinstance(configs['window_size'], int), \ + "window_size must be int value, window size for 'range_abs_max' quantization, default is 10000." + + assert isinstance(configs['moving_rate'], float), \ + "moving_rate must be float value, The decay coefficient of moving average, default is 0.9." + + assert isinstance(configs['quant_weight_only'], bool), \ + "quant_weight_only must be bool value, if set quant_weight_only True, " \ + "then only quantize parameters of layers which need to be quantized, " \ + " and activations will not be quantized." + + return configs + + +def quant_aware(program, place, config, scope=None, for_test=False): + """ + add trainable quantization ops in program. + Args: + program(fluid.Program): program + scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope(). + place(fluid.CPUPlace or fluid.CUDAPlace): place + config(dict): configs for quantization, default values are in quant_config_default dict. + for_test: if program is test program, for_test should be set True, else False. + Return: + fluid.Program: user can finetune this quantization program to enhance the accuracy. + """ + + scope = fluid.global_scope() if not scope else scope + assert isinstance(config, dict), "config must be dict" + + assert 'weight_quantize_type' in config.keys( + ), 'weight_quantize_type must be configured' + assert 'activation_quantize_type' in config.keys( + ), 'activation_quantize_type must be configured' + + config = _parse_configs(config) + main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) + + transform_pass_ops = [] + quant_dequant_ops = [] + for op_type in config['quantize_op_types']: + if op_type in TRANSFORM_PASS_OP_TYPES: + transform_pass_ops.append(op_type) + elif op_type in QUANT_DEQUANT_PASS_OP_TYPES: + quant_dequant_ops.append(op_type) + if len(transform_pass_ops) > 0: + transform_pass = QuantizationTransformPass( + scope=scope, + place=place, + weight_bits=config['weight_bits'], + activation_bits=config['activation_bits'], + activation_quantize_type=config['activation_quantize_type'], + weight_quantize_type=config['weight_quantize_type'], + window_size=config['window_size'], + moving_rate=config['moving_rate'], + quantizable_op_type=transform_pass_ops, + skip_pattern=config['not_quant_pattern']) + + transform_pass.apply(main_graph) + + if len(quant_dequant_ops) > 0: + quant_dequant_pass = AddQuantDequantPass( + scope=scope, + place=place, + moving_rate=config['moving_rate'], + quant_bits=config['activation_bits'], + skip_pattern=config['not_quant_pattern'], + quantizable_op_type=quant_dequant_ops) + quant_dequant_pass.apply(main_graph) + + if for_test: + quant_program = main_graph.to_program() + else: + quant_program = fluid.CompiledProgram(main_graph.graph) + return quant_program + + +def quant_post(program, place, config, scope=None): + """ + add quantization ops in program. the program returned is not trainable. + Args: + program(fluid.Program): program + scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope(). + place(fluid.CPUPlace or fluid.CUDAPlace): place + config(dict): configs for quantization, default values are in quant_config_default dict. + for_test: is for test program. + Return: + fluid.Program: the quantization program is not trainable. + """ + pass + + +def convert(program, place, config, scope=None, save_int8=False): + """ + add quantization ops in program. the program returned is not trainable. + Args: + program(fluid.Program): program + scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope() + place(fluid.CPUPlace or fluid.CUDAPlace): place + config(dict): configs for quantization, default values are in quant_config_default dict. + save_int8: is export int8 freezed program. + Return: + fluid.Program: freezed program which can be used for inference. + parameters is float32 type, but it's value in int8 range. + fluid.Program: freezed int8 program which can be used for inference. + if save_int8 is False, this value is None. + """ + scope = fluid.global_scope() if not scope else scope + test_graph = IrGraph(core.Graph(program.desc), for_test=True) + + # Freeze the graph after training by adjusting the quantize + # operators' order for the inference. + freeze_pass = QuantizationFreezePass( + scope=scope, + place=place, + weight_quantize_type=config['weight_quantize_type']) + freeze_pass.apply(test_graph) + freezed_program = test_graph.to_program() + + if save_int8: + convert_int8_pass = ConvertToInt8Pass( + scope=fluid.global_scope(), place=place) + convert_int8_pass.apply(test_graph) + freezed_program_int8 = test_graph.to_program() + return freezed_program, freezed_program_int8 + else: + return freezed_program diff --git a/setup.py b/setup.py index d79620c5791c3a0144ca3aaa9f1d5d7b979dff31..5ff0a92fdd48668c9447d8625f122d93a168444c 100644 --- a/setup.py +++ b/setup.py @@ -33,8 +33,14 @@ with open('./requirements.txt') as f: setup_requires = f.read().splitlines() packages = [ - 'paddleslim', 'paddleslim.prune', 'paddleslim.dist', 'paddleslim.nas', - 'paddleslim.analysis', 'paddleslim.quant' + 'paddleslim', + 'paddleslim.prune', + 'paddleslim.dist', + 'paddleslim.nas', + 'paddleslim.analysis', + 'paddleslim.quant', + 'paddleslim.core', + 'paddleslim.common', ] setup( diff --git a/tests/layers.py b/tests/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..140ff5919b9d8c9821b371db5ca4896db28bf7f0 --- /dev/null +++ b/tests/layers.py @@ -0,0 +1,44 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + + +def conv_bn_layer(input, + num_filters, + filter_size, + name, + stride=1, + groups=1, + act=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False, + name=name + "_out") + bn_name = name + "_bn" + return fluid.layers.batch_norm( + input=conv, + act=act, + name=bn_name + '_output', + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', ) diff --git a/tests/test_auto_prune.py b/tests/test_auto_prune.py new file mode 100644 index 0000000000000000000000000000000000000000..c9cdc72c33ce683f2dc3ecbfdf406740ef6e69a8 --- /dev/null +++ b/tests/test_auto_prune.py @@ -0,0 +1,84 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle.fluid as fluid +from paddleslim.prune import AutoPruner +from paddleslim.analysis import flops +from layers import conv_bn_layer + + +class TestPrune(unittest.TestCase): + def test_prune(self): + main_program = fluid.Program() + startup_program = fluid.Program() + # X X O X O + # conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6 + # | ^ | ^ + # |____________| |____________________| + # + # X: prune output channels + # O: prune input channels + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[None, 3, 16, 16]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + sum1 = conv1 + conv2 + conv3 = conv_bn_layer(sum1, 8, 3, "conv3") + conv4 = conv_bn_layer(conv3, 8, 3, "conv4") + sum2 = conv4 + sum1 + conv5 = conv_bn_layer(sum2, 8, 3, "conv5") + conv6 = conv_bn_layer(conv5, 8, 3, "conv6") + + shapes = {} + for param in main_program.global_block().all_parameters(): + shapes[param.name] = param.shape + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.Scope() + exe.run(startup_program, scope=scope) + + pruned_flops = 0.5 + pruner = AutoPruner( + main_program, + scope, + place, + params=["conv4_weights"], + init_ratios=[0.5], + pruned_flops=0.5, + pruned_latency=None, + server_addr=("", 0), + init_temperature=100, + reduce_rate=0.85, + max_try_number=300, + max_client_num=10, + search_steps=2, + max_ratios=[0.9], + min_ratios=[0], + key="auto_pruner") + + base_flops = flops(main_program) + program = pruner.prune(main_program) + self.assertTrue(flops(program) <= base_flops * (1 - pruned_flops)) + pruner.reward(1) + + program = pruner.prune(main_program) + self.assertTrue(flops(program) <= base_flops * (1 - pruned_flops)) + pruner.reward(1) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_flops.py b/tests/test_flops.py new file mode 100644 index 0000000000000000000000000000000000000000..cd16b8618d0271e6a0b7e609f8820e16c380b9db --- /dev/null +++ b/tests/test_flops.py @@ -0,0 +1,40 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle.fluid as fluid +from paddleslim.analysis import flops +from layers import conv_bn_layer + + +class TestPrune(unittest.TestCase): + def test_prune(self): + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[None, 3, 16, 16]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + sum1 = conv1 + conv2 + conv3 = conv_bn_layer(sum1, 8, 3, "conv3") + conv4 = conv_bn_layer(conv3, 8, 3, "conv4") + sum2 = conv4 + sum1 + conv5 = conv_bn_layer(sum2, 8, 3, "conv5") + conv6 = conv_bn_layer(conv5, 8, 3, "conv6") + self.assertTrue(1597440 == flops(main_program)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_model_size.py b/tests/test_model_size.py new file mode 100644 index 0000000000000000000000000000000000000000..314450eb449507a971aa3827a8d841e88e10a69e --- /dev/null +++ b/tests/test_model_size.py @@ -0,0 +1,40 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle.fluid as fluid +from paddleslim.analysis import model_size +from layers import conv_bn_layer + + +class TestPrune(unittest.TestCase): + def test_prune(self): + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[None, 3, 16, 16]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + sum1 = conv1 + conv2 + conv3 = conv_bn_layer(sum1, 8, 3, "conv3") + conv4 = conv_bn_layer(conv3, 8, 3, "conv4") + sum2 = conv4 + sum1 + conv5 = conv_bn_layer(sum2, 8, 3, "conv5") + conv6 = conv_bn_layer(conv5, 8, 3, "conv6") + self.assertTrue(3288 == model_size(main_program)) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_nas_search_space.py b/tests/test_nas_search_space.py new file mode 100644 index 0000000000000000000000000000000000000000..ad373cf146fecb1cf9ea2b3681eaf73e9e65dd3d --- /dev/null +++ b/tests/test_nas_search_space.py @@ -0,0 +1,69 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append('..') +import unittest +import paddle.fluid as fluid +from nas.search_space_factory import SearchSpaceFactory + + +class TestSearchSpace(unittest.TestCase): + def test_searchspace(self): + # if output_size is 1, the model will add fc layer in the end. + config = {'input_size': 224, 'output_size': 7, 'block_num': 5} + space = SearchSpaceFactory() + + my_space = space.get_search_space([('MobileNetV2Space', config)]) + model_arch = my_space.token2arch() + + train_prog = fluid.Program() + startup_prog = fluid.Program() + with fluid.program_guard(train_prog, startup_prog): + input_size = config['input_size'] + model_input = fluid.layers.data( + name='model_in', + shape=[1, 3, input_size, input_size], + dtype='float32', + append_batch_size=False) + predict = model_arch[0](model_input) + self.assertTrue(predict.shape[2] == config['output_size']) + + +class TestMultiSearchSpace(unittest.TestCase): + space = SearchSpaceFactory() + + config0 = {'input_size': 224, 'output_size': 7, 'block_num': 5} + config1 = {'input_size': 7, 'output_size': 1, 'block_num': 2} + my_space = space.get_search_space( + [('MobileNetV2Space', config0), ('ResNetSpace', config1)]) + model_archs = my_space.token2arch() + + train_prog = fluid.Program() + startup_prog = fluid.Program() + with fluid.program_guard(train_prog, startup_prog): + input_size = config0['input_size'] + model_input = fluid.layers.data( + name='model_in', + shape=[1, 3, input_size, input_size], + dtype='float32', + append_batch_size=False) + for model_arch in model_archs: + predict = model_arch(model_input) + model_input = predict + print(predict) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_prune.py b/tests/test_prune.py new file mode 100644 index 0000000000000000000000000000000000000000..93609367351618ce375f164a1dca284e85369e4c --- /dev/null +++ b/tests/test_prune.py @@ -0,0 +1,79 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle.fluid as fluid +from prune import Pruner +from layers import conv_bn_layer + + +class TestPrune(unittest.TestCase): + def test_prune(self): + main_program = fluid.Program() + startup_program = fluid.Program() + # X X O X O + # conv1-->conv2-->sum1-->conv3-->conv4-->sum2-->conv5-->conv6 + # | ^ | ^ + # |____________| |____________________| + # + # X: prune output channels + # O: prune input channels + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[None, 3, 16, 16]) + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + sum1 = conv1 + conv2 + conv3 = conv_bn_layer(sum1, 8, 3, "conv3") + conv4 = conv_bn_layer(conv3, 8, 3, "conv4") + sum2 = conv4 + sum1 + conv5 = conv_bn_layer(sum2, 8, 3, "conv5") + conv6 = conv_bn_layer(conv5, 8, 3, "conv6") + + shapes = {} + for param in main_program.global_block().all_parameters(): + shapes[param.name] = param.shape + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.Scope() + exe.run(startup_program, scope=scope) + pruner = Pruner() + main_program = pruner.prune( + main_program, + scope, + params=["conv4_weights"], + ratios=[0.5], + place=place, + lazy=False, + only_graph=False, + param_backup=None, + param_shape_backup=None) + + shapes = { + "conv1_weights": (4L, 3L, 3L, 3L), + "conv2_weights": (4L, 4L, 3L, 3L), + "conv3_weights": (8L, 4L, 3L, 3L), + "conv4_weights": (4L, 8L, 3L, 3L), + "conv5_weights": (8L, 4L, 3L, 3L), + "conv6_weights": (8L, 8L, 3L, 3L) + } + + for param in main_program.global_block().all_parameters(): + if "weights" in param.name: + self.assertTrue(param.shape == shapes[param.name]) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_sa_nas.py b/tests/test_sa_nas.py new file mode 100644 index 0000000000000000000000000000000000000000..5666e1410a820c09bc10fa0b10d282434c7837fe --- /dev/null +++ b/tests/test_sa_nas.py @@ -0,0 +1,58 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle.fluid as fluid +from paddleslim.nas import SANAS +from paddleslim.nas import SearchSpaceFactory +from paddleslim.analysis import flops + + +class TestSANAS(unittest.TestCase): + def test_nas(self): + + factory = SearchSpaceFactory() + config0 = {'input_size': 224, 'output_size': 7, 'block_num': 5} + config1 = {'input_size': 7, 'output_size': 1, 'block_num': 2} + configs = [('MobileNetV2Space', config0), ('ResNetSpace', config1)] + + space = factory.get_search_space([('MobileNetV2Space', config0)]) + origin_arch = space.token2arch()[0] + + main_program = fluid.Program() + s_program = fluid.Program() + with fluid.program_guard(main_program, s_program): + input = fluid.data( + name="input", shape=[None, 3, 224, 224], dtype="float32") + origin_arch(input) + base_flops = flops(main_program) + + search_steps = 3 + sa_nas = SANAS(configs, search_steps=search_steps, is_server=True) + + for i in range(search_steps): + archs = sa_nas.next_archs() + main_program = fluid.Program() + s_program = fluid.Program() + with fluid.program_guard(main_program, s_program): + input = fluid.data( + name="input", shape=[None, 3, 224, 224], dtype="float32") + archs[0](input) + sa_nas.reward(1) + self.assertTrue(flops(main_program) < base_flops) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_sensitivity.py b/tests/test_sensitivity.py new file mode 100644 index 0000000000000000000000000000000000000000..e2cfa01d889db2891fd7507b2d4d9aec018a1163 --- /dev/null +++ b/tests/test_sensitivity.py @@ -0,0 +1,69 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import numpy +import paddle +import paddle.fluid as fluid +from paddleslim.analysis import sensitivity +from layers import conv_bn_layer + + +class TestSensitivity(unittest.TestCase): + def test_sensitivity(self): + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + input = fluid.data(name="image", shape=[None, 1, 28, 28]) + label = fluid.data(name="label", shape=[None, 1], dtype="int64") + conv1 = conv_bn_layer(input, 8, 3, "conv1") + conv2 = conv_bn_layer(conv1, 8, 3, "conv2") + sum1 = conv1 + conv2 + conv3 = conv_bn_layer(sum1, 8, 3, "conv3") + conv4 = conv_bn_layer(conv3, 8, 3, "conv4") + sum2 = conv4 + sum1 + conv5 = conv_bn_layer(sum2, 8, 3, "conv5") + conv6 = conv_bn_layer(conv5, 8, 3, "conv6") + out = fluid.layers.fc(conv6, size=10, act='softmax') + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + eval_program = main_program.clone(for_test=True) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(startup_program) + + val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) + + def eval_func(program, scope): + feeder = fluid.DataFeeder( + feed_list=['image', 'label'], place=place, program=program) + acc_set = [] + for data in val_reader(): + acc_np = exe.run(program=program, + scope=scope, + feed=feeder.feed(data), + fetch_list=[acc_top1]) + acc_set.append(float(acc_np[0])) + acc_val_mean = numpy.array(acc_set).mean() + print("acc_val_mean: {}".format(acc_val_mean)) + return acc_val_mean + + sensitivity(eval_program, + fluid.global_scope(), place, ["conv4_weights"], eval_func, + "./sensitivities_file") + + +if __name__ == '__main__': + unittest.main()