未验证 提交 0294ab41 编写于 作者: Z zhangkaihuo 提交者: GitHub

Update threshold of bn1d (#49734)

上级 609b50a8
......@@ -18,6 +18,10 @@ limitations under the License. */
namespace phi {
namespace funcs {
#define CUDNN_PER_ACTIVATION_THRESHOLD 10240
#define CUDNN_SPATIAL_THRESHOLD_TRAIN 880801
#define CUDNN_SPATIAL_THRESHOLD_EVAL 65535
inline void ExtractNCWHD(const phi::DDim &dims,
const DataLayout &data_layout,
int *N,
......
......@@ -907,15 +907,12 @@ void BatchNormGradRawKernel(const Context &ctx,
#else
}
// CUDNN only support small batch size
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
bool use_native_nhwc =
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC)
: false;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
if (use_native_nhwc || (d_x && d_scale && d_bias)) {
if (use_native_kernel || use_native_nhwc) {
if (x_dims.size() == 2 || use_native_nhwc) {
......
......@@ -722,9 +722,6 @@ void BatchNormKernel(const Context &ctx,
auto handle = ctx.cudnn_handle();
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
// Now, depending on whether we are running test or not, we have two paths.
// It is training mode when it's not reference AND not using pre-trained
// model.
......@@ -829,7 +826,7 @@ void BatchNormKernel(const Context &ctx,
#else
const bool use_native_kernel =
(x_dims.size() == 2 ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_EVAL));
if (use_native_kernel) {
const int block_size = 256;
const int grid_size = (N * C * H * W * D + block_size - 1) / block_size;
......@@ -1005,7 +1002,7 @@ void BatchNormKernel(const Context &ctx,
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
if (use_native_kernel) {
dim3 block;
dim3 grid;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册