未验证 提交 c58f1a5c 编写于 作者: J Juncheng 提交者: GitHub

Fix cudnn bn param desc (#3253)

* use cudnnDeriveBNTensorDescriptor

* CudnnTensorDescHelper

* fix bn min eps
Co-authored-by: Nguo ran <360112263@qq.com>
上级 993b6d37
......@@ -11,7 +11,7 @@ namespace {
void InferDimSizeAndDataFormat(const ShapeView& x_shape, const int32_t axis, int32_t* n, int32_t* c,
int32_t* h, int32_t* w, cudnnTensorFormat_t* format) {
if (axis != 0 && x_shape.Count(axis + 1) == 1 && CUDNN_VERSION >= 7605) {
if (axis != 0 && x_shape.Count(axis + 1) == 1) {
*n = x_shape.At(0);
*h = x_shape.Count(1, axis);
*w = 1;
......@@ -26,35 +26,65 @@ void InferDimSizeAndDataFormat(const ShapeView& x_shape, const int32_t axis, int
}
}
DataType GetParamDataType(const DataType x_data_type) {
return x_data_type == DataType::kFloat16 ? DataType::kFloat : x_data_type;
void InferXYCudnnTensorDesc(const ShapeView& xy_shape, const DataType& data_type,
const int32_t axis, cudnnTensorDescriptor_t xy_desc) {
int32_t n, c, h, w;
cudnnTensorFormat_t format;
InferDimSizeAndDataFormat(xy_shape, axis, &n, &c, &h, &w, &format);
CudaCheck(cudnnSetTensor4dDescriptor(xy_desc, format, GetCudnnDataType(data_type), n, c, h, w));
}
std::function<void(const user_op::Tensor* tensor)> MakeCheckParamTensorFn(
const int32_t param_dim_size, const DataType param_data_type) {
return [=](const user_op::Tensor* tensor) {
CHECK_EQ(tensor->shape().NumAxes(), 1);
CHECK_EQ(tensor->shape().At(0), param_dim_size);
CHECK_EQ(tensor->data_type(), param_data_type);
};
void InferParamCudnnTensorDesc(const cudnnTensorDescriptor_t xy_desc, cudnnBatchNormMode_t mode,
cudnnTensorDescriptor_t param_desc) {
CudaCheck(cudnnDeriveBNTensorDescriptor(param_desc, xy_desc, mode));
}
class CudnnTensorDescHelper final {
public:
OF_DISALLOW_COPY_AND_MOVE(CudnnTensorDescHelper);
CudnnTensorDescHelper(const ShapeView& xy_shape, const DataType& data_type, const int32_t axis,
cudnnBatchNormMode_t mode) {
CudaCheck(cudnnCreateTensorDescriptor(&xy_desc_));
InferXYCudnnTensorDesc(xy_shape, data_type, axis, xy_desc_);
CudaCheck(cudnnCreateTensorDescriptor(&param_desc_));
InferParamCudnnTensorDesc(xy_desc_, mode, param_desc_);
int n, c, h, w, n_stride, c_stride, h_stride, w_stride;
CudaCheck(cudnnGetTensor4dDescriptor(param_desc_, &param_data_type_, &n, &c, &h, &w, &n_stride,
&c_stride, &h_stride, &w_stride));
param_size_ = c;
}
~CudnnTensorDescHelper() {
CudaCheck(cudnnDestroyTensorDescriptor(param_desc_));
CudaCheck(cudnnDestroyTensorDescriptor(xy_desc_));
}
cudnnTensorDescriptor_t xy_desc() const { return xy_desc_; }
cudnnTensorDescriptor_t param_desc() const { return param_desc_; }
void CheckParamTensor(const user_op::Tensor* tensor) const {
CHECK_EQ(tensor->shape().NumAxes(), 1);
CHECK_EQ(tensor->shape().At(0), param_size_);
CHECK_EQ(GetCudnnDataType(tensor->data_type()), param_data_type_);
}
private:
cudnnTensorDescriptor_t xy_desc_ = nullptr;
cudnnTensorDescriptor_t param_desc_ = nullptr;
cudnnDataType_t param_data_type_;
int32_t param_size_ = 0;
};
size_t InferTrainWorkspaceSize(const ShapeView& x_shape, const DataType data_type,
const int32_t axis) {
#if defined(BN_ENABLE_EX_API)
int32_t n, c, h, w;
cudnnTensorFormat_t format;
InferDimSizeAndDataFormat(x_shape, axis, &n, &c, &h, &w, &format);
CudnnTensorDesc xy_desc(format, data_type, n, c, h, w);
CudnnTensorDesc param_desc(format, GetParamDataType(data_type), 1, c, 1, 1);
const CudnnTensorDescHelper desc_helper(x_shape, data_type, axis,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
size_t size_in_bytes;
cudnnHandle_t handle;
CudaCheck(cudnnCreate(&handle));
CudaCheck(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN, xy_desc.Get(), nullptr,
xy_desc.Get(), param_desc.Get(), nullptr, &size_in_bytes));
handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN, desc_helper.xy_desc(),
nullptr, desc_helper.xy_desc(), desc_helper.param_desc(), nullptr, &size_in_bytes));
CudaCheck(cudnnDestroy(handle));
return std::max(size_in_bytes, static_cast<size_t>(1));
#else
......@@ -71,17 +101,15 @@ size_t InferTrainTmpSize(user_op::InferContext* ctx) {
size_t InferGradWorkspaceSize(const ShapeView& x_shape, const DataType data_type,
const int32_t axis) {
#if defined(BN_ENABLE_EX_API)
int32_t n, c, h, w;
cudnnTensorFormat_t format;
InferDimSizeAndDataFormat(x_shape, axis, &n, &c, &h, &w, &format);
CudnnTensorDesc xy_desc(format, data_type, n, c, h, w);
CudnnTensorDesc param_desc(format, GetParamDataType(data_type), 1, c, 1, 1);
const CudnnTensorDescHelper desc_helper(x_shape, data_type, axis,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
size_t size_in_bytes;
cudnnHandle_t handle;
CudaCheck(cudnnCreate(&handle));
CudaCheck(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN, xy_desc.Get(), nullptr,
xy_desc.Get(), nullptr, xy_desc.Get(), param_desc.Get(), nullptr, &size_in_bytes));
handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CUDNN_BATCHNORM_OPS_BN, desc_helper.xy_desc(),
nullptr, desc_helper.xy_desc(), nullptr, desc_helper.xy_desc(), desc_helper.param_desc(),
nullptr, &size_in_bytes));
CudaCheck(cudnnDestroy(handle));
return std::max(size_in_bytes, static_cast<size_t>(1));
#else
......@@ -120,24 +148,17 @@ class NormalizationInferenceKernel final : public user_op::OpKernel {
CHECK_GE(axis, 0);
CHECK_LT(axis, x->shape().NumAxes());
int32_t n, c, h, w;
cudnnTensorFormat_t format;
InferDimSizeAndDataFormat(x->shape(), axis, &n, &c, &h, &w, &format);
CudnnTensorDesc xy_desc(format, data_type, n, c, h, w);
const DataType param_data_type = GetParamDataType(data_type);
const auto CheckParamTensor = MakeCheckParamTensorFn(c, param_data_type);
CheckParamTensor(gamma);
CheckParamTensor(beta);
CheckParamTensor(moving_mean);
CheckParamTensor(moving_variance);
CudnnTensorDesc param_desc(format, param_data_type, 1, c, 1, 1);
const CudnnTensorDescHelper desc_helper(x->shape(), data_type, axis, CUDNN_BATCHNORM_SPATIAL);
desc_helper.CheckParamTensor(gamma);
desc_helper.CheckParamTensor(beta);
desc_helper.CheckParamTensor(moving_mean);
desc_helper.CheckParamTensor(moving_variance);
CudaCheck(cudnnBatchNormalizationForwardInference(
ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL, CudnnSPOnePtr<T>(),
CudnnSPZeroPtr<T>(), xy_desc.Get(), x->dptr(), xy_desc.Get(), y->mut_dptr(),
param_desc.Get(), gamma->dptr(), beta->dptr(), moving_mean->dptr(), moving_variance->dptr(),
epsilon));
CudnnSPZeroPtr<T>(), desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), y->mut_dptr(),
desc_helper.param_desc(), gamma->dptr(), beta->dptr(), moving_mean->dptr(),
moving_variance->dptr(), epsilon));
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
......@@ -184,53 +205,50 @@ class NormalizationTrainKernel final : public user_op::OpKernel {
CHECK_GE(axis, 0);
CHECK_LT(axis, x->shape().NumAxes());
int32_t n, c, h, w;
cudnnTensorFormat_t format;
InferDimSizeAndDataFormat(x->shape(), axis, &n, &c, &h, &w, &format);
CudnnTensorDesc xy_desc(format, data_type, n, c, h, w);
const DataType param_data_type = GetParamDataType(data_type);
const auto CheckParamTensor = MakeCheckParamTensorFn(c, param_data_type);
CheckParamTensor(gamma);
CheckParamTensor(beta);
CheckParamTensor(moving_mean);
CheckParamTensor(moving_variance);
CheckParamTensor(mean);
CheckParamTensor(inv_variance);
CudnnTensorDesc param_desc(format, param_data_type, 1, c, 1, 1);
const CudnnTensorDescHelper desc_helper(x->shape(), data_type, axis,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
desc_helper.CheckParamTensor(gamma);
desc_helper.CheckParamTensor(beta);
desc_helper.CheckParamTensor(moving_mean);
desc_helper.CheckParamTensor(moving_variance);
desc_helper.CheckParamTensor(mean);
desc_helper.CheckParamTensor(inv_variance);
#if defined(BN_ENABLE_EX_API)
size_t workspace_size;
CudaCheck(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT,
CUDNN_BATCHNORM_OPS_BN, xy_desc.Get(), nullptr, xy_desc.Get(), param_desc.Get(), nullptr,
&workspace_size));
CUDNN_BATCHNORM_OPS_BN, desc_helper.xy_desc(), nullptr, desc_helper.xy_desc(),
desc_helper.param_desc(), nullptr, &workspace_size));
size_t reserve_space_size;
CudaCheck(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT,
CUDNN_BATCHNORM_OPS_BN, nullptr, xy_desc.Get(), &reserve_space_size));
CUDNN_BATCHNORM_OPS_BN, nullptr, desc_helper.xy_desc(), &reserve_space_size));
auto* workspace = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
if (reserve_space_size == 0 && workspace_size <= workspace->shape().elem_cnt()) {
CudaCheck(cudnnBatchNormalizationForwardTrainingEx(
ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT,
CUDNN_BATCHNORM_OPS_BN, CudnnSPOnePtr<T>(), CudnnSPZeroPtr<T>(), xy_desc.Get(), x->dptr(),
nullptr, nullptr, xy_desc.Get(), y->mut_dptr(), param_desc.Get(), gamma->dptr(),
beta->dptr(), 1.0 - momentum, moving_mean->mut_dptr(), moving_variance->mut_dptr(),
epsilon, mean->mut_dptr(), inv_variance->mut_dptr(), nullptr, workspace->mut_dptr(),
workspace->shape().elem_cnt(), nullptr, 0));
CUDNN_BATCHNORM_OPS_BN, CudnnSPOnePtr<T>(), CudnnSPZeroPtr<T>(), desc_helper.xy_desc(),
x->dptr(), nullptr, nullptr, desc_helper.xy_desc(), y->mut_dptr(),
desc_helper.param_desc(), gamma->dptr(), beta->dptr(), 1.0 - momentum,
moving_mean->mut_dptr(), moving_variance->mut_dptr(), epsilon, mean->mut_dptr(),
inv_variance->mut_dptr(), nullptr, workspace->mut_dptr(), workspace->shape().elem_cnt(),
nullptr, 0));
} else {
CudaCheck(cudnnBatchNormalizationForwardTraining(
ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CudnnSPOnePtr<T>(),
CudnnSPZeroPtr<T>(), xy_desc.Get(), x->dptr(), xy_desc.Get(), y->mut_dptr(),
param_desc.Get(), gamma->dptr(), beta->dptr(), 1.0 - momentum, moving_mean->mut_dptr(),
moving_variance->mut_dptr(), epsilon, mean->mut_dptr(), inv_variance->mut_dptr()));
CudnnSPZeroPtr<T>(), desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(),
y->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), beta->dptr(), 1.0 - momentum,
moving_mean->mut_dptr(), moving_variance->mut_dptr(), epsilon, mean->mut_dptr(),
inv_variance->mut_dptr()));
}
#else
CudaCheck(cudnnBatchNormalizationForwardTraining(
ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CudnnSPOnePtr<T>(),
CudnnSPZeroPtr<T>(), xy_desc.Get(), x->dptr(), xy_desc.Get(), y->mut_dptr(),
param_desc.Get(), gamma->dptr(), beta->dptr(), 1.0 - momentum, moving_mean->mut_dptr(),
moving_variance->mut_dptr(), epsilon, mean->mut_dptr(), inv_variance->mut_dptr()));
CudnnSPZeroPtr<T>(), desc_helper.xy_desc(), x->dptr(), desc_helper.xy_desc(), y->mut_dptr(),
desc_helper.param_desc(), gamma->dptr(), beta->dptr(), 1.0 - momentum,
moving_mean->mut_dptr(), moving_variance->mut_dptr(), epsilon, mean->mut_dptr(),
inv_variance->mut_dptr()));
#endif
}
......@@ -264,53 +282,49 @@ class NormalizationGradUserKernel final : public user_op::OpKernel {
CHECK_GE(axis, 0);
CHECK_LT(axis, x->shape().NumAxes());
int32_t n, c, h, w;
cudnnTensorFormat_t format;
InferDimSizeAndDataFormat(x->shape(), axis, &n, &c, &h, &w, &format);
CudnnTensorDesc xy_desc(format, data_type, n, c, h, w);
const DataType param_data_type = GetParamDataType(data_type);
const auto CheckParamTensor = MakeCheckParamTensorFn(c, param_data_type);
CheckParamTensor(gamma);
CheckParamTensor(gamma_diff);
CheckParamTensor(beta_diff);
CudnnTensorDesc param_desc(format, param_data_type, 1, c, 1, 1);
const CudnnTensorDescHelper desc_helper(x->shape(), data_type, axis,
CUDNN_BATCHNORM_SPATIAL_PERSISTENT);
desc_helper.CheckParamTensor(gamma);
desc_helper.CheckParamTensor(gamma_diff);
desc_helper.CheckParamTensor(beta_diff);
desc_helper.CheckParamTensor(mean);
desc_helper.CheckParamTensor(inv_variance);
#if defined(BN_ENABLE_EX_API)
size_t workspace_size;
CudaCheck(cudnnGetBatchNormalizationBackwardExWorkspaceSize(
ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT,
CUDNN_BATCHNORM_OPS_BN, xy_desc.Get(), nullptr, xy_desc.Get(), nullptr, xy_desc.Get(),
param_desc.Get(), nullptr, &workspace_size));
CUDNN_BATCHNORM_OPS_BN, desc_helper.xy_desc(), nullptr, desc_helper.xy_desc(), nullptr,
desc_helper.xy_desc(), desc_helper.param_desc(), nullptr, &workspace_size));
size_t reserve_space_size;
CudaCheck(cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT,
CUDNN_BATCHNORM_OPS_BN, nullptr, xy_desc.Get(), &reserve_space_size));
CUDNN_BATCHNORM_OPS_BN, nullptr, desc_helper.xy_desc(), &reserve_space_size));
auto* workspace = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
if (reserve_space_size == 0 && workspace_size <= workspace->shape().elem_cnt()) {
CudaCheck(cudnnBatchNormalizationBackwardEx(
ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT,
CUDNN_BATCHNORM_OPS_BN, CudnnSPOnePtr<T>(), CudnnSPZeroPtr<T>(), CudnnSPOnePtr<T>(),
CudnnSPZeroPtr<T>(), xy_desc.Get(), x->dptr(), nullptr, nullptr, xy_desc.Get(),
dy->dptr(), nullptr, nullptr, xy_desc.Get(), dx->mut_dptr(), param_desc.Get(),
gamma->dptr(), nullptr, gamma_diff->mut_dptr(), beta_diff->mut_dptr(), epsilon,
mean->dptr(), inv_variance->dptr(), nullptr, workspace->mut_dptr(),
workspace->shape().elem_cnt(), nullptr, 0));
CudnnSPZeroPtr<T>(), desc_helper.xy_desc(), x->dptr(), nullptr, nullptr,
desc_helper.xy_desc(), dy->dptr(), nullptr, nullptr, desc_helper.xy_desc(),
dx->mut_dptr(), desc_helper.param_desc(), gamma->dptr(), nullptr, gamma_diff->mut_dptr(),
beta_diff->mut_dptr(), epsilon, mean->dptr(), inv_variance->dptr(), nullptr,
workspace->mut_dptr(), workspace->shape().elem_cnt(), nullptr, 0));
} else {
CudaCheck(cudnnBatchNormalizationBackward(
ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CudnnSPOnePtr<T>(),
CudnnSPZeroPtr<T>(), CudnnSPOnePtr<T>(), CudnnSPZeroPtr<T>(), xy_desc.Get(), x->dptr(),
xy_desc.Get(), dy->dptr(), xy_desc.Get(), dx->mut_dptr(), param_desc.Get(), gamma->dptr(),
gamma_diff->mut_dptr(), beta_diff->mut_dptr(), epsilon, mean->dptr(),
inv_variance->dptr()));
CudnnSPZeroPtr<T>(), CudnnSPOnePtr<T>(), CudnnSPZeroPtr<T>(), desc_helper.xy_desc(),
x->dptr(), desc_helper.xy_desc(), dy->dptr(), desc_helper.xy_desc(), dx->mut_dptr(),
desc_helper.param_desc(), gamma->dptr(), gamma_diff->mut_dptr(), beta_diff->mut_dptr(),
epsilon, mean->dptr(), inv_variance->dptr()));
}
#else
CudaCheck(cudnnBatchNormalizationBackward(
ctx->device_ctx()->cudnn_handle(), CUDNN_BATCHNORM_SPATIAL_PERSISTENT, CudnnSPOnePtr<T>(),
CudnnSPZeroPtr<T>(), CudnnSPOnePtr<T>(), CudnnSPZeroPtr<T>(), xy_desc.Get(), x->dptr(),
xy_desc.Get(), dy->dptr(), xy_desc.Get(), dx->mut_dptr(), param_desc.Get(), gamma->dptr(),
gamma_diff->mut_dptr(), beta_diff->mut_dptr(), epsilon, mean->dptr(),
inv_variance->dptr()));
CudnnSPZeroPtr<T>(), CudnnSPOnePtr<T>(), CudnnSPZeroPtr<T>(), desc_helper.xy_desc(),
x->dptr(), desc_helper.xy_desc(), dy->dptr(), desc_helper.xy_desc(), dx->mut_dptr(),
desc_helper.param_desc(), gamma->dptr(), gamma_diff->mut_dptr(), beta_diff->mut_dptr(),
epsilon, mean->dptr(), inv_variance->dptr()));
#endif
}
......
......@@ -346,7 +346,7 @@ def test_nn_batchnorm(test_case):
arg_dict["input_shape"] = [(2, 4, 3, 5)]
arg_dict["data_type"] = ["float32"]
arg_dict["axis"] = [1, -1]
arg_dict["epsilon"] = [1e-5, 1e-4]
arg_dict["epsilon"] = [1.001e-5, 1e-4]
for arg in GenArgDict(arg_dict):
CompareNnBnWithTensorFlow(**arg)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册