未验证 提交 e4c6ae55 编写于 作者: L Liufang Sang 提交者: GitHub

add log embedding quantization (#189)

上级 140d063d
......@@ -19,6 +19,7 @@ import logging
import copy
import numpy as np
import math
from multiprocessing.dummy import Pool as ThreadPool
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
......@@ -35,7 +36,7 @@ _default_single_config = {
"dtype": "int8"
}
SUPPORT_OP_TYPES = ['lookup_table', 'fused_embedding_seq_pool', 'pyramid_hash']
SUPPORT_QUANTIZE_TYPES = ['abs_max']
SUPPORT_QUANTIZE_TYPES = ['abs_max', 'log']
SUPPORT_QUANTIZE_BITS = [8]
SUPPORT_DTYPE = ['int8']
......@@ -116,6 +117,10 @@ def _get_scale_var_name(var_name):
return var_name + '.scale'
def _get_dict_var_name(var_name):
return var_name + '.dict'
def _get_quant_var_name(var_name):
"""
get quantized var name
......@@ -208,6 +213,8 @@ def _quant_embedding_abs_max(graph, scope, place, config, var_name,
threshold = np.percentile(np.abs(array), 99.99)
return np.clip(array, -threshold, threshold)
_logger.info("Embedding {}: abs_max quantization".format(var_name))
embedding_tensor = _get_var_tensor(scope, var_name)
embedding_array = _clip_array(embedding_tensor, config)
# get scale and quanted tensor
......@@ -243,6 +250,122 @@ def _quant_embedding_abs_max(graph, scope, place, config, var_name,
graph.safe_remove_nodes(embedding_node)
def _quant_embedding_log(graph, scope, place, config, var_name,
embedding_node):
"""
quantize embedding using log
Args:
graph(IrGraph): graph that includes Embedding Parameter
scope(fluid.Scope): scope
place(fluid.CPUPlace or flud.CUDAPlace): place to run program
config(dict): config to quant Embedding
"""
_inverval = 0.125
_dict_len = 256
_dict = np.zeros(_dict_len)
def _search(array, num_array):
length = len(array)
res = np.searchsorted(array, num_array)
res_refine = []
for i in range(len(num_array)):
value = num_array[i]
idx = res[i]
if idx > 0 and ((idx == length) or (
abs(array[idx - 1] - value) < abs(array[idx] - value))):
res_refine.append(idx - 1)
else:
res_refine.append(idx)
return np.array(res_refine)
def _quant_log(tensor_array, config):
"""
quant array using log op
"""
bit_length = config['quantize_bits']
log_and_quant = np.round(np.log2(np.abs(tensor_array)) /
_inverval) * _inverval
unique, counts = np.unique(log_and_quant, return_counts=True)
topk_num = np.sort(unique)[-int(_dict_len / 2):]
pool = ThreadPool(8)
quanted_array = pool.map(lambda x: _search(topk_num, x), log_and_quant)
quanted_array = np.array(quanted_array)
pool.close()
pool.join()
index_tmp = tensor_array < 0
quanted_array_tmp = quanted_array[index_tmp]
quanted_array_tmp = quanted_array_tmp - 128
quanted_array[index_tmp] = quanted_array_tmp
quanted_array = quanted_array.astype(config['dtype'])
return topk_num, quanted_array
def _insert_dequant_log_op(graph, scope, var_node, topk_num_node, config):
"""
Insert dequantize_log op in graph
"""
assert var_node.is_var(), "{} is not a var".format(var_node.name())
dequant_var_node = graph.create_var_node(
name=_get_dequant_var_name(var_node.name()),
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=core.VarDesc.VarType.FP32)
scope.var(dequant_var_node.name())
output_ops = var_node.outputs
dequant_op = graph.create_op_node(
op_type='dequantize_log',
attrs={'op_role': core.op_proto_and_checker_maker.OpRole.Forward},
inputs={'X': var_node,
'Dict': topk_num_node},
outputs={'Out': dequant_var_node})
graph.link_to(var_node, dequant_op)
graph.link_to(topk_num_node, dequant_op)
graph.link_to(dequant_op, dequant_var_node)
for node in output_ops:
graph.update_input_link(var_node, dequant_var_node, node)
_logger.info("Embedding {}: log quantization".format(var_name))
# find embedding var node by 'var_name'
embedding_tensor = _get_var_tensor(scope, var_name)
# get quantize dict and quanted tensor
topk_num, quanted_tensor = _quant_log(embedding_tensor, config)
#create params must use create_persistable_node
topk_num_var = graph.create_persistable_node(
_get_dict_var_name(var_name),
var_type=embedding_node.type(),
shape=topk_num.shape,
var_dtype=core.VarDesc.VarType.FP32)
quant_tensor_var = graph.create_persistable_node(
_get_quant_var_name(var_name),
var_type=embedding_node.type(),
shape=embedding_node.shape(),
var_dtype=core.VarDesc.VarType.INT8)
# create var in scope
scope.var(_get_quant_var_name(var_name))
scope.var(_get_dict_var_name(var_name))
#set var by tensor array or dict
_restore_var(_get_quant_var_name(var_name), quanted_tensor, scope, place)
_restore_var(_get_dict_var_name(var_name), topk_num, scope, place)
# insert dequantize_log op
for op_node in embedding_node.outputs:
graph.update_input_link(embedding_node, quant_tensor_var, op_node)
out_name = op_node.output('Out')[0]
var_node = graph._find_node_by_name(op_node.outputs, out_name)
_insert_dequant_log_op(graph, scope, var_node, topk_num_var, config)
# free float embedding params memory
_clear_var(embedding_node.name(), scope)
graph.safe_remove_nodes(embedding_node)
def _remove_link(in_node, out_node):
in_node.remove_output(out_node)
out_node.remove_input(in_node)
......@@ -301,9 +424,9 @@ def quant_embedding(program, place, config=None, scope=None):
Args:
program(fluid.Program): infer program
scope(fluid.Scope): Scope records the mapping between variable names and variables, similar to brackets in programming languages. Usually users can use `fluid.global_scope() <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_ . When ``None`` will use `fluid.global_scope() <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_. Default : ``None``.
scope(fluid.Scope, optional): Scope records the mapping between variable names and variables, similar to brackets in programming languages. Usually users can use `fluid.global_scope() <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_ . When ``None`` will use `fluid.global_scope() <https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api_cn/executor_cn/global_scope_cn.html>`_. Default : ``None``.
place(fluid.CPUPlace or fluid.CUDAPlace): This parameter represents the executor run on which device.
config(dict): config to quantize. The keys are 'params_name', 'quantize_type', \
config(dict, optional): config to quantize. The keys are 'quantize_op_types'. For op in quantize_op_types, you can define 'quantize_type', \
'quantize_bits', 'dtype', 'threshold'. \
``quantize_type`` is quantize type, supported types are ['abs_max'], default is "abs_max".
``quantize_bits`` supported bits are [8] and default is 8.
......@@ -334,8 +457,12 @@ def quant_embedding(program, place, config=None, scope=None):
for op_node in embedding_node.outputs:
if op_node.name() == 'fused_embedding_seq_pool':
_split_embedding_seq_pool(graph, op_node)
_quant_embedding_abs_max(graph, scope, place, \
config[op_type], weight_name, embedding_node)
if config[op_type]['quantize_type'] == 'abs_max':
_quant_embedding_abs_max(graph, scope, place, config[op_type],
weight_name, embedding_node)
elif config[op_type]['quantize_type'] == 'log':
_quant_embedding_log(graph, scope, place, config[op_type],
weight_name, embedding_node)
quantize_params_map[weight_name] = _get_quant_var_name(weight_name)
for op in all_op:
if op.name() == 'fused_embedding_seq_pool':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册