diff --git a/demo/auto_prune/train.py b/demo/auto_prune/train.py index 70930774dc1c4306d12e63fbd1766a67ec2a5c3c..d65dd875a8d650b57bbe429514e99dc6fa46e630 100644 --- a/demo/auto_prune/train.py +++ b/demo/auto_prune/train.py @@ -195,11 +195,12 @@ def compress(args): server_addr=("", 0), init_temperature=100, reduce_rate=0.85, - max_try_number=300, + max_try_times=300, max_client_num=10, search_steps=100, max_ratios=0.9, min_ratios=0., + is_server=True, key="auto_pruner") while True: diff --git a/demo/nas/sa_nas_mobilenetv2_cifar10.py b/demo/nas/sa_nas_mobilenetv2_cifar10.py index 3e903960b1c783c38d672238d5a2b3a0c1581c4d..249d4c214788c0ffc5a0d741dc48b4942ea5808b 100644 --- a/demo/nas/sa_nas_mobilenetv2_cifar10.py +++ b/demo/nas/sa_nas_mobilenetv2_cifar10.py @@ -39,7 +39,7 @@ def init_sa_nas(config): search_steps = 10000000 ### start a server and a client - sa_nas = SANAS(config, max_flops=base_flops, search_steps=search_steps) + sa_nas = SANAS(config, search_steps=search_steps, is_server=True) ### start a client, server_addr is server address #sa_nas = SANAS(config, max_flops = base_flops, server_addr=("10.255.125.38", 18607), search_steps = search_steps, is_server=False) diff --git a/demo/sensitive_prune/train.py b/demo/sensitive_prune/train.py new file mode 100644 index 0000000000000000000000000000000000000000..e6c1ba7ccd09f41c8d0652075036a1c279251517 --- /dev/null +++ b/demo/sensitive_prune/train.py @@ -0,0 +1,223 @@ +import os +import sys +import logging +import paddle +import argparse +import functools +import math +import time +import numpy as np +import paddle.fluid as fluid +from paddleslim.prune import SensitivePruner +from paddleslim.common import get_logger +from paddleslim.analysis import flops +sys.path.append(sys.path[0] + "/../") +import models +from utility import add_arguments, print_arguments + +_logger = get_logger(__name__, level=logging.INFO) + +parser = argparse.ArgumentParser(description=__doc__) +add_arg = functools.partial(add_arguments, argparser=parser) +# yapf: disable +add_arg('batch_size', int, 64 * 4, "Minibatch size.") +add_arg('use_gpu', bool, True, "Whether to use GPU or not.") +add_arg('model', str, "MobileNet", "The target model.") +add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretained", "Whether to use pretrained model.") +add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.") +add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.") +add_arg('l2_decay', float, 3e-5, "The l2_decay parameter.") +add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") +add_arg('num_epochs', int, 120, "The number of total epochs.") +add_arg('total_images', int, 1281167, "The number of total training images.") +parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step") +add_arg('config_file', str, None, "The config file for compression with yaml format.") +add_arg('data', str, "mnist", "Which data to use. 'mnist' or 'imagenet'") +add_arg('log_period', int, 10, "Log period in batches.") +add_arg('test_period', int, 10, "Test period in epoches.") +add_arg('checkpoints', str, "./checkpoints", "Checkpoints path.") +# yapf: enable + +model_list = [m for m in dir(models) if "__" not in m] + + +def piecewise_decay(args): + step = int(math.ceil(float(args.total_images) / args.batch_size)) + bd = [step * e for e in args.step_epochs] + lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)] + learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=args.momentum_rate, + regularization=fluid.regularizer.L2Decay(args.l2_decay)) + return optimizer + + +def cosine_decay(args): + step = int(math.ceil(float(args.total_images) / args.batch_size)) + learning_rate = fluid.layers.cosine_decay( + learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs) + optimizer = fluid.optimizer.Momentum( + learning_rate=learning_rate, + momentum=args.momentum_rate, + regularization=fluid.regularizer.L2Decay(args.l2_decay)) + return optimizer + + +def create_optimizer(args): + if args.lr_strategy == "piecewise_decay": + return piecewise_decay(args) + elif args.lr_strategy == "cosine_decay": + return cosine_decay(args) + + +def compress(args): + + train_reader = None + test_reader = None + if args.data == "mnist": + import paddle.dataset.mnist as reader + train_reader = reader.train() + val_reader = reader.test() + class_dim = 10 + image_shape = "1,28,28" + elif args.data == "imagenet": + import imagenet_reader as reader + train_reader = reader.train() + val_reader = reader.val() + class_dim = 1000 + image_shape = "3,224,224" + else: + raise ValueError("{} is not supported.".format(args.data)) + + image_shape = [int(m) for m in image_shape.split(",")] + assert args.model in model_list, "{} is not in lists: {}".format( + args.model, model_list) + image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + # model definition + model = models.__dict__[args.model]() + out = model.net(input=image, class_dim=class_dim) + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + val_program = fluid.default_main_program().clone(for_test=True) + opt = create_optimizer(args) + opt.minimize(avg_cost) + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + if args.pretrained_model: + + def if_exist(var): + return os.path.exists( + os.path.join(args.pretrained_model, var.name)) + + fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist) + + val_reader = paddle.batch(val_reader, batch_size=args.batch_size) + train_reader = paddle.batch( + train_reader, batch_size=args.batch_size, drop_last=True) + + train_feeder = feeder = fluid.DataFeeder([image, label], place) + val_feeder = feeder = fluid.DataFeeder( + [image, label], place, program=val_program) + + def test(epoch, program): + batch_id = 0 + acc_top1_ns = [] + acc_top5_ns = [] + for data in val_reader(): + start_time = time.time() + acc_top1_n, acc_top5_n = exe.run( + program, + feed=train_feeder.feed(data), + fetch_list=[acc_top1.name, acc_top5.name]) + end_time = time.time() + if batch_id % args.log_period == 0: + _logger.info( + "Eval epoch[{}] batch[{}] - acc_top1: {:.3f}; acc_top5: {:.3f}; time: {:.3f}". + format(epoch, batch_id, + np.mean(acc_top1_n), + np.mean(acc_top5_n), end_time - start_time)) + acc_top1_ns.append(np.mean(acc_top1_n)) + acc_top5_ns.append(np.mean(acc_top5_n)) + batch_id += 1 + + _logger.info( + "Final eval epoch[{}] - acc_top1: {:.3f}; acc_top5: {:.3f}".format( + epoch, + np.mean(np.array(acc_top1_ns)), np.mean( + np.array(acc_top5_ns)))) + return np.mean(np.array(acc_top1_ns)) + + def train(epoch, program): + + build_strategy = fluid.BuildStrategy() + exec_strategy = fluid.ExecutionStrategy() + train_program = fluid.compiler.CompiledProgram( + program).with_data_parallel( + loss_name=avg_cost.name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + + batch_id = 0 + for data in train_reader(): + start_time = time.time() + loss_n, acc_top1_n, acc_top5_n = exe.run( + train_program, + feed=train_feeder.feed(data), + fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name]) + end_time = time.time() + loss_n = np.mean(loss_n) + acc_top1_n = np.mean(acc_top1_n) + acc_top5_n = np.mean(acc_top5_n) + if batch_id % args.log_period == 0: + _logger.info( + "epoch[{}]-batch[{}] - loss: {:.3f}; acc_top1: {:.3f}; acc_top5: {:.3f}; time: {:.3f}". + format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n, + end_time - start_time)) + batch_id += 1 + + params = [] + for param in fluid.default_main_program().global_block().all_parameters(): + if "_sep_weights" in param.name: + params.append(param.name) + + def eval_func(program): + return test(0, program) + + if args.data == "mnist": + train(0, fluid.default_main_program()) + + pruner = SensitivePruner(place, eval_func, checkpoints=args.checkpoints) + pruned_program, pruned_val_program, iter = pruner.restore() + + if pruned_program is None: + pruned_program = fluid.default_main_program() + if pruned_val_program is None: + pruned_val_program = val_program + + start = iter + end = 6 + for iter in range(start, end): + pruned_program, pruned_val_program = pruner.prune( + pruned_program, pruned_val_program, params, 0.1) + train(iter, pruned_program) + test(iter, pruned_val_program) + pruner.save_checkpoint(pruned_program, pruned_val_program) + + print("before flops: {}".format(flops(fluid.default_main_program()))) + print("after flops: {}".format(flops(pruned_val_program))) + + +def main(): + args = parser.parse_args() + print_arguments(args) + compress(args) + + +if __name__ == '__main__': + main() diff --git a/paddleslim/analysis/__init__.py b/paddleslim/analysis/__init__.py index 76904c8d548208adb29188f28e9e0c6a0f11f30d..9caa0d24006a3e59f2d39c646d247b7e68480f96 100644 --- a/paddleslim/analysis/__init__.py +++ b/paddleslim/analysis/__init__.py @@ -15,9 +15,6 @@ import flops as flops_module from flops import * import model_size as model_size_module from model_size import * -import sensitive -from sensitive import * __all__ = [] __all__ += flops_module.__all__ __all__ += model_size_module.__all__ -__all__ += sensitive.__all__ diff --git a/paddleslim/common/__init__.py b/paddleslim/common/__init__.py index 98b314ab6d144924bff6b68e3fb176ce73583f5c..2794cd4d86c0996155fd8d6e9dd830cdc8775e09 100644 --- a/paddleslim/common/__init__.py +++ b/paddleslim/common/__init__.py @@ -23,6 +23,8 @@ import controller_client from controller_client import * import lock_utils from lock_utils import * +import cached_reader as cached_reader_module +from cached_reader import * __all__ = [] __all__ += controller.__all__ @@ -30,3 +32,4 @@ __all__ += sa_controller.__all__ __all__ += controller_server.__all__ __all__ += controller_client.__all__ __all__ += lock_utils.__all__ +__all__ += cached_reader_module.__all__ diff --git a/paddleslim/common/cached_reader.py b/paddleslim/common/cached_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..55f27054efe55d9df90352b3e707fe51c8996023 --- /dev/null +++ b/paddleslim/common/cached_reader.py @@ -0,0 +1,57 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import numpy as np +from .log_helper import get_logger + +__all__ = ['cached_reader'] + +_logger = get_logger(__name__, level=logging.INFO) + + +def cached_reader(reader, sampled_rate, cache_path, cached_id): + """ + Sample partial data from reader and cache them into local file system. + Args: + reader: Iterative data source. + sampled_rate(float): The sampled rate used to sample partial data for evaluation. None means using all data in eval_reader. default: None. + cache_path(str): The path to cache the sampled data. + cached_id(int): The id of dataset sampled. Evaluations with same cached_id use the same sampled dataset. default: 0. + """ + np.random.seed(cached_id) + cache_path = os.path.join(cache_path, str(cached_id)) + _logger.debug('read data from: {}'.format(cache_path)) + + def s_reader(): + if os.path.isdir(cache_path): + for file_name in open(os.path.join(cache_path, "list")): + yield np.load( + os.path.join(cache_path, file_name.strip()), + allow_pickle=True) + else: + os.makedirs(cache_path) + list_file = open(os.path.join(cache_path, "list"), 'w') + batch = 0 + dtype = None + for data in reader(): + if batch == 0 or (np.random.uniform() < sampled_rate): + np.save( + os.path.join(cache_path, 'batch' + str(batch)), data) + list_file.write('batch' + str(batch) + '.npy\n') + batch += 1 + yield data + + return s_reader diff --git a/paddleslim/common/controller_client.py b/paddleslim/common/controller_client.py index ad989dd16014fa8e6fa1495516e81048324fb826..8a8ebbde3d738438d3cca484ca9c824d853837b2 100644 --- a/paddleslim/common/controller_client.py +++ b/paddleslim/common/controller_client.py @@ -38,7 +38,7 @@ class ControllerClient(object): self.socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._key = key - def update(self, tokens, reward): + def update(self, tokens, reward, iter): """ Update the controller according to latest tokens and reward. Args: @@ -48,8 +48,8 @@ class ControllerClient(object): socket_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) socket_client.connect((self.server_ip, self.server_port)) tokens = ",".join([str(token) for token in tokens]) - socket_client.send("{}\t{}\t{}".format(self._key, tokens, reward) - .encode()) + socket_client.send("{}\t{}\t{}\t{}".format(self._key, tokens, reward, + iter).encode()) response = socket_client.recv(1024).decode() if response.strip('\n').split("\t") == "ok": return True diff --git a/paddleslim/common/controller_server.py b/paddleslim/common/controller_server.py index e4705a887727bf444b3ba285165d27df59a1ed57..bf3ee3ab2e27c468c929013be6954f4042e53537 100644 --- a/paddleslim/common/controller_server.py +++ b/paddleslim/common/controller_server.py @@ -51,23 +51,8 @@ class ControllerServer(object): self._port = address[1] self._ip = address[0] self._key = key - self._socket_file = "./controller_server.socket" def start(self): - open(self._socket_file, 'a').close() - socket_file = open(self._socket_file, 'r+') - lock(socket_file) - tid = socket_file.readline() - if tid == '': - _logger.info("start controller server...") - tid = self._start() - socket_file.write("tid: {}\nip: {}\nport: {}\n".format( - tid, self._ip, self._port)) - _logger.info("started controller server...") - unlock(socket_file) - socket_file.close() - - def _start(self): self._socket_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket_server.bind(self._address) self._socket_server.listen(self._max_client_num) @@ -82,7 +67,6 @@ class ControllerServer(object): def close(self): """Close the server.""" self._closed = True - os.remove(self._socket_file) _logger.info("server closed!") def port(self): @@ -109,14 +93,15 @@ class ControllerServer(object): _logger.debug("recv message from {}: [{}]".format(addr, message)) messages = message.strip('\n').split("\t") - if (len(messages) < 3) or (messages[0] != self._key): + if (len(messages) < 4) or (messages[0] != self._key): _logger.debug("recv noise from {}: [{}]".format( addr, message)) continue tokens = messages[1] reward = messages[2] + iter = messages[3] tokens = [int(token) for token in tokens.split(",")] - self._controller.update(tokens, float(reward)) + self._controller.update(tokens, float(reward), int(iter)) response = "ok" conn.send(response.encode()) _logger.debug("send message to {}: [{}]".format(addr, diff --git a/paddleslim/common/sa_controller.py b/paddleslim/common/sa_controller.py index b619b818a3208d740c1ddb6753cf5931f3d058f5..9a36da93c848821ac8b9d8992b4b4d5d6bf44994 100644 --- a/paddleslim/common/sa_controller.py +++ b/paddleslim/common/sa_controller.py @@ -32,7 +32,7 @@ class SAController(EvolutionaryController): range_table=None, reduce_rate=0.85, init_temperature=1024, - max_iter_number=300, + max_try_times=None, init_tokens=None, constrain_func=None): """Initialize. @@ -40,7 +40,7 @@ class SAController(EvolutionaryController): range_table(list): Range table. reduce_rate(float): The decay rate of temperature. init_temperature(float): Init temperature. - max_iter_number(int): max iteration number. + max_try_times(int): max try times before get legal tokens. init_tokens(list): The initial tokens. constrain_func(function): The callback function used to check whether the tokens meet constraint. None means there is no constraint. Default: None. """ @@ -50,7 +50,7 @@ class SAController(EvolutionaryController): len(self._range_table) == 2) self._reduce_rate = reduce_rate self._init_temperature = init_temperature - self._max_iter_number = max_iter_number + self._max_try_times = max_try_times self._reward = -1 self._tokens = init_tokens self._constrain_func = constrain_func @@ -65,15 +65,17 @@ class SAController(EvolutionaryController): d[key] = self.__dict__[key] return d - def update(self, tokens, reward): + def update(self, tokens, reward, iter): """ Update the controller according to latest tokens and reward. Args: tokens(list): The tokens generated in last step. reward(float): The reward of tokens. """ - self._iter += 1 - temperature = self._init_temperature * self._reduce_rate**self._iter + iter = int(iter) + if iter > self._iter: + self._iter = iter + temperature = self._init_temperature * self._reduce_rate**self._iter if (reward > self._reward) or (np.random.random() <= math.exp( (reward - self._reward) / temperature)): self._reward = reward @@ -99,9 +101,9 @@ class SAController(EvolutionaryController): self._range_table[1][index] + 1) _logger.debug("change index[{}] from {} to {}".format(index, tokens[ index], new_tokens[index])) - if self._constrain_func is None: + if self._constrain_func is None or self._max_try_times is None: return new_tokens - for _ in range(self._max_iter_number): + for _ in range(self._max_try_times): if not self._constrain_func(new_tokens): index = int(len(self._range_table[0]) * np.random.random()) new_tokens = tokens[:] diff --git a/paddleslim/dist/single_distiller.py b/paddleslim/dist/single_distiller.py new file mode 100644 index 0000000000000000000000000000000000000000..7e39a9e2d9c743681320eaa70e0d75476844018c --- /dev/null +++ b/paddleslim/dist/single_distiller.py @@ -0,0 +1,184 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.fluid as fluid + + +def merge(teacher_program, + student_program, + data_name_map, + place, + teacher_scope=fluid.global_scope(), + student_scope=fluid.global_scope(), + name_prefix='teacher_'): + """ + Merge teacher program into student program and add a uniform prefix to the + names of all vars in teacher program + Args: + teacher_program(Program): The input teacher model paddle program + student_program(Program): The input student model paddle program + data_map_map(dict): Describe the mapping between the teacher var name + and the student var name + place(fluid.CPUPlace()|fluid.CUDAPlace(N)): This parameter represents + paddle run on which device. + student_scope(Scope): The input student scope + teacher_scope(Scope): The input teacher scope + name_prefix(str): Name prefix added for all vars of the teacher program. + Return(Program): Merged program. + """ + teacher_program = teacher_program.clone(for_test=True) + for teacher_var in teacher_program.list_vars(): + skip_rename = False + if teacher_var.name != 'fetch' and teacher_var.name != 'feed': + if teacher_var.name in data_name_map.keys(): + new_name = data_name_map[teacher_var.name] + if new_name == teacher_var.name: + skip_rename = True + else: + new_name = name_prefix + teacher_var.name + if not skip_rename: + # scope var rename + scope_var = teacher_scope.var(teacher_var.name).get_tensor() + renamed_scope_var = teacher_scope.var(new_name).get_tensor() + renamed_scope_var.set(np.array(scope_var), place) + + # program var rename + renamed_var = teacher_program.global_block()._rename_var( + teacher_var.name, new_name) + + for teacher_var in teacher_program.list_vars(): + if teacher_var.name != 'fetch' and teacher_var.name != 'feed': + # student scope add var + student_scope_var = student_scope.var(teacher_var.name).get_tensor() + teacher_scope_var = teacher_scope.var(teacher_var.name).get_tensor() + student_scope_var.set(np.array(teacher_scope_var), place) + + # student program add var + new_var = student_program.global_block()._clone_variable( + teacher_var, force_persistable=False) + new_var.stop_gradient = True + + for block in teacher_program.blocks: + for op in block.ops: + if op.type != 'feed' and op.type != 'fetch': + inputs = {} + outputs = {} + attrs = {} + for input_name in op.input_names: + inputs[input_name] = [ + block.var(in_var_name) + for in_var_name in op.input(input_name) + ] + + for output_name in op.output_names: + outputs[output_name] = [ + block.var(out_var_name) + for out_var_name in op.output(output_name) + ] + for attr_name in op.attr_names: + attrs[attr_name] = op.attr(attr_name) + student_program.global_block().append_op( + type=op.type, inputs=inputs, outputs=outputs, attrs=attrs) + return student_program + + +def fsp_loss(teacher_var1_name, teacher_var2_name, student_var1_name, + student_var2_name, program): + """ + Combine variables from student model and teacher model by fsp-loss. + Args: + teacher_var1_name(str): The name of teacher_var1. + teacher_var2_name(str): The name of teacher_var2. Except for the + second dimension, all other dimensions should + be consistent with teacher_var1. + student_var1_name(str): The name of student_var1. + student_var2_name(str): The name of student_var2. Except for the + second dimension, all other dimensions should + be consistent with student_var1. + program(Program): The input distiller program. + Return(Variable): fsp distiller loss. + """ + teacher_var1 = program.global_block().var(teacher_var1_name) + teacher_var2 = program.global_block().var(teacher_var2_name) + student_var1 = program.global_block().var(student_var1_name) + student_var2 = program.global_block().var(student_var2_name) + teacher_fsp_matrix = fluid.layers.fsp_matrix(teacher_var1, teacher_var2) + student_fsp_matrix = fluid.layers.fsp_matrix(student_var1, student_var2) + fsp_loss = fluid.layers.reduce_mean( + fluid.layers.square(student_fsp_matrix - teacher_fsp_matrix)) + return fsp_loss + + +def l2_loss(teacher_var_name, student_var_name, program): + """ + Combine variables from student model and teacher model by l2-loss. + Args: + teacher_var_name(str): The name of teacher_var. + student_var_name(str): The name of student_var. + program(Program): The input distiller program. + Return(Variable): l2 distiller loss. + """ + student_var = program.global_block().var(student_var_name) + teacher_var = program.global_block().var(teacher_var_name) + l2_loss = fluid.layers.reduce_mean( + fluid.layers.square(student_var - teacher_var)) + return l2_loss + + +def soft_label_loss(teacher_var_name, + student_var_name, + program, + teacher_temperature=1., + student_temperature=1.): + """ + Combine variables from student model and teacher model by soft-label-loss. + Args: + teacher_var_name(str): The name of teacher_var. + student_var_name(str): The name of student_var. + program(Program): The input distiller program. + teacher_temperature(float): Temperature used to divide + teacher_feature_map before softmax. default: 1.0 + student_temperature(float): Temperature used to divide + student_feature_map before softmax. default: 1.0 + Return(Variable): l2 distiller loss. + """ + student_var = program.global_block().var(student_var_name) + teacher_var = program.global_block().var(teacher_var_name) + student_var = fluid.layers.softmax(student_var / student_temperature) + teacher_var = fluid.layers.softmax(teacher_var / teacher_temperature) + teacher_var.stop_gradient = True + soft_label_loss = fluid.layers.reduce_mean( + fluid.layers.cross_entropy( + student_var, teacher_var, soft_label=True)) + return soft_label_loss + + +def loss(program, loss_func, **kwargs): + """ + Combine variables from student model and teacher model by self defined loss. + Args: + program(Program): The input distiller program. + loss_func(function): The user self defined loss function. + Return(Variable): self defined distiller loss. + """ + func_parameters = {} + for item in kwargs.items(): + if isinstance(item[1], str): + func_parameters.setdefault(item[0], + program.global_block().var(item[1])) + else: + func_parameters.setdefault(item[0], item[1]) + loss = loss_func(**func_parameters) + return loss diff --git a/paddleslim/nas/sa_nas.py b/paddleslim/nas/sa_nas.py index bbee0d8db641c5b61d520e5a8043721893e86ef5..f57caaa6beb6fec59b618a689b44652f0cf259fc 100644 --- a/paddleslim/nas/sa_nas.py +++ b/paddleslim/nas/sa_nas.py @@ -15,6 +15,7 @@ import socket import logging import numpy as np +import hashlib import paddle.fluid as fluid from ..core import VarWrapper, OpWrapper, GraphWrapper from ..common import SAController @@ -33,98 +34,71 @@ _logger = get_logger(__name__, level=logging.INFO) class SANAS(object): def __init__(self, configs, - max_flops=None, - max_latency=None, - server_addr=("", 0), + server_addr=("", 8881), init_temperature=100, reduce_rate=0.85, - max_try_number=300, - max_client_num=10, search_steps=300, key="sa_nas", - is_server=True): + is_server=False): """ Search a group of ratios used to prune program. Args: configs(list): A list of search space configuration with format (key, input_size, output_size, block_num). `key` is the name of search space with data type str. `input_size` and `output_size` are input size and output size of searched sub-network. `block_num` is the number of blocks in searched network. - max_flops(int): The max flops of searched network. None means no constrains. Default: None. - max_latency(float): The max latency of searched network. None means no constrains. Default: None. server_addr(tuple): A tuple of server ip and server port for controller server. init_temperature(float): The init temperature used in simulated annealing search strategy. reduce_rate(float): The decay rate used in simulated annealing search strategy. - max_try_number(int): The max number of trying to generate legal tokens. - max_client_num(int): The max number of connections of controller server. search_steps(int): The steps of searching. key(str): Identity used in communication between controller server and clients. is_server(bool): Whether current host is controller server. Default: True. """ - + if not is_server: + assert server_addr[ + 0] != "", "You should set the IP and port of server when is_server is False." self._reduce_rate = reduce_rate self._init_temperature = init_temperature - self._max_try_number = max_try_number self._is_server = is_server - self._max_flops = max_flops - self._max_latency = max_latency - self._configs = configs - - factory = SearchSpaceFactory() - self._search_space = factory.get_search_space(configs) - init_tokens = self._search_space.init_tokens() - range_table = self._search_space.range_table() - range_table = (len(range_table) * [0], range_table) - - print range_table - - controller = SAController(range_table, self._reduce_rate, - self._init_temperature, self._max_try_number, - init_tokens, self._constrain_func) + self._keys = hashlib.md5(str(self._configs)).hexdigest() server_ip, server_port = server_addr if server_ip == None or server_ip == "": server_ip = self._get_host_ip() - self._controller_server = ControllerServer( - controller=controller, - address=(server_ip, server_port), - max_client_num=max_client_num, - search_steps=search_steps, - key=key) - # create controller server if self._is_server: + factory = SearchSpaceFactory() + self._search_space = factory.get_search_space(configs) + init_tokens = self._search_space.init_tokens() + range_table = self._search_space.range_table() + range_table = (len(range_table) * [0], range_table) + _logger.info("range table: {}".format(range_table)) + controller = SAController( + range_table, + self._reduce_rate, + self._init_temperature, + max_try_times=None, + init_tokens=init_tokens, + constrain_func=None) + + max_client_num = 100 + self._controller_server = ControllerServer( + controller=controller, + address=(server_ip, server_port), + max_client_num=max_client_num, + search_steps=search_steps, + key=self._key) self._controller_server.start() self._controller_client = ControllerClient( - self._controller_server.ip(), - self._controller_server.port(), - key=key) + server_ip, server_port, key=self._key) self._iter = 0 def _get_host_ip(self): return socket.gethostbyname(socket.gethostname()) - def _constrain_func(self, tokens): - if (self._max_flops is None) and (self._max_latency is None): - return True - archs = self._search_space.token2arch(tokens) - main_program = fluid.Program() - startup_program = fluid.Program() - with fluid.program_guard(main_program, startup_program): - i = 0 - for config, arch in zip(self._configs, archs): - input_size = config[1]["input_size"] - input = fluid.data( - name="data_{}".format(i), - shape=[None, 3, input_size, input_size], - dtype="float32") - output = arch(input) - i += 1 - return flops(main_program) < self._max_flops - def next_archs(self): """ Get next network architectures. @@ -144,4 +118,5 @@ class SANAS(object): bool: True means updating successfully while false means failure. """ self._iter += 1 - return self._controller_client.update(self._current_tokens, score) + return self._controller_client.update(self._current_tokens, score, + self._iter) diff --git a/paddleslim/prune/__init__.py b/paddleslim/prune/__init__.py index bb615b9dfca03ed2b289f902f6d75c73543f6fb2..b012254170d4d63bf24fcccaf8fa5f3eaeccac11 100644 --- a/paddleslim/prune/__init__.py +++ b/paddleslim/prune/__init__.py @@ -19,9 +19,15 @@ import controller_server from controller_server import * import controller_client from controller_client import * +import sensitive_pruner +from sensitive_pruner import * +import sensitive +from sensitive import * __all__ = [] __all__ += pruner.__all__ __all__ += auto_pruner.__all__ __all__ += controller_server.__all__ __all__ += controller_client.__all__ +__all__ += sensitive_pruner.__all__ +__all__ += sensitive.__all__ diff --git a/paddleslim/prune/auto_pruner.py b/paddleslim/prune/auto_pruner.py index fba8c11170f3fbf2eddbe15942dc642ad448658b..8420d0c1b5d6ca1d0401ba249ebfa980037907d0 100644 --- a/paddleslim/prune/auto_pruner.py +++ b/paddleslim/prune/auto_pruner.py @@ -42,7 +42,7 @@ class AutoPruner(object): server_addr=("", 0), init_temperature=100, reduce_rate=0.85, - max_try_number=300, + max_try_times=300, max_client_num=10, search_steps=300, max_ratios=[0.9], @@ -66,7 +66,7 @@ class AutoPruner(object): server_addr(tuple): A tuple of server ip and server port for controller server. init_temperature(float): The init temperature used in simulated annealing search strategy. reduce_rate(float): The decay rate used in simulated annealing search strategy. - max_try_number(int): The max number of trying to generate legal tokens. + max_try_times(int): The max number of trying to generate legal tokens. max_client_num(int): The max number of connections of controller server. search_steps(int): The steps of searching. max_ratios(float|list): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`. @@ -88,7 +88,7 @@ class AutoPruner(object): self._pruned_latency = pruned_latency self._reduce_rate = reduce_rate self._init_temperature = init_temperature - self._max_try_number = max_try_number + self._max_try_times = max_try_times self._is_server = is_server self._range_table = self._get_range_table(min_ratios, max_ratios) @@ -110,7 +110,7 @@ class AutoPruner(object): init_tokens = self._ratios2tokens(self._init_ratios) _logger.info("range table: {}".format(self._range_table)) controller = SAController(self._range_table, self._reduce_rate, - self._init_temperature, self._max_try_number, + self._init_temperature, self._max_try_times, init_tokens, self._constrain_func) server_ip, server_port = server_addr @@ -212,7 +212,7 @@ class AutoPruner(object): self._restore(self._scope) self._param_backup = {} tokens = self._ratios2tokens(self._current_ratios) - self._controller_client.update(tokens, score) + self._controller_client.update(tokens, score, self._iter) self._iter += 1 def _restore(self, scope): diff --git a/paddleslim/analysis/sensitive.py b/paddleslim/prune/sensitive.py similarity index 94% rename from paddleslim/analysis/sensitive.py rename to paddleslim/prune/sensitive.py index 09dd2a875ae21caf64034cf79421d7cc1661b817..ca9ee6f4ae7a790481a8e3b46c03cf37d096b3dc 100644 --- a/paddleslim/analysis/sensitive.py +++ b/paddleslim/prune/sensitive.py @@ -17,6 +17,7 @@ import os import logging import pickle import numpy as np +import paddle.fluid as fluid from ..core import GraphWrapper from ..common import get_logger from ..prune import Pruner @@ -27,13 +28,12 @@ __all__ = ["sensitivity"] def sensitivity(program, - scope, place, param_names, eval_func, sensitivities_file=None, step_size=0.2): - + scope = fluid.global_scope() graph = GraphWrapper(program) sensitivities = _load_sensitivities(sensitivities_file) @@ -55,7 +55,7 @@ def sensitivity(program, ratio += step_size continue if baseline is None: - baseline = eval_func(graph.program, scope) + baseline = eval_func(graph.program) param_backup = {} pruner = Pruner() @@ -68,7 +68,7 @@ def sensitivity(program, lazy=True, only_graph=False, param_backup=param_backup) - pruned_metric = eval_func(pruned_program, scope) + pruned_metric = eval_func(pruned_program) loss = (baseline - pruned_metric) / baseline _logger.info("pruned param: {}; {}; loss={}".format(name, ratio, loss)) @@ -81,7 +81,7 @@ def sensitivity(program, param_t = scope.find_var(param_name).get_tensor() param_t.set(param_backup[param_name], place) ratio += step_size - return sensitivities + return sensitivities def _load_sensitivities(sensitivities_file): diff --git a/paddleslim/prune/sensitive_pruner.py b/paddleslim/prune/sensitive_pruner.py new file mode 100644 index 0000000000000000000000000000000000000000..6213382fa9d47bae81c718f9f23c3e34146e05e4 --- /dev/null +++ b/paddleslim/prune/sensitive_pruner.py @@ -0,0 +1,207 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +import copy +from scipy.optimize import leastsq +import numpy as np +import paddle.fluid as fluid +from ..common import get_logger +from .sensitive import sensitivity +from ..analysis import flops +from .pruner import Pruner + +__all__ = ["SensitivePruner"] + +_logger = get_logger(__name__, level=logging.INFO) + + +class SensitivePruner(object): + def __init__(self, place, eval_func, scope=None, checkpoints=None): + """ + Pruner used to prune parameters iteratively according to sensitivities of parameters in each step. + Args: + place(fluid.CUDAPlace | fluid.CPUPlace): The device place where program execute. + eval_func(function): A callback function used to evaluate pruned program. The argument of this function is pruned program. And it return a score of given program. + scope(fluid.scope): The scope used to execute program. + """ + self._eval_func = eval_func + self._iter = 0 + self._place = place + self._scope = fluid.global_scope() if scope is None else scope + self._pruner = Pruner() + self._checkpoints = checkpoints + + def save_checkpoint(self, train_program, eval_program): + checkpoint = os.path.join(self._checkpoints, str(self._iter - 1)) + exe = fluid.Executor(self._place) + fluid.io.save_persistables( + exe, checkpoint, main_program=train_program, filename="__params__") + + with open(checkpoint + "/main_program", "wb") as f: + f.write(train_program.desc.serialize_to_string()) + with open(checkpoint + "/eval_program", "wb") as f: + f.write(eval_program.desc.serialize_to_string()) + + def restore(self, checkpoints=None): + + exe = fluid.Executor(self._place) + checkpoints = self._checkpoints if checkpoints is None else checkpoints + print("check points: {}".format(checkpoints)) + main_program = None + eval_program = None + if checkpoints is not None: + cks = [dir for dir in os.listdir(checkpoints)] + if len(cks) > 0: + latest = max([int(ck) for ck in cks]) + latest_ck_path = os.path.join(checkpoints, str(latest)) + self._iter += 1 + + with open(latest_ck_path + "/main_program", "rb") as f: + program_desc_str = f.read() + main_program = fluid.Program.parse_from_string( + program_desc_str) + print main_program + + with open(latest_ck_path + "/eval_program", "rb") as f: + program_desc_str = f.read() + eval_program = fluid.Program.parse_from_string( + program_desc_str) + + with fluid.scope_guard(self._scope): + fluid.io.load_persistables(exe, latest_ck_path, + main_program, "__params__") + print("load checkpoint from: {}".format(latest_ck_path)) + print("flops of eval program: {}".format(flops(eval_program))) + return main_program, eval_program, self._iter + + def prune(self, train_program, eval_program, params, pruned_flops): + """ + Pruning parameters of training and evaluation network by sensitivities in current step. + Args: + train_program(fluid.Program): The training program to be pruned. + eval_program(fluid.Program): The evaluation program to be pruned. And it is also used to calculate sensitivities of parameters. + params(list): The parameters to be pruned. + pruned_flops(float): The ratio of FLOPS to be pruned in current step. + Return: + tuple: A tuple of pruned training program and pruned evaluation program. + """ + _logger.info("Pruning: {}".format(params)) + sensitivities_file = "sensitivities_iter{}.data".format(self._iter) + with fluid.scope_guard(self._scope): + sensitivities = sensitivity( + eval_program, + self._place, + params, + self._eval_func, + sensitivities_file=sensitivities_file, + step_size=0.1) + print sensitivities + _, ratios = self._get_ratios_by_sensitive(sensitivities, pruned_flops, + eval_program) + + pruned_program = self._pruner.prune( + train_program, + self._scope, + params, + ratios, + place=self._place, + only_graph=False) + pruned_val_program = None + if eval_program is not None: + pruned_val_program = self._pruner.prune( + eval_program, + self._scope, + params, + ratios, + place=self._place, + only_graph=True) + self._iter += 1 + return pruned_program, pruned_val_program + + def _get_ratios_by_sensitive(self, sensitivities, pruned_flops, + eval_program): + """ + Search a group of ratios for pruning target flops. + """ + + def func(params, x): + a, b, c, d = params + return a * x * x * x + b * x * x + c * x + d + + def error(params, x, y): + return func(params, x) - y + + def slove_coefficient(x, y): + init_coefficient = [10, 10, 10, 10] + coefficient, loss = leastsq(error, init_coefficient, args=(x, y)) + return coefficient + + min_loss = 0. + max_loss = 0. + + # step 1: fit curve by sensitivities + coefficients = {} + for param in sensitivities: + losses = np.array([0] * 5 + sensitivities[param]['loss']) + precents = np.array([0] * 5 + sensitivities[param][ + 'pruned_percent']) + coefficients[param] = slove_coefficient(precents, losses) + loss = np.max(losses) + max_loss = np.max([max_loss, loss]) + + # step 2: Find a group of ratios by binary searching. + base_flops = flops(eval_program) + ratios = [] + max_times = 20 + while min_loss < max_loss and max_times > 0: + loss = (max_loss + min_loss) / 2 + _logger.info( + '-----------Try pruned ratios while acc loss={}-----------'. + format(loss)) + ratios = [] + # step 2.1: Get ratios according to current loss + for param in sensitivities: + coefficient = copy.deepcopy(coefficients[param]) + coefficient[-1] = coefficient[-1] - loss + roots = np.roots(coefficient) + for root in roots: + min_root = 1 + if np.isreal(root) and root > 0 and root < 1: + selected_root = min(root.real, min_root) + ratios.append(selected_root) + _logger.info('Pruned ratios={}'.format( + [round(ratio, 3) for ratio in ratios])) + # step 2.2: Pruning by current ratios + param_shape_backup = {} + pruned_program = self._pruner.prune( + eval_program, + None, # scope + sensitivities.keys(), + ratios, + None, # place + only_graph=True) + pruned_ratio = 1 - (float(flops(pruned_program)) / base_flops) + _logger.info('Pruned flops: {:.4f}'.format(pruned_ratio)) + + # step 2.3: Check whether current ratios is enough + if abs(pruned_ratio - pruned_flops) < 0.015: + break + if pruned_ratio > pruned_flops: + max_loss = loss + else: + min_loss = loss + max_times -= 1 + return sensitivities.keys(), ratios diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py index 8ea9fbe32ee3f8617d9f00a1ce097b715957163e..254cf4958643ef5e4d4e6cd625028baef964e222 100755 --- a/paddleslim/quant/quanter.py +++ b/paddleslim/quant/quanter.py @@ -20,6 +20,7 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass from paddle.fluid.contrib.slim.quantization import TransformForMobilePass +from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass from paddle.fluid import core @@ -186,19 +187,68 @@ def quant_aware(program, place, config, scope=None, for_test=False): return quant_program -def quant_post(program, place, config, scope=None): +def quant_post(executor, + model_dir, + quantize_model_path, + sample_generator, + model_filename=None, + params_filename=None, + batch_size=16, + batch_nums=None, + scope=None, + algo='KL', + quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"]): """ - add quantization ops in program. the program returned is not trainable. + The function utilizes post training quantization method to quantize the + fp32 model. It uses calibrate data to calculate the scale factor of + quantized variables, and inserts fake quant/dequant op to obtain the + quantized model. + Args: - program(fluid.Program): program - scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope(). - place(fluid.CPUPlace or fluid.CUDAPlace): place - config(dict): configs for quantization, default values are in quant_config_default dict. - for_test: is for test program. - Return: - fluid.Program: the quantization program is not trainable. + executor(fluid.Executor): The executor to load, run and save the + quantized model. + model_dir(str): The path of fp32 model that will be quantized, and + the model and params that saved by fluid.io.save_inference_model + are under the path. + quantize_model_path(str): The path to save quantized model using api + fluid.io.save_inference_model. + sample_generator(Python Generator): The sample generator provides + calibrate data for DataLoader, and it only returns a sample every time. + model_filename(str, optional): The name of model file. If parameters + are saved in separate files, set it as 'None'. Default is 'None'. + params_filename(str, optional): The name of params file. + When all parameters are saved in a single file, set it + as filename. If parameters are saved in separate files, + set it as 'None'. Default is 'None'. + batch_size(int, optional): The batch size of DataLoader, default is 16. + batch_nums(int, optional): If batch_nums is not None, the number of calibrate + data is 'batch_size*batch_nums'. If batch_nums is None, use all data + generated by sample_generator as calibrate data. + scope(fluid.Scope, optional): The scope to run program, use it to load + and save variables. If scope is None, will use fluid.global_scope(). + algo(str, optional): If algo=KL, use KL-divergenc method to + get the more precise scale factor. If algo='direct', use + abs_max method to get the scale factor. Default is 'KL'. + quantizable_op_type(list[str], optional): The list of op types + that will be quantized. Default is ["conv2d", "depthwise_conv2d", + "mul"]. + Returns: + None """ - pass + post_training_quantization = PostTrainingQuantization( + executor=executor, + sample_generator=sample_generator, + model_dir=model_dir, + model_filename=model_filename, + params_filename=params_filename, + batch_size=batch_size, + batch_nums=batch_nums, + scope=scope, + algo=algo, + quantizable_op_type=quantizable_op_type, + is_full_quantize=False) + post_training_quantization.quantize() + post_training_quantization.save_quantized_model(quantize_model_path) def convert(program, place, config, scope=None, save_int8=False): diff --git a/tests/test_sa_nas.py b/tests/test_sa_nas.py index c1bcd08dadf87e24f31af1a525f67aa9a92bd26e..5666e1410a820c09bc10fa0b10d282434c7837fe 100644 --- a/tests/test_sa_nas.py +++ b/tests/test_sa_nas.py @@ -40,8 +40,7 @@ class TestSANAS(unittest.TestCase): base_flops = flops(main_program) search_steps = 3 - sa_nas = SANAS( - configs, max_flops=base_flops, search_steps=search_steps) + sa_nas = SANAS(configs, search_steps=search_steps, is_server=True) for i in range(search_steps): archs = sa_nas.next_archs()