# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import numpy as np import paddle import paddle.distributed.fleet as fleet import paddle.optimizer as optimizer from ..quant.quanter import quant_aware, _quant_config_default, _parse_configs, pact, get_pact_optimizer from ..dist import * from ..common.recover_program import recover_inference_program, _remove_fetch_node from ..common import get_logger from .strategy_config import ProgramInfo _logger = get_logger(__name__, level=logging.INFO) __all__ = [ 'build_distill_program', 'build_quant_program', 'build_prune_program', 'remove_unused_var_nodes' ] def _create_optimizer(train_config): """create optimizer""" opt = getattr(optimizer, train_config.get('optimizer') or 'SGD') ### default optimizer is SGD if 'optim_args' in train_config: if train_config[ 'optim_args'] is not None and 'grad_clip' in train_config[ 'optim_args'] and train_config['optim_args'][ 'grad_clip'] is not None: grad_clip = getattr( paddle.nn, train_config['optim_args']['grad_clip'])( **train_config['optim_args']['grad_clip_args']) train_config['optim_args'].pop('grad_clip') train_config['optim_args'].pop('grad_clip_args') else: grad_clip = None if 'grad_clip' in train_config['optim_args'] and train_config[ 'optim_args']['grad_clip'] is None: train_config['optim_args'].pop('grad_clip') train_config['optim_args'].pop('grad_clip_args') else: train_config['optim_args'] = {} grad_clip = None op = opt(learning_rate=train_config["learning_rate"], grad_clip=grad_clip, **train_config['optim_args']) return op def _parse_distill_loss(distill_node_pair, distill_loss='l2_loss', distill_lambda=1.0): """parse distill loss config""" loss_dist = 0.0 losses = [] if isinstance(distill_node_pair[0], str): assert isinstance(distill_loss, str) assert isinstance(distill_lambda, float) distill_node_pair = [distill_node_pair] distill_loss = [distill_loss] distill_lambda = [distill_lambda] assert len(distill_node_pair) == len(distill_loss) assert len(distill_node_pair) == len(distill_lambda) for node, loss, lam in zip(distill_node_pair, distill_loss, distill_lambda): tmp_loss = 0.0 _logger.info("train config.distill_node_pair: {}".format(node, loss, lam)) assert len(node) % 2 == 0, \ "distill_node_pair config wrong, the length needs to be an even number" for i in range(len(node) // 2): tmp_loss += eval(loss)(node[i * 2], node[i * 2 + 1]) loss_dist += lam * tmp_loss losses.append(tmp_loss) return loss_dist, losses def _load_program_and_merge(executor, place, train_program, config, model_dir, model_filename, params_filename, teacher_idx=None, feed_target_names=None): scope = paddle.static.global_scope() new_scope = paddle.static.Scope() if params_filename == 'None': params_filename = None try: with paddle.static.scope_guard(new_scope): [teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.fluid.io.load_inference_model( \ dirname=model_dir, \ model_filename=model_filename, \ params_filename=params_filename, \ executor=executor) except: with paddle.static.scope_guard(new_scope): [teacher_program, teacher_feed_target_names, teacher_fetch_targets]= paddle.static.load_inference_model( \ path_prefix=model_dir, \ executor=executor) _remove_fetch_node(teacher_program) if teacher_idx == None or teacher_idx == 1: test_program = train_program.clone(for_test=True) data_name_map = {} if 'merge_feed' not in config or config['merge_feed'] == True: assert len(feed_target_names) == len(teacher_feed_target_names), \ "the number of feed nodes in the teacher model is not equal to the student model" for i, name in enumerate(feed_target_names): data_name_map[teacher_feed_target_names[i]] = name if teacher_idx is None: teacher_name_prefix = 'teacher_' else: teacher_name_prefix = 'teacher{}_'.format(str(teacher_idx)) merge( teacher_program, train_program, data_name_map, place, teacher_scope=new_scope, name_prefix=teacher_name_prefix, merge_feed=config.get('merge_feed') or True) if teacher_idx == None or teacher_idx == 1: return train_program, test_program, data_name_map else: return train_program, None, data_name_map def build_distill_program(executor, place, config, train_config, train_program_info=None, pruner=None, dist_strategy=None, default_distill_node_pair=None): """build distill program with infermodel""" startup_program = paddle.static.Program() if train_program_info is None: [train_program, feed_target_names, fetch_targets]= paddle.static.load_inference_model( \ path_prefix=config["model_dir"] if "model_dir" in config else config["model_path_prefix"], \ executor=executor) train_program = recover_inference_program(train_program) else: train_program = train_program_info.program feed_target_names = train_program_info.feed_target_names fetch_targets = train_program_info.fetch_targets teacher_model_dir = config[ "teacher_model_dir"] if "teacher_model_dir" in config else config[ "teacher_model_path_prefix"] if isinstance(teacher_model_dir, list): for tea_idx in range(len(teacher_model_dir)): model_filename = config["teacher_model_filename"][ tea_idx] if "teacher_model_filename" in config else None params_filename = config["teacher_params_filename"][ tea_idx] if "teacher_params_filename" in config else None if tea_idx == 0: train_program, test_program, data_name_map = _load_program_and_merge( executor, place, train_program, config, teacher_model_dir[tea_idx], model_filename, params_filename, teacher_idx=(tea_idx + 1), feed_target_names=feed_target_names) else: train_program, _, data_name_map = _load_program_and_merge( executor, place, train_program, config, teacher_model_dir[tea_idx], model_filename, params_filename, teacher_idx=(tea_idx + 1), feed_target_names=feed_target_names) else: model_filename = config[ "teacher_model_filename"] if "teacher_model_filename" in config else None params_filename = config[ "teacher_params_filename"] if "teacher_params_filename" in config else None train_program, test_program, data_name_map = _load_program_and_merge( executor, place, train_program, config, teacher_model_dir, model_filename, params_filename, teacher_idx=None, feed_target_names=feed_target_names) # all feed node should set stop_gradient is False, for using pact quant algo. for var in train_program.list_vars(): if var.name in data_name_map.values() or var.name in data_name_map.keys( ): var.stop_gradient = False train_fetch_list = [] with paddle.static.program_guard(train_program, startup_program): with paddle.utils.unique_name.guard('merge'): optimizer = _create_optimizer(train_config) if train_config.get('use_fleet'): optimizer = fleet.distributed_optimizer(optimizer, dist_strategy) else: if train_config.get('amp_config') is not None: custom_white_list = train_config['amp_config'].get( 'custom_white_list', None) if custom_white_list is not None: train_config['amp_config'].pop('custom_white_list') custom_black_list = train_config['amp_config'].get( 'custom_black_list', None) if custom_black_list is not None: train_config['amp_config'].pop('custom_black_list') custom_black_varnames = train_config['amp_config'].get( 'custom_black_varnames', None) if custom_black_varnames is not None: train_config['amp_config'].pop('custom_black_varnames') amp_list = paddle.static.amp.CustomOpLists( custom_white_list=custom_white_list, custom_black_list=custom_black_list, custom_black_varnames=custom_black_varnames) optimizer = paddle.static.amp.decorate( optimizer=optimizer, amp_lists=amp_list, init_loss_scaling=128.0, use_dynamic_loss_scaling=True, **train_config['amp_config']) distill_loss, losses = _parse_distill_loss( config.get('distill_node_pair') or default_distill_node_pair, config.get('distill_loss') or 'l2_loss', ### default loss is l2_loss config.get('distill_lambda') or 1.0) ### default lambda is 1.0 loss = paddle.mean(distill_loss) loss.stop_gradient = False if 'prune_algo' in config: ### prune & asp if config['prune_algo'] == 'asp': optimizer = pruner.decorate(optimizer) optimizer.minimize(loss) elif 'prune_strategy' in config: ###unstructure prune optimizer.minimize(loss, no_grad_set=pruner.no_grad_set) else: optimizer.minimize(loss) train_fetch_list.append(loss) train_program_info = ProgramInfo(startup_program, train_program, feed_target_names, train_fetch_list, optimizer) test_program_info = ProgramInfo(startup_program, test_program, feed_target_names, fetch_targets) return train_program_info, test_program_info def build_quant_program(executor, place, config, train_program_info, test_program_info): scope = paddle.static.global_scope() assert isinstance(config, dict), "quant config must be dict" default_config = _quant_config_default default_config.update(config) config = _parse_configs(default_config) use_pact = config["use_pact"] if use_pact: act_preprocess_func = pact optimizer_func = get_pact_optimizer pact_executor = executor else: act_preprocess_func = None optimizer_func = None pact_executor = None test_program = quant_aware( test_program_info.program, place, config, scope=scope, act_preprocess_func=None, optimizer_func=None, executor=None, for_test=True) train_program = quant_aware( train_program_info.program, place, config, scope=scope, act_preprocess_func=act_preprocess_func, optimizer_func=optimizer_func, executor=pact_executor, for_test=False, return_program=True) train_program_info.program = train_program test_program_info.program = test_program return train_program_info, test_program_info, config def _get_label_info(dataloader, feed_target_names): label_info = {} for data in dataloader(): for key, value in data[0].items(): if key in feed_target_names: continue label_info['name'] = key label_info['dtype'] = np.array(value).dtype label_info['shape'] = list(np.array(value).shape) label_info['shape'][0] = -1 break break return label_info def build_prune_program(executor, place, config, train_program_info, strategy, patterns, eval_dataloader=None): if 'unstructure' in strategy: from ..prune.unstructured_pruner import UnstructuredPruner, GMPUnstructuredPruner if config["prune_strategy"] is None: pruner = UnstructuredPruner( train_program_info.program, mode=config['prune_mode'], ratio=config['pruned_ratio'], threshold=config['threshold'], prune_params_type=config['prune_params_type'], place=place, local_sparsity=config['local_sparsity'], ) elif config["prune_strategy"] == "gmp": pruner = GMPUnstructuredPruner( train_program_info.program, ratio=config['pruned_ratio'], prune_params_type=config['prune_params_type'], place=place, local_sparsity=config['local_sparsity'], configs=config['gmp_config']) else: if config['prune_algo'] == 'prune': from ..prune import Pruner pruner = Pruner(config["criterion"]) params = [] ### TODO(ceci3): set default prune weight for param in train_program_info.program.global_block( ).all_parameters(): if config[ 'prune_params_name'] is not None and param.name in config[ 'prune_params_name']: params.append(param.name) pruned_program, _, _ = pruner.prune( train_program_info.program, paddle.static.global_scope(), params=params, ratios=[config['pruned_ratio']] * len(params), place=place) train_program_info.program = pruned_program elif config['prune_algo'] == 'asp': from paddle.static import sparsity pruner = sparsity excluded_params_name = [] ### TODO(ceci3): set default prune weight for param in train_program_info.program.global_block( ).all_parameters(): if config[ 'prune_params_name'] is not None and param.name not in config[ 'prune_params_name']: excluded_params_name.append(param.name) pruner.set_excluded_layers(train_program_info.program, excluded_params_name) elif config['prune_algo'] == 'transformer_pruner': from .transformer_pruner import TransformerPruner assert eval_dataloader is not None, "transformer_pruner must set eval_dataloader" label_info = _get_label_info(eval_dataloader, train_program_info.feed_target_names) assert len(label_info) != 0, \ "maybe something wrong in get label name from eval_dataloader, please check your eval_dataloader" pruner = TransformerPruner( executor, place, train_program_info.program, patterns, label_info, width_mult=(1.0 - config['pruned_ratio']), dataloader=eval_dataloader, fetch_targets=train_program_info.fetch_targets) pruned_program = pruner.prune() train_program_info.program = pruned_program else: raise NotImplementedError( "prune_algo must be choice in [\"prune\", \"asp\"], {} is not support". format(config['prune_algo'])) return pruner, train_program_info def remove_unused_var_nodes(program): ''' This function is called before saving the sparse model to remove redundant nodes. Args: program(paddle.static.Program): The sparse model to be saved. Returns: program(paddle.static.Program): The sparse model. ''' from paddle.fluid import core from paddle.fluid.framework import IrGraph graph = IrGraph(core.Graph(program.desc), for_test=True) removed_nodes = set() ops = graph.all_op_nodes() for op_node in ops: for input_node in op_node.inputs: if '_mask' in input_node.name(): removed_nodes.add(op_node) graph.safe_remove_nodes(removed_nodes) program = graph.to_program() return program