test_fake_quantize_op.py 19.1 KB
Newer Older
视言's avatar
视言 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

视言's avatar
视言 已提交
17
import unittest
18
import math
视言's avatar
视言 已提交
19
import numpy as np
20
import math
21
from op_test import OpTest
22
import paddle.fluid.core as core
视言's avatar
视言 已提交
23 24


25 26 27 28 29 30 31 32 33 34 35 36 37
# numpy.round has different behavior in comparision to c++ round function
# so we use round_c instead of numpy.round to align the output data
def round_c_single_element(x):
    dtype = type(x)
    if x >= 0:
        return dtype(np.floor(x + 0.5))
    else:
        return dtype(np.ceil(x - 0.5))


round_c = np.vectorize(round_c_single_element)


视言's avatar
视言 已提交
38 39
class TestFakeQuantizeOp(OpTest):
    def setUp(self):
40
        self.set_dtype()
41 42
        self.op_type = "fake_quantize_abs_max"
        self.attrs = {'bit_length': 8}
43 44
        self.inputs = {'X': np.random.random((124, 240)).astype(self.dtype), }
        scale = np.max(np.abs(self.inputs['X'])).astype(self.dtype)
45
        self.outputs = {
46
            'Out': round_c(self.inputs['X'] / scale * (
47
                (1 << (self.attrs['bit_length'] - 1)) - 1)),
48
            'OutScale': np.array(scale).astype(self.dtype),
49 50
        }

51 52 53
    def set_dtype(self):
        self.dtype = np.float32

54 55
    def test_check_output(self):
        self.check_output()
Z
Zhen Wang 已提交
56 57


58 59 60 61 62
class TestFakeQuantizeOpFloat16(TestFakeQuantizeOp):
    def set_dtype(self):
        self.dtype = np.float16


63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
class TestFakeQuantizeOp1(OpTest):
    def setUp(self):
        self.op_type = "fake_quantize_abs_max"
        self.attrs = {'bit_length': 8}
        self.inputs = {'X': np.zeros((10, 10)).astype("float32"), }
        scale = np.max(np.abs(self.inputs['X'])).astype("float32")
        inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale
        self.outputs = {
            'Out': np.round(self.inputs['X'] * inv_scale * (
                (1 << (self.attrs['bit_length'] - 1)) - 1)),
            'OutScale': np.array(scale).astype("float32"),
        }

    def test_check_output(self):
        self.check_output()


class TestFakeQuantizeOp2(OpTest):
    def setUp(self):
        self.op_type = "fake_quantize_abs_max"
        self.attrs = {'bit_length': 8}
        self.inputs = {'X': np.full((10, 10), 1e-40).astype("float32"), }
        scale = np.max(np.abs(self.inputs['X'])).astype("float32")
        inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale
        self.outputs = {
            'Out': np.round(self.inputs['X'] * inv_scale * (
                (1 << (self.attrs['bit_length'] - 1)) - 1)),
            'OutScale': np.array(scale).astype("float32"),
        }

    def test_check_output(self):
        self.check_output()


Z
Zhen Wang 已提交
97 98
class TestFakeChannelWiseQuantizeOp(OpTest):
    def setUp(self):
99
        self.set_dtype()
100 101 102
        self.set_arg()
        assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1."

Z
Zhen Wang 已提交
103
        self.op_type = "fake_channel_wise_quantize_abs_max"
104 105
        self.attrs = {'bit_length': 8, 'quant_axis': self.quant_axis}

Z
Zhen Wang 已提交
106 107
        scales = []
        outputs = self.inputs['X'].copy()
108 109 110
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
        if self.quant_axis == 0:
            for i in range(self.inputs['X'].shape[0]):
111
                scale_v = np.max(np.abs(self.inputs['X'][i])).astype(self.dtype)
112
                scales.append(scale_v)
113 114
                outputs[i] = round_c(
                    self.dtype(bnt) * (self.dtype(1.0) / scale_v) * outputs[i])
115 116 117
        elif self.quant_axis == 1:
            for i in range(self.inputs['X'].shape[1]):
                scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype(
118
                    self.dtype)
119
                scales.append(scale_v)
120 121 122
                outputs[:, i] = round_c(
                    self.dtype(bnt) * (self.dtype(1.0) / scale_v) *
                    outputs[:, i])
Z
Zhen Wang 已提交
123 124 125

        self.outputs = {
            'Out': outputs,
126
            'OutScale': np.array(scales).astype(self.dtype),
Z
Zhen Wang 已提交
127 128
        }

129 130 131
    def set_arg(self):
        self.quant_axis = 0
        self.inputs = {
132
            'X': np.random.random((20, 15, 6, 6)).astype(self.dtype),
133 134
        }

135 136 137
    def set_dtype(self):
        self.dtype = np.float32

Z
Zhen Wang 已提交
138 139
    def test_check_output(self):
        self.check_output()
140 141


142 143 144 145 146
class TestFakeChannelWiseQuantizeOpFloat16(TestFakeChannelWiseQuantizeOp):
    def set_dtype(self):
        self.dtype = np.float16


147 148 149 150
class TestFakeChannelWiseQuantizeOp1(TestFakeChannelWiseQuantizeOp):
    def set_quant_axis(self):
        self.quant_axis = 1
        self.inputs = {
151
            'X': np.random.random((15, 20, 5, 5)).astype(self.dtype),
152 153 154
        }


155 156 157 158 159
class TestFakeChannelWiseQuantizeOp1Float16(TestFakeChannelWiseQuantizeOp1):
    def set_dtype(self):
        self.dtype = np.float16


160 161 162
class TestFakeChannelWiseQuantizeOp2(TestFakeChannelWiseQuantizeOp):
    def set_quant_axis(self):
        self.quant_axis = 0
163
        self.inputs = {'X': np.random.random((30, 15)).astype(self.dtype), }
164 165 166 167 168


class TestFakeChannelWiseQuantizeOp3(TestFakeChannelWiseQuantizeOp):
    def set_quant_axis(self):
        self.quant_axis = 1
169
        self.inputs = {'X': np.random.random((30, 15)).astype(self.dtype), }
170 171


172
class TestFakeQuantizeRangeAbsMaxOp(OpTest):
173
    def setUp(self):
174
        self.set_dtype()
175
        self.op_type = "fake_quantize_range_abs_max"
视言's avatar
视言 已提交
176
        self.attrs = {
177 178 179
            'bit_length': int(5),
            'window_size': int(1),
            'is_test': False
视言's avatar
视言 已提交
180
        }
181
        x = (np.random.random((8, 16, 7, 7)) - 0.5) * 10
182
        x = x.astype(self.dtype)
视言's avatar
视言 已提交
183
        self.inputs = {
184
            'X': x,
185
            'Iter': np.zeros(1).astype("int64"),
186
            'InScale': np.zeros(1).astype(self.dtype)
视言's avatar
视言 已提交
187
        }
188
        scale = np.max(np.abs(self.inputs['X'])).astype(self.dtype)
189

190
        out_scales = np.zeros(self.attrs['window_size']).astype(self.dtype)
191
        out_scales[0] = scale
视言's avatar
视言 已提交
192
        self.outputs = {
193 194 195
            'Out': round_c(
                self.dtype((1 << (self.attrs['bit_length'] - 1)) - 1) *
                (self.dtype(1.0) / scale) * self.inputs['X']),
196 197
            'OutScale': scale,
            'OutScales': out_scales,
视言's avatar
视言 已提交
198 199
        }

200 201 202
    def set_dtype(self):
        self.dtype = np.float32

视言's avatar
视言 已提交
203 204 205 206
    def test_check_output(self):
        self.check_output()


207 208 209 210 211
class TestFakeQuantizeRangeAbsMaxOpFloat16(TestFakeQuantizeRangeAbsMaxOp):
    def set_dtype(self):
        self.dtype = np.float16


Z
Zhen Wang 已提交
212 213 214 215 216 217 218 219
class TestMovingAverageAbsMaxScaleOp(OpTest):
    def setUp(self):
        self.op_type = "moving_average_abs_max_scale"
        self.attrs = {'moving_rate': float(0.9), 'is_test': False}
        accum = np.zeros(1).astype("float32")
        accum[0] = 1
        state = np.zeros(1).astype("float32")
        state[0] = 1
220
        x = np.random.random((8, 16, 7, 7)).astype("float32")
Z
Zhen Wang 已提交
221
        self.inputs = {
222
            'X': x,
Z
Zhen Wang 已提交
223 224 225 226
            'InAccum': accum,
            'InState': state,
        }

227
        out = x
Z
Zhen Wang 已提交
228 229 230 231 232 233 234 235
        out_accum = np.zeros(1).astype("float32")
        out_state = np.zeros(1).astype("float32")
        out_scale = np.zeros(1).astype("float32")
        out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max(
            np.abs(self.inputs['X'])).astype("float32")
        out_state[0] = self.attrs['moving_rate'] * state[0] + 1
        out_scale = out_accum / out_state
        self.outputs = {
236
            'Out': out,
Z
Zhen Wang 已提交
237 238 239 240 241 242 243 244 245
            'OutAccum': out_accum,
            'OutState': out_state,
            'OutScale': out_scale,
        }

    def test_check_output(self):
        self.check_output()


246 247
class TestFakeQuantizeRangeAbsMaxOp2(OpTest):
    def setUp(self):
248
        self.set_dtype()
249 250 251 252 253 254 255
        self.op_type = "fake_quantize_range_abs_max"
        self.attrs = {
            'bit_length': int(8),
            'window_size': int(1),
            'is_test': True
        }
        x = (np.random.random((8, 16, 7, 7)) - 0.5) * 10
256 257 258 259
        x = x.astype(self.dtype)
        scale = np.array([np.max(np.abs(x)).astype(self.dtype) - 1.0])
        out_scales = np.zeros(self.attrs['window_size']).astype(self.dtype)
        out_scales[0] = scale.astype(self.dtype)
260 261 262
        self.inputs = {
            'X': x,
            'Iter': np.zeros(1).astype("int64"),
263
            'InScale': scale.astype(self.dtype)
264
        }
265 266 267 268 269
        xs = np.clip(x, -scale, scale).astype(self.dtype)
        qs = round_c(
            self.dtype(
                self.dtype((1 << (self.attrs['bit_length'] - 1)) - 1) * (
                    self.dtype(1.0) / scale) * xs))
270 271
        self.outputs = {
            'Out': qs,
272
            'OutScale': scale.astype(self.dtype),
273 274 275
            'OutScales': out_scales,
        }

276 277 278
    def set_dtype(self):
        self.dtype = np.float32

279 280 281 282
    def test_check_output(self):
        self.check_output(no_check_set=set(['OutScale', 'OutScales']))


283 284 285 286 287
class TestFakeQuantizeRangeAbsMaxOp2Float16(TestFakeQuantizeRangeAbsMaxOp2):
    def set_dtype(self):
        self.dtype = np.float16


288 289
class TestMovingOpBase(OpTest):
    def setUp(self):
290
        self.set_dtype()
291 292 293 294 295 296
        self.init_type()
        self.attrs = {
            'bit_length': int(5),
            'moving_rate': float(0.9),
            'is_test': False
        }
297
        accum = np.zeros(1).astype(self.dtype)
298
        accum[0] = 1
299 300 301
        state = np.zeros(1).astype(self.dtype)
        state[0] = self.dtype(1.0)
        scale = np.zeros(1).astype(self.dtype)
302 303
        scale[0] = 0.001
        self.inputs = {
304
            'X': np.random.random((8, 16, 7, 7)).astype(self.dtype),
305 306 307 308 309
            'InScale': scale,
            'InAccum': accum,
            'InState': state,
        }

310 311 312 313 314 315 316 317
        out_accum = np.zeros(1).astype(self.dtype)
        out_state = np.zeros(1).astype(self.dtype)
        out_scale = np.zeros(1).astype(self.dtype)
        out_accum[0] = self.dtype(self.attrs['moving_rate']) * self.dtype(accum[
            0]) + np.max(np.abs(self.inputs['X'])).astype(self.dtype)
        out_state[0] = self.dtype(self.attrs['moving_rate']) * self.dtype(state[
            0]) + self.dtype(1.0)
        out_scale = self.dtype(self.dtype(out_accum) / self.dtype(out_state))
318 319 320 321 322 323 324 325
        out_data = self.calc_output(out_scale)
        self.outputs = {
            'Out': out_data,
            'OutAccum': out_accum,
            'OutState': out_state,
            'OutScale': out_scale,
        }

326 327 328
    def set_dtype(self):
        self.dtype = np.float32

329 330 331 332
    def init_type(self):
        self.op_type = "fake_quantize_moving_average_abs_max"

    def calc_output(self, out_scale):
333
        return round_c(self.inputs['X'] / out_scale * (
334 335 336 337 338 339
            (1 << (self.attrs['bit_length'] - 1)) - 1))

    def test_check_output(self):
        self.check_output()


340 341 342 343 344 345 346 347
class TestMovingOpBaseFloat16(TestMovingOpBase):
    def set_dtype(self):
        self.dtype = np.float16

    def test_check_output(self):
        self.check_output(atol=1e-2)


348 349 350 351 352 353 354 355 356
class TestFakeQuantDequantMovingOp(TestMovingOpBase):
    def init_type(self):
        self.op_type = "fake_quantize_dequantize_moving_average_abs_max"

    def calc_output(self, out_scale):
        range_v = (1 << (self.attrs['bit_length'] - 1)) - 1
        return np.round(self.inputs['X'] / out_scale *
                        range_v) * out_scale / range_v

357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386
    def test_check_grad(self):
        x = self.inputs["X"]
        gradient = [np.ones(x.shape) / np.product(x.shape)]
        self.check_grad(["X"], "Out", user_defined_grads=gradient)


class TestFakeQuantDequantAbsOp(OpTest):
    def setUp(self):
        self.op_type = "fake_quantize_dequantize_abs_max"
        self.attrs = {'bit_length': 8}
        self.inputs = {'X': np.random.random((124, 240)).astype("float32"), }
        scale = np.max(np.abs(self.inputs['X'])).astype("float32")
        out_data = self.calc_output(scale)
        self.outputs = {
            'Out': out_data,
            'OutScale': np.array(scale).astype("float32"),
        }

    def calc_output(self, scale):
        range_v = (1 << (self.attrs['bit_length'] - 1)) - 1
        return np.round(self.inputs['X'] / scale * range_v) * scale / range_v

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        x = self.inputs["X"]
        gradient = [np.ones(x.shape) / np.product(x.shape)]
        self.check_grad(["X"], "Out", user_defined_grads=gradient)

387

H
huangxu96 已提交
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
class TestChannelWiseFakeQuantDequantOp(OpTest):
    def setUp(self):
        self.set_arg()
        assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1."

        self.op_type = "fake_channel_wise_quantize_dequantize_abs_max"
        self.attrs = {'bit_length': 8, 'quant_axis': self.quant_axis}

        scales = []
        outputs = self.inputs['X'].copy()
        range_v = (1 << (self.attrs['bit_length'] - 1)) - 1
        if self.quant_axis == 0:
            for i in range(self.inputs['X'].shape[0]):
                scale_v = np.max(np.abs(self.inputs['X'][i])).astype("float32")
                scales.append(scale_v)
                outputs[i] = np.round(outputs[i] * range_v /
                                      scale_v) * scale_v / range_v
        elif self.quant_axis == 1:
            for i in range(self.inputs['X'].shape[1]):
                scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype(
                    "float32")
                scales.append(scale_v)
                outputs[:, i] = np.round(outputs[:, i] * range_v /
                                         scale_v) * scale_v / range_v

        self.outputs = {
            'Out': outputs,
            'OutScale': np.array(scales).astype("float32"),
        }

    def set_arg(self):
        self.quant_axis = 0
        self.inputs = {
            'X': np.random.random((3, 4, 64, 64)).astype("float32"),
        }

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        x = self.inputs["X"]
        gradient = [np.ones(x.shape) / np.product(x.shape)]
        self.check_grad(["X"], "Out", user_defined_grads=gradient)


class TestChannelWiseFakeQuantDequantOp1(TestChannelWiseFakeQuantDequantOp):
    def set_arg(self):
        self.quant_axis = 1
        self.inputs = {
            'X': np.random.random((15, 20, 5, 5)).astype("float32"),
        }


class TestChannelWiseFakeQuantDequantOp2(TestChannelWiseFakeQuantDequantOp):
    def set_arg(self):
        self.quant_axis = 0
        self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }


class TestChannelWiseFakeQuantDequantOp3(TestChannelWiseFakeQuantDequantOp):
    def set_arg(self):
        self.quant_axis = 1
        self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }


453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591
def quantize_max_abs(x, max_range):
    scale = np.max(np.abs(x).flatten())
    y = np.round(x / scale * max_range)
    return y, scale


def channel_wise_quantize_max_abs(x, quant_bit=8, quant_axis=0):
    assert quant_axis in [0, 1], "The quant_axis should be 0 or 1."
    scales = []
    y = x.copy()
    max_range = math.pow(2, quant_bit - 1) - 1
    if quant_axis == 0:
        for i in range(x.shape[0]):
            scale = np.max(np.abs(x[i])).astype("float32")
            scales.append(scale)
            y[i] = np.round(x[i] * max_range / scale)
    elif quant_axis == 1:
        for i in range(x.shape[1]):
            scale = np.max(np.abs(x[:, i])).astype("float32")
            scales.append(scale)
            y[:, i] = np.round(x[:, i] * max_range / scale)
    return y, scales


class TestChannelWiseQuantizeOp(OpTest):
    def set_args(self):
        self.bit_length = 8
        self.data_type = "float32"
        self.quant_axis = 0

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
        yq, scale = channel_wise_quantize_max_abs(x, self.bit_length,
                                                  self.quant_axis)
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis
        }
        self.outputs = {'Y': yq}

    def test_check_output(self):
        self.check_output()


class TestChannelWiseQuantizeOp1(TestChannelWiseQuantizeOp):
    def set_args(self):
        self.bit_length = 8
        self.data_type = "float32"
        self.quant_axis = 1


class TestChannelWiseQuantizeOpTrain(OpTest):
    def set_args(self):
        self.bit_length = 8
        self.data_type = "float32"
        self.quant_axis = 0
        self.is_test = False

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
        yq, scale = channel_wise_quantize_max_abs(x, self.bit_length,
                                                  self.quant_axis)
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis,
            'is_test': self.is_test
        }
        self.outputs = {'Y': yq, 'OutScale': scale}

    def test_check_output(self):
        self.check_output()


class TestquantizeOp(OpTest):
    def set_args(self):
        self.bit_length = 8
        self.quant_axis = -1
        self.max_range = math.pow(2, self.bit_length - 1) - 1
        self.data_type = "float32"

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(31, 65).astype(self.data_type)
        yq, scale = quantize_max_abs(x, self.max_range)
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis,
        }
        self.outputs = {'Y': yq}

    def test_check_output(self):
        self.check_output()


class TestquantizeOpTrain(TestquantizeOp):
    def set_args(self):
        self.bit_length = 8
        self.quant_axis = -1
        self.max_range = math.pow(2, self.bit_length - 1) - 1
        self.data_type = "float32"
        self.is_test = False

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(31, 65).astype(self.data_type)
        yq, scale = quantize_max_abs(x, self.max_range)
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis,
            'is_test': self.is_test
        }
        self.outputs = {'Y': yq, 'OutScale': scale}

    def test_check_output(self):
        self.check_output()


视言's avatar
视言 已提交
592 593
if __name__ == "__main__":
    unittest.main()