未验证 提交 0ae8a2d6 编写于 作者: L Leo Chen 提交者: GitHub

Fix the underflow of fp16 fake quantize operators (#43088)

Co-authored-by: NRyan Jeng <rjeng@nvidia.com>
上级 4700a08e
......@@ -217,16 +217,18 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x;
T s = scale[0];
T inv_s = inverse(s);
T bin_cnt_t = static_cast<T>(bin_cnt);
using ComputeDataType = typename QuantizeDataType<T>::type;
ComputeDataType s = static_cast<ComputeDataType>(scale[0]);
ComputeDataType inv_s = inverse(s);
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[i];
T v = x > s ? s : x;
ComputeDataType x = static_cast<ComputeDataType>(in[i]);
ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt_t * inv_s * v;
out[i] = static_cast<T>(
round(static_cast<typename QuantizeDataType<T>::type>(v)));
out[i] = static_cast<T>(round(v));
}
}
......@@ -237,18 +239,19 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x;
T s = scale[0];
T inv_s = inverse(s);
T bin_cnt_t = static_cast<T>(bin_cnt);
using ComputeDataType = typename QuantizeDataType<T>::type;
ComputeDataType s = static_cast<ComputeDataType>(scale[0]);
ComputeDataType inv_s = inverse(s);
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[i];
ComputeDataType x = static_cast<ComputeDataType>(in[i]);
x = x > s ? s : x;
x = x < -s ? -s : x;
x = bin_cnt_t * inv_s * x;
x = static_cast<T>(
round(static_cast<typename QuantizeDataType<T>::type>(x)));
out[i] = (x * s) / bin_cnt_t;
x = round(x);
out[i] = static_cast<T>((x * s) / bin_cnt_t);
}
}
......@@ -302,17 +305,18 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
T s = scale[blockIdx.x];
T inv_s = inverse(s);
T bin_cnt_t = static_cast<T>(bin_cnt);
using ComputeDataType = typename QuantizeDataType<T>::type;
ComputeDataType s = static_cast<ComputeDataType>(scale[blockIdx.x]);
ComputeDataType inv_s = inverse(s);
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
for (int64_t i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i];
T v = x > s ? s : x;
ComputeDataType x = static_cast<ComputeDataType>(in_c[i]);
ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt_t * inv_s * v;
out_c[i] = static_cast<T>(
round(static_cast<typename QuantizeDataType<T>::type>(v)));
out_c[i] = static_cast<T>(round(v));
}
}
......@@ -322,16 +326,17 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN(
const T* in, const T* scale, const int bin_cnt, const int64_t n,
const int nScale, const int quant_stride, T* out) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
T bin_cnt_t = static_cast<T>(bin_cnt);
using ComputeDataType = typename QuantizeDataType<T>::type;
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) {
T s = scale[(i / quant_stride) % nScale];
T inv_s = inverse(s);
T x = in[i];
T v = x > s ? s : x;
ComputeDataType s =
static_cast<ComputeDataType>(scale[(i / quant_stride) % nScale]);
ComputeDataType inv_s = inverse(s);
ComputeDataType x = static_cast<ComputeDataType>(in[i]);
ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt_t * inv_s * v;
out[i] = static_cast<T>(
round(static_cast<typename QuantizeDataType<T>::type>(v)));
out[i] = static_cast<T>(round(v));
}
}
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -15,439 +15,312 @@
from __future__ import print_function
import unittest
import math
import itertools
import numpy as np
import math
from op_test import OpTest
import paddle.fluid.core as core
# numpy.round has different behavior in comparision to c++ round function
# so we use round_c instead of numpy.round to align the output data
def round_c_single_element(x):
dtype = type(x)
if x >= 0:
return dtype(np.floor(x + 0.5))
else:
return dtype(np.ceil(x - 0.5))
def round_c_single_element(val):
dtype = type(val)
if val >= 0:
return dtype(np.floor(val + 0.5))
return dtype(np.ceil(val - 0.5))
round_c = np.vectorize(round_c_single_element)
class TestFakeQuantizeOp(OpTest):
def setUp(self):
self.set_dtype()
self.op_type = "fake_quantize_abs_max"
self.attrs = {'bit_length': 8}
self.inputs = {'X': np.random.random((124, 240)).astype(self.dtype), }
scale = np.max(np.abs(self.inputs['X'])).astype(self.dtype)
self.outputs = {
'Out': round_c(self.inputs['X'] / scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1)),
'OutScale': np.array(scale).astype(self.dtype),
}
def set_dtype(self):
self.dtype = np.float32
def get_compute_type(dtype):
assert dtype in [np.float16, np.float32, np.float64]
if dtype == np.float16:
return np.float32
return dtype
def test_check_output(self):
self.check_output()
class TestFakeQuantizeOpFloat16(TestFakeQuantizeOp):
def set_dtype(self):
self.dtype = np.float16
class TestFakeQuantizeOp1(OpTest):
class TestFakeQuantizeAbsMaxOp(OpTest):
def setUp(self):
self.op_type = "fake_quantize_abs_max"
self.op_type = 'fake_quantize_abs_max'
self.attrs = {'bit_length': 8}
self.inputs = {'X': np.zeros((10, 10)).astype("float32"), }
scale = np.max(np.abs(self.inputs['X'])).astype("float32")
inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale
self.outputs = {
'Out': np.round(self.inputs['X'] * inv_scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1)),
'OutScale': np.array(scale).astype("float32"),
}
def test_check_output(self):
self.check_output()
class TestFakeQuantizeOp2(OpTest):
def setUp(self):
self.op_type = "fake_quantize_abs_max"
self.attrs = {'bit_length': 8}
self.inputs = {'X': np.full((10, 10), 1e-40).astype("float32"), }
scale = np.max(np.abs(self.inputs['X'])).astype("float32")
def _fake_quantize_abs_max(self, dtype, input_shape, distribution):
input_data = distribution(input_shape).astype(dtype)
compute_type = get_compute_type(dtype)
scale = np.max(np.abs(input_data))
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale
self.outputs = {
'Out': np.round(self.inputs['X'] * inv_scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1)),
'OutScale': np.array(scale).astype("float32"),
}
def test_check_output(self):
output_data = round_c(input_data.astype(compute_type) * inv_scale * bnt)
self.inputs = {'X': input_data}
self.outputs = {'Out': output_data, 'OutScale': scale}
self.dtype = dtype
self.check_output()
def test_fake_quantize_abs_max(self):
self._fake_quantize_abs_max(np.float32, (124, 240), np.random.random)
class TestFakeChannelWiseQuantizeOp(OpTest):
def setUp(self):
self.set_dtype()
self.set_arg()
assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1."
def test_fake_quantize_abs_max_float16(self):
self._fake_quantize_abs_max(np.float16, (124, 240), np.random.random)
self.op_type = "fake_channel_wise_quantize_abs_max"
self.attrs = {'bit_length': 8, 'quant_axis': self.quant_axis}
def test_fake_quantize_abs_max_underflow(self):
self._fake_quantize_abs_max(np.float32, (10, 10), np.zeros)
scales = []
outputs = self.inputs['X'].copy()
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
if self.quant_axis == 0:
for i in range(self.inputs['X'].shape[0]):
scale_v = np.max(np.abs(self.inputs['X'][i])).astype(self.dtype)
scales.append(scale_v)
outputs[i] = round_c(
self.dtype(bnt) * (self.dtype(1.0) / scale_v) * outputs[i])
elif self.quant_axis == 1:
for i in range(self.inputs['X'].shape[1]):
scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype(
self.dtype)
scales.append(scale_v)
outputs[:, i] = round_c(
self.dtype(bnt) * (self.dtype(1.0) / scale_v) *
outputs[:, i])
self.outputs = {
'Out': outputs,
'OutScale': np.array(scales).astype(self.dtype),
}
def test_fake_quantize_abs_max_underflow2(self):
self._fake_quantize_abs_max(np.float32, (10, 10),
lambda shape: np.full(shape, 1e-40))
def set_arg(self):
self.quant_axis = 0
self.inputs = {
'X': np.random.random((20, 15, 6, 6)).astype(self.dtype),
}
def set_dtype(self):
self.dtype = np.float32
class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest):
def setUp(self):
self.op_type = 'fake_channel_wise_quantize_abs_max'
self.attrs = {'bit_length': 8}
def test_check_output(self):
def _fake_channel_wise_quantize_abs_max(self, dtype, input_shape,
quant_axis, distribution):
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.'
input_data = distribution(input_shape).astype(dtype)
compute_type = get_compute_type(dtype)
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
compute_axis = tuple(
i for i in range(len(input_shape)) if i != quant_axis)
scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True)
output_data = round_c(bnt * input_data.astype(compute_type) /
scale_broadcast)
if quant_axis == 1:
scale_broadcast = np.transpose(scale_broadcast,
(1, ) + compute_axis)
scale = scale_broadcast.reshape(input_shape[quant_axis], -1)[:, 0]
self.inputs = {'X': input_data}
self.outputs = {'Out': output_data, 'OutScale': scale}
self.dtype = dtype
self.attrs['quant_axis'] = quant_axis
self.check_output()
class TestFakeChannelWiseQuantizeOpFloat16(TestFakeChannelWiseQuantizeOp):
def set_dtype(self):
self.dtype = np.float16
class TestFakeChannelWiseQuantizeOp1(TestFakeChannelWiseQuantizeOp):
def set_quant_axis(self):
self.quant_axis = 1
self.inputs = {
'X': np.random.random((15, 20, 5, 5)).astype(self.dtype),
}
class TestFakeChannelWiseQuantizeOp1Float16(TestFakeChannelWiseQuantizeOp1):
def set_dtype(self):
self.dtype = np.float16
class TestFakeChannelWiseQuantizeOp2(TestFakeChannelWiseQuantizeOp):
def set_quant_axis(self):
self.quant_axis = 0
self.inputs = {'X': np.random.random((30, 15)).astype(self.dtype), }
class TestFakeChannelWiseQuantizeOp3(TestFakeChannelWiseQuantizeOp):
def set_quant_axis(self):
self.quant_axis = 1
self.inputs = {'X': np.random.random((30, 15)).astype(self.dtype), }
def test_fake_channel_wise_quantize_abs_max(self):
dtype_options = [np.float32, np.float16]
input_shape_quant_axis_options = [[(20, 15, 6, 6), 0],
[(15, 20, 5, 5), 1], [(30, 15), 0],
[(30, 15), 1]]
for dtype, input_shape_quant_axis in itertools.product(
dtype_options, input_shape_quant_axis_options):
input_shape, quant_axis = input_shape_quant_axis
with self.subTest(
dtype=dtype, input_shape=input_shape,
quant_axis=quant_axis):
self._fake_channel_wise_quantize_abs_max(
dtype, input_shape, quant_axis, np.random.random)
class TestFakeQuantizeRangeAbsMaxOp(OpTest):
def setUp(self):
self.set_dtype()
self.op_type = "fake_quantize_range_abs_max"
self.attrs = {
'bit_length': int(5),
'window_size': int(1),
'is_test': False
}
x = (np.random.random((8, 16, 7, 7)) - 0.5) * 10
x = x.astype(self.dtype)
self.op_type = 'fake_quantize_range_abs_max'
self.attrs = {'bit_length': 5, 'window_size': 1}
def _fake_quantize_range_abs_max(self,
dtype,
input_shape,
distribution,
is_test=False):
input_data = distribution(input_shape).astype(dtype)
compute_type = get_compute_type(dtype)
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
in_scale = np.zeros(1).astype(dtype)
out_scale = np.zeros(self.attrs['window_size']).astype(dtype)
out_scale[0] = np.max(np.abs(input_data))
if is_test:
out_scale[0] = in_scale[0] = out_scale[0] - 1.0
clip_data = np.clip(input_data, -in_scale, in_scale)
else:
clip_data = input_data
output_data = round_c(
clip_data.astype(compute_type) / out_scale[0] * bnt)
self.inputs = {
'X': x,
'Iter': np.zeros(1).astype("int64"),
'InScale': np.zeros(1).astype(self.dtype)
'X': input_data,
'Iter': np.zeros(1).astype(np.int64),
'InScale': in_scale
}
scale = np.max(np.abs(self.inputs['X'])).astype(self.dtype)
out_scales = np.zeros(self.attrs['window_size']).astype(self.dtype)
out_scales[0] = scale
self.outputs = {
'Out': round_c(
self.dtype((1 << (self.attrs['bit_length'] - 1)) - 1) *
(self.dtype(1.0) / scale) * self.inputs['X']),
'OutScale': scale,
'OutScales': out_scales,
'Out': output_data,
'OutScale': out_scale[0],
'OutScales': out_scale
}
def set_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.dtype = dtype
self.attrs['is_test'] = is_test
self.check_output()
class TestFakeQuantizeRangeAbsMaxOpFloat16(TestFakeQuantizeRangeAbsMaxOp):
def set_dtype(self):
self.dtype = np.float16
def test_fake_quantize_range_abs_max(self):
dtype_options = [np.float32, np.float16]
is_test_options = [False, True]
for dtype, is_test in itertools.product(dtype_options, is_test_options):
self.attrs['bit_length'] = 8 if is_test else 5
with self.subTest(dtype=dtype, is_test=is_test):
self._fake_quantize_range_abs_max(
dtype, (8, 16, 7, 7),
lambda shape: (np.random.random(shape) - 0.5) * 10,
is_test=is_test)
class TestMovingAverageAbsMaxScaleOp(OpTest):
def setUp(self):
self.op_type = "moving_average_abs_max_scale"
self.op_type = 'moving_average_abs_max_scale'
self.attrs = {'moving_rate': float(0.9), 'is_test': False}
accum = np.zeros(1).astype("float32")
accum[0] = 1
state = np.zeros(1).astype("float32")
state[0] = 1
x = np.random.random((8, 16, 7, 7)).astype("float32")
self.inputs = {
'X': x,
'InAccum': accum,
'InState': state,
}
out = x
out_accum = np.zeros(1).astype("float32")
out_state = np.zeros(1).astype("float32")
out_scale = np.zeros(1).astype("float32")
out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max(
np.abs(self.inputs['X'])).astype("float32")
out_state[0] = self.attrs['moving_rate'] * state[0] + 1
def _moving_average_abs_max_scale(self, dtype, input_shape, distribution):
input_data = distribution(input_shape).astype(dtype)
in_accum = np.ones(1).astype(dtype)
in_state = np.ones(1).astype(dtype)
out_accum = self.attrs['moving_rate'] * in_accum[0] + np.max(
np.abs(input_data))
out_state = self.attrs['moving_rate'] * in_state[0] + 1.0
out_scale = out_accum / out_state
self.inputs = {
'X': input_data,
'InAccum': in_accum,
'InState': in_state
}
self.outputs = {
'Out': out,
'Out': input_data,
'OutAccum': out_accum,
'OutState': out_state,
'OutScale': out_scale,
'OutScale': out_scale
}
def test_check_output(self):
self.dtype = dtype
self.check_output()
def test_moving_average_abs_max(self):
self._moving_average_abs_max_scale(np.float32, (8, 16, 7, 7),
np.random.random)
class TestFakeQuantizeRangeAbsMaxOp2(OpTest):
def setUp(self):
self.set_dtype()
self.op_type = "fake_quantize_range_abs_max"
self.attrs = {
'bit_length': int(8),
'window_size': int(1),
'is_test': True
}
x = (np.random.random((8, 16, 7, 7)) - 0.5) * 10
x = x.astype(self.dtype)
scale = np.array([np.max(np.abs(x)).astype(self.dtype) - 1.0])
out_scales = np.zeros(self.attrs['window_size']).astype(self.dtype)
out_scales[0] = scale.astype(self.dtype)
self.inputs = {
'X': x,
'Iter': np.zeros(1).astype("int64"),
'InScale': scale.astype(self.dtype)
}
xs = np.clip(x, -scale, scale).astype(self.dtype)
qs = round_c(
self.dtype(
self.dtype((1 << (self.attrs['bit_length'] - 1)) - 1) * (
self.dtype(1.0) / scale) * xs))
self.outputs = {
'Out': qs,
'OutScale': scale.astype(self.dtype),
'OutScales': out_scales,
}
def set_dtype(self):
self.dtype = np.float32
def test_check_output(self):
self.check_output(no_check_set=set(['OutScale', 'OutScales']))
class TestFakeQuantizeRangeAbsMaxOp2Float16(TestFakeQuantizeRangeAbsMaxOp2):
def set_dtype(self):
self.dtype = np.float16
class TestMovingOpBase(OpTest):
class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
def setUp(self):
self.set_dtype()
self.init_type()
self.attrs = {
'bit_length': int(5),
'moving_rate': float(0.9),
'is_test': False
}
accum = np.zeros(1).astype(self.dtype)
accum[0] = 1
state = np.zeros(1).astype(self.dtype)
state[0] = self.dtype(1.0)
scale = np.zeros(1).astype(self.dtype)
scale[0] = 0.001
self.op_type = 'fake_quantize_moving_average_abs_max'
self.attrs = {'bit_length': 5, 'moving_rate': 0.9, 'is_test': False}
def _fake_quantize_moving_average_abs_max(self,
dtype,
input_shape,
distribution,
dequantize=False,
with_gradient=False):
input_data = distribution(input_shape).astype(dtype)
compute_type = get_compute_type(dtype)
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
in_accum = np.ones(1).astype(dtype)
in_state = np.ones(1).astype(dtype)
in_scale = np.array([0.001]).astype(dtype)
out_accum = np.zeros(1).astype(dtype)
out_state = np.zeros(1).astype(dtype)
out_scale = np.zeros(1).astype(dtype)
out_accum[0] = self.attrs['moving_rate'] * in_accum[0] + np.max(
np.abs(input_data))
out_state[0] = self.attrs['moving_rate'] * in_state[0] + 1.0
out_scale = out_accum / out_state
round_data = round_c(input_data.astype(compute_type) / out_scale * bnt)
if dequantize:
output_data = (round_data * out_scale / bnt).astype(dtype)
self.op_type = 'fake_quantize_dequantize_moving_average_abs_max'
else:
output_data = round_data.astype(dtype)
self.inputs = {
'X': np.random.random((8, 16, 7, 7)).astype(self.dtype),
'InScale': scale,
'InAccum': accum,
'InState': state,
'X': input_data,
'InScale': in_scale,
'InAccum': in_accum,
'InState': in_state
}
out_accum = np.zeros(1).astype(self.dtype)
out_state = np.zeros(1).astype(self.dtype)
out_scale = np.zeros(1).astype(self.dtype)
out_accum[0] = self.dtype(self.attrs['moving_rate']) * self.dtype(accum[
0]) + np.max(np.abs(self.inputs['X'])).astype(self.dtype)
out_state[0] = self.dtype(self.attrs['moving_rate']) * self.dtype(state[
0]) + self.dtype(1.0)
out_scale = self.dtype(self.dtype(out_accum) / self.dtype(out_state))
out_data = self.calc_output(out_scale)
self.outputs = {
'Out': out_data,
'Out': output_data,
'OutAccum': out_accum,
'OutState': out_state,
'OutScale': out_scale,
'OutScale': out_scale
}
def set_dtype(self):
self.dtype = np.float32
def init_type(self):
self.op_type = "fake_quantize_moving_average_abs_max"
def calc_output(self, out_scale):
return round_c(self.inputs['X'] / out_scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1))
def test_check_output(self):
self.dtype = dtype
self.check_output()
if with_gradient:
gradient = [
np.ones(input_data.shape) / np.product(input_data.shape)
]
self.check_grad(['X'], 'Out', user_defined_grads=gradient)
def test_fake_quantize_moving_average_abs_max(self):
self._fake_quantize_moving_average_abs_max(np.float32, (8, 16, 7, 7),
np.random.random)
class TestMovingOpBaseFloat16(TestMovingOpBase):
def set_dtype(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output(atol=1e-2)
def test_fake_quantize_moving_average_abs_max_float16(self):
self._fake_quantize_moving_average_abs_max(np.float16, (8, 16, 7, 7),
np.random.random)
def test_fake_quantize_dequantize_moving_average_abs_max(self):
self._fake_quantize_moving_average_abs_max(
np.float32, (8, 16, 7, 7),
np.random.random,
dequantize=True,
with_gradient=True)
class TestFakeQuantDequantMovingOp(TestMovingOpBase):
def init_type(self):
self.op_type = "fake_quantize_dequantize_moving_average_abs_max"
def calc_output(self, out_scale):
range_v = (1 << (self.attrs['bit_length'] - 1)) - 1
return np.round(self.inputs['X'] / out_scale *
range_v) * out_scale / range_v
def test_check_grad(self):
x = self.inputs["X"]
gradient = [np.ones(x.shape) / np.product(x.shape)]
self.check_grad(["X"], "Out", user_defined_grads=gradient)
class TestFakeQuantDequantAbsOp(OpTest):
class TestFakeQuantizeDequantizeAbsMaxOp(OpTest):
def setUp(self):
self.op_type = "fake_quantize_dequantize_abs_max"
self.op_type = 'fake_quantize_dequantize_abs_max'
self.attrs = {'bit_length': 8}
self.inputs = {'X': np.random.random((124, 240)).astype("float32"), }
scale = np.max(np.abs(self.inputs['X'])).astype("float32")
out_data = self.calc_output(scale)
def _fake_quantize_dequantize_abs_max(self, dtype, input_shape,
distribution):
input_data = distribution(input_shape).astype(dtype)
scale = np.max(np.abs(input_data)).astype(dtype)
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
output_data = round_c(input_data / scale * bnt) * scale / bnt
self.inputs = {'X': input_data}
self.outputs = {
'Out': out_data,
'OutScale': np.array(scale).astype("float32"),
'Out': output_data,
'OutScale': np.array(scale).astype(dtype)
}
def calc_output(self, scale):
range_v = (1 << (self.attrs['bit_length'] - 1)) - 1
return np.round(self.inputs['X'] / scale * range_v) * scale / range_v
def test_check_output(self):
self.dtype = dtype
self.check_output()
gradient = [np.ones(input_data.shape) / np.product(input_data.shape)]
self.check_grad(['X'], 'Out', user_defined_grads=gradient)
def test_check_grad(self):
x = self.inputs["X"]
gradient = [np.ones(x.shape) / np.product(x.shape)]
self.check_grad(["X"], "Out", user_defined_grads=gradient)
def test_fake_quantize_dequantize_abs_max(self):
self._fake_quantize_dequantize_abs_max(np.float32, (124, 240),
np.random.random)
class TestChannelWiseFakeQuantDequantOp(OpTest):
class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest):
def setUp(self):
self.set_arg()
assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1."
self.op_type = "fake_channel_wise_quantize_dequantize_abs_max"
self.attrs = {'bit_length': 8, 'quant_axis': self.quant_axis}
scales = []
outputs = self.inputs['X'].copy()
range_v = (1 << (self.attrs['bit_length'] - 1)) - 1
if self.quant_axis == 0:
for i in range(self.inputs['X'].shape[0]):
scale_v = np.max(np.abs(self.inputs['X'][i])).astype("float32")
scales.append(scale_v)
outputs[i] = np.round(outputs[i] * range_v /
scale_v) * scale_v / range_v
elif self.quant_axis == 1:
for i in range(self.inputs['X'].shape[1]):
scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype(
"float32")
scales.append(scale_v)
outputs[:, i] = np.round(outputs[:, i] * range_v /
scale_v) * scale_v / range_v
self.outputs = {
'Out': outputs,
'OutScale': np.array(scales).astype("float32"),
}
def set_arg(self):
self.quant_axis = 0
self.inputs = {
'X': np.random.random((3, 4, 64, 64)).astype("float32"),
}
self.op_type = 'fake_channel_wise_quantize_dequantize_abs_max'
self.attrs = {'bit_length': 8}
def test_check_output(self):
def _fake_channel_wise_quantize_dequantize_abs_max(
self, dtype, input_shape, quant_axis, distribution):
assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.'
input_data = distribution(input_shape).astype(dtype)
compute_type = get_compute_type(dtype)
bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
output_data = input_data.copy().astype(compute_type)
compute_axis = tuple(
i for i in range(len(input_shape)) if i != quant_axis)
scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True)
output_data = round_c(bnt * output_data /
scale_broadcast) * scale_broadcast / bnt
if quant_axis == 1:
scale_broadcast = np.transpose(scale_broadcast,
(1, ) + compute_axis)
scale = scale_broadcast.reshape(input_shape[quant_axis], -1)[:, 0]
self.inputs = {'X': input_data}
self.outputs = {'Out': output_data, 'OutScale': scale}
self.dtype = dtype
self.attrs['quant_axis'] = quant_axis
self.check_output()
gradient = [np.ones(input_data.shape) / np.product(input_data.shape)]
self.check_grad(['X'], 'Out', user_defined_grads=gradient)
def test_check_grad(self):
x = self.inputs["X"]
gradient = [np.ones(x.shape) / np.product(x.shape)]
self.check_grad(["X"], "Out", user_defined_grads=gradient)
class TestChannelWiseFakeQuantDequantOp1(TestChannelWiseFakeQuantDequantOp):
def set_arg(self):
self.quant_axis = 1
self.inputs = {
'X': np.random.random((15, 20, 5, 5)).astype("float32"),
}
class TestChannelWiseFakeQuantDequantOp2(TestChannelWiseFakeQuantDequantOp):
def set_arg(self):
self.quant_axis = 0
self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }
class TestChannelWiseFakeQuantDequantOp3(TestChannelWiseFakeQuantDequantOp):
def set_arg(self):
self.quant_axis = 1
self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }
def test_channel_wise_fake_quant_dequant_abs_max(self):
input_shape_quant_axis_options = [[(3, 4, 64, 64), 0], [(
15, 20, 5, 5), 1], [(30, 15), 0], [(30, 15), 1]]
for input_shape, quant_axis in input_shape_quant_axis_options:
with self.subTest(input_shape=input_shape, quant_axis=quant_axis):
self._fake_channel_wise_quantize_dequantize_abs_max(
np.float32, input_shape, quant_axis, np.random.random)
def quantize_max_abs(x, max_range):
......@@ -589,5 +462,5 @@ class TestquantizeOpTrain(TestquantizeOp):
self.check_output()
if __name__ == "__main__":
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册