未验证 提交 90e9a486 编写于 作者: W wenbin 提交者: GitHub

CE fix (#38324)

* CE fix

* more format
上级 c9bc2758
...@@ -130,7 +130,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() { ...@@ -130,7 +130,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() {
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) .IsStringIn({"NCHW", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("affine_channel")) AddOpCompat(OpCompat("affine_channel"))
...@@ -148,7 +148,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() { ...@@ -148,7 +148,7 @@ ConvAffineChannelFusePass::ConvAffineChannelFusePass() {
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("data_layout") .AddAttr("data_layout")
.IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) .IsStringIn({"NCHW", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("elementwise_add")) AddOpCompat(OpCompat("elementwise_add"))
...@@ -197,6 +197,13 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -197,6 +197,13 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
GET_CONV_BN_NODES(conv_ac_pattern); GET_CONV_BN_NODES(conv_ac_pattern);
auto data_format = conv->Op()->GetAttrIfExists<std::string>("data_format");
if (data_format == "AnyLayout") {
LOG_FIRST_N(WARNING, 1) << "conv_affine_channel_fuse_pass is enabled, "
"it's wrong if data_format of conv is not "
"NCHW.";
}
// Get affine_channel bias for resizing eltwise_y! // Get affine_channel bias for resizing eltwise_y!
auto* ac_bias_tensor = auto* ac_bias_tensor =
scope->FindVar(ac_bias->Name())->GetMutable<LoDTensor>(); scope->FindVar(ac_bias->Name())->GetMutable<LoDTensor>();
...@@ -282,7 +289,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() { ...@@ -282,7 +289,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() {
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) .IsStringIn({"NCHW", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("affine_channel")) AddOpCompat(OpCompat("affine_channel"))
.AddInput("X") .AddInput("X")
...@@ -299,7 +306,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() { ...@@ -299,7 +306,7 @@ ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() {
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("data_layout") .AddAttr("data_layout")
.IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) .IsStringIn({"NCHW", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("elementwise_add")) AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X") .AddInput("X")
...@@ -347,6 +354,12 @@ void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -347,6 +354,12 @@ void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
VLOG(4) << "handle ConvBN fuse"; VLOG(4) << "handle ConvBN fuse";
GET_CONV_BN_NODES(conv_ac_pattern); GET_CONV_BN_NODES(conv_ac_pattern);
auto data_format = conv->Op()->GetAttrIfExists<std::string>("data_format");
if (data_format == "AnyLayout") {
LOG_FIRST_N(WARNING, 1) << "conv_eltwiseadd_affine_channel_fuse_pass is "
"enabled, it's wrong if data_format of conv "
"is not NCHW.";
}
// OPERATORS // OPERATORS
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_ac_pattern); GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_ac_pattern);
// BIAS inputs // BIAS inputs
......
...@@ -77,7 +77,7 @@ ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() { ...@@ -77,7 +77,7 @@ ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() {
.AddAttr("dilations") .AddAttr("dilations")
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NHWC", "NCHW"}) .IsStringIn({"NHWC", "NCHW", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("elementwise_add")) AddOpCompat(OpCompat("elementwise_add"))
......
...@@ -57,7 +57,7 @@ ConvElementwiseAddFusePass::ConvElementwiseAddFusePass() { ...@@ -57,7 +57,7 @@ ConvElementwiseAddFusePass::ConvElementwiseAddFusePass() {
.AddAttr("dilations") .AddAttr("dilations")
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW" /*, "NHWC", "AnyLayout"*/}) .IsStringIn({"NCHW", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("elementwise_add")) AddOpCompat(OpCompat("elementwise_add"))
...@@ -97,6 +97,13 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -97,6 +97,13 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
GET_NODES; GET_NODES;
auto base_op_desc = *conv_op->Op()->Proto(); auto base_op_desc = *conv_op->Op()->Proto();
auto data_format =
conv_op->Op()->GetAttrIfExists<std::string>("data_format");
if (data_format == "AnyLayout") {
LOG_FIRST_N(WARNING, 1) << "conv_elementwise_add_fuse_pass is enabled, "
"it's wrong if data_format of conv is not "
"NCHW.";
}
std::string bias_name = elementwise_add_in_y->Name(); std::string bias_name = elementwise_add_in_y->Name();
std::string output_name = elementwise_add_out->Name(); std::string output_name = elementwise_add_out->Name();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册