test_weight_quantization_mobilenetv1.py 3.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import os
import time
from paddle.dataset.common import download, DATA_HOME
from paddle.fluid.contrib.slim.quantization import WeightQuantization


class TestWeightQuantization(unittest.TestCase):
    def setUp(self):
        self.weight_quantization_dir = 'weight_quantization'
        self.cache_folder = os.path.join(DATA_HOME,
                                         self.weight_quantization_dir)

    def download_model(self, model_name, data_url, data_md5):
        download(data_url, self.weight_quantization_dir, data_md5)
        file_name = data_url.split('/')[-1]
        file_path = os.path.join(self.cache_folder, file_name)
        print(model_name + ' is downloaded at ' + file_path)

        unziped_path = os.path.join(self.cache_folder, model_name)
        self.cache_unzipping(unziped_path, file_path)
        print(model_name + ' is unziped at ' + unziped_path)
        return unziped_path

    def cache_unzipping(self, target_folder, zip_path):
        if not os.path.exists(target_folder):
            cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder,
                                                          zip_path)
            os.system(cmd)

45 46
    def run_test(self, model_name, model_data_url, model_data_md5, weight_bits,
                 quantizable_op_type, threshold_rate):
47 48 49 50 51 52 53

        model_dir = self.download_model(model_name, model_data_url,
                                        model_data_md5)

        timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
        save_model_dir = os.path.join(
            os.getcwd(),
54
            model_name + "_wq_" + str(weight_bits) + "_" + timestamp)
55 56 57
        weight_quant = WeightQuantization(model_dir=model_dir + "/model")
        weight_quant.quantize_weight_to_int(
            save_model_dir=save_model_dir,
58
            weight_bits=weight_bits,
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
            quantizable_op_type=quantizable_op_type,
            threshold_rate=threshold_rate)
        print("finish weight quantization for " + model_name + "\n")

        try:
            os.system("rm -rf {}".format(save_model_dir))
        except Exception as e:
            print("Failed to delete {} due to {}".format(save_model_dir, str(
                e)))


class TestWeightQuantizationMobilenetv1(TestWeightQuantization):
    model_name = "mobilenetv1"
    model_data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz"
    model_data_md5 = "13892b0716d26443a8cdea15b3c6438b"

    def test_weight_quantization_mobilenetv1_8bit(self):
76
        weight_bits = 8
77 78 79
        quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
        threshold_rate = 0.0
        self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
80
                      weight_bits, quantizable_op_type, threshold_rate)
81 82

    def test_weight_quantization_mobilenetv1_16bit(self):
83
        weight_bits = 16
84 85 86
        quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
        threshold_rate = 1e-9
        self.run_test(self.model_name, self.model_data_url, self.model_data_md5,
87
                      weight_bits, quantizable_op_type, threshold_rate)
88 89 90 91


if __name__ == '__main__':
    unittest.main()