conv_bn_fuse_pass.cc 28.3 KB
Newer Older
S
Sylwester Fraczek 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/conv_bn_fuse_pass.h"
W
wanghuancoder 已提交
16

S
Sylwester Fraczek 已提交
17
#include <string>
W
wanghuancoder 已提交
18

19
#include "paddle/fluid/framework/convert_utils.h"
20
#include "paddle/fluid/framework/lod_tensor.h"
P
Pei Yang 已提交
21
#include "paddle/fluid/framework/op_version_registry.h"
22
#include "paddle/fluid/framework/tensor.h"
S
Sylwester Fraczek 已提交
23
#include "paddle/fluid/platform/enforce.h"
24 25
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/common/data_type.h"
S
Sylwester Fraczek 已提交
26

27
namespace phi {
28
class DenseTensor;
29
}  // namespace phi
30

W
wanghuancoder 已提交
31 32 33 34 35 36
namespace paddle {
namespace framework {
class Scope;
}  // namespace framework
}  // namespace paddle

37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
namespace {
template <typename T1, typename T2>
void ConvertTensorType(paddle::framework::LoDTensor* tensor) {
  paddle::framework::Tensor tmp_tensor;
  tmp_tensor.set_type(paddle::experimental::CppTypeToDataType<T2>::Type());
  tmp_tensor.Resize(tensor->dims());
  auto* tmp_data = tmp_tensor.mutable_data<T2>(paddle::platform::CPUPlace());
  auto* data = tensor->mutable_data<T1>(paddle::platform::CPUPlace());
  for (int i = 0; i < tensor->numel(); i++) {
    tmp_data[i] = static_cast<T2>(data[i]);
  }
  tensor->clear();
  paddle::framework::TensorCopySync(
      tmp_tensor, paddle::platform::CPUPlace(), tensor);
}
}  // namespace

S
Sylwester Fraczek 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
namespace paddle {
namespace framework {
namespace ir {

#define GET_CONV_BN_NODES(pattern_name)                                      \
  /* OPERATORS */                                                            \
  GET_IR_NODE_FROM_SUBGRAPH(conv, conv, pattern_name);                       \
  GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, pattern_name);           \
  /* CONV inputs */                                                          \
  GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, pattern_name);         \
  /* CONV outputs */                                                         \
  GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, pattern_name);               \
  /* BN inputs */                                                            \
  GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, pattern_name);               \
  GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, pattern_name);                 \
  GET_IR_NODE_FROM_SUBGRAPH(bn_mean, bn_mean, pattern_name);                 \
  GET_IR_NODE_FROM_SUBGRAPH(bn_variance, bn_variance, pattern_name);         \
  /* BN outputs */                                                           \
  GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, pattern_name); /* Out */         \
  GET_IR_NODE_FROM_SUBGRAPH(bn_mean_out, bn_mean_out, pattern_name);         \
  GET_IR_NODE_FROM_SUBGRAPH(bn_variance_out, bn_variance_out, pattern_name); \
  GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name);     \
  GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name)

void recompute_bias_and_weights(const Scope* scope,
                                ir::Node* conv_weight,            //
                                const ir::Node& bn_scale,         //
                                const LoDTensor& bn_bias_tensor,  //
                                const ir::Node& bn_mean,          //
                                const ir::Node& bn_variance,      //
84
                                LoDTensor* eltwise_y_in_tensor,   //
85 86
                                float epsilon,
                                const std::string& conv_type) {
87 88 89 90 91 92 93
  using EigenVectorArrayMap =
      Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
  using ConstEigenVectorArrayMap =
      Eigen::Map<const Eigen::Array<float, Eigen::Dynamic, 1>>;
  using EigenMatrixArrayMap = Eigen::Map<
      Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;

S
Sylwester Fraczek 已提交
94
  // Re-compute bias of conv2d from BN
95
  PADDLE_ENFORCE_EQ(
96 97
      eltwise_y_in_tensor->dims(),
      bn_bias_tensor.dims(),
98 99 100 101
      platform::errors::InvalidArgument("Tensor elementwise y(%d) and batch "
                                        "norm bias(%d) must have same dims.",
                                        eltwise_y_in_tensor->dims().size(),
                                        bn_bias_tensor.dims().size()));
S
Sylwester Fraczek 已提交
102 103 104 105 106 107

  auto* scale_tensor = scope->FindVar(bn_scale.Name())->GetMutable<LoDTensor>();
  auto* variance_tensor =
      scope->FindVar(bn_variance.Name())->GetMutable<LoDTensor>();
  auto* mean_tensor = scope->FindVar(bn_mean.Name())->GetMutable<LoDTensor>();

108 109
  ConstEigenVectorArrayMap scale_array(
      scale_tensor->data<float>(), scale_tensor->numel(), 1);
110 111
  EigenVectorArrayMap variance_array(
      variance_tensor->mutable_data<float>(platform::CPUPlace()),
112 113 114 115 116 117
      variance_tensor->numel(),
      1);
  ConstEigenVectorArrayMap mean_array(
      mean_tensor->data<float>(), mean_tensor->numel(), 1);
  ConstEigenVectorArrayMap bn_bias_array(
      bn_bias_tensor.data<float>(), bn_bias_tensor.numel(), 1);
S
Sylwester Fraczek 已提交
118

119 120 121 122
  // variance will not be used anymore, so make it std_array and then tmp_array
  variance_array += epsilon;
  variance_array = variance_array.sqrt();
  variance_array = scale_array / variance_array;
123
  for (int i = 0; i < variance_tensor->numel(); i++) {
124 125
    PADDLE_ENFORCE_EQ(std::isfinite(variance_array[i]),
                      true,
126 127 128 129 130
                      platform::errors::InvalidArgument(
                          "The inverse of Fused batch norm variance "
                          "should be finite. Found nonfinite values! "
                          "Please check %s ",
                          bn_variance.Name()));
131
  }
132 133
  EigenVectorArrayMap eltwise_y_in_array(
      eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
134 135
      eltwise_y_in_tensor->numel(),
      1);
136

137 138
  eltwise_y_in_array =
      ((eltwise_y_in_array - mean_array) * variance_array) + bn_bias_array;
139
  for (int i = 0; i < eltwise_y_in_tensor->numel(); i++) {
140 141
    PADDLE_ENFORCE_EQ(std::isfinite(eltwise_y_in_array[i]),
                      true,
142 143 144 145 146
                      platform::errors::InvalidArgument(
                          "Fused batch norm bias should be "
                          "finite. Found nonfinite values! "
                          "Please check %s and related variables.",
                          bn_variance.Name()));
147
  }
S
Sylwester Fraczek 已提交
148 149

  // Re-compute weight of conv2d from BN
150 151
  auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
  auto weights_shape = weights->dims();
152 153 154 155 156 157 158 159 160 161 162 163 164
  auto weights_data = weights->mutable_data<float>(platform::CPUPlace());

  // ConvTranspose weights are in IOHW format
  if (conv_type == "conv2d_transpose") {
    int kernel_size = weights_shape[2] * weights_shape[3];
    for (int i = 0; i < weights->numel();) {
      for (int j = 0; j < weights_shape[1]; ++j) {
        for (int k = 0; k < kernel_size; ++k, ++i) {
          weights_data[i] *= variance_array[j];
        }
      }
    }
  } else {
165
    auto weights_shape_2d = phi::flatten_to_2d(weights_shape, 1);
166

167 168
    EigenMatrixArrayMap weights_array_2d(
        weights_data, weights_shape_2d[0], weights_shape_2d[1]);
169

170 171
    weights_array_2d.colwise() *= variance_array;
  }
S
Sylwester Fraczek 已提交
172 173
}

W
Wangzheee 已提交
174 175 176 177 178 179 180 181 182
ConvBNFusePass::ConvBNFusePass() {
  AddOpCompat(OpCompat("conv2d"))
      .AddInput("Input")
      .IsTensor()
      .End()
      .AddInput("Filter")
      .IsTensor()
      .End()
      .AddInput("Bias")
183
      .IsTensor()
W
Wangzheee 已提交
184 185 186
      .IsOptional()
      .End()
      .AddInput("ResidualData")
187
      .IsTensor()
W
Wangzheee 已提交
188 189 190 191 192 193
      .IsOptional()
      .End()
      .AddOutput("Output")
      .IsTensor()
      .End()
      .AddAttr("strides")
194
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
195 196
      .End()
      .AddAttr("paddings")
197
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
198 199 200 201 202 203 204 205 206
      .End()
      .AddAttr("padding_algorithm")
      .IsOptional()
      .IsStringIn({"EXPLICIT", "SAME", "VALID"})
      .End()
      .AddAttr("groups")
      .IsNumGE(1)
      .End()
      .AddAttr("dilations")
207
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
      .End()
      .AddAttr("data_format")
      .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
      .End();

  AddOpCompat(OpCompat("batch_norm"))
      .AddInput("X")
      .IsTensor()
      .End()
      .AddInput("Scale")
      .IsTensor()
      .End()
      .AddInput("Bias")
      .IsTensor()
      .End()
      .AddInput("Mean")
      .IsTensor()
      .End()
      .AddInput("Variance")
      .IsTensor()
      .End()
      .AddOutput("MeanOut")
      .IsTensor()
      .End()
      .AddOutput("VarianceOut")
      .IsTensor()
      .End()
      .AddOutput("SavedMean")
      .IsTensor()
      .End()
      .AddOutput("SavedVariance")
      .IsTensor()
      .End()
      .AddOutput("Y")
      .IsTensor()
      .End()
244 245 246 247
      .AddOutput("ReserveSpace")
      .IsTensor()
      .IsOptional()
      .End()
W
Wangzheee 已提交
248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
      .AddAttr("epsilon")
      .IsNumLE(0.001f)
      .IsNumGE(0.0f)
      .End();

  AddOpCompat(OpCompat("elementwise_add"))
      .AddInput("X")
      .IsTensor()
      .End()
      .AddInput("Y")
      .IsTensor()
      .End()
      .AddOutput("Out")
      .IsTensor()
      .End()
      .AddAttr("axis")
      .IsNumEQ(1)
      .End();
}

268
void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
269 270
  PADDLE_ENFORCE_NOT_NULL(
      graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
271
  FusePassBase::Init(name_scope_, graph);
S
Sylwester Fraczek 已提交
272 273

  auto* scope = param_scope();
274 275
  PADDLE_ENFORCE_NOT_NULL(
      scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
S
Sylwester Fraczek 已提交
276 277 278 279 280 281

  GraphPatternDetector gpd;
  auto* conv_input =
      gpd.mutable_pattern()
          ->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
          ->AsInput()
282
          ->assert_is_op_input(conv_type(), "Input");
S
Sylwester Fraczek 已提交
283
  patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
284
  conv_bn_pattern(conv_input, conv_type(), false /*with_eltwise_add*/);
S
Sylwester Fraczek 已提交
285 286 287 288

  int found_conv_bn_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
W
Wangzheee 已提交
289 290 291 292
    if (!IsCompat(subgraph, g)) {
      LOG(WARNING) << "Pass in op compat failed.";
      return;
    }
293
    VLOG(4) << "handle " + conv_type() + "BN fuse";
S
Sylwester Fraczek 已提交
294 295 296
    // conv, batch_norm,
    // conv_weight, conv_out,
    // bn_scale, bn_bias, bn_mean, bn_variance,
W
Wojciech Uss 已提交
297 298
    // bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,
    // bn_saved_variance
S
Sylwester Fraczek 已提交
299 300
    GET_CONV_BN_NODES(conv_bn_pattern);

W
Wojciech Uss 已提交
301 302 303
    // check if fuse can be done and if MKL-DNN should be used
    FuseOptions fuse_option = FindFuseOption(*conv, *batch_norm);
    if (fuse_option == DO_NOT_FUSE) {
304
      VLOG(3) << "do not perform " + conv_type() + " bn fuse";
W
Wojciech Uss 已提交
305 306 307
      return;
    }

308 309 310 311 312 313
    // conv_weight fp32 --> fp16
    auto* conv_weight_tensor =
        scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
    auto tensor_type = conv_weight_tensor->dtype();

    if (tensor_type == paddle::experimental::DataType::FLOAT16) {
314
      ConvertTensorType<float16, float>(conv_weight_tensor);
315 316
    }

317 318 319 320
    // Get batch norm bias
    auto* bn_bias_tensor =
        scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>();

S
Sylwester Fraczek 已提交
321 322
    // Create eltwise_y (conv bias) variable
    VarDesc eltwise_y_in_desc(
323
        patterns::PDNodeName("fuse_conv_bn", conv_type() + "_eltwise_y_in"));
324
    eltwise_y_in_desc.SetShape(phi::vectorize(bn_bias_tensor->dims()));
325 326
    eltwise_y_in_desc.SetDataType(
        framework::TransToProtoVarType(bn_bias_tensor->dtype()));
327
    eltwise_y_in_desc.SetLoDLevel(bn_bias->Var()->GetLoDLevel());
W
Wojciech Uss 已提交
328
    eltwise_y_in_desc.SetPersistable(true);
S
Sylwester Fraczek 已提交
329 330 331 332 333 334 335
    auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
    auto* eltwise_y_in_tensor =
        scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();

    // Initialize eltwise_y
    eltwise_y_in_tensor->Resize(bn_bias_tensor->dims());
    std::fill_n(eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
336 337
                eltwise_y_in_tensor->numel(),
                0.0f);
S
Sylwester Fraczek 已提交
338 339

    // update weights and biases
340
    float epsilon =
R
Ruibiao Chen 已提交
341
        PADDLE_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon"));
342 343 344 345 346 347 348 349 350
    recompute_bias_and_weights(scope,
                               conv_weight,
                               *bn_scale,
                               *bn_bias_tensor,
                               *bn_mean,
                               *bn_variance,
                               eltwise_y_in_tensor,
                               epsilon,
                               conv_type());
S
Sylwester Fraczek 已提交
351

352
    if (tensor_type == paddle::experimental::DataType::FLOAT16) {
353 354
      ConvertTensorType<float, float16>(conv_weight_tensor);
      ConvertTensorType<float, float16>(eltwise_y_in_tensor);
355 356
    }

W
Wojciech Uss 已提交
357 358 359 360
    // with MKL-DNN fuse conv+bn into conv with bias
    // without MKL-DNN fuse conv+bn into conv+elementwise_add
    if (fuse_option == FUSE_MKLDNN) {
      auto input_names = conv->Op()->InputNames();
361 362 363
      bool has_bias =
          std::find(input_names.begin(), input_names.end(), "Bias") !=
          input_names.end();
W
Wojciech Uss 已提交
364 365 366
      if (has_bias && conv->Op()->Input("Bias").size() > 0) {
        // reuse existing conv bias node
        auto conv_bias_names = conv->Op()->Input("Bias");
367
        PADDLE_ENFORCE_EQ(
368 369
            conv_bias_names.size(),
            1UL,
370
            platform::errors::InvalidArgument("Find input var Bais error."));
W
Wojciech Uss 已提交
371 372
        auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
        auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
373
        PADDLE_ENFORCE_EQ(
374 375
            conv_bias_tensor->dims(),
            eltwise_y_in_tensor->dims(),
376 377 378 379 380
            platform::errors::InvalidArgument(
                "Tensor convolution bias(%d) and elementwise y(%d) "
                "must have same dims.",
                conv_bias_tensor->dims().size(),
                eltwise_y_in_tensor->dims().size()));
W
Wojciech Uss 已提交
381 382 383 384 385 386 387 388 389 390 391

        auto eigen_conv_bias = EigenVector<float>::From(*conv_bias_tensor);
        eigen_conv_bias += EigenVector<float>::From(*eltwise_y_in_tensor);
      } else {
        // add new conv_bias node
        conv->Op()->SetInput(
            "Bias", std::vector<std::string>({eltwise_y_in_node->Name()}));
        IR_NODE_LINK_TO(eltwise_y_in_node, conv);
      }
      conv->Op()->SetOutput("Output",
                            std::vector<std::string>({bn_out->Name()}));
W
Wangzheee 已提交
392 393 394 395
      if (!IsCompat(*conv->Op())) {
        LOG(WARNING) << "conv_bn fuse pass in out conv op compat failed.";
        return;
      }
396 397 398 399 400 401 402 403 404 405 406
      GraphSafeRemoveNodes(graph,
                           {conv_out,
                            bn_scale,
                            bn_bias,
                            bn_mean,
                            bn_variance,
                            batch_norm,
                            bn_mean_out,
                            bn_variance_out,
                            bn_saved_mean,
                            bn_saved_variance});
W
Wojciech Uss 已提交
407 408 409 410 411 412 413 414 415 416 417

      IR_NODE_LINK_TO(conv, bn_out);
      found_conv_bn_count++;
    } else {  // fuse_option == FUSE_NATIVE
      // create an elementwise add node.
      OpDesc desc;
      desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
      desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
      desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
      desc.SetType("elementwise_add");
      desc.SetAttr("axis", 1);
W
Wangzheee 已提交
418 419 420 421 422
      if (!IsCompat(desc)) {
        LOG(WARNING)
            << "conv_bn fuse pass in out elementwise_add op compat failed.";
        return;
      }
W
Wojciech Uss 已提交
423 424
      auto eltwise_op = g->CreateOpNode(&desc);  // OpDesc will be copied.

425 426 427 428 429 430 431 432 433 434
      GraphSafeRemoveNodes(graph,
                           {bn_scale,
                            bn_bias,
                            bn_mean,
                            bn_variance,
                            batch_norm,
                            bn_mean_out,
                            bn_variance_out,
                            bn_saved_mean,
                            bn_saved_variance});
W
Wojciech Uss 已提交
435 436 437 438 439 440

      IR_NODE_LINK_TO(conv_out, eltwise_op);
      IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
      IR_NODE_LINK_TO(eltwise_op, bn_out);
      found_conv_bn_count++;
    }
S
Sylwester Fraczek 已提交
441 442
  };

443
  gpd(graph, handler);
S
Sylwester Fraczek 已提交
444 445 446 447

  AddStatis(found_conv_bn_count);
}

W
Wangzheee 已提交
448 449 450 451 452 453 454 455 456
ConvEltwiseAddBNFusePass::ConvEltwiseAddBNFusePass() {
  AddOpCompat(OpCompat("conv2d"))
      .AddInput("Input")
      .IsTensor()
      .End()
      .AddInput("Filter")
      .IsTensor()
      .End()
      .AddInput("Bias")
457
      .IsTensor()
W
Wangzheee 已提交
458 459 460
      .IsOptional()
      .End()
      .AddInput("ResidualData")
461
      .IsTensor()
W
Wangzheee 已提交
462 463 464 465 466 467
      .IsOptional()
      .End()
      .AddOutput("Output")
      .IsTensor()
      .End()
      .AddAttr("strides")
468
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
469 470
      .End()
      .AddAttr("paddings")
471
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
472 473 474 475 476 477 478 479 480
      .End()
      .AddAttr("padding_algorithm")
      .IsStringIn({"EXPLICIT", "SAME", "VALID"})
      .IsOptional()
      .End()
      .AddAttr("groups")
      .IsNumGE(1)
      .End()
      .AddAttr("dilations")
481
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
      .End()
      .AddAttr("data_format")
      .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
      .End();

  AddOpCompat(OpCompat("batch_norm"))
      .AddInput("X")
      .IsTensor()
      .End()
      .AddInput("Scale")
      .IsTensor()
      .End()
      .AddInput("Bias")
      .IsTensor()
      .End()
      .AddInput("Mean")
      .IsTensor()
      .End()
      .AddInput("Variance")
      .IsTensor()
      .End()
      .AddOutput("MeanOut")
      .IsTensor()
      .End()
      .AddOutput("VarianceOut")
      .IsTensor()
      .End()
      .AddOutput("SavedMean")
      .IsTensor()
      .End()
      .AddOutput("SavedVariance")
      .IsTensor()
      .End()
      .AddOutput("Y")
      .IsTensor()
      .End()
518 519 520 521
      .AddOutput("ReserveSpace")
      .IsTensor()
      .IsOptional()
      .End()
W
Wangzheee 已提交
522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541
      .AddAttr("epsilon")
      .IsNumLE(0.001f)
      .IsNumGE(0.0f)
      .End();

  AddOpCompat(OpCompat("elementwise_add"))
      .AddInput("X")
      .IsTensor()
      .End()
      .AddInput("Y")
      .IsTensor()
      .End()
      .AddOutput("Out")
      .IsTensor()
      .End()
      .AddAttr("axis")
      .IsNumEQ(1)
      .End();
}

542
void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
543 544
  PADDLE_ENFORCE_NOT_NULL(
      graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
545
  FusePassBase::Init(name_scope_, graph);
S
Sylwester Fraczek 已提交
546 547

  auto* scope = param_scope();
548 549
  PADDLE_ENFORCE_NOT_NULL(
      scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
S
Sylwester Fraczek 已提交
550 551 552 553 554 555

  GraphPatternDetector gpd;
  auto* conv_input =
      gpd.mutable_pattern()
          ->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
          ->AsInput()
556
          ->assert_is_op_input(conv_type(), "Input");
S
Sylwester Fraczek 已提交
557
  patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
558
  conv_bn_pattern(conv_input, conv_type(), true /*with_eltwise_add*/);
S
Sylwester Fraczek 已提交
559 560 561 562

  int found_conv_bn_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
W
Wangzheee 已提交
563 564 565 566
    if (!IsCompat(subgraph, g)) {
      LOG(WARNING) << "Pass in op compat failed.";
      return;
    }
567
    VLOG(4) << "handle " + conv_type() + "BN fuse";
S
Sylwester Fraczek 已提交
568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588
    // conv, batch_norm,
    // conv_weight, conv_out,
    // bn_scale, bn_bias, bn_mean, bn_variance,
    // bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,bn_saved_variance
    GET_CONV_BN_NODES(conv_bn_pattern);
    // OPERATORS
    GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bn_pattern);
    // BIAS inputs
    GET_IR_NODE_FROM_SUBGRAPH(eltwise_y_in, eltwise_y_in, conv_bn_pattern);
    // BIAS outputs
    GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bn_pattern);

    // Get eltwise_y (conv bias) variable
    auto* eltwise_y_in_tensor =
        scope->FindVar(eltwise_y_in->Name())->GetMutable<LoDTensor>();

    // Get batch norm bias
    auto* bn_bias_tensor =
        scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>();

    // update weights and biases
589
    float epsilon =
R
Ruibiao Chen 已提交
590
        PADDLE_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon"));
591

592 593 594 595 596 597 598 599 600 601
    // conv_weight fp16 --> fp32
    auto* conv_weight_tensor =
        scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
    auto tensor_type = conv_weight_tensor->dtype();

    if (tensor_type == paddle::experimental::DataType::FLOAT16) {
      ConvertTensorType<float16, float>(conv_weight_tensor);
      ConvertTensorType<float16, float>(eltwise_y_in_tensor);
    }

602 603 604 605 606 607 608
    // if bias is an input to other ops as well then we cannot overwrite it
    // so we create separate elementwise Y in nodes
    if (eltwise_y_in->outputs.size() > 1) {
      // Make a copy of eltwise Y input tensor
      // Create eltwise_y (conv bias) variable
      VarDesc eltwise_y_in_desc(patterns::PDNodeName(
          name_scope_, "eltwise_y_in" + std::to_string(found_conv_bn_count)));
609
      eltwise_y_in_desc.SetShape(phi::vectorize(eltwise_y_in_tensor->dims()));
610 611
      eltwise_y_in_desc.SetDataType(
          framework::TransToProtoVarType(eltwise_y_in_tensor->dtype()));
612 613 614 615 616 617 618
      eltwise_y_in_desc.SetLoDLevel(eltwise_y_in->Var()->GetLoDLevel());
      eltwise_y_in_desc.SetPersistable(true);
      auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
      auto* eltwise_y_in_tensor_ex =
          scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();

      // Initialize eltwise_y
619 620 621 622 623 624 625 626 627 628 629 630
      TensorCopy(
          *eltwise_y_in_tensor, platform::CPUPlace(), eltwise_y_in_tensor_ex);

      recompute_bias_and_weights(scope,
                                 conv_weight,
                                 *bn_scale,
                                 *bn_bias_tensor,
                                 *bn_mean,
                                 *bn_variance,
                                 eltwise_y_in_tensor_ex,
                                 epsilon,
                                 conv_type());
631 632 633 634 635 636 637 638 639 640 641 642 643 644
      // Set new var
      eltwise->Op()->RenameInput(eltwise_y_in->Name(),
                                 eltwise_y_in_node->Name());
      // Link new bias node to eltwise
      IR_NODE_LINK_TO(eltwise_y_in_node, eltwise);
      // unlink original bias from eltwise_op
      eltwise_y_in->outputs.erase(
          std::remove_if(eltwise_y_in->outputs.begin(),
                         eltwise_y_in->outputs.end(),
                         [&](Node*& n) {
                           return n->id() == eltwise->id() ? true : false;
                         }),
          eltwise_y_in->outputs.end());
    } else {
645 646 647 648 649 650 651 652 653
      recompute_bias_and_weights(scope,
                                 conv_weight,
                                 *bn_scale,
                                 *bn_bias_tensor,
                                 *bn_mean,
                                 *bn_variance,
                                 eltwise_y_in_tensor,
                                 epsilon,
                                 conv_type());
654
    }
S
Sylwester Fraczek 已提交
655

656 657 658 659 660
    if (tensor_type == paddle::experimental::DataType::FLOAT16) {
      ConvertTensorType<float, float16>(conv_weight_tensor);
      ConvertTensorType<float, float16>(eltwise_y_in_tensor);
    }

S
Sylwester Fraczek 已提交
661 662 663
    // Update the elementwise_add node
    eltwise->Op()->SetAttr("axis", 1);
    eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
W
Wangzheee 已提交
664 665 666 667 668
    if (!IsCompat(*eltwise->Op())) {
      LOG(WARNING)
          << "conv_eltwise_bn fuse pass in out eltwise op compat failed.";
      return;
    }
669 670 671 672 673 674 675 676 677 678 679
    GraphSafeRemoveNodes(graph,
                         {bn_scale,
                          bn_bias,
                          bn_mean,
                          bn_variance,
                          batch_norm,
                          bn_mean_out,
                          bn_variance_out,
                          bn_saved_mean,
                          bn_saved_variance,
                          eltwise_out});
S
Sylwester Fraczek 已提交
680 681 682 683 684 685

    IR_NODE_LINK_TO(eltwise, bn_out);

    found_conv_bn_count++;
  };

686
  gpd(graph, handler);
S
Sylwester Fraczek 已提交
687 688 689 690

  AddStatis(found_conv_bn_count);
}

W
Wangzheee 已提交
691 692 693 694 695 696 697 698 699
ConvTransposeBNFusePass::ConvTransposeBNFusePass() {
  AddOpCompat(OpCompat("conv2d_transpose"))
      .AddInput("Input")
      .IsTensor()
      .End()
      .AddInput("Filter")
      .IsTensor()
      .End()
      .AddInput("Bias")
700
      .IsTensor()
W
Wangzheee 已提交
701 702 703 704 705
      .IsOptional()
      .End()
      .AddOutput("Output")
      .IsTensor()
      .End()
706 707 708 709 710 711 712 713 714
      .AddAttr("output_padding")
      .IsType<std::vector<int>>()
      .IsOptional()
      .End()
      .AddAttr("output_size")
      .IsType<std::vector<int>>()
      .IsOptional()
      .End()
      .AddAttr("groups")
715
      .IsNumEQ(1)
716 717 718 719
      .End()
      .AddAttr("dilations")
      .IsType<std::vector<int>>()
      .End()
W
Wangzheee 已提交
720
      .AddAttr("strides")
721
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
722 723
      .End()
      .AddAttr("paddings")
724
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
725 726
      .End()
      .AddAttr("padding_algorithm")
727
      .IsOptional()
W
Wangzheee 已提交
728
      .IsStringIn({"EXPLICIT", "SAME", "VALID"})
729 730
      .End()
      .AddAttr("data_format")
731
      .IsStringIn({"NCHW", "AnyLayout"})
732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755
      .End();
}

ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() {
  AddOpCompat(OpCompat("conv2d_transpose"))
      .AddInput("Input")
      .IsTensor()
      .End()
      .AddInput("Filter")
      .IsTensor()
      .End()
      .AddInput("Bias")
      .IsTensor()
      .IsOptional()
      .End()
      .AddOutput("Output")
      .IsTensor()
      .End()
      .AddAttr("output_padding")
      .IsType<std::vector<int>>()
      .IsOptional()
      .End()
      .AddAttr("output_size")
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
756 757 758
      .IsOptional()
      .End()
      .AddAttr("groups")
759
      .IsNumEQ(1)
W
Wangzheee 已提交
760 761
      .End()
      .AddAttr("dilations")
762 763 764 765 766 767 768 769 770
      .IsType<std::vector<int>>()
      .End()
      .AddAttr("strides")
      .IsType<std::vector<int>>()
      .End()
      .AddAttr("paddings")
      .IsType<std::vector<int>>()
      .End()
      .AddAttr("padding_algorithm")
771
      .IsOptional()
772
      .IsStringIn({"EXPLICIT", "SAME", "VALID"})
W
Wangzheee 已提交
773 774
      .End()
      .AddAttr("data_format")
775
      .IsStringIn({"NCHW", "AnyLayout"})
W
Wangzheee 已提交
776 777 778
      .End();
}

779 780
DepthwiseConvBNFusePass::DepthwiseConvBNFusePass() {
  AddOpCompat(OpCompat("depthwise_conv2d"))
W
Wangzheee 已提交
781 782 783 784 785 786 787
      .AddInput("Input")
      .IsTensor()
      .End()
      .AddInput("Filter")
      .IsTensor()
      .End()
      .AddInput("Bias")
788 789 790 791 792
      .IsTensor()
      .IsOptional()
      .End()
      .AddInput("ResidualData")
      .IsTensor()
W
Wangzheee 已提交
793 794 795 796 797 798
      .IsOptional()
      .End()
      .AddOutput("Output")
      .IsTensor()
      .End()
      .AddAttr("strides")
799
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
800 801
      .End()
      .AddAttr("paddings")
802
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
803 804 805
      .End()
      .AddAttr("padding_algorithm")
      .IsOptional()
806
      .IsStringIn({"EXPLICIT", "SAME", "VALID"})
W
Wangzheee 已提交
807 808 809 810 811
      .End()
      .AddAttr("groups")
      .IsNumGE(1)
      .End()
      .AddAttr("dilations")
812
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
813 814 815 816 817 818
      .End()
      .AddAttr("data_format")
      .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
      .End();
}

S
Sylwester Fraczek 已提交
819 820 821 822 823 824 825
}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(conv_bn_fuse_pass, paddle::framework::ir::ConvBNFusePass);
REGISTER_PASS(conv_eltwiseadd_bn_fuse_pass,
              paddle::framework::ir::ConvEltwiseAddBNFusePass);
826 827 828 829
REGISTER_PASS(conv_transpose_bn_fuse_pass,
              paddle::framework::ir::ConvTransposeBNFusePass);
REGISTER_PASS(conv_transpose_eltwiseadd_bn_fuse_pass,
              paddle::framework::ir::ConvTransposeEltwiseAddBNFusePass);
830 831 832 833
REGISTER_PASS(depthwise_conv_bn_fuse_pass,
              paddle::framework::ir::DepthwiseConvBNFusePass);
REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass,
              paddle::framework::ir::DepthwiseConvEltwiseAddBNFusePass);
P
Pei Yang 已提交
834 835 836
REGISTER_PASS_CAPABILITY(conv_bn_fuse_pass)
    .AddCombination(
        paddle::framework::compatible::OpVersionComparatorCombination()
837
            .LE("conv2d", 1)
P
Pei Yang 已提交
838 839 840 841
            .EQ("batch_norm", 0));
REGISTER_PASS_CAPABILITY(conv_eltwiseadd_bn_fuse_pass)
    .AddCombination(
        paddle::framework::compatible::OpVersionComparatorCombination()
842
            .LE("conv2d", 1)
843
            .LE("elementwise_add", 1)
P
Pei Yang 已提交
844
            .EQ("batch_norm", 0));
845 846 847 848 849 850
REGISTER_PASS_CAPABILITY(conv_transpose_eltwiseadd_bn_fuse_pass)
    .AddCombination(
        paddle::framework::compatible::OpVersionComparatorCombination()
            .LE("conv2d_transpose", 2)
            .LE("elementwise_add", 1)
            .EQ("batch_norm", 0));
851 852 853 854 855
REGISTER_PASS_CAPABILITY(conv_transpose_bn_fuse_pass)
    .AddCombination(
        paddle::framework::compatible::OpVersionComparatorCombination()
            .LE("conv2d_transpose", 2)
            .EQ("batch_norm", 0));