conv_bn_fuse_pass.cc 17.1 KB
Newer Older
S
Sylwester Fraczek 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/conv_bn_fuse_pass.h"
W
wanghuancoder 已提交
16

S
Sylwester Fraczek 已提交
17 18
#include <string>
#include <vector>
W
wanghuancoder 已提交
19

S
Sylwester Fraczek 已提交
20
#include "paddle/fluid/framework/lod_tensor.h"
P
Pei Yang 已提交
21
#include "paddle/fluid/framework/op_version_registry.h"
S
Sylwester Fraczek 已提交
22 23 24
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/enforce.h"

W
wanghuancoder 已提交
25 26 27 28 29 30 31
namespace paddle {
namespace framework {
class LoDTensor;
class Scope;
}  // namespace framework
}  // namespace paddle

S
Sylwester Fraczek 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
namespace paddle {
namespace framework {
namespace ir {

#define GET_CONV_BN_NODES(pattern_name)                                      \
  /* OPERATORS */                                                            \
  GET_IR_NODE_FROM_SUBGRAPH(conv, conv, pattern_name);                       \
  GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, pattern_name);           \
  /* CONV inputs */                                                          \
  GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, pattern_name);         \
  /* CONV outputs */                                                         \
  GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, pattern_name);               \
  /* BN inputs */                                                            \
  GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, pattern_name);               \
  GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, pattern_name);                 \
  GET_IR_NODE_FROM_SUBGRAPH(bn_mean, bn_mean, pattern_name);                 \
  GET_IR_NODE_FROM_SUBGRAPH(bn_variance, bn_variance, pattern_name);         \
  /* BN outputs */                                                           \
  GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, pattern_name); /* Out */         \
  GET_IR_NODE_FROM_SUBGRAPH(bn_mean_out, bn_mean_out, pattern_name);         \
  GET_IR_NODE_FROM_SUBGRAPH(bn_variance_out, bn_variance_out, pattern_name); \
  GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name);     \
  GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name)

void recompute_bias_and_weights(const Scope* scope,
                                ir::Node* conv_weight,            //
                                const ir::Node& bn_scale,         //
                                const LoDTensor& bn_bias_tensor,  //
                                const ir::Node& bn_mean,          //
                                const ir::Node& bn_variance,      //
62
                                LoDTensor* eltwise_y_in_tensor,   //
63
                                float epsilon, const std::string& conv_type) {
64 65 66 67 68 69 70
  using EigenVectorArrayMap =
      Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
  using ConstEigenVectorArrayMap =
      Eigen::Map<const Eigen::Array<float, Eigen::Dynamic, 1>>;
  using EigenMatrixArrayMap = Eigen::Map<
      Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;

S
Sylwester Fraczek 已提交
71
  // Re-compute bias of conv2d from BN
72 73 74 75 76 77
  PADDLE_ENFORCE_EQ(
      eltwise_y_in_tensor->dims(), bn_bias_tensor.dims(),
      platform::errors::InvalidArgument("Tensor elementwise y(%d) and batch "
                                        "norm bias(%d) must have same dims.",
                                        eltwise_y_in_tensor->dims().size(),
                                        bn_bias_tensor.dims().size()));
S
Sylwester Fraczek 已提交
78 79 80 81 82 83

  auto* scale_tensor = scope->FindVar(bn_scale.Name())->GetMutable<LoDTensor>();
  auto* variance_tensor =
      scope->FindVar(bn_variance.Name())->GetMutable<LoDTensor>();
  auto* mean_tensor = scope->FindVar(bn_mean.Name())->GetMutable<LoDTensor>();

84 85 86 87 88 89 90 91 92
  ConstEigenVectorArrayMap scale_array(scale_tensor->data<float>(),
                                       scale_tensor->numel(), 1);
  EigenVectorArrayMap variance_array(
      variance_tensor->mutable_data<float>(platform::CPUPlace()),
      variance_tensor->numel(), 1);
  ConstEigenVectorArrayMap mean_array(mean_tensor->data<float>(),
                                      mean_tensor->numel(), 1);
  ConstEigenVectorArrayMap bn_bias_array(bn_bias_tensor.data<float>(),
                                         bn_bias_tensor.numel(), 1);
S
Sylwester Fraczek 已提交
93

94 95 96 97
  // variance will not be used anymore, so make it std_array and then tmp_array
  variance_array += epsilon;
  variance_array = variance_array.sqrt();
  variance_array = scale_array / variance_array;
98 99 100 101 102 103
  for (int i = 0; i < variance_tensor->numel(); i++) {
    PADDLE_ENFORCE_EQ(
        isfinite(variance_array[i]), true,
        platform::errors::InvalidArgument("fuse batch norm variance should be "
                                          "finite. Found nonfinite values!"));
  }
104 105 106
  EigenVectorArrayMap eltwise_y_in_array(
      eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
      eltwise_y_in_tensor->numel(), 1);
107

108 109
  eltwise_y_in_array =
      ((eltwise_y_in_array - mean_array) * variance_array) + bn_bias_array;
110 111 112 113 114 115
  for (int i = 0; i < eltwise_y_in_tensor->numel(); i++) {
    PADDLE_ENFORCE_EQ(
        isfinite(eltwise_y_in_array[i]), true,
        platform::errors::InvalidArgument("fused batch norm bias should be "
                                          "finite. Found nonfinite values!"));
  }
S
Sylwester Fraczek 已提交
116 117

  // Re-compute weight of conv2d from BN
118 119
  auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
  auto weights_shape = weights->dims();
120 121 122 123 124 125 126 127 128 129 130 131 132 133
  auto weights_data = weights->mutable_data<float>(platform::CPUPlace());

  // ConvTranspose weights are in IOHW format
  if (conv_type == "conv2d_transpose") {
    int kernel_size = weights_shape[2] * weights_shape[3];
    for (int i = 0; i < weights->numel();) {
      for (int j = 0; j < weights_shape[1]; ++j) {
        for (int k = 0; k < kernel_size; ++k, ++i) {
          weights_data[i] *= variance_array[j];
        }
      }
    }
  } else {
    auto weights_shape_2d = flatten_to_2d(weights_shape, 1);
134

135 136
    EigenMatrixArrayMap weights_array_2d(weights_data, weights_shape_2d[0],
                                         weights_shape_2d[1]);
137

138 139
    weights_array_2d.colwise() *= variance_array;
  }
S
Sylwester Fraczek 已提交
140 141
}

142
void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
143 144
  PADDLE_ENFORCE_NOT_NULL(
      graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
145
  FusePassBase::Init(name_scope_, graph);
S
Sylwester Fraczek 已提交
146 147

  auto* scope = param_scope();
148 149
  PADDLE_ENFORCE_NOT_NULL(
      scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
S
Sylwester Fraczek 已提交
150 151 152 153 154 155

  GraphPatternDetector gpd;
  auto* conv_input =
      gpd.mutable_pattern()
          ->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
          ->AsInput()
156
          ->assert_is_op_input(conv_type(), "Input");
S
Sylwester Fraczek 已提交
157
  patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
158
  conv_bn_pattern(conv_input, conv_type(), false /*with_eltwise_add*/);
S
Sylwester Fraczek 已提交
159 160 161 162

  int found_conv_bn_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
163
    VLOG(4) << "handle " + conv_type() + "BN fuse";
S
Sylwester Fraczek 已提交
164 165 166 167

    // conv, batch_norm,
    // conv_weight, conv_out,
    // bn_scale, bn_bias, bn_mean, bn_variance,
W
Wojciech Uss 已提交
168 169
    // bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,
    // bn_saved_variance
S
Sylwester Fraczek 已提交
170 171
    GET_CONV_BN_NODES(conv_bn_pattern);

W
Wojciech Uss 已提交
172 173 174
    // check if fuse can be done and if MKL-DNN should be used
    FuseOptions fuse_option = FindFuseOption(*conv, *batch_norm);
    if (fuse_option == DO_NOT_FUSE) {
175
      VLOG(3) << "do not perform " + conv_type() + " bn fuse";
W
Wojciech Uss 已提交
176 177 178
      return;
    }

179 180 181 182
    // Get batch norm bias
    auto* bn_bias_tensor =
        scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>();

S
Sylwester Fraczek 已提交
183 184
    // Create eltwise_y (conv bias) variable
    VarDesc eltwise_y_in_desc(
185
        patterns::PDNodeName("fuse_conv_bn", conv_type() + "_eltwise_y_in"));
186 187 188
    eltwise_y_in_desc.SetShape(framework::vectorize(bn_bias_tensor->dims()));
    eltwise_y_in_desc.SetDataType(bn_bias_tensor->type());
    eltwise_y_in_desc.SetLoDLevel(bn_bias->Var()->GetLoDLevel());
W
Wojciech Uss 已提交
189
    eltwise_y_in_desc.SetPersistable(true);
S
Sylwester Fraczek 已提交
190 191 192 193 194 195 196 197 198 199
    auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
    auto* eltwise_y_in_tensor =
        scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();

    // Initialize eltwise_y
    eltwise_y_in_tensor->Resize(bn_bias_tensor->dims());
    std::fill_n(eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
                eltwise_y_in_tensor->numel(), 0.0f);

    // update weights and biases
200 201
    float epsilon =
        BOOST_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon"));
S
Sylwester Fraczek 已提交
202
    recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
203
                               *bn_mean, *bn_variance, eltwise_y_in_tensor,
204
                               epsilon, conv_type());
S
Sylwester Fraczek 已提交
205

W
Wojciech Uss 已提交
206 207 208 209 210 211 212 213 214
    // with MKL-DNN fuse conv+bn into conv with bias
    // without MKL-DNN fuse conv+bn into conv+elementwise_add
    if (fuse_option == FUSE_MKLDNN) {
      auto input_names = conv->Op()->InputNames();
      bool has_bias = std::find(input_names.begin(), input_names.end(),
                                "Bias") != input_names.end();
      if (has_bias && conv->Op()->Input("Bias").size() > 0) {
        // reuse existing conv bias node
        auto conv_bias_names = conv->Op()->Input("Bias");
215 216 217
        PADDLE_ENFORCE_EQ(
            conv_bias_names.size(), 1UL,
            platform::errors::InvalidArgument("Find input var Bais error."));
W
Wojciech Uss 已提交
218 219
        auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
        auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
220 221 222 223 224 225 226
        PADDLE_ENFORCE_EQ(
            conv_bias_tensor->dims(), eltwise_y_in_tensor->dims(),
            platform::errors::InvalidArgument(
                "Tensor convolution bias(%d) and elementwise y(%d) "
                "must have same dims.",
                conv_bias_tensor->dims().size(),
                eltwise_y_in_tensor->dims().size()));
W
Wojciech Uss 已提交
227 228 229 230 231 232 233 234 235 236 237 238

        auto eigen_conv_bias = EigenVector<float>::From(*conv_bias_tensor);
        eigen_conv_bias += EigenVector<float>::From(*eltwise_y_in_tensor);
      } else {
        // add new conv_bias node
        conv->Op()->SetInput(
            "Bias", std::vector<std::string>({eltwise_y_in_node->Name()}));
        IR_NODE_LINK_TO(eltwise_y_in_node, conv);
      }
      conv->Op()->SetOutput("Output",
                            std::vector<std::string>({bn_out->Name()}));
      GraphSafeRemoveNodes(
239
          graph,
W
Wojciech Uss 已提交
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
          {conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm,
           bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance});

      IR_NODE_LINK_TO(conv, bn_out);
      found_conv_bn_count++;
    } else {  // fuse_option == FUSE_NATIVE
      // create an elementwise add node.
      OpDesc desc;
      desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
      desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
      desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
      desc.SetType("elementwise_add");
      desc.SetAttr("axis", 1);
      auto eltwise_op = g->CreateOpNode(&desc);  // OpDesc will be copied.

255 256 257
      GraphSafeRemoveNodes(graph, {bn_scale, bn_bias, bn_mean, bn_variance,
                                   batch_norm, bn_mean_out, bn_variance_out,
                                   bn_saved_mean, bn_saved_variance});
W
Wojciech Uss 已提交
258 259 260 261 262 263

      IR_NODE_LINK_TO(conv_out, eltwise_op);
      IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
      IR_NODE_LINK_TO(eltwise_op, bn_out);
      found_conv_bn_count++;
    }
S
Sylwester Fraczek 已提交
264 265
  };

266
  gpd(graph, handler);
S
Sylwester Fraczek 已提交
267 268 269 270

  AddStatis(found_conv_bn_count);
}

271
void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
272 273
  PADDLE_ENFORCE_NOT_NULL(
      graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
274
  FusePassBase::Init(name_scope_, graph);
S
Sylwester Fraczek 已提交
275 276

  auto* scope = param_scope();
277 278
  PADDLE_ENFORCE_NOT_NULL(
      scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
S
Sylwester Fraczek 已提交
279 280 281 282 283 284

  GraphPatternDetector gpd;
  auto* conv_input =
      gpd.mutable_pattern()
          ->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
          ->AsInput()
285
          ->assert_is_op_input(conv_type(), "Input");
S
Sylwester Fraczek 已提交
286
  patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
287
  conv_bn_pattern(conv_input, conv_type(), true /*with_eltwise_add*/);
S
Sylwester Fraczek 已提交
288 289 290 291

  int found_conv_bn_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
292
    VLOG(4) << "handle " + conv_type() + "BN fuse";
S
Sylwester Fraczek 已提交
293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314

    // conv, batch_norm,
    // conv_weight, conv_out,
    // bn_scale, bn_bias, bn_mean, bn_variance,
    // bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,bn_saved_variance
    GET_CONV_BN_NODES(conv_bn_pattern);
    // OPERATORS
    GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bn_pattern);
    // BIAS inputs
    GET_IR_NODE_FROM_SUBGRAPH(eltwise_y_in, eltwise_y_in, conv_bn_pattern);
    // BIAS outputs
    GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bn_pattern);

    // Get eltwise_y (conv bias) variable
    auto* eltwise_y_in_tensor =
        scope->FindVar(eltwise_y_in->Name())->GetMutable<LoDTensor>();

    // Get batch norm bias
    auto* bn_bias_tensor =
        scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>();

    // update weights and biases
315 316
    float epsilon =
        BOOST_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon"));
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358

    // if bias is an input to other ops as well then we cannot overwrite it
    // so we create separate elementwise Y in nodes
    if (eltwise_y_in->outputs.size() > 1) {
      // Make a copy of eltwise Y input tensor
      // Create eltwise_y (conv bias) variable
      VarDesc eltwise_y_in_desc(patterns::PDNodeName(
          name_scope_, "eltwise_y_in" + std::to_string(found_conv_bn_count)));
      eltwise_y_in_desc.SetShape(
          framework::vectorize(eltwise_y_in_tensor->dims()));
      eltwise_y_in_desc.SetDataType(eltwise_y_in_tensor->type());
      eltwise_y_in_desc.SetLoDLevel(eltwise_y_in->Var()->GetLoDLevel());
      eltwise_y_in_desc.SetPersistable(true);
      auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
      auto* eltwise_y_in_tensor_ex =
          scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();

      // Initialize eltwise_y
      TensorCopy(*eltwise_y_in_tensor, platform::CPUPlace(),
                 eltwise_y_in_tensor_ex);

      recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
                                 *bn_mean, *bn_variance, eltwise_y_in_tensor_ex,
                                 epsilon, conv_type());
      // Set new var
      eltwise->Op()->RenameInput(eltwise_y_in->Name(),
                                 eltwise_y_in_node->Name());
      // Link new bias node to eltwise
      IR_NODE_LINK_TO(eltwise_y_in_node, eltwise);
      // unlink original bias from eltwise_op
      eltwise_y_in->outputs.erase(
          std::remove_if(eltwise_y_in->outputs.begin(),
                         eltwise_y_in->outputs.end(),
                         [&](Node*& n) {
                           return n->id() == eltwise->id() ? true : false;
                         }),
          eltwise_y_in->outputs.end());
    } else {
      recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
                                 *bn_mean, *bn_variance, eltwise_y_in_tensor,
                                 epsilon, conv_type());
    }
S
Sylwester Fraczek 已提交
359 360 361 362 363 364

    // Update the elementwise_add node
    eltwise->Op()->SetAttr("axis", 1);
    eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));

    GraphSafeRemoveNodes(
365
        graph,
S
Sylwester Fraczek 已提交
366 367 368 369 370 371 372 373
        {bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
         bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out});

    IR_NODE_LINK_TO(eltwise, bn_out);

    found_conv_bn_count++;
  };

374
  gpd(graph, handler);
S
Sylwester Fraczek 已提交
375 376 377 378 379 380 381 382 383 384 385

  AddStatis(found_conv_bn_count);
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(conv_bn_fuse_pass, paddle::framework::ir::ConvBNFusePass);
REGISTER_PASS(conv_eltwiseadd_bn_fuse_pass,
              paddle::framework::ir::ConvEltwiseAddBNFusePass);
386 387 388 389
REGISTER_PASS(conv_transpose_bn_fuse_pass,
              paddle::framework::ir::ConvTransposeBNFusePass);
REGISTER_PASS(conv_transpose_eltwiseadd_bn_fuse_pass,
              paddle::framework::ir::ConvTransposeEltwiseAddBNFusePass);
390 391 392 393
REGISTER_PASS(depthwise_conv_bn_fuse_pass,
              paddle::framework::ir::DepthwiseConvBNFusePass);
REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass,
              paddle::framework::ir::DepthwiseConvEltwiseAddBNFusePass);
P
Pei Yang 已提交
394 395 396
REGISTER_PASS_CAPABILITY(conv_bn_fuse_pass)
    .AddCombination(
        paddle::framework::compatible::OpVersionComparatorCombination()
397
            .LE("conv2d", 1)
P
Pei Yang 已提交
398 399 400 401
            .EQ("batch_norm", 0));
REGISTER_PASS_CAPABILITY(conv_eltwiseadd_bn_fuse_pass)
    .AddCombination(
        paddle::framework::compatible::OpVersionComparatorCombination()
402
            .LE("conv2d", 1)
403
            .LE("elementwise_add", 1)
P
Pei Yang 已提交
404
            .EQ("batch_norm", 0));