未验证 提交 7879477f 编写于 作者: R ronnywang 提交者: GitHub

[ROCM] add cuda kenrel for batch_norm_op (#32393)

上级 49773f36
......@@ -32,6 +32,12 @@ namespace cub = hipcub;
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
#ifdef __HIPCC__
#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim)
#else
#define LAUNCH_BOUNDS(BlockDim)
#endif
namespace paddle {
namespace operators {
......@@ -58,12 +64,10 @@ using DataLayout = framework::DataLayout;
// axis=(n,h,w)))
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDX(const T *x, const T *mean,
const T *variance, const T *ddx,
const T *dy, const T *scale,
const T *ddscale, const int N, const int C,
const int sample_size, const double epsilon,
T *dx) {
__global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDX(
const T *x, const T *mean, const T *variance, const T *ddx, const T *dy,
const T *scale, const T *ddscale, const int N, const int C,
const int sample_size, const double epsilon, T *dx) {
const int outer_size = C;
const int inner_size = N * sample_size;
......@@ -160,12 +164,10 @@ __global__ void DoubleGradComputeDX(const T *x, const T *mean,
// scale * inv_var * (ddx - (x - mean) * inv_var.pow(2) *
// np.mean(ddx * (x - mean), axis=(n,h,w)))
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDDY(const T *x, const T *mean,
const T *variance, const T *ddscale,
const T *ddbias, const T *ddx,
const T *scale, const int N, const int C,
const int sample_size,
const double epsilon, T *ddy) {
__global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDDY(
const T *x, const T *mean, const T *variance, const T *ddscale,
const T *ddbias, const T *ddx, const T *scale, const int N, const int C,
const int sample_size, const double epsilon, T *ddy) {
const int outer_size = C;
const int inner_size = N * sample_size;
......@@ -238,11 +240,10 @@ __global__ void DoubleGradComputeDDY(const T *x, const T *mean,
// inv_var.pow(2) * np.mean(dy * (x-mean), axis=(n,h,w)))) *
// ddx
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDScale(const T *x, const T *mean,
const T *variance, const T *ddx,
const T *dy, const int N, const int C,
const int sample_size,
const double epsilon, T *dscale) {
__global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDScale(
const T *x, const T *mean, const T *variance, const T *ddx, const T *dy,
const int N, const int C, const int sample_size, const double epsilon,
T *dscale) {
const int outer_size = C;
const int inner_size = N * sample_size;
......@@ -302,7 +303,7 @@ __global__ void DoubleGradComputeDScale(const T *x, const T *mean,
// math: dscale = np.sum(ddx * dy, axis=(n,h,w)) * inv_var
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDScaleWithGlobal(
__global__ LAUNCH_BOUNDS(BlockDim) void DoubleGradComputeDScaleWithGlobal(
const T *ddx, const T *variance, const T *dy, const double epsilon,
const int N, const int C, const int sample_size, T *dscale) {
int outer_size = C;
......@@ -422,8 +423,11 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
set_constant(dev_ctx, &scale_tmp, static_cast<T>(1));
}
const T *scale_data = Scale ? Scale->data<T>() : scale_tmp.data<T>();
#ifdef __HIPCC__
const int block = 256;
#else
const int block = 512;
#endif
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid = std::min(C, max_blocks);
......@@ -532,6 +536,5 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
}
}
}
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册