提交 ed7e33e7 编写于 作者: I itminner

add some params check

上级 cd738054
...@@ -12,5 +12,5 @@ ...@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from quanter import quant_aware, quant_post, convert from .quanter import quant_aware, quant_post, convert
from .quant_embedding import quant_embedding from .quant_embedding import quant_embedding
...@@ -22,7 +22,9 @@ from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass ...@@ -22,7 +22,9 @@ from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid import core from paddle.fluid import core
QUANTIZATION_TYPES=['abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'] WEIGHT_QUANTIZATION_TYPES=['abs_max', 'channel_wise_abs_max']
ACTIVATION_QUANTIZATION_TYPES=['abs_max','range_abs_max', 'moving_average_abs_max']
VALID_DTYPES = ['int8']
quant_config_default = { quant_config_default = {
# weight quantize type, default is 'abs_max' # weight quantize type, default is 'abs_max'
...@@ -61,19 +63,23 @@ def _parse_configs(user_config): ...@@ -61,19 +63,23 @@ def _parse_configs(user_config):
configs.update(user_config) configs.update(user_config)
# check configs is valid # check configs is valid
assert configs['weight_quantize_type'] in QUANTIZATION_TYPES, \ assert configs['weight_quantize_type'] in WEIGHT_QUANTIZATION_TYPES, \
"Unknown weight_quantize_type: '%s'. It can only be " \ "Unknown weight_quantize_type: '%s'. It can only be " + " ".join(WEIGHT_QUANTIZATION_TYPES)
"'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
assert configs['activation_quantize_type'] in QUANTIZATION_TYPES, \ assert configs['activation_quantize_type'] in ACTIVATION_QUANTIZATION_TYPES, \
"Unknown activation_quantize_type: '%s'. It can only be " \ "Unknown activation_quantize_type: '%s'. It can only be " + " ".join(ACTIVATION_QUANTIZATION_TYPES)
"'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
assert isinstance(configs['weight_bits'], int), \ assert isinstance(configs['weight_bits'], int), \
"weight_bits must be int value, such as 8, 16, 32, etc" "weight_bits must be int value."
assert isinstance(configs['weight_bits'] >= 1 and configs['weight_bits'] <= 16), \
"weight_bits should be between 1 and 16."
assert isinstance(configs['activation_bits'], int), \ assert isinstance(configs['activation_bits'], int), \
"activation_bits must be int value, such as 8, 16, 32, etc" "activation_bits must be int value."
assert isinstance(configs['activation_bits'] >= 1 and configs['activation_bits'] <= 16), \
"activation_bits should be between 1 and 16."
assert isinstance(configs['not_quant_pattern'], list), \ assert isinstance(configs['not_quant_pattern'], list), \
"not_quant_pattern must be a list" "not_quant_pattern must be a list"
...@@ -82,7 +88,10 @@ def _parse_configs(user_config): ...@@ -82,7 +88,10 @@ def _parse_configs(user_config):
"quantize_op_types must be a list" "quantize_op_types must be a list"
assert isinstance(configs['dtype'], str), \ assert isinstance(configs['dtype'], str), \
"dtype must be a str, it can be config as 'int8', 'uint8', 'int16', etc." "dtype must be a str."
assert isinstance(configs['dtype'] in VALID_DTYPES), \
"dtype can only be " + " ".join(VALID_DTYPES)
assert isinstance(configs['window_size'], int), \ assert isinstance(configs['window_size'], int), \
"window_size must be int value, window size for 'range_abs_max' quantization, default is 10000." "window_size must be int value, window size for 'range_abs_max' quantization, default is 10000."
...@@ -104,10 +113,10 @@ def quant_aware(program, scope, place, config, for_test=False): ...@@ -104,10 +113,10 @@ def quant_aware(program, scope, place, config, for_test=False):
add trainable quantization ops in program. add trainable quantization ops in program.
Args: Args:
program(fluid.Program): program program(fluid.Program): program
scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope() scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope().
place(fluid.CPUPlace or fluid.CUDAPlace): place place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): configs for quantization, default values are in quant_config_default dict. config(dict): configs for quantization, default values are in quant_config_default dict.
for_test: is for test program. for_test: if program is test program, for_test should be set True, else False.
Return: Return:
fluid.Program: user can finetune this quantization program to enhance the accuracy. fluid.Program: user can finetune this quantization program to enhance the accuracy.
""" """
...@@ -122,17 +131,18 @@ def quant_aware(program, scope, place, config, for_test=False): ...@@ -122,17 +131,18 @@ def quant_aware(program, scope, place, config, for_test=False):
main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) main_graph = IrGraph(core.Graph(program.desc), for_test=for_test)
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=scope, place=place, scope=scope,
place=place,
weight_bits=config['weight_bits'], weight_bits=config['weight_bits'],
activation_bits=config['activation_bits'], activation_bits=config['activation_bits'],
activation_quantize_type=config['activation_quantize_type'], activation_quantize_type=config['activation_quantize_type'],
weight_quantize_type=config['weight_quantize_type'], weight_quantize_type=config['weight_quantize_type'],
window_size=config['window_size'], window_size=config['window_size'],
moving_rate=config['moving_rate'], moving_rate=config['moving_rate'],
quantizable_op_type=config['quantize_op_types'],
skip_pattern=''#not_quant_pattern skip_pattern=''#not_quant_pattern
) )
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
if for_test: if for_test:
...@@ -146,7 +156,7 @@ def quant_post(program, scope, place, config): ...@@ -146,7 +156,7 @@ def quant_post(program, scope, place, config):
add quantization ops in program. the program returned is not trainable. add quantization ops in program. the program returned is not trainable.
Args: Args:
program(fluid.Program): program program(fluid.Program): program
scope(fluid.Scope): the scope to store var, when is None will use fluid.global_scope() scope(fluid.Scope): the scope to store var, it's should be the value of program's scope, usually it's fluid.global_scope().
place(fluid.CPUPlace or fluid.CUDAPlace): place place(fluid.CPUPlace or fluid.CUDAPlace): place
config(dict): configs for quantization, default values are in quant_config_default dict. config(dict): configs for quantization, default values are in quant_config_default dict.
for_test: is for test program. for_test: is for test program.
...@@ -186,6 +196,7 @@ def convert(program, scope, place, config, save_int8=False): ...@@ -186,6 +196,7 @@ def convert(program, scope, place, config, save_int8=False):
fluid.Program: freezed program which can be used for inference. fluid.Program: freezed program which can be used for inference.
parameters is float32 type, but it's value in int8 range. parameters is float32 type, but it's value in int8 range.
fluid.Program: freezed int8 program which can be used for inference. fluid.Program: freezed int8 program which can be used for inference.
if save_int8 is False, this value is None.
""" """
test_graph = IrGraph(core.Graph(program.desc), for_test=True) test_graph = IrGraph(core.Graph(program.desc), for_test=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册