未验证 提交 6f1ec935 编写于 作者: Z zhangkaihuo 提交者: GitHub

Fix bn performance degradation (#50287)

* fix bn performance degradation
上级 8c14b02b
...@@ -783,7 +783,8 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -783,7 +783,8 @@ void BatchNormGradRawKernel(const Context &ctx,
} }
// CUDNN only support small batch size // CUDNN only support small batch size
bool use_native_nhwc = bool use_native_nhwc =
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC) d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC &&
H * W >= CUDNN_SPATIAL_THRESHOLD_EVAL)
: false; : false;
const bool use_native_kernel = const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册