未验证 提交 188ec0d9 编写于 作者: L Liufang Sang 提交者: GitHub

add default config and support more ops (#180)

* add default config and support more ops

* remove debug code

* remove debug code

* fix details
上级 52502c06
...@@ -18,24 +18,28 @@ from __future__ import print_function ...@@ -18,24 +18,28 @@ from __future__ import print_function
import logging import logging
import copy import copy
import numpy as np import numpy as np
import math
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid import core from paddle.fluid import core
#_logger = logging.basicConfig(level=logging.DEBUG) from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
__all__ = ['quant_embedding'] __all__ = ['quant_embedding']
default_config = { _default_single_config = {
"quantize_type": "abs_max", "quantize_type": "abs_max",
"quantize_bits": 8, "quantize_bits": 8,
"dtype": "int8" "dtype": "int8"
} }
SUPPORT_OP_TYPES = ['lookup_table', 'fused_embedding_seq_pool', 'pyramid_hash']
SUPPORT_QUANTIZE_TYPES = ['abs_max']
SUPPORT_QUANTIZE_BITS = [8]
SUPPORT_DTYPE = ['int8']
support_quantize_types = ['abs_max'] _default_config = {"quantize_op_types": SUPPORT_OP_TYPES, }
support_quantize_bits = [8]
support_dtype = ['int8']
def _merge_config(old_config, new_config): def _merge_config(old_config, new_config):
...@@ -49,32 +53,47 @@ def _merge_config(old_config, new_config): ...@@ -49,32 +53,47 @@ def _merge_config(old_config, new_config):
""" """
old_config.update(new_config) old_config.update(new_config)
keys = old_config.keys() keys = old_config.keys()
assert 'params_name' in keys, "params_name must be set" assert isinstance(old_config['quantize_op_types'], (str, list)), \
'quantize_op_types can only be str or list[str]'
quantize_type = old_config['quantize_type'] if isinstance(old_config['quantize_op_types'], str):
assert isinstance(quantize_type, str), "quantize_type must be \ old_config['quantize_op_types'] = [old_config['quantize_op_types']]
for op_type in old_config['quantize_op_types']:
assert op_type in SUPPORT_OP_TYPES, \
'{} is not supported, supported op types are {}'.format(
op_type, SUPPORT_OP_TYPES)
if op_type not in keys:
old_config[op_type] = _default_single_config
continue
else:
assert isinstance(old_config[op_type], dict), \
"op type {}'s config must be dict"
config_tmp = copy.deepcopy(_default_single_config)
config_tmp.update(old_config[op_type])
old_config[op_type] = config_tmp
quantize_type = old_config[op_type]['quantize_type']
assert isinstance(quantize_type, str), "quantize_type must be \
str" str"
assert quantize_type in support_quantize_types, " \ assert quantize_type in SUPPORT_QUANTIZE_TYPES , "" \
quantize_type {} is not supported, now supported quantize type \ "quantize_type {} is not supported, now supported quantize type" \
are {}.".format(quantize_type, support_quantize_types) " are {}.".format(quantize_type, SUPPORT_QUANTIZE_TYPES)
quantize_bits = old_config['quantize_bits'] quantize_bits = old_config[op_type]['quantize_bits']
assert isinstance(quantize_bits, int), "quantize_bits must be int" assert isinstance(quantize_bits, int), "quantize_bits must be int"
assert quantize_bits in support_quantize_bits, " quantize_bits {} \ assert quantize_bits in SUPPORT_QUANTIZE_BITS , " quantize_bits {}" \
is not supported, now supported quantize bits are \ " is not supported, now supported quantize bits are" \
{}. ".format(quantize_bits, support_quantize_bits) " {}. ".format(quantize_bits, SUPPORT_QUANTIZE_BITS)
dtype = old_config['dtype'] dtype = old_config[op_type]['dtype']
assert isinstance(dtype, str), "dtype must be str" assert isinstance(dtype, str), "dtype must be str"
assert dtype in support_dtype, " dtype {} is not \ assert dtype in SUPPORT_DTYPE , " dtype {} is not "\
supported, now supported dtypes are {} \ "supported, now supported dtypes are {} ".format(dtype, SUPPORT_DTYPE)
".format(dtype, support_dtype) if 'threshold' in old_config[op_type].keys():
if 'threshold' in keys: assert isinstance(old_config[op_type]['threshold'], (float, int)), \
assert isinstance(old_config['threshold'], (float, int)), "threshold \ "threshold must be number."
must be number."
_logger.info("quant_embedding config {}".format(old_config))
print("quant_embedding config {}".format(old_config))
return old_config return old_config
...@@ -90,18 +109,6 @@ def _get_var_tensor(scope, var_name): ...@@ -90,18 +109,6 @@ def _get_var_tensor(scope, var_name):
return np.array(scope.find_var(var_name).get_tensor()) return np.array(scope.find_var(var_name).get_tensor())
def _clip_tensor(tensor_array, threshold):
"""
when 'threshold' is set, clip tensor by 'threshold' and '-threshold'
Args:
tensor_array(np.array): array to clip
config(dict): config dict
"""
tensor_array[tensor_array > threshold] = threshold
tensor_array[tensor_array < -threshold] = -threshold
return tensor_array
def _get_scale_var_name(var_name): def _get_scale_var_name(var_name):
""" """
get scale var name get scale var name
...@@ -139,7 +146,8 @@ def _clear_var(var_name, scope): ...@@ -139,7 +146,8 @@ def _clear_var(var_name, scope):
tensor._clear() tensor._clear()
def _quant_embedding_abs_max(graph, scope, place, config): def _quant_embedding_abs_max(graph, scope, place, config, var_name,
embedding_node):
""" """
quantize embedding using abs_max quantize embedding using abs_max
...@@ -190,16 +198,20 @@ def _quant_embedding_abs_max(graph, scope, place, config): ...@@ -190,16 +198,20 @@ def _quant_embedding_abs_max(graph, scope, place, config):
for node in output_ops: for node in output_ops:
graph.update_input_link(var_node, dequant_var_node, node) graph.update_input_link(var_node, dequant_var_node, node)
all_var_nodes = graph.all_var_nodes() def _clip_array(array, config):
var_name = config['params_name'] if 'threshold' in config.keys():
# find embedding var node by 'params_name' threshold = config['threshold']
embedding_node = graph._find_node_by_name(all_var_nodes, var_name) else:
embedding_tensor = _get_var_tensor(scope, var_name) abs_array = np.max(np.abs(array))
if 'threshold' in config.keys(): if abs_array < 1.0:
embedding_tensor = _clip_tensor(embedding_tensor, config['threshold']) return array
threshold = np.percentile(np.abs(array), 99.99)
return np.clip(array, -threshold, threshold)
embedding_tensor = _get_var_tensor(scope, var_name)
embedding_array = _clip_array(embedding_tensor, config)
# get scale and quanted tensor # get scale and quanted tensor
scale, quanted_tensor = _quant_abs_max(embedding_tensor, config) scale, quanted_tensor = _quant_abs_max(embedding_array, config)
#create params must to use create_persistable_node #create params must to use create_persistable_node
scale_var = graph.create_persistable_node( scale_var = graph.create_persistable_node(
...@@ -221,18 +233,70 @@ def _quant_embedding_abs_max(graph, scope, place, config): ...@@ -221,18 +233,70 @@ def _quant_embedding_abs_max(graph, scope, place, config):
# insert dequantize_abs_max op # insert dequantize_abs_max op
for op_node in embedding_node.outputs: for op_node in embedding_node.outputs:
if op_node.name() == 'lookup_table': graph.update_input_link(embedding_node, quant_tensor_var, op_node)
graph.update_input_link(embedding_node, quant_tensor_var, op_node) out_name = op_node.output('Out')[0]
var_node = op_node.outputs[0] var_node = graph._find_node_by_name(op_node.outputs, out_name)
_insert_dequant_abs_max_op(graph, scope, var_node, scale_var, _insert_dequant_abs_max_op(graph, scope, var_node, scale_var, config)
config)
# free float embedding params memory # free float embedding params memory
_clear_var(embedding_node.name(), scope) _clear_var(embedding_node.name(), scope)
graph.safe_remove_nodes(embedding_node) graph.safe_remove_nodes(embedding_node)
def quant_embedding(program, place, config, scope=None): def _remove_link(in_node, out_node):
in_node.remove_output(out_node)
out_node.remove_input(in_node)
def _split_embedding_seq_pool(graph, op):
inputs = op.inputs
outputs = op.outputs
op_desc = op.node.op()
combiner = op_desc.attr("combiner")
padding_idx = op_desc.attr("padding_idx")
is_sparse = op_desc.attr("is_sparse")
ids = graph._find_node_by_name(inputs, op.input('Ids')[0])
weight = graph._find_node_by_name(inputs, op.input('W')[0])
out = outputs[0]
lookup_out = graph.create_var_node(
name=ids.name() + '.look_up_table.out',
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=weight.dtype())
lookup_table_op = graph.create_op_node(
op_type='lookup_table',
attrs={'is_sparse': is_sparse,
'padding_idx': padding_idx},
inputs={'W': weight,
'Ids': ids},
outputs={'Out': lookup_out})
_remove_link(ids, op)
_remove_link(weight, op)
_remove_link(op, out)
graph.link_to(ids, lookup_table_op)
graph.link_to(weight, lookup_table_op)
graph.link_to(lookup_table_op, lookup_out)
max_index = graph.create_var_node(
name=ids.name() + '.seq_pool_op.max_index',
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=weight.dtype())
seq_pool_op = graph.create_op_node(
op_type='sequence_pool',
inputs={'X': lookup_out},
outputs={'Out': out,
'MaxIndex': max_index},
attrs={'pooltype': combiner.upper(),
'is_test': True})
if combiner == 'max':
max_index.stop_gradient = True
graph.link_to(lookup_out, seq_pool_op)
graph.link_to(seq_pool_op, out)
graph.link_to(seq_pool_op, max_index)
def quant_embedding(program, place, config=None, scope=None):
"""quantize lookup_table op parameters """quantize lookup_table op parameters
Args: Args:
...@@ -241,7 +305,6 @@ def quant_embedding(program, place, config, scope=None): ...@@ -241,7 +305,6 @@ def quant_embedding(program, place, config, scope=None):
place(fluid.CPUPlace or fluid.CUDAPlace): This parameter represents the executor run on which device. place(fluid.CPUPlace or fluid.CUDAPlace): This parameter represents the executor run on which device.
config(dict): config to quantize. The keys are 'params_name', 'quantize_type', \ config(dict): config to quantize. The keys are 'params_name', 'quantize_type', \
'quantize_bits', 'dtype', 'threshold'. \ 'quantize_bits', 'dtype', 'threshold'. \
``params_name`` is parameter name to quantize, must be set.
``quantize_type`` is quantize type, supported types are ['abs_max'], default is "abs_max". ``quantize_type`` is quantize type, supported types are ['abs_max'], default is "abs_max".
``quantize_bits`` supported bits are [8] and default is 8. ``quantize_bits`` supported bits are [8] and default is 8.
``dtype`` is quantize dtype, supported dtype are ['int8'], default is 'int8'. ``dtype`` is quantize dtype, supported dtype are ['int8'], default is 'int8'.
...@@ -251,12 +314,31 @@ def quant_embedding(program, place, config, scope=None): ...@@ -251,12 +314,31 @@ def quant_embedding(program, place, config, scope=None):
Returns: Returns:
None None
""" """
assert isinstance(config, dict), "config must be dict" config = config or {}
config = _merge_config(copy.deepcopy(default_config), config) config = _merge_config(copy.deepcopy(_default_config), config)
scope = fluid.global_scope() if scope is None else scope scope = fluid.global_scope() if scope is None else scope
graph = IrGraph(core.Graph(program.desc), for_test=True) graph = IrGraph(core.Graph(program.desc), for_test=True)
if config['quantize_type'] == 'abs_max': quantize_params_map = {}
_quant_embedding_abs_max(graph, scope, place, config) all_op = graph.all_op_nodes()
for op in all_op:
if op.inputs == [] and op.outputs == []:
continue
op_type = op.name()
if op_type in config['quantize_op_types']:
weight_name = op.input('W')[0]
if weight_name in quantize_params_map.values():
continue
embedding_node = graph._find_node_by_name(op.inputs,
op.input('W')[0])
for op_node in embedding_node.outputs:
if op_node.name() == 'fused_embedding_seq_pool':
_split_embedding_seq_pool(graph, op_node)
_quant_embedding_abs_max(graph, scope, place, \
config[op_type], weight_name, embedding_node)
quantize_params_map[weight_name] = _get_quant_var_name(weight_name)
for op in all_op:
if op.name() == 'fused_embedding_seq_pool':
graph.safe_remove_nodes(op)
return graph.to_program() return graph.to_program()
...@@ -24,8 +24,7 @@ class TestQuantEmbedding(unittest.TestCase): ...@@ -24,8 +24,7 @@ class TestQuantEmbedding(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
config = {'params_name': 'emb', 'quantize_type': 'abs_max'} quant_program = quant.quant_embedding(infer_program, place)
quant_program = quant.quant_embedding(infer_program, place, config)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册