From e18a3a459776ccd28a7e3aa37f621041b67df0bd Mon Sep 17 00:00:00 2001 From: whs Date: Wed, 27 Mar 2019 14:28:10 +0800 Subject: [PATCH] Add demos for PaddleSlim on classification task. (#1923) * Add demos for PaddleSlim on classification task. * Refine structure of directory. --- fluid/PaddleSlim/compress.py | 154 ++++++++++++++ .../configs/filter_pruning_sen.yaml | 25 +++ .../configs/filter_pruning_uniform.yaml | 21 ++ .../mobilenetv1_resnet50_distillation.yaml | 23 ++ fluid/PaddleSlim/configs/quantization.yaml | 20 ++ .../data/pretrain/download_pretrain.sh | 9 + fluid/PaddleSlim/models/mobilenet.py | 197 ++++++++++++++++++ fluid/PaddleSlim/models/resnet.py | 162 ++++++++++++++ fluid/PaddleSlim/reader.py | 188 +++++++++++++++++ fluid/PaddleSlim/run.sh | 36 ++++ fluid/PaddleSlim/utility.py | 63 ++++++ 11 files changed, 898 insertions(+) create mode 100644 fluid/PaddleSlim/compress.py create mode 100644 fluid/PaddleSlim/configs/filter_pruning_sen.yaml create mode 100644 fluid/PaddleSlim/configs/filter_pruning_uniform.yaml create mode 100644 fluid/PaddleSlim/configs/mobilenetv1_resnet50_distillation.yaml create mode 100644 fluid/PaddleSlim/configs/quantization.yaml create mode 100644 fluid/PaddleSlim/data/pretrain/download_pretrain.sh create mode 100644 fluid/PaddleSlim/models/mobilenet.py create mode 100644 fluid/PaddleSlim/models/resnet.py create mode 100644 fluid/PaddleSlim/reader.py create mode 100644 fluid/PaddleSlim/run.sh create mode 100644 fluid/PaddleSlim/utility.py diff --git a/fluid/PaddleSlim/compress.py b/fluid/PaddleSlim/compress.py new file mode 100644 index 00000000..14b301fc --- /dev/null +++ b/fluid/PaddleSlim/compress.py @@ -0,0 +1,154 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import time +import sys +import logging +import paddle +import models +import argparse +import functools +import paddle.fluid as fluid +import reader +from utility import add_arguments, print_arguments + +from paddle.fluid.contrib.slim import Compressor + +logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('batch_size', int, 64*4, "Minibatch size.") +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('class_dim', int, 1000, "Class number.") +add_arg('image_shape', str, "3,224,224", "Input image size") +add_arg('model', str, "MobileNet", "Set the network to use.") +add_arg('pretrained_model', str, None, "Whether to use pretrained model.") +add_arg('teacher_model', str, None, "Set the teacher network to use.") +add_arg('teacher_pretrained_model', str, None, "Whether to use pretrained model.") +add_arg('compress_config', str, None, "The config file for compression with yaml format.") +# yapf: enable + +model_list = [m for m in dir(models) if "__" not in m] + + +def compress(args): + image_shape = [int(m) for m in args.image_shape.split(",")] + + assert args.model in model_list, "{} is not in lists: {}".format(args.model, + model_list) + image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + # model definition + model = models.__dict__[args.model]() + + if args.model is "GoogleNet": + out0, out1, out2 = model.net(input=image, class_dim=args.class_dim) + cost0 = fluid.layers.cross_entropy(input=out0, label=label) + cost1 = fluid.layers.cross_entropy(input=out1, label=label) + cost2 = fluid.layers.cross_entropy(input=out2, label=label) + avg_cost0 = fluid.layers.mean(x=cost0) + avg_cost1 = fluid.layers.mean(x=cost1) + avg_cost2 = fluid.layers.mean(x=cost2) + avg_cost = avg_cost0 + 0.3 * avg_cost1 + 0.3 * avg_cost2 + acc_top1 = fluid.layers.accuracy(input=out0, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out0, label=label, k=5) + else: + out = model.net(input=image, class_dim=args.class_dim) + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + val_program = fluid.default_main_program().clone() + + opt = fluid.optimizer.Momentum( + momentum=0.9, + learning_rate=fluid.layers.piecewise_decay( + boundaries=[5000 * 30, 5000 * 60, 5000 * 90], + values=[0.01, 0.001, 0.0001, 0.00001]), + regularization=fluid.regularizer.L2Decay(4e-5)) + + # opt = fluid.optimizer.Momentum( + # momentum=0.9, + # learning_rate=0.01, + # regularization=fluid.regularizer.L2Decay(4e-5)) + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + if args.pretrained_model: + + def if_exist(var): + return os.path.exists(os.path.join(args.pretrained_model, var.name)) + + fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist) + + val_reader = paddle.batch(reader.val(), batch_size=args.batch_size) + val_feed_list = [('image', image.name), ('label', label.name)] + val_fetch_list = [('acc_top1', acc_top1.name), ('acc_top5', acc_top5.name)] + + train_reader = paddle.batch( + reader.train(), batch_size=args.batch_size, drop_last=True) + train_feed_list = [('image', image.name), ('label', label.name)] + train_fetch_list = [('loss', avg_cost.name)] + + teacher_programs = [] + distiller_optimizer = None + if args.teacher_model: + teacher_model = models.__dict__[args.teacher_model]() + # define teacher program + teacher_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(teacher_program, startup_program): + img = teacher_program.global_block()._clone_variable( + image, force_persistable=False) + predict = teacher_model.net(img, class_dim=args.class_dim) + exe.run(startup_program) + assert args.teacher_pretrained_model and os.path.exists( + args.teacher_pretrained_model + ), "teacher_pretrained_model should be set when teacher_model is not None." + + def if_exist(var): + return os.path.exists( + os.path.join(args.teacher_pretrained_model, var.name)) + + fluid.io.load_vars( + exe, + args.teacher_pretrained_model, + main_program=teacher_program, + predicate=if_exist) + + distiller_optimizer = opt + teacher_programs.append(teacher_program.clone(for_test=True)) + + com_pass = Compressor( + place, + fluid.global_scope(), + fluid.default_main_program(), + train_reader=train_reader, + train_feed_list=train_feed_list, + train_fetch_list=train_fetch_list, + eval_program=val_program, + eval_reader=val_reader, + eval_feed_list=val_feed_list, + eval_fetch_list=val_fetch_list, + teacher_programs=teacher_programs, + train_optimizer=opt, + distiller_optimizer=distiller_optimizer) + com_pass.config(args.compress_config) + com_pass.run() + + +def main(): + args = parser.parse_args() + print_arguments(args) + compress(args) + + +if __name__ == '__main__': + main() diff --git a/fluid/PaddleSlim/configs/filter_pruning_sen.yaml b/fluid/PaddleSlim/configs/filter_pruning_sen.yaml new file mode 100644 index 00000000..4537405f --- /dev/null +++ b/fluid/PaddleSlim/configs/filter_pruning_sen.yaml @@ -0,0 +1,25 @@ +version: 1.0 +pruners: + pruner_1: + class: 'StructurePruner' + pruning_axis: + '*': 0 + criterions: + '*': 'l1_norm' +strategies: + sensitive_pruning_strategy: + class: 'SensitivePruneStrategy' + pruner: 'pruner_1' + start_epoch: 0 + delta_rate: 0.1 + target_ratio: 0.5 + num_steps: 1 +# eval_rate: 0.2 + pruned_params: '.*_sep_weights' + sensitivities_file: 'mobilenet_acc_top1_sensitive.data' + metric_name: 'acc_top1' +compressor: + epoch: 200 + checkpoint_path: './checkpoints/' + strategies: + - sensitive_pruning_strategy diff --git a/fluid/PaddleSlim/configs/filter_pruning_uniform.yaml b/fluid/PaddleSlim/configs/filter_pruning_uniform.yaml new file mode 100644 index 00000000..1dea1070 --- /dev/null +++ b/fluid/PaddleSlim/configs/filter_pruning_uniform.yaml @@ -0,0 +1,21 @@ +version: 1.0 +pruners: + pruner_1: + class: 'StructurePruner' + pruning_axis: + '*': 0 + criterions: + '*': 'l1_norm' +strategies: + uniform_pruning_strategy: + class: 'UniformPruneStrategy' + pruner: 'pruner_1' + start_epoch: 0 + target_ratio: 0.5 + pruned_params: '.*_sep_weights' + metric_name: 'acc_top1' +compressor: + epoch: 200 + checkpoint_path: './checkpoints/' + strategies: + - uniform_pruning_strategy diff --git a/fluid/PaddleSlim/configs/mobilenetv1_resnet50_distillation.yaml b/fluid/PaddleSlim/configs/mobilenetv1_resnet50_distillation.yaml new file mode 100644 index 00000000..c825edf5 --- /dev/null +++ b/fluid/PaddleSlim/configs/mobilenetv1_resnet50_distillation.yaml @@ -0,0 +1,23 @@ +version: 1.0 +distillers: + fsp_distiller: + class: 'FSPDistiller' + teacher_pairs: [['res2a_branch2a.conv2d.output.1.tmp_0', 'res3a_branch2a.conv2d.output.1.tmp_0']] + student_pairs: [['depthwise_conv2d_1.tmp_0', 'conv2d_3.tmp_0']] + distillation_loss_weight: 1 + l2_distiller: + class: 'L2Distiller' + teacher_feature_map: 'fc_1.tmp_0' + student_feature_map: 'fc_0.tmp_0' + distillation_loss_weight: 1 +strategies: + distillation_strategy: + class: 'DistillationStrategy' + distillers: ['fsp_distiller', 'l2_distiller'] + start_epoch: 0 + end_epoch: 130 +compressor: + epoch: 130 + checkpoint_path: './checkpoints/' + strategies: + - distillation_strategy diff --git a/fluid/PaddleSlim/configs/quantization.yaml b/fluid/PaddleSlim/configs/quantization.yaml new file mode 100644 index 00000000..b8e74cdd --- /dev/null +++ b/fluid/PaddleSlim/configs/quantization.yaml @@ -0,0 +1,20 @@ +version: 1.0 +strategies: + quantization_strategy: + class: 'QuantizationStrategy' + start_epoch: 0 + end_epoch: 0 + float_model_save_path: './output/float' +# mobile_model_save_path: './output/mobile' +# int8_model_save_path: './output/int8' + weight_bits: 8 + activation_bits: 8 + weight_quantize_type: 'abs_max' + activation_quantize_type: 'abs_max' + save_in_nodes: ['image'] + save_out_nodes: ['fc_0.tmp_2'] +compressor: + epoch: 6 + checkpoint_path: './checkpoints_quan/' + strategies: + - quantization_strategy diff --git a/fluid/PaddleSlim/data/pretrain/download_pretrain.sh b/fluid/PaddleSlim/data/pretrain/download_pretrain.sh new file mode 100644 index 00000000..92cb48da --- /dev/null +++ b/fluid/PaddleSlim/data/pretrain/download_pretrain.sh @@ -0,0 +1,9 @@ +root_url="http://paddle-imagenet-models-name.bj.bcebos.com" +MobileNetV1="MobileNetV1_pretrained.zip" +ResNet50="ResNet50_pretrained.zip" + +wget ${root_url}/${MobileNetV1} +unzip ${MobileNetV1} + +wget ${root_url}/${ResNet50} +unzip ${ResNet50} diff --git a/fluid/PaddleSlim/models/mobilenet.py b/fluid/PaddleSlim/models/mobilenet.py new file mode 100644 index 00000000..5c4b16a5 --- /dev/null +++ b/fluid/PaddleSlim/models/mobilenet.py @@ -0,0 +1,197 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle.fluid as fluid +from paddle.fluid.initializer import MSRA +from paddle.fluid.param_attr import ParamAttr + +__all__ = ['MobileNet'] + +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + + +class MobileNet(): + def __init__(self): + self.params = train_parameters + + def net(self, input, class_dim=1000, scale=1.0): + # conv1: 112x112 + input = self.conv_bn_layer( + input, + filter_size=3, + channels=3, + num_filters=int(32 * scale), + stride=2, + padding=1, + name="conv1") + + # 56x56 + input = self.depthwise_separable( + input, + num_filters1=32, + num_filters2=64, + num_groups=32, + stride=1, + scale=scale, + name="conv2_1") + + input = self.depthwise_separable( + input, + num_filters1=64, + num_filters2=128, + num_groups=64, + stride=2, + scale=scale, + name="conv2_2") + + # 28x28 + input = self.depthwise_separable( + input, + num_filters1=128, + num_filters2=128, + num_groups=128, + stride=1, + scale=scale, + name="conv3_1") + + input = self.depthwise_separable( + input, + num_filters1=128, + num_filters2=256, + num_groups=128, + stride=2, + scale=scale, + name="conv3_2") + + # 14x14 + input = self.depthwise_separable( + input, + num_filters1=256, + num_filters2=256, + num_groups=256, + stride=1, + scale=scale, + name="conv4_1") + + input = self.depthwise_separable( + input, + num_filters1=256, + num_filters2=512, + num_groups=256, + stride=2, + scale=scale, + name="conv4_2") + + # 14x14 + for i in range(5): + input = self.depthwise_separable( + input, + num_filters1=512, + num_filters2=512, + num_groups=512, + stride=1, + scale=scale, + name="conv5" + "_" + str(i + 1)) + # 7x7 + input = self.depthwise_separable( + input, + num_filters1=512, + num_filters2=1024, + num_groups=512, + stride=2, + scale=scale, + name="conv5_6") + + input = self.depthwise_separable( + input, + num_filters1=1024, + num_filters2=1024, + num_groups=1024, + stride=1, + scale=scale, + name="conv6") + + input = fluid.layers.pool2d( + input=input, + pool_size=0, + pool_stride=1, + pool_type='avg', + global_pooling=True) + + output = fluid.layers.fc(input=input, + size=class_dim, + act='softmax', + param_attr=ParamAttr( + initializer=MSRA(), name="fc7_weights"), + bias_attr=ParamAttr(name="fc7_offset")) + + return output + + def conv_bn_layer(self, + input, + filter_size, + num_filters, + stride, + padding, + channels=None, + num_groups=1, + act='relu', + use_cudnn=True, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=padding, + groups=num_groups, + act=None, + use_cudnn=use_cudnn, + param_attr=ParamAttr( + initializer=MSRA(), name=name + "_weights"), + bias_attr=False) + bn_name = name + "_bn" + return fluid.layers.batch_norm( + input=conv, + act=act, + param_attr=ParamAttr(name=bn_name + "_scale"), + bias_attr=ParamAttr(name=bn_name + "_offset"), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') + + def depthwise_separable(self, + input, + num_filters1, + num_filters2, + num_groups, + stride, + scale, + name=None): + depthwise_conv = self.conv_bn_layer( + input=input, + filter_size=3, + num_filters=int(num_filters1 * scale), + stride=stride, + padding=1, + num_groups=int(num_groups * scale), + use_cudnn=False, + name=name + "_dw") + + pointwise_conv = self.conv_bn_layer( + input=depthwise_conv, + filter_size=1, + num_filters=int(num_filters2 * scale), + stride=1, + padding=0, + name=name + "_sep") + return pointwise_conv diff --git a/fluid/PaddleSlim/models/resnet.py b/fluid/PaddleSlim/models/resnet.py new file mode 100644 index 00000000..17232552 --- /dev/null +++ b/fluid/PaddleSlim/models/resnet.py @@ -0,0 +1,162 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import paddle +import paddle.fluid as fluid +import math +from paddle.fluid.param_attr import ParamAttr + +__all__ = ["ResNet", "ResNet50", "ResNet101", "ResNet152"] + +train_parameters = { + "input_size": [3, 224, 224], + "input_mean": [0.485, 0.456, 0.406], + "input_std": [0.229, 0.224, 0.225], + "learning_strategy": { + "name": "piecewise_decay", + "batch_size": 256, + "epochs": [30, 60, 90], + "steps": [0.1, 0.01, 0.001, 0.0001] + } +} + + +class ResNet(): + def __init__(self, layers=50): + self.params = train_parameters + self.layers = layers + + def net(self, input, class_dim=1000): + layers = self.layers + supported_layers = [50, 101, 152] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format(supported_layers, layers) + + if layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + num_filters = [64, 128, 256, 512] + + conv = self.conv_bn_layer( + input=input, + num_filters=64, + filter_size=7, + stride=2, + act='relu', + name="res_conv1") #debug + conv = fluid.layers.pool2d( + input=conv, + pool_size=3, + pool_stride=2, + pool_padding=1, + pool_type='max') + + for block in range(len(depth)): + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + conv = self.bottleneck_block( + input=conv, + num_filters=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + name=conv_name) + + pool = fluid.layers.pool2d( + input=conv, pool_size=7, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + out = fluid.layers.fc(input=pool, + size=class_dim, + act='softmax', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, + stdv))) + return out + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + conv = fluid.layers.conv2d( + input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False, + name=name + '.conv2d.output.1') + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + return fluid.layers.batch_norm( + input=conv, + act=act, + name=bn_name + '.output.1', + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance', ) + + def shortcut(self, input, ch_out, stride, name): + ch_in = input.shape[1] + if ch_in != ch_out or stride != 1: + return self.conv_bn_layer(input, ch_out, 1, stride, name=name) + else: + return input + + def bottleneck_block(self, input, num_filters, stride, name): + conv0 = self.conv_bn_layer( + input=input, + num_filters=num_filters, + filter_size=1, + act='relu', + name=name + "_branch2a") + conv1 = self.conv_bn_layer( + input=conv0, + num_filters=num_filters, + filter_size=3, + stride=stride, + act='relu', + name=name + "_branch2b") + conv2 = self.conv_bn_layer( + input=conv1, + num_filters=num_filters * 4, + filter_size=1, + act=None, + name=name + "_branch2c") + + short = self.shortcut( + input, num_filters * 4, stride, name=name + "_branch1") + + return fluid.layers.elementwise_add( + x=short, y=conv2, act='relu', name=name + ".add.output.5") + + +def ResNet50(): + model = ResNet(layers=50) + return model + + +def ResNet101(): + model = ResNet(layers=101) + return model + + +def ResNet152(): + model = ResNet(layers=152) + return model diff --git a/fluid/PaddleSlim/reader.py b/fluid/PaddleSlim/reader.py new file mode 100644 index 00000000..f4a9da1e --- /dev/null +++ b/fluid/PaddleSlim/reader.py @@ -0,0 +1,188 @@ +import os +import math +import random +import functools +import numpy as np +import paddle +from PIL import Image, ImageEnhance + +random.seed(0) +np.random.seed(0) + +DATA_DIM = 224 + +THREAD = 16 +BUF_SIZE = 10240 + +DATA_DIR = 'data/ILSVRC2012' + +img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) +img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) + + +def resize_short(img, target_size): + percent = float(target_size) / min(img.size[0], img.size[1]) + resized_width = int(round(img.size[0] * percent)) + resized_height = int(round(img.size[1] * percent)) + img = img.resize((resized_width, resized_height), Image.LANCZOS) + return img + + +def crop_image(img, target_size, center): + width, height = img.size + size = target_size + if center == True: + w_start = (width - size) / 2 + h_start = (height - size) / 2 + else: + w_start = np.random.randint(0, width - size + 1) + h_start = np.random.randint(0, height - size + 1) + w_end = w_start + size + h_end = h_start + size + img = img.crop((w_start, h_start, w_end, h_end)) + return img + + +def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]): + aspect_ratio = math.sqrt(np.random.uniform(*ratio)) + w = 1. * aspect_ratio + h = 1. / aspect_ratio + + bound = min((float(img.size[0]) / img.size[1]) / (w**2), + (float(img.size[1]) / img.size[0]) / (h**2)) + scale_max = min(scale[1], bound) + scale_min = min(scale[0], bound) + + target_area = img.size[0] * img.size[1] * np.random.uniform(scale_min, + scale_max) + target_size = math.sqrt(target_area) + w = int(target_size * w) + h = int(target_size * h) + + i = np.random.randint(0, img.size[0] - w + 1) + j = np.random.randint(0, img.size[1] - h + 1) + + img = img.crop((i, j, i + w, j + h)) + img = img.resize((size, size), Image.LANCZOS) + return img + + +def rotate_image(img): + angle = np.random.randint(-10, 11) + img = img.rotate(angle) + return img + + +def distort_color(img): + def random_brightness(img, lower=0.5, upper=1.5): + e = np.random.uniform(lower, upper) + return ImageEnhance.Brightness(img).enhance(e) + + def random_contrast(img, lower=0.5, upper=1.5): + e = np.random.uniform(lower, upper) + return ImageEnhance.Contrast(img).enhance(e) + + def random_color(img, lower=0.5, upper=1.5): + e = np.random.uniform(lower, upper) + return ImageEnhance.Color(img).enhance(e) + + ops = [random_brightness, random_contrast, random_color] + np.random.shuffle(ops) + + img = ops[0](img) + img = ops[1](img) + img = ops[2](img) + + return img + + +def process_image(sample, mode, color_jitter, rotate): + img_path = sample[0] + + img = Image.open(img_path) + if mode == 'train': + if rotate: img = rotate_image(img) + img = random_crop(img, DATA_DIM) + else: + img = resize_short(img, target_size=256) + img = crop_image(img, target_size=DATA_DIM, center=True) + if mode == 'train': + if color_jitter: + img = distort_color(img) + if np.random.randint(0, 2) == 1: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + + if img.mode != 'RGB': + img = img.convert('RGB') + + img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255 + img -= img_mean + img /= img_std + + if mode == 'train' or mode == 'val': + return img, sample[1] + elif mode == 'test': + return [img] + + +def _reader_creator(file_list, + mode, + shuffle=False, + color_jitter=False, + rotate=False, + data_dir=DATA_DIR, + batch_size=1): + def reader(): + with open(file_list) as flist: + full_lines = [line.strip() for line in flist] + if shuffle: + np.random.shuffle(full_lines) + if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'): + # distributed mode if the env var `PADDLE_TRAINING_ROLE` exits + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + trainer_count = int(os.getenv("PADDLE_TRAINERS", "1")) + per_node_lines = len(full_lines) // trainer_count + lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1) + * per_node_lines] + print( + "read images from %d, length: %d, lines length: %d, total: %d" + % (trainer_id * per_node_lines, per_node_lines, len(lines), + len(full_lines))) + else: + lines = full_lines + + for line in lines: + if mode == 'train' or mode == 'val': + img_path, label = line.split() + # img_path = img_path.replace("JPEG", "jpeg") + img_path = os.path.join(data_dir, img_path) + yield img_path, int(label) + elif mode == 'test': + img_path = os.path.join(data_dir, line) + yield [img_path] + + mapper = functools.partial( + process_image, mode=mode, color_jitter=color_jitter, rotate=rotate) + + return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) + + +def train(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'train_list.txt') + return _reader_creator( + file_list, + 'train', + shuffle=True, + color_jitter=False, + rotate=False, + data_dir=data_dir) + + +def val(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'val_list.txt') + return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir) + + +def test(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'val_list.txt') + return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir) diff --git a/fluid/PaddleSlim/run.sh b/fluid/PaddleSlim/run.sh new file mode 100644 index 00000000..4fa57cd8 --- /dev/null +++ b/fluid/PaddleSlim/run.sh @@ -0,0 +1,36 @@ + +# for distillation +#-------------------- +export CUDA_VISIBLE_DEVICES=0 +python compress.py \ +--model "MobileNet" \ +--teacher_model "ResNet50" \ +--teacher_pretrained_model ./data/pretrain/ResNet50_pretrained \ +--compress_config ./configs/mobilenetv1_resnet50_distillation.yaml + + +# for sensitivity filter pruning +#--------------------------- +#export CUDA_VISIBLE_DEVICES=0 +#python compress.py \ +#--model "MobileNet" \ +#--pretrained_model ./data/pretrain/MobileNetV1_pretrained \ +#--compress_config ./configs/filter_pruning_sen.yaml + +# for uniform filter pruning +#--------------------------- +#export CUDA_VISIBLE_DEVICES=0 +#python compress.py \ +#--model "MobileNet" \ +#--pretrained_model ./data/pretrain/MobileNetV1_pretrained \ +#--compress_config ./configs/filter_pruning_uniform.yaml + +# for quantization +#--------------------------- +#export CUDA_VISIBLE_DEVICES=0 +#python compress.py \ +#--batch_size 64 \ +#--model "MobileNet" \ +#--pretrained_model ./data/pretrain/MobileNetV1_pretrained \ +#--compress_config ./configs/quantization.yaml + diff --git a/fluid/PaddleSlim/utility.py b/fluid/PaddleSlim/utility.py new file mode 100644 index 00000000..5b10a179 --- /dev/null +++ b/fluid/PaddleSlim/utility.py @@ -0,0 +1,63 @@ +"""Contains common utility functions.""" +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import distutils.util +import numpy as np +import six +from paddle.fluid import core + + +def print_arguments(args): + """Print argparse's arguments. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + parser.add_argument("name", default="Jonh", type=str, help="User name.") + args = parser.parse_args() + print_arguments(args) + + :param args: Input argparse.Namespace for printing. + :type args: argparse.Namespace + """ + print("----------- Configuration Arguments -----------") + for arg, value in sorted(six.iteritems(vars(args))): + print("%s: %s" % (arg, value)) + print("------------------------------------------------") + + +def add_arguments(argname, type, default, help, argparser, **kwargs): + """Add argparse's argument. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + add_argument("name", str, "Jonh", "User name.", parser) + args = parser.parse_args() + """ + type = distutils.util.strtobool if type == bool else type + argparser.add_argument( + "--" + argname, + default=default, + type=type, + help=help + ' Default: %(default)s.', + **kwargs) -- GitLab