提交 cc80c766 编写于 作者: P panfengfeng 提交者: chenzomi

add quantizaiton gpu op

上级 54481c30
......@@ -26,3 +26,7 @@ from .squeeze_grad import SqueezeGrad, gpu_schedule_SqueezeGrad
from .mean import SimpleMean, gpu_schedule_SimpleMean
from .mean_grad import SimpleMeanGrad, gpu_schedule_SimpleMeanGrad
from .mul import Mul, gpu_schedule_Mul
from .hsigmoid import Hsigmoid, gpu_schedule_Hsigmoid
from .hsigmoid_grad import HsigmoidGrad, gpu_schedule_HsigmoidGrad
from .hswish import Hswish, gpu_schedule_Hswish
from .hswish_grad import HswishGrad, gpu_schedule_HswishGrad
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""hsigmoid"""
import _akg.topi as topi
import _akg.tvm as tvm
from _akg.topi import tag
@tvm.tag_scope(tag=tag.ELEMWISE)
def topi_nn_hsigmoid(x):
"""
topi hsigmoid
Args:
x:
Returns:
"""
return tvm.compute(x.shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0,
tvm.if_then_else(x(*i) >= 3, 1,
(x(*i) + 3) / 6)))
def Hsigmoid(x):
"""
Hsigmoid
Args:
x:
Returns:
"""
return topi_nn_hsigmoid(x)
def gpu_schedule_Hsigmoid(outs):
"""
gpu schedule Hsigmoid
Args:
outs:
Returns:
"""
device = 'cuda'
ctx = tvm.context(device, 0)
if not ctx.exist:
raise SystemError("Skip because %s is not enabled" % device)
with tvm.target.create(device):
sch = topi.cuda.schedule_elemwise(outs)
return sch
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hsigmoid grad"""
import _akg.topi as topi
import _akg.tvm as tvm
def HsigmoidGrad(y_grad, x):
"""
HsigmoidGrad
Args:
y_grad:
x:
Returns:
"""
return tvm.compute(x.shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0,
tvm.if_then_else(x(*i) >= 3, 0,
y_grad(*i) / 6)))
def gpu_schedule_HsigmoidGrad(outs):
"""
gpu schedule ReLU6Grad
Args:
outs:
Returns:
"""
device = 'cuda'
ctx = tvm.context(device, 0)
if not ctx.exist:
raise SystemError("Skip because %s is not enabled" % device)
with tvm.target.create(device):
sch = topi.cuda.schedule_elemwise(outs)
return sch
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""hswish"""
import _akg.topi as topi
import _akg.tvm as tvm
from _akg.topi import tag
@tvm.tag_scope(tag=tag.ELEMWISE)
def topi_nn_hswish(x):
"""
topi hswish
Args:
x:
Returns:
"""
return tvm.compute(x.shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0,
tvm.if_then_else(x(*i) >= 3, x(*i),
x(*i) * (x(*i) + 3) / 6)))
def Hswish(x):
"""
Hswish
Args:
x:
Returns:
"""
return topi_nn_hswish(x)
def gpu_schedule_Hswish(outs):
"""
gpu schedule Hswish
Args:
outs:
Returns:
"""
device = 'cuda'
ctx = tvm.context(device, 0)
if not ctx.exist:
raise SystemError("Skip because %s is not enabled" % device)
with tvm.target.create(device):
sch = topi.cuda.schedule_elemwise(outs)
return sch
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HswishGrad"""
import _akg.topi as topi
import _akg.tvm as tvm
def HswishGrad(y_grad, x):
"""
HswishGrad
Args:
y_grad:
x:
Returns:
"""
shape = x.shape
res0 = tvm.compute(shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0, y_grad(*i) * (2 * x(*i) + 3) / 6))
res6 = tvm.compute(shape, lambda *i: tvm.if_then_else(x(*i) >= 3, y_grad(*i), res0(*i)))
return res6
def gpu_schedule_HswishGrad(outs):
"""
gpu schedule HswishGrad
Args:
outs:
Returns:
"""
device = 'cuda'
ctx = tvm.context(device, 0)
if not ctx.exist:
raise SystemError("Skip because %s is not enabled" % device)
with tvm.target.create(device):
sch = topi.cuda.schedule_elemwise(outs)
return sch
......@@ -300,6 +300,13 @@ class ParamValidator:
for arg, value in args.items():
ParamValidator.check_subclass(arg, value, mstype.tensor)
@staticmethod
def check_bool(arg_name, arg_value):
"""Check arg isintance of bool"""
if not isinstance(arg_value, bool):
raise ValueError(f'The `{arg_name}` should be isintance of bool, but got {arg_value}.')
return arg_value
@staticmethod
def check_type(arg_name, arg_value, valid_types):
"""Type checking."""
......
......@@ -473,6 +473,7 @@ if(ENABLE_GPU)
gpu_cuda_lib
gpu_queue
cublas
${CUDA_PATH}/lib64/libcurand.so
${CUDNN_PATH}/lib64/libcudnn.so
${CUDA_PATH}/lib64/libcudart.so
${CUDA_PATH}/lib64/stubs/libcuda.so)
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdint.h>
#include <thrust/device_ptr.h>
#include <thrust/fill.h>
#include <thrust/reduce.h>
#include <thrust/system/cuda/execution_policy.h>
#include "batchnorm_fold2_impl.cuh"
#include "batchnorm_fold_impl.cuh"
#include "include/cuda_runtime.h"
template <typename T>
__global__ void BatchNormFold2Kernel(const T *x, const T *beta, const T *gamma, const T *batch_std, const T *batch_mean,
const T *running_std, const T *running_mean, const int *global_step, T *y,
int freeze_bn, size_t N, size_t C, size_t H, size_t W) {
int c = 0;
size_t num_count = N * C * H * W;
if (*global_step < freeze_bn) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
c = i / (H * W) % C;
y[i] = x[i] * running_std[c] / batch_std[c] + beta[c] - gamma[c] * batch_mean[c] / batch_std[c];
}
} else {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
c = i / (H * W) % C;
y[i] = x[i] + beta[c] - gamma[c] * running_mean[c] / running_std[c];
}
}
}
template <typename T>
__global__ void BatchNormFold2GradReduce1(const T *dout, T *tmp, const T *x, T *tmp2, size_t N, size_t C, size_t HW) {
int n = 0;
int c = 0;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < N * C; i += blockDim.x * gridDim.x) {
n = i / C;
c = i % C;
tmp[c * N + n] = thrust::reduce(thrust::seq, dout + i * HW, dout + (i + 1) * HW, 0.f, thrust::plus<T>());
tmp2[c * N + n] = thrust::reduce(thrust::seq, x + i * HW, x + (i + 1) * HW, 0.f, thrust::plus<T>());
}
}
template <typename T>
__global__ void BatchNormFold2GradReduce2(const T *tmp, T *d_beta, const T *tmp2, T *reduce_x, size_t N, size_t C) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < C; i += blockDim.x * gridDim.x) {
d_beta[i] = thrust::reduce(thrust::seq, tmp + i * N, tmp + (i + 1) * N, 0.f, thrust::plus<T>());
reduce_x[i] = thrust::reduce(thrust::seq, tmp2 + i * N, tmp2 + (i + 1) * N, 0.f, thrust::plus<T>());
}
}
template <typename T>
__global__ void BatchNormFold2GradNotFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std,
const T *running_mean, const T *running_std, const T *gamma, T *d_gamma,
T *d_batch_mean, T *d_batch_std, size_t C) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < C; i += blockDim.x * gridDim.x) {
d_gamma[i] = -d_beta[i] * batch_mean[i] / batch_std[i];
d_batch_mean[i] = -d_beta[i] * gamma[i] / batch_std[i];
d_batch_std[i] =
(d_beta[i] * gamma[i] * batch_mean[i] - reduce_x[i] * running_std[i]) / batch_std[i] / batch_std[i];
}
}
template <typename T>
__global__ void BatchNormFold2GradFreeze(const T *d_beta, const T *running_mean, const T *running_std, T *d_gamma,
size_t C) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < C; i += blockDim.x * gridDim.x) {
d_gamma[i] = -d_beta[i] * running_mean[i] / running_std[i];
}
}
template <typename T>
__global__ void BatchNormFold2GradMul(const T *dout, const T *x, T *tmp_x, size_t NCHW) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < NCHW; i += blockDim.x * gridDim.x) {
tmp_x[i] = dout[i] * x[i];
}
}
template <typename T>
__global__ void DxMul(size_t N, size_t C, size_t HW, const T *batch_std, const T *running_std, T *d_x) {
int c = 0;
size_t num_count = N * C * HW;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
c = (i / HW) % C;
d_x[i] = d_x[i] * running_std[c] / batch_std[c];
}
}
template <typename T>
void BatchNormFold2Forward(const T *x, const T *beta, const T *gamma, const T *batch_std, const T *batch_mean,
const T *running_std, const T *running_mean, const int *global_step, T *y, int freeze_bn,
size_t N, size_t C, size_t H, size_t W, cudaStream_t cuda_stream) {
auto num_count = N * C * H * W;
BatchNormFold2Kernel<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(
x, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, y, freeze_bn, N, C, H, W);
}
template void BatchNormFold2Forward<float>(const float *x, const float *beta, const float *gamma,
const float *batch_std, const float *batch_mean, const float *running_std,
const float *running_mean, const int *global_step, float *y, int freeze_bn,
size_t N, size_t C, size_t H, size_t W, cudaStream_t cuda_stream);
template <typename T>
void BatchNormFold2GradReduce(const T *dout, const T *x, T *d_beta, T *tmp, T *reduce_x, T *tmp2, T *tmp_x, size_t N,
size_t C, size_t H, size_t W, cudaStream_t cuda_stream) {
auto hw = H * W;
auto num_count = N * C * H * W;
BatchNormFold2GradMul<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(dout, x, tmp_x, num_count);
BatchNormFold2GradReduce1<<<GET_BLOCKS(N * C), GET_THREADS, 0, cuda_stream>>>(dout, tmp, tmp_x, tmp2, N, C, hw);
BatchNormFold2GradReduce2<<<GET_BLOCKS(C), GET_THREADS, 0, cuda_stream>>>(tmp, d_beta, tmp2, reduce_x, N, C);
}
template void BatchNormFold2GradReduce<float>(const float *dout, const float *x, float *d_beta, float *tmp,
float *reduce_x, float *tmp2, float *tmp_x, size_t N, size_t C, size_t H,
size_t W, cudaStream_t cuda_stream);
template <typename T>
void CalBatchNormFold2GradNotFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std,
const T *running_mean, const T *running_std, const T *gamma, T *d_gamma,
T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream) {
BatchNormFold2GradNotFreeze<<<GET_BLOCKS(C), GET_THREADS, 0, cuda_stream>>>(
d_beta, reduce_x, batch_mean, batch_std, running_mean, running_std, gamma, d_gamma, d_batch_mean, d_batch_std, C);
}
template void CalBatchNormFold2GradNotFreeze<float>(const float *d_beta, const float *reduce_x, const float *batch_mean,
const float *batch_std, const float *running_mean,
const float *running_std, const float *gamma, float *d_gamma,
float *d_batch_mean, float *d_batch_std, size_t C,
cudaStream_t cuda_stream);
template <typename T>
void CalBatchNormFold2GradFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std,
const T *running_mean, const T *running_std, const T *gamma, T *d_gamma,
T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream) {
BatchNormFold2GradFreeze<<<GET_BLOCKS(C), GET_THREADS, 0, cuda_stream>>>(d_beta, running_mean, running_std, d_gamma,
C);
ThrustFillWith(d_batch_mean, C, (T)0.f, cuda_stream);
ThrustFillWith(d_batch_std, C, (T)0.f, cuda_stream);
}
template void CalBatchNormFold2GradFreeze<float>(const float *d_beta, const float *reduce_x, const float *batch_mean,
const float *batch_std, const float *running_mean,
const float *running_std, const float *gamma, float *d_gamma,
float *d_batch_mean, float *d_batch_std, size_t C,
cudaStream_t cuda_stream);
template <typename T>
void CalBatchNormFold2GradNotFreezeDxMul(const T *batch_std, const T *running_std, T *d_x, size_t N, size_t C, size_t H,
size_t W, cudaStream_t cuda_stream) {
DxMul<<<GET_BLOCKS(N * C * H * W), GET_THREADS, 0, cuda_stream>>>(N, C, H * W, batch_std, running_std, d_x);
}
template void CalBatchNormFold2GradNotFreezeDxMul<float>(const float *batch_std, const float *running_std, float *d_x,
size_t N, size_t C, size_t H, size_t W,
cudaStream_t cuda_stream);
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_
#include "device/gpu/cuda_common.h"
template <typename T>
void BatchNormFold2Forward(const T *x, const T *beta, const T *gamma, const T *batch_std, const T *batch_mean,
const T *running_std, const T *running_mean, const int *global_step, T *y, int freeze_bn,
size_t N, size_t C, size_t H, size_t W, cudaStream_t cuda_stream);
template <typename T>
void CalBatchNormFold2GradNotFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std,
const T *running_mean, const T *running_std, const T *gamma, T *d_gamma,
T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream);
template <typename T>
void CalBatchNormFold2GradFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std,
const T *running_mean, const T *running_std, const T *gamma, T *d_gamma,
T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream);
template <typename T>
void BatchNormFold2GradReduce(const T *dout, const T *x, T *d_beta, T *tmp, T *reduce_x, T *tmp2, T *tmp_x, size_t N,
size_t C, size_t H, size_t W, cudaStream_t cuda_stream);
template <typename T>
void CalBatchNormFold2GradNotFreezeDxMul(const T *batch_std, const T *running_std, T *d_x, size_t N, size_t C, size_t H,
size_t W, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/device_ptr.h>
#include <thrust/fill.h>
#include <thrust/system/cuda/execution_policy.h>
#include "batchnorm_fold_impl.cuh"
#include "device/gpu/cuda_common.h"
template <typename T>
__global__ void UpdateRunningStd(int channel_size, const double epsilon, T* running_std) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) {
running_std[i] = sqrtf(running_std[i] + epsilon);
}
return;
}
template <typename T>
__global__ void UpdateBatchStd(int channel_size, T* batch_std) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) {
batch_std[i] = 1 / batch_std[i];
}
return;
}
template <typename T>
__global__ void CalDx(const T* d_batch_mean, const T* d_batch_std, const T* x, const T* batch_mean, const T* batch_std,
int batch_size, int channel_size, int height, int width, T* dx) {
int n = batch_size * channel_size * height * width;
int normal_size = batch_size * height * width;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
int channel_index = i / (height * width) % channel_size;
dx[i] = d_batch_mean[channel_index] / normal_size +
d_batch_std[channel_index] * (x[i] - batch_mean[channel_index]) / batch_std[channel_index] / normal_size;
}
return;
}
template <typename T>
void CalUpdateRunningStd(int channel_size, double epsilon, T* running_std, cudaStream_t cuda_stream) {
UpdateRunningStd<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>(channel_size, epsilon, running_std);
return;
}
template void CalUpdateRunningStd<float>(int channel_size, double epsilon, float* running_std,
cudaStream_t cuda_stream);
template <typename T>
void CalUpdateBatchStd(int channel_size, T* batch_std, cudaStream_t cuda_stream) {
UpdateBatchStd<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>(channel_size, batch_std);
return;
}
template void CalUpdateBatchStd<float>(int channel_size, float* batch_std, cudaStream_t cuda_stream);
template <typename T>
void CalBatchNormFoldGrad(const T* d_batch_mean, const T* d_batch_std, const T* x, const T* batch_mean,
const T* batch_std, int batch_size, int channel_size, int height, int width, T* dx,
cudaStream_t cuda_stream) {
CalDx<<<GET_BLOCKS(batch_size * channel_size * height * width), GET_THREADS, 0, cuda_stream>>>(
d_batch_mean, d_batch_std, x, batch_mean, batch_std, batch_size, channel_size, height, width, dx);
}
template void CalBatchNormFoldGrad<float>(const float* d_batch_mean, const float* d_batch_std, const float* x,
const float* batch_mean, const float* batch_std, int batch_size,
int channel_size, int height, int width, float* dx, cudaStream_t cuda_stream);
template <typename T>
void ThrustFillWith(T* array, int size, T tofill, cudaStream_t cuda_stream) {
thrust::device_ptr<T> dev_ptr(array);
thrust::fill(thrust::cuda::par.on(cuda_stream), dev_ptr, dev_ptr + size, tofill);
}
template void ThrustFillWith<float>(float* array, int size, float tofill, cudaStream_t cuda_stream);
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORM_FOLD_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORM_FOLD_H_
template <typename T>
void CalUpdateRunningStd(int channel_size, double epsilon, T* running_std, cudaStream_t cuda_stream);
template <typename T>
void CalUpdateBatchStd(int channel_size, T* batch_std, cudaStream_t cuda_stream);
template <typename T>
void CalBatchNormFoldGrad(const T* d_batch_mean, const T* d_batch_std, const T* x, const T* batch_mean,
const T* batch_std, int batch_size, int channel_size, int height, int width, T* dx,
cudaStream_t cuda_stream);
template <typename T>
void ThrustFillWith(T* array, int size, T tofill, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BATCHNORM_FOLD_H_
......@@ -41,3 +41,4 @@ template void CalConcatV2(const size_t size, const int w1, const int w2, const i
int* output, cudaStream_t cuda_stream);
template void CalConcatV2(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2,
half* output, cudaStream_t cuda_stream);
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/reduce.h>
#include "correction_mul_impl.cuh"
#include "device/gpu/cuda_common.h"
template <typename T>
__global__ void CorrectionMul(const T* weight, const T* gamma, const T* running_std, const int batchsize, const int chw,
T* output) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batchsize * chw; i += blockDim.x * gridDim.x) {
int n = i / chw;
output[i] = weight[i] * gamma[n] / running_std[n];
}
return;
}
template <typename T>
__global__ void Mul(int N, const T* a, const T* b, T* c) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
c[i] = a[i] * b[i];
}
return;
}
template <typename T>
__global__ void Reduce(int N, int CHW, const T* tmp, const T* running_std, T* d_gamma) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
d_gamma[i] = thrust::reduce(thrust::seq, tmp + i * CHW, tmp + (i + 1) * CHW, 0.f, thrust::plus<T>());
d_gamma[i] = d_gamma[i] / running_std[i];
}
return;
}
template <typename T>
void CalCorrectionMul(const T* weight, const T* gamma, const T* running_std, int N, int C, int H, int W, T* output,
cudaStream_t cuda_stream) {
CorrectionMul<<<GET_BLOCKS(N * C * H * W), GET_THREADS, 0, cuda_stream>>>(weight, gamma, running_std, N, C * H * W,
output);
}
template void CalCorrectionMul<float>(const float* weight, const float* gamma, const float* running_std, int N, int C,
int H, int W, float* output, cudaStream_t cuda_stream);
template <typename T>
void CalCorrectionMulGrad(const T* d_out, const T* weight, const T* running_std, int N, int C, int H, int W, T* d_gamma,
T* tmp, cudaStream_t cuda_stream) {
Mul<<<GET_BLOCKS(N * C * H * W), GET_THREADS, 0, cuda_stream>>>(N * C * H * W, d_out, weight, tmp);
Reduce<<<GET_BLOCKS(N), GET_THREADS, 0, cuda_stream>>>(N, C * H * W, tmp, running_std, d_gamma);
}
template void CalCorrectionMulGrad<float>(const float* d_out, const float* weight, const float* running_std, int N,
int C, int H, int W, float* d_gamma, float* tmp, cudaStream_t cuda_stream);
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CORRECTIONMUL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CORRECTIONMUL_H_
template <typename T>
void CalCorrectionMul(const T* weight, const T* gamma, const T* running_std, int batch_size, int channel_size,
int height, int width, T* output, cudaStream_t cuda_stream);
template <typename T>
void CalCorrectionMulGrad(const T* d_out, const T* weight, const T* running_std, int batch_size, int channel_size,
int height, int width, T* d_gamma, T* tmp, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CORRECTIONMUL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdint.h>
#include "cross_entropy_cuda_impl.cuh"
#include "include/cuda_runtime.h"
__global__ void CalCrossEntropyWithGradKernel(const float *softmax_logits, const float *log_softmax_logits,
const float *labels, const int batch_size, const int num_classes,
float *loss, float *dx) {
extern __shared__ float loss_shared[];
const float mean_scale = 1.0f / static_cast<float>(batch_size);
loss_shared[threadIdx.x] = 0;
for (int i = threadIdx.x * num_classes; i < (threadIdx.x + 1) * num_classes; ++i) {
loss_shared[threadIdx.x] -= log_softmax_logits[i] * labels[i];
dx[i] = (softmax_logits[i] - labels[i]) * mean_scale;
}
__syncthreads();
if (threadIdx.x == 0) {
*loss = 0;
for (int i = 0; i < batch_size; i++) {
*loss += loss_shared[i];
}
*loss *= mean_scale;
}
}
void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels,
const int batch_size, const int num_classes, float *loss, float *dx,
cudaStream_t cuda_stream) {
CalCrossEntropyWithGradKernel<<<1, batch_size, batch_size * sizeof(float), cuda_stream>>>(
softmax_logits, log_softmax_logits, labels, batch_size, num_classes, loss, dx);
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_
#include "device/gpu/cuda_common.h"
void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels,
const int batch_size, const int num_classes, float *loss, float *dx,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdint.h>
#include "dropout_impl.cuh"
#include "include/cuda_runtime.h"
__global__ void DropoutForwardKernel(const float *input, float *mask, float *output, size_t num_count,
float drop_prob) {
float scale = 1.f / (1.f - drop_prob);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
mask[i] = mask[i] > drop_prob;
output[i] = scale * input[i] * mask[i];
}
}
void DropoutForward(const float *input, float *mask, float *output, size_t num_count, float drop_prob,
cudaStream_t cuda_stream) {
DropoutForwardKernel<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(input, mask, output, num_count,
drop_prob);
}
__global__ void DropoutBackwardKernel(const float *dy, const float *mask, float *dx, size_t num_count,
float drop_prob) {
float scale = 1.f / (1.f - drop_prob);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
dx[i] = scale * dy[i] * mask[i];
}
}
void DropoutBackward(const float *dy, const float *mask, float *dx, size_t num_count, float drop_prob,
cudaStream_t cuda_stream) {
DropoutBackwardKernel<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(dy, mask, dx, num_count, drop_prob);
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_
#include "device/gpu/cuda_common.h"
void DropoutForward(const float *input, float *mask, float *output, size_t num_count, float drop_prob,
cudaStream_t cuda_stream);
void DropoutBackward(const float *dy, const float *mask, float *dx, size_t num_count, float drop_prob,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/extrema.h>
#include <thrust/device_vector.h>
#include <thrust/pair.h>
#include "device/gpu/cuda_common.h"
#include "fake_quant_impl.cuh"
__global__ void FakeQuantize(const float* input, float* output, const int size, const float* nudge_min,
const float* nudge_max, const float* scale, bool symmetric) {
float input_x = 0.f;
int nudge_input = 0;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
input_x = input[i];
// clamp input x
if (input_x < nudge_min[0]) {
input_x = nudge_min[0];
}
if (input_x > nudge_max[0]) {
input_x = nudge_max[0];
}
// clamp shift
nudge_input = floor((input_x - nudge_min[0]) / scale[0] + 0.5f);
// quantize
output[i] = nudge_input * scale[0] + nudge_min[0];
}
return;
}
__global__ void FakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size,
const float* nudge_min, const float* nudge_max) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
if (input[i] < nudge_min[0] || input[i] > nudge_max[0]) {
output[i] = 0;
} else {
output[i] = gradient[i];
}
}
return;
}
__global__ void NudgeMinMax(const float* input_min, const float* input_max, const float quant_min,
const float quant_max, float* nudge_min, float* nudge_max, float* scale) {
float zp_from_min = 0.f;
if ((quant_max - quant_min) == 0 || (*input_max - *input_min) == 0) {
*scale = 0.f;
zp_from_min = 0.f;
} else {
*scale = (*input_max - *input_min) / (quant_max - quant_min);
zp_from_min = quant_min - *input_min / *scale;
}
float nudge_zp = 0.f;
if (zp_from_min <= quant_min) {
nudge_zp = quant_min;
} else if (zp_from_min >= quant_max) {
nudge_zp = quant_max;
} else {
nudge_zp = round(zp_from_min);
}
*nudge_min = (quant_min - nudge_zp) * (*scale);
*nudge_max = (quant_max - nudge_zp) * (*scale);
return;
}
__global__ void UpdateInputMinMaxWithEMA(float* input_min, float* input_max, const float min, const float max,
const float decay) {
*input_min = decay * (min) + (1 - decay) * (*input_min);
*input_min = *input_min > 0 ? 0 : *input_min;
*input_max = decay * (max) + (1 - decay) * (*input_max);
*input_max = *input_max < 0 ? 0 : *input_max;
return;
}
__global__ void UpdateInputMinMax(float* input_min, float* input_max, const float min, const float max) {
*input_min = min;
*input_max = max;
}
void CalFakeQuantize(const float* input, float* output, const int size, const float* nudge_min, const float* nudge_max,
const float* scale, bool symmetric, cudaStream_t cuda_stream) {
FakeQuantize<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, output, size, nudge_min, nudge_max, scale,
symmetric);
return;
}
void CalFakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size,
const float* nudge_min, const float* nudge_max, cudaStream_t cuda_stream) {
FakeQuantizeGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, gradient, output, size, nudge_min,
nudge_max);
return;
}
void CalNudge(const float* input_min, const float* input_max, const float quant_min, const float quant_max,
float* nudge_min, float* nudge_max, float* scale, cudaStream_t cuda_stream) {
NudgeMinMax<<<1, 1>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale);
return;
}
void CalMinMax(float* input, float* input_min, float* input_max, const int size, const float ema_decay, const bool ema,
cudaStream_t cuda_stream) {
float minel = 0.f;
float maxel = 0.f;
thrust::pair<thrust::device_ptr<float>, thrust::device_ptr<float>> tuple;
tuple = thrust::minmax_element(thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + size);
minel = tuple.first[0];
maxel = tuple.second[0];
if (ema) {
UpdateInputMinMaxWithEMA<<<1, 1>>>(input_min, input_max, minel, maxel, ema_decay);
} else {
UpdateInputMinMax<<<1, 1>>>(input_min, input_max, minel, maxel);
}
return;
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
void CalFakeQuantize(const float* input, float* output, const int size, const float* nudge_min, const float* nudge_max,
const float* scale, bool symmetric, cudaStream_t cuda_stream);
void CalFakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size,
const float* nudge_min, const float* nudge_max, cudaStream_t cuda_stream);
void CalNudge(const float* input_min, const float* input_max, const float quant_min, const float quant_max,
float* nudge_min, float* nudge_max, float* scale, cudaStream_t cuda_stream);
void CalMinMax(float* input, float* input_min, float* input_max, const int size, const float ema_decay, const bool ema,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/extrema.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include <thrust/pair.h>
#include "fake_quant_per_channel_impl.cuh"
#include "device/gpu/cuda_common.h"
/**
* Find the nudge min, max and scale value as output.
* @param input_min array
* @param input_max array
* @param quant_min 1 << bit -1
* @param quant_max 0
* @param nudge_min array
* @param nudge_max array
* @param scale array
* @param channel_num
* @return
*/
__global__ void NudgeMinMaxPerChannel(const float* input_min, const float* input_max, const float quant_min,
const float quant_max, float* nudge_min, float* nudge_max, float* scale,
int channel_num) {
float zp_from_min = 0.f;
float nudge_zp = 0.f;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_num; i += blockDim.x * gridDim.x) {
if ((quant_max - quant_min) == 0 || (input_max[i] - input_min[i]) == 0) {
scale[i] = 0.f;
zp_from_min = 0.f;
} else {
scale[i] = (input_max[i] - input_min[i]) / (quant_max - quant_min);
zp_from_min = quant_min - input_min[i] / scale[i];
}
if (zp_from_min <= quant_min) {
nudge_zp = quant_min;
} else if (zp_from_min >= quant_max) {
nudge_zp = quant_max;
} else {
nudge_zp = round(zp_from_min);
}
nudge_min[i] = (quant_min - nudge_zp) * (scale[i]);
nudge_max[i] = (quant_max - nudge_zp) * (scale[i]);
}
}
void CalNudgePerChannel(const float* input_min, const float* input_max, const float quant_min, const float quant_max,
float* nudge_min, float* nudge_max, float* scale, const int channel_num,
cudaStream_t cuda_stream) {
NudgeMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(
input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, channel_num);
}
/**
* Calulate fake quant output accroding by nudge min, nudge max, nudge scale.
* @param input - array
* @param output - array
* @param total_size - int, purpose for cal the per chanel number in filters
* @param channel_size - int, purpose for cal the per channel number in filters
* @param nudge_min - array
* @param nudge_max - array
* @param scale - array
* @return
*/
__global__ void FakeQuantizePerChannel(const float* input, float* output, const int total_size, const int channel_size,
const float* nudge_min, const float* nudge_max, const float* scale,
bool symmetric) {
float input_x = 0.f;
int nudge_input = 0;
int channel_idx = 0;
int per_channel_num = total_size / channel_size;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < total_size; i += blockDim.x * gridDim.x) {
input_x = input[i];
channel_idx = floor(static_cast<double>(i) / static_cast<double>(per_channel_num));
// clamp input x
if (input_x < nudge_min[channel_idx]) {
input_x = nudge_min[channel_idx];
}
if (input_x > nudge_max[channel_idx]) {
input_x = nudge_max[channel_idx];
}
// clamp shift
nudge_input = floor((input_x - nudge_min[channel_idx]) / scale[channel_idx] + 0.5f);
// quantize
output[i] = nudge_input * scale[channel_idx] + nudge_min[channel_idx];
}
}
void CalFakeQuantizePerChannel(const float* input, float* output, const int total_size, const int channel_size,
const float* nudge_min, const float* nudge_max, const float* scale, bool symmetric,
cudaStream_t cuda_stream) {
FakeQuantizePerChannel<<<GET_BLOCKS(total_size), GET_THREADS, 0, cuda_stream>>>(
input, output, total_size, channel_size, nudge_min, nudge_max, scale, symmetric);
}
/**
* UpdateInputMinMaxPerChannel or UpdateInputMinMaxPerChannel With EMA.
* @param input_min
* @param input_max
* @param min
* @param max
* @return
*/
__global__ void UpdateInputMinMaxPerChannel(float* input_min, float* input_max, float* input, int channels,
int per_channel_nums, bool ema, float ema_decay) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) {
thrust::pair<float*, float*> sum =
thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1));
if (ema) {
input_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i];
input_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i];
} else {
input_min[i] = sum.first[0];
input_max[i] = sum.second[0];
}
}
}
__global__ void UpdateInputMinMaxPerChannelWithEMA(float* input_min, float* input_max, float min, float max,
const float decay) {
*input_min = decay * (min) + (1 - decay) * (*input_min);
*input_max = decay * (max) + (1 - decay) * (*input_max);
}
void CalMinMaxPerChannel(float* input, float* input_min, float* input_max, const int total_size, const int channel_size,
const float ema_decay, const bool ema, cudaStream_t cuda_stream) {
int per_channel_num = total_size / channel_size;
UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>(
input_min, input_max, input, channel_size, per_channel_num, ema, ema_decay);
}
__global__ void FakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output,
const int total_size, const int channel_size, const float* nudge_min,
const float* nudge_max) {
int channel_idx = 0;
int per_channel_num = total_size / channel_size;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < total_size; i += blockDim.x * gridDim.x) {
channel_idx = floor(static_cast<double>(i) / static_cast<double>(per_channel_num));
if (input[i] < nudge_min[channel_idx] || input[i] > nudge_max[channel_idx]) {
output[i] = 0;
} else {
output[i] = gradient[i];
}
}
}
void CalFakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output, const int total_num,
const int channel_num, const float* nudge_min, const float* nudge_max,
cudaStream_t cuda_stream) {
FakeQuantizePerChannelGrad<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(
input, gradient, output, total_num, channel_num, nudge_min, nudge_max);
}
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
void CalNudgePerChannel(const float* input_min, const float* input_max, const float quant_min, const float quant_max,
float* nudge_min, float* nudge_max, float* scale, const int channel_num,
cudaStream_t cuda_stream);
void CalFakeQuantizePerChannel(const float* input, float* output, const int total_num, const int channel_num,
const float* nudge_min, const float* nudge_max, const float* scale, bool symmetric,
cudaStream_t cuda_stream);
void CalMinMaxPerChannel(float* input, float* input_min, float* input_max, const int total_num, const int channel_num,
const float ema_decay, const bool ema, cudaStream_t cuda_stream);
void CalFakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output, const int total_num,
const int channel_num, const float* nudge_min, const float* nudge_max,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <stdint.h>
#include "sparse_cross_entropy_cuda_impl.cuh"
#include "include/cuda_runtime.h"
template <typename T>
__global__ void CalCrossEntropyKernel(const float *logits, T *labels, const int batch_size, const int class_num,
float *loss) {
float total_loss = 0.0;
float epsilon = 1e-6;
for (int i = 0; i < batch_size; ++i) {
float logit = logits[i * class_num + labels[i]];
if (logit <= 0) {
logit += epsilon;
}
float single_loss = -logf(logit);
total_loss += single_loss;
}
total_loss /= batch_size;
loss[0] = total_loss;
return;
}
template <typename T>
__global__ void CalCrossEntropyGradKernel(const float *logits, T *labels, const int batch_size, const int class_num,
float *grad) {
for (int i = 0; i < batch_size; i++) {
for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < class_num; j += blockDim.x * gridDim.x) {
if (labels[i] == j) {
grad[i * class_num + j] = (logits[i * class_num + j] - 1) / batch_size;
} else {
grad[i * class_num + j] = logits[i * class_num + j] / batch_size;
}
}
}
return;
}
template <typename T>
void CalCrossEntropy(const float *logits, T *labels, const int batch_size, const int class_num, float *loss,
cudaStream_t cuda_stream) {
CalCrossEntropyKernel<<<1, 1, 0, cuda_stream>>>(logits, labels, batch_size, class_num, loss);
return;
}
template <typename T>
void CalCrossEntropyGrad(const float *logits, T *labels, const int batch_size, const int class_num, float *grad,
cudaStream_t cuda_stream) {
CalCrossEntropyGradKernel<<<GET_BLOCKS(class_num), GET_THREADS, 0, cuda_stream>>>(logits, labels, batch_size,
class_num, grad);
return;
}
template void CalCrossEntropy<int>(const float *logits, int *labels, const int batch_size, const int class_num,
float *loss, cudaStream_t cuda_stream);
template void CalCrossEntropy<uint64_t>(const float *logits, uint64_t *labels, const int batch_size,
const int class_num, float *loss, cudaStream_t cuda_stream);
template void CalCrossEntropyGrad<int>(const float *logits, int *labels, const int batch_size, const int class_num,
float *grad, cudaStream_t cuda_stream);
template void CalCrossEntropyGrad<uint64_t>(const float *logits, uint64_t *labels, const int batch_size,
const int class_num, float *grad, cudaStream_t cuda_stream);
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_
#include "device/gpu/cuda_common.h"
template <typename T>
void CalCrossEntropy(const float *logits, T *labels, const int batch_size, const int class_num, float *loss,
cudaStream_t cuda_stream);
template <typename T>
void CalCrossEntropyGrad(const float *logits, T *labels, const int batch_size, const int class_num, float *grad,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/nn/dropout_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/dropout_impl.cuh"
namespace mindspore {
namespace kernel {
DropoutGpuFwdKernel::DropoutGpuFwdKernel()
: cudnn_handle_(nullptr),
is_null_input_(false),
num_count_(0),
drop_prob_(0.0),
states_init_(false),
mask_generator_(nullptr) {}
DropoutGpuFwdKernel::~DropoutGpuFwdKernel() { DestroyResource(); }
const std::vector<size_t> &DropoutGpuFwdKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &DropoutGpuFwdKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &DropoutGpuFwdKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool DropoutGpuFwdKernel::Init(const CNodePtr &kernel_node) {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but DropoutGpuFwdKernel needs 1.";
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
InitSizeLists();
return true;
}
num_count_ = 1;
for (size_t x : input_shape) {
num_count_ *= x;
}
drop_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("drop_prob"));
InitSizeLists();
return true;
}
void DropoutGpuFwdKernel::InitResource() {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
}
void DropoutGpuFwdKernel::DestroyResource() noexcept {}
void DropoutGpuFwdKernel::InitSizeLists() {
size_t input_size = num_count_ * sizeof(float);
size_t workspace_size = 0;
input_size_list_.push_back(input_size);
output_size_list_.push_back(input_size); // output size: the same with input size
output_size_list_.push_back(input_size); // mask size: the same with input size
workspace_size_list_.push_back(workspace_size);
}
bool DropoutGpuFwdKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
if (is_null_input_) {
return true;
}
auto *input = reinterpret_cast<float *>(inputs[0]->addr);
auto *output = reinterpret_cast<float *>(outputs[0]->addr);
auto *mask = reinterpret_cast<float *>(outputs[1]->addr);
if (!states_init_) {
curandCreateGenerator(&mask_generator_, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(mask_generator_, time(NULL));
states_init_ = true;
}
curandGenerateUniform(mask_generator_, mask, num_count_);
DropoutForward(input, mask, output, num_count_, drop_prob_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "include/curand.h"
namespace mindspore {
namespace kernel {
class DropoutGpuFwdKernel : public GpuKernel {
public:
DropoutGpuFwdKernel();
~DropoutGpuFwdKernel() override;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override;
bool Init(const CNodePtr &kernel_node) override;
protected:
void InitResource() override;
void InitSizeLists() override;
private:
void DestroyResource() noexcept;
cudnnHandle_t cudnn_handle_;
bool is_null_input_;
size_t num_count_;
float drop_prob_;
bool states_init_;
curandGenerator_t mask_generator_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
MS_REG_GPU_KERNEL(Dropout, DropoutGpuFwdKernel)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GPU_KERNEL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/nn/dropout_grad_kernel.h"
#include "kernel/gpu/cuda_impl/dropout_impl.cuh"
namespace mindspore {
namespace kernel {
DropoutGradGpuFwdKernel::DropoutGradGpuFwdKernel()
: cudnn_handle_(nullptr), is_null_input_(false), num_count_(0), drop_prob_(0.0) {}
DropoutGradGpuFwdKernel::~DropoutGradGpuFwdKernel() { DestroyResource(); }
const std::vector<size_t> &DropoutGradGpuFwdKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &DropoutGradGpuFwdKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &DropoutGradGpuFwdKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool DropoutGradGpuFwdKernel::Init(const CNodePtr &kernel_node) {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but DropoutGradGpuFwdKernel needs 2.";
return false;
}
auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
InitSizeLists();
return true;
}
num_count_ = 1;
for (size_t x : input_shape) {
num_count_ *= x;
}
drop_prob_ = GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("drop_prob"));
InitSizeLists();
return true;
}
void DropoutGradGpuFwdKernel::InitResource() {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
}
void DropoutGradGpuFwdKernel::DestroyResource() noexcept {}
void DropoutGradGpuFwdKernel::InitSizeLists() {
size_t dy_size = num_count_ * sizeof(float);
size_t mask_size = dy_size;
size_t dx_size = dy_size;
size_t workspace_size = 0;
input_size_list_.push_back(dy_size);
input_size_list_.push_back(mask_size);
output_size_list_.push_back(dx_size);
workspace_size_list_.push_back(workspace_size);
}
bool DropoutGradGpuFwdKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
if (is_null_input_) {
return true;
}
auto *dy = reinterpret_cast<float *>(inputs[0]->addr);
auto *mask = reinterpret_cast<float *>(inputs[1]->addr);
auto *dx = reinterpret_cast<float *>(outputs[0]->addr);
DropoutBackward(dy, mask, dx, num_count_, drop_prob_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class DropoutGradGpuFwdKernel : public GpuKernel {
public:
DropoutGradGpuFwdKernel();
~DropoutGradGpuFwdKernel() override;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override;
bool Init(const CNodePtr &kernel_node) override;
protected:
void InitResource() override;
void InitSizeLists() override;
private:
void DestroyResource() noexcept;
cudnnHandle_t cudnn_handle_;
bool is_null_input_;
size_t num_count_;
float drop_prob_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
MS_REG_GPU_KERNEL(DropoutGrad, DropoutGradGpuFwdKernel)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_DROPOUT_GRAD_KERNEL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/batchnorm_fold2_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(BatchNormFold2,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
BatchNormFold2GpuKernel, float)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/batchnorm_fold2_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class BatchNormFold2GpuKernel : public GpuKernel {
public:
BatchNormFold2GpuKernel()
: cudnn_handle_(nullptr),
is_null_input_(false),
batch_size_(0),
channel_(0),
height_(0),
width_(0),
freeze_bn_(0) {}
~BatchNormFold2GpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
if (is_null_input_) {
return true;
}
auto *input = GetDeviceAddress<T>(inputs, 0);
auto *beta = GetDeviceAddress<T>(inputs, 1);
auto *gamma = GetDeviceAddress<T>(inputs, 2);
auto *batch_std = GetDeviceAddress<T>(inputs, 3);
auto *batch_mean = GetDeviceAddress<T>(inputs, 4);
auto *running_std = GetDeviceAddress<T>(inputs, 5);
auto *running_mean = GetDeviceAddress<T>(inputs, 6);
auto *global_step = GetDeviceAddress<int32_t>(inputs, 7);
auto *output = GetDeviceAddress<T>(outputs, 0);
BatchNormFold2Forward(input, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, output,
freeze_bn_, batch_size_, channel_, height_, width_,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 8) {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but BatchNormFold2GpuKernel needs 8.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "BatchNormFold2GpuKernel input is null";
InitSizeLists();
return true;
}
if (input_shape.size() != 4) {
MS_LOG(ERROR) << "BatchNormFold2GpuKernel input shape needs (N,C,H,W).";
return false;
}
batch_size_ = input_shape[0];
channel_ = input_shape[1];
height_ = input_shape[2];
width_ = input_shape[3];
freeze_bn_ = GetValue<int32_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn"));
InitSizeLists();
return true;
}
protected:
void InitResource() { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); }
void InitSizeLists() {
size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T);
size_t weight_size = channel_ * sizeof(T);
input_size_list_.push_back(input_size);
input_size_list_.push_back(weight_size); // beta
input_size_list_.push_back(weight_size); // gamma
input_size_list_.push_back(weight_size); // batch_std
input_size_list_.push_back(weight_size); // batch_mean
input_size_list_.push_back(weight_size); // running_std
input_size_list_.push_back(weight_size); // running_mean
input_size_list_.push_back(sizeof(int32_t)); // global_step
output_size_list_.push_back(input_size);
size_t workspace_size = 0;
workspace_size_list_.push_back(workspace_size);
}
private:
void DestroyResource() noexcept {}
cudnnHandle_t cudnn_handle_;
bool is_null_input_;
size_t batch_size_;
size_t channel_;
size_t height_;
size_t width_;
size_t freeze_bn_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GPU_KERNEL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/batchnorm_fold2_grad_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(BatchNormFold2Grad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
BatchNormFold2GradGpuKernel, float)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/batchnorm_fold2_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class BatchNormFold2GradGpuKernel : public GpuKernel {
public:
BatchNormFold2GradGpuKernel()
: cudnn_handle_(nullptr),
is_null_input_(false),
batch_size_(0),
channel_(0),
height_(0),
width_(0),
freeze_bn_(0) {}
~BatchNormFold2GradGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
if (is_null_input_) {
return true;
}
auto *dout = GetDeviceAddress<T>(inputs, 0);
auto *x = GetDeviceAddress<T>(inputs, 1);
auto *gamma = GetDeviceAddress<T>(inputs, 2);
auto *batch_std = GetDeviceAddress<T>(inputs, 3);
auto *batch_mean = GetDeviceAddress<T>(inputs, 4);
auto *running_std = GetDeviceAddress<T>(inputs, 5);
auto *running_mean = GetDeviceAddress<T>(inputs, 6);
auto *global_step = GetDeviceAddress<int32_t>(inputs, 7);
auto *d_batch_std = GetDeviceAddress<T>(outputs, 0);
auto *d_batch_mean = GetDeviceAddress<T>(outputs, 1);
auto *d_beta = GetDeviceAddress<T>(outputs, 2);
auto *d_gamma = GetDeviceAddress<T>(outputs, 3);
auto *d_x = GetDeviceAddress<T>(outputs, 4);
auto *tmp = GetDeviceAddress<T>(workspace, 0);
auto *tmp2 = GetDeviceAddress<T>(workspace, 1);
auto *reduce_x = GetDeviceAddress<T>(workspace, 2);
auto *tmp_x = GetDeviceAddress<T>(workspace, 3);
int32_t current_step_host[1];
size_t x_size = batch_size_ * channel_ * height_ * width_ * sizeof(T);
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(current_step_host, global_step, sizeof(int32_t), cudaMemcpyDeviceToHost),
"Failed to copy gpu memory.");
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(d_x, dout, x_size, cudaMemcpyDeviceToDevice), "Failed to copy gpu memory.");
BatchNormFold2GradReduce(dout, x, d_beta, tmp, reduce_x, tmp2, tmp_x, batch_size_, channel_, height_, width_,
reinterpret_cast<cudaStream_t>(stream_ptr));
if (current_step_host[0] < freeze_bn_) {
CalBatchNormFold2GradNotFreezeDxMul(batch_std, running_std, d_x, batch_size_, channel_, height_, width_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalBatchNormFold2GradNotFreeze(d_beta, reduce_x, batch_mean, batch_std, running_mean, running_std, gamma, d_gamma,
d_batch_mean, d_batch_std, channel_, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CalBatchNormFold2GradFreeze(d_beta, reduce_x, batch_mean, batch_std, running_mean, running_std, gamma, d_gamma,
d_batch_mean, d_batch_std, channel_, reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}
bool Init(const CNodePtr &kernel_node) {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 8) {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but BatchNormFold2GradGpuKernel needs 8.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape);
if (is_null_input_) {
MS_LOG(WARNING) << "BatchNormFold2GradGpuKernel input is null";
InitSizeLists();
return true;
}
if (input_shape.size() != 4) {
MS_LOG(ERROR) << "BatchNormFold2GradGpuKernel input shape needs (N,C,H,W).";
return false;
}
batch_size_ = input_shape[0];
channel_ = input_shape[1];
height_ = input_shape[2];
width_ = input_shape[3];
freeze_bn_ = GetValue<int32_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn"));
InitSizeLists();
return true;
}
protected:
void InitResource() { cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); }
void InitSizeLists() {
size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T);
size_t weight_size = channel_ * sizeof(T);
size_t workspace_size = batch_size_ * channel_ * sizeof(T);
input_size_list_.push_back(input_size); // dout
input_size_list_.push_back(input_size); // x
input_size_list_.push_back(weight_size); // gamma
input_size_list_.push_back(weight_size); // batch_std
input_size_list_.push_back(weight_size); // batch_mean
input_size_list_.push_back(weight_size); // running_std
input_size_list_.push_back(weight_size); // running_mean
input_size_list_.push_back(sizeof(int32_t)); // global_step
output_size_list_.push_back(weight_size); // d_batch_std
output_size_list_.push_back(weight_size); // d_batch_mean
output_size_list_.push_back(weight_size); // d_beta
output_size_list_.push_back(weight_size); // d_gamma
output_size_list_.push_back(input_size); // d_x
workspace_size_list_.push_back(workspace_size); // tmp
workspace_size_list_.push_back(workspace_size); // tmp2
workspace_size_list_.push_back(weight_size); // reduce_x
workspace_size_list_.push_back(input_size); // tmp_x
}
private:
void DestroyResource() noexcept {}
cudnnHandle_t cudnn_handle_;
bool is_null_input_;
size_t batch_size_;
size_t channel_;
size_t height_;
size_t width_;
int32_t freeze_bn_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_BATCHNORMFOLD2_GRAD_GPU_KERNEL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/batchnorm_fold_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(BatchNormFold,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
BatchNormFoldGpuKernel, float)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/kernel_constants.h"
#include "kernel/gpu/cuda_impl/batchnorm_fold_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class BatchNormFoldGpuKernel : public GpuKernel {
public:
BatchNormFoldGpuKernel()
: input_size_(0),
output_size_(0),
exp_avg_factor_(0.9),
epsilon_(1e-12),
is_training_(true),
freeze_bn_(0),
batch_(0),
channel_(0),
height_(0),
width_(0),
mode_(CUDNN_BATCHNORM_SPATIAL),
x_desc_(nullptr),
scale_bias_mean_var_desc_(nullptr),
handle_(nullptr) {}
~BatchNormFoldGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
(void)workspace;
auto x = reinterpret_cast<T *>(inputs[0]->addr);
auto mean = reinterpret_cast<T *>(inputs[1]->addr);
auto variance = reinterpret_cast<T *>(inputs[2]->addr);
int *current_step = reinterpret_cast<int *>(inputs[3]->addr);
int current_step_host[1];
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost),
"Copy gpu memoy failed.");
if (x == nullptr) {
MS_LOG(ERROR) << "BatchNormFoldGpuKernel x is null.";
return false;
}
if (mean == nullptr) {
MS_LOG(ERROR) << "BatchNormFoldGpuKernel mean is null.";
return false;
}
if (variance == nullptr) {
MS_LOG(ERROR) << "BatchNormFoldGpuKernel variance is null.";
return false;
}
if (current_step == nullptr) {
MS_LOG(ERROR) << "BatchNormFoldGpuKernel current_step is null.";
return false;
}
auto batch_mean = reinterpret_cast<T *>(outputs[0]->addr);
auto batch_std = reinterpret_cast<T *>(outputs[1]->addr);
auto running_mean = reinterpret_cast<T *>(outputs[2]->addr);
auto running_std = reinterpret_cast<T *>(outputs[3]->addr);
auto y = reinterpret_cast<T *>(workspace[0]->addr);
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(running_mean, mean, output_size_, cudaMemcpyDeviceToDevice),
"Failed to copy gpu memory.");
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(running_std, variance, output_size_, cudaMemcpyDeviceToDevice),
"Failed to copy gpu memory.");
CalUpdateRunningStd(channel_, epsilon_, running_std, reinterpret_cast<cudaStream_t>(stream_ptr));
if (!is_training_ || current_step_host[0] >= freeze_bn_) {
CHECK_CUDA_RET_WITH_ERROR(cudaMemset(batch_mean, 0, output_size_), "Failed to set gpu memory.");
ThrustFillWith(batch_std, channel_, 1.f, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
const T alpha = 1;
const T beta = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnBatchNormalizationForwardTraining(
handle_, mode_, &alpha, &beta, x_desc_, x, x_desc_, y, scale_bias_mean_var_desc_,
mean, mean, exp_avg_factor_, mean, variance, epsilon_, batch_mean, batch_std),
"Failed to launch kernel.")
CalUpdateBatchStd(channel_, batch_std, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 4) {
MS_LOG(ERROR) << "Input number is " << input_num << " but BatchNormFold GpuKernel OP needs 4 input.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 4) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but BatchNormFold GpuKernel OP needs 4 output.";
return false;
}
T momentum = GetValue<T>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("momentum"));
exp_avg_factor_ = 1.0 - momentum;
epsilon_ = GetValue<T>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("epsilon"));
is_training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("is_training"));
freeze_bn_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn"));
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.size() != 4) {
MS_LOG(ERROR) << "Input shape is " << input_shape.size()
<< ", but BatchNormFold GpuKernel OP needs 4DTensor input.";
return false;
}
batch_ = input_shape[0];
channel_ = input_shape[1];
height_ = input_shape[2];
width_ = input_shape[3];
input_size_ = sizeof(T) * batch_ * channel_ * height_ * width_;
output_size_ = sizeof(T) * channel_;
cudnnDataType_t cudnnDataType = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))];
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(x_desc_, CUDNN_TENSOR_NCHW, cudnnDataType, batch_, channel_, height_, width_),
"Set x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(scale_bias_mean_var_desc_, CUDNN_TENSOR_NCHW, cudnnDataType, 1, channel_, 1, 1),
"Set para desc failed");
InitSizeLists();
return true;
}
protected:
void InitSizeLists() {
// x, mean, variance, current_step
input_size_list_.push_back(input_size_);
input_size_list_.push_back(output_size_);
input_size_list_.push_back(output_size_);
input_size_list_.push_back(sizeof(int));
// batch_mean, batch_std, running_mean, running_std
output_size_list_.push_back(output_size_);
output_size_list_.push_back(output_size_);
output_size_list_.push_back(output_size_);
output_size_list_.push_back(output_size_);
// store y
workspace_size_list_.push_back(input_size_);
}
void InitResource() {
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_mean_var_desc_), "Create para desc failed");
}
private:
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_mean_var_desc_), "Destroy para desc failed");
}
size_t input_size_;
size_t output_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
double exp_avg_factor_;
double epsilon_;
bool is_training_;
int freeze_bn_;
int batch_;
int channel_;
int height_;
int width_;
cudnnBatchNormMode_t mode_;
cudnnTensorDescriptor_t x_desc_;
cudnnTensorDescriptor_t scale_bias_mean_var_desc_;
cudnnHandle_t handle_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GPUKERNEL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/batchnorm_fold_grad_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(BatchNormFoldGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32),
BatchNormFoldGradGpuKernel, float)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/batchnorm_fold_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class BatchNormFoldGradGpuKernel : public GpuKernel {
public:
BatchNormFoldGradGpuKernel()
: input_size_(0),
channel_size_(0),
workspace_size_(0),
momentum_(0.1),
epsilon_(1e-12),
is_training_(true),
freeze_bn_(0),
current_step_(0),
batch_(0),
channel_(0),
height_(0),
width_(0) {}
~BatchNormFoldGradGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
(void)workspace;
// 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step'
T *d_batch_mean = GetDeviceAddress<T>(inputs, 0);
T *d_batch_std = GetDeviceAddress<T>(inputs, 1);
T *x = GetDeviceAddress<T>(inputs, 2);
T *batch_mean = GetDeviceAddress<T>(inputs, 3);
T *batch_std = GetDeviceAddress<T>(inputs, 4);
int *current_step = GetDeviceAddress<int>(inputs, 5);
int current_step_host[1];
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(current_step_host, current_step, sizeof(int), cudaMemcpyDeviceToHost),
"Copy gpu memoy failed.");
if (d_batch_mean == nullptr) {
MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel d_batch_mean is null.";
return false;
}
if (d_batch_std == nullptr) {
MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel d_batch_std is null.";
return false;
}
if (x == nullptr) {
MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel x is null.";
return false;
}
if (batch_mean == nullptr) {
MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel batch_mean is null.";
return false;
}
if (batch_std == nullptr) {
MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel batch_std is null.";
return false;
}
if (current_step == nullptr) {
MS_LOG(ERROR) << "BatchNormFoldGradGpuKernel current_step is null.";
return false;
}
T *dx = reinterpret_cast<T *>(outputs[0]->addr);
if (!is_training_ || current_step_host[0] >= freeze_bn_) {
ThrustFillWith(dx, batch_ * channel_ * height_ * width_, 0.f, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
CalBatchNormFoldGrad(d_batch_mean, d_batch_std, x, batch_mean, batch_std, batch_, channel_, height_, width_, dx,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 6) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but BatchNormFoldGrad GpuKernel OP needs 6 input.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but BatchNormFoldGrad GpuKernel OP needs 4 output.";
return false;
}
epsilon_ = GetValue<T>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("epsilon"));
is_training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("is_training"));
freeze_bn_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("freeze_bn"));
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
if (input_shape.size() != 4) {
MS_LOG(ERROR) << "Input shape is " << input_shape.size()
<< ", but BatchNormFoldGrad GpuKernel OP needs 4DTensor input.";
return false;
}
batch_ = input_shape[0];
channel_ = input_shape[1];
height_ = input_shape[2];
width_ = input_shape[3];
input_size_ = sizeof(T) * batch_ * channel_ * height_ * width_;
channel_size_ = sizeof(T) * channel_;
InitSizeLists();
return true;
}
protected:
void InitSizeLists() {
// 'd_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'current_step'
input_size_list_.push_back(channel_size_);
input_size_list_.push_back(channel_size_);
input_size_list_.push_back(input_size_);
input_size_list_.push_back(channel_size_);
input_size_list_.push_back(channel_size_);
input_size_list_.push_back(sizeof(int));
// 'dx'
output_size_list_.push_back(input_size_);
workspace_size_list_.push_back(workspace_size_);
}
private:
size_t input_size_;
size_t channel_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
T momentum_;
T epsilon_;
bool is_training_;
int freeze_bn_;
int current_step_;
int batch_;
int channel_;
int height_;
int width_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_BATCHNORM_FOLD_GRAD_GPUKERNEL_H_
/**
* Copyright 2020、 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/correction_mul_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(CorrectionMul,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
CorrectionMulGpuKernel, float)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/correction_mul_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class CorrectionMulGpuKernel : public GpuKernel {
public:
CorrectionMulGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {}
~CorrectionMulGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
auto *weight = GetDeviceAddress<T>(inputs, 0);
auto *gamma = GetDeviceAddress<T>(inputs, 1);
auto *running_std = GetDeviceAddress<T>(inputs, 2);
auto *output = GetDeviceAddress<T>(outputs, 0);
CalCorrectionMul(weight, gamma, running_std, batch_size_, channel_, height_, width_, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but CorrectionMulGpuKernel needs 3.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.size() != 4) {
MS_LOG(ERROR) << "CorrectionMulGpuKernel input shape needs (N,C,H,W).";
return false;
}
batch_size_ = input_shape[0];
channel_ = input_shape[1];
height_ = input_shape[2];
width_ = input_shape[3];
InitSizeLists();
return true;
}
protected:
void InitSizeLists() {
size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T);
size_t weight_size = batch_size_ * sizeof(T);
input_size_list_.push_back(input_size); // weight
input_size_list_.push_back(weight_size); // gamma
input_size_list_.push_back(weight_size); // running_std
size_t workspace_size = 0;
output_size_list_.push_back(input_size);
workspace_size_list_.push_back(workspace_size);
}
void InitResource() {}
private:
void DestroyResource() noexcept {}
size_t batch_size_;
size_t channel_;
size_t height_;
size_t width_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMUL_GPUKERNEL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/correction_mul_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/correction_mul_impl.cuh"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(CorrectionMulGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
CorrectionMulGradGpuKernel, float)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/correction_mul_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class CorrectionMulGradGpuKernel : public GpuKernel {
public:
CorrectionMulGradGpuKernel() : batch_size_(0), channel_(0), height_(0), width_(0) {}
~CorrectionMulGradGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
auto *d_out = GetDeviceAddress<T>(inputs, 0);
auto *weight = GetDeviceAddress<T>(inputs, 1);
auto *gamma = GetDeviceAddress<T>(inputs, 2);
auto *running_std = GetDeviceAddress<T>(inputs, 3);
auto *d_weight = GetDeviceAddress<T>(outputs, 0);
auto *d_gamma = GetDeviceAddress<T>(outputs, 1);
auto *tmp = GetDeviceAddress<T>(workspace, 0);
CalCorrectionMul(d_out, gamma, running_std, batch_size_, channel_, height_, width_, d_weight,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalCorrectionMulGrad(d_out, weight, running_std, batch_size_, channel_, height_, width_, d_gamma, tmp,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) {
InitResource();
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 4) {
MS_LOG(ERROR) << "Argument number is " << input_num << ", but CorrectionMulGradGpuKernel needs 4.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (input_shape.size() != 4) {
MS_LOG(ERROR) << "CorrectionMulGradGpuKernel input shape needs (N,C,H,W).";
return false;
}
batch_size_ = input_shape[0];
channel_ = input_shape[1];
height_ = input_shape[2];
width_ = input_shape[3];
InitSizeLists();
return true;
}
protected:
void InitSizeLists() {
size_t input_size = batch_size_ * channel_ * height_ * width_ * sizeof(T);
size_t weight_size = batch_size_ * sizeof(T);
input_size_list_.push_back(input_size); // d_out
input_size_list_.push_back(input_size); // weight
input_size_list_.push_back(weight_size); // gamma
input_size_list_.push_back(weight_size); // running_std
output_size_list_.push_back(input_size); // d_weight
output_size_list_.push_back(weight_size); // d_gamma
workspace_size_list_.push_back(input_size); // tmp d_out * weight
}
void InitResource() {}
private:
void DestroyResource() noexcept {}
size_t batch_size_;
size_t channel_;
size_t height_;
size_t width_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CORRECTIONMULGRAD_GPUKERNEL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
#include <cuda_runtime_api.h>
namespace mindspore {
namespace kernel {
FakeQuantGpuKernel::FakeQuantGpuKernel()
: input_size_(0),
min_size_(0),
max_size_(0),
output_size_(0),
workspace_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_num_(0),
quant_delay_(0),
ema_(false),
ema_decay_(0),
global_step_(0),
training_(false),
narrow_range_(false),
symmetric_(false) {}
const std::vector<size_t> &FakeQuantGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &FakeQuantGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &FakeQuantGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 output.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuant GpuKernel OP needs 1 output.";
}
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = 1.0 - GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
}
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
if (quant_delay_ < 0) {
MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << "is less then 0, require larger than 0.";
}
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
if (symmetric_) {
quant_min_ = 0 - (1 << (num_bits_ - 1));
quant_max_ = (1 << (num_bits_ - 1)) - 1;
} else {
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
}
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
if (narrow_range_) {
quant_min_++;
}
if (quant_num_ == 0) {
quant_num_ = 1;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shape.size(); ++i) {
quant_num_ *= SizeToInt(input_shape[i]);
}
input_size_ = sizeof(float);
min_size_ = sizeof(float);
max_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
output_size_ = input_size_;
InitSizeLists();
return true;
}
void FakeQuantGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(min_size_); // min
input_size_list_.push_back(max_size_); // max
output_size_list_.push_back(output_size_);
workspace_size_list_.push_back(workspace_size_);
}
bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
(void)workspace;
float *output = GetDeviceAddress<float>(outputs, 0);
float *input = GetDeviceAddress<float>(inputs, 0);
float *input_min = GetDeviceAddress<float>(inputs, 1);
float *input_max = GetDeviceAddress<float>(inputs, 2);
if (input == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input x is null.";
}
if (input_min == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input min is null.";
}
if (input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGpuKernel input max is null.";
}
// Allocate space for device copies
int size = sizeof(float);
float *d_scale = nullptr;
float *d_nudge_min = nullptr;
float *d_nudge_max = nullptr;
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), size), "Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), size), "Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), size), "Malloc gpu memory failed");
if (training_) {
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMax(input, input_min, input_max, quant_num_, ema_decay_, ema_, reinterpret_cast<cudaStream_t>(stream_ptr));
// control flow for quant_delay
if (global_step_ >= quant_delay_) {
// real launch
CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, input, input_size_, cudaMemcpyDeviceToDevice),
"Copy gpu memory failed");
}
global_step_++;
} else {
// real launch
CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantize(input, output, quant_num_, d_nudge_min, d_nudge_max, d_scale, symmetric_,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
// Cleanup
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed");
return true;
}
MS_REG_GPU_KERNEL(FakeQuantWithMinMax, FakeQuantGpuKernel)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class FakeQuantGpuKernel : public GpuKernel {
public:
FakeQuantGpuKernel();
~FakeQuantGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override;
bool Init(const CNodePtr &kernel) override;
protected:
void InitSizeLists() override;
private:
size_t input_size_;
size_t min_size_;
size_t max_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_bits_;
float quant_min_;
float quant_max_;
int quant_num_;
int quant_delay_;
bool ema_;
float ema_decay_;
int global_step_;
bool training_;
bool narrow_range_;
bool symmetric_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_impl.cuh"
namespace mindspore {
namespace kernel {
FakeQuantGradGpuKernel::FakeQuantGradGpuKernel()
: input_size_(0),
min_size_(0),
max_size_(0),
output_size_(0),
workspace_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_size_(0),
quant_delay_(0),
global_step_(0) {}
const std::vector<size_t> &FakeQuantGradGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &FakeQuantGradGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &FakeQuantGradGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 4) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuantGrad GpuKernel OP needs 1 output.";
}
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
}
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
if (quant_delay_ < 0) {
MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0.";
}
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
if (quant_size_ == 0) {
quant_size_ = 1;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < input_shape.size(); ++i) {
quant_size_ *= SizeToInt(input_shape[i]);
}
input_size_ = sizeof(float);
min_size_ = sizeof(float);
max_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
output_size_ = input_size_;
InitSizeLists();
return true;
}
void FakeQuantGradGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // gradient
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(min_size_); // min
input_size_list_.push_back(max_size_); // max
output_size_list_.push_back(output_size_);
workspace_size_list_.push_back(workspace_size_);
}
bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
(void)workspace;
float *output = GetDeviceAddress<float>(outputs, 0);
float *gradient = GetDeviceAddress<float>(inputs, 0);
float *input = GetDeviceAddress<float>(inputs, 1);
float *input_min = GetDeviceAddress<float>(inputs, 2);
float *input_max = GetDeviceAddress<float>(inputs, 3);
if (gradient == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel gradient is null";
}
if (input == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input is null.";
}
if (input_min == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input min is null.";
}
if (input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantGradGpuKernel input max is null.";
}
if (global_step_ >= quant_delay_) {
float *d_scale = nullptr;
float *d_nudge_min = nullptr;
float *d_nudge_max = nullptr;
int size = sizeof(float);
// Allocate space for device copies
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), size), "Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), size), "Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), size), "Malloc gpu memory failed");
CalNudge(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizeGrad(input, gradient, output, quant_size_, d_nudge_min, d_nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
// Cleanup
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed");
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, gradient, input_size_, cudaMemcpyDeviceToDevice),
"Copy gpu memory failed.");
}
global_step_++;
return true;
}
MS_REG_GPU_KERNEL(FakeQuantWithMinMaxGrad, FakeQuantGradGpuKernel)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class FakeQuantGradGpuKernel : public GpuKernel {
public:
FakeQuantGradGpuKernel();
~FakeQuantGradGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override;
bool Init(const CNodePtr &kernel_node) override;
protected:
void InitSizeLists() override;
private:
size_t input_size_;
size_t min_size_;
size_t max_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_bits_;
float quant_min_;
float quant_max_;
int quant_size_;
int quant_delay_;
int global_step_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
#include <cuda_runtime_api.h>
namespace mindspore {
namespace kernel {
FakeQuantPerChannelGpuKernel::FakeQuantPerChannelGpuKernel()
: input_size_(0),
min_size_(0),
max_size_(0),
output_size_(0),
workspace_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
quant_delay_(0),
ema_(false),
ema_decay_(0),
global_step_(0),
training_(false),
channel_out_(0),
narrow_range_(false),
symmetric_(false) {}
const std::vector<size_t> &FakeQuantPerChannelGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &FakeQuantPerChannelGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &FakeQuantPerChannelGpuKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 3) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuant GpuKernel OP needs 3 input.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << " but FakeQuant GpuKernel OP needs 1 output.";
return false;
}
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
ema_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema"));
ema_decay_ = 1.0 - GetValue<float>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ema_decay"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << "is out of range, expected between 2 and 16.";
return false;
}
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
if (quant_delay_ < 0) {
MS_LOG(EXCEPTION) << "Attr \'quant_delay\' " << num_bits_ << " is less then 0, require larger than 0.";
return false;
}
training_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("training"));
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
if (symmetric_) {
quant_min_ = 0 - (1 << (num_bits_ - 1));
quant_max_ = (1 << (num_bits_ - 1)) - 1;
} else {
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
}
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
if (narrow_range_) {
quant_min_++;
}
// shape info for gpu
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
channel_out_ = SizeToInt(input_shape[0]);
min_size_ = sizeof(float) * channel_out_;
max_size_ = sizeof(float) * channel_out_;
input_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
output_size_ = input_size_;
InitSizeLists();
return true;
}
void FakeQuantPerChannelGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(min_size_); // min
input_size_list_.push_back(max_size_); // max
output_size_list_.push_back(output_size_);
workspace_size_list_.push_back(workspace_size_);
}
bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
(void)workspace;
float *output = GetDeviceAddress<float>(outputs, 0);
float *input = GetDeviceAddress<float>(inputs, 0);
float *input_min = GetDeviceAddress<float>(inputs, 1);
float *input_max = GetDeviceAddress<float>(inputs, 2);
if (input == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input is null.";
}
if (input_min == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input min is null.";
}
if (input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGpuKernel input max is null.";
}
// Allocate space for device copies
float *d_scale = nullptr;
float *d_nudge_min = nullptr;
float *d_nudge_max = nullptr;
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), sizeof(float) * channel_out_),
"Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), sizeof(float) * channel_out_),
"Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), sizeof(float) * channel_out_),
"Malloc gpu memory failed");
int total_size = input_size_ / sizeof(float);
bool symmetric = false;
if (training_) {
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMaxPerChannel(input, input_min, input_max, total_size, channel_out_, ema_decay_, ema_,
reinterpret_cast<cudaStream_t>(stream_ptr));
// control flow for quant_delay
if (global_step_ >= quant_delay_) {
// real launch
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannel(input, output, total_size, channel_out_, d_nudge_min, d_nudge_max, d_scale, symmetric,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, input, input_size_, cudaMemcpyDeviceToDevice),
"Copy gpu memory failed.");
}
global_step_++;
} else {
// real launch
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannel(input, output, total_size, channel_out_, d_nudge_min, d_nudge_max, d_scale, symmetric,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
// Cleanup
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed");
return true;
}
MS_REG_GPU_KERNEL(FakeQuantWithMinMaxPerChannel, FakeQuantPerChannelGpuKernel)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class FakeQuantPerChannelGpuKernel : public GpuKernel {
public:
FakeQuantPerChannelGpuKernel();
~FakeQuantPerChannelGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override;
bool Init(const CNodePtr &kernel) override;
protected:
void InitSizeLists() override;
private:
size_t input_size_;
size_t min_size_;
size_t max_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_bits_;
float quant_min_;
float quant_max_;
int quant_delay_;
bool ema_;
float ema_decay_;
int global_step_;
bool training_;
int channel_out_;
bool narrow_range_;
bool symmetric_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GPUKERNEL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_per_channel_impl.cuh"
namespace mindspore {
namespace kernel {
FakeQuantPerChannelGradGpuKernel::FakeQuantPerChannelGradGpuKernel()
: input_size_(0),
min_size_(0),
max_size_(0),
output_size_(0),
workspace_size_(0),
num_bits_(0),
quant_min_(0),
quant_max_(0),
channel_out_(0),
quant_delay_(0),
global_step_(0),
narrow_range_(false),
symmetric_(false) {}
const std::vector<size_t> &FakeQuantPerChannelGradGpuKernel::GetInputSizeList() const { return input_size_list_; }
const std::vector<size_t> &FakeQuantPerChannelGradGpuKernel::GetOutputSizeList() const { return output_size_list_; }
const std::vector<size_t> &FakeQuantPerChannelGradGpuKernel::GetWorkspaceSizeList() const {
return workspace_size_list_;
}
bool FakeQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 4) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but FakeQuantGrad GpuKernel OP needs 4 output.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but FakeQuantGrad GpuKernel OP needs 1 output.";
}
num_bits_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("num_bits"));
if (num_bits_ <= 2 || num_bits_ >= 16) {
MS_LOG(EXCEPTION) << "Attr \'num_bits\' " << num_bits_ << " is out of range, expected between 2 and 16.";
}
quant_delay_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("quant_delay"));
if (quant_delay_ < 0) {
MS_LOG(EXCEPTION) << "Attr \'quant_delay_\' " << quant_delay_ << " is less then 0, require larger than 0.";
}
symmetric_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("symmetric"));
if (symmetric_) {
quant_min_ = 0 - (1 << (num_bits_ - 1));
quant_max_ = (1 << (num_bits_ - 1)) - 1;
} else {
quant_min_ = 0;
quant_max_ = (1 << num_bits_) - 1;
}
narrow_range_ = GetValue<bool>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("narrow_range"));
if (narrow_range_) {
quant_min_++;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
channel_out_ = SizeToInt(input_shape[0]);
min_size_ = sizeof(float) * channel_out_;
max_size_ = sizeof(float) * channel_out_;
input_size_ = sizeof(float);
for (size_t i = 0; i < input_shape.size(); i++) {
input_size_ *= input_shape[i];
}
output_size_ = input_size_;
InitSizeLists();
return true;
}
void FakeQuantPerChannelGradGpuKernel::InitSizeLists() {
input_size_list_.push_back(input_size_); // gradient
input_size_list_.push_back(input_size_); // input
input_size_list_.push_back(min_size_); // min
input_size_list_.push_back(max_size_); // max
output_size_list_.push_back(output_size_);
workspace_size_list_.push_back(workspace_size_);
}
bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
(void)workspace;
float *output = GetDeviceAddress<float>(outputs, 0);
float *gradient = GetDeviceAddress<float>(inputs, 0);
float *input = GetDeviceAddress<float>(inputs, 1);
float *input_min = GetDeviceAddress<float>(inputs, 2);
float *input_max = GetDeviceAddress<float>(inputs, 3);
if (gradient == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel gradient is null";
}
if (input == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input is null";
}
if (input_min == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input min is null";
}
if (input_max == nullptr) {
MS_LOG(EXCEPTION) << "FakeQuantPerChannelGradGpuKernel input max is null";
}
int total_size = input_size_ / sizeof(float);
if (global_step_ >= quant_delay_) {
float *d_scale = nullptr;
float *d_nudge_min = nullptr;
float *d_nudge_max = nullptr;
// Allocate space for device copies
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_scale), channel_out_ * sizeof(float)),
"Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_min), channel_out_ * sizeof(float)),
"Malloc gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaMalloc(reinterpret_cast<void **>(&d_nudge_max), channel_out_ * sizeof(float)),
"Malloc gpu memory failed");
CalNudgePerChannel(input_min, input_max, quant_min_, quant_max_, d_nudge_min, d_nudge_max, d_scale, channel_out_,
reinterpret_cast<cudaStream_t>(stream_ptr));
CalFakeQuantizePerChannelGrad(input, gradient, output, total_size, channel_out_, d_nudge_min, d_nudge_max,
reinterpret_cast<cudaStream_t>(stream_ptr));
// Cleanup
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_scale), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_min), "Free gpu memory failed");
CHECK_CUDA_RET_WITH_ERROR(cudaFree(d_nudge_max), "Free gpu memory failed");
} else {
CHECK_CUDA_RET_WITH_ERROR(cudaMemcpy(output, gradient, input_size_, cudaMemcpyDeviceToDevice),
"Copy gpu memory failed.");
}
global_step_++;
return true;
}
MS_REG_GPU_KERNEL(FakeQuantWithMinMaxPerChannelGrad, FakeQuantPerChannelGradGpuKernel)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class FakeQuantPerChannelGradGpuKernel : public GpuKernel {
public:
FakeQuantPerChannelGradGpuKernel();
~FakeQuantPerChannelGradGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override;
const std::vector<size_t> &GetOutputSizeList() const override;
const std::vector<size_t> &GetWorkspaceSizeList() const override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override;
bool Init(const CNodePtr &kernel_node) override;
protected:
void InitSizeLists() override;
private:
size_t input_size_;
size_t min_size_;
size_t max_size_;
size_t output_size_;
size_t workspace_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
int num_bits_;
float quant_min_;
float quant_max_;
int channel_out_;
int quant_delay_;
int global_step_;
bool narrow_range_;
bool symmetric_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_PER_CHANNEL_GRAD_GPUKERNEL_H_
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HSigmoid op"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "HSigmoid",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
def _hsigmoid_akg():
"""HSigmoid AutoDiff register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HSigmoidGrad op"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "HSigmoidGrad",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "y_grad"
},
{
"index": 1,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
def _hsigmoid_grad_akg():
"""HSigmoidGrad AutoDiff register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HSwish op"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "HSwish",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
def _hswish_akg():
"""HSwish AutoDiff register"""
return
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HSwishGrad op"""
from mindspore.ops.op_info_register import op_info_register
@op_info_register("""{
"op_name": "HSwishGrad",
"imply_type": "AutoDiff",
"fusion_type": "OPAQUE",
"processor": "cuda",
"attr": [
],
"inputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "y_grad"
},
{
"index": 1,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "x"
}
],
"outputs": [
{
"index": 0,
"dtype": [
"float32", "float16"
],
"format": [
"DefaultFormat", "DefaultFormat"
],
"name": "output"
}
]
}""")
def _hswish_grad_akg():
"""HSwishGrad AutoDiff register"""
return
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册