未验证 提交 90e9a486 编写于 作者: W wenbin 提交者: GitHub

CE fix (#38324)

* CE fix

* more format
上级 c9bc2758
......@@ -130,7 +130,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() {
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/})
.IsStringIn({"NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("affine_channel"))
......@@ -148,7 +148,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() {
.IsTensor()
.End()
.AddAttr("data_layout")
.IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/})
.IsStringIn({"NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
......@@ -197,6 +197,13 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
GET_CONV_BN_NODES(conv_ac_pattern);
auto data_format = conv->Op()->GetAttrIfExists<std::string>("data_format");
if (data_format == "AnyLayout") {
LOG_FIRST_N(WARNING, 1) << "conv_affine_channel_fuse_pass is enabled, "
"it's wrong if data_format of conv is not "
"NCHW.";
}
// Get affine_channel bias for resizing eltwise_y!
auto* ac_bias_tensor =
scope->FindVar(ac_bias->Name())->GetMutable<LoDTensor>();
......@@ -282,7 +289,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() {
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/})
.IsStringIn({"NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("affine_channel"))
.AddInput("X")
......@@ -299,7 +306,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() {
.IsTensor()
.End()
.AddAttr("data_layout")
.IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/})
.IsStringIn({"NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
......@@ -347,6 +354,12 @@ void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
VLOG(4) << "handle ConvBN fuse";
GET_CONV_BN_NODES(conv_ac_pattern);
auto data_format = conv->Op()->GetAttrIfExists<std::string>("data_format");
if (data_format == "AnyLayout") {
LOG_FIRST_N(WARNING, 1) << "conv_eltwiseadd_affine_channel_fuse_pass is "
"enabled, it's wrong if data_format of conv "
"is not NCHW.";
}
// OPERATORS
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_ac_pattern);
// BIAS inputs
......
......@@ -77,7 +77,7 @@ ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() {
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NHWC", "NCHW"})
.IsStringIn({"NHWC", "NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
......
......@@ -57,7 +57,7 @@ ConvElementwiseAddFusePass::ConvElementwiseAddFusePass() {
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/})
.IsStringIn({"NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
......@@ -97,6 +97,13 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
GET_NODES;
auto base_op_desc = *conv_op->Op()->Proto();
auto data_format =
conv_op->Op()->GetAttrIfExists<std::string>("data_format");
if (data_format == "AnyLayout") {
LOG_FIRST_N(WARNING, 1) << "conv_elementwise_add_fuse_pass is enabled, "
"it's wrong if data_format of conv is not "
"NCHW.";
}
std::string bias_name = elementwise_add_in_y->Name();
std::string output_name = elementwise_add_out->Name();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册