conv_bn_fuse_pass.cc 31.1 KB
Newer Older
S
Sylwester Fraczek 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/conv_bn_fuse_pass.h"
W
wanghuancoder 已提交
16

S
Sylwester Fraczek 已提交
17
#include <string>
W
wanghuancoder 已提交
18

19
#include "paddle/fluid/framework/convert_utils.h"
Z
zyfncg 已提交
20
#include "paddle/fluid/framework/eigen.h"
21
#include "paddle/fluid/framework/lod_tensor.h"
P
Pei Yang 已提交
22
#include "paddle/fluid/framework/op_version_registry.h"
23
#include "paddle/fluid/framework/tensor.h"
S
Sylwester Fraczek 已提交
24
#include "paddle/fluid/platform/enforce.h"
25 26
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/common/data_type.h"
S
Sylwester Fraczek 已提交
27

28
namespace phi {
29
class DenseTensor;
30
}  // namespace phi
31

W
wanghuancoder 已提交
32 33 34 35 36 37
namespace paddle {
namespace framework {
class Scope;
}  // namespace framework
}  // namespace paddle

38 39
namespace {
template <typename T1, typename T2>
40
void ConvertTensorType(phi::DenseTensor* tensor) {
41
  phi::DenseTensor tmp_tensor;
42 43 44 45 46 47 48 49 50 51 52 53 54
  tmp_tensor.set_type(paddle::experimental::CppTypeToDataType<T2>::Type());
  tmp_tensor.Resize(tensor->dims());
  auto* tmp_data = tmp_tensor.mutable_data<T2>(paddle::platform::CPUPlace());
  auto* data = tensor->mutable_data<T1>(paddle::platform::CPUPlace());
  for (int i = 0; i < tensor->numel(); i++) {
    tmp_data[i] = static_cast<T2>(data[i]);
  }
  tensor->clear();
  paddle::framework::TensorCopySync(
      tmp_tensor, paddle::platform::CPUPlace(), tensor);
}
}  // namespace

S
Sylwester Fraczek 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
namespace paddle {
namespace framework {
namespace ir {

#define GET_CONV_BN_NODES(pattern_name)                                      \
  /* OPERATORS */                                                            \
  GET_IR_NODE_FROM_SUBGRAPH(conv, conv, pattern_name);                       \
  GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, pattern_name);           \
  /* CONV inputs */                                                          \
  GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, pattern_name);         \
  /* CONV outputs */                                                         \
  GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, pattern_name);               \
  /* BN inputs */                                                            \
  GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, pattern_name);               \
  GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, pattern_name);                 \
  GET_IR_NODE_FROM_SUBGRAPH(bn_mean, bn_mean, pattern_name);                 \
  GET_IR_NODE_FROM_SUBGRAPH(bn_variance, bn_variance, pattern_name);         \
  /* BN outputs */                                                           \
  GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, pattern_name); /* Out */         \
  GET_IR_NODE_FROM_SUBGRAPH(bn_mean_out, bn_mean_out, pattern_name);         \
  GET_IR_NODE_FROM_SUBGRAPH(bn_variance_out, bn_variance_out, pattern_name); \
  GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name);     \
  GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name)

void recompute_bias_and_weights(const Scope* scope,
80 81 82 83 84 85
                                ir::Node* conv_weight,                   //
                                const ir::Node& bn_scale,                //
                                const phi::DenseTensor& bn_bias_tensor,  //
                                const ir::Node& bn_mean,                 //
                                const ir::Node& bn_variance,             //
                                phi::DenseTensor* eltwise_y_in_tensor,   //
86 87
                                float epsilon,
                                const std::string& conv_type) {
88 89 90 91 92 93 94
  using EigenVectorArrayMap =
      Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
  using ConstEigenVectorArrayMap =
      Eigen::Map<const Eigen::Array<float, Eigen::Dynamic, 1>>;
  using EigenMatrixArrayMap = Eigen::Map<
      Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;

S
Sylwester Fraczek 已提交
95
  // Re-compute bias of conv2d from BN
96 97 98 99 100 101 102
  PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(),
                    bn_bias_tensor.dims(),
                    platform::errors::InvalidArgument(
                        "phi::DenseTensor elementwise y(%d) and batch "
                        "norm bias(%d) must have same dims.",
                        eltwise_y_in_tensor->dims().size(),
                        bn_bias_tensor.dims().size()));
S
Sylwester Fraczek 已提交
103

104 105
  auto* scale_tensor =
      scope->FindVar(bn_scale.Name())->GetMutable<phi::DenseTensor>();
S
Sylwester Fraczek 已提交
106
  auto* variance_tensor =
107 108 109
      scope->FindVar(bn_variance.Name())->GetMutable<phi::DenseTensor>();
  auto* mean_tensor =
      scope->FindVar(bn_mean.Name())->GetMutable<phi::DenseTensor>();
S
Sylwester Fraczek 已提交
110

111 112
  ConstEigenVectorArrayMap scale_array(
      scale_tensor->data<float>(), scale_tensor->numel(), 1);
113 114
  EigenVectorArrayMap variance_array(
      variance_tensor->mutable_data<float>(platform::CPUPlace()),
115 116 117 118 119 120
      variance_tensor->numel(),
      1);
  ConstEigenVectorArrayMap mean_array(
      mean_tensor->data<float>(), mean_tensor->numel(), 1);
  ConstEigenVectorArrayMap bn_bias_array(
      bn_bias_tensor.data<float>(), bn_bias_tensor.numel(), 1);
S
Sylwester Fraczek 已提交
121

122 123 124 125
  // variance will not be used anymore, so make it std_array and then tmp_array
  variance_array += epsilon;
  variance_array = variance_array.sqrt();
  variance_array = scale_array / variance_array;
126
  for (int i = 0; i < variance_tensor->numel(); i++) {
127 128
    PADDLE_ENFORCE_EQ(std::isfinite(variance_array[i]),
                      true,
129 130 131 132 133
                      platform::errors::InvalidArgument(
                          "The inverse of Fused batch norm variance "
                          "should be finite. Found nonfinite values! "
                          "Please check %s ",
                          bn_variance.Name()));
134
  }
135 136
  EigenVectorArrayMap eltwise_y_in_array(
      eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
137 138
      eltwise_y_in_tensor->numel(),
      1);
139

140 141
  eltwise_y_in_array =
      ((eltwise_y_in_array - mean_array) * variance_array) + bn_bias_array;
142
  for (int i = 0; i < eltwise_y_in_tensor->numel(); i++) {
143 144
    PADDLE_ENFORCE_EQ(std::isfinite(eltwise_y_in_array[i]),
                      true,
145 146 147 148 149
                      platform::errors::InvalidArgument(
                          "Fused batch norm bias should be "
                          "finite. Found nonfinite values! "
                          "Please check %s and related variables.",
                          bn_variance.Name()));
150
  }
S
Sylwester Fraczek 已提交
151 152

  // Re-compute weight of conv2d from BN
153 154
  auto* weights =
      scope->FindVar(conv_weight->Name())->GetMutable<phi::DenseTensor>();
155
  auto weights_shape = weights->dims();
156 157 158 159 160 161 162 163 164 165 166 167 168
  auto weights_data = weights->mutable_data<float>(platform::CPUPlace());

  // ConvTranspose weights are in IOHW format
  if (conv_type == "conv2d_transpose") {
    int kernel_size = weights_shape[2] * weights_shape[3];
    for (int i = 0; i < weights->numel();) {
      for (int j = 0; j < weights_shape[1]; ++j) {
        for (int k = 0; k < kernel_size; ++k, ++i) {
          weights_data[i] *= variance_array[j];
        }
      }
    }
  } else {
169
    auto weights_shape_2d = phi::flatten_to_2d(weights_shape, 1);
170

171 172
    EigenMatrixArrayMap weights_array_2d(
        weights_data, weights_shape_2d[0], weights_shape_2d[1]);
173

174 175
    weights_array_2d.colwise() *= variance_array;
  }
S
Sylwester Fraczek 已提交
176 177
}

W
Wangzheee 已提交
178 179 180 181 182 183 184 185 186
ConvBNFusePass::ConvBNFusePass() {
  AddOpCompat(OpCompat("conv2d"))
      .AddInput("Input")
      .IsTensor()
      .End()
      .AddInput("Filter")
      .IsTensor()
      .End()
      .AddInput("Bias")
187
      .IsTensor()
W
Wangzheee 已提交
188 189 190
      .IsOptional()
      .End()
      .AddInput("ResidualData")
191
      .IsTensor()
W
Wangzheee 已提交
192 193 194 195 196 197
      .IsOptional()
      .End()
      .AddOutput("Output")
      .IsTensor()
      .End()
      .AddAttr("strides")
198
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
199 200
      .End()
      .AddAttr("paddings")
201
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
202 203 204 205 206 207 208 209 210
      .End()
      .AddAttr("padding_algorithm")
      .IsOptional()
      .IsStringIn({"EXPLICIT", "SAME", "VALID"})
      .End()
      .AddAttr("groups")
      .IsNumGE(1)
      .End()
      .AddAttr("dilations")
211
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
212 213 214 215
      .End()
      .AddAttr("data_format")
      .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
      .End();
Z
zyfncg 已提交
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
  AddOpCompat(OpCompat("fused_conv2d"))
      .AddInput("Input")
      .IsTensor()
      .End()
      .AddInput("Filter")
      .IsTensor()
      .End()
      .AddInput("Bias")
      .IsTensor()
      .IsOptional()
      .End()
      .AddInput("ResidualData")
      .IsTensor()
      .IsOptional()
      .End()
      .AddOutput("Output")
      .IsTensor()
      .End()
      .AddAttr("strides")
      .IsType<std::vector<int>>()
      .End()
      .AddAttr("paddings")
      .IsType<std::vector<int>>()
      .End()
      .AddAttr("padding_algorithm")
      .IsOptional()
      .IsStringIn({"EXPLICIT", "SAME", "VALID"})
      .End()
      .AddAttr("groups")
      .IsNumGE(1)
      .End()
      .AddAttr("dilations")
      .IsType<std::vector<int>>()
      .End()
      .AddAttr("data_format")
      .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
      .End();
W
Wangzheee 已提交
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284

  AddOpCompat(OpCompat("batch_norm"))
      .AddInput("X")
      .IsTensor()
      .End()
      .AddInput("Scale")
      .IsTensor()
      .End()
      .AddInput("Bias")
      .IsTensor()
      .End()
      .AddInput("Mean")
      .IsTensor()
      .End()
      .AddInput("Variance")
      .IsTensor()
      .End()
      .AddOutput("MeanOut")
      .IsTensor()
      .End()
      .AddOutput("VarianceOut")
      .IsTensor()
      .End()
      .AddOutput("SavedMean")
      .IsTensor()
      .End()
      .AddOutput("SavedVariance")
      .IsTensor()
      .End()
      .AddOutput("Y")
      .IsTensor()
      .End()
285 286 287 288
      .AddOutput("ReserveSpace")
      .IsTensor()
      .IsOptional()
      .End()
W
Wangzheee 已提交
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
      .AddAttr("epsilon")
      .IsNumLE(0.001f)
      .IsNumGE(0.0f)
      .End();

  AddOpCompat(OpCompat("elementwise_add"))
      .AddInput("X")
      .IsTensor()
      .End()
      .AddInput("Y")
      .IsTensor()
      .End()
      .AddOutput("Out")
      .IsTensor()
      .End()
      .AddAttr("axis")
      .IsNumEQ(1)
      .End();
}

309
void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
310 311
  PADDLE_ENFORCE_NOT_NULL(
      graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
312
  FusePassBase::Init(name_scope_, graph);
S
Sylwester Fraczek 已提交
313 314

  auto* scope = param_scope();
315 316
  PADDLE_ENFORCE_NOT_NULL(
      scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
S
Sylwester Fraczek 已提交
317 318 319 320 321 322

  GraphPatternDetector gpd;
  auto* conv_input =
      gpd.mutable_pattern()
          ->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
          ->AsInput()
323
          ->assert_is_op_input(conv_type(), "Input");
S
Sylwester Fraczek 已提交
324
  patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
325
  conv_bn_pattern(conv_input, conv_type(), false /*with_eltwise_add*/);
S
Sylwester Fraczek 已提交
326 327 328 329

  int found_conv_bn_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
W
Wangzheee 已提交
330 331 332 333
    if (!IsCompat(subgraph, g)) {
      LOG(WARNING) << "Pass in op compat failed.";
      return;
    }
334
    VLOG(4) << "handle " + conv_type() + "BN fuse";
S
Sylwester Fraczek 已提交
335 336 337
    // conv, batch_norm,
    // conv_weight, conv_out,
    // bn_scale, bn_bias, bn_mean, bn_variance,
W
Wojciech Uss 已提交
338 339
    // bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,
    // bn_saved_variance
S
Sylwester Fraczek 已提交
340 341
    GET_CONV_BN_NODES(conv_bn_pattern);

W
Wojciech Uss 已提交
342 343 344
    // check if fuse can be done and if MKL-DNN should be used
    FuseOptions fuse_option = FindFuseOption(*conv, *batch_norm);
    if (fuse_option == DO_NOT_FUSE) {
345
      VLOG(3) << "do not perform " + conv_type() + " bn fuse";
W
Wojciech Uss 已提交
346 347 348
      return;
    }

349 350
    // conv_weight fp32 --> fp16
    auto* conv_weight_tensor =
351
        scope->FindVar(conv_weight->Name())->GetMutable<phi::DenseTensor>();
352 353 354
    auto tensor_type = conv_weight_tensor->dtype();

    if (tensor_type == paddle::experimental::DataType::FLOAT16) {
355
      ConvertTensorType<float16, float>(conv_weight_tensor);
356 357
    }

358 359
    // Get batch norm bias
    auto* bn_bias_tensor =
360
        scope->FindVar(bn_bias->Name())->GetMutable<phi::DenseTensor>();
361

362
    float epsilon =
R
Ruibiao Chen 已提交
363
        PADDLE_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon"));
S
Sylwester Fraczek 已提交
364

365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408
    bool is_mkldnn = fuse_option == FUSE_MKLDNN;
    auto input_names = conv->Op()->InputNames();
    bool has_bias = std::find(input_names.begin(), input_names.end(), "Bias") !=
                        input_names.end() &&
                    conv->Op()->Input("Bias").size() > 0;
    bool mkldnn_with_bias = is_mkldnn && has_bias;

    // Create eltwise_y (conv bias) variable
    phi::DenseTensor* eltwise_y_in_tensor;
    Node* eltwise_y_in_node;
    if (!mkldnn_with_bias) {
      VarDesc eltwise_y_in_desc(
          patterns::PDNodeName("fuse_conv_bn", conv_type() + "_eltwise_y_in"));
      eltwise_y_in_desc.SetShape(phi::vectorize(bn_bias_tensor->dims()));
      eltwise_y_in_desc.SetDataType(
          framework::TransToProtoVarType(bn_bias_tensor->dtype()));
      eltwise_y_in_desc.SetLoDLevel(bn_bias->Var()->GetLoDLevel());
      eltwise_y_in_desc.SetPersistable(true);
      eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
      eltwise_y_in_tensor =
          scope->Var(eltwise_y_in_node->Name())->GetMutable<phi::DenseTensor>();

      // Initialize eltwise_y
      eltwise_y_in_tensor->Resize(bn_bias_tensor->dims());
      std::fill_n(
          eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
          eltwise_y_in_tensor->numel(),
          0.0f);

      // update weights and biases
      recompute_bias_and_weights(scope,
                                 conv_weight,
                                 *bn_scale,
                                 *bn_bias_tensor,
                                 *bn_mean,
                                 *bn_variance,
                                 eltwise_y_in_tensor,
                                 epsilon,
                                 conv_type());

      if (tensor_type == paddle::experimental::DataType::FLOAT16) {
        ConvertTensorType<float, float16>(conv_weight_tensor);
        ConvertTensorType<float, float16>(eltwise_y_in_tensor);
      }
409 410
    }

W
Wojciech Uss 已提交
411 412
    // with MKL-DNN fuse conv+bn into conv with bias
    // without MKL-DNN fuse conv+bn into conv+elementwise_add
413
    if (is_mkldnn) {
Z
zyfncg 已提交
414 415 416 417
      if (conv->Op()->Type() == "conv2d" ||
          conv->Op()->Type() == "depthwise_conv2d") {
        conv->Op()->SetType("fused_conv2d");
      }
418
      if (mkldnn_with_bias) {
W
Wojciech Uss 已提交
419 420
        // reuse existing conv bias node
        auto conv_bias_names = conv->Op()->Input("Bias");
421
        PADDLE_ENFORCE_EQ(
422 423
            conv_bias_names.size(),
            1UL,
424
            phi::errors::InvalidArgument("Find input var Bias error."));
W
Wojciech Uss 已提交
425
        auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
426
        auto* conv_bias_tensor = conv_bias_var->GetMutable<phi::DenseTensor>();
427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
        PADDLE_ENFORCE_EQ(conv_bias_tensor->dims(),
                          bn_bias_tensor->dims(),
                          phi::errors::InvalidArgument(
                              "phi::DenseTensor convolution bias(%d) and batch "
                              "normalization bias (%d) "
                              "must have same dims.",
                              conv_bias_tensor->dims().size(),
                              bn_bias_tensor->dims().size()));

        recompute_bias_and_weights(scope,
                                   conv_weight,
                                   *bn_scale,
                                   *bn_bias_tensor,
                                   *bn_mean,
                                   *bn_variance,
                                   conv_bias_tensor,
                                   epsilon,
                                   conv_type());

        if (tensor_type == paddle::experimental::DataType::FLOAT16) {
          ConvertTensorType<float, float16>(conv_weight_tensor);
          ConvertTensorType<float, float16>(conv_bias_tensor);
        }

W
Wojciech Uss 已提交
451 452 453 454 455 456 457 458
      } else {
        // add new conv_bias node
        conv->Op()->SetInput(
            "Bias", std::vector<std::string>({eltwise_y_in_node->Name()}));
        IR_NODE_LINK_TO(eltwise_y_in_node, conv);
      }
      conv->Op()->SetOutput("Output",
                            std::vector<std::string>({bn_out->Name()}));
W
Wangzheee 已提交
459 460 461 462
      if (!IsCompat(*conv->Op())) {
        LOG(WARNING) << "conv_bn fuse pass in out conv op compat failed.";
        return;
      }
463 464 465 466 467 468 469 470 471 472 473
      GraphSafeRemoveNodes(graph,
                           {conv_out,
                            bn_scale,
                            bn_bias,
                            bn_mean,
                            bn_variance,
                            batch_norm,
                            bn_mean_out,
                            bn_variance_out,
                            bn_saved_mean,
                            bn_saved_variance});
W
Wojciech Uss 已提交
474 475 476 477

      IR_NODE_LINK_TO(conv, bn_out);
      found_conv_bn_count++;
    } else {  // fuse_option == FUSE_NATIVE
478
              // create an elementwise add node.
W
Wojciech Uss 已提交
479 480 481 482 483 484
      OpDesc desc;
      desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
      desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
      desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
      desc.SetType("elementwise_add");
      desc.SetAttr("axis", 1);
W
Wangzheee 已提交
485 486 487 488 489
      if (!IsCompat(desc)) {
        LOG(WARNING)
            << "conv_bn fuse pass in out elementwise_add op compat failed.";
        return;
      }
W
Wojciech Uss 已提交
490 491
      auto eltwise_op = g->CreateOpNode(&desc);  // OpDesc will be copied.

492 493 494 495 496 497 498 499 500 501
      GraphSafeRemoveNodes(graph,
                           {bn_scale,
                            bn_bias,
                            bn_mean,
                            bn_variance,
                            batch_norm,
                            bn_mean_out,
                            bn_variance_out,
                            bn_saved_mean,
                            bn_saved_variance});
W
Wojciech Uss 已提交
502 503 504 505 506 507

      IR_NODE_LINK_TO(conv_out, eltwise_op);
      IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
      IR_NODE_LINK_TO(eltwise_op, bn_out);
      found_conv_bn_count++;
    }
S
Sylwester Fraczek 已提交
508 509
  };

510
  gpd(graph, handler);
S
Sylwester Fraczek 已提交
511 512 513 514

  AddStatis(found_conv_bn_count);
}

W
Wangzheee 已提交
515 516 517 518 519 520 521 522 523
ConvEltwiseAddBNFusePass::ConvEltwiseAddBNFusePass() {
  AddOpCompat(OpCompat("conv2d"))
      .AddInput("Input")
      .IsTensor()
      .End()
      .AddInput("Filter")
      .IsTensor()
      .End()
      .AddInput("Bias")
524
      .IsTensor()
W
Wangzheee 已提交
525 526 527
      .IsOptional()
      .End()
      .AddInput("ResidualData")
528
      .IsTensor()
W
Wangzheee 已提交
529 530 531 532 533 534
      .IsOptional()
      .End()
      .AddOutput("Output")
      .IsTensor()
      .End()
      .AddAttr("strides")
535
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
536 537
      .End()
      .AddAttr("paddings")
538
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
539 540 541 542 543 544 545 546 547
      .End()
      .AddAttr("padding_algorithm")
      .IsStringIn({"EXPLICIT", "SAME", "VALID"})
      .IsOptional()
      .End()
      .AddAttr("groups")
      .IsNumGE(1)
      .End()
      .AddAttr("dilations")
548
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
      .End()
      .AddAttr("data_format")
      .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
      .End();

  AddOpCompat(OpCompat("batch_norm"))
      .AddInput("X")
      .IsTensor()
      .End()
      .AddInput("Scale")
      .IsTensor()
      .End()
      .AddInput("Bias")
      .IsTensor()
      .End()
      .AddInput("Mean")
      .IsTensor()
      .End()
      .AddInput("Variance")
      .IsTensor()
      .End()
      .AddOutput("MeanOut")
      .IsTensor()
      .End()
      .AddOutput("VarianceOut")
      .IsTensor()
      .End()
      .AddOutput("SavedMean")
      .IsTensor()
      .End()
      .AddOutput("SavedVariance")
      .IsTensor()
      .End()
      .AddOutput("Y")
      .IsTensor()
      .End()
585 586 587 588
      .AddOutput("ReserveSpace")
      .IsTensor()
      .IsOptional()
      .End()
W
Wangzheee 已提交
589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608
      .AddAttr("epsilon")
      .IsNumLE(0.001f)
      .IsNumGE(0.0f)
      .End();

  AddOpCompat(OpCompat("elementwise_add"))
      .AddInput("X")
      .IsTensor()
      .End()
      .AddInput("Y")
      .IsTensor()
      .End()
      .AddOutput("Out")
      .IsTensor()
      .End()
      .AddAttr("axis")
      .IsNumEQ(1)
      .End();
}

609
void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
610 611
  PADDLE_ENFORCE_NOT_NULL(
      graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
612
  FusePassBase::Init(name_scope_, graph);
S
Sylwester Fraczek 已提交
613 614

  auto* scope = param_scope();
615 616
  PADDLE_ENFORCE_NOT_NULL(
      scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
S
Sylwester Fraczek 已提交
617 618 619 620 621 622

  GraphPatternDetector gpd;
  auto* conv_input =
      gpd.mutable_pattern()
          ->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
          ->AsInput()
623
          ->assert_is_op_input(conv_type(), "Input");
S
Sylwester Fraczek 已提交
624
  patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
625
  conv_bn_pattern(conv_input, conv_type(), true /*with_eltwise_add*/);
S
Sylwester Fraczek 已提交
626 627 628 629

  int found_conv_bn_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
W
Wangzheee 已提交
630 631 632 633
    if (!IsCompat(subgraph, g)) {
      LOG(WARNING) << "Pass in op compat failed.";
      return;
    }
634
    VLOG(4) << "handle " + conv_type() + "BN fuse";
S
Sylwester Fraczek 已提交
635 636 637 638 639 640 641 642 643 644 645 646 647 648
    // conv, batch_norm,
    // conv_weight, conv_out,
    // bn_scale, bn_bias, bn_mean, bn_variance,
    // bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,bn_saved_variance
    GET_CONV_BN_NODES(conv_bn_pattern);
    // OPERATORS
    GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bn_pattern);
    // BIAS inputs
    GET_IR_NODE_FROM_SUBGRAPH(eltwise_y_in, eltwise_y_in, conv_bn_pattern);
    // BIAS outputs
    GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bn_pattern);

    // Get eltwise_y (conv bias) variable
    auto* eltwise_y_in_tensor =
649
        scope->FindVar(eltwise_y_in->Name())->GetMutable<phi::DenseTensor>();
S
Sylwester Fraczek 已提交
650 651 652

    // Get batch norm bias
    auto* bn_bias_tensor =
653
        scope->FindVar(bn_bias->Name())->GetMutable<phi::DenseTensor>();
S
Sylwester Fraczek 已提交
654 655

    // update weights and biases
656
    float epsilon =
R
Ruibiao Chen 已提交
657
        PADDLE_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon"));
658

659 660
    // conv_weight fp16 --> fp32
    auto* conv_weight_tensor =
661
        scope->FindVar(conv_weight->Name())->GetMutable<phi::DenseTensor>();
662 663 664 665 666 667 668
    auto tensor_type = conv_weight_tensor->dtype();

    if (tensor_type == paddle::experimental::DataType::FLOAT16) {
      ConvertTensorType<float16, float>(conv_weight_tensor);
      ConvertTensorType<float16, float>(eltwise_y_in_tensor);
    }

669 670 671 672 673 674 675
    // if bias is an input to other ops as well then we cannot overwrite it
    // so we create separate elementwise Y in nodes
    if (eltwise_y_in->outputs.size() > 1) {
      // Make a copy of eltwise Y input tensor
      // Create eltwise_y (conv bias) variable
      VarDesc eltwise_y_in_desc(patterns::PDNodeName(
          name_scope_, "eltwise_y_in" + std::to_string(found_conv_bn_count)));
676
      eltwise_y_in_desc.SetShape(phi::vectorize(eltwise_y_in_tensor->dims()));
677 678
      eltwise_y_in_desc.SetDataType(
          framework::TransToProtoVarType(eltwise_y_in_tensor->dtype()));
679 680 681 682
      eltwise_y_in_desc.SetLoDLevel(eltwise_y_in->Var()->GetLoDLevel());
      eltwise_y_in_desc.SetPersistable(true);
      auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
      auto* eltwise_y_in_tensor_ex =
683
          scope->Var(eltwise_y_in_node->Name())->GetMutable<phi::DenseTensor>();
684 685

      // Initialize eltwise_y
686 687 688 689 690 691 692 693 694 695 696 697
      TensorCopy(
          *eltwise_y_in_tensor, platform::CPUPlace(), eltwise_y_in_tensor_ex);

      recompute_bias_and_weights(scope,
                                 conv_weight,
                                 *bn_scale,
                                 *bn_bias_tensor,
                                 *bn_mean,
                                 *bn_variance,
                                 eltwise_y_in_tensor_ex,
                                 epsilon,
                                 conv_type());
698 699 700 701 702 703 704 705 706 707 708 709 710 711
      // Set new var
      eltwise->Op()->RenameInput(eltwise_y_in->Name(),
                                 eltwise_y_in_node->Name());
      // Link new bias node to eltwise
      IR_NODE_LINK_TO(eltwise_y_in_node, eltwise);
      // unlink original bias from eltwise_op
      eltwise_y_in->outputs.erase(
          std::remove_if(eltwise_y_in->outputs.begin(),
                         eltwise_y_in->outputs.end(),
                         [&](Node*& n) {
                           return n->id() == eltwise->id() ? true : false;
                         }),
          eltwise_y_in->outputs.end());
    } else {
712 713 714 715 716 717 718 719 720
      recompute_bias_and_weights(scope,
                                 conv_weight,
                                 *bn_scale,
                                 *bn_bias_tensor,
                                 *bn_mean,
                                 *bn_variance,
                                 eltwise_y_in_tensor,
                                 epsilon,
                                 conv_type());
721
    }
S
Sylwester Fraczek 已提交
722

723 724 725 726 727
    if (tensor_type == paddle::experimental::DataType::FLOAT16) {
      ConvertTensorType<float, float16>(conv_weight_tensor);
      ConvertTensorType<float, float16>(eltwise_y_in_tensor);
    }

S
Sylwester Fraczek 已提交
728 729 730
    // Update the elementwise_add node
    eltwise->Op()->SetAttr("axis", 1);
    eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
W
Wangzheee 已提交
731 732 733 734 735
    if (!IsCompat(*eltwise->Op())) {
      LOG(WARNING)
          << "conv_eltwise_bn fuse pass in out eltwise op compat failed.";
      return;
    }
736 737 738 739 740 741 742 743 744 745 746
    GraphSafeRemoveNodes(graph,
                         {bn_scale,
                          bn_bias,
                          bn_mean,
                          bn_variance,
                          batch_norm,
                          bn_mean_out,
                          bn_variance_out,
                          bn_saved_mean,
                          bn_saved_variance,
                          eltwise_out});
S
Sylwester Fraczek 已提交
747 748 749 750 751 752

    IR_NODE_LINK_TO(eltwise, bn_out);

    found_conv_bn_count++;
  };

753
  gpd(graph, handler);
S
Sylwester Fraczek 已提交
754 755 756 757

  AddStatis(found_conv_bn_count);
}

W
Wangzheee 已提交
758 759 760 761 762 763 764 765 766
ConvTransposeBNFusePass::ConvTransposeBNFusePass() {
  AddOpCompat(OpCompat("conv2d_transpose"))
      .AddInput("Input")
      .IsTensor()
      .End()
      .AddInput("Filter")
      .IsTensor()
      .End()
      .AddInput("Bias")
767
      .IsTensor()
W
Wangzheee 已提交
768 769 770 771 772
      .IsOptional()
      .End()
      .AddOutput("Output")
      .IsTensor()
      .End()
773 774 775 776 777 778 779 780 781
      .AddAttr("output_padding")
      .IsType<std::vector<int>>()
      .IsOptional()
      .End()
      .AddAttr("output_size")
      .IsType<std::vector<int>>()
      .IsOptional()
      .End()
      .AddAttr("groups")
782
      .IsNumEQ(1)
783 784 785 786
      .End()
      .AddAttr("dilations")
      .IsType<std::vector<int>>()
      .End()
W
Wangzheee 已提交
787
      .AddAttr("strides")
788
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
789 790
      .End()
      .AddAttr("paddings")
791
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
792 793
      .End()
      .AddAttr("padding_algorithm")
794
      .IsOptional()
W
Wangzheee 已提交
795
      .IsStringIn({"EXPLICIT", "SAME", "VALID"})
796 797
      .End()
      .AddAttr("data_format")
798
      .IsStringIn({"NCHW", "AnyLayout"})
799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822
      .End();
}

ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() {
  AddOpCompat(OpCompat("conv2d_transpose"))
      .AddInput("Input")
      .IsTensor()
      .End()
      .AddInput("Filter")
      .IsTensor()
      .End()
      .AddInput("Bias")
      .IsTensor()
      .IsOptional()
      .End()
      .AddOutput("Output")
      .IsTensor()
      .End()
      .AddAttr("output_padding")
      .IsType<std::vector<int>>()
      .IsOptional()
      .End()
      .AddAttr("output_size")
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
823 824 825
      .IsOptional()
      .End()
      .AddAttr("groups")
826
      .IsNumEQ(1)
W
Wangzheee 已提交
827 828
      .End()
      .AddAttr("dilations")
829 830 831 832 833 834 835 836 837
      .IsType<std::vector<int>>()
      .End()
      .AddAttr("strides")
      .IsType<std::vector<int>>()
      .End()
      .AddAttr("paddings")
      .IsType<std::vector<int>>()
      .End()
      .AddAttr("padding_algorithm")
838
      .IsOptional()
839
      .IsStringIn({"EXPLICIT", "SAME", "VALID"})
W
Wangzheee 已提交
840 841
      .End()
      .AddAttr("data_format")
842
      .IsStringIn({"NCHW", "AnyLayout"})
W
Wangzheee 已提交
843 844 845
      .End();
}

846 847
DepthwiseConvBNFusePass::DepthwiseConvBNFusePass() {
  AddOpCompat(OpCompat("depthwise_conv2d"))
W
Wangzheee 已提交
848 849 850 851 852 853 854
      .AddInput("Input")
      .IsTensor()
      .End()
      .AddInput("Filter")
      .IsTensor()
      .End()
      .AddInput("Bias")
855 856 857 858 859
      .IsTensor()
      .IsOptional()
      .End()
      .AddInput("ResidualData")
      .IsTensor()
W
Wangzheee 已提交
860 861 862 863 864 865
      .IsOptional()
      .End()
      .AddOutput("Output")
      .IsTensor()
      .End()
      .AddAttr("strides")
866
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
867 868
      .End()
      .AddAttr("paddings")
869
      .IsType<std::vector<int>>()
Z
zyfncg 已提交
870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906
      .End()
      .AddAttr("padding_algorithm")
      .IsOptional()
      .IsStringIn({"EXPLICIT", "SAME", "VALID"})
      .End()
      .AddAttr("groups")
      .IsNumGE(1)
      .End()
      .AddAttr("dilations")
      .IsType<std::vector<int>>()
      .End()
      .AddAttr("data_format")
      .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
      .End();
  AddOpCompat(OpCompat("fused_conv2d"))
      .AddInput("Input")
      .IsTensor()
      .End()
      .AddInput("Filter")
      .IsTensor()
      .End()
      .AddInput("Bias")
      .IsTensor()
      .IsOptional()
      .End()
      .AddInput("ResidualData")
      .IsTensor()
      .IsOptional()
      .End()
      .AddOutput("Output")
      .IsTensor()
      .End()
      .AddAttr("strides")
      .IsType<std::vector<int>>()
      .End()
      .AddAttr("paddings")
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
907 908 909
      .End()
      .AddAttr("padding_algorithm")
      .IsOptional()
910
      .IsStringIn({"EXPLICIT", "SAME", "VALID"})
W
Wangzheee 已提交
911 912 913 914 915
      .End()
      .AddAttr("groups")
      .IsNumGE(1)
      .End()
      .AddAttr("dilations")
916
      .IsType<std::vector<int>>()
W
Wangzheee 已提交
917 918 919 920 921 922
      .End()
      .AddAttr("data_format")
      .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
      .End();
}

S
Sylwester Fraczek 已提交
923 924 925 926 927 928 929
}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(conv_bn_fuse_pass, paddle::framework::ir::ConvBNFusePass);
REGISTER_PASS(conv_eltwiseadd_bn_fuse_pass,
              paddle::framework::ir::ConvEltwiseAddBNFusePass);
930 931 932 933
REGISTER_PASS(conv_transpose_bn_fuse_pass,
              paddle::framework::ir::ConvTransposeBNFusePass);
REGISTER_PASS(conv_transpose_eltwiseadd_bn_fuse_pass,
              paddle::framework::ir::ConvTransposeEltwiseAddBNFusePass);
934 935 936 937
REGISTER_PASS(depthwise_conv_bn_fuse_pass,
              paddle::framework::ir::DepthwiseConvBNFusePass);
REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass,
              paddle::framework::ir::DepthwiseConvEltwiseAddBNFusePass);
P
Pei Yang 已提交
938 939 940
REGISTER_PASS_CAPABILITY(conv_bn_fuse_pass)
    .AddCombination(
        paddle::framework::compatible::OpVersionComparatorCombination()
941
            .LE("conv2d", 1)
P
Pei Yang 已提交
942 943 944 945
            .EQ("batch_norm", 0));
REGISTER_PASS_CAPABILITY(conv_eltwiseadd_bn_fuse_pass)
    .AddCombination(
        paddle::framework::compatible::OpVersionComparatorCombination()
946
            .LE("conv2d", 1)
947
            .LE("elementwise_add", 1)
P
Pei Yang 已提交
948
            .EQ("batch_norm", 0));
949 950 951 952 953 954
REGISTER_PASS_CAPABILITY(conv_transpose_eltwiseadd_bn_fuse_pass)
    .AddCombination(
        paddle::framework::compatible::OpVersionComparatorCombination()
            .LE("conv2d_transpose", 2)
            .LE("elementwise_add", 1)
            .EQ("batch_norm", 0));
955 956 957 958 959
REGISTER_PASS_CAPABILITY(conv_transpose_bn_fuse_pass)
    .AddCombination(
        paddle::framework::compatible::OpVersionComparatorCombination()
            .LE("conv2d_transpose", 2)
            .EQ("batch_norm", 0));