diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cu index 75c5eacb25bb46e45f5275b0f901dd1ece49c31d..0e762b7dc405f72204cff663b8855ebd84e20062 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cu @@ -20,7 +20,6 @@ #include #include #include "fake_quant_perchannel_impl.cuh" -#include "device/gpu/cuda_common.h" /** * Find the nudge min, max and scale value as output. @@ -34,13 +33,17 @@ * @param channel_num * @return */ -__global__ void NudgeMinMaxPerChannel(const float *input_min, const float *input_max, const float quant_min, - const float quant_max, float *nudge_min, float *nudge_max, float *scale, - int channel_num) { +__global__ void NudgeMinMaxPerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max, + float *nudge_min, float *nudge_max, float *scale, int channel_num, + const bool symmetric) { float zp_from_min = 0.f; float nudge_zp = 0.f; for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_num; i += blockDim.x * gridDim.x) { + if (symmetric) { + input_max[i] = abs(input_min[0]) < input_max[i] ? input_max[i] : -input_min[i]; + input_min[i] = abs(input_min[i]) < input_max[i] ? -input_max[i] : input_min[i]; + } if ((quant_max - quant_min) == 0 || (input_max[i] - input_min[i]) == 0) { scale[i] = 0.f; zp_from_min = 0.f; @@ -62,11 +65,11 @@ __global__ void NudgeMinMaxPerChannel(const float *input_min, const float *input } } -void CalNudgePerChannel(const float *input_min, const float *input_max, const float quant_min, const float quant_max, - float *nudge_min, float *nudge_max, float *scale, const int channel_num, +void CalNudgePerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max, + float *nudge_min, float *nudge_max, float *scale, const int channel_num, const bool symmetric, cudaStream_t cuda_stream) { NudgeMinMaxPerChannel<<>>( - input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, channel_num); + input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, channel_num, symmetric); } /** @@ -80,9 +83,8 @@ void CalNudgePerChannel(const float *input_min, const float *input_max, const fl * @param scale - array * @return */ -__global__ void FakeQuantizePerChannel(const float *input, float *output, const int total_size, const int channel_size, - const float *nudge_min, const float *nudge_max, const float *scale, - bool symmetric) { +__global__ void FakeQuantPerChannel(const float *input, float *output, const int total_size, const int channel_size, + const float *nudge_min, const float *nudge_max, const float *scale) { float input_x = 0.f; int nudge_input = 0; int channel_idx = 0; @@ -106,16 +108,15 @@ __global__ void FakeQuantizePerChannel(const float *input, float *output, const } } -void CalFakeQuantizePerChannel(const float *input, float *output, const int total_size, const int channel_size, - const float *nudge_min, const float *nudge_max, const float *scale, bool symmetric, - cudaStream_t cuda_stream) { - FakeQuantizePerChannel<<>>( - input, output, total_size, channel_size, nudge_min, nudge_max, scale, symmetric); +void CalFakeQuantPerChannel(const float *input, float *output, const int total_size, const int channel_size, + const float *nudge_min, const float *nudge_max, const float *scale, + cudaStream_t cuda_stream) { + FakeQuantPerChannel<<>>(input, output, total_size, channel_size, + nudge_min, nudge_max, scale); } -__global__ void FakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output, - const int total_size, const int channel_size, const float *nudge_min, - const float *nudge_max) { +__global__ void FakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_size, + const int channel_size, const float *nudge_min, const float *nudge_max) { int channel_idx = 0; int per_channel_num = total_size / channel_size; @@ -129,9 +130,9 @@ __global__ void FakeQuantizePerChannelGrad(const float *input, const float *grad } } -void CalFakeQuantizePerChannelGrad(const float *input, const float *gradient, float *output, const int total_num, - const int channel_num, const float *nudge_min, const float *nudge_max, - cudaStream_t cuda_stream) { - FakeQuantizePerChannelGrad<<>>( - input, gradient, output, total_num, channel_num, nudge_min, nudge_max); +void CalFakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_num, + const int channel_num, const float *nudge_min, const float *nudge_max, + cudaStream_t cuda_stream) { + FakeQuantPerChannelGrad<<>>(input, gradient, output, total_num, + channel_num, nudge_min, nudge_max); } diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh index 3dff7156a7725533414bf611e4986c97a4e95d90..ad2e387b082f74c57e6f59280e8648db6fdab1dd 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh @@ -14,22 +14,21 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ -void CalNudgePerChannel(const float* input_min, const float* input_max, const float quant_min, const float quant_max, - float* nudge_min, float* nudge_max, float* scale, const int channel_num, - cudaStream_t cuda_stream); +#include "device/gpu/cuda_common.h" -void CalFakeQuantizePerChannel(const float* input, float* output, const int total_num, const int channel_num, - const float* nudge_min, const float* nudge_max, const float* scale, bool symmetric, - cudaStream_t cuda_stream); +void CalNudgePerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max, + float *nudge_min, float *nudge_max, float *scale, const int channel_num, const bool symmetric, + cudaStream_t cuda_stream); -void CalMinMaxPerChannel(float* input, float* input_min, float* input_max, const int total_num, const int channel_num, - const float ema_decay, const bool ema, cudaStream_t cuda_stream); +void CalFakeQuantPerChannel(const float *input, float *output, const int total_num, const int channel_num, + const float *nudge_min, const float *nudge_max, const float *scale, + cudaStream_t cuda_stream); -void CalFakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output, const int total_num, - const int channel_num, const float* nudge_min, const float* nudge_max, - cudaStream_t cuda_stream); +void CalFakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_num, + const int channel_num, const float *nudge_min, const float *nudge_max, + cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cu index 11a25ba2947a365605c931ae205de3147aa76688..f527d33df959313cd01f24eedaf448371c0bbeeb 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cu @@ -17,11 +17,10 @@ #include #include #include -#include "device/gpu/cuda_common.h" #include "fake_quant_perlayer_impl.cuh" -__global__ void FakeQuantize(const float *input, float *output, const int size, const float *nudge_min, - const float *nudge_max, const float *scale) { +__global__ void FakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min, + const float *nudge_max, const float *scale) { float input_x = 0.f; int nudge_input = 0; @@ -43,8 +42,8 @@ __global__ void FakeQuantize(const float *input, float *output, const int size, return; } -__global__ void FakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size, - const float *nudge_min, const float *nudge_max) { +__global__ void FakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size, + const float *nudge_min, const float *nudge_max) { for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) { if (input[i] < nudge_min[0] || input[i] > nudge_max[0]) { output[i] = 0; @@ -55,12 +54,18 @@ __global__ void FakeQuantizeGrad(const float *input, const float *gradient, floa return; } -__global__ void NudgeMinMax(const float *input_min, const float *input_max, const float quant_min, - const float quant_max, float *nudge_min, float *nudge_max, float *scale) { +__global__ void NudgeMinMaxPerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max, + float *nudge_min, float *nudge_max, float *scale, const bool symmetric) { float zp_from_min = 0.f; scale[0] = 0.f; nudge_max[0] = 0.f; nudge_min[0] = 0.f; + + if (symmetric) { + input_max[0] = abs(input_min[0]) < input_max[0] ? input_max[0] : -input_min[0]; + input_min[0] = abs(input_min[0]) < input_max[0] ? -input_max[0] : input_min[0]; + } + if ((quant_max - quant_min) == 0 || (input_max[0] - input_min[0]) == 0) { scale[0] = 0.f; zp_from_min = 0.f; @@ -83,53 +88,24 @@ __global__ void NudgeMinMax(const float *input_min, const float *input_max, cons return; } -__global__ void UpdateInputMinMaxWithEMA(float *input_min, float *input_max, const float min, const float max, - const float decay) { - input_min[0] = decay * (min) + (1 - decay) * (input_min[0]); - input_min[0] = input_min[0] > 0 ? 0 : input_min[0]; - input_max[0] = decay * (max) + (1 - decay) * (input_max[0]); - input_max[0] = input_max[0] < 0 ? 0 : input_max[0]; - return; -} - -__global__ void UpdateInputMinMax(float *input_min, float *input_max, const float min, const float max) { - input_min[0] = min > 0 ? 0 : min; - input_max[0] = max < 0 ? 0 : max; -} - -void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max, - const float *scale, bool symmetric, cudaStream_t cuda_stream) { - FakeQuantize<<>>(input, output, size, nudge_min, nudge_max, scale); - return; -} - -void CalFakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size, - const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream) { - FakeQuantizeGrad<<>>(input, gradient, output, size, nudge_min, - nudge_max); +void CalFakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min, + const float *nudge_max, const float *scale, cudaStream_t cuda_stream) { + FakeQuantPerLayer<<>>(input, output, size, nudge_min, nudge_max, + scale); return; } -void CalNudge(const float *input_min, const float *input_max, const float quant_min, const float quant_max, - float *nudge_min, float *nudge_max, float *scale, cudaStream_t cuda_stream) { - NudgeMinMax<<<1, 1, 0, cuda_stream>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale); +void CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size, + const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream) { + FakeQuantPerLayerGrad<<>>(input, gradient, output, size, nudge_min, + nudge_max); return; } -void CalMinMax(float *input, float *input_min, float *input_max, const int size, const float ema_decay, const bool ema, - cudaStream_t cuda_stream) { - float minel = 0.f; - float maxel = 0.f; - auto policy = thrust::cuda::par.on(cuda_stream); - thrust::pair, thrust::device_ptr> tuple; - tuple = thrust::minmax_element(policy, thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + size); - minel = tuple.first[0]; - maxel = tuple.second[0]; - - if (ema) { - UpdateInputMinMaxWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, minel, maxel, ema_decay); - } else { - UpdateInputMinMax<<<1, 1, 0, cuda_stream>>>(input_min, input_max, minel, maxel); - } +void CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max, + float *nudge_min, float *nudge_max, float *scale, const bool symmetric, + cudaStream_t cuda_stream) { + NudgeMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, + symmetric); return; } diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh index 27c39dead1b8831a5ae21989e2a42ddec7d1173e..dda95ed781f56b79e6f0116f2fcc82700e09e144 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh @@ -14,19 +14,18 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ -void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max, - const float *scale, bool symmetric, cudaStream_t cuda_stream); +#include "device/gpu/cuda_common.h" -void CalFakeQuantizeGrad(const float *input, const float *gradient, float *output, const int size, - const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream); +void CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max, + float *nudge_min, float *nudge_max, float *scale, const bool symmetric, cudaStream_t cuda_stream); -void CalNudge(const float *input_min, const float *input_max, const float quant_min, const float quant_max, - float *nudge_min, float *nudge_max, float *scale, cudaStream_t cuda_stream); +void CalFakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min, + const float *nudge_max, const float *scale, cudaStream_t cuda_stream); -void CalMinMax(float *input, float *input_min, float *input_max, const int size, const float ema_decay, const bool ema, - cudaStream_t cuda_stream); +void CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size, + const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_ +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_ diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc index ffed550fbbc74cbec50c711853554c706166060f..8db6ddd84877690d5f41e61e3ad61a7fe221b741 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc @@ -102,9 +102,9 @@ void FakeQuantPerChannelGpuKernel::InitSizeLists() { void FakeQuantPerChannelGpuKernel::CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, float *nudge_min, float *nudge_max, float *scale, void *stream_ptr) { CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, - reinterpret_cast(stream_ptr)); - CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale, - symmetric_, reinterpret_cast(stream_ptr)); + symmetric_, reinterpret_cast(stream_ptr)); + CalFakeQuantPerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale, + reinterpret_cast(stream_ptr)); } bool FakeQuantPerChannelGpuKernel::Launch(const std::vector &inputs, diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc index a57516eb2c75c5440d32db1eee774453529ee6d7..5c774c05edd90f90e26d5eb450b897a0a5382287 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc @@ -119,9 +119,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector &inp int total_size = input_size_ / sizeof(float); if (global_step_ >= quant_delay_) { CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_, - reinterpret_cast(stream_ptr)); - CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max, - reinterpret_cast(stream_ptr)); + symmetric_, reinterpret_cast(stream_ptr)); + CalFakeQuantPerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max, + reinterpret_cast(stream_ptr)); } else { CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.cc index 845fb5b923d3e1952989150242bc394f15cd3045..44869983eb8988919e9d4472f2e308f271b1911c 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.cc @@ -117,10 +117,10 @@ bool FakeQuantPerLayerGpuKernel::Launch(const std::vector &inputs, c // control flow for quant_delay if (global_step_ >= quant_delay_) { // real launch - CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, - reinterpret_cast(stream_ptr)); - CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_, - reinterpret_cast(stream_ptr)); + CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, + reinterpret_cast(stream_ptr)); + CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale, + reinterpret_cast(stream_ptr)); } else { CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), @@ -129,10 +129,10 @@ bool FakeQuantPerLayerGpuKernel::Launch(const std::vector &inputs, c global_step_++; } else { // real launch - CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, - reinterpret_cast(stream_ptr)); - CalFakeQuantize(input, output, quant_num_, nudge_min, nudge_max, scale, symmetric_, - reinterpret_cast(stream_ptr)); + CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, + reinterpret_cast(stream_ptr)); + CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale, + reinterpret_cast(stream_ptr)); } return true; diff --git a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc index 9c6584e23965774c67320caac09d62280220b512..c8d57b2bb1fa6cce3e840923b182545fe48fd129 100644 --- a/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc @@ -115,10 +115,10 @@ bool FakeQuantPerLayerGradGpuKernel::Launch(const std::vector &input } if (global_step_ >= quant_delay_) { - CalNudge(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, - reinterpret_cast(stream_ptr)); - CalFakeQuantizeGrad(input, gradient, output, quant_num_, nudge_min, nudge_max, - reinterpret_cast(stream_ptr)); + CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_, + reinterpret_cast(stream_ptr)); + CalFakeQuantPerLayerGrad(input, gradient, output, quant_num_, nudge_min, nudge_max, + reinterpret_cast(stream_ptr)); } else { CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast(stream_ptr)), diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index 937e54a7e48c427f2df5dc174099932c6d77c0b7..46b3cd1934f540745d925f3164b05807a4d38999 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -150,7 +150,7 @@ class ConvertToQuantNetwork: prefix = name add_quant = _AddFakeQuantAfterSubCell(prim_op, num_bits=self.act_bits, - quant_delay=self.act_delay, + quant_delay=self.act_qdelay, per_channel=self.act_channel, symmetric=self.act_symmetric, narrow_range=self.act_range) @@ -408,19 +408,19 @@ def convert_quant_network(network, Args: network (Cell): Obtain a pipeline through network for saving graph summary. - quant_delay (int or tuple): Number of steps after which weights and activations are quantized during - eval. The first element represent weights and second element represent data flow. Default: (0, 0) bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False. freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0. - num_bits (int or tuple): Number of bits to use for quantizing weights and activations. The first + quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during + eval. The first element represent weights and second element represent data flow. Default: (0, 0) + num_bits (int, list or tuple): Number of bits to use for quantizing weights and activations. The first element represent weights and second element represent data flow. Default: (8, 8) - per_channel (int or tuple): Quantization granularity based on layer or on channel. If `True` + per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True` then base on per channel otherwise base on per layer. The first element represent weights and second element represent data flow. Default: (False, False) - symmetric (int or tuple): Quantization algorithm use symmetric or not. If `True` then base on - symmetric otherwise base on assymmetric. The first element represent weights and second + symmetric (bool, list or tuple): Quantization algorithm use symmetric or not. If `True` then base on + symmetric otherwise base on asymmetric. The first element represent weights and second element represent data flow. Default: (False, False) - narrow_range (int or tuple): Quantization algorithm use narrow range or not. If `True` then base + narrow_range (bool, list or tuple): Quantization algorithm use narrow range or not. If `True` then base on narrow range otherwise base on off narrow range. The first element represent weights and second element represent data flow. Default: (False, False) diff --git a/tests/st/ops/gpu/test_fake_quant_perchannel.py b/tests/st/ops/gpu/test_fake_quant_perchannel.py new file mode 100644 index 0000000000000000000000000000000000000000..caa7d4b7e80acd9c89a1781c8c79d4cb5829adc4 --- /dev/null +++ b/tests/st/ops/gpu/test_fake_quant_perchannel.py @@ -0,0 +1,625 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore.common.tensor import Tensor +from mindspore import nn +from mindspore.ops.operations import _quant_ops as Q + +context.set_context(device_target='GPU', device_id=0) + + +class Net(nn.Cell): + def __init__(self, num_bits=8, symmetric=False, narrow_range=False, channel_axis=1): + super(Net, self).__init__() + self.op = Q.FakeQuantPerChannel(num_bits=num_bits, + symmetric=symmetric, + narrow_range=narrow_range, + channel_axis=channel_axis) + + def construct(self, x, minq, maxq): + return self.op(x, minq, maxq) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel1(): + # WithVarsPerChannel_ZeroMinAndMax + x = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32) + min_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32) + max_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32) + expect = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False, channel_axis=0) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel2(): + # WithVarsPerChannelDim1NudgedDown_RegularRange + # scale 1/4, zp 0.4, nudge 0. nudged ranges [0.0, 63.75] + x = np.array([-0.1, 0.0, 63.75, 63.8]).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32) + max_val = np.array([63.65, 63.65, 63.65, 63.65]).astype(np.float32) + expect = np.array([0.0, 0.0, 63.75, 63.75]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False, channel_axis=0) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel3(): + # WithVarsPerChannelDim1NudgedDown_NarrowRange + # scale 1/4, zp 1.4, nudge 1. nudged ranges[0.0, 63.5] + x = np.array([-0.1, 0.0, 63.5, 63.6]).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32) + max_val = np.array([63.4, 63.4, 63.4, 63.4]).astype(np.float32) + expect = np.array([0.0, 0.0, 63.5, 63.5]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True, channel_axis=0) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel4(): + # WithVarsPerChannelDim1NudgedUp_RegularRange + # [-0.125, 63.625] + # scale 1/4, zp: 0.5, nudge 0. nudged range [-0.25, 63.5] + x = np.array([-0.26, -0.25, -0.24, 63.6]).astype(np.float32) + expect = np.array([-0.25, -0.25, -0.25, 63.5]).astype(np.float32) + min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32) + max_val = np.array([63.625, 63.625, 63.625, 63.625]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False, channel_axis=0) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel5(): + # WithVarsPerChannelDim1NudgedUp_NarrowRange + # scale 1/4, zp: 1.5, nudge 2. nudged range [-0.25, 63.25] + x = np.array([-0.26, -0.25, -0.24, 63.3]).astype(np.float32) + expect = np.array([-0.25, -0.25, -0.25, 63.25]).astype(np.float32) + min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32) + max_val = np.array([63.375, 63.375, 63.375, 63.375]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True, channel_axis=0) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel6(): + # WithVarsPerChannelDim2NudgedDown_RegularRange + # scale 1/4, zp: 0.4, nudge 0. nudged range [-0.25, 63.75] + x = np.array([-0.1, 0.0, 0.1, 0.25, 63.75, 63.80] + ).reshape(2, 3).astype(np.float32) + expect = np.array([-0.0, 0.0, 0.0, 0.25, 63.75, 63.75]).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1]).reshape(3).astype(np.float32) + max_val = np.array([63.65, 63.65, 63.65]).reshape(3).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False, channel_axis=1) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel7(): + # WithVarsPerChannelDim2NudgedDown_NarrowRange + # scale 1/4, zp: 1.4, nudge 1. nudged range [-0.25, 63.5] + x = np.array([-0.1, 0.0, 0.1, 0.25, 63.5, 63.6] + ).reshape(2, 3).astype(np.float32) + expect = np.array([0.0, 0.0, 0.0, 0.25, 63.5, 63.5]).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1]).reshape(3).astype(np.float32) + max_val = np.array([63.4, 63.4, 63.4]).reshape(3).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True, channel_axis=1) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel8(): + # WithVarsPerChannelDim2NudgedUp_RegularRange + # scale 1/4, zp: 0.5, nudge 1. nudged range [-0.25, 63.5] + x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6] + ).reshape(2, 3).astype(np.float32) + expect = np.array([-0.25, -0.25, -0.25, 0.0, 63.5, 63.5] + ).astype(np.float32) + min_val = np.array([-0.125, -0.125, -0.125]).reshape(3).astype(np.float32) + max_val = np.array([63.625, 63.625, 63.625]).reshape(3).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False, channel_axis=1) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel9(): + # WithVarsPerChannelDim2NudgedUp_NarrowRange + # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] + x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3] + ).reshape(2, 3).astype(np.float32) + expect = np.array( + [-0.25, -0.25, -0.25, 0.0, 63.25, 63.25]).astype(np.float32) + min_val = np.array([-0.125, -0.125, -0.125]).reshape(3).astype(np.float32) + max_val = np.array([63.375, 63.375, 63.375]).reshape(3).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True, channel_axis=1) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel10(): + # WithVarsPerChannelDim4NudgedDown_RegularRange + # scale 1/4, zp: 0.4, nudge 0. nudged range [-0.25, 63.25] + x = np.array([-0.1, 0.0, 0.1, 0.25, 0.5, 0.75, + 1.0, 1.25, 1.5, 1.75, 2.0, 2.25, + 63.0, 63.25, 63.5, 63.7, 63.75, 63.8, + 63.9, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32) + expect = np.array([0.0, 0.0, 0.0, 0.25, 0.5, 0.75, + 1.0, 1.25, 1.5, 1.75, 2.0, 2.25, + 63.0, 63.25, 63.5, 63.75, 63.75, 63.75, + 63.75, 63.75, 63.75, 63.75, 63.75, 63.75]).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32) + max_val = np.array([63.65, 63.65, 63.65, 63.65] + ).reshape(4).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False, channel_axis=1) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel11(): + # WithVarsPerChannelDim4NudgedDown_NarrowRange + # scale 1/4, zp: 1.4, nudge 1. nudged range [0.0, 63.25] + x = np.array([-0.1, 0.0, 0.1, 0.25, 0.5, 0.75, + 1.0, 1.25, 1.5, 1.75, 2.0, 2.25, + 63.0, 63.25, 63.3, 63.4, 63.5, 63.6, + 63.7, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32) + expect = np.array([0.0, 0.0, 0.0, 0.25, 0.5, 0.75, + 1.0, 1.25, 1.5, 1.75, 2.0, 2.25, + 63.0, 63.25, 63.25, 63.5, 63.5, 63.5, + 63.5, 63.5, 63.5, 63.5, 63.5, 63.5]).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32) + max_val = np.array([63.4, 63.4, 63.4, 63.4]).reshape(4).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True, channel_axis=1) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel12(): + # WithVarsPerChannelDim4NudgedUp_RegularRange + # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] + x = np.array([-0.3, -0.25, -0.2, 0.0, 0.25, 0.5, + 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, + 63.0, 63.25, 63.4, 63.5, 63.6, 63.7, + 100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32) + expect = np.array([-0.25, -0.25, -0.25, 0.0, 0.25, 0.5, + 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, + 63.0, 63.25, 63.5, 63.5, 63.5, 63.5, + 63.5, 63.5, 63.5, 63.5, 63.5, 63.5]).astype(np.float32) + min_val = np.array([-0.125, -0.125, -0.125, -0.125] + ).reshape(4).astype(np.float32) + max_val = np.array([63.625, 63.625, 63.625, 63.625] + ).reshape(4).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False, channel_axis=1) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel13(): + # WithVarsPerChannelDim4NudgedUp_NarrowRange + # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] + x = np.array([-0.3, -0.25, -0.2, 0.0, 0.25, 0.5, + 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, + 63.0, 63.2, 63.25, 63.3, 63.4, 63.5, + 100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32) + expect = np.array([-0.25, -0.25, -0.25, 0.0, 0.25, 0.5, + 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, + 63.0, 63.25, 63.25, 63.25, 63.25, 63.25, + 63.25, 63.25, 63.25, 63.25, 63.25, 63.25]).astype(np.float32) + min_val = np.array([-0.125, -0.125, -0.125, -0.125] + ).reshape(4).astype(np.float32) + max_val = np.array([63.375, 63.375, 63.375, 63.375] + ).reshape(4).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True, channel_axis=1) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel14(): + # WithVarsPerChannelDim1NudgedDown_4Bits_RegularRange + # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] + x = np.array([-0.1, 0.0, 7.5, 7.6]).reshape(4).astype(np.float32) + expect = np.array([0.0, 0.0, 7.5, 7.5]).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32) + max_val = np.array([7.4, 7.4, 7.4, 7.4]).reshape(4).astype(np.float32) + + net = Net(num_bits=4, narrow_range=False, channel_axis=0) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel15(): + # WithVarsPerChannelDim1NudgedDown_4Bits_NarrowRange + # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] + x = np.array([-0.1, 0.0, 7.0, 7.1]).reshape(4).astype(np.float32) + expect = np.array([0.0, 0.0, 7.0, 7.0]).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32) + max_val = np.array([6.9, 6.9, 6.9, 6.9]).reshape(4).astype(np.float32) + + net = Net(num_bits=4, narrow_range=True, channel_axis=0) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel16(): + # WithVarsPerChannelDim1NudgedUp_4Bits_RegularRange + # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] + x = np.array([-0.6, -0.5, 7.0, 7.1]).reshape(4).astype(np.float32) + expect = np.array([-0.5, -0.5, 7.0, 7.0]).astype(np.float32) + min_val = np.array([-0.4, -0.4, -0.4, -0.4]).reshape(4).astype(np.float32) + max_val = np.array([7.1, 7.1, 7.1, 7.1]).reshape(4).astype(np.float32) + + net = Net(num_bits=4, narrow_range=False, channel_axis=0) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel17(): + # WithVarsPerChannelDim1NudgedUp_4Bits_NarrowRange + # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] + x = np.array([-0.6, -0.5, 6.5, 6.6]).reshape(4).astype(np.float32) + expect = np.array([-0.5, -0.5, 6.5, 6.5]).astype(np.float32) + min_val = np.array([-0.4, -0.4, -0.4, -0.4]).reshape(4).astype(np.float32) + max_val = np.array([6.6, 6.6, 6.6, 6.6]).reshape(4).astype(np.float32) + + net = Net(num_bits=4, narrow_range=True, channel_axis=0) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel18(): + # WithVarsPerChannelDim2NudgedDown_4Bits_RegularRange + # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] + x = np.array([-0.1, 0.0, 0.1, 0.5, 7.5, 7.6] + ).reshape(2, 3).astype(np.float32) + expect = np.array([0.0, 0.0, 0.0, 0.5, 7.5, 7.5]).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1]).reshape(3).astype(np.float32) + max_val = np.array([7.4, 7.4, 7.4]).reshape(3).astype(np.float32) + + net = Net(num_bits=4, narrow_range=False, channel_axis=1) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel19(): + # WithVarsPerChannelDim2NudgedDown_4Bits_NarrowRange + # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] + x = np.array([-0.1, 0.0, 0.1, 0.5, 7.0, 7.1] + ).reshape(2, 3).astype(np.float32) + expect = np.array([0.0, 0.0, 0.0, 0.5, 7.0, 7.0]).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1]).reshape(3).astype(np.float32) + max_val = np.array([6.9, 6.9, 6.9]).reshape(3).astype(np.float32) + + net = Net(num_bits=4, narrow_range=True, channel_axis=1) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel20(): + # WithVarsPerChannelDim2NudgedUp_4Bits_RegularRange + # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] + x = np.array([-0.51, -0.5, -0.24, 0.0, 7.0, 7.1] + ).reshape(2, 3).astype(np.float32) + expect = np.array([-0.5, -0.5, 0.0, 0.0, 7.0, 7.0]).astype(np.float32) + min_val = np.array([-0.4, -0.4, -0.4]).reshape(3).astype(np.float32) + max_val = np.array([7.1, 7.1, 7.1]).reshape(3).astype(np.float32) + + net = Net(num_bits=4, narrow_range=False, channel_axis=1) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel21(): + # WithVarsPerChannelDim2NudgedUp_4Bits_NarrowRange + # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] + x = np.array([-0.6, -0.5, -0.24, 0.0, 6.5, 6.6] + ).reshape(2, 3).astype(np.float32) + expect = np.array([-0.5, -0.5, 0.0, 0.0, 6.5, 6.5]).astype(np.float32) + min_val = np.array([-0.4, -0.4, -0.4]).reshape(3).astype(np.float32) + max_val = np.array([6.6, 6.6, 6.6]).reshape(3).astype(np.float32) + + net = Net(num_bits=4, narrow_range=True, channel_axis=1) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel22(): + # WithVarsPerChannelDim4NudgedDown_4Bits_RegularRange + # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] + x = np.array([-0.1, 0.0, 0.1, 0.5, 1.0, 1.5, + 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, + 6.0, 6.5, 7.0, 7.4, 7.5, 7.7, + 7.8, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32) + expect = np.array([0.0, 0.0, 0.0, 0.5, 1.0, 1.5, + 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, + 6.0, 6.5, 7.0, 7.5, 7.5, 7.5, + 7.5, 7.5, 7.5, 7.5, 7.5, 7.5]).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32) + max_val = np.array([7.4, 7.4, 7.4, 7.4]).reshape(4).astype(np.float32) + + net = Net(num_bits=4, narrow_range=False, channel_axis=1) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel23(): + # WithVarsPerChannelDim4NudgedDown_4Bits_NarrowRange + # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] + x = np.array([-0.1, 0.0, 0.1, 0.5, 1.0, 1.5, + 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, + 6.0, 6.5, 6.8, 6.9, 7.0, 7.1, + 7.2, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32) + expect = np.array([0.0, 0.0, 0.0, 0.5, 1.0, 1.5, + 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, + 6.0, 6.5, 7.0, 7.0, 7.0, 7.0, + 7.0, 7.0, 7.0, 7.0, 7.0, 7.0]).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1, -0.1]).reshape(4).astype(np.float32) + max_val = np.array([6.9, 6.9, 6.9, 6.9]).reshape(4).astype(np.float32) + + net = Net(num_bits=4, narrow_range=True, channel_axis=1) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel24(): + # WithVarsPerChannelDim4NudgedUp_4Bits_RegularRange + # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] + x = np.array([-0.6, -0.5, -0.4, 0.0, 0.5, 1.0, + 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, + 6.0, 6.5, 6.9, 7.0, 7.1, 7.7, + 100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32) + expect = np.array([-0.5, -0.5, -0.5, 0.0, 0.5, 1.0, + 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, + 6.0, 6.5, 7.0, 7.0, 7.0, 7.0, + 7.0, 7.0, 7.0, 7.0, 7.0, 7.0]).astype(np.float32) + min_val = np.array([-0.4, -0.4, -0.4, -0.4]).reshape(4).astype(np.float32) + max_val = np.array([7.1, 7.1, 7.1, 7.1]).reshape(4).astype(np.float32) + + net = Net(num_bits=4, narrow_range=False, channel_axis=1) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_perchannel25(): + # WithVarsPerChannelDim4NudgedUp_4Bits_NarrowRange + # scale 1/4, zp: 0.5, nudge 2. nudged range [-0.25, 63.25] + x = np.array([-0.6, -0.5, -0.4, 0.0, 0.5, 1.0, + 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, + 5.5, 6.0, 6.4, 6.5, 6.6, 6.7, + 100.0, 100.0, 100.0, 100.0, 100.0, 1000.0]).reshape(1, 4, 2, 3).astype(np.float32) + expect = np.array([-0.5, -0.5, -0.5, 0.0, 0.5, 1.0, + 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, + 5.5, 6.0, 6.5, 6.5, 6.5, 6.5, + 6.5, 6.5, 6.5, 6.5, 6.5, 6.5]).astype(np.float32) + min_val = np.array([-0.4, -0.4, -0.4, -0.4]).reshape(4).astype(np.float32) + max_val = np.array([6.6, 6.6, 6.6, 6.6]).reshape(4).astype(np.float32) + + net = Net(num_bits=4, narrow_range=True, channel_axis=1) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) diff --git a/tests/st/ops/gpu/test_fake_quant_perchannel_grad.py b/tests/st/ops/gpu/test_fake_quant_perchannel_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb90b680434ce1afa1c7415710b5bfa627b0a3d --- /dev/null +++ b/tests/st/ops/gpu/test_fake_quant_perchannel_grad.py @@ -0,0 +1,373 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from mindspore import Tensor +import mindspore.nn as nn +import mindspore.context as context +from mindspore.ops.operations import _quant_ops as Q + +context.set_context(device_target='GPU', device_id=0) + + +class Net(nn.Cell): + def __init__(self, num_bits=8, narrow_range=False): + super(Net, self).__init__() + self.op = Q.FakeQuantPerChannelGrad( + num_bits=num_bits, narrow_range=narrow_range) + + def construct(self, dout, x, minq, maxq): + return self.op(dout, x, minq, maxq) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad1(): + # WithVarsPerChannelDim1GradientNudgedDown_ZeroMinAndMax + dout = np.random.uniform(-1, 1, size=[4]).astype('float32') + x = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32) + min_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32) + max_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32) + expect = dout + + net = Net(num_bits=8, narrow_range=False) + output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("=" * 40) + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad2(): + # WithVarsPerChannelDim1GradientNudgedDown_RegularRange + dout = np.random.uniform(-1, 1, size=[4]).astype('float32') + x = np.array([-0.1, 0.0, 63.75, 63.8]).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32) + max_val = np.array([63.65, 63.65, 63.65, 63.65]).astype(np.float32) + expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False) + output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("=" * 40) + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad3(): + # WithVarsPerChannelDim1GradientNudgedDown_NarrowRange + dout = np.random.uniform(-1, 1, size=[4]).astype('float32') + x = np.array([-0.1, 0.0, 63.5, 63.6]).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32) + max_val = np.array([63.4, 63.4, 63.4, 63.4]).astype(np.float32) + expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True) + output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("=" * 40) + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad4(): + # WithVarsPerChannelDim1GradientNudgedUp_RegularRange + dout = np.random.uniform(-1, 1, size=[4]).astype('float32') + x = np.array([-0.3, -0.25, 63.5, 63.6]).astype(np.float32) + min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32) + max_val = np.array([63.625, 63.625, 63.625, 63.625]).astype(np.float32) + expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False) + output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("=" * 40) + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad5(): + # WithVarsPerChannelDim1GradientNudgedUp_NarrowRange + dout = np.random.uniform(-1, 1, size=[4]).astype('float32') + x = np.array([-0.3, -0.25, 63.25, 63.3]).astype(np.float32) + min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32) + max_val = np.array([63.375, 63.375, 63.375, 63.375]).astype(np.float32) + expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True) + output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("=" * 40) + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad6(): + # WithVarsPerChannelDim2GradientNudgedDown_RegularRange + read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32') + x = np.array([-0.1, 0.0, 0.1, 0.25, 63.75, 63.8] + ).reshape(3, 2).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1]).astype(np.float32) + max_val = np.array([63.65, 63.65, 63.65]).astype(np.float32) + dout = read_dout.flatten() + expect = np.array([0.0, dout[1], dout[2], dout[3], + dout[4], 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True) + output = net(Tensor(read_dout), Tensor( + x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("=" * 40) + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad7(): + # WithVarsPerChannelDim2GradientNudgedDown_NarrowRange + read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32') + x = np.array([-0.1, 0.0, 0.1, 0.25, 63.5, 63.6] + ).reshape(3, 2).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1]).astype(np.float32) + max_val = np.array([63.4, 63.4, 63.4]).astype(np.float32) + dout = read_dout.flatten() + expect = np.array([0.0, dout[1], dout[2], dout[3], + dout[4], 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True) + output = net(Tensor(read_dout), Tensor( + x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("=" * 40) + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad8(): + # WithVarsPerChannelDim2GradientNudgedUp_RegularRange + read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32') + x = np.array([-0.3, -0.25, -0.2, 0.0, 63.5, 63.6] + ).reshape(3, 2).astype(np.float32) + min_val = np.array([-0.125, -0.125, -0.125]).astype(np.float32) + max_val = np.array([63.625, 63.625, 63.625]).astype(np.float32) + dout = read_dout.flatten() + expect = np.array([0.0, dout[1], dout[2], dout[3], + dout[4], 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False) + output = net(Tensor(read_dout), Tensor( + x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("=" * 40) + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad9(): + # WithVarsPerChannelDim2GradientNudgedUp_NarrowRange + read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32') + x = np.array([-0.3, -0.25, -0.2, 0.0, 63.25, 63.3] + ).reshape(3, 2).astype(np.float32) + min_val = np.array([-0.125, -0.125, -0.125]).astype(np.float32) + max_val = np.array([63.375, 63.375, 63.375]).astype(np.float32) + dout = read_dout.flatten() + expect = np.array([0.0, dout[1], dout[2], dout[3], + dout[4], 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True) + output = net(Tensor(read_dout), Tensor( + x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("=" * 40) + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad10(): + # WithVarsPerChannelDim4GradientNudgedDown_RegularRange + read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32') + x = np.array([-0.1, 0.0, 63.75, 63.8, -0.1, 0.0, + 63.75, 63.8, -0.1, 0.0, 63.75, 63.8, + -0.1, 0.0, 63.75, 63.8, -0.1, 0.0, + 63.75, 63.8, -0.1, 0.0, 63.75, 63.8]).reshape(4, 3, 2, 1).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32) + max_val = np.array([63.65, 63.65, 63.65, 63.65]).astype(np.float32) + dout = read_dout.flatten() + expect = np.array([0.0, dout[1], dout[2], 0.0, + 0.0, dout[5], dout[6], 0.0, + 0.0, dout[9], dout[10], 0.0, + 0.0, dout[13], dout[14], 0.0, + 0.0, dout[17], dout[18], 0.0, + 0.0, dout[21], dout[22], 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False) + output = net(Tensor(read_dout), Tensor( + x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("=" * 40) + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad11(): + # WithVarsPerChannelDim4GradientNudgedDown_NarrowRange + read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32') + x = np.array([-0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, + 63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6]).reshape(4, 3, 2, 1).astype(np.float32) + min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32) + max_val = np.array([63.4, 63.4, 63.4, 63.4]).astype(np.float32) + dout = read_dout.flatten() + expect = np.array([0.0, dout[1], dout[2], 0.0, + 0.0, dout[5], dout[6], 0.0, + 0.0, dout[9], dout[10], 0.0, + 0.0, dout[13], dout[14], 0.0, + 0.0, dout[17], dout[18], 0.0, + 0.0, dout[21], dout[22], 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True) + output = net(Tensor(read_dout), Tensor( + x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("=" * 40) + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad12(): + # WithVarsPerChannelDim4GradientNudgedUp_RegularRange + read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32') + x = np.array([-0.3, -0.25, 63.5, 63.6, -0.3, -0.25, + 63.5, 63.6, -0.3, -0.25, 63.5, 63.6, + -0.3, -0.25, 63.5, 63.6, -0.3, -0.25, + 63.5, 63.6, -0.3, -0.25, 63.5, 63.6]).reshape(4, 3, 2, 1).astype(np.float32) + min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32) + max_val = np.array([63.625, 63.625, 63.625, 63.625]).astype(np.float32) + dout = read_dout.flatten() + expect = np.array([0.0, dout[1], dout[2], 0.0, + 0.0, dout[5], dout[6], 0.0, + 0.0, dout[9], dout[10], 0.0, + 0.0, dout[13], dout[14], 0.0, + 0.0, dout[17], dout[18], 0.0, + 0.0, dout[21], dout[22], 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False) + output = net(Tensor(read_dout), Tensor( + x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("=" * 40) + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad13(): + # WithVarsPerChannelDim4GradientNudgedUp_NarrowRange + read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32') + x = np.array([-0.3, -0.25, 63.25, 63.3, -0.3, -0.25, + 63.25, 63.3, -0.3, -0.25, 63.25, 63.3, + -0.3, -0.25, 63.25, 63.3, -0.3, -0.25, + 63.25, 63.3, -0.3, -0.25, 63.25, 63.3]).reshape(4, 3, 2, 1).astype(np.float32) + min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32) + max_val = np.array([63.375, 63.375, 63.375, 63.375]).astype(np.float32) + dout = read_dout.flatten() + expect = np.array([0.0, dout[1], dout[2], 0.0, + 0.0, dout[5], dout[6], 0.0, + 0.0, dout[9], dout[10], 0.0, + 0.0, dout[13], dout[14], 0.0, + 0.0, dout[17], dout[18], 0.0, + 0.0, dout[21], dout[22], 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True) + output = net(Tensor(read_dout), Tensor( + x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("=" * 40) + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) diff --git a/tests/st/ops/gpu/test_fake_quant_perlayer.py b/tests/st/ops/gpu/test_fake_quant_perlayer.py new file mode 100644 index 0000000000000000000000000000000000000000..661cea09253249ac6ee6db963b3cfaf5ab45f707 --- /dev/null +++ b/tests/st/ops/gpu/test_fake_quant_perlayer.py @@ -0,0 +1,386 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore.common.tensor import Tensor +import mindspore.nn as nn +from mindspore.ops.operations import _quant_ops as Q + +context.set_context(device_target='GPU', device_id=0) + + +class Net(nn.Cell): + def __init__(self, + num_bits=8, + quant_delay=0, + symmetric=False, + narrow_range=False, + training=True): + super(Net, self).__init__() + self.fake_quant = Q.FakeQuantPerLayer(num_bits=num_bits, + quant_delay=quant_delay, + symmetric=symmetric, + narrow_range=narrow_range, + training=training) + + def construct(self, x, minq, maxq): + return self.fake_quant(x, minq, maxq) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant1(): + # (8, false, 0.0f, 0.0f, TensorShape({2, 3}), + # {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, + # {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + x = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).reshape(2, 3).astype(np.float32) + min_val = np.array([0]).reshape(1).astype(np.float32) + max_val = np.array([0]).reshape(1).astype(np.float32) + expect = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant2(): + # 8, false, -10.0f, 53.75f, TensorShape({2, 3}), + # {-10.1f, -10.0f, -9.9f, -9.75f, 53.75f, 53.8f}, + # {-10.0f, -10.0f, -10.0f, -9.75f, 53.75f, 53.75f}); + x = np.array([-10.1, -10.0, -9.9, -9.75, 53.75, 53.8]).reshape(2, 3).astype(np.float32) + min_val = np.array([-10.0]).reshape(1).astype(np.float32) + max_val = np.array([53.75]).reshape(1).astype(np.float32) + expect = np.array([-10.0, -10.0, -10.0, -9.75, 53.75, 53.75]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant3(): + # WithVarsNoNudging_NarrowRange + x = np.array([-10.1, -10.0, -9.90, -9.75, 53.5, 53.6]).reshape(2, 3).astype(np.float32) + min_val = np.array([-10.0]).reshape(1).astype(np.float32) + max_val = np.array([53.5]).reshape(1).astype(np.float32) + expect = np.array([-10.0, -10.0, -10.0, -9.75, 53.5, 53.5]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant4(): + # WithVarsNudgedDown_RegularRange + x = np.array([-0.1, 0.0, 0.1, 0.25, 63.75, 63.8]).reshape(2, 3).astype(np.float32) + min_val = np.array([-0.1]).reshape(1).astype(np.float32) + max_val = np.array([63.65]).reshape(1).astype(np.float32) + expect = np.array([-0.0, 0.0, 0.0, 0.25, 63.75, 63.75]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant5(): + # WithVarsNudgedDown_NarrowRange + x = np.array([-0.1, 0.0, 0.1, 0.25, 63.5, 63.6]).reshape(2, 3).astype(np.float32) + min_val = np.array([-0.1]).reshape(1).astype(np.float32) + max_val = np.array([63.4]).reshape(1).astype(np.float32) + expect = np.array([-0.0, 0.0, 0.0, 0.25, 63.5, 63.5]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant6(): + # WithVarsNudgedUp_RegularRange + x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6]).reshape(2, 3).astype(np.float32) + min_val = np.array([-0.125]).reshape(1).astype(np.float32) + max_val = np.array([63.625]).reshape(1).astype(np.float32) + expect = np.array([-0.25, -0.25, -0.25, 0.0, 63.5, 63.5]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant7(): + # WithVarsNudgedUp_NarrowRange + x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3]).reshape(2, 3).astype(np.float32) + min_val = np.array([-0.125]).reshape(1).astype(np.float32) + max_val = np.array([63.375]).reshape(1).astype(np.float32) + expect = np.array([-0.25, -0.25, -0.25, 0.0, 63.25, 63.25]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant8(): + # WithVarsNudgedZeroIs255_RegularRange + x = np.array([-63.80, -63.75, -63.70, -63.5, 0.0, 0.1]).reshape(2, 3).astype(np.float32) + min_val = np.array([-63.65]).reshape(1).astype(np.float32) + max_val = np.array([0.1]).reshape(1).astype(np.float32) + expect = np.array([-63.75, -63.75, -63.75, -63.5, 0.0, 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant9(): + # WithVarsNudgedZeroIs255_NarrowRange + x = np.array([-63.6, -63.5, -63.4, -63.25, 0.0, 0.1]).reshape(2, 3).astype(np.float32) + min_val = np.array([-63.4]).reshape(1).astype(np.float32) + max_val = np.array([0.1]).reshape(1).astype(np.float32) + expect = np.array([-63.5, -63.5, -63.5, -63.25, 0.0, 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant10(): + # WithVarsNoNudging_4Bits_RegularRange + x = np.array([-6.1, -6.0, -5.9, -5.5, 1.5, 1.6]).reshape(2, 3).astype(np.float32) + min_val = np.array([-6.0]).reshape(1).astype(np.float32) + max_val = np.array([1.5]).reshape(1).astype(np.float32) + expect = np.array([-6.0, -6.0, -6.0, -5.5, 1.5, 1.5]).astype(np.float32) + + net = Net(num_bits=4, narrow_range=False) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant11(): + # WithVarsNoNudging_4Bits_NarrowRange + x = np.array([-6.1, -6.0, -5.9, -5.5, 1.0, 1.1]).reshape(2, 3).astype(np.float32) + min_val = np.array([-6.0]).reshape(1).astype(np.float32) + max_val = np.array([1.0]).reshape(1).astype(np.float32) + expect = np.array([-6.0, -6.0, -6.0, -5.5, 1.0, 1.0]).astype(np.float32) + + net = Net(num_bits=4, narrow_range=True) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant12(): + # WithVarsNudgedDown_4Bits_RegularRange + x = np.array([-0.1, 0.0, 0.1, 0.5, 7.5, 7.6]).reshape(2, 3).astype(np.float32) + min_val = np.array([-0.1]).reshape(1).astype(np.float32) + max_val = np.array([7.4]).reshape(1).astype(np.float32) + expect = np.array([-0.0, 0.0, 0.0, 0.5, 7.5, 7.5]).astype(np.float32) + + net = Net(num_bits=4, narrow_range=False) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant13(): + # WithVarsNudgedDown_4Bits_NarrowRange + x = np.array([-0.1, 0.0, 0.1, 0.5, 7.0, 7.1]).reshape(2, 3).astype(np.float32) + min_val = np.array([-0.1]).reshape(1).astype(np.float32) + max_val = np.array([6.9]).reshape(1).astype(np.float32) + expect = np.array([-0.0, 0.0, 0.0, 0.5, 7.0, 7.0]).astype(np.float32) + + net = Net(num_bits=4, narrow_range=True) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant14(): + # WithVarsNudgedUp_4Bits_RegularRange + x = np.array([-0.6, -0.5, -0.24, 0.0, 7.0, 7.1]).reshape(2, 3).astype(np.float32) + min_val = np.array([-0.4]).reshape(1).astype(np.float32) + max_val = np.array([7.1]).reshape(1).astype(np.float32) + expect = np.array([-0.5, -0.5, -0.00, 0.0, 7.0, 7.0]).astype(np.float32) + + net = Net(num_bits=4, narrow_range=False) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant15(): + # WithVarsNudgedUp_4Bits_NarrowRange + x = np.array([-0.6, -0.5, -0.24, 0.0, 6.5, 6.6]).reshape(2, 3).astype(np.float32) + min_val = np.array([-0.4]).reshape(1).astype(np.float32) + max_val = np.array([6.6]).reshape(1).astype(np.float32) + expect = np.array([-0.5, -0.5, -0.00, 0.0, 6.5, 6.5]).astype(np.float32) + + net = Net(num_bits=4, narrow_range=True) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant16(): + # WithVarsNudgedZero15_4Bits_RegularRange + x = np.array([-7.6, -7.5, -7.4, -7.2, 0.0, 0.1]).reshape(2, 3).astype(np.float32) + min_val = np.array([-7.3]).reshape(1).astype(np.float32) + max_val = np.array([0.2]).reshape(1).astype(np.float32) + expect = np.array([-7.5, -7.5, -7.5, -7.0, 0.0, 0.0]).astype(np.float32) + + net = Net(num_bits=4, narrow_range=False) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant17(): + # WithVarsNudgedZero15_4Bits_NarrowRange + x = np.array([-7.1, -7.0, -6.9, -6.5, 0.0, 0.1]).reshape(2, 3).astype(np.float32) + min_val = np.array([-6.8]).reshape(1).astype(np.float32) + max_val = np.array([0.2]).reshape(1).astype(np.float32) + expect = np.array([-7.0, -7.0, -7.0, -6.5, 0.0, 0.0]).astype(np.float32) + + net = Net(num_bits=4, narrow_range=True) + output = net(Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) diff --git a/tests/st/ops/gpu/test_fake_quant_perlayer_grad.py b/tests/st/ops/gpu/test_fake_quant_perlayer_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..f8330eff9f19f5931eb1442d3f3cc9082aada8fb --- /dev/null +++ b/tests/st/ops/gpu/test_fake_quant_perlayer_grad.py @@ -0,0 +1,221 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +from mindspore import Tensor +import mindspore.nn as nn +import mindspore.context as context +from mindspore.ops.operations import _quant_ops as Q + +context.set_context(device_target='GPU', device_id=0) + + +class Net(nn.Cell): + def __init__(self, num_bits=8, narrow_range=False): + super(Net, self).__init__() + self.op = Q.FakeQuantPerLayerGrad(num_bits=num_bits, narrow_range=narrow_range) + + def construct(self, dout, x, minq, maxq): + return self.op(dout, x, minq, maxq) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad1(): + # WithArgsGradient RegularRange + dout = np.random.uniform(-1, 1, size=[6]).astype('float32') + x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6]).astype(np.float32) + min_val = np.array([-0.125]).reshape(1).astype(np.float32) + max_val = np.array([63.625]).reshape(1).astype(np.float32) + expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False) + output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad2(): + # WithArgsGradient NarrowRange + dout = np.random.uniform(-1, 1, size=[6]).astype('float32') + x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3]).astype(np.float32) + min_val = np.array([-0.125]).reshape(1).astype(np.float32) + max_val = np.array([63.375]).reshape(1).astype(np.float32) + expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True) + output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad3(): + # WithArgsGradient_4Bits_RegularRange + dout = np.random.uniform(-1, 1, size=[6]).astype('float32') + x = np.array([-0.6, -0.5, -0.4, 0.0, 7.0, 7.1]).astype(np.float32) + min_val = np.array([-0.4]).reshape(1).astype(np.float32) + max_val = np.array([7.1]).reshape(1).astype(np.float32) + expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32) + + net = Net(num_bits=4, narrow_range=False) + output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad4(): + # WithArgsGradient_4Bits_NarrowRange + dout = np.random.uniform(-1, 1, size=[6]).astype('float32') + x = np.array([-0.6, -0.5, -0.4, 0.0, 6.5, 6.6]).astype(np.float32) + min_val = np.array([-0.4]).reshape(1).astype(np.float32) + max_val = np.array([6.6]).reshape(1).astype(np.float32) + expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32) + + net = Net(num_bits=4, narrow_range=True) + output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad5(): + # FakeQuantWithMinMaxVarsGradient + dout = np.random.uniform(-1, 1, size=[6]).astype('float32') + x = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).astype(np.float32) + min_val = np.array([0.0]).reshape(1).astype(np.float32) + max_val = np.array([0.0]).reshape(1).astype(np.float32) + expect = dout + + net = Net(num_bits=8, narrow_range=True) + output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad6(): + # WithVarsGradient_RegularRange + dout = np.random.uniform(-1, 1, size=[6]).astype('float32') + x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6]).astype(np.float32) + min_val = np.array([-0.125]).reshape(1).astype(np.float32) + max_val = np.array([63.625]).reshape(1).astype(np.float32) + expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=False) + output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad7(): + # WithVarsGradient_NarrowRange + dout = np.random.uniform(-1, 1, size=[6]).astype('float32') + x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3]).astype(np.float32) + min_val = np.array([-0.125]).reshape(1).astype(np.float32) + max_val = np.array([63.375]).reshape(1).astype(np.float32) + expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32) + + net = Net(num_bits=8, narrow_range=True) + output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad8(): + # WithVarsGradient_4Bits_RegularRange + dout = np.random.uniform(-1, 1, size=[6]).astype('float32') + x = np.array([-0.6, -0.5, -0.4, 0.0, 7.0, 7.1]).astype(np.float32) + min_val = np.array([-0.4]).reshape(1).astype(np.float32) + max_val = np.array([7.1]).reshape(1).astype(np.float32) + expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32) + + net = Net(num_bits=4, narrow_range=False) + output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fake_quant_grad9(): + # WithVarsGradient_4Bits_NarrowRange + dout = np.random.uniform(-1, 1, size=[6]).astype('float32') + x = np.array([-0.6, -0.5, -0.4, 0.0, 6.5, 6.6]).astype(np.float32) + min_val = np.array([-0.4]).reshape(1).astype(np.float32) + max_val = np.array([6.6]).reshape(1).astype(np.float32) + expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32) + + net = Net(num_bits=4, narrow_range=True) + output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val)) + + error = np.ones(shape=expect.shape) * 1.0e-5 + diff = output.asnumpy().flatten() - expect + print("output: ", output) + print("expect: ", expect) + assert np.all(np.abs(diff) < error)