diff --git a/demo/distillation/train.py b/demo/distillation/train.py new file mode 100644 index 0000000000000000000000000000000000000000..7f389168440a59f0872d44ab6e62f262e373f6f0 --- /dev/null +++ b/demo/distillation/train.py @@ -0,0 +1,238 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import math +import logging +import paddle +import argparse +import functools +import numpy as np +import paddle.fluid as fluid +sys.path.append(sys.path[0] + "/../") +import models +import imagenet_reader as reader +from utility import add_arguments, print_arguments +from paddleslim.dist import merge, l2_loss, soft_label_loss, fsp_loss + +logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') +_logger = logging.getLogger(__name__) +_logger.setLevel(logging.INFO) + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('batch_size', int, 64*4, "Minibatch size.") +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('total_images', int, 1281167, "Training image number.") +add_arg('image_shape', str, "3,224,224", "Input image size") +add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.") +add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.") +add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.") +add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") +add_arg('num_epochs', int, 120, "The number of total epochs.") +add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'") +add_arg('log_period', int, 20, "Log period in batches.") +add_arg('model', str, "MobileNet", "Set the network to use.") +add_arg('pretrained_model', str, None, "Whether to use pretrained model.") +add_arg('teacher_model', str, "ResNet50", "Set the teacher network to use.") +add_arg('teacher_pretrained_model', str, "../pretrain/ResNet50_pretrained", "Whether to use pretrained model.") +parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") +# yapf: enable + +model_list = [m for m in dir(models) if "__" not in m] + + +def piecewise_decay(args): + step = int(math.ceil(float(args.total_images) / args.batch_size)) + bd = [step * e for e in args.step_epochs] + lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)] + learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=args.momentum_rate, + regularization=fluid.regularizer.L2Decay(args.l2_decay)) + return optimizer + + +def cosine_decay(args): + step = int(math.ceil(float(args.total_images) / args.batch_size)) + learning_rate = fluid.layers.cosine_decay( + learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=args.momentum_rate, + regularization=fluid.regularizer.L2Decay(args.l2_decay)) + return optimizer + + +def create_optimizer(args): + if args.lr_strategy == "piecewise_decay": + return piecewise_decay(args) + elif args.lr_strategy == "cosine_decay": + return cosine_decay(args) + + +def compress(args): + if args.data == "mnist": + import paddle.dataset.mnist as reader + train_reader = reader.train() + val_reader = reader.test() + class_dim = 10 + image_shape = "1,28,28" + elif args.data == "imagenet": + import imagenet_reader as reader + train_reader = reader.train() + val_reader = reader.val() + class_dim = 1000 + image_shape = "3,224,224" + else: + raise ValueError("{} is not supported.".format(args.data)) + image_shape = [int(m) for m in image_shape.split(",")] + + assert args.model in model_list, "{} is not in lists: {}".format( + args.model, model_list) + student_program = fluid.Program() + s_startup = fluid.Program() + + with fluid.program_guard(student_program, s_startup): + with fluid.unique_name.guard(): + image = fluid.layers.data( + name='image', shape=image_shape, dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + train_loader = fluid.io.DataLoader.from_generator( + feed_list=[image, label], + capacity=64, + use_double_buffer=True, + iterable=True) + valid_loader = fluid.io.DataLoader.from_generator( + feed_list=[image, label], + capacity=64, + use_double_buffer=True, + iterable=True) + # model definition + model = models.__dict__[args.model]() + out = model.net(input=image, class_dim=class_dim) + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + #print("="*50+"student_model_params"+"="*50) + #for v in student_program.list_vars(): + # print(v.name, v.shape) + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + train_reader = paddle.batch( + train_reader, batch_size=args.batch_size, drop_last=True) + val_reader = paddle.batch( + val_reader, batch_size=args.batch_size, drop_last=True) + val_program = student_program.clone(for_test=True) + + places = fluid.cuda_places() + train_loader.set_sample_list_generator(train_reader, places) + valid_loader.set_sample_list_generator(val_reader, place) + + teacher_model = models.__dict__[args.teacher_model]() + # define teacher program + teacher_program = fluid.Program() + t_startup = fluid.Program() + teacher_scope = fluid.Scope() + with fluid.scope_guard(teacher_scope): + with fluid.program_guard(teacher_program, t_startup): + with fluid.unique_name.guard(): + image = fluid.layers.data( + name='image', shape=image_shape, dtype='float32') + predict = teacher_model.net(image, class_dim=class_dim) + + #print("="*50+"teacher_model_params"+"="*50) + #for v in teacher_program.list_vars(): + # print(v.name, v.shape) + + exe.run(t_startup) + assert args.teacher_pretrained_model and os.path.exists( + args.teacher_pretrained_model + ), "teacher_pretrained_model should be set when teacher_model is not None." + + def if_exist(var): + return os.path.exists( + os.path.join(args.teacher_pretrained_model, var.name) + ) and var.name != 'conv1_weights' and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0' + + fluid.io.load_vars( + exe, + args.teacher_pretrained_model, + main_program=teacher_program, + predicate=if_exist) + + data_name_map = {'image': 'image'} + main = merge( + teacher_program, + student_program, + data_name_map, + place, + teacher_scope=teacher_scope) + + #print("="*50+"teacher_vars"+"="*50) + #for v in teacher_program.list_vars(): + # if '_generated_var' not in v.name and 'fetch' not in v.name and 'feed' not in v.name: + # print(v.name, v.shape) + #return + + with fluid.program_guard(main, s_startup): + l2_loss_v = l2_loss("teacher_fc_0.tmp_0", "fc_0.tmp_0", main) + fsp_loss_v = fsp_loss("teacher_res2a_branch2a.conv2d.output.1.tmp_0", + "teacher_res3a_branch2a.conv2d.output.1.tmp_0", + "depthwise_conv2d_1.tmp_0", "conv2d_3.tmp_0", + main) + loss = avg_cost + l2_loss_v + fsp_loss_v + opt = create_optimizer(args) + opt.minimize(loss) + exe.run(s_startup) + build_strategy = fluid.BuildStrategy() + build_strategy.fuse_all_reduce_ops = False + parallel_main = fluid.CompiledProgram(main).with_data_parallel( + loss_name=loss.name, build_strategy=build_strategy) + + for epoch_id in range(args.num_epochs): + for step_id, data in enumerate(train_loader): + loss_1, loss_2, loss_3, loss_4 = exe.run( + parallel_main, + feed=data, + fetch_list=[ + loss.name, avg_cost.name, l2_loss_v.name, fsp_loss_v.name + ]) + if step_id % args.log_period == 0: + _logger.info( + "train_epoch {} step {} loss {:.6f}, class loss {:.6f}, l2 loss {:.6f}, fsp loss {:.6f}". + format(epoch_id, step_id, loss_1[0], loss_2[0], loss_3[0], + loss_4[0])) + val_acc1s = [] + val_acc5s = [] + for step_id, data in enumerate(valid_loader): + val_loss, val_acc1, val_acc5 = exe.run( + val_program, + data, + fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name]) + val_acc1s.append(val_acc1) + val_acc5s.append(val_acc5) + if step_id % args.log_period == 0: + _logger.info( + "valid_epoch {} step {} loss {:.6f}, top1 {:.6f}, top5 {:.6f}". + format(epoch_id, step_id, val_loss[0], val_acc1[0], + val_acc5[0])) + _logger.info("epoch {} top1 {:.6f}, top5 {:.6f}".format( + epoch_id, np.mean(val_acc1s), np.mean(val_acc5s))) + + +def main(): + args = parser.parse_args() + print_arguments(args) + compress(args) + + +if __name__ == '__main__': + main() diff --git a/paddleslim/analysis/__init__.py b/paddleslim/analysis/__init__.py index 76904c8d548208adb29188f28e9e0c6a0f11f30d..9caa0d24006a3e59f2d39c646d247b7e68480f96 100644 --- a/paddleslim/analysis/__init__.py +++ b/paddleslim/analysis/__init__.py @@ -15,9 +15,6 @@ import flops as flops_module from flops import * import model_size as model_size_module from model_size import * -import sensitive -from sensitive import * __all__ = [] __all__ += flops_module.__all__ __all__ += model_size_module.__all__ -__all__ += sensitive.__all__ diff --git a/paddleslim/core/graph_wrapper.py b/paddleslim/core/graph_wrapper.py index 72de894a2e4345c32e7a4eee2f35249b77c2f467..dc01846a10feb8bf212f9e35b9cd585df47ba739 100644 --- a/paddleslim/core/graph_wrapper.py +++ b/paddleslim/core/graph_wrapper.py @@ -54,6 +54,9 @@ class VarWrapper(object): """ return self._var.name + def __repr__(self): + return self._var.name + def shape(self): """ Get the shape of the varibale. @@ -131,6 +134,11 @@ class OpWrapper(object): """ return self._op.type + def __repr__(self): + return "op[id: {}, type: {}; inputs: {}]".format(self.idx(), + self.type(), + self.all_inputs()) + def is_bwd_op(self): """ Whether this operator is backward op. diff --git a/paddleslim/dist/single_distiller.py b/paddleslim/dist/single_distiller.py new file mode 100644 index 0000000000000000000000000000000000000000..7e39a9e2d9c743681320eaa70e0d75476844018c --- /dev/null +++ b/paddleslim/dist/single_distiller.py @@ -0,0 +1,184 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.fluid as fluid + + +def merge(teacher_program, + student_program, + data_name_map, + place, + teacher_scope=fluid.global_scope(), + student_scope=fluid.global_scope(), + name_prefix='teacher_'): + """ + Merge teacher program into student program and add a uniform prefix to the + names of all vars in teacher program + Args: + teacher_program(Program): The input teacher model paddle program + student_program(Program): The input student model paddle program + data_map_map(dict): Describe the mapping between the teacher var name + and the student var name + place(fluid.CPUPlace()|fluid.CUDAPlace(N)): This parameter represents + paddle run on which device. + student_scope(Scope): The input student scope + teacher_scope(Scope): The input teacher scope + name_prefix(str): Name prefix added for all vars of the teacher program. + Return(Program): Merged program. + """ + teacher_program = teacher_program.clone(for_test=True) + for teacher_var in teacher_program.list_vars(): + skip_rename = False + if teacher_var.name != 'fetch' and teacher_var.name != 'feed': + if teacher_var.name in data_name_map.keys(): + new_name = data_name_map[teacher_var.name] + if new_name == teacher_var.name: + skip_rename = True + else: + new_name = name_prefix + teacher_var.name + if not skip_rename: + # scope var rename + scope_var = teacher_scope.var(teacher_var.name).get_tensor() + renamed_scope_var = teacher_scope.var(new_name).get_tensor() + renamed_scope_var.set(np.array(scope_var), place) + + # program var rename + renamed_var = teacher_program.global_block()._rename_var( + teacher_var.name, new_name) + + for teacher_var in teacher_program.list_vars(): + if teacher_var.name != 'fetch' and teacher_var.name != 'feed': + # student scope add var + student_scope_var = student_scope.var(teacher_var.name).get_tensor() + teacher_scope_var = teacher_scope.var(teacher_var.name).get_tensor() + student_scope_var.set(np.array(teacher_scope_var), place) + + # student program add var + new_var = student_program.global_block()._clone_variable( + teacher_var, force_persistable=False) + new_var.stop_gradient = True + + for block in teacher_program.blocks: + for op in block.ops: + if op.type != 'feed' and op.type != 'fetch': + inputs = {} + outputs = {} + attrs = {} + for input_name in op.input_names: + inputs[input_name] = [ + block.var(in_var_name) + for in_var_name in op.input(input_name) + ] + + for output_name in op.output_names: + outputs[output_name] = [ + block.var(out_var_name) + for out_var_name in op.output(output_name) + ] + for attr_name in op.attr_names: + attrs[attr_name] = op.attr(attr_name) + student_program.global_block().append_op( + type=op.type, inputs=inputs, outputs=outputs, attrs=attrs) + return student_program + + +def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name, + student_var2_name, program): + """ + Combine variables from student model and teacher model by fsp-loss. + Args: + teacher_var1_name(str): The name of teacher_var1. + teacher_var2_name(str): The name of teacher_var2. Except for the + second dimension, all other dimensions should + be consistent with teacher_var1. + student_var1_name(str): The name of student_var1. + student_var2_name(str): The name of student_var2. Except for the + second dimension, all other dimensions should + be consistent with student_var1. + program(Program): The input distiller program. + Return(Variable): fsp distiller loss. + """ + teacher_var1 = program.global_block().var(teacher_var1_name) + teacher_var2 = program.global_block().var(teacher_var2_name) + student_var1 = program.global_block().var(student_var1_name) + student_var2 = program.global_block().var(student_var2_name) + teacher_fsp_matrix = fluid.layers.fsp_matrix(teacher_var1, teacher_var2) + student_fsp_matrix = fluid.layers.fsp_matrix(student_var1, student_var2) + fsp_loss = fluid.layers.reduce_mean( + fluid.layers.square(student_fsp_matrix - teacher_fsp_matrix)) + return fsp_loss + + +def l2_loss(teacher_var_name, student_var_name, program): + """ + Combine variables from student model and teacher model by l2-loss. + Args: + teacher_var_name(str): The name of teacher_var. + student_var_name(str): The name of student_var. + program(Program): The input distiller program. + Return(Variable): l2 distiller loss. + """ + student_var = program.global_block().var(student_var_name) + teacher_var = program.global_block().var(teacher_var_name) + l2_loss = fluid.layers.reduce_mean( + fluid.layers.square(student_var - teacher_var)) + return l2_loss + + +def soft_label_loss(teacher_var_name, + student_var_name, + program, + teacher_temperature=1., + student_temperature=1.): + """ + Combine variables from student model and teacher model by soft-label-loss. + Args: + teacher_var_name(str): The name of teacher_var. + student_var_name(str): The name of student_var. + program(Program): The input distiller program. + teacher_temperature(float): Temperature used to divide + teacher_feature_map before softmax. default: 1.0 + student_temperature(float): Temperature used to divide + student_feature_map before softmax. default: 1.0 + Return(Variable): l2 distiller loss. + """ + student_var = program.global_block().var(student_var_name) + teacher_var = program.global_block().var(teacher_var_name) + student_var = fluid.layers.softmax(student_var / student_temperature) + teacher_var = fluid.layers.softmax(teacher_var / teacher_temperature) + teacher_var.stop_gradient = True + soft_label_loss = fluid.layers.reduce_mean( + fluid.layers.cross_entropy( + student_var, teacher_var, soft_label=True)) + return soft_label_loss + + +def loss(program, loss_func, **kwargs): + """ + Combine variables from student model and teacher model by self defined loss. + Args: + program(Program): The input distiller program. + loss_func(function): The user self defined loss function. + Return(Variable): self defined distiller loss. + """ + func_parameters = {} + for item in kwargs.items(): + if isinstance(item[1], str): + func_parameters.setdefault(item[0], + program.global_block().var(item[1])) + else: + func_parameters.setdefault(item[0], item[1]) + loss = loss_func(**func_parameters) + return loss diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py index f57caaa6beb6fec59b618a689b44652f0cf259fc..00decbfd1ae38dfa3fedf3234665ca740674d603 100644 --- a/paddleslim/nas/sa_nas.py +++ b/paddleslim/nas/sa_nas.py @@ -60,16 +60,17 @@ class SANAS(object): self._init_temperature = init_temperature self._is_server = is_server self._configs = configs - self._keys = hashlib.md5(str(self._configs)).hexdigest() + self._key = hashlib.md5(str(self._configs)).hexdigest() server_ip, server_port = server_addr if server_ip == None or server_ip == "": server_ip = self._get_host_ip() + factory = SearchSpaceFactory() + self._search_space = factory.get_search_space(configs) + # create controller server if self._is_server: - factory = SearchSpaceFactory() - self._search_space = factory.get_search_space(configs) init_tokens = self._search_space.init_tokens() range_table = self._search_space.range_table() range_table = (len(range_table) * [0], range_table) @@ -90,6 +91,7 @@ class SANAS(object): search_steps=search_steps, key=self._key) self._controller_server.start() + server_port = self._controller_server.port() self._controller_client = ControllerClient( server_ip, server_port, key=self._key) @@ -99,6 +101,9 @@ class SANAS(object): def _get_host_ip(self): return socket.gethostbyname(socket.gethostname()) + def tokens2arch(self, tokens): + return self._search_space.token2arch(self.tokens) + def next_archs(self): """ Get next network architectures. diff --git a/paddleslim/nas/search_space/combine_search_space.py b/paddleslim/nas/search_space/combine_search_space.py index 667720a9110aa92e096a4f8fa30bb3e4b3e3cecb..17ebbd3939798ad0e2a7d3fd763bb9427f6e13f0 100644 --- a/paddleslim/nas/search_space/combine_search_space.py +++ b/paddleslim/nas/search_space/combine_search_space.py @@ -39,6 +39,7 @@ class CombineSearchSpace(object): for config_list in config_lists: key, config = config_list self.spaces.append(self._get_single_search_space(key, config)) + self.init_tokens() def _get_single_search_space(self, key, config): """ @@ -51,9 +52,11 @@ class CombineSearchSpace(object): model space(class) """ cls = SEARCHSPACE.get(key) - space = cls(config['input_size'], config['output_size'], - config['block_num'], config['block_mask']) - + block_mask = config['block_mask'] if 'block_mask' in config else None + space = cls(config['input_size'], + config['output_size'], + config['block_num'], + block_mask=block_mask) return space def init_tokens(self): diff --git a/paddleslim/prune/__init__.py b/paddleslim/prune/__init__.py index f8f87862f7c0e9c09c23b753be600eed5c915a90..b012254170d4d63bf24fcccaf8fa5f3eaeccac11 100644 --- a/paddleslim/prune/__init__.py +++ b/paddleslim/prune/__init__.py @@ -21,6 +21,8 @@ import controller_client from controller_client import * import sensitive_pruner from sensitive_pruner import * +import sensitive +from sensitive import * __all__ = [] __all__ += pruner.__all__ @@ -28,3 +30,4 @@ __all__ += auto_pruner.__all__ __all__ += controller_server.__all__ __all__ += controller_client.__all__ __all__ += sensitive_pruner.__all__ +__all__ += sensitive.__all__ diff --git a/paddleslim/prune/pruner.py b/paddleslim/prune/pruner.py index 0fdde525a793b90df63f3245ac5215365dd7ccf4..e2b6a7e1d28078abef97c5fa53b215b098f18cca 100644 --- a/paddleslim/prune/pruner.py +++ b/paddleslim/prune/pruner.py @@ -528,33 +528,41 @@ class Pruner(): Returns: list: A list of operators. """ + _logger.debug("######################search: {}######################". + format(op_node)) visited = [op_node.idx()] stack = [] brothers = [] for op in graph.next_ops(op_node): - if (op.type() != 'conv2d') and (op.type() != 'fc') and ( - not op.is_bwd_op()): + if ("conv2d" not in op.type()) and (op.type() != 'fc') and ( + not op.is_bwd_op()) and (not op.is_opt_op()): stack.append(op) visited.append(op.idx()) while len(stack) > 0: top_op = stack.pop() - if top_op.type().startswith("elementwise_"): - for parent in graph.pre_ops(top_op): - if parent.idx() not in visited and ( - not parent.is_bwd_op()): - if ((parent.type() == 'conv2d') or - (parent.type() == 'fc')): - brothers.append(parent) - else: - stack.append(parent) - visited.append(parent.idx()) + for parent in graph.pre_ops(top_op): + if parent.idx() not in visited and ( + not parent.is_bwd_op()) and (not parent.is_opt_op()): + _logger.debug("----------go back from {} to {}----------". + format(top_op, parent)) + if (('conv2d' in parent.type()) or + (parent.type() == 'fc')): + brothers.append(parent) + else: + stack.append(parent) + visited.append(parent.idx()) for child in graph.next_ops(top_op): - if (child.type() != 'conv2d') and (child.type() != 'fc') and ( + if ('conv2d' not in child.type() + ) and (child.type() != 'fc') and ( child.idx() not in visited) and ( - not child.is_bwd_op()): + not child.is_bwd_op()) and (not child.is_opt_op()): stack.append(child) visited.append(child.idx()) + _logger.debug("brothers: {}".format(brothers)) + _logger.debug( + "######################Finish search######################".format( + op_node)) return brothers def _cal_pruned_idx(self, name, param, ratio, axis): diff --git a/paddleslim/analysis/sensitive.py b/paddleslim/prune/sensitive.py similarity index 100% rename from paddleslim/analysis/sensitive.py rename to paddleslim/prune/sensitive.py diff --git a/paddleslim/prune/sensitive_pruner.py b/paddleslim/prune/sensitive_pruner.py index 21f8900336d736de85fdaef42fb7479488dfe3ee..6213382fa9d47bae81c718f9f23c3e34146e05e4 100644 --- a/paddleslim/prune/sensitive_pruner.py +++ b/paddleslim/prune/sensitive_pruner.py @@ -19,7 +19,7 @@ from scipy.optimize import leastsq import numpy as np import paddle.fluid as fluid from ..common import get_logger -from ..analysis import sensitivity +from .sensitive import sensitivity from ..analysis import flops from .pruner import Pruner diff --git a/tests/test_prune.py b/tests/test_prune.py index 93609367351618ce375f164a1dca284e85369e4c..3fdaa867e350af876648871f83fe70cc83b548b6 100644 --- a/tests/test_prune.py +++ b/tests/test_prune.py @@ -15,7 +15,7 @@ import sys sys.path.append("../") import unittest import paddle.fluid as fluid -from prune import Pruner +from paddleslim.prune import Pruner from layers import conv_bn_layer diff --git a/tests/test_sa_nas.py b/tests/test_sa_nas.py index 5666e1410a820c09bc10fa0b10d282434c7837fe..a4203a85a898632ac2102eb61ab7dd7b475e73ef 100644 --- a/tests/test_sa_nas.py +++ b/tests/test_sa_nas.py @@ -40,7 +40,11 @@ class TestSANAS(unittest.TestCase): base_flops = flops(main_program) search_steps = 3 - sa_nas = SANAS(configs, search_steps=search_steps, is_server=True) + sa_nas = SANAS( + configs, + search_steps=search_steps, + server_addr=("", 0), + is_server=True) for i in range(search_steps): archs = sa_nas.next_archs()