未验证 提交 0294ab41 编写于 作者: Z zhangkaihuo 提交者: GitHub

Update threshold of bn1d (#49734)

上级 609b50a8
...@@ -18,6 +18,10 @@ limitations under the License. */ ...@@ -18,6 +18,10 @@ limitations under the License. */
namespace phi { namespace phi {
namespace funcs { namespace funcs {
#define CUDNN_PER_ACTIVATION_THRESHOLD 10240
#define CUDNN_SPATIAL_THRESHOLD_TRAIN 880801
#define CUDNN_SPATIAL_THRESHOLD_EVAL 65535
inline void ExtractNCWHD(const phi::DDim &dims, inline void ExtractNCWHD(const phi::DDim &dims,
const DataLayout &data_layout, const DataLayout &data_layout,
int *N, int *N,
......
...@@ -907,15 +907,12 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -907,15 +907,12 @@ void BatchNormGradRawKernel(const Context &ctx,
#else #else
} }
// CUDNN only support small batch size // CUDNN only support small batch size
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
bool use_native_nhwc = bool use_native_nhwc =
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC) d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC)
: false; : false;
const bool use_native_kernel = const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
if (use_native_nhwc || (d_x && d_scale && d_bias)) { if (use_native_nhwc || (d_x && d_scale && d_bias)) {
if (use_native_kernel || use_native_nhwc) { if (use_native_kernel || use_native_nhwc) {
if (x_dims.size() == 2 || use_native_nhwc) { if (x_dims.size() == 2 || use_native_nhwc) {
......
...@@ -722,9 +722,6 @@ void BatchNormKernel(const Context &ctx, ...@@ -722,9 +722,6 @@ void BatchNormKernel(const Context &ctx,
auto handle = ctx.cudnn_handle(); auto handle = ctx.cudnn_handle();
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
// Now, depending on whether we are running test or not, we have two paths. // Now, depending on whether we are running test or not, we have two paths.
// It is training mode when it's not reference AND not using pre-trained // It is training mode when it's not reference AND not using pre-trained
// model. // model.
...@@ -829,7 +826,7 @@ void BatchNormKernel(const Context &ctx, ...@@ -829,7 +826,7 @@ void BatchNormKernel(const Context &ctx,
#else #else
const bool use_native_kernel = const bool use_native_kernel =
(x_dims.size() == 2 || (x_dims.size() == 2 ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_EVAL));
if (use_native_kernel) { if (use_native_kernel) {
const int block_size = 256; const int block_size = 256;
const int grid_size = (N * C * H * W * D + block_size - 1) / block_size; const int grid_size = (N * C * H * W * D + block_size - 1) / block_size;
...@@ -1005,7 +1002,7 @@ void BatchNormKernel(const Context &ctx, ...@@ -1005,7 +1002,7 @@ void BatchNormKernel(const Context &ctx,
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; // const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const bool use_native_kernel = const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
if (use_native_kernel) { if (use_native_kernel) {
dim3 block; dim3 block;
dim3 grid; dim3 grid;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册