提交 f38831bd 编写于 作者: S slf12

fix details

上级 38f3eba3
...@@ -19,8 +19,6 @@ import logging ...@@ -19,8 +19,6 @@ import logging
import copy import copy
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid import core from paddle.fluid import core
...@@ -48,38 +46,32 @@ def _merge_config(old_config, new_config): ...@@ -48,38 +46,32 @@ def _merge_config(old_config, new_config):
new_config(dict): the user defined config, 'params_name' must be set. new_config(dict): the user defined config, 'params_name' must be set.
When 'threshold' is not set, quant embedding without clip . When 'threshold' is not set, quant embedding without clip .
""" """
keys = new_config.keys() old_config.update(new_config)
keys = old_config.keys()
assert 'params_name' in keys, "params_name must be set" assert 'params_name' in keys, "params_name must be set"
old_config['params_name'] = new_config['params_name']
quantize_type = old_config['quantize_type']
if 'quantize_type' in keys: assert isinstance(quantize_type, str), "quantize_type must be \
quantize_type = new_config['quantize_type'] str"
assert isinstance(quantize_type, str), "quantize_type must be \
str" assert quantize_type in support_quantize_types, " \
quantize_type {} is not supported, now supported quantize type \
assert quantize_type in support_quantize_types, " \ are {}.".format(quantize_type, support_quantize_types)
quantize_type {} is not supported, now supported quantize type \
are {}.".format(quantize_type, support_quantize_types) quantize_bits = old_config['quantize_bits']
old_config['quantize_type'] = quantize_type assert isinstance(quantize_bits, int), "quantize_bits must be int"
assert quantize_bits in support_quantize_bits, " quantize_bits {} \
if 'quantize_bits' in keys:
quantize_bits = new_config['quantize_bits']
assert isinstance(quantize_bits, int), "quantize_bits must be int"
assert quantize_bits in support_quantize_bits, " quantize_bits {} \
is not supported, now supported quantize bits are \ is not supported, now supported quantize bits are \
{}. ".format(quantize_bits, support_quantize_bits) {}. ".format(quantize_bits, support_quantize_bits)
old_config['quantize_bits'] = quantize_bits
if 'dtype' in keys:
dtype = new_config['dtype']
assert isinstance(dtype, str), "dtype must be str"
assert dtype in support_dtype, " dtype {} is not \
supported, now supported dtypes are {} \
".format(dtype, support_dtype)
old_config['dtype'] = dtype
dtype = old_config['dtype']
assert isinstance(dtype, str), "dtype must be str"
assert dtype in support_dtype, " dtype {} is not \
supported, now supported dtypes are {} \
".format(dtype, support_dtype)
if 'threshold' in keys: if 'threshold' in keys:
old_config['threshold'] = new_config['threshold'] assert isinstance(new_config['threshold'], (float, int)), "threshold \
must be number."
print("quant_embedding config {}".format(old_config)) print("quant_embedding config {}".format(old_config))
return old_config return old_config
...@@ -97,18 +89,15 @@ def _get_var_tensor(scope, var_name): ...@@ -97,18 +89,15 @@ def _get_var_tensor(scope, var_name):
return np.array(scope.find_var(var_name).get_tensor()) return np.array(scope.find_var(var_name).get_tensor())
def _clip_tensor(tensor_array, config): def _clip_tensor(tensor_array, threshold):
""" """
when 'threshold' is set, clip tensor by 'threshold' and '-threshold' when 'threshold' is set, clip tensor by 'threshold' and '-threshold'
Args: Args:
tensor_array(np.array): array to clip tensor_array(np.array): array to clip
config(dict): config dict config(dict): config dict
""" """
if 'threshold' in config.keys(): tensor_array[tensor_array > threshold] = threshold
threshold = config['threshold'] tensor_array[tensor_array < -threshold] = -threshold
assert isinstance(threshold, (int, float)), "threshold must be number"
tensor_array[tensor_array > threshold] = threshold
tensor_array[tensor_array < threshold] = -threshold
return tensor_array return tensor_array
...@@ -168,7 +157,7 @@ def _quant_embedding_abs_max(graph, scope, place, config): ...@@ -168,7 +157,7 @@ def _quant_embedding_abs_max(graph, scope, place, config):
scale = np.max(np.abs(tensor_array)).astype("float32") scale = np.max(np.abs(tensor_array)).astype("float32")
quanted_tensor = np.round(tensor_array / scale * ( quanted_tensor = np.round(tensor_array / scale * (
(1 << (bit_length - 1)) - 1)) (1 << (bit_length - 1)) - 1))
return scale, quanted_tensor.astype(np.int8) return scale, quanted_tensor.astype(config['dtype'])
def _insert_dequant_abs_max_op(graph, scope, var_node, scale_node, config): def _insert_dequant_abs_max_op(graph, scope, var_node, scale_node, config):
""" """
...@@ -205,7 +194,8 @@ def _quant_embedding_abs_max(graph, scope, place, config): ...@@ -205,7 +194,8 @@ def _quant_embedding_abs_max(graph, scope, place, config):
# find embedding var node by 'params_name' # find embedding var node by 'params_name'
embedding_node = graph._find_node_by_name(all_var_nodes, var_name) embedding_node = graph._find_node_by_name(all_var_nodes, var_name)
embedding_tensor = _get_var_tensor(scope, var_name) embedding_tensor = _get_var_tensor(scope, var_name)
embedding_tensor = _clip_tensor(embedding_tensor, config) if 'threshold' in config.keys():
embedding_tensor = _clip_tensor(embedding_tensor, config['threshold'])
# get scale and quanted tensor # get scale and quanted tensor
scale, quanted_tensor = _quant_abs_max(embedding_tensor, config) scale, quanted_tensor = _quant_abs_max(embedding_tensor, config)
...@@ -242,6 +232,21 @@ def _quant_embedding_abs_max(graph, scope, place, config): ...@@ -242,6 +232,21 @@ def _quant_embedding_abs_max(graph, scope, place, config):
def quant_embedding(program, scope, place, config): def quant_embedding(program, scope, place, config):
"""
quant lookup_table op parameters
Args:
program(fluid.Program): infer program
scope(fluid.Scope): the scope to store var, usually is fluid.global_scope()
place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): config to quant. The keys are 'params_name', 'quantize_type', \
'quantize_bits', 'dtype', 'threshold'. \
'params_name': parameter name to quant, must be set.
'quantize_type': quantize type, supported types are ['abs_max']. default is "abs_max".
'quantize_bits': quantize bits, supported bits are [8]. default is 8.
'dtype': quantize dtype, supported dtype are ['int8']. default is 'int8'.
'threshold': threshold to clip tensor before quant. When threshold is not set, \
tensor will not be clipped.
"""
assert isinstance(config, dict), "config must be dict" assert isinstance(config, dict), "config must be dict"
config = _merge_config(copy.deepcopy(default_config), config) config = _merge_config(copy.deepcopy(default_config), config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册