提交 684ecac9 编写于 作者: C chenzomi

rebase master to r0.5 for quantizaiton aware training

上级 412e4580
......@@ -252,7 +252,7 @@ checkopts()
done
}
checkopts "$@"
echo "---------------- mindspore: build start ----------------"
echo "---------------- mindSpore: build start ----------------"
mkdir -pv "${BUILD_PATH}/package/mindspore/lib"
git submodule update --init graphengine
if [[ "X$ENABLE_AKG" = "Xon" ]] && [[ "X$ENABLE_D" = "Xon" ]]; then
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/extrema.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include <thrust/pair.h>
#include "fake_quant_perchannel_impl.cuh"
/**
* Find the nudge min, max and scale value as output.
* @param input_min array
* @param input_max array
* @param quant_min 1 << bit -1
* @param quant_max 0
* @param nudge_min array
* @param nudge_max array
* @param scale array
* @param channel_num
* @return
*/
__global__ void NudgeMinMaxPerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, int channel_num,
const bool symmetric) {
float zp_from_min = 0.f;
float nudge_zp = 0.f;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_num; i += blockDim.x * gridDim.x) {
if (symmetric) {
input_max[i] = abs(input_min[0]) < input_max[i] ? input_max[i] : -input_min[i];
input_min[i] = abs(input_min[i]) < input_max[i] ? -input_max[i] : input_min[i];
}
if ((quant_max - quant_min) == 0 || (input_max[i] - input_min[i]) == 0) {
scale[i] = 0.f;
zp_from_min = 0.f;
} else {
scale[i] = (input_max[i] - input_min[i]) / (quant_max - quant_min);
zp_from_min = quant_min - input_min[i] / scale[i];
}
if (zp_from_min <= quant_min) {
nudge_zp = quant_min;
} else if (zp_from_min >= quant_max) {
nudge_zp = quant_max;
} else {
nudge_zp = round(zp_from_min);
}
nudge_min[i] = (quant_min - nudge_zp) * (scale[i]);
nudge_max[i] = (quant_max - nudge_zp) * (scale[i]);
}
}
void CalNudgePerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, const int channel_num, const bool symmetric,
cudaStream_t cuda_stream) {
NudgeMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(
input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, channel_num, symmetric);
}
/**
* Calulate fake quant output accroding by nudge min, nudge max, nudge scale.
* @param input - array
* @param output - array
* @param total_size - int, purpose for cal the per chanel number in filters
* @param channel_size - int, purpose for cal the per channel number in filters
* @param nudge_min - array
* @param nudge_max - array
* @param scale - array
* @return
*/
__global__ void FakeQuantPerChannel(const float *input, float *output, const int total_size, const int channel_size,
const float *nudge_min, const float *nudge_max, const float *scale) {
float input_x = 0.f;
int nudge_input = 0;
int channel_idx = 0;
int per_channel_num = total_size / channel_size;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < total_size; i += blockDim.x * gridDim.x) {
input_x = input[i];
channel_idx = floor(static_cast<double>(i) / static_cast<double>(per_channel_num));
// clamp input x
if (input_x < nudge_min[channel_idx]) {
input_x = nudge_min[channel_idx];
}
if (input_x > nudge_max[channel_idx]) {
input_x = nudge_max[channel_idx];
}
// clamp shift
nudge_input = floor((input_x - nudge_min[channel_idx]) / scale[channel_idx] + 0.5f);
// quantize
output[i] = nudge_input * scale[channel_idx] + nudge_min[channel_idx];
}
}
void CalFakeQuantPerChannel(const float *input, float *output, const int total_size, const int channel_size,
const float *nudge_min, const float *nudge_max, const float *scale,
cudaStream_t cuda_stream) {
FakeQuantPerChannel<<<GET_BLOCKS(total_size), GET_THREADS, 0, cuda_stream>>>(input, output, total_size, channel_size,
nudge_min, nudge_max, scale);
}
__global__ void FakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_size,
const int channel_size, const float *nudge_min, const float *nudge_max) {
int channel_idx = 0;
int per_channel_num = total_size / channel_size;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < total_size; i += blockDim.x * gridDim.x) {
channel_idx = floor(static_cast<double>(i) / static_cast<double>(per_channel_num));
if (input[i] < nudge_min[channel_idx] || input[i] > nudge_max[channel_idx]) {
output[i] = 0;
} else {
output[i] = gradient[i];
}
}
}
void CalFakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_num,
const int channel_num, const float *nudge_min, const float *nudge_max,
cudaStream_t cuda_stream) {
FakeQuantPerChannelGrad<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(input, gradient, output, total_num,
channel_num, nudge_min, nudge_max);
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_
#include "device/gpu/cuda_common.h"
void CalNudgePerChannel(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, const int channel_num, const bool symmetric,
cudaStream_t cuda_stream);
void CalFakeQuantPerChannel(const float *input, float *output, const int total_num, const int channel_num,
const float *nudge_min, const float *nudge_max, const float *scale,
cudaStream_t cuda_stream);
void CalFakeQuantPerChannelGrad(const float *input, const float *gradient, float *output, const int total_num,
const int channel_num, const float *nudge_min, const float *nudge_max,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/extrema.h>
#include <thrust/device_vector.h>
#include <thrust/pair.h>
#include "fake_quant_perlayer_impl.cuh"
__global__ void FakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min,
const float *nudge_max, const float *scale) {
float input_x = 0.f;
int nudge_input = 0;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
input_x = input[i];
// clamp input x
if (input_x < nudge_min[0]) {
input_x = nudge_min[0];
}
if (input_x > nudge_max[0]) {
input_x = nudge_max[0];
}
// clamp shift
nudge_input = round((input_x - nudge_min[0]) / scale[0]);
// quantize
output[i] = nudge_input * scale[0] + nudge_min[0];
}
return;
}
__global__ void FakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size,
const float *nudge_min, const float *nudge_max) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
if (input[i] < nudge_min[0] || input[i] > nudge_max[0]) {
output[i] = 0;
} else {
output[i] = gradient[i];
}
}
return;
}
__global__ void NudgeMinMaxPerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, const bool symmetric) {
float zp_from_min = 0.f;
scale[0] = 0.f;
nudge_max[0] = 0.f;
nudge_min[0] = 0.f;
if (symmetric) {
input_max[0] = abs(input_min[0]) < input_max[0] ? input_max[0] : -input_min[0];
input_min[0] = abs(input_min[0]) < input_max[0] ? -input_max[0] : input_min[0];
}
if ((quant_max - quant_min) == 0 || (input_max[0] - input_min[0]) == 0) {
scale[0] = 0.f;
zp_from_min = 0.f;
} else {
scale[0] = (input_max[0] - input_min[0]) / (quant_max - quant_min);
zp_from_min = quant_min - input_min[0] / scale[0];
}
float nudge_zp = 0.f;
if (zp_from_min <= quant_min) {
nudge_zp = quant_min;
} else if (zp_from_min >= quant_max) {
nudge_zp = quant_max;
} else {
nudge_zp = round(zp_from_min);
}
nudge_min[0] = (quant_min - nudge_zp) * (scale[0]);
nudge_max[0] = (quant_max - nudge_zp) * (scale[0]);
return;
}
void CalFakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min,
const float *nudge_max, const float *scale, cudaStream_t cuda_stream) {
FakeQuantPerLayer<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, output, size, nudge_min, nudge_max,
scale);
return;
}
void CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size,
const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream) {
FakeQuantPerLayerGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, gradient, output, size, nudge_min,
nudge_max);
return;
}
void CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, const bool symmetric,
cudaStream_t cuda_stream) {
NudgeMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale,
symmetric);
return;
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_
#include "device/gpu/cuda_common.h"
void CalNudgePerLayer(float *input_min, float *input_max, const float quant_min, const float quant_max,
float *nudge_min, float *nudge_max, float *scale, const bool symmetric, cudaStream_t cuda_stream);
void CalFakeQuantPerLayer(const float *input, float *output, const int size, const float *nudge_min,
const float *nudge_max, const float *scale, cudaStream_t cuda_stream);
void CalFakeQuantPerLayerGrad(const float *input, const float *gradient, float *output, const int size,
const float *nudge_min, const float *nudge_max, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/extrema.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include <thrust/pair.h>
#include "minmax_update_impl.cuh"
#include "device/gpu/cuda_common.h"
__global__ void UpdateInputMinMaxPerLayerWithEMA(const float *input_min, const float *input_max, float *output_min,
float *output_max, const float min, const float max,
const float decay) {
output_min[0] = decay * (min) + (1 - decay) * (input_min[0]);
output_min[0] = input_min[0] > 0 ? 0 : input_min[0];
output_max[0] = decay * (max) + (1 - decay) * (input_max[0]);
output_max[0] = input_max[0] < 0 ? 0 : input_max[0];
return;
}
__global__ void UpdateInputMinMaxPerLayer(float *output_min, float *output_max, const float min, const float max) {
output_min[0] = min > 0 ? 0 : min;
output_max[0] = max < 0 ? 0 : max;
return;
}
__global__ void UpdateInputMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min,
float *output_max, int channels, int per_channel_nums, bool ema,
float ema_decay) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) {
thrust::pair<float *, float *> sum =
thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1));
if (ema) {
output_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i];
output_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i];
} else {
output_min[i] = sum.first[0];
output_max[i] = sum.second[0];
}
output_min[i] = input_min[i] > 0 ? 0 : input_min[i];
output_max[i] = input_max[i] < 0 ? 0 : input_max[i];
}
return;
}
void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int total_num, const int channel_num, const float ema_decay, const bool ema,
cudaStream_t cuda_stream) {
int per_channel_num = total_num / channel_num;
UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(
input, input_min, input_max, output_min, output_max, channel_num, per_channel_num, ema, ema_decay);
return;
}
void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int total_num, const float ema_decay, const bool ema, cudaStream_t cuda_stream) {
float minel = 0.f;
float maxel = 0.f;
auto policy = thrust::cuda::par.on(cuda_stream);
thrust::pair<thrust::device_ptr<float>, thrust::device_ptr<float>> tuple;
tuple =
thrust::minmax_element(policy, thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + total_num);
minel = tuple.first[0];
maxel = tuple.second[0];
if (ema) {
UpdateInputMinMaxPerLayerWithEMA<<<1, 1, 0, cuda_stream>>>(input_min, input_max, output_min, output_max, minel,
maxel, ema_decay);
} else {
UpdateInputMinMaxPerLayer<<<1, 1, 0, cuda_stream>>>(output_min, output_max, minel, maxel);
}
return;
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_
#include "device/gpu/cuda_common.h"
void CalMinMaxPerChannel(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int total_num, const int channel_num, const float ema_decay, const bool ema,
cudaStream_t cuda_stream);
void CalMinMaxPerLayer(float *input, float *input_min, float *input_max, float *output_min, float *output_max,
const int size, const float ema_decay, const bool ema, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh"
#include "kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
......@@ -25,21 +25,15 @@ namespace mindspore {
namespace kernel {
FakeQuantPerChannelGpuKernel::FakeQuantPerChannelGpuKernel()
: input_size_(0),
min_size_(0),
max_size_(0),
output_size_(0),
workspace_size_(0),
num_channels_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_delay_(0),
ema_(false),
ema_decay_(0),
global_step_(0),
training_(false),
channel_out_(0),
symmetric_(false),
narrow_range_(false),
symmetric_(false) {}
quant_delay_(0),
quant_min_(0),
quant_max_(0),
global_step_(0) {}
const std::vector<size_t> &FakeQuantPerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; }
......@@ -60,91 +54,57 @@ bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
return false;
}
// get attribute
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = 1.0 - GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << "is out of range, expected between 2 and 16.";
return false;
}
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
if (quant_delay_ < 0) {
MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << " is less then 0, require larger than 0.";
return false;
}
training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
if (symmetric_) {
quant_min_ = 0 - (1 << (num_bits_ - 1));
quant_max_ = (1 << (num_bits_ - 1)) - 1;
} else {
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
}
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
// quant min and max value
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
// shape info for gpu
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
channel_out_ = SizeToInt(input_shape[0]);
min_size_ = sizeof(float) * channel_out_;
max_size_ = sizeof(float) * channel_out_;
num_channels_ = SizeToInt(input_shape[0]);
input_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
output_size_ = input_size_;
InitSizeLists();
return true;
}
void FakeQuantPerChannelGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // input in tensor
input_size_list_.push_back(min_size_); // min one scalar
input_size_list_.push_back(max_size_); // max on scalar
output_size_list_.push_back(output_size_); // output in tensor
workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel
workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel
workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel
}
void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForTraining(float *input, float *output, float *input_min,
float *input_max, float *d_nudge_min, float *d_nudge_max,
float *d_scale, void *stream_ptr) {
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMaxPerChannel(input, input_min, input_max, input_size_ / sizeof(float), channel_out_, ema_decay_, ema_,
reinterpret_cast<cudaStream_t>(stream_ptr));
// control flow for quant_delay
if (global_step_ >= quant_delay_) {
// real launch
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max,
d_scale, symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDA_RET_WITH_ERROR(
cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
"Copy gpu memory failed.");
}
global_step_++;
input_size_list_.push_back(input_size_); // input in tensor
input_size_list_.push_back(sizeof(float) * num_channels_); // min one scalar
input_size_list_.push_back(sizeof(float) * num_channels_); // max on scalar
output_size_list_.push_back(input_size_); // output in tensor
workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel
workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel
workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel
}
void FakeQuantPerChannelGpuKernel::CalFakeQuantizeForInfer(float *input, float *output, float *input_min,
float *input_max, float *d_nudge_min, float *d_nudge_max,
float *d_scale, void *stream_ptr) {
// real launch
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannel(input, output, input_size_ / sizeof(float), channel_out_, d_nudge_min, d_nudge_max, d_scale,
symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
void FakeQuantPerChannelGpuKernel::CalFakeQuantize(float *input, float *output, float *input_min, float *input_max,
float *nudge_min, float *nudge_max, float *scale, void *stream_ptr) {
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_,
symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantPerChannel(input, output, input_size_ / sizeof(float), num_channels_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
......@@ -155,9 +115,9 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
float *input = GetDeviceAddress<float>(inputs, 0);
float *input_min = GetDeviceAddress<float>(inputs, 1);
float *input_max = GetDeviceAddress<float>(inputs, 2);
float *d_scale = GetDeviceAddress<float>(workspace, 0);
float *d_nudge_min = GetDeviceAddress<float>(workspace, 1);
float *d_nudge_max = GetDeviceAddress<float>(workspace, 2);
float *scale = GetDeviceAddress<float>(workspace, 0);
float *nudge_min = GetDeviceAddress<float>(workspace, 1);
float *nudge_max = GetDeviceAddress<float>(workspace, 2);
if (input == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null.";
......@@ -167,9 +127,16 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
}
if (training_) {
CalFakeQuantizeForTraining(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr);
if (global_step_ >= quant_delay_) {
CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr);
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"Copy gpu memory failed.");
}
global_step_++;
} else {
CalFakeQuantizeForInfer(input, output, input_min, input_max, d_nudge_min, d_nudge_max, d_scale, stream_ptr);
CalFakeQuantize(input, output, input_min, input_max, nudge_min, nudge_max, scale, stream_ptr);
}
return true;
......
......@@ -39,31 +39,23 @@ class FakeQuantPerChannelGpuKernel : public GpuKernel {
void InitSizeLists() override;
private:
void CalFakeQuantizeForTraining(float *input, float *output, float *input_min, float *input_max, float *d_nudge_min,
float *d_nudge_max, float *d_scale, void *stream_ptr);
void CalFakeQuantizeForInfer(float *input, float *output, float *input_min, float *input_max, float *d_nudge_min,
float *d_nudge_max, float *d_scale, void *stream_ptr);
void CalFakeQuantize(float *input, float *output, float *input_min, float *input_max, float *nudge_min,
float *nudge_max, float *scale, void *stream_ptr);
size_t input_size_;
size_t min_size_;
size_t max_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_channels_;
int num_bits_;
bool training_;
bool symmetric_;
bool narrow_range_;
int quant_delay_;
float quant_min_;
float quant_max_;
int quant_delay_;
bool ema_;
float ema_decay_;
int global_step_;
bool training_;
int channel_out_;
bool narrow_range_;
bool symmetric_;
};
} // namespace kernel
} // namespace mindspore
......
......@@ -14,21 +14,17 @@
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh"
#include "kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh"
namespace mindspore {
namespace kernel {
FakeQuantPerChannelGradGpuKernel::FakeQuantPerChannelGradGpuKernel()
: input_size_(0),
min_size_(0),
max_size_(0),
output_size_(0),
workspace_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
channel_out_(0),
num_channels_(0),
quant_delay_(0),
global_step_(0),
narrow_range_(false),
......@@ -64,42 +60,34 @@ bool FakeQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) {
}
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
if (symmetric_) {
quant_min_ = 0 - (1 << (num_bits_ - 1));
quant_max_ = (1 << (num_bits_ - 1)) - 1;
} else {
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
}
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
// quant min and max value
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
channel_out_ = SizeToInt(input_shape[0]);
min_size_ = sizeof(float) * channel_out_;
max_size_ = sizeof(float) * channel_out_;
num_channels_ = SizeToInt(input_shape[0]);
input_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
output_size_ = input_size_;
InitSizeLists();
return true;
}
void FakeQuantPerChannelGradGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // gradient
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(min_size_); // min
input_size_list_.push_back(max_size_); // max
output_size_list_.push_back(output_size_);
workspace_size_list_.push_back(sizeof(float) * channel_out_); // scale in channel
workspace_size_list_.push_back(sizeof(float) * channel_out_); // min in channel
workspace_size_list_.push_back(sizeof(float) * channel_out_); // max in channel
input_size_list_.push_back(input_size_); // gradient
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(sizeof(float) * num_channels_); // min
input_size_list_.push_back(sizeof(float) * num_channels_); // max
output_size_list_.push_back(input_size_); // output
workspace_size_list_.push_back(sizeof(float) * num_channels_); // scale in channel
workspace_size_list_.push_back(sizeof(float) * num_channels_); // min in channel
workspace_size_list_.push_back(sizeof(float) * num_channels_); // max in channel
}
bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
......@@ -111,9 +99,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
float *input = GetDeviceAddress<float>(inputs, 1);
float *input_min = GetDeviceAddress<float>(inputs, 2);
float *input_max = GetDeviceAddress<float>(inputs, 3);
float *d_scale = GetDeviceAddress<float>(workspace, 0);
float *d_nudge_min = GetDeviceAddress<float>(workspace, 1);
float *d_nudge_max = GetDeviceAddress<float>(workspace, 2);
float *scale = GetDeviceAddress<float>(workspace, 0);
float *nudge_min = GetDeviceAddress<float>(workspace, 1);
float *nudge_max = GetDeviceAddress<float>(workspace, 2);
if (gradient == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel gradient is null";
......@@ -130,10 +118,10 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
int total_size = input_size_ / sizeof(float);
if (global_step_ >= quant_delay_) {
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, channel_out_, d_nudge_min, d_nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, num_channels_,
symmetric_, reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantPerChannelGrad(input, gradient, output, total_size, num_channels_, nudge_min, nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
......
......@@ -40,10 +40,6 @@ class FakeQuantPerChannelGradGpuKernel : public GpuKernel {
private:
size_t input_size_;
size_t min_size_;
size_t max_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
......@@ -51,7 +47,7 @@ class FakeQuantPerChannelGradGpuKernel : public GpuKernel {
int num_bits_;
float quant_min_;
float quant_max_;
int channel_out_;
int num_channels_;
int quant_delay_;
int global_step_;
bool narrow_range_;
......
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_impl.cuh"
#include "kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
......@@ -23,31 +23,25 @@
namespace mindspore {
namespace kernel {
FakeQuantGpuKernel::FakeQuantGpuKernel()
FakeQuantPerLayerGpuKernel::FakeQuantPerLayerGpuKernel()
: input_size_(0),
min_size_(0),
max_size_(0),
output_size_(0),
workspace_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_num_(0),
quant_delay_(0),
ema_(false),
ema_decay_(0),
quant_num_(1),
global_step_(0),
num_bits_(0),
quant_delay_(0),
training_(false),
narrow_range_(false),
symmetric_(false) {}
const std::vector<size_t> &FakeQuantGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &FakeQuantGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &FakeQuantGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) {
bool FakeQuantPerLayerGpuKernel::Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output.";
......@@ -59,96 +53,74 @@ bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) {
}
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
}
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
if (quant_delay_ < 0) {
MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << "is less then 0, require larger than 0.";
}
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
if (symmetric_) {
quant_min_ = 0 - (1 << (num_bits_ - 1));
quant_max_ = (1 << (num_bits_ - 1)) - 1;
} else {
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
}
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
// quant min and max value
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
if (quant_num_ == 0) {
quant_num_ = 1;
}
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shape.size(); ++i) {
quant_num_ *= SizeToInt(input_shape[i]);
}
input_size_ = sizeof(float);
min_size_ = sizeof(float);
max_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
output_size_ = input_size_;
InitSizeLists();
return true;
}
void FakeQuantGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(min_size_); // min
input_size_list_.push_back(max_size_); // max
output_size_list_.push_back(output_size_);
workspace_size_list_.push_back(workspace_size_);
void FakeQuantPerLayerGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // x
input_size_list_.push_back(sizeof(float)); // min
input_size_list_.push_back(sizeof(float)); // max
output_size_list_.push_back(input_size_); // y
workspace_size_list_.push_back(sizeof(float)); // scale
workspace_size_list_.push_back(sizeof(float)); // nudge_min
workspace_size_list_.push_back(sizeof(float)); // nudge_max
}
bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
bool FakeQuantPerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
float *output = GetDeviceAddress<float>(outputs, 0);
float *input = GetDeviceAddress<float>(inputs, 0);
float *input_min = GetDeviceAddress<float>(inputs, 1);
float *input_max = GetDeviceAddress<float>(inputs, 2);
float *scale = GetDeviceAddress<float>(workspace, 0);
float *nudge_min = GetDeviceAddress<float>(workspace, 1);
float *nudge_max = GetDeviceAddress<float>(workspace, 2);
if (input == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input x is null.";
}
if (input_min == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input min is null.";
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input x is null.";
}
if (input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input max is null.";
if (input_min == nullptr || input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGpuKernel input min or input max is null.";
}
// Allocate space for device copies
int size = sizeof(float);
float *d_scale = nullptr;
float *d_nudge_min = nullptr;
float *d_nudge_max = nullptr;
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), size), "Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), size), "Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), size), "Malloc gpu memory failed");
if (training_) {
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMax(input, input_min, input_max, quant_num_, ema_decay_, ema_, reinterpret_cast<cudaStream_t>(stream_ptr));
// control flow for quant_delay
if (global_step_ >= quant_delay_) {
// real launch
CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, input, input_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
......@@ -157,20 +129,15 @@ bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std
global_step_++;
} else {
// real launch
CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantPerLayer(input, output, quant_num_, nudge_min, nudge_max, scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
// Cleanup
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed");
return true;
}
MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerLayer, FakeQuantPerLayerGpuKernel)
} // namespace kernel
} // namespace mindspore
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
......@@ -23,10 +23,10 @@
namespace mindspore {
namespace kernel {
class FakeQuantGpuKernel : public GpuKernel {
class FakeQuantPerLayerGpuKernel : public GpuKernel {
public:
FakeQuantGpuKernel();
~FakeQuantGpuKernel() = default;
FakeQuantPerLayerGpuKernel();
~FakeQuantPerLayerGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
......@@ -40,22 +40,16 @@ class FakeQuantGpuKernel : public GpuKernel {
private:
size_t input_size_;
size_t min_size_;
size_t max_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_bits_;
float quant_min_;
float quant_max_;
int quant_num_;
int quant_delay_;
bool ema_;
float ema_decay_;
int global_step_;
int num_bits_;
int quant_delay_;
bool training_;
bool narrow_range_;
bool symmetric_;
......@@ -63,4 +57,4 @@ class FakeQuantGpuKernel : public GpuKernel {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GPUKERNEL_H_
......@@ -14,33 +14,30 @@
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_impl.cuh"
#include "kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh"
namespace mindspore {
namespace kernel {
FakeQuantGradGpuKernel::FakeQuantGradGpuKernel()
FakeQuantPerLayerGradGpuKernel::FakeQuantPerLayerGradGpuKernel()
: input_size_(0),
min_size_(0),
max_size_(0),
output_size_(0),
workspace_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_size_(0),
quant_num_(1),
quant_delay_(0),
global_step_(0),
narrow_range_(false),
symmetric_(false) {}
const std::vector<size_t> &FakeQuantGradGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGradGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &FakeQuantGradGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGradGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &FakeQuantGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
const std::vector<size_t> &FakeQuantPerLayerGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) {
bool FakeQuantPerLayerGradGpuKernel::Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 4) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output.";
......@@ -62,87 +59,66 @@ bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) {
}
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
if (symmetric_) {
quant_min_ = 0 - (1 << (num_bits_ - 1));
quant_max_ = (1 << (num_bits_ - 1)) - 1;
} else {
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
}
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
// quant min and max value
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (narrow_range_) {
quant_min_++;
}
if (quant_size_ == 0) {
quant_size_ = 1;
}
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shape.size(); ++i) {
quant_size_ *= SizeToInt(input_shape[i]);
quant_num_ *= SizeToInt(input_shape[i]);
}
input_size_ = sizeof(float);
min_size_ = sizeof(float);
max_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
output_size_ = input_size_;
InitSizeLists();
return true;
}
void FakeQuantGradGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // gradient
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(min_size_); // min
input_size_list_.push_back(max_size_); // max
output_size_list_.push_back(output_size_);
void FakeQuantPerLayerGradGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // gradient
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(sizeof(float)); // min
input_size_list_.push_back(sizeof(float)); // max
output_size_list_.push_back(input_size_); // output
workspace_size_list_.push_back(sizeof(float)); // scale
workspace_size_list_.push_back(sizeof(float)); // nudge_min
workspace_size_list_.push_back(sizeof(float)); // nudge_max
}
bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
bool FakeQuantPerLayerGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
float *output = GetDeviceAddress<float>(outputs, 0);
float *gradient = GetDeviceAddress<float>(inputs, 0);
float *input = GetDeviceAddress<float>(inputs, 1);
float *input_min = GetDeviceAddress<float>(inputs, 2);
float *input_max = GetDeviceAddress<float>(inputs, 3);
float *scale = GetDeviceAddress<float>(workspace, 0);
float *nudge_min = GetDeviceAddress<float>(workspace, 1);
float *nudge_max = GetDeviceAddress<float>(workspace, 2);
if (gradient == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel gradient is null";
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel gradient is null";
}
if (input == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input is null.";
}
if (input_min == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input min is null.";
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input is null.";
}
if (input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input max is null.";
if (input_min == nullptr || input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerLayerGradGpuKernel input min or max is null.";
}
if (global_step_ >= quant_delay_) {
float *d_scale = nullptr;
float *d_nudge_min = nullptr;
float *d_nudge_max = nullptr;
int size = sizeof(float);
// Allocate space for device copies
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), size), "Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), size), "Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), size), "Malloc gpu memory failed");
CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizeGrad(input, gradient, output, quant_size_, d_nudge_min, d_nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
// Cleanup
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed");
CalNudgePerLayer(input_min, input_max, quant_min_, quant_max_, nudge_min, nudge_max, scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantPerLayerGrad(input, gradient, output, quant_num_, nudge_min, nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpyAsync(output, gradient, input_size_, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
......@@ -152,6 +128,6 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const
return true;
}
MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantGradGpuKernel)
MS_REG_GPU_KERNEL(FakeQuantPerLayerGrad, FakeQuantPerLayerGradGpuKernel)
} // namespace kernel
} // namespace mindspore
......@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
......@@ -23,10 +23,10 @@
namespace mindspore {
namespace kernel {
class FakeQuantGradGpuKernel : public GpuKernel {
class FakeQuantPerLayerGradGpuKernel : public GpuKernel {
public:
FakeQuantGradGpuKernel();
~FakeQuantGradGpuKernel() = default;
FakeQuantPerLayerGradGpuKernel();
~FakeQuantPerLayerGradGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
......@@ -40,9 +40,6 @@ class FakeQuantGradGpuKernel : public GpuKernel {
private:
size_t input_size_;
size_t min_size_;
size_t max_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
......@@ -51,7 +48,7 @@ class FakeQuantGradGpuKernel : public GpuKernel {
int num_bits_;
float quant_min_;
float quant_max_;
int quant_size_;
int quant_num_;
int quant_delay_;
int global_step_;
bool narrow_range_;
......@@ -60,4 +57,4 @@ class FakeQuantGradGpuKernel : public GpuKernel {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PERLAYER_GRAD_GPUKERNEL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/minmax_update_impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
#include <cuda_runtime_api.h>
namespace mindspore {
namespace kernel {
MinMaxUpdatePerChannelGpuKernel::MinMaxUpdatePerChannelGpuKernel()
: input_size_(0), quant_num_(1), ema_(false), ema_decay_(0), num_channels_(0) {}
const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &MinMaxUpdatePerChannelGpuKernel::GetWorkspaceSizeList() const {
return workspace_size_list_;
}
bool MinMaxUpdatePerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 2) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output.";
}
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
num_channels_ = SizeToInt(input_shape[0]);
for (size_t i = 0; i < input_shape.size(); ++i) {
quant_num_ *= SizeToInt(input_shape[i]);
}
input_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
InitSizeLists();
return true;
}
void MinMaxUpdatePerChannelGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(sizeof(float) * num_channels_); // min
input_size_list_.push_back(sizeof(float) * num_channels_); // max
output_size_list_.push_back(sizeof(float) * num_channels_); // output min
output_size_list_.push_back(sizeof(float) * num_channels_); // output max
}
bool MinMaxUpdatePerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
float *output_min = GetDeviceAddress<float>(outputs, 0);
float *output_max = GetDeviceAddress<float>(outputs, 1);
float *input = GetDeviceAddress<float>(inputs, 0);
float *input_min = GetDeviceAddress<float>(inputs, 1);
float *input_max = GetDeviceAddress<float>(inputs, 2);
if (input == nullptr) {
MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input x is null.";
}
if (input_min == nullptr || input_max == nullptr) {
MS_LOG(EXCEPTION) << "MinMaxUpdatePerChannelGpuKernel input min or input max is null.";
}
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMaxPerChannel(input, input_min, input_max, output_min, output_max, input_size_ / sizeof(float), num_channels_,
ema_decay_, ema_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
MS_REG_GPU_KERNEL(MinMaxUpdatePerChannel, MinMaxUpdatePerChannelGpuKernel)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class MinMaxUpdatePerChannelGpuKernel : public GpuKernel {
public:
MinMaxUpdatePerChannelGpuKernel();
~MinMaxUpdatePerChannelGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const CNodePtr &kernel) override;
protected:
void InitSizeLists() override;
private:
size_t input_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int quant_num_;
bool ema_;
float ema_decay_;
int num_channels_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/minmax_update_impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
#include <cuda_runtime_api.h>
namespace mindspore {
namespace kernel {
MinMaxUpdatePerLayerGpuKernel::MinMaxUpdatePerLayerGpuKernel()
: input_size_(0), quant_num_(1), ema_(false), ema_decay_(0) {}
const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &MinMaxUpdatePerLayerGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool MinMaxUpdatePerLayerGpuKernel::Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 2) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output.";
}
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
// init size
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shape.size(); ++i) {
quant_num_ *= SizeToInt(input_shape[i]);
}
input_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
InitSizeLists();
return true;
}
void MinMaxUpdatePerLayerGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(sizeof(float)); // input min
input_size_list_.push_back(sizeof(float)); // input max
output_size_list_.push_back(sizeof(float)); // output min
output_size_list_.push_back(sizeof(float)); // output max
}
bool MinMaxUpdatePerLayerGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
float *output_min = GetDeviceAddress<float>(outputs, 0);
float *output_max = GetDeviceAddress<float>(outputs, 1);
float *input = GetDeviceAddress<float>(inputs, 0);
float *input_min = GetDeviceAddress<float>(inputs, 1);
float *input_max = GetDeviceAddress<float>(inputs, 2);
if (input == nullptr) {
MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input x is null.";
}
if (input_min == nullptr || input_max == nullptr) {
MS_LOG(EXCEPTION) << "MinMaxUpdatePerLayerGpuKernel input min or input max is null.";
}
CalMinMaxPerLayer(input, input_min, input_max, output_min, output_max, quant_num_, ema_decay_, ema_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
MS_REG_GPU_KERNEL(MinMaxUpdatePerLayer, MinMaxUpdatePerLayerGpuKernel)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class MinMaxUpdatePerLayerGpuKernel : public GpuKernel {
public:
MinMaxUpdatePerLayerGpuKernel();
~MinMaxUpdatePerLayerGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const CNodePtr &kernel) override;
protected:
void InitSizeLists() override;
private:
size_t input_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int quant_num_;
bool ema_;
float ema_decay_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_
......@@ -324,6 +324,7 @@ class FakeQuantWithMinMax(Cell):
validator.check_type("min_init", min_init, [int, float])
validator.check_type("max_init", max_init, [int, float])
validator.check("min_init", min_init, "max_init", max_init, rel=Rel.LT)
validator.check_integer('quant_delay', quant_delay, 0, Rel.GE)
self.min_init = min_init
self.max_init = max_init
self.num_bits = num_bits
......
......@@ -106,7 +106,7 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
channel_axis (int): Quantization by channel axis, support 0 and 1. Default: 1.
channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
......@@ -123,12 +123,13 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
"""
support_quant_bit = [4, 7, 8]
support_x_rank = [2, 4]
ascend_support_x_rank = [2, 4]
@prim_attr_register
def __init__(self, ema=False, ema_decay=0.999, channel_axis=1):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if context.get_context('device_target') == "Ascend":
self.is_ascend = context.get_context('device_target') == "Ascend"
if self.is_ascend:
from mindspore.ops._op_impl._custom_op import minmax_update_perchannel
if ema and not ema_decay:
raise ValueError(
......@@ -137,15 +138,18 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
self.ema = validator.check_value_type('ema', ema, (bool,), self.name)
self.ema_decay = validator.check_number_range(
'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name)
self.channel_axis = validator.check_int_range(
'channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
if self.is_ascend:
self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
else:
self.channel_axis = validator.check_integer('channel_axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(
inputs=['x', 'min', 'max'], outputs=['min_up', 'max_up'])
def infer_shape(self, x_shape, min_shape, max_shape):
if len(x_shape) not in self.support_x_rank:
raise ValueError(f"For '{self.name}' x rank should be in '{self.support_x_rank}'")
validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name)
if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
if not self.is_ascend:
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape",
max_shape, Rel.EQ, self.name)
validator.check_integer("min shape", len(
......@@ -317,7 +321,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
channel_axis (int): Quantization by channel axis, support 0 and 1. Default: 1.
channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
Inputs:
- **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
......@@ -335,7 +339,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
>>> result = fake_quant(input_x, _min, _max)
"""
support_quant_bit = [4, 7, 8]
support_x_rank = [2, 4]
ascend_support_x_rank = [2, 4]
@prim_attr_register
def __init__(self,
......@@ -348,7 +352,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
training=True,
channel_axis=1):
"""init FakeQuantPerChannel OP"""
if context.get_context('device_target') == "Ascend":
self.is_ascend = context.get_context('device_target') == "Ascend"
if self.is_ascend:
from mindspore.ops._op_impl._custom_op import fake_quant_perchannel
if num_bits not in self.support_quant_bit:
raise ValueError(
......@@ -370,13 +375,17 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
'num_bits', num_bits, 0, Rel.GT, self.name)
self.quant_delay = validator.check_integer(
'quant_delay', quant_delay, 0, Rel.GE, self.name)
self.channel_axis = validator.check_int_range(
'channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
if self.is_ascend:
self.channel_axis = validator.check_int_range('channel_axis', channel_axis, 0, 1, Rel.INC_BOTH, self.name)
else:
self.channel_axis = validator.check_integer('channel_axis', channel_axis, 0, Rel.GE, self.name)
self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out'])
def infer_shape(self, x_shape, min_shape, max_shape):
if len(x_shape) not in self.support_x_rank:
raise ValueError(f"For '{self.name}' x rank should be in '{self.support_x_rank}'")
if self.is_ascend and len(x_shape) not in self.ascend_support_x_rank:
raise ValueError(f"For '{self.name}' x rank should be in '{self.ascend_support_x_rank}'")
if not self.is_ascend:
validator.check_integer("x rank", len(x_shape), 1, Rel.GE, self.name)
validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name)
validator.check_integer(
"min shape", min_shape[0], x_shape[self.channel_axis], Rel.EQ, self.name)
......
......@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""aware quantization."""
"""quantization aware."""
import copy
import re
import numpy as np
import mindspore.context as context
from ... import log as logger
from ... import nn, ops
......@@ -32,6 +33,7 @@ from ...ops.operations import _inner_ops as inner
from ...train import serialization
from . import quant_utils
_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
nn.ReLU6: quant.ReLU6Quant,
nn.HSigmoid: quant.HSigmoidQuant,
......@@ -46,7 +48,7 @@ class _AddFakeQuantInput(nn.Cell):
def __init__(self, network, quant_delay=0):
super(_AddFakeQuantInput, self).__init__(auto_prefix=False)
self.fake_quant_input = quant.FakeQuantWithMinMax(min_init=-6, max_init=6, quant_delay=quant_delay, ema=True)
self.fake_quant_input.update_parameters_name('fake_quant_input')
self.fake_quant_input.update_parameters_name('fake_quant_input.')
self.network = network
def construct(self, data):
......@@ -165,8 +167,8 @@ class ConvertToQuantNetwork:
convert Conv2d cell to quant cell
"""
conv_inner = subcell.conv
bn_inner = subcell.batchnorm
if subcell.has_bn and self.bn_fold:
bn_inner = subcell.batchnorm
conv_inner = quant.Conv2dBatchNormQuant(conv_inner.in_channels,
conv_inner.out_channels,
kernel_size=conv_inner.kernel_size,
......@@ -421,26 +423,26 @@ def convert_quant_network(network,
Args:
network (Cell): Obtain a pipeline through network for saving graph summary.
quant_delay (int): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: [0, 0]
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False.
freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0.
num_bits (list of int): Number of bits to use for quantizing weights and activations. The first
element represent weights and second element represent data flow. Default: [8, 8]
per_channel (list of bool): Quantization granularity based on layer or on channel. If `True`
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
num_bits (int, list or tuple): Number of bits to use for quantizing weights and activations. The first
element represent weights and second element represent data flow. Default: (8, 8)
per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
then base on per channel otherwise base on per layer. The first element represent weights
and second element represent data flow. Default: [False, False]
symmetric (list of bool): Quantization algorithm use symmetric or not. If `True` then base on
and second element represent data flow. Default: (False, False)
symmetric (bool, list or tuple): Quantization algorithm use symmetric or not. If `True` then base on
symmetric otherwise base on asymmetric. The first element represent weights and second
element represent data flow. Default: [False, False]
narrow_range (list of bool): Quantization algorithm use narrow range or not. If `True` then base
element represent data flow. Default: (False, False)
narrow_range (bool, list or tuple): Quantization algorithm use narrow range or not. If `True` then base
on narrow range otherwise base on off narrow range. The first element represent weights and
second element represent data flow. Default: [False, False]
second element represent data flow. Default: (False, False)
Returns:
Cell, Network which has change to aware quantization training network cell.
Cell, Network which has change to quantization aware training network cell.
"""
support_device = ["Ascend", "GPU"]
def convert2list(name, value):
if not isinstance(value, list) and not isinstance(value, tuple):
value = [value]
......@@ -454,6 +456,9 @@ def convert_quant_network(network,
symmetric = convert2list("symmetric", symmetric)
narrow_range = convert2list("narrow range", narrow_range)
if context.get_context('device_target') not in support_device:
raise KeyError("Not support {} backend.".format(context.get_context('device_target')))
net = ConvertToQuantNetwork(network=network,
quant_delay=quant_delay,
bn_fold=bn_fold,
......
此差异已折叠。
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops.operations import _quant_ops as Q
context.set_context(device_target='GPU', device_id=0)
class Net(nn.Cell):
def __init__(self, num_bits=8, narrow_range=False):
super(Net, self).__init__()
self.op = Q.FakeQuantPerChannelGrad(
num_bits=num_bits, narrow_range=narrow_range)
def construct(self, dout, x, minq, maxq):
return self.op(dout, x, minq, maxq)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad1():
# WithVarsPerChannelDim1GradientNudgedDown_ZeroMinAndMax
dout = np.random.uniform(-1, 1, size=[4]).astype('float32')
x = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32)
min_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32)
max_val = np.array([0.0, 0.0, 0.0, 0.0]).astype(np.float32)
expect = dout
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad2():
# WithVarsPerChannelDim1GradientNudgedDown_RegularRange
dout = np.random.uniform(-1, 1, size=[4]).astype('float32')
x = np.array([-0.1, 0.0, 63.75, 63.8]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.65, 63.65, 63.65, 63.65]).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad3():
# WithVarsPerChannelDim1GradientNudgedDown_NarrowRange
dout = np.random.uniform(-1, 1, size=[4]).astype('float32')
x = np.array([-0.1, 0.0, 63.5, 63.6]).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.4, 63.4, 63.4, 63.4]).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad4():
# WithVarsPerChannelDim1GradientNudgedUp_RegularRange
dout = np.random.uniform(-1, 1, size=[4]).astype('float32')
x = np.array([-0.3, -0.25, 63.5, 63.6]).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.625, 63.625, 63.625, 63.625]).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad5():
# WithVarsPerChannelDim1GradientNudgedUp_NarrowRange
dout = np.random.uniform(-1, 1, size=[4]).astype('float32')
x = np.array([-0.3, -0.25, 63.25, 63.3]).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.375, 63.375, 63.375, 63.375]).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad6():
# WithVarsPerChannelDim2GradientNudgedDown_RegularRange
read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32')
x = np.array([-0.1, 0.0, 0.1, 0.25, 63.75, 63.8]
).reshape(3, 2).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.65, 63.65, 63.65]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], dout[3],
dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad7():
# WithVarsPerChannelDim2GradientNudgedDown_NarrowRange
read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32')
x = np.array([-0.1, 0.0, 0.1, 0.25, 63.5, 63.6]
).reshape(3, 2).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.4, 63.4, 63.4]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], dout[3],
dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad8():
# WithVarsPerChannelDim2GradientNudgedUp_RegularRange
read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32')
x = np.array([-0.3, -0.25, -0.2, 0.0, 63.5, 63.6]
).reshape(3, 2).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.625, 63.625, 63.625]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], dout[3],
dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad9():
# WithVarsPerChannelDim2GradientNudgedUp_NarrowRange
read_dout = np.random.uniform(-1, 1, size=[3, 2]).astype('float32')
x = np.array([-0.3, -0.25, -0.2, 0.0, 63.25, 63.3]
).reshape(3, 2).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.375, 63.375, 63.375]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], dout[3],
dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad10():
# WithVarsPerChannelDim4GradientNudgedDown_RegularRange
read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32')
x = np.array([-0.1, 0.0, 63.75, 63.8, -0.1, 0.0,
63.75, 63.8, -0.1, 0.0, 63.75, 63.8,
-0.1, 0.0, 63.75, 63.8, -0.1, 0.0,
63.75, 63.8, -0.1, 0.0, 63.75, 63.8]).reshape(4, 3, 2, 1).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.65, 63.65, 63.65, 63.65]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], 0.0,
0.0, dout[5], dout[6], 0.0,
0.0, dout[9], dout[10], 0.0,
0.0, dout[13], dout[14], 0.0,
0.0, dout[17], dout[18], 0.0,
0.0, dout[21], dout[22], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad11():
# WithVarsPerChannelDim4GradientNudgedDown_NarrowRange
read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32')
x = np.array([-0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5,
63.6, -0.1, 0.0, 63.5, 63.6, -0.1, 0.0, 63.5, 63.6]).reshape(4, 3, 2, 1).astype(np.float32)
min_val = np.array([-0.1, -0.1, -0.1, -0.1]).astype(np.float32)
max_val = np.array([63.4, 63.4, 63.4, 63.4]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], 0.0,
0.0, dout[5], dout[6], 0.0,
0.0, dout[9], dout[10], 0.0,
0.0, dout[13], dout[14], 0.0,
0.0, dout[17], dout[18], 0.0,
0.0, dout[21], dout[22], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad12():
# WithVarsPerChannelDim4GradientNudgedUp_RegularRange
read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32')
x = np.array([-0.3, -0.25, 63.5, 63.6, -0.3, -0.25,
63.5, 63.6, -0.3, -0.25, 63.5, 63.6,
-0.3, -0.25, 63.5, 63.6, -0.3, -0.25,
63.5, 63.6, -0.3, -0.25, 63.5, 63.6]).reshape(4, 3, 2, 1).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.625, 63.625, 63.625, 63.625]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], 0.0,
0.0, dout[5], dout[6], 0.0,
0.0, dout[9], dout[10], 0.0,
0.0, dout[13], dout[14], 0.0,
0.0, dout[17], dout[18], 0.0,
0.0, dout[21], dout[22], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad13():
# WithVarsPerChannelDim4GradientNudgedUp_NarrowRange
read_dout = np.random.uniform(-1, 1, size=[4, 3, 2, 1]).astype('float32')
x = np.array([-0.3, -0.25, 63.25, 63.3, -0.3, -0.25,
63.25, 63.3, -0.3, -0.25, 63.25, 63.3,
-0.3, -0.25, 63.25, 63.3, -0.3, -0.25,
63.25, 63.3, -0.3, -0.25, 63.25, 63.3]).reshape(4, 3, 2, 1).astype(np.float32)
min_val = np.array([-0.125, -0.125, -0.125, -0.125]).astype(np.float32)
max_val = np.array([63.375, 63.375, 63.375, 63.375]).astype(np.float32)
dout = read_dout.flatten()
expect = np.array([0.0, dout[1], dout[2], 0.0,
0.0, dout[5], dout[6], 0.0,
0.0, dout[9], dout[10], 0.0,
0.0, dout[13], dout[14], 0.0,
0.0, dout[17], dout[18], 0.0,
0.0, dout[21], dout[22], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(read_dout), Tensor(
x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("=" * 40)
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore.common.tensor import Tensor
import mindspore.nn as nn
from mindspore.ops.operations import _quant_ops as Q
context.set_context(device_target='GPU', device_id=0)
class Net(nn.Cell):
def __init__(self,
num_bits=8,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True):
super(Net, self).__init__()
self.fake_quant = Q.FakeQuantPerLayer(num_bits=num_bits,
quant_delay=quant_delay,
symmetric=symmetric,
narrow_range=narrow_range,
training=training)
def construct(self, x, minq, maxq):
return self.fake_quant(x, minq, maxq)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant1():
# (8, false, 0.0f, 0.0f, TensorShape({2, 3}),
# {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f},
# {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
x = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).reshape(2, 3).astype(np.float32)
min_val = np.array([0]).reshape(1).astype(np.float32)
max_val = np.array([0]).reshape(1).astype(np.float32)
expect = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant2():
# 8, false, -10.0f, 53.75f, TensorShape({2, 3}),
# {-10.1f, -10.0f, -9.9f, -9.75f, 53.75f, 53.8f},
# {-10.0f, -10.0f, -10.0f, -9.75f, 53.75f, 53.75f});
x = np.array([-10.1, -10.0, -9.9, -9.75, 53.75, 53.8]).reshape(2, 3).astype(np.float32)
min_val = np.array([-10.0]).reshape(1).astype(np.float32)
max_val = np.array([53.75]).reshape(1).astype(np.float32)
expect = np.array([-10.0, -10.0, -10.0, -9.75, 53.75, 53.75]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant3():
# WithVarsNoNudging_NarrowRange
x = np.array([-10.1, -10.0, -9.90, -9.75, 53.5, 53.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-10.0]).reshape(1).astype(np.float32)
max_val = np.array([53.5]).reshape(1).astype(np.float32)
expect = np.array([-10.0, -10.0, -10.0, -9.75, 53.5, 53.5]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant4():
# WithVarsNudgedDown_RegularRange
x = np.array([-0.1, 0.0, 0.1, 0.25, 63.75, 63.8]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.1]).reshape(1).astype(np.float32)
max_val = np.array([63.65]).reshape(1).astype(np.float32)
expect = np.array([-0.0, 0.0, 0.0, 0.25, 63.75, 63.75]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant5():
# WithVarsNudgedDown_NarrowRange
x = np.array([-0.1, 0.0, 0.1, 0.25, 63.5, 63.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.1]).reshape(1).astype(np.float32)
max_val = np.array([63.4]).reshape(1).astype(np.float32)
expect = np.array([-0.0, 0.0, 0.0, 0.25, 63.5, 63.5]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant6():
# WithVarsNudgedUp_RegularRange
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.625]).reshape(1).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 0.0, 63.5, 63.5]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant7():
# WithVarsNudgedUp_NarrowRange
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.375]).reshape(1).astype(np.float32)
expect = np.array([-0.25, -0.25, -0.25, 0.0, 63.25, 63.25]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant8():
# WithVarsNudgedZeroIs255_RegularRange
x = np.array([-63.80, -63.75, -63.70, -63.5, 0.0, 0.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-63.65]).reshape(1).astype(np.float32)
max_val = np.array([0.1]).reshape(1).astype(np.float32)
expect = np.array([-63.75, -63.75, -63.75, -63.5, 0.0, 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant9():
# WithVarsNudgedZeroIs255_NarrowRange
x = np.array([-63.6, -63.5, -63.4, -63.25, 0.0, 0.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-63.4]).reshape(1).astype(np.float32)
max_val = np.array([0.1]).reshape(1).astype(np.float32)
expect = np.array([-63.5, -63.5, -63.5, -63.25, 0.0, 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant10():
# WithVarsNoNudging_4Bits_RegularRange
x = np.array([-6.1, -6.0, -5.9, -5.5, 1.5, 1.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-6.0]).reshape(1).astype(np.float32)
max_val = np.array([1.5]).reshape(1).astype(np.float32)
expect = np.array([-6.0, -6.0, -6.0, -5.5, 1.5, 1.5]).astype(np.float32)
net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant11():
# WithVarsNoNudging_4Bits_NarrowRange
x = np.array([-6.1, -6.0, -5.9, -5.5, 1.0, 1.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-6.0]).reshape(1).astype(np.float32)
max_val = np.array([1.0]).reshape(1).astype(np.float32)
expect = np.array([-6.0, -6.0, -6.0, -5.5, 1.0, 1.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant12():
# WithVarsNudgedDown_4Bits_RegularRange
x = np.array([-0.1, 0.0, 0.1, 0.5, 7.5, 7.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.1]).reshape(1).astype(np.float32)
max_val = np.array([7.4]).reshape(1).astype(np.float32)
expect = np.array([-0.0, 0.0, 0.0, 0.5, 7.5, 7.5]).astype(np.float32)
net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant13():
# WithVarsNudgedDown_4Bits_NarrowRange
x = np.array([-0.1, 0.0, 0.1, 0.5, 7.0, 7.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.1]).reshape(1).astype(np.float32)
max_val = np.array([6.9]).reshape(1).astype(np.float32)
expect = np.array([-0.0, 0.0, 0.0, 0.5, 7.0, 7.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant14():
# WithVarsNudgedUp_4Bits_RegularRange
x = np.array([-0.6, -0.5, -0.24, 0.0, 7.0, 7.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([7.1]).reshape(1).astype(np.float32)
expect = np.array([-0.5, -0.5, -0.00, 0.0, 7.0, 7.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant15():
# WithVarsNudgedUp_4Bits_NarrowRange
x = np.array([-0.6, -0.5, -0.24, 0.0, 6.5, 6.6]).reshape(2, 3).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([6.6]).reshape(1).astype(np.float32)
expect = np.array([-0.5, -0.5, -0.00, 0.0, 6.5, 6.5]).astype(np.float32)
net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant16():
# WithVarsNudgedZero15_4Bits_RegularRange
x = np.array([-7.6, -7.5, -7.4, -7.2, 0.0, 0.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-7.3]).reshape(1).astype(np.float32)
max_val = np.array([0.2]).reshape(1).astype(np.float32)
expect = np.array([-7.5, -7.5, -7.5, -7.0, 0.0, 0.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant17():
# WithVarsNudgedZero15_4Bits_NarrowRange
x = np.array([-7.1, -7.0, -6.9, -6.5, 0.0, 0.1]).reshape(2, 3).astype(np.float32)
min_val = np.array([-6.8]).reshape(1).astype(np.float32)
max_val = np.array([0.2]).reshape(1).astype(np.float32)
expect = np.array([-7.0, -7.0, -7.0, -6.5, 0.0, 0.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
from mindspore import Tensor
import mindspore.nn as nn
import mindspore.context as context
from mindspore.ops.operations import _quant_ops as Q
context.set_context(device_target='GPU', device_id=0)
class Net(nn.Cell):
def __init__(self, num_bits=8, narrow_range=False):
super(Net, self).__init__()
self.op = Q.FakeQuantPerLayerGrad(num_bits=num_bits, narrow_range=narrow_range)
def construct(self, dout, x, minq, maxq):
return self.op(dout, x, minq, maxq)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad1():
# WithArgsGradient RegularRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6]).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.625]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad2():
# WithArgsGradient NarrowRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3]).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.375]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad3():
# WithArgsGradient_4Bits_RegularRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.6, -0.5, -0.4, 0.0, 7.0, 7.1]).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([7.1]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad4():
# WithArgsGradient_4Bits_NarrowRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.6, -0.5, -0.4, 0.0, 6.5, 6.6]).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([6.6]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad5():
# FakeQuantWithMinMaxVarsGradient
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).astype(np.float32)
min_val = np.array([0.0]).reshape(1).astype(np.float32)
max_val = np.array([0.0]).reshape(1).astype(np.float32)
expect = dout
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad6():
# WithVarsGradient_RegularRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.5, 63.6]).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.625]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad7():
# WithVarsGradient_NarrowRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.26, -0.25, -0.24, 0.0, 63.25, 63.3]).astype(np.float32)
min_val = np.array([-0.125]).reshape(1).astype(np.float32)
max_val = np.array([63.375]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=8, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad8():
# WithVarsGradient_4Bits_RegularRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.6, -0.5, -0.4, 0.0, 7.0, 7.1]).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([7.1]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=False)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_fake_quant_grad9():
# WithVarsGradient_4Bits_NarrowRange
dout = np.random.uniform(-1, 1, size=[6]).astype('float32')
x = np.array([-0.6, -0.5, -0.4, 0.0, 6.5, 6.6]).astype(np.float32)
min_val = np.array([-0.4]).reshape(1).astype(np.float32)
max_val = np.array([6.6]).reshape(1).astype(np.float32)
expect = np.array([0.0, dout[1], dout[2], dout[3], dout[4], 0.0]).astype(np.float32)
net = Net(num_bits=4, narrow_range=True)
output = net(Tensor(dout), Tensor(x), Tensor(min_val), Tensor(max_val))
error = np.ones(shape=expect.shape) * 1.0e-5
diff = output.asnumpy().flatten() - expect
print("output: ", output)
print("expect: ", expect)
assert np.all(np.abs(diff) < error)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册