test_conv_bn_constraints.py 3.9 KB
Newer Older
W
whs 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
# Copyright (c) 2023  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os
import unittest
import tempfile
sys.path.append("../../")
import paddle
from paddle.vision.models import resnet18
from paddleslim.quant import SlimQuantConfig as QuantConfig
from paddleslim.quant import SlimQAT
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
from paddleslim.quant.nn.conv_bn import QuantedConv2DBatchNorm
from paddleslim.quant.constraints import FreezedConvBNConstraint
from test_qat import TestQuantAwareTraining, load_model_and_count_layer


class TestConvBNConstraintsBaseCase(TestQuantAwareTraining):
    """ Common cases for testing 'quantize', 'convert' and 'jit.save' function."""

    def extra_qconfig(self, qconfig):
        qconfig.add_constraints(FreezedConvBNConstraint(freeze_bn_delay=1))


class TestConvBNConstraints(unittest.TestCase):
    """ More special cases on convolution and batch norm constraints."""

    def setUp(self):
        paddle.set_device("cpu")
        self.temp_dir = tempfile.TemporaryDirectory(dir="./")
        self.path = os.path.join(self.temp_dir.name, 'conv_bn_constraints')

    def tearDown(self):
        self.temp_dir.cleanup()

    def _count_layers(self, model, layer_type):
        count = 0
        for _layer in model.sublayers(True):
            if isinstance(_layer, layer_type):
                count += 1
        return count

    def _get_one_layer(self, model, layer_type):
        for _layer in model.sublayers(True):
            if isinstance(_layer, layer_type):
                return _layer
        return None

    def test_conv_bn(self):
        model = resnet18()
        conv_count = self._count_layers(model, paddle.nn.Conv2D)
        quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.9)
        q_config = QuantConfig(activation=quanter, weight=quanter)
        # It will freeze the batch normaliztion after 'freeze_bn_delay' steps
        q_config.add_constraints(FreezedConvBNConstraint(freeze_bn_delay=1))

        qat = SlimQAT(q_config)
        x = paddle.rand([1, 3, 224, 224])
        paddle.jit.save(model, "./test_model", input_spec=[x])
        quant_model = qat.quantize(model, inplace=True, inputs=x)

        # check freeze_bn_delay
        qat_conv_bn_layer = self._get_one_layer(quant_model,
                                                QuantedConv2DBatchNorm)
        self.assertIsNotNone(qat_conv_bn_layer)
        self.assertFalse(qat_conv_bn_layer._freeze_bn)
        quant_model.train()
        out = quant_model(x)
        out.backward()
        out = quant_model(x)
        out.backward()
        self.assertTrue(qat_conv_bn_layer._freeze_bn)

        # check the count of QAT layers in QAT model
        qat_layer_count = self._count_layers(quant_model,
                                             QuantedConv2DBatchNorm)
        self.assertEqual(qat_layer_count, conv_count)

        # check the count of convolution and batch norm in saved static graph
        quant_model.eval()
        infer_model = qat.convert(quant_model, inplace=True)
        save_path = os.path.join(self.path, 'infer_model')
        paddle.jit.save(infer_model, save_path, input_spec=[x])
        layer2count = load_model_and_count_layer(save_path,
                                                 ['conv2d', 'batch_norm'])
        self.assertEqual(layer2count['conv2d'], conv_count)
        self.assertEqual(layer2count['batch_norm'], 0)


if __name__ == '__main__':
    unittest.main()