未验证 提交 dec9094d 编写于 作者: C Chen Weihang 提交者: GitHub

polish batch norm sig (#40746)

上级 67ffb86e
......@@ -18,10 +18,17 @@ namespace phi {
KernelSignature BatchNormOpArgumentMapping(const ArgumentMappingContext& ctx) {
bool is_test = paddle::any_cast<bool>(ctx.Attr("is_test"));
bool use_global_stats = paddle::any_cast<bool>(ctx.Attr("use_global_stats"));
bool use_global_stats =
ctx.HasAttr("use_global_stats")
? paddle::any_cast<bool>(ctx.Attr("use_global_stats"))
: false;
bool trainable_statistics =
paddle::any_cast<bool>(ctx.Attr("trainable_statistics"));
bool fuse_with_relu = paddle::any_cast<bool>(ctx.Attr("fuse_with_relu"));
ctx.HasAttr("trainable_statistics")
? paddle::any_cast<bool>(ctx.Attr("trainable_statistics"))
: false;
bool fuse_with_relu = ctx.HasAttr("fuse_with_relu")
? paddle::any_cast<bool>(ctx.Attr("fuse_with_relu"))
: false;
// Dispenable `MomentumTensor` is useless now
if (is_test && !use_global_stats && !trainable_statistics &&
!fuse_with_relu) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册