提交 19d50fea 编写于 作者: L lizhenyu

add FusedBatchNormGradEx gpu kernel

上级 98565d8b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/fused_batch_norm_grad_ex_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradEx,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32) // dy
.AddInputAttr(kNumberTypeFloat32) // x
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // save_mean
.AddInputAttr(kNumberTypeFloat32) // save_variance
.AddInputAttr(kNumberTypeFloat32) // reserve
.AddOutputAttr(kNumberTypeFloat32) // dx
.AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32), // dbias
FusedBatchNormGradExGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradEx,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16) // dy
.AddInputAttr(kNumberTypeFloat16) // x
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // save_mean
.AddInputAttr(kNumberTypeFloat32) // save_variance
.AddInputAttr(kNumberTypeFloat32) // reserve
.AddOutputAttr(kNumberTypeFloat16) // dx
.AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32), // dbias
FusedBatchNormGradExGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithActivation,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32) // dy
.AddInputAttr(kNumberTypeFloat32) // x
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // save_mean
.AddInputAttr(kNumberTypeFloat32) // save_variance
.AddInputAttr(kNumberTypeFloat32) // reserve
.AddInputAttr(kNumberTypeFloat32) // b
.AddInputAttr(kNumberTypeFloat32) // y
.AddOutputAttr(kNumberTypeFloat32) // dx
.AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32), // dbias
FusedBatchNormGradExGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithActivation,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16) // dy
.AddInputAttr(kNumberTypeFloat16) // x
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // save_mean
.AddInputAttr(kNumberTypeFloat32) // save_variance
.AddInputAttr(kNumberTypeFloat32) // reserve
.AddInputAttr(kNumberTypeFloat32) // b
.AddInputAttr(kNumberTypeFloat16) // y
.AddOutputAttr(kNumberTypeFloat16) // dx
.AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32), // dbias
FusedBatchNormGradExGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithAddAndActivation,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32) // dy
.AddInputAttr(kNumberTypeFloat32) // x
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // save_mean
.AddInputAttr(kNumberTypeFloat32) // save_variance
.AddInputAttr(kNumberTypeFloat32) // reserve
.AddInputAttr(kNumberTypeFloat32) // b
.AddInputAttr(kNumberTypeFloat32) // y
.AddOutputAttr(kNumberTypeFloat32) // dx
.AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32) // dbias
.AddOutputAttr(kNumberTypeFloat32), // dz
FusedBatchNormGradExGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(FusedBatchNormGradExWithAddAndActivation,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16) // dy
.AddInputAttr(kNumberTypeFloat16) // x
.AddInputAttr(kNumberTypeFloat32) // scale
.AddInputAttr(kNumberTypeFloat32) // save_mean
.AddInputAttr(kNumberTypeFloat32) // save_variance
.AddInputAttr(kNumberTypeFloat32) // reserve
.AddInputAttr(kNumberTypeFloat32) // b
.AddInputAttr(kNumberTypeFloat16) // y
.AddOutputAttr(kNumberTypeFloat16) // dx
.AddOutputAttr(kNumberTypeFloat32) // dscale
.AddOutputAttr(kNumberTypeFloat32) // dbias
.AddOutputAttr(kNumberTypeFloat16), // dz
FusedBatchNormGradExGpuKernel, half)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GRAD_EX_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GRAD_EX_GPU_KERNEL_H_
#include <vector>
#include <string>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "utils/utils.h"
namespace mindspore {
namespace kernel {
template <typename T>
class FusedBatchNormGradExGpuKernel : public GpuKernel {
public:
FusedBatchNormGradExGpuKernel()
: x_size_(0),
para_size_(0),
workspace_size_(0),
reserve_size_(0),
mode_(CUDNN_BATCHNORM_SPATIAL),
bn_ops_(CUDNN_BATCHNORM_OPS_BN),
epsilon_(10e-5),
is_null_input_(false),
x_desc_(nullptr),
y_desc_(nullptr),
dy_desc_(nullptr),
dx_desc_(nullptr),
dz_desc_(nullptr),
scale_bias_diff_desc_(nullptr),
activation_desc_(nullptr),
handle_(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT) {}
~FusedBatchNormGradExGpuKernel() override { DestroyResource(); }
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
VARIABLE_NOT_USED(workspace);
VARIABLE_NOT_USED(stream_ptr);
if (is_null_input_) {
return true;
}
auto dy = GetDeviceAddress<T>(inputs, 0);
auto x = GetDeviceAddress<T>(inputs, 1);
auto scale = GetDeviceAddress<float>(inputs, 2);
auto save_mean = GetDeviceAddress<float>(inputs, 3);
auto save_variance = GetDeviceAddress<float>(inputs, 4);
auto reserve_addr = GetDeviceAddress<float>(inputs, 5);
reserve_size_ = inputs[5]->size;
void *bias = nullptr;
T *y = nullptr;
if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) {
bias = GetDeviceAddress<float>(inputs, 6);
y = GetDeviceAddress<T>(inputs, 7);
}
auto dx = GetDeviceAddress<T>(outputs, 0);
auto dscale = GetDeviceAddress<float>(outputs, 1);
auto dbias = GetDeviceAddress<float>(outputs, 2);
T *dz = nullptr;
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) {
dz = GetDeviceAddress<T>(outputs, 3);
}
void *workspace_addr = nullptr;
if (workspace_size_ != 0) {
workspace_addr = GetDeviceAddress<T>(workspace, 0);
}
const float alpha_data_diff = 1;
const float beta_data_diff = 0;
const float alpha_param_diff = 1;
const float beta_param_diff = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnBatchNormalizationBackwardEx(
handle_, mode_, bn_ops_, &alpha_data_diff, &beta_data_diff, &alpha_param_diff,
&beta_param_diff, x_desc_, x, y_desc_, y, dy_desc_, dy, dz_desc_, dz, dx_desc_, dx,
scale_bias_diff_desc_, scale, bias, dscale, dbias, epsilon_, save_mean, save_variance,
activation_desc_, workspace_addr, workspace_size_, reserve_addr, reserve_size_),
"Kernel launch failed");
return true;
}
bool Init(const CNodePtr &kernel_node) override {
MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == kFusedBatchNormGradEx) {
bn_ops_ = CUDNN_BATCHNORM_OPS_BN;
} else if (kernel_name == kFusedBatchNormGradExWithActivation) {
bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
} else if (kernel_name == kFusedBatchNormGradExWithAddAndActivation) {
bn_ops_ = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION;
} else {
MS_LOG(EXCEPTION) << "Invalid kernel name: " << kernel_name;
}
InitResource();
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
epsilon_ = GetAttr<float>(kernel_node, "epsilon");
cudnn_data_type_ = GetCudnnDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN) {
if (input_num != 6) {
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", " << kernel_name << " should be 6";
}
} else {
if (input_num != 8) {
MS_LOG(EXCEPTION) << "input tensor size is " << input_num << ", " << kernel_name << " should be 8";
}
}
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (shape.size() != 4) {
MS_LOG(EXCEPTION) << "tensor shape is " << shape.size() << ", FusedBatchNormGradExGpuKernel should be 4";
}
is_null_input_ = CHECK_NULL_INPUT(shape);
if (is_null_input_) {
MS_LOG(WARNING) << "FusedBatchNormGradExGpuKernel input is null";
InitSizeLists();
return true;
}
std::string format = AnfAlgo::GetInputFormat(kernel_node, 0);
SetTensorDescriptor(format, shape);
InitSizeLists();
return true;
}
protected:
void InitResource() override {
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&x_desc_), "Create x desc failed");
if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&y_desc_), "Create y desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateActivationDescriptor(&activation_desc_),
"Create activation descriptor failed");
}
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dy_desc_), "Create dy desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dx_desc_), "Create dx desc failed");
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&dz_desc_), "Create dz desc failed");
}
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&scale_bias_diff_desc_), "Create para desc failed");
}
void InitSizeLists() override {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(x_desc_, &x_size_), "Get x size failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(scale_bias_diff_desc_, &para_size_),
"Get para size failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
handle_, mode_, bn_ops_, x_desc_, y_desc_, dy_desc_, dz_desc_, dx_desc_,
scale_bias_diff_desc_, activation_desc_, &workspace_size_),
"cudnnGetBatchNormalizationBackwardExWorkspaceSize failed");
}
input_size_list_.push_back(x_size_);
input_size_list_.push_back(x_size_);
input_size_list_.push_back(para_size_);
input_size_list_.push_back(para_size_);
input_size_list_.push_back(para_size_);
input_size_list_.push_back(reserve_size_);
if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) {
input_size_list_.push_back(para_size_);
input_size_list_.push_back(x_size_);
}
output_size_list_.push_back(x_size_);
output_size_list_.push_back(para_size_);
output_size_list_.push_back(para_size_);
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) {
output_size_list_.push_back(x_size_);
}
workspace_size_list_.push_back(workspace_size_);
}
private:
void SetTensorDescriptor(const std::string &format, const std::vector<size_t> &shape) {
cudnnTensorFormat_t cudnn_format;
int batch, channel, height, width;
if (format == kOpFormat_NHWC) {
batch = SizeToInt(shape[0]);
height = SizeToInt(shape[1]);
width = SizeToInt(shape[2]);
channel = SizeToInt(shape[3]);
cudnn_format = CUDNN_TENSOR_NHWC;
} else {
batch = SizeToInt(shape[0]);
channel = SizeToInt(shape[1]);
height = SizeToInt(shape[2]);
width = SizeToInt(shape[3]);
cudnn_format = CUDNN_TENSOR_NCHW;
}
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(x_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
"Set x desc failed");
if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) {
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(y_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
"Set z desc failed");
}
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(dy_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
"Set dy desc failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(dx_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
"Set dx desc failed");
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) {
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(dz_desc_, cudnn_format, cudnn_data_type_, batch, channel, height, width),
"Set z desc failed");
}
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(scale_bias_diff_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, channel, 1, 1),
"Set para desc failed");
if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) {
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetActivationDescriptor(activation_desc_, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0.0),
"cudnnSetActivationDescriptor failed");
}
}
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed");
if (bn_ops_ != CUDNN_BATCHNORM_OPS_BN) {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyActivationDescriptor(activation_desc_),
"Destroy activation descriptor failed");
}
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dy_desc_), "Destroy dy desc failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dx_desc_), "Destroy dx desc failed");
if (bn_ops_ == CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION) {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(dz_desc_), "Destroy z desc failed");
}
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(scale_bias_diff_desc_), "Destroy para desc failed");
}
size_t x_size_;
size_t para_size_;
size_t workspace_size_;
size_t reserve_size_;
cudnnBatchNormMode_t mode_;
cudnnBatchNormOps_t bn_ops_;
double epsilon_;
bool is_null_input_;
cudnnTensorDescriptor_t x_desc_;
cudnnTensorDescriptor_t y_desc_;
cudnnTensorDescriptor_t dy_desc_;
cudnnTensorDescriptor_t dx_desc_;
cudnnTensorDescriptor_t dz_desc_;
cudnnTensorDescriptor_t scale_bias_diff_desc_;
cudnnActivationDescriptor_t activation_desc_;
cudnnHandle_t handle_;
cudnnDataType_t cudnn_data_type_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_FUSED_BATCH_NORM_GRAD_EX_GPU_KERNEL_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册