test_fake_quantize_op.py 17.1 KB
Newer Older
视言's avatar
视言 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

视言's avatar
视言 已提交
17 18
import unittest
import numpy as np
19
import math
20
from op_test import OpTest
21
import paddle.fluid.core as core
视言's avatar
视言 已提交
22 23 24 25


class TestFakeQuantizeOp(OpTest):
    def setUp(self):
26 27 28 29 30 31 32 33 34 35 36 37
        self.op_type = "fake_quantize_abs_max"
        self.attrs = {'bit_length': 8}
        self.inputs = {'X': np.random.random((124, 240)).astype("float32"), }
        scale = np.max(np.abs(self.inputs['X'])).astype("float32")
        self.outputs = {
            'Out': np.round(self.inputs['X'] / scale * (
                (1 << (self.attrs['bit_length'] - 1)) - 1)),
            'OutScale': np.array(scale).astype("float32"),
        }

    def test_check_output(self):
        self.check_output()
Z
Zhen Wang 已提交
38 39


40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
class TestFakeQuantizeOp1(OpTest):
    def setUp(self):
        self.op_type = "fake_quantize_abs_max"
        self.attrs = {'bit_length': 8}
        self.inputs = {'X': np.zeros((10, 10)).astype("float32"), }
        scale = np.max(np.abs(self.inputs['X'])).astype("float32")
        inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale
        self.outputs = {
            'Out': np.round(self.inputs['X'] * inv_scale * (
                (1 << (self.attrs['bit_length'] - 1)) - 1)),
            'OutScale': np.array(scale).astype("float32"),
        }

    def test_check_output(self):
        self.check_output()


class TestFakeQuantizeOp2(OpTest):
    def setUp(self):
        self.op_type = "fake_quantize_abs_max"
        self.attrs = {'bit_length': 8}
        self.inputs = {'X': np.full((10, 10), 1e-40).astype("float32"), }
        scale = np.max(np.abs(self.inputs['X'])).astype("float32")
        inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale
        self.outputs = {
            'Out': np.round(self.inputs['X'] * inv_scale * (
                (1 << (self.attrs['bit_length'] - 1)) - 1)),
            'OutScale': np.array(scale).astype("float32"),
        }

    def test_check_output(self):
        self.check_output()


Z
Zhen Wang 已提交
74 75
class TestFakeChannelWiseQuantizeOp(OpTest):
    def setUp(self):
76 77 78
        self.set_arg()
        assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1."

Z
Zhen Wang 已提交
79
        self.op_type = "fake_channel_wise_quantize_abs_max"
80 81
        self.attrs = {'bit_length': 8, 'quant_axis': self.quant_axis}

Z
Zhen Wang 已提交
82 83
        scales = []
        outputs = self.inputs['X'].copy()
84 85 86 87 88 89 90 91 92 93 94 95
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
        if self.quant_axis == 0:
            for i in range(self.inputs['X'].shape[0]):
                scale_v = np.max(np.abs(self.inputs['X'][i])).astype("float32")
                scales.append(scale_v)
                outputs[i] = np.round(outputs[i] / scale_v * bnt)
        elif self.quant_axis == 1:
            for i in range(self.inputs['X'].shape[1]):
                scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype(
                    "float32")
                scales.append(scale_v)
                outputs[:, i] = np.round(outputs[:, i] / scale_v * bnt)
Z
Zhen Wang 已提交
96 97 98

        self.outputs = {
            'Out': outputs,
99
            'OutScale': np.array(scales).astype("float32"),
Z
Zhen Wang 已提交
100 101
        }

102 103 104 105 106 107
    def set_arg(self):
        self.quant_axis = 0
        self.inputs = {
            'X': np.random.random((20, 15, 6, 6)).astype("float32"),
        }

Z
Zhen Wang 已提交
108 109
    def test_check_output(self):
        self.check_output()
110 111


112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
class TestFakeChannelWiseQuantizeOp1(TestFakeChannelWiseQuantizeOp):
    def set_quant_axis(self):
        self.quant_axis = 1
        self.inputs = {
            'X': np.random.random((15, 20, 5, 5)).astype("float32"),
        }


class TestFakeChannelWiseQuantizeOp2(TestFakeChannelWiseQuantizeOp):
    def set_quant_axis(self):
        self.quant_axis = 0
        self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }


class TestFakeChannelWiseQuantizeOp3(TestFakeChannelWiseQuantizeOp):
    def set_quant_axis(self):
        self.quant_axis = 1
        self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }


132
class TestFakeQuantizeRangeAbsMaxOp(OpTest):
133 134
    def setUp(self):
        self.op_type = "fake_quantize_range_abs_max"
视言's avatar
视言 已提交
135
        self.attrs = {
136 137 138
            'bit_length': int(5),
            'window_size': int(1),
            'is_test': False
视言's avatar
视言 已提交
139
        }
140 141
        x = (np.random.random((8, 16, 7, 7)) - 0.5) * 10
        x = x.astype("float32")
视言's avatar
视言 已提交
142
        self.inputs = {
143
            'X': x,
144 145
            'Iter': np.zeros(1).astype("int64"),
            'InScale': np.zeros(1).astype("float32")
视言's avatar
视言 已提交
146
        }
147
        scale = np.max(np.abs(self.inputs['X'])).astype("float32")
148

149 150
        out_scales = np.zeros(self.attrs['window_size']).astype("float32")
        out_scales[0] = scale
视言's avatar
视言 已提交
151
        self.outputs = {
152
            'Out': np.round(self.inputs['X'] / scale * (
视言's avatar
视言 已提交
153
                (1 << (self.attrs['bit_length'] - 1)) - 1)),
154 155
            'OutScale': scale,
            'OutScales': out_scales,
视言's avatar
视言 已提交
156 157 158 159 160 161
        }

    def test_check_output(self):
        self.check_output()


Z
Zhen Wang 已提交
162 163 164 165 166 167 168 169
class TestMovingAverageAbsMaxScaleOp(OpTest):
    def setUp(self):
        self.op_type = "moving_average_abs_max_scale"
        self.attrs = {'moving_rate': float(0.9), 'is_test': False}
        accum = np.zeros(1).astype("float32")
        accum[0] = 1
        state = np.zeros(1).astype("float32")
        state[0] = 1
170
        x = np.random.random((8, 16, 7, 7)).astype("float32")
Z
Zhen Wang 已提交
171
        self.inputs = {
172
            'X': x,
Z
Zhen Wang 已提交
173 174 175 176
            'InAccum': accum,
            'InState': state,
        }

177
        out = x
Z
Zhen Wang 已提交
178 179 180 181 182 183 184 185
        out_accum = np.zeros(1).astype("float32")
        out_state = np.zeros(1).astype("float32")
        out_scale = np.zeros(1).astype("float32")
        out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max(
            np.abs(self.inputs['X'])).astype("float32")
        out_state[0] = self.attrs['moving_rate'] * state[0] + 1
        out_scale = out_accum / out_state
        self.outputs = {
186
            'Out': out,
Z
Zhen Wang 已提交
187 188 189 190 191 192 193 194 195
            'OutAccum': out_accum,
            'OutState': out_state,
            'OutScale': out_scale,
        }

    def test_check_output(self):
        self.check_output()


196 197 198 199 200 201 202 203 204 205
class TestFakeQuantizeRangeAbsMaxOp2(OpTest):
    def setUp(self):
        self.op_type = "fake_quantize_range_abs_max"
        self.attrs = {
            'bit_length': int(8),
            'window_size': int(1),
            'is_test': True
        }
        x = (np.random.random((8, 16, 7, 7)) - 0.5) * 10
        x = x.astype("float32")
206
        scale = np.array([np.max(np.abs(x)).astype("float32") - 1.0])
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
        out_scales = np.zeros(self.attrs['window_size']).astype("float32")
        out_scales[0] = scale
        self.inputs = {
            'X': x,
            'Iter': np.zeros(1).astype("int64"),
            'InScale': scale.astype("float32")
        }
        xs = np.clip(x, -scale, scale)
        qs = np.round(xs / scale * ((1 << (self.attrs['bit_length'] - 1)) - 1))
        self.outputs = {
            'Out': qs,
            'OutScale': scale.astype("float32"),
            'OutScales': out_scales,
        }

    def test_check_output(self):
        self.check_output(no_check_set=set(['OutScale', 'OutScales']))


226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
class TestMovingOpBase(OpTest):
    def setUp(self):
        self.init_type()
        self.attrs = {
            'bit_length': int(5),
            'moving_rate': float(0.9),
            'is_test': False
        }
        accum = np.zeros(1).astype("float32")
        accum[0] = 1
        state = np.zeros(1).astype("float32")
        state[0] = 1
        scale = np.zeros(1).astype("float32")
        scale[0] = 0.001
        self.inputs = {
            'X': np.random.random((8, 16, 7, 7)).astype("float32"),
            'InScale': scale,
            'InAccum': accum,
            'InState': state,
        }

        out_accum = np.zeros(1).astype("float32")
        out_state = np.zeros(1).astype("float32")
        out_scale = np.zeros(1).astype("float32")
        out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max(
            np.abs(self.inputs['X'])).astype("float32")
        out_state[0] = self.attrs['moving_rate'] * state[0] + 1
        out_scale = out_accum / out_state
        out_data = self.calc_output(out_scale)
        self.outputs = {
            'Out': out_data,
            'OutAccum': out_accum,
            'OutState': out_state,
            'OutScale': out_scale,
        }

    def init_type(self):
        self.op_type = "fake_quantize_moving_average_abs_max"

    def calc_output(self, out_scale):
        return np.round(self.inputs['X'] / out_scale * (
            (1 << (self.attrs['bit_length'] - 1)) - 1))

    def test_check_output(self):
        self.check_output()


class TestFakeQuantDequantMovingOp(TestMovingOpBase):
    def init_type(self):
        self.op_type = "fake_quantize_dequantize_moving_average_abs_max"

    def calc_output(self, out_scale):
        range_v = (1 << (self.attrs['bit_length'] - 1)) - 1
        return np.round(self.inputs['X'] / out_scale *
                        range_v) * out_scale / range_v

282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
    def test_check_grad(self):
        x = self.inputs["X"]
        gradient = [np.ones(x.shape) / np.product(x.shape)]
        self.check_grad(["X"], "Out", user_defined_grads=gradient)


class TestFakeQuantDequantAbsOp(OpTest):
    def setUp(self):
        self.op_type = "fake_quantize_dequantize_abs_max"
        self.attrs = {'bit_length': 8}
        self.inputs = {'X': np.random.random((124, 240)).astype("float32"), }
        scale = np.max(np.abs(self.inputs['X'])).astype("float32")
        out_data = self.calc_output(scale)
        self.outputs = {
            'Out': out_data,
            'OutScale': np.array(scale).astype("float32"),
        }

    def calc_output(self, scale):
        range_v = (1 << (self.attrs['bit_length'] - 1)) - 1
        return np.round(self.inputs['X'] / scale * range_v) * scale / range_v

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        x = self.inputs["X"]
        gradient = [np.ones(x.shape) / np.product(x.shape)]
        self.check_grad(["X"], "Out", user_defined_grads=gradient)

312

H
huangxu96 已提交
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
class TestChannelWiseFakeQuantDequantOp(OpTest):
    def setUp(self):
        self.set_arg()
        assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1."

        self.op_type = "fake_channel_wise_quantize_dequantize_abs_max"
        self.attrs = {'bit_length': 8, 'quant_axis': self.quant_axis}

        scales = []
        outputs = self.inputs['X'].copy()
        range_v = (1 << (self.attrs['bit_length'] - 1)) - 1
        if self.quant_axis == 0:
            for i in range(self.inputs['X'].shape[0]):
                scale_v = np.max(np.abs(self.inputs['X'][i])).astype("float32")
                scales.append(scale_v)
                outputs[i] = np.round(outputs[i] * range_v /
                                      scale_v) * scale_v / range_v
        elif self.quant_axis == 1:
            for i in range(self.inputs['X'].shape[1]):
                scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype(
                    "float32")
                scales.append(scale_v)
                outputs[:, i] = np.round(outputs[:, i] * range_v /
                                         scale_v) * scale_v / range_v

        self.outputs = {
            'Out': outputs,
            'OutScale': np.array(scales).astype("float32"),
        }

    def set_arg(self):
        self.quant_axis = 0
        self.inputs = {
            'X': np.random.random((3, 4, 64, 64)).astype("float32"),
        }

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        x = self.inputs["X"]
        gradient = [np.ones(x.shape) / np.product(x.shape)]
        self.check_grad(["X"], "Out", user_defined_grads=gradient)


class TestChannelWiseFakeQuantDequantOp1(TestChannelWiseFakeQuantDequantOp):
    def set_arg(self):
        self.quant_axis = 1
        self.inputs = {
            'X': np.random.random((15, 20, 5, 5)).astype("float32"),
        }


class TestChannelWiseFakeQuantDequantOp2(TestChannelWiseFakeQuantDequantOp):
    def set_arg(self):
        self.quant_axis = 0
        self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }


class TestChannelWiseFakeQuantDequantOp3(TestChannelWiseFakeQuantDequantOp):
    def set_arg(self):
        self.quant_axis = 1
        self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }


378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516
def quantize_max_abs(x, max_range):
    scale = np.max(np.abs(x).flatten())
    y = np.round(x / scale * max_range)
    return y, scale


def channel_wise_quantize_max_abs(x, quant_bit=8, quant_axis=0):
    assert quant_axis in [0, 1], "The quant_axis should be 0 or 1."
    scales = []
    y = x.copy()
    max_range = math.pow(2, quant_bit - 1) - 1
    if quant_axis == 0:
        for i in range(x.shape[0]):
            scale = np.max(np.abs(x[i])).astype("float32")
            scales.append(scale)
            y[i] = np.round(x[i] * max_range / scale)
    elif quant_axis == 1:
        for i in range(x.shape[1]):
            scale = np.max(np.abs(x[:, i])).astype("float32")
            scales.append(scale)
            y[:, i] = np.round(x[:, i] * max_range / scale)
    return y, scales


class TestChannelWiseQuantizeOp(OpTest):
    def set_args(self):
        self.bit_length = 8
        self.data_type = "float32"
        self.quant_axis = 0

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
        yq, scale = channel_wise_quantize_max_abs(x, self.bit_length,
                                                  self.quant_axis)
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis
        }
        self.outputs = {'Y': yq}

    def test_check_output(self):
        self.check_output()


class TestChannelWiseQuantizeOp1(TestChannelWiseQuantizeOp):
    def set_args(self):
        self.bit_length = 8
        self.data_type = "float32"
        self.quant_axis = 1


class TestChannelWiseQuantizeOpTrain(OpTest):
    def set_args(self):
        self.bit_length = 8
        self.data_type = "float32"
        self.quant_axis = 0
        self.is_test = False

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
        yq, scale = channel_wise_quantize_max_abs(x, self.bit_length,
                                                  self.quant_axis)
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis,
            'is_test': self.is_test
        }
        self.outputs = {'Y': yq, 'OutScale': scale}

    def test_check_output(self):
        self.check_output()


class TestquantizeOp(OpTest):
    def set_args(self):
        self.bit_length = 8
        self.quant_axis = -1
        self.max_range = math.pow(2, self.bit_length - 1) - 1
        self.data_type = "float32"

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(31, 65).astype(self.data_type)
        yq, scale = quantize_max_abs(x, self.max_range)
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis,
        }
        self.outputs = {'Y': yq}

    def test_check_output(self):
        self.check_output()


class TestquantizeOpTrain(TestquantizeOp):
    def set_args(self):
        self.bit_length = 8
        self.quant_axis = -1
        self.max_range = math.pow(2, self.bit_length - 1) - 1
        self.data_type = "float32"
        self.is_test = False

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(31, 65).astype(self.data_type)
        yq, scale = quantize_max_abs(x, self.max_range)
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis,
            'is_test': self.is_test
        }
        self.outputs = {'Y': yq, 'OutScale': scale}

    def test_check_output(self):
        self.check_output()


视言's avatar
视言 已提交
517 518
if __name__ == "__main__":
    unittest.main()