未验证 提交 304d7815 编写于 作者: C ceci3 提交者: GitHub

optimize auto compress (#1550)

* support sub block

* support post-precess

* update

* fix

* add unittest

* revert prune

* fix unittest

* add unittest
上级 e72ce197
...@@ -550,7 +550,7 @@ class AutoCompression: ...@@ -550,7 +550,7 @@ class AutoCompression:
train_program_info = self._compiled_program(train_program_info, train_program_info = self._compiled_program(train_program_info,
strategy) strategy)
test_program_info = self._compiled_program(test_program_info, test_program_info = self._compiled_program(test_program_info,
self._strategy) strategy)
return train_program_info, test_program_info return train_program_info, test_program_info
def _compiled_program(self, program_info, strategy): def _compiled_program(self, program_info, strategy):
......
...@@ -84,6 +84,13 @@ def _create_optimizer(train_config): ...@@ -84,6 +84,13 @@ def _create_optimizer(train_config):
return opt, lr return opt, lr
def _find_var_from_program(program, var_name):
for block in program.blocks:
if block.has_var(var_name):
return block.var(var_name)
raise ValueError("var {} not in this program".format(var_name))
def _get_distill_node(student_program, config): def _get_distill_node(student_program, config):
node = config.get('node') node = config.get('node')
if len(node) == 0: if len(node) == 0:
...@@ -95,7 +102,7 @@ def _get_distill_node(student_program, config): ...@@ -95,7 +102,7 @@ def _get_distill_node(student_program, config):
else: else:
test_node = node[0] test_node = node[0]
try: try:
test_var = student_program.global_block().var(test_node) test_var = _find_var_from_program(student_program, test_node)
distill_node_pair = [] distill_node_pair = []
if isinstance(node[0], list): if isinstance(node[0], list):
for n_list in node: for n_list in node:
...@@ -113,6 +120,14 @@ def _get_distill_node(student_program, config): ...@@ -113,6 +120,14 @@ def _get_distill_node(student_program, config):
return node return node
def _get_target_node(distill_node):
targets = []
for idx, node in enumerate(distill_node):
if idx % 2 != 0:
targets.append(node)
return targets
def _parse_distill_loss(distill_node_pair, def _parse_distill_loss(distill_node_pair,
distill_loss='l2', distill_loss='l2',
distill_lambda=1.0): distill_lambda=1.0):
...@@ -149,6 +164,7 @@ def _load_program_and_merge(executor, ...@@ -149,6 +164,7 @@ def _load_program_and_merge(executor,
model_dir, model_dir,
model_filename, model_filename,
params_filename, params_filename,
distill_node_pair,
teacher_idx=None, teacher_idx=None,
feed_target_names=None): feed_target_names=None):
scope = paddle.static.global_scope() scope = paddle.static.global_scope()
...@@ -171,8 +187,8 @@ def _load_program_and_merge(executor, ...@@ -171,8 +187,8 @@ def _load_program_and_merge(executor,
_remove_fetch_node(teacher_program) _remove_fetch_node(teacher_program)
if teacher_idx == None or teacher_idx == 1: target_nodes = _get_target_node(distill_node_pair)
test_program = train_program.clone(for_test=True) teacher_program = teacher_program._prune(target_nodes)
data_name_map = {} data_name_map = {}
...@@ -196,9 +212,9 @@ def _load_program_and_merge(executor, ...@@ -196,9 +212,9 @@ def _load_program_and_merge(executor,
name_prefix=teacher_name_prefix, name_prefix=teacher_name_prefix,
merge_feed=merge_feed) merge_feed=merge_feed)
if teacher_idx == None or teacher_idx == 1: if teacher_idx == None or teacher_idx == 1:
return train_program, test_program, data_name_map return train_program, data_name_map
else: else:
return train_program, None, data_name_map return train_program, data_name_map
def build_distill_program(executor, def build_distill_program(executor,
...@@ -224,6 +240,38 @@ def build_distill_program(executor, ...@@ -224,6 +240,38 @@ def build_distill_program(executor,
distill_node_pair = _get_distill_node(train_program, distill_node_pair = _get_distill_node(train_program,
config) or default_distill_node_pair config) or default_distill_node_pair
test_program = train_program.clone(for_test=True)
target_nodes = _get_target_node(distill_node_pair)
def _prepend_feed(block, feed_idx, feed_target_names):
for idx in feed_idx[::-1]:
block._remove_op(idx)
feed_var = block.create_var(
name='feed',
type=paddle.framework.core.VarDesc.VarType.FEED_MINIBATCH,
persistable=True, )
for i, name in enumerate(feed_target_names):
out = block.var(name)
block._prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
judge_feed_pos = False
if train_program.desc.block(0).op(0).type() != 'feed':
judge_feed_pos = True
if judge_feed_pos:
feed_idx = []
for op in train_program.global_block().ops:
if op.type == 'feed':
feed_idx.append(op.idx)
_prepend_feed(train_program.global_block(), feed_idx, feed_target_names)
train_program = train_program._prune(target_nodes)
teacher_model_dir = config[ teacher_model_dir = config[
"teacher_model_dir"] if "teacher_model_dir" in config else config[ "teacher_model_dir"] if "teacher_model_dir" in config else config[
"teacher_model_path_prefix"] "teacher_model_path_prefix"]
...@@ -234,7 +282,7 @@ def build_distill_program(executor, ...@@ -234,7 +282,7 @@ def build_distill_program(executor,
params_filename = config["teacher_params_filename"][ params_filename = config["teacher_params_filename"][
tea_idx] if "teacher_params_filename" in config else None tea_idx] if "teacher_params_filename" in config else None
if tea_idx == 0: if tea_idx == 0:
train_program, test_program, data_name_map = _load_program_and_merge( train_program, data_name_map = _load_program_and_merge(
executor, executor,
place, place,
train_program, train_program,
...@@ -242,10 +290,11 @@ def build_distill_program(executor, ...@@ -242,10 +290,11 @@ def build_distill_program(executor,
teacher_model_dir[tea_idx], teacher_model_dir[tea_idx],
model_filename, model_filename,
params_filename, params_filename,
distill_node_pair,
teacher_idx=(tea_idx + 1), teacher_idx=(tea_idx + 1),
feed_target_names=feed_target_names) feed_target_names=feed_target_names)
else: else:
train_program, _, data_name_map = _load_program_and_merge( train_program, data_name_map = _load_program_and_merge(
executor, executor,
place, place,
train_program, train_program,
...@@ -253,6 +302,7 @@ def build_distill_program(executor, ...@@ -253,6 +302,7 @@ def build_distill_program(executor,
teacher_model_dir[tea_idx], teacher_model_dir[tea_idx],
model_filename, model_filename,
params_filename, params_filename,
distill_node_pair,
teacher_idx=(tea_idx + 1), teacher_idx=(tea_idx + 1),
feed_target_names=feed_target_names) feed_target_names=feed_target_names)
...@@ -261,7 +311,7 @@ def build_distill_program(executor, ...@@ -261,7 +311,7 @@ def build_distill_program(executor,
"teacher_model_filename"] if "teacher_model_filename" in config else None "teacher_model_filename"] if "teacher_model_filename" in config else None
params_filename = config[ params_filename = config[
"teacher_params_filename"] if "teacher_params_filename" in config else None "teacher_params_filename"] if "teacher_params_filename" in config else None
train_program, test_program, data_name_map = _load_program_and_merge( train_program, data_name_map = _load_program_and_merge(
executor, executor,
place, place,
train_program, train_program,
...@@ -269,6 +319,7 @@ def build_distill_program(executor, ...@@ -269,6 +319,7 @@ def build_distill_program(executor,
teacher_model_dir, teacher_model_dir,
model_filename, model_filename,
params_filename, params_filename,
distill_node_pair,
teacher_idx=None, teacher_idx=None,
feed_target_names=feed_target_names) feed_target_names=feed_target_names)
# all feed node should set stop_gradient is False, for using pact quant algo. # all feed node should set stop_gradient is False, for using pact quant algo.
...@@ -479,7 +530,7 @@ def build_prune_program(executor, ...@@ -479,7 +530,7 @@ def build_prune_program(executor,
place=place) place=place)
_logger.info( _logger.info(
"####################channel pruning##########################") "####################channel pruning##########################")
for param in pruned_program.global_block().all_parameters(): for param in pruned_program.all_parameters():
if param.name in original_shapes: if param.name in original_shapes:
_logger.info("{}, from {} to {}".format( _logger.info("{}, from {} to {}".format(
param.name, original_shapes[param.name], param.shape)) param.name, original_shapes[param.name], param.shape))
......
...@@ -19,7 +19,7 @@ from ..core import GraphWrapper ...@@ -19,7 +19,7 @@ from ..core import GraphWrapper
from ..common import get_logger from ..common import get_logger
from ..common.recover_program import recover_inference_program from ..common.recover_program import recover_inference_program
from ..common.transformer_pattern import preprocess_transformer_patterns from ..common.transformer_pattern import preprocess_transformer_patterns
from ..common.patterns_common import is_dynamic_weight_op from ..common.patterns_common import has_trainable_var
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
...@@ -297,7 +297,7 @@ class TransformerPruner: ...@@ -297,7 +297,7 @@ class TransformerPruner:
tmp_mha_ops = patterns['MHA$0'] tmp_mha_ops = patterns['MHA$0']
for op in tmp_mha_ops: for op in tmp_mha_ops:
if op.type() in ['matmul', 'matmul_v2'] and ( if op.type() in ['matmul', 'matmul_v2'] and (
not is_dynamic_weight_op(op)) and head_num == -1: not has_trainable_var(op)) and head_num == -1:
inp_var = op.inputs("X") inp_var = op.inputs("X")
head_num = inp_var[0].shape()[1] head_num = inp_var[0].shape()[1]
......
...@@ -30,7 +30,7 @@ def find_final_nodes(program): ...@@ -30,7 +30,7 @@ def find_final_nodes(program):
final_nodes = [] final_nodes = []
graph = GraphWrapper(program) graph = GraphWrapper(program)
for op in sorted(graph.ops()): for op in sorted(graph.ops()):
if op.type() in ALL_WEIGHT_OP and is_output_weight_ops(op, graph): if has_trainable_var(op) and is_final_op_with_trainable_var(op, graph):
n_op = has_bias(op, graph) n_op = has_bias(op, graph)
if n_op is not None: if n_op is not None:
final_nodes.extend(n_op.all_outputs()) final_nodes.extend(n_op.all_outputs())
...@@ -52,7 +52,7 @@ def _is_mha(pattern_ops, pattern_ops_type, skip_quant_tensor_list=[]): ...@@ -52,7 +52,7 @@ def _is_mha(pattern_ops, pattern_ops_type, skip_quant_tensor_list=[]):
matmul_num = 0 matmul_num = 0
for op in pattern_ops: for op in pattern_ops:
if op.type() in ['matmul', 'matmul_v2']: if op.type() in ['matmul', 'matmul_v2']:
if not is_dynamic_weight_op(op): if not has_trainable_var(op):
matmul_num += 1 matmul_num += 1
if matmul_num == 2: if matmul_num == 2:
return True return True
...@@ -68,7 +68,7 @@ def _is_ffn(pattern_ops, pattern_ops_type): ...@@ -68,7 +68,7 @@ def _is_ffn(pattern_ops, pattern_ops_type):
act_num = 0 act_num = 0
for op in pattern_ops: for op in pattern_ops:
if op.type() in ['mul', 'matmul', 'matmul_v2']: if op.type() in ['mul', 'matmul', 'matmul_v2']:
if is_dynamic_weight_op(op): if has_trainable_var(op):
linear_num += 1 linear_num += 1
if op.type() in ['relu', 'gelu']: if op.type() in ['relu', 'gelu']:
act_num += 1 act_num += 1
......
...@@ -39,7 +39,7 @@ def find_weight_op(op, graph): ...@@ -39,7 +39,7 @@ def find_weight_op(op, graph):
""" Find operators with weight.""" """ Find operators with weight."""
next_ops = sorted(graph.next_ops(op)) next_ops = sorted(graph.next_ops(op))
for next_op in next_ops: for next_op in next_ops:
if is_dynamic_weight_op(next_op): if has_trainable_var(next_op):
return next_op return next_op
else: else:
return find_weight_op(next_op, graph) return find_weight_op(next_op, graph)
...@@ -56,25 +56,24 @@ def get_weight(op, return_name=True): ...@@ -56,25 +56,24 @@ def get_weight(op, return_name=True):
return None return None
def is_dynamic_weight_op(op): def has_trainable_var(op):
""" Judge whether the operator with trainable variable """
weight_ops = ALL_WEIGHT_OP weight_ops = ALL_WEIGHT_OP
if op.type() in weight_ops: if op.type() in weight_ops:
if op.type() in ['mul', 'matmul', 'matmul_v2']: for inp in sorted(op.all_inputs()):
for inp in sorted(op.all_inputs()): if inp._var.persistable == True:
if inp._var.persistable == True: return True
return True return False
return False
return True
return False return False
def is_output_weight_ops(op, graph): def is_final_op_with_trainable_var(op, graph):
""" Judge whether is the final op with weights in the graph """ """ Judge whether is the final op with weights in the graph """
next_ops = sorted(graph.next_ops(op)) next_ops = sorted(graph.next_ops(op))
for next_op in next_ops: for next_op in next_ops:
if is_dynamic_weight_op(next_op): if has_trainable_var(next_op):
return False return False
return is_output_weight_ops(next_op, graph) return is_final_op_with_trainable_var(next_op, graph)
return True return True
......
...@@ -31,7 +31,7 @@ def _append_transformer_prune_params(op, graph, block_num, params_dict): ...@@ -31,7 +31,7 @@ def _append_transformer_prune_params(op, graph, block_num, params_dict):
continue continue
next_op = _find_gemm_op(next_op, graph) next_op = _find_gemm_op(next_op, graph)
if next_op.type() in ['mul', 'matmul', 'matmul_v2' if next_op.type() in ['mul', 'matmul', 'matmul_v2'
] and is_dynamic_weight_op(next_op): ] and has_trainable_var(next_op):
if block_num not in params_dict: if block_num not in params_dict:
params_dict[block_num] = {} params_dict[block_num] = {}
params_dict[block_num]['P1'] = [get_weight(next_op)] params_dict[block_num]['P1'] = [get_weight(next_op)]
......
...@@ -17,6 +17,30 @@ import paddle ...@@ -17,6 +17,30 @@ import paddle
from paddleslim.core import GraphWrapper from paddleslim.core import GraphWrapper
def _find_var_from_program(program, var_name):
for block in program.blocks:
if block.has_var(var_name):
return block.var(var_name)
raise ValueError("var {} not in this program".format(var_name))
def _except_feed_fetch(var_name, merge_feed):
if var_name != 'fetch' and (not merge_feed or var_name != 'feed'):
return True
return False
def _is_same_block(block1, block2):
if len(block1.ops) != len(block2.ops):
return False
for op1, op2 in zip(block1.ops, block2.ops):
if op1.type != op2.type:
return False
return True
def merge(teacher_program, def merge(teacher_program,
student_program, student_program,
data_name_map, data_name_map,
...@@ -52,55 +76,127 @@ def merge(teacher_program, ...@@ -52,55 +76,127 @@ def merge(teacher_program,
if teacher_scope == None: if teacher_scope == None:
teacher_scope = scope teacher_scope = scope
teacher_program = teacher_program.clone(for_test=True) teacher_program = teacher_program.clone(for_test=True)
for teacher_var in teacher_program.list_vars():
skip_rename = False is_same_model = True
if teacher_var.name != 'fetch' and (not merge_feed or if len(student_program.blocks) == len(teacher_program.blocks):
teacher_var.name != 'feed'): for block in teacher_program.blocks:
if teacher_var.name in data_name_map.keys(): if not _is_same_block(block, student_program.block(block.idx)):
new_name = data_name_map[teacher_var.name] is_same_model = False
if new_name == teacher_var.name: break
skip_rename = True else:
else: is_same_model = False
new_name = name_prefix + teacher_var.name
if not skip_rename: if is_same_model:
# scope var rename for block in student_program.blocks:
old_var = teacher_scope.var(teacher_var.name).get_tensor() for op in block.ops:
renamed_var = scope.var(new_name).get_tensor() if op.type == 'while':
renamed_var.set(np.array(old_var), place) tmp_var = []
for _var_name in op.input('X'):
# program var rename tmp_var.append('teacher_' + _var_name)
renamed_var = teacher_program.global_block()._rename_var( tmp_var.extend(op.input('X'))
teacher_var.name, new_name) op.desc.set_input("X", tmp_var)
for teacher_var in teacher_program.list_vars():
if teacher_var.name != 'fetch' and (not merge_feed or
teacher_var.name != 'feed'):
# student program add var
new_var = student_program.global_block()._clone_variable(
teacher_var, force_persistable=False)
new_var.stop_gradient = True
for block in teacher_program.blocks: for block in teacher_program.blocks:
for teacher_var in list(block.vars.values()):
skip_rename = False
if _except_feed_fetch(teacher_var.name, merge_feed):
if teacher_var.name in data_name_map.keys():
new_name = data_name_map[teacher_var.name]
if new_name == teacher_var.name:
skip_rename = True
else:
new_name = name_prefix + teacher_var.name
if not skip_rename:
# scope var rename
old_var = teacher_scope.var(teacher_var.name).get_tensor()
renamed_var = scope.var(new_name).get_tensor()
renamed_var.set(np.array(old_var), place)
# program var rename
renamed_var = block._rename_var(teacher_var.name, new_name)
### input and output of the sub_block need to rename specially.
for op in block.ops: for op in block.ops:
for iname in op.input_names:
for in_var_name in op.input(iname):
if _except_feed_fetch(
in_var_name,
merge_feed) and not block.has_var(in_var_name):
if in_var_name in data_name_map.keys():
new_name = data_name_map[in_var_name]
if new_name != in_var_name:
op._rename_input(in_var_name,
name_prefix + in_var_name)
else:
op._rename_input(in_var_name,
name_prefix + in_var_name)
for oname in op.output_names:
for out_var_name in op.output(oname):
if _except_feed_fetch(
out_var_name,
merge_feed) and not block.has_var(out_var_name):
if out_var_name in data_name_map.keys():
new_name = data_name_map[out_var_name]
if new_name != out_var_name:
op._rename_output(out_var_name,
name_prefix + out_var_name)
else:
op._rename_output(out_var_name,
name_prefix + out_var_name)
for block in teacher_program.blocks:
for teacher_var in list(block.vars.values()):
if teacher_var.name != 'fetch' and (not merge_feed or
teacher_var.name != 'feed'):
# student program add var
if len(student_program.blocks) > 1 and is_same_model:
new_var = student_program.block(block.idx)._clone_variable(
teacher_var, force_persistable=False)
else:
new_var = student_program.global_block()._clone_variable(
teacher_var, force_persistable=False)
new_var.stop_gradient = True
for block in reversed(teacher_program.blocks):
for op_idx, op in enumerate(block.ops):
if (not merge_feed or op.type != 'feed') and op.type != 'fetch': if (not merge_feed or op.type != 'feed') and op.type != 'fetch':
inputs = {} inputs = {}
outputs = {} outputs = {}
attrs = {} attrs = {}
for input_name in op.input_names: for input_name in op.input_names:
inputs[input_name] = [ inputs[input_name] = []
block.var(in_var_name) for in_var_name in op.input(input_name):
for in_var_name in op.input(input_name) inputs[input_name].append(
] block._find_var_recursive(in_var_name))
for output_name in op.output_names: for output_name in op.output_names:
outputs[output_name] = [ outputs[output_name] = []
block.var(out_var_name) for out_var_name in op.output(output_name):
for out_var_name in op.output(output_name) outputs[output_name].append(
] block._find_var_recursive(out_var_name))
for attr_name in op.attr_names: for attr_name in op.attr_names:
attrs[attr_name] = op.attr(attr_name) if attr_name == 'sub_block':
student_program.global_block().append_op( attrs[attr_name] = student_program.block(
type=op.type, inputs=inputs, outputs=outputs, attrs=attrs) op._block_attr("sub_block").idx)
else:
attrs[attr_name] = op.attr(attr_name)
if len(student_program.blocks) > 1 and is_same_model:
student_program.block(op.block.idx)._insert_op(
2 * op_idx,
type=op.type,
inputs=inputs,
outputs=outputs,
attrs=attrs)
else:
student_program.global_block().append_op(
type=op.type,
inputs=inputs,
outputs=outputs,
attrs=attrs)
student_program._sync_with_cpp()
student_graph = GraphWrapper(student_program) student_graph = GraphWrapper(student_program)
for op in student_graph.ops(): for op in student_graph.ops():
...@@ -137,10 +233,10 @@ def fsp(teacher_var1_name, ...@@ -137,10 +233,10 @@ def fsp(teacher_var1_name,
""" """
if program == None: if program == None:
program = paddle.static.default_main_program() program = paddle.static.default_main_program()
teacher_var1 = program.global_block().var(teacher_var1_name) teacher_var1 = _find_var_from_program(program, teacher_var1_name)
teacher_var2 = program.global_block().var(teacher_var2_name) teacher_var2 = _find_var_from_program(program, teacher_var2_name)
student_var1 = program.global_block().var(student_var1_name) student_var1 = _find_var_from_program(program, student_var1_name)
student_var2 = program.global_block().var(student_var2_name) student_var2 = _find_var_from_program(program, student_var2_name)
teacher_fsp_matrix = paddle.fluid.layers.fsp_matrix(teacher_var1, teacher_fsp_matrix = paddle.fluid.layers.fsp_matrix(teacher_var1,
teacher_var2) teacher_var2)
student_fsp_matrix = paddle.fluid.layers.fsp_matrix(student_var1, student_fsp_matrix = paddle.fluid.layers.fsp_matrix(student_var1,
...@@ -165,8 +261,8 @@ def l2(teacher_var_name, student_var_name, program=None): ...@@ -165,8 +261,8 @@ def l2(teacher_var_name, student_var_name, program=None):
""" """
if program == None: if program == None:
program = paddle.static.default_main_program() program = paddle.static.default_main_program()
student_var = program.global_block().var(student_var_name) student_var = _find_var_from_program(program, student_var_name)
teacher_var = program.global_block().var(teacher_var_name) teacher_var = _find_var_from_program(program, teacher_var_name)
l2_loss = paddle.mean( l2_loss = paddle.mean(
paddle.nn.functional.square_error_cost(student_var, teacher_var)) paddle.nn.functional.square_error_cost(student_var, teacher_var))
return l2_loss return l2_loss
...@@ -194,8 +290,8 @@ def soft_label(teacher_var_name, ...@@ -194,8 +290,8 @@ def soft_label(teacher_var_name,
""" """
if program == None: if program == None:
program = paddle.static.default_main_program() program = paddle.static.default_main_program()
student_var = program.global_block().var(student_var_name) student_var = _find_var_from_program(program, student_var_name)
teacher_var = program.global_block().var(teacher_var_name) teacher_var = _find_var_from_program(program, teacher_var_name)
teacher_var.stop_gradient = True teacher_var.stop_gradient = True
student_var = paddle.nn.functional.softmax(student_var / student_var = paddle.nn.functional.softmax(student_var /
...@@ -225,7 +321,7 @@ def loss(loss_func, program=None, **kwargs): ...@@ -225,7 +321,7 @@ def loss(loss_func, program=None, **kwargs):
for item in kwargs.items(): for item in kwargs.items():
if isinstance(item[1], str): if isinstance(item[1], str):
func_parameters.setdefault(item[0], func_parameters.setdefault(item[0],
program.global_block().var(item[1])) _find_var_from_program(program, item[1]))
else: else:
func_parameters.setdefault(item[0], item[1]) func_parameters.setdefault(item[0], item[1])
loss = loss_func(**func_parameters) loss = loss_func(**func_parameters)
...@@ -297,8 +393,8 @@ def dkd(teacher_var_name, ...@@ -297,8 +393,8 @@ def dkd(teacher_var_name,
""" """
if program == None: if program == None:
program = paddle.static.default_main_program() program = paddle.static.default_main_program()
student_var = program.global_block().var(student_var_name) student_var = _find_var_from_program(program, student_var_name)
teacher_var = program.global_block().var(teacher_var_name) teacher_var = _find_var_from_program(program, teacher_var_name)
return _dkd_loss( return _dkd_loss(
student_var, student_var,
teacher_var, teacher_var,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License" # Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -34,7 +32,7 @@ from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass ...@@ -34,7 +32,7 @@ from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass
from ..common import get_logger from ..common import get_logger
from ..common.patterns import get_patterns from ..common.patterns import get_patterns
from ..common.patterns_common import is_dynamic_weight_op, get_weight from ..common.patterns_common import has_trainable_var, get_weight
from ..core.graph_wrapper import GraphWrapper from ..core.graph_wrapper import GraphWrapper
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
...@@ -341,7 +339,7 @@ def quant_aware(program, ...@@ -341,7 +339,7 @@ def quant_aware(program,
for op in ops: for op in ops:
if op._op.type in ['matmul', 'matmul_v2'] and ( if op._op.type in ['matmul', 'matmul_v2'] and (
not is_dynamic_weight_op(op)): not has_trainable_var(op)):
input_names = op._op.input_arg_names input_names = op._op.input_arg_names
for input_name in input_names: for input_name in input_names:
pre_op = find_pre_ops(program, input_name)[0] pre_op = find_pre_ops(program, input_name)[0]
...@@ -387,6 +385,7 @@ def quant_aware(program, ...@@ -387,6 +385,7 @@ def quant_aware(program,
scale_dict = post_training_quantization._scale_dict scale_dict = post_training_quantization._scale_dict
else: else:
main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) main_graph = IrGraph(core.Graph(program.desc), for_test=for_test)
sub_graphs = [sub_graph for sub_graph in main_graph.all_sub_graphs()]
transform_pass_ops = [] transform_pass_ops = []
quant_dequant_ops = [] quant_dequant_ops = []
for op_type in config['quantize_op_types']: for op_type in config['quantize_op_types']:
...@@ -416,7 +415,8 @@ def quant_aware(program, ...@@ -416,7 +415,8 @@ def quant_aware(program,
executor=executor, executor=executor,
is_test=is_test) is_test=is_test)
transform_pass.apply(main_graph) for sub_graph in sub_graphs:
transform_pass.apply(sub_graph)
if len(quant_dequant_ops) > 0: if len(quant_dequant_ops) > 0:
qdq_func = 'AddQuantDequantPassV2' if config[ qdq_func = 'AddQuantDequantPassV2' if config[
...@@ -430,7 +430,8 @@ def quant_aware(program, ...@@ -430,7 +430,8 @@ def quant_aware(program,
quantizable_op_type=quant_dequant_ops, quantizable_op_type=quant_dequant_ops,
is_test=is_test) is_test=is_test)
quant_dequant_pass.apply(main_graph) for sub_graph in sub_graphs:
quant_dequant_pass.apply(sub_graph)
out_scale_training_pass = OutScaleForTrainingPass( out_scale_training_pass = OutScaleForTrainingPass(
scope=scope, scope=scope,
...@@ -439,16 +440,18 @@ def quant_aware(program, ...@@ -439,16 +440,18 @@ def quant_aware(program,
is_test=is_test, is_test=is_test,
scale_dict=scale_dict) scale_dict=scale_dict)
out_scale_training_pass.apply(main_graph) for sub_graph in sub_graphs:
out_scale_training_pass.apply(sub_graph)
if (weight_preprocess_func is not None or if (weight_preprocess_func is not None or act_preprocess_func is not None
act_preprocess_func is not None) and not for_test: ) and not for_test and not config['onnx_format']:
_logger.info( _logger.info(
"When a preprocess_func is used in quant_aware, Need to save a mapping table to match variable names in the convert phase." "When a preprocess_func is used in quant_aware, Need to save a mapping table to match variable names in the convert phase."
) )
_logger.info("The mapping table is saved as '{}'.".format( _logger.info("The mapping table is saved as '{}'.".format(
VARS_MAPPING_TABLE)) VARS_MAPPING_TABLE))
save_dict(main_graph.out_node_mapping_table) for sub_graph in sub_graphs:
save_dict(sub_graph.out_node_mapping_table)
# TDOD: remove it. # TDOD: remove it.
if draw_graph: if draw_graph:
...@@ -683,18 +686,21 @@ def convert(program, ...@@ -683,18 +686,21 @@ def convert(program,
if config['onnx_format']: if config['onnx_format']:
quant_weight_pass = QuantWeightPass(scope, place) quant_weight_pass = QuantWeightPass(scope, place)
quant_weight_pass.apply(test_graph) for sub_graph in test_graph.all_sub_graphs():
quant_weight_pass.apply(sub_graph)
try: try:
out_scale_infer_pass = AddQuantDequantForInferencePass( out_scale_infer_pass = AddQuantDequantForInferencePass(
scope=scope, place=place, quant_bits=config['activation_bits']) scope=scope, place=place, quant_bits=config['activation_bits'])
out_scale_infer_pass.apply(test_graph) for sub_graph in test_graph.all_sub_graphs():
out_scale_infer_pass.apply(sub_graph)
except: except:
_logger.warning( _logger.warning(
"Unable to convert quant model with onnx_format=True, please update PaddlePaddle >= 2.4.0" "Unable to convert quant model with onnx_format=True, please update PaddlePaddle >= 2.4.0"
) )
else: else:
out_scale_infer_pass = OutScaleForInferencePass(scope=scope) out_scale_infer_pass = OutScaleForInferencePass(scope=scope)
out_scale_infer_pass.apply(test_graph) for sub_graph in test_graph.all_sub_graphs():
out_scale_infer_pass.apply(sub_graph)
# Freeze the graph after training by adjusting the quantize # Freeze the graph after training by adjusting the quantize
# operators' order for the inference. # operators' order for the inference.
freeze_pass = QuantizationFreezePass( freeze_pass = QuantizationFreezePass(
...@@ -705,13 +711,31 @@ def convert(program, ...@@ -705,13 +711,31 @@ def convert(program,
weight_quantize_type=config['weight_quantize_type']) weight_quantize_type=config['weight_quantize_type'])
if os.path.exists(VARS_MAPPING_TABLE): if os.path.exists(VARS_MAPPING_TABLE):
test_graph.out_node_mapping_table = load_dict() test_graph.out_node_mapping_table = load_dict()
freeze_pass.apply(test_graph) for sub_graph in test_graph.all_sub_graphs():
freeze_pass.apply(sub_graph)
freezed_program = test_graph.to_program() freezed_program = test_graph.to_program()
# Move sub blocks persistable var to global block
global_block = freezed_program.global_block()
for _op in global_block.ops:
if _op.type == "while":
_block_id = _op.attr("sub_block").id
_block = freezed_program.block(_block_id)
persistables = []
for _name, _var in _block.vars.items():
if _var.persistable:
global_block._clone_variable(_var)
persistables.append(_name)
for _name in persistables:
_block._remove_var(_name)
persistables.extend(_op.input('X'))
_op.desc.set_input("X", persistables)
if save_int8: if save_int8:
convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place) convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
convert_int8_pass.apply(test_graph) for sub_graph in test_graph.all_sub_graphs():
convert_int8_pass.apply(sub_graph)
freezed_program_int8 = test_graph.to_program() freezed_program_int8 = test_graph.to_program()
return freezed_program, freezed_program_int8 return freezed_program, freezed_program_int8
else: else:
......
import os
import sys
import unittest
sys.path.append("../../")
import numpy as np
import paddle
from paddle.io import Dataset
from paddleslim.auto_compression import AutoCompression
paddle.enable_static()
class RandomEvalDataset(Dataset):
def __init__(self, num_samples, image_shape=[1, 28, 28], class_num=10):
self.num_samples = num_samples
self.image_shape = image_shape
self.class_num = class_num
def __getitem__(self, idx):
image = np.random.random(self.image_shape).astype('float32')
return image
def __len__(self):
return self.num_samples
class ACTQATWhileOP(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(ACTQATWhileOP, self).__init__(*args, **kwargs)
if not os.path.exists('mnist_while'):
os.system(
"wget -q http://paddle-inference-dist.bj.bcebos.com/int8/mnist_while.tar.gz"
)
os.system('tar -xzvf mnist_while.tar.gz')
self.create_dataloader()
self.get_config()
def create_dataloader(self):
# define a random dataset
self.eval_dataset = RandomEvalDataset(32)
def get_config(self):
self.config = {
'QuantAware': {},
'Distillation': {},
'TrainConfig': {
'epochs': 1,
'eval_iter': 100,
'learning_rate': 5.0e-03,
'optimizer_builder': {
'optimizer': {
'type': 'SGD'
},
"weight_decay": 0.0005,
}
}
}
def test_demo(self):
image = paddle.static.data(
name='x', shape=[-1, 1, 28, 28], dtype='float32')
train_loader = paddle.io.DataLoader(
self.eval_dataset, feed_list=[image], batch_size=4)
ac = AutoCompression(
model_dir="./mnist_while",
model_filename="model.pdmodel",
params_filename="model.pdiparams",
config=self.config,
save_dir="qat_while_output",
train_dataloader=train_loader)
ac.compress()
os.system('rm -rf qat_while_output')
class ACTQATWhileOPCase2(ACTQATWhileOP):
def get_config(self):
self.config = {
'QuantAware': {
'quantize_op_types': ['conv2d', 'mul', 'relu']
},
'Distillation': {},
'TrainConfig': {
'epochs': 1,
'eval_iter': 100,
'learning_rate': 5.0e-03,
'optimizer_builder': {
'optimizer': {
'type': 'SGD'
},
"weight_decay": 0.0005,
}
}
}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册