test_fake_quantize_op.py 17.7 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
视言's avatar
视言 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

视言's avatar
视言 已提交
17
import unittest
18
import itertools
视言's avatar
视言 已提交
19
import numpy as np
20
import math
21
from op_test import OpTest
视言's avatar
视言 已提交
22 23


24 25 26 27 28 29 30
# numpy.round has different behavior in comparision to c++ round function
# so we use round_c instead of numpy.round to align the output data
def round_c_single_element(val):
    dtype = type(val)
    if val >= 0:
        return dtype(np.floor(val + 0.5))
    return dtype(np.ceil(val - 0.5))
31

Z
Zhen Wang 已提交
32

33
round_c = np.vectorize(round_c_single_element)
Z
Zhen Wang 已提交
34

35

36 37 38 39 40
def get_compute_type(dtype):
    assert dtype in [np.float16, np.float32, np.float64]
    if dtype == np.float16:
        return np.float32
    return dtype
41 42


43
class TestFakeQuantizeAbsMaxOp(OpTest):
44
    def setUp(self):
45
        self.op_type = 'fake_quantize_abs_max'
46 47
        self.attrs = {'bit_length': 8}

48 49 50 51 52 53 54 55 56 57
    def _fake_quantize_abs_max(self, dtype, input_shape, distribution):
        input_data = distribution(input_shape).astype(dtype)
        compute_type = get_compute_type(dtype)
        scale = np.max(np.abs(input_data))
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
        inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale
        output_data = round_c(input_data.astype(compute_type) * inv_scale * bnt)
        self.inputs = {'X': input_data}
        self.outputs = {'Out': output_data, 'OutScale': scale}
        self.dtype = dtype
58 59
        self.check_output()

60 61
    def test_fake_quantize_abs_max(self):
        self._fake_quantize_abs_max(np.float32, (124, 240), np.random.random)
62

63 64
    def test_fake_quantize_abs_max_float16(self):
        self._fake_quantize_abs_max(np.float16, (124, 240), np.random.random)
65

66 67
    def test_fake_quantize_abs_max_underflow(self):
        self._fake_quantize_abs_max(np.float32, (10, 10), np.zeros)
68

69 70 71
    def test_fake_quantize_abs_max_underflow2(self):
        self._fake_quantize_abs_max(np.float32, (10, 10),
                                    lambda shape: np.full(shape, 1e-40))
Z
Zhen Wang 已提交
72 73


74 75 76 77
class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest):
    def setUp(self):
        self.op_type = 'fake_channel_wise_quantize_abs_max'
        self.attrs = {'bit_length': 8}
78

79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
    def _fake_channel_wise_quantize_abs_max(self, dtype, input_shape,
                                            quant_axis, distribution):
        assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.'
        input_data = distribution(input_shape).astype(dtype)
        compute_type = get_compute_type(dtype)
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
        compute_axis = tuple(
            i for i in range(len(input_shape)) if i != quant_axis)
        scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True)
        output_data = round_c(bnt * input_data.astype(compute_type) /
                              scale_broadcast)
        if quant_axis == 1:
            scale_broadcast = np.transpose(scale_broadcast,
                                           (1, ) + compute_axis)
        scale = scale_broadcast.reshape(input_shape[quant_axis], -1)[:, 0]
        self.inputs = {'X': input_data}
        self.outputs = {'Out': output_data, 'OutScale': scale}
        self.dtype = dtype
        self.attrs['quant_axis'] = quant_axis
Z
Zhen Wang 已提交
98
        self.check_output()
99

100 101 102 103 104 105 106 107 108 109 110 111 112
    def test_fake_channel_wise_quantize_abs_max(self):
        dtype_options = [np.float32, np.float16]
        input_shape_quant_axis_options = [[(20, 15, 6, 6), 0],
                                          [(15, 20, 5, 5), 1], [(30, 15), 0],
                                          [(30, 15), 1]]
        for dtype, input_shape_quant_axis in itertools.product(
                dtype_options, input_shape_quant_axis_options):
            input_shape, quant_axis = input_shape_quant_axis
            with self.subTest(
                    dtype=dtype, input_shape=input_shape,
                    quant_axis=quant_axis):
                self._fake_channel_wise_quantize_abs_max(
                    dtype, input_shape, quant_axis, np.random.random)
113 114


115
class TestFakeQuantizeRangeAbsMaxOp(OpTest):
116
    def setUp(self):
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
        self.op_type = 'fake_quantize_range_abs_max'
        self.attrs = {'bit_length': 5, 'window_size': 1}

    def _fake_quantize_range_abs_max(self,
                                     dtype,
                                     input_shape,
                                     distribution,
                                     is_test=False):
        input_data = distribution(input_shape).astype(dtype)
        compute_type = get_compute_type(dtype)
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
        in_scale = np.zeros(1).astype(dtype)
        out_scale = np.zeros(self.attrs['window_size']).astype(dtype)
        out_scale[0] = np.max(np.abs(input_data))
        if is_test:
            out_scale[0] = in_scale[0] = out_scale[0] - 1.0
            clip_data = np.clip(input_data, -in_scale, in_scale)
        else:
            clip_data = input_data
        output_data = round_c(
            clip_data.astype(compute_type) / out_scale[0] * bnt)
视言's avatar
视言 已提交
138
        self.inputs = {
139 140 141
            'X': input_data,
            'Iter': np.zeros(1).astype(np.int64),
            'InScale': in_scale
视言's avatar
视言 已提交
142 143
        }
        self.outputs = {
144 145 146
            'Out': output_data,
            'OutScale': out_scale[0],
            'OutScales': out_scale
视言's avatar
视言 已提交
147
        }
148 149
        self.dtype = dtype
        self.attrs['is_test'] = is_test
视言's avatar
视言 已提交
150 151
        self.check_output()

152 153 154 155 156 157 158 159 160 161 162
    def test_fake_quantize_range_abs_max(self):
        dtype_options = [np.float32, np.float16]
        is_test_options = [False, True]
        for dtype, is_test in itertools.product(dtype_options, is_test_options):
            self.attrs['bit_length'] = 8 if is_test else 5
            with self.subTest(dtype=dtype, is_test=is_test):
                self._fake_quantize_range_abs_max(
                    dtype, (8, 16, 7, 7),
                    lambda shape: (np.random.random(shape) - 0.5) * 10,
                    is_test=is_test)

视言's avatar
视言 已提交
163

Z
Zhen Wang 已提交
164 165
class TestMovingAverageAbsMaxScaleOp(OpTest):
    def setUp(self):
166
        self.op_type = 'moving_average_abs_max_scale'
Z
Zhen Wang 已提交
167 168
        self.attrs = {'moving_rate': float(0.9), 'is_test': False}

169 170 171 172 173 174 175
    def _moving_average_abs_max_scale(self, dtype, input_shape, distribution):
        input_data = distribution(input_shape).astype(dtype)
        in_accum = np.ones(1).astype(dtype)
        in_state = np.ones(1).astype(dtype)
        out_accum = self.attrs['moving_rate'] * in_accum[0] + np.max(
            np.abs(input_data))
        out_state = self.attrs['moving_rate'] * in_state[0] + 1.0
Z
Zhen Wang 已提交
176
        out_scale = out_accum / out_state
177 178 179 180 181
        self.inputs = {
            'X': input_data,
            'InAccum': in_accum,
            'InState': in_state
        }
Z
Zhen Wang 已提交
182
        self.outputs = {
183
            'Out': input_data,
Z
Zhen Wang 已提交
184 185
            'OutAccum': out_accum,
            'OutState': out_state,
186
            'OutScale': out_scale
Z
Zhen Wang 已提交
187
        }
188
        self.dtype = dtype
Z
Zhen Wang 已提交
189 190
        self.check_output()

191 192 193
    def test_moving_average_abs_max(self):
        self._moving_average_abs_max_scale(np.float32, (8, 16, 7, 7),
                                           np.random.random)
194 195


196
class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
197
    def setUp(self):
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
        self.op_type = 'fake_quantize_moving_average_abs_max'
        self.attrs = {'bit_length': 5, 'moving_rate': 0.9, 'is_test': False}

    def _fake_quantize_moving_average_abs_max(self,
                                              dtype,
                                              input_shape,
                                              distribution,
                                              dequantize=False,
                                              with_gradient=False):
        input_data = distribution(input_shape).astype(dtype)
        compute_type = get_compute_type(dtype)
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
        in_accum = np.ones(1).astype(dtype)
        in_state = np.ones(1).astype(dtype)
        in_scale = np.array([0.001]).astype(dtype)
        out_accum = np.zeros(1).astype(dtype)
        out_state = np.zeros(1).astype(dtype)
        out_scale = np.zeros(1).astype(dtype)
        out_accum[0] = self.attrs['moving_rate'] * in_accum[0] + np.max(
            np.abs(input_data))
        out_state[0] = self.attrs['moving_rate'] * in_state[0] + 1.0
        out_scale = out_accum / out_state
        round_data = round_c(input_data.astype(compute_type) / out_scale * bnt)
        if dequantize:
            output_data = (round_data * out_scale / bnt).astype(dtype)
            self.op_type = 'fake_quantize_dequantize_moving_average_abs_max'
        else:
            output_data = round_data.astype(dtype)
226
        self.inputs = {
227 228 229 230
            'X': input_data,
            'InScale': in_scale,
            'InAccum': in_accum,
            'InState': in_state
231 232
        }
        self.outputs = {
233
            'Out': output_data,
234 235
            'OutAccum': out_accum,
            'OutState': out_state,
236
            'OutScale': out_scale
237
        }
238
        self.dtype = dtype
239
        self.check_output()
240 241 242 243 244
        if with_gradient:
            gradient = [
                np.ones(input_data.shape) / np.product(input_data.shape)
            ]
            self.check_grad(['X'], 'Out', user_defined_grads=gradient)
245

246 247 248
    def test_fake_quantize_moving_average_abs_max(self):
        self._fake_quantize_moving_average_abs_max(np.float32, (8, 16, 7, 7),
                                                   np.random.random)
249

250 251 252
    def test_fake_quantize_moving_average_abs_max_float16(self):
        self._fake_quantize_moving_average_abs_max(np.float16, (8, 16, 7, 7),
                                                   np.random.random)
253

254 255 256 257 258 259
    def test_fake_quantize_dequantize_moving_average_abs_max(self):
        self._fake_quantize_moving_average_abs_max(
            np.float32, (8, 16, 7, 7),
            np.random.random,
            dequantize=True,
            with_gradient=True)
260 261


262
class TestFakeQuantizeDequantizeAbsMaxOp(OpTest):
263
    def setUp(self):
264
        self.op_type = 'fake_quantize_dequantize_abs_max'
265
        self.attrs = {'bit_length': 8}
266 267 268 269 270 271 272 273

    def _fake_quantize_dequantize_abs_max(self, dtype, input_shape,
                                          distribution):
        input_data = distribution(input_shape).astype(dtype)
        scale = np.max(np.abs(input_data)).astype(dtype)
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
        output_data = round_c(input_data / scale * bnt) * scale / bnt
        self.inputs = {'X': input_data}
274
        self.outputs = {
275 276
            'Out': output_data,
            'OutScale': np.array(scale).astype(dtype)
277
        }
278
        self.dtype = dtype
279
        self.check_output()
280 281
        gradient = [np.ones(input_data.shape) / np.product(input_data.shape)]
        self.check_grad(['X'], 'Out', user_defined_grads=gradient)
282

283 284 285
    def test_fake_quantize_dequantize_abs_max(self):
        self._fake_quantize_dequantize_abs_max(np.float32, (124, 240),
                                               np.random.random)
286

287

288
class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest):
H
huangxu96 已提交
289
    def setUp(self):
290 291
        self.op_type = 'fake_channel_wise_quantize_dequantize_abs_max'
        self.attrs = {'bit_length': 8}
H
huangxu96 已提交
292

293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
    def _fake_channel_wise_quantize_dequantize_abs_max(
            self, dtype, input_shape, quant_axis, distribution):
        assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.'
        input_data = distribution(input_shape).astype(dtype)
        compute_type = get_compute_type(dtype)
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
        output_data = input_data.copy().astype(compute_type)
        compute_axis = tuple(
            i for i in range(len(input_shape)) if i != quant_axis)
        scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True)
        output_data = round_c(bnt * output_data /
                              scale_broadcast) * scale_broadcast / bnt
        if quant_axis == 1:
            scale_broadcast = np.transpose(scale_broadcast,
                                           (1, ) + compute_axis)
        scale = scale_broadcast.reshape(input_shape[quant_axis], -1)[:, 0]
        self.inputs = {'X': input_data}
        self.outputs = {'Out': output_data, 'OutScale': scale}
        self.dtype = dtype
        self.attrs['quant_axis'] = quant_axis
H
huangxu96 已提交
313
        self.check_output()
314 315
        gradient = [np.ones(input_data.shape) / np.product(input_data.shape)]
        self.check_grad(['X'], 'Out', user_defined_grads=gradient)
H
huangxu96 已提交
316

317 318 319 320 321 322 323 324
    def test_channel_wise_fake_quant_dequant_abs_max(self):
        input_shape_quant_axis_options = [[(3, 4, 64, 64), 0],
                                          [(15, 20, 5, 5), 1], [(30, 15), 0],
                                          [(30, 15), 1]]
        for input_shape, quant_axis in input_shape_quant_axis_options:
            with self.subTest(input_shape=input_shape, quant_axis=quant_axis):
                self._fake_channel_wise_quantize_dequantize_abs_max(
                    np.float32, input_shape, quant_axis, np.random.random)
H
huangxu96 已提交
325 326


327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
def quantize_max_abs(x, max_range):
    scale = np.max(np.abs(x).flatten())
    y = np.round(x / scale * max_range)
    return y, scale


def channel_wise_quantize_max_abs(x, quant_bit=8, quant_axis=0):
    assert quant_axis in [0, 1], "The quant_axis should be 0 or 1."
    scales = []
    y = x.copy()
    max_range = math.pow(2, quant_bit - 1) - 1
    if quant_axis == 0:
        for i in range(x.shape[0]):
            scale = np.max(np.abs(x[i])).astype("float32")
            scales.append(scale)
            y[i] = np.round(x[i] * max_range / scale)
    elif quant_axis == 1:
        for i in range(x.shape[1]):
            scale = np.max(np.abs(x[:, i])).astype("float32")
            scales.append(scale)
            y[:, i] = np.round(x[:, i] * max_range / scale)
    return y, scales


class TestChannelWiseQuantizeOp(OpTest):
    def set_args(self):
        self.bit_length = 8
        self.data_type = "float32"
        self.quant_axis = 0

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
        yq, scale = channel_wise_quantize_max_abs(x, self.bit_length,
                                                  self.quant_axis)
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis
        }
        self.outputs = {'Y': yq}

    def test_check_output(self):
        self.check_output()


class TestChannelWiseQuantizeOp1(TestChannelWiseQuantizeOp):
    def set_args(self):
        self.bit_length = 8
        self.data_type = "float32"
        self.quant_axis = 1


class TestChannelWiseQuantizeOpTrain(OpTest):
    def set_args(self):
        self.bit_length = 8
        self.data_type = "float32"
        self.quant_axis = 0
        self.is_test = False

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
        yq, scale = channel_wise_quantize_max_abs(x, self.bit_length,
                                                  self.quant_axis)
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis,
            'is_test': self.is_test
        }
        self.outputs = {'Y': yq, 'OutScale': scale}

    def test_check_output(self):
        self.check_output()


class TestquantizeOp(OpTest):
    def set_args(self):
        self.bit_length = 8
        self.quant_axis = -1
        self.max_range = math.pow(2, self.bit_length - 1) - 1
        self.data_type = "float32"

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(31, 65).astype(self.data_type)
        yq, scale = quantize_max_abs(x, self.max_range)
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis,
        }
        self.outputs = {'Y': yq}

    def test_check_output(self):
        self.check_output()


class TestquantizeOpTrain(TestquantizeOp):
    def set_args(self):
        self.bit_length = 8
        self.quant_axis = -1
        self.max_range = math.pow(2, self.bit_length - 1) - 1
        self.data_type = "float32"
        self.is_test = False

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(31, 65).astype(self.data_type)
        yq, scale = quantize_max_abs(x, self.max_range)
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis,
            'is_test': self.is_test
        }
        self.outputs = {'Y': yq, 'OutScale': scale}

    def test_check_output(self):
        self.check_output()


466
if __name__ == '__main__':
视言's avatar
视言 已提交
467
    unittest.main()