未验证 提交 bbe0fdb0 编写于 作者: Y Yao Zihang 提交者: GitHub

Fix cudnn error for BatchNorm1D kernel (#43072)

上级 b2912939
......@@ -542,70 +542,60 @@ void BatchNormGradRawKernel(const Context &ctx,
// This branch calls CUDNN APIs
if (d_x && d_scale && d_bias) {
bool called = false;
#if CUDNN_VERSION_MIN(7, 4, 1)
called = true;
size_t workspace_size = 0;
void *workspace_ptr = nullptr;
DenseTensor workspace_tensor;
auto reserve_space_size = reserve_space->memory_size();
// --------------- cudnn batchnorm workspace ---------------
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::
cudnnGetBatchNormalizationBackwardExWorkspaceSize(
/*handle=*/ctx.cudnn_handle(),
/*mode=*/mode_,
/*bnIps=*/CUDNN_BATCHNORM_OPS_BN,
/*xDesc=*/data_desc_,
/*yDesc=*/data_desc_,
/*dyDesc=*/data_desc_,
/*dzDesc=*/nullptr,
/*dxDesc=*/data_desc_,
/*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
/*activationDesc=*/nullptr,
/*sizeInBytes=*/&workspace_size));
workspace_tensor.Resize({static_cast<int64_t>(workspace_size)});
workspace_ptr =
static_cast<void *>(ctx.template Alloc<uint8_t>(&workspace_tensor));
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnBatchNormalizationBackwardEx(
/*handle=*/ctx.cudnn_handle(),
/*mode=*/mode_,
/*bnOps=*/CUDNN_BATCHNORM_OPS_BN,
/*alphaDataDiff=*/CudnnDataType<T>::kOne(),
/*betaDataDiff=*/CudnnDataType<T>::kZero(),
/*alphaParamDiff=*/CudnnDataType<T>::kOne(),
/*betaParamDiff=*/CudnnDataType<T>::kZero(),
/*xDesc=*/data_desc_,
/*xData=*/transformed_x.template data<T>(),
/*yDesc=*/nullptr,
/*yData=*/nullptr,
/*dyDesc=*/data_desc_,
/*dyData=*/transformed_d_y.template data<T>(),
/*dzDesc=*/nullptr,
/*dzData=*/nullptr,
/*dxDesc=*/data_desc_,
/*dxData=*/ctx.template Alloc<T>(&transformed_d_x),
/*dBnScaleBiasDesc=*/bn_param_desc_,
/*bnScaleData=*/scale.template data<BatchNormParamType<T>>(),
/*bnBiasData=*/nullptr,
/*dBnScaleData=*/
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
/*dBnBiasData=*/ctx.template Alloc<BatchNormParamType<T>>(d_bias),
/*epsilon=*/epsilon,
/*savedMean=*/saved_mean_data,
/*savedInvVariance=*/saved_var_data,
/*activationDesc=*/nullptr,
/*workspace=*/workspace_ptr,
/*workSpaceSizeInBytes=*/workspace_size,
/*reserveSpace=*/
const_cast<uint8_t *>(reserve_space->template data<uint8_t>()),
/*reserveSpaceSizeInBytes=*/reserve_space_size));
#endif // CUDNN_VERSION_MIN(7, 4, 1)
if (!called) {
#ifdef PADDLE_WITH_HIP
if (compute_format == DataLayout::kNCHW) {
BNBackward<T, block, DataLayout::kNCHW>
<<<grid2, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
scale.template data<BatchNormParamType<T>>(),
saved_mean_data,
saved_var_data,
C,
N,
H * W * D,
epsilon,
transformed_d_x.template data<T>(),
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias));
} else {
BNBackward<T, block, DataLayout::kNHWC>
<<<grid2, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
scale.template data<BatchNormParamType<T>>(),
saved_mean_data,
saved_var_data,
C,
N,
H * W * D,
epsilon,
transformed_d_x.template data<T>(),
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias));
}
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenBatchNormalizationBackward(
// dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
// CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
// CudnnDataType<T>::kZero(), data_desc_,
// transformed_x.template data<T>(), data_desc_,
// transformed_d_y.template data<T>(), data_desc_,
// transformed_d_x.template mutable_data<T>(ctx.GetPlace()),
// bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
// d_scale->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace()),
// d_bias->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace()),
// epsilon, saved_mean_data, saved_var_data));
#else
// CUDNN PER_ACTIVATION mode only support small batch size
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const bool use_native_kernel =
(x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD);
if (use_native_kernel) {
if (compute_format == DataLayout::kNCHW) {
BNBackward<T, block, DataLayout::kNCHW>
<<<grid2, block, 0, ctx.stream()>>>(
......@@ -637,22 +627,67 @@ void BatchNormGradRawKernel(const Context &ctx,
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias));
}
} else {
#if CUDNN_VERSION_MIN(7, 4, 1)
size_t workspace_size = 0;
void *workspace_ptr = nullptr;
DenseTensor workspace_tensor;
auto reserve_space_size = reserve_space->memory_size();
// --------------- cudnn batchnorm workspace ---------------
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::
cudnnGetBatchNormalizationBackwardExWorkspaceSize(
/*handle=*/ctx.cudnn_handle(),
/*mode=*/mode_,
/*bnIps=*/CUDNN_BATCHNORM_OPS_BN,
/*xDesc=*/data_desc_,
/*yDesc=*/data_desc_,
/*dyDesc=*/data_desc_,
/*dzDesc=*/nullptr,
/*dxDesc=*/data_desc_,
/*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
/*activationDesc=*/nullptr,
/*sizeInBytes=*/&workspace_size));
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenBatchNormalizationBackward(
// dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
// CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
// CudnnDataType<T>::kZero(), data_desc_,
// transformed_x.template data<T>(), data_desc_,
// transformed_d_y.template data<T>(), data_desc_,
// transformed_d_x.template mutable_data<T>(ctx.GetPlace()),
// bn_param_desc_, scale->template data<BatchNormParamType<T>>(),
// d_scale->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace()),
// d_bias->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace()),
// epsilon, saved_mean_data, saved_var_data));
workspace_tensor.Resize({static_cast<int64_t>(workspace_size)});
workspace_ptr =
static_cast<void *>(ctx.template Alloc<uint8_t>(&workspace_tensor));
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnBatchNormalizationBackwardEx(
/*handle=*/ctx.cudnn_handle(),
/*mode=*/mode_,
/*bnOps=*/CUDNN_BATCHNORM_OPS_BN,
/*alphaDataDiff=*/CudnnDataType<T>::kOne(),
/*betaDataDiff=*/CudnnDataType<T>::kZero(),
/*alphaParamDiff=*/CudnnDataType<T>::kOne(),
/*betaParamDiff=*/CudnnDataType<T>::kZero(),
/*xDesc=*/data_desc_,
/*xData=*/transformed_x.template data<T>(),
/*yDesc=*/nullptr,
/*yData=*/nullptr,
/*dyDesc=*/data_desc_,
/*dyData=*/transformed_d_y.template data<T>(),
/*dzDesc=*/nullptr,
/*dzData=*/nullptr,
/*dxDesc=*/data_desc_,
/*dxData=*/ctx.template Alloc<T>(&transformed_d_x),
/*dBnScaleBiasDesc=*/bn_param_desc_,
/*bnScaleData=*/scale.template data<BatchNormParamType<T>>(),
/*bnBiasData=*/nullptr,
/*dBnScaleData=*/
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
/*dBnBiasData=*/
ctx.template Alloc<BatchNormParamType<T>>(d_bias),
/*epsilon=*/epsilon,
/*savedMean=*/saved_mean_data,
/*savedInvVariance=*/saved_var_data,
/*activationDesc=*/nullptr,
/*workspace=*/workspace_ptr,
/*workSpaceSizeInBytes=*/workspace_size,
/*reserveSpace=*/
const_cast<uint8_t *>(reserve_space->template data<uint8_t>()),
/*reserveSpaceSizeInBytes=*/reserve_space_size));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnBatchNormalizationBackward(
......@@ -675,8 +710,9 @@ void BatchNormGradRawKernel(const Context &ctx,
epsilon,
saved_mean_data,
saved_var_data));
#endif
#endif // CUDNN_VERSION_MIN(7, 4, 1)
}
#endif
if (data_layout == DataLayout::kNHWC &&
compute_format == DataLayout::kNCHW) {
......
......@@ -446,90 +446,81 @@ void BatchNormKernel(const Context &ctx,
paddle::framework::TensorCopy(x, ctx.GetPlace(), y);
} else {
double this_factor = 1. - momentum;
bool called = false;
#if CUDNN_VERSION_MIN(7, 4, 1)
called = true;
size_t workspace_size = 0;
size_t reserve_space_size = 0;
void *reserve_space_ptr = nullptr;
void *workspace_ptr = nullptr;
DenseTensor workspace_tensor;
DenseTensor reserve_space_tensor;
// Create reserve space and workspace for batch norm.
// Create tensor for each batchnorm op, it will be used in the
// backward. Thus this tensor shouldn't be temp.
// auto *reserve_space = ctx.Output<Tensor>("ReserveSpace");
if (reserve_space == nullptr) {
reserve_space = &reserve_space_tensor;
}
PADDLE_ENFORCE_NOT_NULL(
reserve_space,
phi::errors::NotFound(
"The argument ReserveSpace of batch_norm op is not found."));
// --------------- cudnn batchnorm workspace ---------------
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
/*handle=*/handle,
/*mode=*/mode_,
/*bnIps=*/CUDNN_BATCHNORM_OPS_BN,
/*xDesc=*/data_desc_,
/*zDesc=*/nullptr,
/*yDesc=*/data_desc_,
/*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
/*activationDesc=*/nullptr,
/*sizeInBytes=*/&workspace_size));
// -------------- cudnn batchnorm reserve space --------------
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::
cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
/*handle=*/handle,
/*mode=*/mode_,
/*bnOps=*/CUDNN_BATCHNORM_OPS_BN,
/*activationDesc=*/nullptr,
/*xDesc=*/data_desc_,
/*sizeInBytes=*/&reserve_space_size));
reserve_space->Resize({static_cast<int64_t>(reserve_space_size)});
reserve_space_ptr =
static_cast<void *>(ctx.template Alloc<uint8_t>(reserve_space));
workspace_tensor.Resize({static_cast<int64_t>(workspace_size)});
workspace_ptr =
static_cast<void *>(ctx.template Alloc<uint8_t>(&workspace_tensor));
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnBatchNormalizationForwardTrainingEx(
handle,
mode_,
CUDNN_BATCHNORM_OPS_BN,
CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(),
data_desc_,
transformed_x.template data<T>(),
nullptr,
nullptr,
data_desc_,
transformed_y.template data<T>(),
bn_param_desc_,
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
this_factor,
ctx.template Alloc<BatchNormParamType<T>>(mean_out),
ctx.template Alloc<BatchNormParamType<T>>(variance_out),
epsilon,
ctx.template Alloc<BatchNormParamType<T>>(saved_mean),
ctx.template Alloc<BatchNormParamType<T>>(saved_variance),
nullptr,
workspace_ptr,
workspace_size,
reserve_space_ptr,
reserve_space_size));
#endif // CUDNN_VERSION_MIN(7, 4, 1)
if (!called) {
#ifdef PADDLE_WITH_HIP
const int num = transformed_x.numel();
const int block = 256;
const int num = transformed_x.numel();
const int block = 256;
const int max_threads = ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
const int grid = std::min(C, max_blocks);
if (compute_format == DataLayout::kNCHW) {
BNForwardTraining<T, block, DataLayout::kNCHW>
<<<grid, block, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
this_factor,
transformed_y.template data<T>(),
mean_out->template data<BatchNormParamType<T>>(),
variance_out->template data<BatchNormParamType<T>>(),
saved_mean->template data<BatchNormParamType<T>>(),
saved_variance->template data<BatchNormParamType<T>>());
} else {
BNForwardTraining<T, block, DataLayout::kNHWC>
<<<grid, block, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
this_factor,
transformed_y.template data<T>(),
mean_out->template data<BatchNormParamType<T>>(),
variance_out->template data<BatchNormParamType<T>>(),
saved_mean->template data<BatchNormParamType<T>>(),
saved_variance->template data<BatchNormParamType<T>>());
}
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenBatchNormalizationForwardTraining(
// handle, mode_, const_cast<void *>(static_cast<const void *>(
// CudnnDataType<T>::kOne())),
// const_cast<void *>(
// static_cast<const void *>(CudnnDataType<T>::kZero())),
// data_desc_,
// static_cast<const void *>(transformed_x.template data<T>()),
// data_desc_,
// static_cast<void *>(
// transformed_y.template mutable_data<T>(ctx.GetPlace())),
// bn_param_desc_,
// const_cast<void *>(static_cast<const void *>(
// scale->template data<BatchNormParamType<T>>())),
// const_cast<void *>(static_cast<const void *>(
// bias->template data<BatchNormParamType<T>>())),
// this_factor,
// static_cast<void *>(
// mean_out->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace())),
// static_cast<void *>(variance_out->template mutable_data<
// BatchNormParamType<T>>(ctx.GetPlace())),
// epsilon,
// static_cast<void *>(
// saved_mean->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace())),
// static_cast<void *>(saved_variance->template mutable_data<
// BatchNormParamType<T>>(ctx.GetPlace()))));
#else
// CUDNN PER_ACTIVATION mode only support small batch size
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const bool use_native_kernel =
(x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD);
if (use_native_kernel) {
const int block = 512;
const int max_threads = ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
const int grid = std::min(C, max_blocks);
......@@ -566,35 +557,83 @@ void BatchNormKernel(const Context &ctx,
saved_mean->template data<BatchNormParamType<T>>(),
saved_variance->template data<BatchNormParamType<T>>());
}
// TODO(wangran16): wait for MIOpen to improve the performance of BN
// PADDLE_ENFORCE_GPU_SUCCESS(
// platform::dynload::miopenBatchNormalizationForwardTraining(
// handle, mode_, const_cast<void *>(static_cast<const void *>(
// CudnnDataType<T>::kOne())),
// const_cast<void *>(
// static_cast<const void *>(CudnnDataType<T>::kZero())),
// data_desc_,
// static_cast<const void *>(transformed_x.template data<T>()),
// data_desc_,
// static_cast<void *>(
// transformed_y.template mutable_data<T>(ctx.GetPlace())),
// bn_param_desc_,
// const_cast<void *>(static_cast<const void *>(
// scale->template data<BatchNormParamType<T>>())),
// const_cast<void *>(static_cast<const void *>(
// bias->template data<BatchNormParamType<T>>())),
// this_factor,
// static_cast<void *>(
// mean_out->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace())),
// static_cast<void *>(variance_out->template mutable_data<
// BatchNormParamType<T>>(ctx.GetPlace())),
// epsilon,
// static_cast<void *>(
// saved_mean->template mutable_data<BatchNormParamType<T>>(
// ctx.GetPlace())),
// static_cast<void *>(saved_variance->template mutable_data<
// BatchNormParamType<T>>(ctx.GetPlace()))));
} else {
#if CUDNN_VERSION_MIN(7, 4, 1)
size_t workspace_size = 0;
size_t reserve_space_size = 0;
void *reserve_space_ptr = nullptr;
void *workspace_ptr = nullptr;
DenseTensor workspace_tensor;
DenseTensor reserve_space_tensor;
// Create reserve space and workspace for batch norm.
// Create tensor for each batchnorm op, it will be used in the
// backward. Thus this tensor shouldn't be temp.
// auto *reserve_space = ctx.Output<Tensor>("ReserveSpace");
if (reserve_space == nullptr) {
reserve_space = &reserve_space_tensor;
}
PADDLE_ENFORCE_NOT_NULL(
reserve_space,
phi::errors::NotFound(
"The argument ReserveSpace of batch_norm op is not found."));
// --------------- cudnn batchnorm workspace ---------------
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
/*handle=*/handle,
/*mode=*/mode_,
/*bnIps=*/CUDNN_BATCHNORM_OPS_BN,
/*xDesc=*/data_desc_,
/*zDesc=*/nullptr,
/*yDesc=*/data_desc_,
/*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
/*activationDesc=*/nullptr,
/*sizeInBytes=*/&workspace_size));
// -------------- cudnn batchnorm reserve space --------------
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::
cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
/*handle=*/handle,
/*mode=*/mode_,
/*bnOps=*/CUDNN_BATCHNORM_OPS_BN,
/*activationDesc=*/nullptr,
/*xDesc=*/data_desc_,
/*sizeInBytes=*/&reserve_space_size));
reserve_space->Resize({static_cast<int64_t>(reserve_space_size)});
reserve_space_ptr =
static_cast<void *>(ctx.template Alloc<uint8_t>(reserve_space));
workspace_tensor.Resize({static_cast<int64_t>(workspace_size)});
workspace_ptr =
static_cast<void *>(ctx.template Alloc<uint8_t>(&workspace_tensor));
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnBatchNormalizationForwardTrainingEx(
handle,
mode_,
CUDNN_BATCHNORM_OPS_BN,
CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(),
data_desc_,
transformed_x.template data<T>(),
nullptr,
nullptr,
data_desc_,
transformed_y.template data<T>(),
bn_param_desc_,
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
this_factor,
ctx.template Alloc<BatchNormParamType<T>>(mean_out),
ctx.template Alloc<BatchNormParamType<T>>(variance_out),
epsilon,
ctx.template Alloc<BatchNormParamType<T>>(saved_mean),
ctx.template Alloc<BatchNormParamType<T>>(saved_variance),
nullptr,
workspace_ptr,
workspace_size,
reserve_space_ptr,
reserve_space_size));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnBatchNormalizationForwardTraining(
......@@ -615,8 +654,9 @@ void BatchNormKernel(const Context &ctx,
epsilon,
ctx.template Alloc<BatchNormParamType<T>>(saved_mean),
ctx.template Alloc<BatchNormParamType<T>>(saved_variance)));
#endif
#endif // CUDNN_VERSION_MIN(7, 4, 1)
}
#endif
}
}
......
......@@ -110,11 +110,43 @@ class TestBatchNorm(unittest.TestCase):
y.backward()
return y.numpy(), x1.gradient()
x = np.random.randn(*shape).astype("float32")
y1, g1 = compute_v1(x)
y2, g2 = compute_v2(x)
self.assertTrue(np.allclose(g1, g2))
self.assertTrue(np.allclose(y1, y2))
x = np.random.randn(*shape).astype("float32")
y1, g1 = compute_v1(x)
y2, g2 = compute_v2(x)
self.assertTrue(np.allclose(g1, g2))
self.assertTrue(np.allclose(y1, y2))
def test_eager_api_1d(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
shape = [200000, 4]
def compute_v1(x):
with fluid.dygraph.guard(p):
bn = fluid.dygraph.BatchNorm(shape[1])
x1 = paddle.to_tensor(x)
x1.stop_gradient = False
y = bn(x1)
y.backward()
return y.numpy(), x1.gradient()
def compute_v2(x):
with fluid.dygraph.guard(p):
with _test_eager_guard():
bn = paddle.nn.BatchNorm1D(shape[1])
x1 = paddle.to_tensor(x)
x1.stop_gradient = False
y = bn(x1)
y.backward()
return y.numpy(), x1.gradient()
x = np.random.randn(*shape).astype("float32")
y1, g1 = compute_v1(x)
y2, g2 = compute_v2(x)
self.assertTrue(np.allclose(g1, g2))
self.assertTrue(np.allclose(y1, y2))
def test_dygraph(self):
places = [fluid.CPUPlace()]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册