提交 f38831bd 编写于 作者: S slf12

fix details

上级 38f3eba3
......@@ -19,8 +19,6 @@ import logging
import copy
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
......@@ -48,38 +46,32 @@ def _merge_config(old_config, new_config):
new_config(dict): the user defined config, 'params_name' must be set.
When 'threshold' is not set, quant embedding without clip .
"""
keys = new_config.keys()
old_config.update(new_config)
keys = old_config.keys()
assert 'params_name' in keys, "params_name must be set"
old_config['params_name'] = new_config['params_name']
if 'quantize_type' in keys:
quantize_type = new_config['quantize_type']
assert isinstance(quantize_type, str), "quantize_type must be \
str"
assert quantize_type in support_quantize_types, " \
quantize_type {} is not supported, now supported quantize type \
are {}.".format(quantize_type, support_quantize_types)
old_config['quantize_type'] = quantize_type
if 'quantize_bits' in keys:
quantize_bits = new_config['quantize_bits']
assert isinstance(quantize_bits, int), "quantize_bits must be int"
assert quantize_bits in support_quantize_bits, " quantize_bits {} \
quantize_type = old_config['quantize_type']
assert isinstance(quantize_type, str), "quantize_type must be \
str"
assert quantize_type in support_quantize_types, " \
quantize_type {} is not supported, now supported quantize type \
are {}.".format(quantize_type, support_quantize_types)
quantize_bits = old_config['quantize_bits']
assert isinstance(quantize_bits, int), "quantize_bits must be int"
assert quantize_bits in support_quantize_bits, " quantize_bits {} \
is not supported, now supported quantize bits are \
{}. ".format(quantize_bits, support_quantize_bits)
old_config['quantize_bits'] = quantize_bits
if 'dtype' in keys:
dtype = new_config['dtype']
assert isinstance(dtype, str), "dtype must be str"
assert dtype in support_dtype, " dtype {} is not \
supported, now supported dtypes are {} \
".format(dtype, support_dtype)
old_config['dtype'] = dtype
dtype = old_config['dtype']
assert isinstance(dtype, str), "dtype must be str"
assert dtype in support_dtype, " dtype {} is not \
supported, now supported dtypes are {} \
".format(dtype, support_dtype)
if 'threshold' in keys:
old_config['threshold'] = new_config['threshold']
assert isinstance(new_config['threshold'], (float, int)), "threshold \
must be number."
print("quant_embedding config {}".format(old_config))
return old_config
......@@ -97,18 +89,15 @@ def _get_var_tensor(scope, var_name):
return np.array(scope.find_var(var_name).get_tensor())
def _clip_tensor(tensor_array, config):
def _clip_tensor(tensor_array, threshold):
"""
when 'threshold' is set, clip tensor by 'threshold' and '-threshold'
Args:
tensor_array(np.array): array to clip
config(dict): config dict
"""
if 'threshold' in config.keys():
threshold = config['threshold']
assert isinstance(threshold, (int, float)), "threshold must be number"
tensor_array[tensor_array > threshold] = threshold
tensor_array[tensor_array < threshold] = -threshold
tensor_array[tensor_array > threshold] = threshold
tensor_array[tensor_array < -threshold] = -threshold
return tensor_array
......@@ -168,7 +157,7 @@ def _quant_embedding_abs_max(graph, scope, place, config):
scale = np.max(np.abs(tensor_array)).astype("float32")
quanted_tensor = np.round(tensor_array / scale * (
(1 << (bit_length - 1)) - 1))
return scale, quanted_tensor.astype(np.int8)
return scale, quanted_tensor.astype(config['dtype'])
def _insert_dequant_abs_max_op(graph, scope, var_node, scale_node, config):
"""
......@@ -205,7 +194,8 @@ def _quant_embedding_abs_max(graph, scope, place, config):
# find embedding var node by 'params_name'
embedding_node = graph._find_node_by_name(all_var_nodes, var_name)
embedding_tensor = _get_var_tensor(scope, var_name)
embedding_tensor = _clip_tensor(embedding_tensor, config)
if 'threshold' in config.keys():
embedding_tensor = _clip_tensor(embedding_tensor, config['threshold'])
# get scale and quanted tensor
scale, quanted_tensor = _quant_abs_max(embedding_tensor, config)
......@@ -242,6 +232,21 @@ def _quant_embedding_abs_max(graph, scope, place, config):
def quant_embedding(program, scope, place, config):
"""
quant lookup_table op parameters
Args:
program(fluid.Program): infer program
scope(fluid.Scope): the scope to store var, usually is fluid.global_scope()
place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): config to quant. The keys are 'params_name', 'quantize_type', \
'quantize_bits', 'dtype', 'threshold'. \
'params_name': parameter name to quant, must be set.
'quantize_type': quantize type, supported types are ['abs_max']. default is "abs_max".
'quantize_bits': quantize bits, supported bits are [8]. default is 8.
'dtype': quantize dtype, supported dtype are ['int8']. default is 'int8'.
'threshold': threshold to clip tensor before quant. When threshold is not set, \
tensor will not be clipped.
"""
assert isinstance(config, dict), "config must be dict"
config = _merge_config(copy.deepcopy(default_config), config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册