未验证 提交 d7985052 编写于 作者: C ceci3 提交者: GitHub

fix bn/in/squeeze/syncbn extra (#35502)

* fix bn/in/squeeze/syncbn extra

* update bn

* update

* update
上级 3896bdbd
...@@ -257,13 +257,16 @@ void BatchNormOpMaker::Make() { ...@@ -257,13 +257,16 @@ void BatchNormOpMaker::Make() {
AddOutput("ReserveSpace", AddOutput("ReserveSpace",
"Reserve GPU space for triggering the new semi-persistent " "Reserve GPU space for triggering the new semi-persistent "
"NHWC kernel") "NHWC kernel")
.AsDispensable(); .AsDispensable()
.AsExtra();
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false)
.AsExtra();
AddAttr<bool>("fuse_with_relu", AddAttr<bool>("fuse_with_relu",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false)
.AsExtra();
AddAttr<bool>("use_global_stats", AddAttr<bool>("use_global_stats",
"(bool, default false) Whether to use global mean and " "(bool, default false) Whether to use global mean and "
"variance. In inference or test mode, set use_global_stats " "variance. In inference or test mode, set use_global_stats "
......
...@@ -149,11 +149,13 @@ void InstanceNormOpMaker::Make() { ...@@ -149,11 +149,13 @@ void InstanceNormOpMaker::Make() {
AddOutput("SavedMean", AddOutput("SavedMean",
"Mean of the current mini batch, " "Mean of the current mini batch, "
"will apply to output when training") "will apply to output when training")
.AsIntermediate(); .AsIntermediate()
.AsExtra();
AddOutput("SavedVariance", AddOutput("SavedVariance",
"Variance of the current mini batch, " "Variance of the current mini batch, "
"will apply to output when training") "will apply to output when training")
.AsIntermediate(); .AsIntermediate()
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
Instance Normalization. Instance Normalization.
......
...@@ -47,7 +47,9 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT< ...@@ -47,7 +47,9 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
mkldnn::batch_normalization_backward>( mkldnn::batch_normalization_backward>(
mkldnn_engine, ctx.GetPlace()) { mkldnn_engine, ctx.GetPlace()) {
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu"); const bool fuse_with_relu = ctx.HasAttr("fuse_with_relu")
? ctx.Attr<bool>("fuse_with_relu")
: false;
std::vector<std::string> DataLayout_error_msg = {"kNHWC", "kNCHW", std::vector<std::string> DataLayout_error_msg = {"kNHWC", "kNCHW",
"kAnyLayout", "kMKLDNN"}; "kAnyLayout", "kMKLDNN"};
......
...@@ -225,7 +225,7 @@ class Squeeze2Op : public framework::OperatorWithKernel { ...@@ -225,7 +225,7 @@ class Squeeze2Op : public framework::OperatorWithKernel {
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
} }
OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Squeeze2"); if (!ctx->HasOutput("XShape")) return;
std::vector<int64_t> xshape_dims(x_dims.size() + 1); std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0; xshape_dims[0] = 0;
...@@ -323,7 +323,8 @@ class Squeeze2OpMaker : public SqueezeOpMaker { ...@@ -323,7 +323,8 @@ class Squeeze2OpMaker : public SqueezeOpMaker {
AddOutput("XShape", AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will " "XShape is just used to store the shape and lod of X, which will "
"be used in SqueezeGradOp.") "be used in SqueezeGradOp.")
.AsIntermediate(); .AsIntermediate()
.AsExtra();
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册