test_fake_quantize_op.py 21.6 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
视言's avatar
视言 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import itertools
16
import math
17 18 19
import unittest

import numpy as np
W
wanghuancoder 已提交
20
from eager_op_test import OpTest
视言's avatar
视言 已提交
21 22


23 24 25 26 27
def round_c_single_element(val):
    dtype = type(val)
    if val >= 0:
        return dtype(np.floor(val + 0.5))
    return dtype(np.ceil(val - 0.5))
28 29


30
# rounding to nearest ties away from zero
31 32 33
round_c = np.vectorize(round_c_single_element)


34 35 36 37 38
def get_compute_type(dtype):
    assert dtype in [np.float16, np.float32, np.float64]
    if dtype == np.float16:
        return np.float32
    return dtype
39

Z
Zhen Wang 已提交
40

41
class TestFakeQuantizeAbsMaxOp(OpTest):
42
    def setUp(self):
43
        self.op_type = 'fake_quantize_abs_max'
44 45
        self.attrs = {'bit_length': 8}

46 47 48
    def _fake_quantize_abs_max(
        self, dtype, input_shape, distribution, round_type='TiesAwayFromZero'
    ):
49 50 51 52
        input_data = distribution(input_shape).astype(dtype)
        compute_type = get_compute_type(dtype)
        scale = np.max(np.abs(input_data))
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
53
        inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale
54 55
        if round_type == 'TiesToEven':
            round_out = np.round(
56 57
                input_data.astype(compute_type) * inv_scale * bnt
            )
58
            output_data = np.clip(round_out, -bnt - 1, bnt)
59 60
            self.attrs['round_type'] = 0
        else:
61
            output_data = round_c(
62 63
                input_data.astype(compute_type) * inv_scale * bnt
            )
64
            self.attrs['round_type'] = 1
65 66 67
        self.inputs = {'X': input_data}
        self.outputs = {'Out': output_data, 'OutScale': scale}
        self.dtype = dtype
W
wanghuancoder 已提交
68
        self.check_output(check_dygraph=False)
69

70 71
    def test_fake_quantize_abs_max(self):
        self._fake_quantize_abs_max(np.float32, (124, 240), np.random.random)
72

73
    def test_fake_quantize_abs_max_round1(self):
74 75 76
        self._fake_quantize_abs_max(
            np.float32, (124, 240), np.random.random, round_type='TiesToEven'
        )
77

78 79
    def test_fake_quantize_abs_max_float16(self):
        self._fake_quantize_abs_max(np.float16, (124, 240), np.random.random)
80

81 82
    def test_fake_quantize_abs_max_underflow(self):
        self._fake_quantize_abs_max(np.float32, (10, 10), np.zeros)
83

84
    def test_fake_quantize_abs_max_underflow2(self):
85 86 87
        self._fake_quantize_abs_max(
            np.float32, (10, 10), lambda shape: np.full(shape, 1e-40)
        )
Z
Zhen Wang 已提交
88

89

90 91 92 93
class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest):
    def setUp(self):
        self.op_type = 'fake_channel_wise_quantize_abs_max'
        self.attrs = {'bit_length': 8}
94

95 96 97 98 99 100 101 102
    def _fake_channel_wise_quantize_abs_max(
        self,
        dtype,
        input_shape,
        quant_axis,
        distribution,
        round_type='TiesToEven',
    ):
103 104 105 106
        assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.'
        input_data = distribution(input_shape).astype(dtype)
        compute_type = get_compute_type(dtype)
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
107 108 109
        compute_axis = tuple(
            i for i in range(len(input_shape)) if i != quant_axis
        )
110
        scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True)
111 112
        if round_type == 'TiesToEven':
            round_out = np.round(
113 114
                input_data.astype(compute_type) / scale_broadcast * bnt
            )
115
            output_data = np.clip(round_out, -bnt - 1, bnt)
116 117
            self.attrs['round_type'] = 0
        else:
118 119 120
            output_data = round_c(
                bnt * input_data.astype(compute_type) / scale_broadcast
            )
121
            self.attrs['round_type'] = 1
122
        if quant_axis == 1:
123
            scale_broadcast = np.transpose(scale_broadcast, (1,) + compute_axis)
124 125 126 127 128
        scale = scale_broadcast.reshape(input_shape[quant_axis], -1)[:, 0]
        self.inputs = {'X': input_data}
        self.outputs = {'Out': output_data, 'OutScale': scale}
        self.dtype = dtype
        self.attrs['quant_axis'] = quant_axis
W
wanghuancoder 已提交
129
        self.check_output(check_dygraph=False)
130

131 132
    def test_fake_channel_wise_quantize_abs_max(self):
        dtype_options = [np.float32, np.float16]
133 134 135 136 137 138
        input_shape_quant_axis_options = [
            [(20, 15, 6, 6), 0],
            [(20, 15, 6, 6), 1],
            [(30, 30), 0],
            [(30, 30), 1],
        ]
139 140
        round_type_options = ['TiesToEven', 'TiesAwayFromZero']
        for dtype, input_shape_quant_axis, round_type in itertools.product(
141 142
            dtype_options, input_shape_quant_axis_options, round_type_options
        ):
143
            input_shape, quant_axis = input_shape_quant_axis
144 145 146 147 148 149
            with self.subTest(
                dtype=dtype,
                input_shape=input_shape,
                quant_axis=quant_axis,
                round_type=round_type,
            ):
150
                self._fake_channel_wise_quantize_abs_max(
151 152
                    dtype, input_shape, quant_axis, np.random.random, round_type
                )
153 154


155
class TestFakeQuantizeRangeAbsMaxOp(OpTest):
156
    def setUp(self):
157 158 159
        self.op_type = 'fake_quantize_range_abs_max'
        self.attrs = {'bit_length': 5, 'window_size': 1}

160 161 162 163 164 165 166 167
    def _fake_quantize_range_abs_max(
        self,
        dtype,
        input_shape,
        distribution,
        is_test=False,
        round_type='TiesToEven',
    ):
168 169 170 171 172 173 174 175
        input_data = distribution(input_shape).astype(dtype)
        compute_type = get_compute_type(dtype)
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
        in_scale = np.zeros(1).astype(dtype)
        out_scale = np.zeros(self.attrs['window_size']).astype(dtype)
        out_scale[0] = np.max(np.abs(input_data))
        if is_test:
            out_scale[0] = in_scale[0] = out_scale[0] - 1.0
176 177
        if round_type == 'TiesToEven':
            round_out = np.round(
178 179
                input_data.astype(compute_type) / out_scale[0] * bnt
            )
180
            self.attrs['round_type'] = 0
181
            output_data = np.clip(round_out, -bnt - 1, bnt)
182
        else:
183 184 185 186 187
            if is_test:
                clip_data = np.clip(input_data, -in_scale, in_scale)
            else:
                clip_data = input_data
            output_data = round_c(
188 189
                clip_data.astype(compute_type) / out_scale[0] * bnt
            )
190
            self.attrs['round_type'] = 1
视言's avatar
视言 已提交
191
        self.inputs = {
192 193
            'X': input_data,
            'Iter': np.zeros(1).astype(np.int64),
194
            'InScale': in_scale,
视言's avatar
视言 已提交
195 196
        }
        self.outputs = {
197 198
            'Out': output_data,
            'OutScale': out_scale[0],
199
            'OutScales': out_scale,
视言's avatar
视言 已提交
200
        }
201 202
        self.dtype = dtype
        self.attrs['is_test'] = is_test
W
wanghuancoder 已提交
203
        self.check_output(check_dygraph=False)
视言's avatar
视言 已提交
204

205
    def test_fake_quantize_range_abs_max(self):
206
        dtype_options = [np.float16, np.float32]
207
        is_test_options = [False, True]
208 209
        round_type_options = ['TiesToEven', 'TiesAwayFromZero']
        for dtype, is_test, round_type in itertools.product(
210 211
            dtype_options, is_test_options, round_type_options
        ):
212
            self.attrs['bit_length'] = 8 if is_test else 5
213 214 215
            with self.subTest(
                dtype=dtype, is_test=is_test, round_type=round_type
            ):
216
                self._fake_quantize_range_abs_max(
217 218
                    dtype,
                    (8, 16, 6, 6),
219 220
                    lambda shape: (np.random.random(shape) - 0.4) * 10,
                    is_test=is_test,
221 222
                    round_type=round_type,
                )
223 224


Z
Zhen Wang 已提交
225 226
class TestMovingAverageAbsMaxScaleOp(OpTest):
    def setUp(self):
227
        self.op_type = 'moving_average_abs_max_scale'
Z
Zhen Wang 已提交
228 229
        self.attrs = {'moving_rate': float(0.9), 'is_test': False}

230 231 232 233 234
    def _moving_average_abs_max_scale(self, dtype, input_shape, distribution):
        input_data = distribution(input_shape).astype(dtype)
        in_accum = np.ones(1).astype(dtype)
        in_state = np.ones(1).astype(dtype)
        out_accum = self.attrs['moving_rate'] * in_accum[0] + np.max(
235 236
            np.abs(input_data)
        )
237
        out_state = self.attrs['moving_rate'] * in_state[0] + 1.0
Z
Zhen Wang 已提交
238
        out_scale = out_accum / out_state
239 240 241
        self.inputs = {
            'X': input_data,
            'InAccum': in_accum,
242
            'InState': in_state,
243
        }
Z
Zhen Wang 已提交
244
        self.outputs = {
245
            'Out': input_data,
Z
Zhen Wang 已提交
246 247
            'OutAccum': out_accum,
            'OutState': out_state,
248
            'OutScale': out_scale,
Z
Zhen Wang 已提交
249
        }
250
        self.dtype = dtype
W
wanghuancoder 已提交
251
        self.check_output(check_dygraph=False)
Z
Zhen Wang 已提交
252

253
    def test_moving_average_abs_max(self):
254 255 256
        self._moving_average_abs_max_scale(
            np.float32, (8, 16, 7, 7), np.random.random
        )
Z
Zhen Wang 已提交
257

258

259
class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
260
    def setUp(self):
261 262 263
        self.op_type = 'fake_quantize_moving_average_abs_max'
        self.attrs = {'bit_length': 5, 'moving_rate': 0.9, 'is_test': False}

264 265 266 267 268 269 270 271 272
    def _fake_quantize_moving_average_abs_max(
        self,
        dtype,
        input_shape,
        distribution,
        dequantize=False,
        with_gradient=False,
        round_type='TiesAwayFromZero',
    ):
273 274 275 276 277 278 279 280 281 282
        input_data = distribution(input_shape).astype(dtype)
        compute_type = get_compute_type(dtype)
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
        in_accum = np.ones(1).astype(dtype)
        in_state = np.ones(1).astype(dtype)
        in_scale = np.array([0.001]).astype(dtype)
        out_accum = np.zeros(1).astype(dtype)
        out_state = np.zeros(1).astype(dtype)
        out_scale = np.zeros(1).astype(dtype)
        out_accum[0] = self.attrs['moving_rate'] * in_accum[0] + np.max(
283 284
            np.abs(input_data)
        )
285 286
        out_state[0] = self.attrs['moving_rate'] * in_state[0] + 1.0
        out_scale = out_accum / out_state
287 288
        if round_type == 'TiesToEven':
            round_out = np.round(
289 290
                input_data.astype(compute_type) / out_scale * bnt
            )
291
            quant_data = np.clip(round_out, -bnt - 1, bnt)
292 293
            self.attrs['round_type'] = 0
        else:
294
            quant_data = round_c(
295 296
                input_data.astype(compute_type) / out_scale * bnt
            )
297
            self.attrs['round_type'] = 1
298
        if dequantize:
299
            output_data = (quant_data * out_scale / bnt).astype(dtype)
300 301
            self.op_type = 'fake_quantize_dequantize_moving_average_abs_max'
        else:
302
            output_data = quant_data.astype(dtype)
303
        self.inputs = {
304 305 306
            'X': input_data,
            'InScale': in_scale,
            'InAccum': in_accum,
307
            'InState': in_state,
308 309
        }
        self.outputs = {
310
            'Out': output_data,
311 312
            'OutAccum': out_accum,
            'OutState': out_state,
313
            'OutScale': out_scale,
314
        }
315
        self.dtype = dtype
W
wanghuancoder 已提交
316
        self.check_output(check_dygraph=False)
317 318 319 320 321
        if with_gradient:
            gradient = [
                np.ones(input_data.shape) / np.product(input_data.shape)
            ]
            self.check_grad(['X'], 'Out', user_defined_grads=gradient)
322

323
    def test_fake_quantize_moving_average_abs_max(self):
324 325 326
        self._fake_quantize_moving_average_abs_max(
            np.float32, (8, 16, 7, 7), np.random.random
        )
327

328
    def test_fake_quantize_moving_average_abs_max_float16(self):
329 330 331
        self._fake_quantize_moving_average_abs_max(
            np.float16, (8, 16, 7, 7), np.random.random
        )
332

333
    def test_fake_quantize_moving_average_abs_max_round1(self):
334 335 336
        self._fake_quantize_moving_average_abs_max(
            np.float32, (8, 16, 7, 7), np.random.random, round_type='TiesToEven'
        )
337

338
    def test_fake_quantize_dequantize_moving_average_abs_max(self):
339 340 341 342 343 344 345
        self._fake_quantize_moving_average_abs_max(
            np.float32,
            (8, 16, 7, 7),
            np.random.random,
            dequantize=True,
            with_gradient=True,
        )
346

347

348
class TestFakeQuantizeDequantizeAbsMaxOp(OpTest):
349
    def setUp(self):
350
        self.op_type = 'fake_quantize_dequantize_abs_max'
351
        self.attrs = {'bit_length': 8}
352

353 354 355
    def _fake_quantize_dequantize_abs_max(
        self, dtype, input_shape, distribution, round_type='TiesAwayFromZero'
    ):
356 357 358
        input_data = distribution(input_shape).astype(dtype)
        scale = np.max(np.abs(input_data)).astype(dtype)
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
359 360
        if round_type == 'TiesToEven':
            round_out = np.round(input_data / scale * bnt)
361
            output_data = np.clip(round_out, -bnt - 1, bnt) * scale / bnt
362 363
            self.attrs['round_type'] = 0
        else:
364
            output_data = round_c(input_data / scale * bnt) * scale / bnt
365
            self.attrs['round_type'] = 1
366
        self.inputs = {'X': input_data}
367
        self.outputs = {
368
            'Out': output_data,
369
            'OutScale': np.array(scale).astype(dtype),
370
        }
371
        self.dtype = dtype
W
wanghuancoder 已提交
372
        self.check_output(check_dygraph=False)
373 374
        gradient = [np.ones(input_data.shape) / np.product(input_data.shape)]
        self.check_grad(['X'], 'Out', user_defined_grads=gradient)
375

376
    def test_fake_quantize_dequantize_abs_max(self):
377 378 379
        self._fake_quantize_dequantize_abs_max(
            np.float32, (124, 240), np.random.random
        )
380

381
    def test_fake_quantize_dequantize_abs_max_round1(self):
382 383 384
        self._fake_quantize_dequantize_abs_max(
            np.float32, (124, 240), np.random.random, round_type='TiesToEven'
        )
385

386

387
class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest):
H
huangxu96 已提交
388
    def setUp(self):
389 390
        self.op_type = 'fake_channel_wise_quantize_dequantize_abs_max'
        self.attrs = {'bit_length': 8}
H
huangxu96 已提交
391

392 393 394 395 396 397 398 399
    def _fake_channel_wise_quantize_dequantize_abs_max(
        self,
        dtype,
        input_shape,
        quant_axis,
        distribution,
        round_type='TiesToEven',
    ):
400 401 402 403 404
        assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.'
        input_data = distribution(input_shape).astype(dtype)
        compute_type = get_compute_type(dtype)
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
        output_data = input_data.copy().astype(compute_type)
405 406 407
        compute_axis = tuple(
            i for i in range(len(input_shape)) if i != quant_axis
        )
408
        scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True)
409 410
        if round_type == 'TiesToEven':
            round_out = np.round(bnt * output_data / scale_broadcast)
411 412 413
            output_data = (
                np.clip(round_out, -bnt - 1, bnt) * scale_broadcast / bnt
            )
414 415
            self.attrs['round_type'] = 0
        else:
416 417 418 419 420
            output_data = (
                round_c(bnt * output_data / scale_broadcast)
                * scale_broadcast
                / bnt
            )
421
            self.attrs['round_type'] = 1
422
        if quant_axis == 1:
423
            scale_broadcast = np.transpose(scale_broadcast, (1,) + compute_axis)
424 425 426 427 428
        scale = scale_broadcast.reshape(input_shape[quant_axis], -1)[:, 0]
        self.inputs = {'X': input_data}
        self.outputs = {'Out': output_data, 'OutScale': scale}
        self.dtype = dtype
        self.attrs['quant_axis'] = quant_axis
W
wanghuancoder 已提交
429
        self.check_output(check_dygraph=False)
430 431
        gradient = [np.ones(input_data.shape) / np.product(input_data.shape)]
        self.check_grad(['X'], 'Out', user_defined_grads=gradient)
H
huangxu96 已提交
432

433
    def test_channel_wise_fake_quant_dequant_abs_max(self):
434 435 436 437 438 439
        input_shape_quant_axis_options = [
            [(3, 4, 64, 64), 0],
            [(15, 20, 5, 5), 1],
            [(30, 15), 0],
            [(30, 15), 1],
        ]
440 441
        round_type_options = ['TiesToEven', 'TiesAwayFromZero']
        for input_shape_quant_axis, round_type in itertools.product(
442 443
            input_shape_quant_axis_options, round_type_options
        ):
444
            input_shape, quant_axis = input_shape_quant_axis
445 446 447 448 449
            with self.subTest(
                input_shape=input_shape,
                quant_axis=quant_axis,
                round_type=round_type,
            ):
450
                self._fake_channel_wise_quantize_dequantize_abs_max(
451 452 453 454
                    np.float32,
                    input_shape,
                    quant_axis,
                    np.random.random,
455 456
                    round_type=round_type,
                )
H
huangxu96 已提交
457 458


459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492
def quantize_max_abs(x, max_range):
    scale = np.max(np.abs(x).flatten())
    y = np.round(x / scale * max_range)
    return y, scale


def channel_wise_quantize_max_abs(x, quant_bit=8, quant_axis=0):
    assert quant_axis in [0, 1], "The quant_axis should be 0 or 1."
    scales = []
    y = x.copy()
    max_range = math.pow(2, quant_bit - 1) - 1
    if quant_axis == 0:
        for i in range(x.shape[0]):
            scale = np.max(np.abs(x[i])).astype("float32")
            scales.append(scale)
            y[i] = np.round(x[i] * max_range / scale)
    elif quant_axis == 1:
        for i in range(x.shape[1]):
            scale = np.max(np.abs(x[:, i])).astype("float32")
            scales.append(scale)
            y[:, i] = np.round(x[:, i] * max_range / scale)
    return y, scales


class TestChannelWiseQuantizeOp(OpTest):
    def set_args(self):
        self.bit_length = 8
        self.data_type = "float32"
        self.quant_axis = 0

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
493 494 495
        yq, scale = channel_wise_quantize_max_abs(
            x, self.bit_length, self.quant_axis
        )
496 497 498 499 500 501
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
502
            'quant_axis': self.quant_axis,
503 504 505 506
        }
        self.outputs = {'Y': yq}

    def test_check_output(self):
W
wanghuancoder 已提交
507
        self.check_output(check_dygraph=False)
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527


class TestChannelWiseQuantizeOp1(TestChannelWiseQuantizeOp):
    def set_args(self):
        self.bit_length = 8
        self.data_type = "float32"
        self.quant_axis = 1


class TestChannelWiseQuantizeOpTrain(OpTest):
    def set_args(self):
        self.bit_length = 8
        self.data_type = "float32"
        self.quant_axis = 0
        self.is_test = False

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(4, 3, 64, 64).astype(self.data_type)
528 529 530
        yq, scale = channel_wise_quantize_max_abs(
            x, self.bit_length, self.quant_axis
        )
531 532 533 534 535 536 537
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis,
538
            'is_test': self.is_test,
539 540 541 542
        }
        self.outputs = {'Y': yq, 'OutScale': scale}

    def test_check_output(self):
W
wanghuancoder 已提交
543
        self.check_output(check_dygraph=False)
544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568


class TestquantizeOp(OpTest):
    def set_args(self):
        self.bit_length = 8
        self.quant_axis = -1
        self.max_range = math.pow(2, self.bit_length - 1) - 1
        self.data_type = "float32"

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        x = np.random.randn(31, 65).astype(self.data_type)
        yq, scale = quantize_max_abs(x, self.max_range)
        scale = np.array(scale).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")

        self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point}
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis,
        }
        self.outputs = {'Y': yq}

    def test_check_output(self):
W
wanghuancoder 已提交
569
        self.check_output(check_dygraph=False)
570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585


class TestquantizeOpTrain(TestquantizeOp):
    def set_args(self):
        self.bit_length = 8
        self.quant_axis = -1
        self.max_range = math.pow(2, self.bit_length - 1) - 1
        self.data_type = "float32"
        self.is_test = False

    def setUp(self):
        self.set_args()
        self.op_type = "quantize_linear"
        self.attrs = {
            'bit_length': self.bit_length,
            'quant_axis': self.quant_axis,
586
            'moving_rate': 0.9,
587
            'is_test': self.is_test,
588
        }
589 590 591 592 593 594 595 596 597

        x = np.random.randn(31, 65).astype(self.data_type)
        scale = np.array([0.001]).astype(self.data_type)
        zero_point = np.zeros(scale.shape, dtype="int32")
        in_accum = np.ones(1).astype(self.data_type)
        in_state = np.ones(1).astype(self.data_type)
        out_accum = np.zeros(1).astype(self.data_type)
        out_state = np.zeros(1).astype(self.data_type)
        out_accum[0] = self.attrs['moving_rate'] * in_accum[0] + np.max(
598 599
            np.abs(x)
        )
600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618
        out_state[0] = self.attrs['moving_rate'] * in_state[0] + 1.0
        out_scale = out_accum / out_state

        round_out = np.round(x / out_scale * self.max_range)
        quant_data = np.clip(round_out, -self.max_range - 1, self.max_range)

        self.inputs = {
            'X': x,
            'Scale': scale,
            'ZeroPoint': zero_point,
            'InAccum': in_accum,
            'InState': in_state,
        }
        self.outputs = {
            'Y': quant_data,
            'OutScale': out_scale,
            'OutAccum': out_accum,
            'OutState': out_state,
        }
619 620

    def test_check_output(self):
W
wanghuancoder 已提交
621
        self.check_output(check_dygraph=False)
622 623


624
if __name__ == '__main__':
视言's avatar
视言 已提交
625
    unittest.main()