diff --git a/paddleslim/quant/quant_embedding.py b/paddleslim/quant/quant_embedding.py index 6e14d282c7d884ba8652f3b179dc38510206ac7b..c5c3bc62a7aaddafce11151fdaa0d49ab1f8c05a 100755 --- a/paddleslim/quant/quant_embedding.py +++ b/paddleslim/quant/quant_embedding.py @@ -19,6 +19,7 @@ import logging import copy import numpy as np import math +from multiprocessing.dummy import Pool as ThreadPool import paddle.fluid as fluid from paddle.fluid.framework import IrGraph @@ -35,7 +36,7 @@ _default_single_config = { "dtype": "int8" } SUPPORT_OP_TYPES = ['lookup_table', 'fused_embedding_seq_pool', 'pyramid_hash'] -SUPPORT_QUANTIZE_TYPES = ['abs_max'] +SUPPORT_QUANTIZE_TYPES = ['abs_max', 'log'] SUPPORT_QUANTIZE_BITS = [8] SUPPORT_DTYPE = ['int8'] @@ -116,6 +117,10 @@ def _get_scale_var_name(var_name): return var_name + '.scale' +def _get_dict_var_name(var_name): + return var_name + '.dict' + + def _get_quant_var_name(var_name): """ get quantized var name @@ -208,6 +213,8 @@ def _quant_embedding_abs_max(graph, scope, place, config, var_name, threshold = np.percentile(np.abs(array), 99.99) return np.clip(array, -threshold, threshold) + _logger.info("Embedding {}: abs_max quantization".format(var_name)) + embedding_tensor = _get_var_tensor(scope, var_name) embedding_array = _clip_array(embedding_tensor, config) # get scale and quanted tensor @@ -243,6 +250,122 @@ def _quant_embedding_abs_max(graph, scope, place, config, var_name, graph.safe_remove_nodes(embedding_node) +def _quant_embedding_log(graph, scope, place, config, var_name, + embedding_node): + """ + quantize embedding using log + + Args: + graph(IrGraph): graph that includes Embedding Parameter + scope(fluid.Scope): scope + place(fluid.CPUPlace or flud.CUDAPlace): place to run program + config(dict): config to quant Embedding + """ + + _inverval = 0.125 + _dict_len = 256 + _dict = np.zeros(_dict_len) + + def _search(array, num_array): + length = len(array) + res = np.searchsorted(array, num_array) + res_refine = [] + for i in range(len(num_array)): + value = num_array[i] + idx = res[i] + if idx > 0 and ((idx == length) or ( + abs(array[idx - 1] - value) < abs(array[idx] - value))): + res_refine.append(idx - 1) + else: + res_refine.append(idx) + return np.array(res_refine) + + def _quant_log(tensor_array, config): + """ + quant array using log op + """ + bit_length = config['quantize_bits'] + log_and_quant = np.round(np.log2(np.abs(tensor_array)) / + _inverval) * _inverval + unique, counts = np.unique(log_and_quant, return_counts=True) + topk_num = np.sort(unique)[-int(_dict_len / 2):] + + pool = ThreadPool(8) + quanted_array = pool.map(lambda x: _search(topk_num, x), log_and_quant) + quanted_array = np.array(quanted_array) + pool.close() + pool.join() + index_tmp = tensor_array < 0 + quanted_array_tmp = quanted_array[index_tmp] + quanted_array_tmp = quanted_array_tmp - 128 + quanted_array[index_tmp] = quanted_array_tmp + quanted_array = quanted_array.astype(config['dtype']) + return topk_num, quanted_array + + def _insert_dequant_log_op(graph, scope, var_node, topk_num_node, config): + """ + Insert dequantize_log op in graph + """ + assert var_node.is_var(), "{} is not a var".format(var_node.name()) + + dequant_var_node = graph.create_var_node( + name=_get_dequant_var_name(var_node.name()), + var_type=var_node.type(), + shape=var_node.shape(), + var_dtype=core.VarDesc.VarType.FP32) + scope.var(dequant_var_node.name()) + + output_ops = var_node.outputs + dequant_op = graph.create_op_node( + op_type='dequantize_log', + attrs={'op_role': core.op_proto_and_checker_maker.OpRole.Forward}, + inputs={'X': var_node, + 'Dict': topk_num_node}, + outputs={'Out': dequant_var_node}) + graph.link_to(var_node, dequant_op) + graph.link_to(topk_num_node, dequant_op) + graph.link_to(dequant_op, dequant_var_node) + for node in output_ops: + graph.update_input_link(var_node, dequant_var_node, node) + + _logger.info("Embedding {}: log quantization".format(var_name)) + # find embedding var node by 'var_name' + embedding_tensor = _get_var_tensor(scope, var_name) + + # get quantize dict and quanted tensor + topk_num, quanted_tensor = _quant_log(embedding_tensor, config) + + #create params must use create_persistable_node + topk_num_var = graph.create_persistable_node( + _get_dict_var_name(var_name), + var_type=embedding_node.type(), + shape=topk_num.shape, + var_dtype=core.VarDesc.VarType.FP32) + quant_tensor_var = graph.create_persistable_node( + _get_quant_var_name(var_name), + var_type=embedding_node.type(), + shape=embedding_node.shape(), + var_dtype=core.VarDesc.VarType.INT8) + # create var in scope + scope.var(_get_quant_var_name(var_name)) + scope.var(_get_dict_var_name(var_name)) + #set var by tensor array or dict + _restore_var(_get_quant_var_name(var_name), quanted_tensor, scope, place) + _restore_var(_get_dict_var_name(var_name), topk_num, scope, place) + + # insert dequantize_log op + for op_node in embedding_node.outputs: + graph.update_input_link(embedding_node, quant_tensor_var, op_node) + out_name = op_node.output('Out')[0] + var_node = graph._find_node_by_name(op_node.outputs, out_name) + + _insert_dequant_log_op(graph, scope, var_node, topk_num_var, config) + + # free float embedding params memory + _clear_var(embedding_node.name(), scope) + graph.safe_remove_nodes(embedding_node) + + def _remove_link(in_node, out_node): in_node.remove_output(out_node) out_node.remove_input(in_node) @@ -301,9 +424,9 @@ def quant_embedding(program, place, config=None, scope=None): Args: program(fluid.Program): infer program - scope(fluid.Scope): Scope records the mapping between variable names and variables, similar to brackets in programming languages. Usually users can use `fluid.global_scope() `_ . When ``None`` will use `fluid.global_scope() `_. Default : ``None``. + scope(fluid.Scope, optional): Scope records the mapping between variable names and variables, similar to brackets in programming languages. Usually users can use `fluid.global_scope() `_ . When ``None`` will use `fluid.global_scope() `_. Default : ``None``. place(fluid.CPUPlace or fluid.CUDAPlace): This parameter represents the executor run on which device. - config(dict): config to quantize. The keys are 'params_name', 'quantize_type', \ + config(dict, optional): config to quantize. The keys are 'quantize_op_types'. For op in quantize_op_types, you can define 'quantize_type', \ 'quantize_bits', 'dtype', 'threshold'. \ ``quantize_type`` is quantize type, supported types are ['abs_max'], default is "abs_max". ``quantize_bits`` supported bits are [8] and default is 8. @@ -334,8 +457,12 @@ def quant_embedding(program, place, config=None, scope=None): for op_node in embedding_node.outputs: if op_node.name() == 'fused_embedding_seq_pool': _split_embedding_seq_pool(graph, op_node) - _quant_embedding_abs_max(graph, scope, place, \ - config[op_type], weight_name, embedding_node) + if config[op_type]['quantize_type'] == 'abs_max': + _quant_embedding_abs_max(graph, scope, place, config[op_type], + weight_name, embedding_node) + elif config[op_type]['quantize_type'] == 'log': + _quant_embedding_log(graph, scope, place, config[op_type], + weight_name, embedding_node) quantize_params_map[weight_name] = _get_quant_var_name(weight_name) for op in all_op: if op.name() == 'fused_embedding_seq_pool':