test_fake_quantize_op.py 21.4 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
视言's avatar
视言 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import itertools
16
import math
17 18 19
import unittest

import numpy as np
W
wanghuancoder 已提交
20
from eager_op_test import OpTest
视言's avatar
视言 已提交
21 22


23 24 25 26 27
def round_c_single_element(val):
    dtype = type(val)
    if val >= 0:
        return dtype(np.floor(val + 0.5))
    return dtype(np.ceil(val - 0.5))
28 29


30
# rounding to nearest ties away from zero
31 32 33
round_c = np.vectorize(round_c_single_element)


34 35 36 37 38
def get_compute_type(dtype):
    assert dtype in [np.float16, np.float32, np.float64]
    if dtype == np.float16:
        return np.float32
    return dtype
39

Z
Zhen Wang 已提交
40

41
class TestFakeQuantizeAbsMaxOp(OpTest):
42
    def setUp(self):
43
        self.op_type = 'fake_quantize_abs_max'
44 45
        self.attrs = {'bit_length': 8}

46 47 48
    def _fake_quantize_abs_max(
        self, dtype, input_shape, distribution, round_type='TiesAwayFromZero'
    ):
49 50
        input_data = distribution(input_shape).astype(dtype)
        compute_type = get_compute_type(dtype)
51
        scale = np.max(np.abs(input_data)).flatten()
52
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
53
        inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale
54 55
        if round_type == 'TiesToEven':
            round_out = np.round(
56 57
                input_data.astype(compute_type) * inv_scale * bnt
            )
58
            output_data = np.clip(round_out, -bnt - 1, bnt)
59 60
            self.attrs['round_type'] = 0
        else:
61
            output_data = round_c(
62 63
                input_data.astype(compute_type) * inv_scale * bnt
            )
64
            self.attrs['round_type'] = 1
65 66 67
        self.inputs = {'X': input_data}
        self.outputs = {'Out': output_data, 'OutScale': scale}
        self.dtype = dtype
W
wanghuancoder 已提交
68
        self.check_output(check_dygraph=False)
69

70 71
    def test_fake_quantize_abs_max(self):
        self._fake_quantize_abs_max(np.float32, (124, 240), np.random.random)
72

73
    def test_fake_quantize_abs_max_round1(self):
74 75 76
        self._fake_quantize_abs_max(
            np.float32, (124, 240), np.random.random, round_type='TiesToEven'
        )
77

78 79
    def test_fake_quantize_abs_max_float16(self):
        self._fake_quantize_abs_max(np.float16, (124, 240), np.random.random)
80

81 82
    def test_fake_quantize_abs_max_underflow(self):
        self._fake_quantize_abs_max(np.float32, (10, 10), np.zeros)
83

84
    def test_fake_quantize_abs_max_underflow2(self):
85 86 87
        self._fake_quantize_abs_max(
            np.float32, (10, 10), lambda shape: np.full(shape, 1e-40)
        )
Z
Zhen Wang 已提交
88

89

90 91 92 93
class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest):
    def setUp(self):
        self.op_type = 'fake_channel_wise_quantize_abs_max'
        self.attrs = {'bit_length': 8}
94

95 96 97 98 99 100 101 102
    def _fake_channel_wise_quantize_abs_max(
        self,
        dtype,
        input_shape,
        quant_axis,
        distribution,
        round_type='TiesToEven',
    ):
103 104 105 106
        assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.'
        input_data = distribution(input_shape).astype(dtype)
        compute_type = get_compute_type(dtype)
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
107 108 109
        compute_axis = tuple(
            i for i in range(len(input_shape)) if i != quant_axis
        )
110
        scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True)
111 112
        if round_type == 'TiesToEven':
            round_out = np.round(
113 114
                input_data.astype(compute_type) / scale_broadcast * bnt
            )
115
            output_data = np.clip(round_out, -bnt - 1, bnt)
116 117
            self.attrs['round_type'] = 0
        else:
118 119 120
            output_data = round_c(
                bnt * input_data.astype(compute_type) / scale_broadcast
            )
121
            self.attrs['round_type'] = 1
122
        if quant_axis == 1:
123
            scale_broadcast = np.transpose(scale_broadcast, (1,) + compute_axis)
124 125 126 127 128
        scale = scale_broadcast.reshape(input_shape[quant_axis], -1)[:, 0]
        self.inputs = {'X': input_data}
        self.outputs = {'Out': output_data, 'OutScale': scale}
        self.dtype = dtype
        self.attrs['quant_axis'] = quant_axis
W
wanghuancoder 已提交
129
        self.check_output(check_dygraph=False)
130

131 132
    def test_fake_channel_wise_quantize_abs_max(self):
        dtype_options = [np.float32, np.float16]
133 134 135 136 137 138
        input_shape_quant_axis_options = [
            [(20, 15, 6, 6), 0],
            [(20, 15, 6, 6), 1],
            [(30, 30), 0],
            [(30, 30), 1],
        ]
139 140
        round_type_options = ['TiesToEven', 'TiesAwayFromZero']
        for dtype, input_shape_quant_axis, round_type in itertools.product(
141 142
            dtype_options, input_shape_quant_axis_options, round_type_options
        ):
143
            input_shape, quant_axis = input_shape_quant_axis
144 145 146 147 148 149
            with self.subTest(
                dtype=dtype,
                input_shape=input_shape,
                quant_axis=quant_axis,
                round_type=round_type,
            ):
150
                self._fake_channel_wise_quantize_abs_max(
151 152
                    dtype, input_shape, quant_axis, np.random.random, round_type
                )
153 154


155
class TestFakeQuantizeRangeAbsMaxOp(OpTest):
156
    def setUp(self):
157 158 159
        self.op_type = 'fake_quantize_range_abs_max'
        self.attrs = {'bit_length': 5, 'window_size': 1}

160 161 162 163 164 165 166 167
    def _fake_quantize_range_abs_max(
        self,
        dtype,
        input_shape,
        distribution,
        is_test=False,
        round_type='TiesToEven',
    ):
168 169 170 171 172 173 174 175
        input_data = distribution(input_shape).astype(dtype)
        compute_type = get_compute_type(dtype)
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
        in_scale = np.zeros(1).astype(dtype)
        out_scale = np.zeros(self.attrs['window_size']).astype(dtype)
        out_scale[0] = np.max(np.abs(input_data))
        if is_test:
            out_scale[0] = in_scale[0] = out_scale[0] - 1.0
176 177
        if round_type == 'TiesToEven':
            round_out = np.round(
178 179
                input_data.astype(compute_type) / out_scale[0] * bnt
            )
180
            self.attrs['round_type'] = 0
181
            output_data = np.clip(round_out, -bnt - 1, bnt)
182
        else:
183 184 185 186 187
            if is_test:
                clip_data = np.clip(input_data, -in_scale, in_scale)
            else:
                clip_data = input_data
            output_data = round_c(
188 189
                clip_data.astype(compute_type) / out_scale[0] * bnt
            )
190
            self.attrs['round_type'] = 1
视言's avatar
视言 已提交
191
        self.inputs = {
192 193
            'X': input_data,
            'Iter': np.zeros(1).astype(np.int64),
194
            'InScale': in_scale,
视言's avatar
视言 已提交
195 196
        }
        self.outputs = {
197
            'Out': output_data,
198 199
            'OutScale': np.array([], dtype) if is_test else out_scale,
            'OutScales': np.array([], dtype) if is_test else out_scale,
视言's avatar
视言 已提交
200
        }
201 202
        self.dtype = dtype
        self.attrs['is_test'] = is_test
W
wanghuancoder 已提交
203
        self.check_output(check_dygraph=False)
视言's avatar
视言 已提交
204

205
    def test_fake_quantize_range_abs_max(self):
206
        dtype_options = [np.float16, np.float32]
207
        is_test_options = [False, True]
208 209
        round_type_options = ['TiesToEven', 'TiesAwayFromZero']
        for dtype, is_test, round_type in itertools.product(
210 211
            dtype_options, is_test_options, round_type_options
        ):
212
            self.attrs['bit_length'] = 8 if is_test else 5
213 214 215
            with self.subTest(
                dtype=dtype, is_test=is_test, round_type=round_type
            ):
216
                self._fake_quantize_range_abs_max(
217 218
                    dtype,
                    (8, 16, 6, 6),
219 220
                    lambda shape: (np.random.random(shape) - 0.4) * 10,
                    is_test=is_test,
221 222
                    round_type=round_type,
                )
223 224


Z
Zhen Wang 已提交
225 226
class TestMovingAverageAbsMaxScaleOp(OpTest):
    def setUp(self):
227
        self.op_type = 'moving_average_abs_max_scale'
Z
Zhen Wang 已提交
228 229
        self.attrs = {'moving_rate': float(0.9), 'is_test': False}

230 231 232 233
    def _moving_average_abs_max_scale(self, dtype, input_shape, distribution):
        input_data = distribution(input_shape).astype(dtype)
        in_accum = np.ones(1).astype(dtype)
        in_state = np.ones(1).astype(dtype)
234
        out_accum = self.attrs['moving_rate'] * in_accum + np.max(
235 236
            np.abs(input_data)
        )
237
        out_state = self.attrs['moving_rate'] * in_state + 1.0
Z
Zhen Wang 已提交
238
        out_scale = out_accum / out_state
239 240 241
        self.inputs = {
            'X': input_data,
            'InAccum': in_accum,
242
            'InState': in_state,
243
        }
Z
Zhen Wang 已提交
244
        self.outputs = {
245
            'Out': input_data,
Z
Zhen Wang 已提交
246 247
            'OutAccum': out_accum,
            'OutState': out_state,
248
            'OutScale': out_scale,
Z
Zhen Wang 已提交
249
        }
250
        self.dtype = dtype
W
wanghuancoder 已提交
251
        self.check_output(check_dygraph=False)
Z
Zhen Wang 已提交
252

253
    def test_moving_average_abs_max(self):
254 255 256
        self._moving_average_abs_max_scale(
            np.float32, (8, 16, 7, 7), np.random.random
        )
Z
Zhen Wang 已提交
257

258

259
class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
260
    def setUp(self):
261 262 263
        self.op_type = 'fake_quantize_moving_average_abs_max'
        self.attrs = {'bit_length': 5, 'moving_rate': 0.9, 'is_test': False}

264 265 266 267 268 269 270 271 272
    def _fake_quantize_moving_average_abs_max(
        self,
        dtype,
        input_shape,
        distribution,
        dequantize=False,
        with_gradient=False,
        round_type='TiesAwayFromZero',
    ):
273 274 275 276 277 278
        input_data = distribution(input_shape).astype(dtype)
        compute_type = get_compute_type(dtype)
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
        in_accum = np.ones(1).astype(dtype)
        in_state = np.ones(1).astype(dtype)
        in_scale = np.array([0.001]).astype(dtype)
279
        out_accum = self.attrs['moving_rate'] * in_accum + np.max(
280 281
            np.abs(input_data)
        )
282
        out_state = self.attrs['moving_rate'] * in_state + 1.0
283
        out_scale = out_accum / out_state
284 285
        if round_type == 'TiesToEven':
            round_out = np.round(
286 287
                input_data.astype(compute_type) / out_scale * bnt
            )
288
            quant_data = np.clip(round_out, -bnt - 1, bnt)
289 290
            self.attrs['round_type'] = 0
        else:
291
            quant_data = round_c(
292 293
                input_data.astype(compute_type) / out_scale * bnt
            )
294
            self.attrs['round_type'] = 1
295
        if dequantize:
296
            output_data = (quant_data * out_scale / bnt).astype(dtype)
297 298
            self.op_type = 'fake_quantize_dequantize_moving_average_abs_max'
        else:
299
            output_data = quant_data.astype(dtype)
300
        self.inputs = {
301 302 303
            'X': input_data,
            'InScale': in_scale,
            'InAccum': in_accum,
304
            'InState': in_state,
305 306
        }
        self.outputs = {
307
            'Out': output_data,
308 309
            'OutAccum': out_accum,
            'OutState': out_state,
310
            'OutScale': out_scale,
311
        }
312
        self.dtype = dtype
W
wanghuancoder 已提交
313
        self.check_output(check_dygraph=False)
314 315 316 317 318
        if with_gradient:
            gradient = [
                np.ones(input_data.shape) / np.product(input_data.shape)
            ]
            self.check_grad(['X'], 'Out', user_defined_grads=gradient)
319

320
    def test_fake_quantize_moving_average_abs_max(self):
321 322 323
        self._fake_quantize_moving_average_abs_max(
            np.float32, (8, 16, 7, 7), np.random.random
        )
324

325
    def test_fake_quantize_moving_average_abs_max_float16(self):
326 327 328
        self._fake_quantize_moving_average_abs_max(
            np.float16, (8, 16, 7, 7), np.random.random
        )
329

330
    def test_fake_quantize_moving_average_abs_max_round1(self):
331 332 333
        self._fake_quantize_moving_average_abs_max(
            np.float32, (8, 16, 7, 7), np.random.random, round_type='TiesToEven'
        )
334

335
    def test_fake_quantize_dequantize_moving_average_abs_max(self):
336 337 338 339 340 341 342
        self._fake_quantize_moving_average_abs_max(
            np.float32,
            (8, 16, 7, 7),
            np.random.random,
            dequantize=True,
            with_gradient=True,
        )
343

344

345
class TestFakeQuantizeDequantizeAbsMaxOp(OpTest):
346
    def setUp(self):
347
        self.op_type = 'fake_quantize_dequantize_abs_max'
348
        self.attrs = {'bit_length': 8}
349

350 351 352
    def _fake_quantize_dequantize_abs_max(
        self, dtype, input_shape, distribution, round_type='TiesAwayFromZero'
    ):
353
        input_data = distribution(input_shape).astype(dtype)
354
        scale = np.max(np.abs(input_data)).flatten().astype(dtype)
355
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
356 357
        if round_type == 'TiesToEven':
            round_out = np.round(input_data / scale * bnt)
358
            output_data = np.clip(round_out, -bnt - 1, bnt) * scale / bnt
359 360
            self.attrs['round_type'] = 0
        else:
361
            output_data = round_c(input_data / scale * bnt) * scale / bnt
362
            self.attrs['round_type'] = 1
363
        self.inputs = {'X': input_data}
364
        self.outputs = {
365
            'Out': output_data,
366
            'OutScale': np.array(scale).astype(dtype),
367
        }
368
        self.dtype = dtype
W
wanghuancoder 已提交
369
        self.check_output(check_dygraph=False)
370 371
        gradient = [np.ones(input_data.shape) / np.product(input_data.shape)]
        self.check_grad(['X'], 'Out', user_defined_grads=gradient)
372

373
    def test_fake_quantize_dequantize_abs_max(self):
374 375 376
        self._fake_quantize_dequantize_abs_max(
            np.float32, (124, 240), np.random.random
        )
377

378
    def test_fake_quantize_dequantize_abs_max_round1(self):
379 380 381
        self._fake_quantize_dequantize_abs_max(
            np.float32, (124, 240), np.random.random, round_type='TiesToEven'
        )
382

383

384
class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest):
H
huangxu96 已提交
385
    def setUp(self):
386 387
        self.op_type = 'fake_channel_wise_quantize_dequantize_abs_max'
        self.attrs = {'bit_length': 8}
H
huangxu96 已提交
388

389 390 391 392 393 394 395 396
    def _fake_channel_wise_quantize_dequantize_abs_max(
        self,
        dtype,
        input_shape,
        quant_axis,
        distribution,
        round_type='TiesToEven',
    ):
397 398 399 400 401
        assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.'
        input_data = distribution(input_shape).astype(dtype)
        compute_type = get_compute_type(dtype)
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
        output_data = input_data.copy().astype(compute_type)
402 403 404
        compute_axis = tuple(
            i for i in range(len(input_shape)) if i != quant_axis
        )
405
        scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True)
406 407
        if round_type == 'TiesToEven':
            round_out = np.round(bnt * output_data / scale_broadcast)
408 409 410
            output_data = (
                np.clip(round_out, -bnt - 1, bnt) * scale_broadcast / bnt
            )
411 412
            self.attrs['round_type'] = 0
        else:
413 414 415 416 417
            output_data = (
                round_c(bnt * output_data / scale_broadcast)
                * scale_broadcast
                / bnt
            )
418
            self.attrs['round_type'] = 1
419
        if quant_axis == 1:
420
            scale_broadcast = np.transpose(scale_broadcast, (1,) + compute_axis)
421 422 423 424 425
        scale = scale_broadcast.reshape(input_shape[quant_axis], -1)[:, 0]
        self.inputs = {'X': input_data}
        self.outputs = {'Out': output_data, 'OutScale': scale}
        self.dtype = dtype
        self.attrs['quant_axis'] = quant_axis
W
wanghuancoder 已提交
426
        self.check_output(check_dygraph=False)
427 428
        gradient = [np.ones(input_data.shape) / np.product(input_data.shape)]
        self.check_grad(['X'], 'Out', user_defined_grads=gradient)
H
huangxu96 已提交
429

430
    def test_channel_wise_fake_quant_dequant_abs_max(self):
431 432 433 434 435 436
        input_shape_quant_axis_options = [
            [(3, 4, 64, 64), 0],
            [(15, 20, 5, 5), 1],
            [(30, 15), 0],
            [(30, 15), 1],
        ]
437 438
        round_type_options = ['TiesToEven', 'TiesAwayFromZero']
        for input_shape_quant_axis, round_type in itertools.product(
439 440
            input_shape_quant_axis_options, round_type_options
        ):
441
            input_shape, quant_axis = input_shape_quant_axis
442 443 444 445 446
            with self.subTest(
                input_shape=input_shape,
                quant_axis=quant_axis,
                round_type=round_type,
            ):
447
                self._fake_channel_wise_quantize_dequantize_abs_max(
448 449 450 451
                    np.float32,
                    input_shape,
                    quant_axis,
                    np.random.random,
452 453
                    round_type=round_type,
                )
H
huangxu96 已提交
454 455


456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
def quantize_max_abs(x, max_range):
    scale = np.max(np.abs(x).flatten())
    y = np.round(x / scale * max_range)
    return y, scale


def channel_wise_quantize_max_abs(x, quant_bit=8, quant_axis=0):
    assert quant_axis in [0, 1], "The quant_axis should be 0 or 1."
    scales = []
    y = x.copy()
    max_range = math.pow(2, quant_bit - 1) - 1
    if quant_axis == 0:
        for i in range(x.shape[0]):
            scale = np.max(np.abs(x[i])).astype("float32")
            scales.append(scale)
            y[i] = np.round(x[i] * max_range / scale)
    elif quant_axis == 1:
        for i in range(x.shape[1]):
            scale = np.max(np.abs(x[:, i])).astype("float32")
            scales.append(scale)
            y[:, i] = np.round(x[:, i] * max_range / scale)
    return y, scales


class TestChannelWiseQuantizeOp(OpTest):
    def set_args(self):
        self.bit_length = 8
        self.data_type = "float32"
        self.quant_axis = 0

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
490 491 492
        yq, scale = channel_wise_quantize_max_abs(
            x, self.bit_length, self.quant_axis
        )
493 494 495 496 497 498
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
499
            'quant_axis': self.quant_axis,
500 501 502 503
        }
        self.outputs = {'Y': yq}

    def test_check_output(self):
W
wanghuancoder 已提交
504
        self.check_output(check_dygraph=False)
505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524


class TestChannelWiseQuantizeOp1(TestChannelWiseQuantizeOp):
    def set_args(self):
        self.bit_length = 8
        self.data_type = "float32"
        self.quant_axis = 1


class TestChannelWiseQuantizeOpTrain(OpTest):
    def set_args(self):
        self.bit_length = 8
        self.data_type = "float32"
        self.quant_axis = 0
        self.is_test = False

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
525 526 527
        yq, scale = channel_wise_quantize_max_abs(
            x, self.bit_length, self.quant_axis
        )
528 529 530 531 532 533 534
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis,
535
            'is_test': self.is_test,
536 537 538 539
        }
        self.outputs = {'Y': yq, 'OutScale': scale}

    def test_check_output(self):
W
wanghuancoder 已提交
540
        self.check_output(check_dygraph=False)
541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565


class TestquantizeOp(OpTest):
    def set_args(self):
        self.bit_length = 8
        self.quant_axis = -1
        self.max_range = math.pow(2, self.bit_length - 1) - 1
        self.data_type = "float32"

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(31, 65).astype(self.data_type)
        yq, scale = quantize_max_abs(x, self.max_range)
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis,
        }
        self.outputs = {'Y': yq}

    def test_check_output(self):
W
wanghuancoder 已提交
566
        self.check_output(check_dygraph=False)
567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582


class TestquantizeOpTrain(TestquantizeOp):
    def set_args(self):
        self.bit_length = 8
        self.quant_axis = -1
        self.max_range = math.pow(2, self.bit_length - 1) - 1
        self.data_type = "float32"
        self.is_test = False

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis,
583
            'moving_rate': 0.9,
584
            'is_test': self.is_test,
585
        }
586 587 588 589 590 591

        x = np.random.randn(31, 65).astype(self.data_type)
        scale = np.array([0.001]).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")
        in_accum = np.ones(1).astype(self.data_type)
        in_state = np.ones(1).astype(self.data_type)
592 593
        out_accum = self.attrs['moving_rate'] * in_accum + np.max(np.abs(x))
        out_state = self.attrs['moving_rate'] * in_state + 1.0
594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611
        out_scale = out_accum / out_state

        round_out = np.round(x / out_scale * self.max_range)
        quant_data = np.clip(round_out, -self.max_range - 1, self.max_range)

        self.inputs = {
            'X': x,
            'Scale': scale,
            'ZeroPoint': zero_point,
            'InAccum': in_accum,
            'InState': in_state,
        }
        self.outputs = {
            'Y': quant_data,
            'OutScale': out_scale,
            'OutAccum': out_accum,
            'OutState': out_state,
        }
612 613

    def test_check_output(self):
W
wanghuancoder 已提交
614
        self.check_output(check_dygraph=False)
615 616


617
if __name__ == '__main__':
视言's avatar
视言 已提交
618
    unittest.main()