From 304d7815e439e938f1a3a7349762e589e72c7a78 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Wed, 21 Dec 2022 10:41:56 +0800 Subject: [PATCH] optimize auto compress (#1550) * support sub block * support post-precess * update * fix * add unittest * revert prune * fix unittest * add unittest --- paddleslim/auto_compression/compressor.py | 2 +- .../create_compressed_program.py | 69 ++++++- .../auto_compression/transformer_pruner.py | 4 +- paddleslim/common/patterns.py | 6 +- paddleslim/common/patterns_common.py | 21 +- paddleslim/common/transformer_pattern.py | 2 +- paddleslim/dist/single_distiller.py | 194 +++++++++++++----- paddleslim/quant/quanter.py | 54 +++-- tests/act/test_act_qat.py | 97 +++++++++ 9 files changed, 358 insertions(+), 91 deletions(-) create mode 100644 tests/act/test_act_qat.py diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index e098830f..d59d8d9e 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -550,7 +550,7 @@ class AutoCompression: train_program_info = self._compiled_program(train_program_info, strategy) test_program_info = self._compiled_program(test_program_info, - self._strategy) + strategy) return train_program_info, test_program_info def _compiled_program(self, program_info, strategy): diff --git a/paddleslim/auto_compression/create_compressed_program.py b/paddleslim/auto_compression/create_compressed_program.py index eacc39ff..60885c4c 100644 --- a/paddleslim/auto_compression/create_compressed_program.py +++ b/paddleslim/auto_compression/create_compressed_program.py @@ -84,6 +84,13 @@ def _create_optimizer(train_config): return opt, lr +def _find_var_from_program(program, var_name): + for block in program.blocks: + if block.has_var(var_name): + return block.var(var_name) + raise ValueError("var {} not in this program".format(var_name)) + + def _get_distill_node(student_program, config): node = config.get('node') if len(node) == 0: @@ -95,7 +102,7 @@ def _get_distill_node(student_program, config): else: test_node = node[0] try: - test_var = student_program.global_block().var(test_node) + test_var = _find_var_from_program(student_program, test_node) distill_node_pair = [] if isinstance(node[0], list): for n_list in node: @@ -113,6 +120,14 @@ def _get_distill_node(student_program, config): return node +def _get_target_node(distill_node): + targets = [] + for idx, node in enumerate(distill_node): + if idx % 2 != 0: + targets.append(node) + return targets + + def _parse_distill_loss(distill_node_pair, distill_loss='l2', distill_lambda=1.0): @@ -149,6 +164,7 @@ def _load_program_and_merge(executor, model_dir, model_filename, params_filename, + distill_node_pair, teacher_idx=None, feed_target_names=None): scope = paddle.static.global_scope() @@ -171,8 +187,8 @@ def _load_program_and_merge(executor, _remove_fetch_node(teacher_program) - if teacher_idx == None or teacher_idx == 1: - test_program = train_program.clone(for_test=True) + target_nodes = _get_target_node(distill_node_pair) + teacher_program = teacher_program._prune(target_nodes) data_name_map = {} @@ -196,9 +212,9 @@ def _load_program_and_merge(executor, name_prefix=teacher_name_prefix, merge_feed=merge_feed) if teacher_idx == None or teacher_idx == 1: - return train_program, test_program, data_name_map + return train_program, data_name_map else: - return train_program, None, data_name_map + return train_program, data_name_map def build_distill_program(executor, @@ -224,6 +240,38 @@ def build_distill_program(executor, distill_node_pair = _get_distill_node(train_program, config) or default_distill_node_pair + test_program = train_program.clone(for_test=True) + + target_nodes = _get_target_node(distill_node_pair) + + def _prepend_feed(block, feed_idx, feed_target_names): + for idx in feed_idx[::-1]: + block._remove_op(idx) + + feed_var = block.create_var( + name='feed', + type=paddle.framework.core.VarDesc.VarType.FEED_MINIBATCH, + persistable=True, ) + + for i, name in enumerate(feed_target_names): + out = block.var(name) + block._prepend_op( + type='feed', + inputs={'X': [feed_var]}, + outputs={'Out': [out]}, + attrs={'col': i}) + + judge_feed_pos = False + if train_program.desc.block(0).op(0).type() != 'feed': + judge_feed_pos = True + if judge_feed_pos: + feed_idx = [] + for op in train_program.global_block().ops: + if op.type == 'feed': + feed_idx.append(op.idx) + _prepend_feed(train_program.global_block(), feed_idx, feed_target_names) + train_program = train_program._prune(target_nodes) + teacher_model_dir = config[ "teacher_model_dir"] if "teacher_model_dir" in config else config[ "teacher_model_path_prefix"] @@ -234,7 +282,7 @@ def build_distill_program(executor, params_filename = config["teacher_params_filename"][ tea_idx] if "teacher_params_filename" in config else None if tea_idx == 0: - train_program, test_program, data_name_map = _load_program_and_merge( + train_program, data_name_map = _load_program_and_merge( executor, place, train_program, @@ -242,10 +290,11 @@ def build_distill_program(executor, teacher_model_dir[tea_idx], model_filename, params_filename, + distill_node_pair, teacher_idx=(tea_idx + 1), feed_target_names=feed_target_names) else: - train_program, _, data_name_map = _load_program_and_merge( + train_program, data_name_map = _load_program_and_merge( executor, place, train_program, @@ -253,6 +302,7 @@ def build_distill_program(executor, teacher_model_dir[tea_idx], model_filename, params_filename, + distill_node_pair, teacher_idx=(tea_idx + 1), feed_target_names=feed_target_names) @@ -261,7 +311,7 @@ def build_distill_program(executor, "teacher_model_filename"] if "teacher_model_filename" in config else None params_filename = config[ "teacher_params_filename"] if "teacher_params_filename" in config else None - train_program, test_program, data_name_map = _load_program_and_merge( + train_program, data_name_map = _load_program_and_merge( executor, place, train_program, @@ -269,6 +319,7 @@ def build_distill_program(executor, teacher_model_dir, model_filename, params_filename, + distill_node_pair, teacher_idx=None, feed_target_names=feed_target_names) # all feed node should set stop_gradient is False, for using pact quant algo. @@ -479,7 +530,7 @@ def build_prune_program(executor, place=place) _logger.info( "####################channel pruning##########################") - for param in pruned_program.global_block().all_parameters(): + for param in pruned_program.all_parameters(): if param.name in original_shapes: _logger.info("{}, from {} to {}".format( param.name, original_shapes[param.name], param.shape)) diff --git a/paddleslim/auto_compression/transformer_pruner.py b/paddleslim/auto_compression/transformer_pruner.py index a7737585..f393eac4 100644 --- a/paddleslim/auto_compression/transformer_pruner.py +++ b/paddleslim/auto_compression/transformer_pruner.py @@ -19,7 +19,7 @@ from ..core import GraphWrapper from ..common import get_logger from ..common.recover_program import recover_inference_program from ..common.transformer_pattern import preprocess_transformer_patterns -from ..common.patterns_common import is_dynamic_weight_op +from ..common.patterns_common import has_trainable_var _logger = get_logger(__name__, level=logging.INFO) @@ -297,7 +297,7 @@ class TransformerPruner: tmp_mha_ops = patterns['MHA$0'] for op in tmp_mha_ops: if op.type() in ['matmul', 'matmul_v2'] and ( - not is_dynamic_weight_op(op)) and head_num == -1: + not has_trainable_var(op)) and head_num == -1: inp_var = op.inputs("X") head_num = inp_var[0].shape()[1] diff --git a/paddleslim/common/patterns.py b/paddleslim/common/patterns.py index 8b629d99..048b1fdc 100644 --- a/paddleslim/common/patterns.py +++ b/paddleslim/common/patterns.py @@ -30,7 +30,7 @@ def find_final_nodes(program): final_nodes = [] graph = GraphWrapper(program) for op in sorted(graph.ops()): - if op.type() in ALL_WEIGHT_OP and is_output_weight_ops(op, graph): + if has_trainable_var(op) and is_final_op_with_trainable_var(op, graph): n_op = has_bias(op, graph) if n_op is not None: final_nodes.extend(n_op.all_outputs()) @@ -52,7 +52,7 @@ def _is_mha(pattern_ops, pattern_ops_type, skip_quant_tensor_list=[]): matmul_num = 0 for op in pattern_ops: if op.type() in ['matmul', 'matmul_v2']: - if not is_dynamic_weight_op(op): + if not has_trainable_var(op): matmul_num += 1 if matmul_num == 2: return True @@ -68,7 +68,7 @@ def _is_ffn(pattern_ops, pattern_ops_type): act_num = 0 for op in pattern_ops: if op.type() in ['mul', 'matmul', 'matmul_v2']: - if is_dynamic_weight_op(op): + if has_trainable_var(op): linear_num += 1 if op.type() in ['relu', 'gelu']: act_num += 1 diff --git a/paddleslim/common/patterns_common.py b/paddleslim/common/patterns_common.py index 98b87562..ddfc1335 100644 --- a/paddleslim/common/patterns_common.py +++ b/paddleslim/common/patterns_common.py @@ -39,7 +39,7 @@ def find_weight_op(op, graph): """ Find operators with weight.""" next_ops = sorted(graph.next_ops(op)) for next_op in next_ops: - if is_dynamic_weight_op(next_op): + if has_trainable_var(next_op): return next_op else: return find_weight_op(next_op, graph) @@ -56,25 +56,24 @@ def get_weight(op, return_name=True): return None -def is_dynamic_weight_op(op): +def has_trainable_var(op): + """ Judge whether the operator with trainable variable """ weight_ops = ALL_WEIGHT_OP if op.type() in weight_ops: - if op.type() in ['mul', 'matmul', 'matmul_v2']: - for inp in sorted(op.all_inputs()): - if inp._var.persistable == True: - return True - return False - return True + for inp in sorted(op.all_inputs()): + if inp._var.persistable == True: + return True + return False return False -def is_output_weight_ops(op, graph): +def is_final_op_with_trainable_var(op, graph): """ Judge whether is the final op with weights in the graph """ next_ops = sorted(graph.next_ops(op)) for next_op in next_ops: - if is_dynamic_weight_op(next_op): + if has_trainable_var(next_op): return False - return is_output_weight_ops(next_op, graph) + return is_final_op_with_trainable_var(next_op, graph) return True diff --git a/paddleslim/common/transformer_pattern.py b/paddleslim/common/transformer_pattern.py index 32100139..1ba9c977 100644 --- a/paddleslim/common/transformer_pattern.py +++ b/paddleslim/common/transformer_pattern.py @@ -31,7 +31,7 @@ def _append_transformer_prune_params(op, graph, block_num, params_dict): continue next_op = _find_gemm_op(next_op, graph) if next_op.type() in ['mul', 'matmul', 'matmul_v2' - ] and is_dynamic_weight_op(next_op): + ] and has_trainable_var(next_op): if block_num not in params_dict: params_dict[block_num] = {} params_dict[block_num]['P1'] = [get_weight(next_op)] diff --git a/paddleslim/dist/single_distiller.py b/paddleslim/dist/single_distiller.py index 5cfdf808..fba5bb08 100644 --- a/paddleslim/dist/single_distiller.py +++ b/paddleslim/dist/single_distiller.py @@ -17,6 +17,30 @@ import paddle from paddleslim.core import GraphWrapper +def _find_var_from_program(program, var_name): + for block in program.blocks: + if block.has_var(var_name): + return block.var(var_name) + raise ValueError("var {} not in this program".format(var_name)) + + +def _except_feed_fetch(var_name, merge_feed): + if var_name != 'fetch' and (not merge_feed or var_name != 'feed'): + return True + return False + + +def _is_same_block(block1, block2): + if len(block1.ops) != len(block2.ops): + return False + + for op1, op2 in zip(block1.ops, block2.ops): + if op1.type != op2.type: + return False + + return True + + def merge(teacher_program, student_program, data_name_map, @@ -52,55 +76,127 @@ def merge(teacher_program, if teacher_scope == None: teacher_scope = scope teacher_program = teacher_program.clone(for_test=True) - for teacher_var in teacher_program.list_vars(): - skip_rename = False - if teacher_var.name != 'fetch' and (not merge_feed or - teacher_var.name != 'feed'): - if teacher_var.name in data_name_map.keys(): - new_name = data_name_map[teacher_var.name] - if new_name == teacher_var.name: - skip_rename = True - else: - new_name = name_prefix + teacher_var.name - if not skip_rename: - # scope var rename - old_var = teacher_scope.var(teacher_var.name).get_tensor() - renamed_var = scope.var(new_name).get_tensor() - renamed_var.set(np.array(old_var), place) - - # program var rename - renamed_var = teacher_program.global_block()._rename_var( - teacher_var.name, new_name) - - for teacher_var in teacher_program.list_vars(): - if teacher_var.name != 'fetch' and (not merge_feed or - teacher_var.name != 'feed'): - # student program add var - new_var = student_program.global_block()._clone_variable( - teacher_var, force_persistable=False) - new_var.stop_gradient = True + + is_same_model = True + if len(student_program.blocks) == len(teacher_program.blocks): + for block in teacher_program.blocks: + if not _is_same_block(block, student_program.block(block.idx)): + is_same_model = False + break + else: + is_same_model = False + + if is_same_model: + for block in student_program.blocks: + for op in block.ops: + if op.type == 'while': + tmp_var = [] + for _var_name in op.input('X'): + tmp_var.append('teacher_' + _var_name) + tmp_var.extend(op.input('X')) + op.desc.set_input("X", tmp_var) for block in teacher_program.blocks: + for teacher_var in list(block.vars.values()): + skip_rename = False + if _except_feed_fetch(teacher_var.name, merge_feed): + if teacher_var.name in data_name_map.keys(): + new_name = data_name_map[teacher_var.name] + if new_name == teacher_var.name: + skip_rename = True + else: + new_name = name_prefix + teacher_var.name + if not skip_rename: + # scope var rename + old_var = teacher_scope.var(teacher_var.name).get_tensor() + renamed_var = scope.var(new_name).get_tensor() + renamed_var.set(np.array(old_var), place) + + # program var rename + renamed_var = block._rename_var(teacher_var.name, new_name) + + ### input and output of the sub_block need to rename specially. for op in block.ops: + for iname in op.input_names: + for in_var_name in op.input(iname): + if _except_feed_fetch( + in_var_name, + merge_feed) and not block.has_var(in_var_name): + if in_var_name in data_name_map.keys(): + new_name = data_name_map[in_var_name] + if new_name != in_var_name: + op._rename_input(in_var_name, + name_prefix + in_var_name) + else: + op._rename_input(in_var_name, + name_prefix + in_var_name) + + for oname in op.output_names: + for out_var_name in op.output(oname): + if _except_feed_fetch( + out_var_name, + merge_feed) and not block.has_var(out_var_name): + if out_var_name in data_name_map.keys(): + new_name = data_name_map[out_var_name] + if new_name != out_var_name: + op._rename_output(out_var_name, + name_prefix + out_var_name) + else: + op._rename_output(out_var_name, + name_prefix + out_var_name) + + for block in teacher_program.blocks: + for teacher_var in list(block.vars.values()): + if teacher_var.name != 'fetch' and (not merge_feed or + teacher_var.name != 'feed'): + # student program add var + if len(student_program.blocks) > 1 and is_same_model: + new_var = student_program.block(block.idx)._clone_variable( + teacher_var, force_persistable=False) + else: + new_var = student_program.global_block()._clone_variable( + teacher_var, force_persistable=False) + new_var.stop_gradient = True + + for block in reversed(teacher_program.blocks): + for op_idx, op in enumerate(block.ops): if (not merge_feed or op.type != 'feed') and op.type != 'fetch': inputs = {} outputs = {} attrs = {} for input_name in op.input_names: - inputs[input_name] = [ - block.var(in_var_name) - for in_var_name in op.input(input_name) - ] + inputs[input_name] = [] + for in_var_name in op.input(input_name): + inputs[input_name].append( + block._find_var_recursive(in_var_name)) for output_name in op.output_names: - outputs[output_name] = [ - block.var(out_var_name) - for out_var_name in op.output(output_name) - ] + outputs[output_name] = [] + for out_var_name in op.output(output_name): + outputs[output_name].append( + block._find_var_recursive(out_var_name)) + for attr_name in op.attr_names: - attrs[attr_name] = op.attr(attr_name) - student_program.global_block().append_op( - type=op.type, inputs=inputs, outputs=outputs, attrs=attrs) + if attr_name == 'sub_block': + attrs[attr_name] = student_program.block( + op._block_attr("sub_block").idx) + else: + attrs[attr_name] = op.attr(attr_name) + if len(student_program.blocks) > 1 and is_same_model: + student_program.block(op.block.idx)._insert_op( + 2 * op_idx, + type=op.type, + inputs=inputs, + outputs=outputs, + attrs=attrs) + else: + student_program.global_block().append_op( + type=op.type, + inputs=inputs, + outputs=outputs, + attrs=attrs) + + student_program._sync_with_cpp() student_graph = GraphWrapper(student_program) for op in student_graph.ops(): @@ -137,10 +233,10 @@ def fsp(teacher_var1_name, """ if program == None: program = paddle.static.default_main_program() - teacher_var1 = program.global_block().var(teacher_var1_name) - teacher_var2 = program.global_block().var(teacher_var2_name) - student_var1 = program.global_block().var(student_var1_name) - student_var2 = program.global_block().var(student_var2_name) + teacher_var1 = _find_var_from_program(program, teacher_var1_name) + teacher_var2 = _find_var_from_program(program, teacher_var2_name) + student_var1 = _find_var_from_program(program, student_var1_name) + student_var2 = _find_var_from_program(program, student_var2_name) teacher_fsp_matrix = paddle.fluid.layers.fsp_matrix(teacher_var1, teacher_var2) student_fsp_matrix = paddle.fluid.layers.fsp_matrix(student_var1, @@ -165,8 +261,8 @@ def l2(teacher_var_name, student_var_name, program=None): """ if program == None: program = paddle.static.default_main_program() - student_var = program.global_block().var(student_var_name) - teacher_var = program.global_block().var(teacher_var_name) + student_var = _find_var_from_program(program, student_var_name) + teacher_var = _find_var_from_program(program, teacher_var_name) l2_loss = paddle.mean( paddle.nn.functional.square_error_cost(student_var, teacher_var)) return l2_loss @@ -194,8 +290,8 @@ def soft_label(teacher_var_name, """ if program == None: program = paddle.static.default_main_program() - student_var = program.global_block().var(student_var_name) - teacher_var = program.global_block().var(teacher_var_name) + student_var = _find_var_from_program(program, student_var_name) + teacher_var = _find_var_from_program(program, teacher_var_name) teacher_var.stop_gradient = True student_var = paddle.nn.functional.softmax(student_var / @@ -225,7 +321,7 @@ def loss(loss_func, program=None, **kwargs): for item in kwargs.items(): if isinstance(item[1], str): func_parameters.setdefault(item[0], - program.global_block().var(item[1])) + _find_var_from_program(program, item[1])) else: func_parameters.setdefault(item[0], item[1]) loss = loss_func(**func_parameters) @@ -297,8 +393,8 @@ def dkd(teacher_var_name, """ if program == None: program = paddle.static.default_main_program() - student_var = program.global_block().var(student_var_name) - teacher_var = program.global_block().var(teacher_var_name) + student_var = _find_var_from_program(program, student_var_name) + teacher_var = _find_var_from_program(program, teacher_var_name) return _dkd_loss( student_var, teacher_var, diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py index a038da56..e11d341b 100755 --- a/paddleslim/quant/quanter.py +++ b/paddleslim/quant/quanter.py @@ -1,5 +1,3 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -34,7 +32,7 @@ from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass from ..common import get_logger from ..common.patterns import get_patterns -from ..common.patterns_common import is_dynamic_weight_op, get_weight +from ..common.patterns_common import has_trainable_var, get_weight from ..core.graph_wrapper import GraphWrapper _logger = get_logger(__name__, level=logging.INFO) @@ -341,7 +339,7 @@ def quant_aware(program, for op in ops: if op._op.type in ['matmul', 'matmul_v2'] and ( - not is_dynamic_weight_op(op)): + not has_trainable_var(op)): input_names = op._op.input_arg_names for input_name in input_names: pre_op = find_pre_ops(program, input_name)[0] @@ -387,6 +385,7 @@ def quant_aware(program, scale_dict = post_training_quantization._scale_dict else: main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) + sub_graphs = [sub_graph for sub_graph in main_graph.all_sub_graphs()] transform_pass_ops = [] quant_dequant_ops = [] for op_type in config['quantize_op_types']: @@ -416,7 +415,8 @@ def quant_aware(program, executor=executor, is_test=is_test) - transform_pass.apply(main_graph) + for sub_graph in sub_graphs: + transform_pass.apply(sub_graph) if len(quant_dequant_ops) > 0: qdq_func = 'AddQuantDequantPassV2' if config[ @@ -430,7 +430,8 @@ def quant_aware(program, quantizable_op_type=quant_dequant_ops, is_test=is_test) - quant_dequant_pass.apply(main_graph) + for sub_graph in sub_graphs: + quant_dequant_pass.apply(sub_graph) out_scale_training_pass = OutScaleForTrainingPass( scope=scope, @@ -439,16 +440,18 @@ def quant_aware(program, is_test=is_test, scale_dict=scale_dict) - out_scale_training_pass.apply(main_graph) + for sub_graph in sub_graphs: + out_scale_training_pass.apply(sub_graph) - if (weight_preprocess_func is not None or - act_preprocess_func is not None) and not for_test: + if (weight_preprocess_func is not None or act_preprocess_func is not None + ) and not for_test and not config['onnx_format']: _logger.info( "When a preprocess_func is used in quant_aware, Need to save a mapping table to match variable names in the convert phase." ) _logger.info("The mapping table is saved as '{}'.".format( VARS_MAPPING_TABLE)) - save_dict(main_graph.out_node_mapping_table) + for sub_graph in sub_graphs: + save_dict(sub_graph.out_node_mapping_table) # TDOD: remove it. if draw_graph: @@ -683,18 +686,21 @@ def convert(program, if config['onnx_format']: quant_weight_pass = QuantWeightPass(scope, place) - quant_weight_pass.apply(test_graph) + for sub_graph in test_graph.all_sub_graphs(): + quant_weight_pass.apply(sub_graph) try: out_scale_infer_pass = AddQuantDequantForInferencePass( scope=scope, place=place, quant_bits=config['activation_bits']) - out_scale_infer_pass.apply(test_graph) + for sub_graph in test_graph.all_sub_graphs(): + out_scale_infer_pass.apply(sub_graph) except: _logger.warning( "Unable to convert quant model with onnx_format=True, please update PaddlePaddle >= 2.4.0" ) else: out_scale_infer_pass = OutScaleForInferencePass(scope=scope) - out_scale_infer_pass.apply(test_graph) + for sub_graph in test_graph.all_sub_graphs(): + out_scale_infer_pass.apply(sub_graph) # Freeze the graph after training by adjusting the quantize # operators' order for the inference. freeze_pass = QuantizationFreezePass( @@ -705,13 +711,31 @@ def convert(program, weight_quantize_type=config['weight_quantize_type']) if os.path.exists(VARS_MAPPING_TABLE): test_graph.out_node_mapping_table = load_dict() - freeze_pass.apply(test_graph) + for sub_graph in test_graph.all_sub_graphs(): + freeze_pass.apply(sub_graph) freezed_program = test_graph.to_program() + # Move sub blocks persistable var to global block + global_block = freezed_program.global_block() + for _op in global_block.ops: + if _op.type == "while": + _block_id = _op.attr("sub_block").id + _block = freezed_program.block(_block_id) + persistables = [] + for _name, _var in _block.vars.items(): + if _var.persistable: + global_block._clone_variable(_var) + persistables.append(_name) + for _name in persistables: + _block._remove_var(_name) + persistables.extend(_op.input('X')) + _op.desc.set_input("X", persistables) + if save_int8: convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place) - convert_int8_pass.apply(test_graph) + for sub_graph in test_graph.all_sub_graphs(): + convert_int8_pass.apply(sub_graph) freezed_program_int8 = test_graph.to_program() return freezed_program, freezed_program_int8 else: diff --git a/tests/act/test_act_qat.py b/tests/act/test_act_qat.py new file mode 100644 index 00000000..7389c110 --- /dev/null +++ b/tests/act/test_act_qat.py @@ -0,0 +1,97 @@ +import os +import sys +import unittest +sys.path.append("../../") +import numpy as np +import paddle +from paddle.io import Dataset +from paddleslim.auto_compression import AutoCompression +paddle.enable_static() + + +class RandomEvalDataset(Dataset): + def __init__(self, num_samples, image_shape=[1, 28, 28], class_num=10): + self.num_samples = num_samples + self.image_shape = image_shape + self.class_num = class_num + + def __getitem__(self, idx): + image = np.random.random(self.image_shape).astype('float32') + return image + + def __len__(self): + return self.num_samples + + +class ACTQATWhileOP(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(ACTQATWhileOP, self).__init__(*args, **kwargs) + if not os.path.exists('mnist_while'): + os.system( + "wget -q http://paddle-inference-dist.bj.bcebos.com/int8/mnist_while.tar.gz" + ) + os.system('tar -xzvf mnist_while.tar.gz') + self.create_dataloader() + self.get_config() + + def create_dataloader(self): + # define a random dataset + self.eval_dataset = RandomEvalDataset(32) + + def get_config(self): + self.config = { + 'QuantAware': {}, + 'Distillation': {}, + 'TrainConfig': { + 'epochs': 1, + 'eval_iter': 100, + 'learning_rate': 5.0e-03, + 'optimizer_builder': { + 'optimizer': { + 'type': 'SGD' + }, + "weight_decay": 0.0005, + } + } + } + + def test_demo(self): + image = paddle.static.data( + name='x', shape=[-1, 1, 28, 28], dtype='float32') + train_loader = paddle.io.DataLoader( + self.eval_dataset, feed_list=[image], batch_size=4) + + ac = AutoCompression( + model_dir="./mnist_while", + model_filename="model.pdmodel", + params_filename="model.pdiparams", + config=self.config, + save_dir="qat_while_output", + train_dataloader=train_loader) + ac.compress() + os.system('rm -rf qat_while_output') + + +class ACTQATWhileOPCase2(ACTQATWhileOP): + def get_config(self): + self.config = { + 'QuantAware': { + 'quantize_op_types': ['conv2d', 'mul', 'relu'] + }, + 'Distillation': {}, + 'TrainConfig': { + 'epochs': 1, + 'eval_iter': 100, + 'learning_rate': 5.0e-03, + 'optimizer_builder': { + 'optimizer': { + 'type': 'SGD' + }, + "weight_decay": 0.0005, + } + } + } + + +if __name__ == '__main__': + unittest.main() -- GitLab