未验证 提交 dec9094d 编写于 作者: C Chen Weihang 提交者: GitHub

polish batch norm sig (#40746)

上级 67ffb86e
...@@ -18,10 +18,17 @@ namespace phi { ...@@ -18,10 +18,17 @@ namespace phi {
KernelSignature BatchNormOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature BatchNormOpArgumentMapping(const ArgumentMappingContext& ctx) {
bool is_test = paddle::any_cast<bool>(ctx.Attr("is_test")); bool is_test = paddle::any_cast<bool>(ctx.Attr("is_test"));
bool use_global_stats = paddle::any_cast<bool>(ctx.Attr("use_global_stats")); bool use_global_stats =
ctx.HasAttr("use_global_stats")
? paddle::any_cast<bool>(ctx.Attr("use_global_stats"))
: false;
bool trainable_statistics = bool trainable_statistics =
paddle::any_cast<bool>(ctx.Attr("trainable_statistics")); ctx.HasAttr("trainable_statistics")
bool fuse_with_relu = paddle::any_cast<bool>(ctx.Attr("fuse_with_relu")); ? paddle::any_cast<bool>(ctx.Attr("trainable_statistics"))
: false;
bool fuse_with_relu = ctx.HasAttr("fuse_with_relu")
? paddle::any_cast<bool>(ctx.Attr("fuse_with_relu"))
: false;
// Dispenable `MomentumTensor` is useless now // Dispenable `MomentumTensor` is useless now
if (is_test && !use_global_stats && !trainable_statistics && if (is_test && !use_global_stats && !trainable_statistics &&
!fuse_with_relu) { !fuse_with_relu) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册