save_quant_model.py 4.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   copyright (c) 2019 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
#     http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.

import unittest
import os
import sys
import argparse
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
W
Wojciech Uss 已提交
22
from paddle.fluid.contrib.slim.quantization import Quant2Int8MkldnnPass
23 24
from paddle.fluid import core

P
pangyoki 已提交
25 26
paddle.enable_static()

27 28 29 30

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
W
Wojciech Uss 已提交
31 32 33 34
        '--quant_model_path',
        type=str,
        default='',
        help='A path to a Quant model.')
35 36 37 38 39
    parser.add_argument(
        '--int8_model_save_path',
        type=str,
        default='',
        help='Saved optimized and quantized INT8 model')
40
    parser.add_argument(
41
        '--ops_to_quantize',
42 43
        type=str,
        default='',
44 45
        help='A comma separated list of operators to quantize. Only quantizable operators are taken into account. If the option is not used, an attempt to quantize all quantizable operators will be made.'
    )
46 47 48 49 50 51 52 53
    parser.add_argument(
        '--op_ids_to_skip',
        type=str,
        default='',
        help='A comma separated list of operator ids to skip in quantization.')
    parser.add_argument(
        '--debug',
        action='store_true',
W
Wojciech Uss 已提交
54
        help='If used, the graph of Quant model is drawn.')
55 56 57 58 59

    test_args, args = parser.parse_known_args(namespace=unittest)
    return test_args, sys.argv[:1] + args


60 61 62 63 64
def transform_and_save_int8_model(original_path,
                                  save_path,
                                  ops_to_quantize='',
                                  op_ids_to_skip='',
                                  debug=False):
65 66 67 68 69 70 71 72 73 74 75 76
    place = fluid.CPUPlace()
    exe = fluid.Executor(place)
    inference_scope = fluid.executor.global_scope()
    with fluid.scope_guard(inference_scope):
        if os.path.exists(os.path.join(original_path, '__model__')):
            [inference_program, feed_target_names,
             fetch_targets] = fluid.io.load_inference_model(original_path, exe)
        else:
            [inference_program, feed_target_names,
             fetch_targets] = fluid.io.load_inference_model(original_path, exe,
                                                            'model', 'params')

77 78 79 80
        ops_to_quantize_set = set()
        print(ops_to_quantize)
        if len(ops_to_quantize) > 0:
            ops_to_quantize_set = set(ops_to_quantize.split(','))
81

82 83 84 85
        op_ids_to_skip_set = set([-1])
        print(op_ids_to_skip)
        if len(op_ids_to_skip) > 0:
            op_ids_to_skip_set = set(map(int, op_ids_to_skip.split(',')))
86 87

        graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
88
        if (debug):
W
Wojciech Uss 已提交
89 90
            graph.draw('.', 'quant_orig', graph.all_op_nodes())
        transform_to_mkldnn_int8_pass = Quant2Int8MkldnnPass(
91 92
            ops_to_quantize_set,
            _op_ids_to_skip=op_ids_to_skip_set,
93 94 95
            _scope=inference_scope,
            _place=place,
            _core=core,
96
            _debug=debug)
W
Wojciech Uss 已提交
97
        graph = transform_to_mkldnn_int8_pass.apply(graph)
98 99 100 101
        inference_program = graph.to_program()
        with fluid.scope_guard(inference_scope):
            fluid.io.save_inference_model(save_path, feed_target_names,
                                          fetch_targets, exe, inference_program)
W
Wojciech Uss 已提交
102 103 104
        print(
            "Success! INT8 model obtained from the Quant model can be found at {}\n"
            .format(save_path))
105 106 107 108 109


if __name__ == '__main__':
    global test_args
    test_args, remaining_args = parse_args()
110 111 112
    transform_and_save_int8_model(
        test_args.quant_model_path, test_args.int8_model_save_path,
        test_args.ops_to_quantize, test_args.op_ids_to_skip, test_args.debug)