未验证 提交 d7985052 编写于 作者: C ceci3 提交者: GitHub

fix bn/in/squeeze/syncbn extra (#35502)

* fix bn/in/squeeze/syncbn extra

* update bn

* update

* update
上级 3896bdbd
......@@ -257,13 +257,16 @@ void BatchNormOpMaker::Make() {
AddOutput("ReserveSpace",
"Reserve GPU space for triggering the new semi-persistent "
"NHWC kernel")
.AsDispensable();
.AsDispensable()
.AsExtra();
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
.SetDefault(false)
.AsExtra();
AddAttr<bool>("fuse_with_relu",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
.SetDefault(false)
.AsExtra();
AddAttr<bool>("use_global_stats",
"(bool, default false) Whether to use global mean and "
"variance. In inference or test mode, set use_global_stats "
......
......@@ -149,11 +149,13 @@ void InstanceNormOpMaker::Make() {
AddOutput("SavedMean",
"Mean of the current mini batch, "
"will apply to output when training")
.AsIntermediate();
.AsIntermediate()
.AsExtra();
AddOutput("SavedVariance",
"Variance of the current mini batch, "
"will apply to output when training")
.AsIntermediate();
.AsIntermediate()
.AsExtra();
AddComment(R"DOC(
Instance Normalization.
......
......@@ -47,7 +47,9 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
mkldnn::batch_normalization_backward>(
mkldnn_engine, ctx.GetPlace()) {
const float epsilon = ctx.Attr<float>("epsilon");
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
const bool fuse_with_relu = ctx.HasAttr("fuse_with_relu")
? ctx.Attr<bool>("fuse_with_relu")
: false;
std::vector<std::string> DataLayout_error_msg = {"kNHWC", "kNCHW",
"kAnyLayout", "kMKLDNN"};
......
......@@ -225,7 +225,7 @@ class Squeeze2Op : public framework::OperatorWithKernel {
ctx->ShareLoD("X", "Out");
}
OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Squeeze2");
if (!ctx->HasOutput("XShape")) return;
std::vector<int64_t> xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
......@@ -323,7 +323,8 @@ class Squeeze2OpMaker : public SqueezeOpMaker {
AddOutput("XShape",
"XShape is just used to store the shape and lod of X, which will "
"be used in SqueezeGradOp.")
.AsIntermediate();
.AsIntermediate()
.AsExtra();
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册