conv_bn_fuse_pass.cc 16.5 KB
Newer Older
S
Sylwester Fraczek 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/conv_bn_fuse_pass.h"
W
wanghuancoder 已提交
16

S
Sylwester Fraczek 已提交
17 18
#include <string>
#include <vector>
W
wanghuancoder 已提交
19

S
Sylwester Fraczek 已提交
20
#include "paddle/fluid/framework/lod_tensor.h"
P
Pei Yang 已提交
21
#include "paddle/fluid/framework/op_version_registry.h"
S
Sylwester Fraczek 已提交
22 23 24
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/enforce.h"

W
wanghuancoder 已提交
25 26 27 28 29 30 31
namespace paddle {
namespace framework {
class LoDTensor;
class Scope;
}  // namespace framework
}  // namespace paddle

S
Sylwester Fraczek 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
namespace paddle {
namespace framework {
namespace ir {

#define GET_CONV_BN_NODES(pattern_name)                                      \
  /* OPERATORS */                                                            \
  GET_IR_NODE_FROM_SUBGRAPH(conv, conv, pattern_name);                       \
  GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, pattern_name);           \
  /* CONV inputs */                                                          \
  GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, pattern_name);         \
  /* CONV outputs */                                                         \
  GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, pattern_name);               \
  /* BN inputs */                                                            \
  GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, pattern_name);               \
  GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, pattern_name);                 \
  GET_IR_NODE_FROM_SUBGRAPH(bn_mean, bn_mean, pattern_name);                 \
  GET_IR_NODE_FROM_SUBGRAPH(bn_variance, bn_variance, pattern_name);         \
  /* BN outputs */                                                           \
  GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, pattern_name); /* Out */         \
  GET_IR_NODE_FROM_SUBGRAPH(bn_mean_out, bn_mean_out, pattern_name);         \
  GET_IR_NODE_FROM_SUBGRAPH(bn_variance_out, bn_variance_out, pattern_name); \
  GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name);     \
  GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name)

void recompute_bias_and_weights(const Scope* scope,
                                ir::Node* conv_weight,            //
                                const ir::Node& bn_scale,         //
                                const LoDTensor& bn_bias_tensor,  //
                                const ir::Node& bn_mean,          //
                                const ir::Node& bn_variance,      //
62
                                LoDTensor* eltwise_y_in_tensor,   //
63
                                float epsilon, const std::string& conv_type) {
64 65 66 67 68 69 70
  using EigenVectorArrayMap =
      Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
  using ConstEigenVectorArrayMap =
      Eigen::Map<const Eigen::Array<float, Eigen::Dynamic, 1>>;
  using EigenMatrixArrayMap = Eigen::Map<
      Eigen::Array<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;

S
Sylwester Fraczek 已提交
71
  // Re-compute bias of conv2d from BN
72 73 74 75 76 77
  PADDLE_ENFORCE_EQ(
      eltwise_y_in_tensor->dims(), bn_bias_tensor.dims(),
      platform::errors::InvalidArgument("Tensor elementwise y(%d) and batch "
                                        "norm bias(%d) must have same dims.",
                                        eltwise_y_in_tensor->dims().size(),
                                        bn_bias_tensor.dims().size()));
S
Sylwester Fraczek 已提交
78 79 80 81 82 83

  auto* scale_tensor = scope->FindVar(bn_scale.Name())->GetMutable<LoDTensor>();
  auto* variance_tensor =
      scope->FindVar(bn_variance.Name())->GetMutable<LoDTensor>();
  auto* mean_tensor = scope->FindVar(bn_mean.Name())->GetMutable<LoDTensor>();

84 85 86 87 88 89 90 91 92
  ConstEigenVectorArrayMap scale_array(scale_tensor->data<float>(),
                                       scale_tensor->numel(), 1);
  EigenVectorArrayMap variance_array(
      variance_tensor->mutable_data<float>(platform::CPUPlace()),
      variance_tensor->numel(), 1);
  ConstEigenVectorArrayMap mean_array(mean_tensor->data<float>(),
                                      mean_tensor->numel(), 1);
  ConstEigenVectorArrayMap bn_bias_array(bn_bias_tensor.data<float>(),
                                         bn_bias_tensor.numel(), 1);
S
Sylwester Fraczek 已提交
93

94 95 96 97 98 99 100 101
  // variance will not be used anymore, so make it std_array and then tmp_array
  variance_array += epsilon;
  variance_array = variance_array.sqrt();
  variance_array = scale_array / variance_array;

  EigenVectorArrayMap eltwise_y_in_array(
      eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
      eltwise_y_in_tensor->numel(), 1);
102

103 104
  eltwise_y_in_array =
      ((eltwise_y_in_array - mean_array) * variance_array) + bn_bias_array;
S
Sylwester Fraczek 已提交
105 106

  // Re-compute weight of conv2d from BN
107 108
  auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
  auto weights_shape = weights->dims();
109 110 111 112 113 114 115 116 117 118 119 120 121 122
  auto weights_data = weights->mutable_data<float>(platform::CPUPlace());

  // ConvTranspose weights are in IOHW format
  if (conv_type == "conv2d_transpose") {
    int kernel_size = weights_shape[2] * weights_shape[3];
    for (int i = 0; i < weights->numel();) {
      for (int j = 0; j < weights_shape[1]; ++j) {
        for (int k = 0; k < kernel_size; ++k, ++i) {
          weights_data[i] *= variance_array[j];
        }
      }
    }
  } else {
    auto weights_shape_2d = flatten_to_2d(weights_shape, 1);
123

124 125
    EigenMatrixArrayMap weights_array_2d(weights_data, weights_shape_2d[0],
                                         weights_shape_2d[1]);
126

127 128
    weights_array_2d.colwise() *= variance_array;
  }
S
Sylwester Fraczek 已提交
129 130
}

131
void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const {
132 133
  PADDLE_ENFORCE_NOT_NULL(
      graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
134
  FusePassBase::Init(name_scope_, graph);
S
Sylwester Fraczek 已提交
135 136

  auto* scope = param_scope();
137 138
  PADDLE_ENFORCE_NOT_NULL(
      scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
S
Sylwester Fraczek 已提交
139 140 141 142 143 144

  GraphPatternDetector gpd;
  auto* conv_input =
      gpd.mutable_pattern()
          ->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
          ->AsInput()
145
          ->assert_is_op_input(conv_type(), "Input");
S
Sylwester Fraczek 已提交
146
  patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
147
  conv_bn_pattern(conv_input, conv_type(), false /*with_eltwise_add*/);
S
Sylwester Fraczek 已提交
148 149 150 151

  int found_conv_bn_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
152
    VLOG(4) << "handle " + conv_type() + "BN fuse";
S
Sylwester Fraczek 已提交
153 154 155 156

    // conv, batch_norm,
    // conv_weight, conv_out,
    // bn_scale, bn_bias, bn_mean, bn_variance,
W
Wojciech Uss 已提交
157 158
    // bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,
    // bn_saved_variance
S
Sylwester Fraczek 已提交
159 160
    GET_CONV_BN_NODES(conv_bn_pattern);

W
Wojciech Uss 已提交
161 162 163
    // check if fuse can be done and if MKL-DNN should be used
    FuseOptions fuse_option = FindFuseOption(*conv, *batch_norm);
    if (fuse_option == DO_NOT_FUSE) {
164
      VLOG(3) << "do not perform " + conv_type() + " bn fuse";
W
Wojciech Uss 已提交
165 166 167
      return;
    }

168 169 170 171
    // Get batch norm bias
    auto* bn_bias_tensor =
        scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>();

S
Sylwester Fraczek 已提交
172 173
    // Create eltwise_y (conv bias) variable
    VarDesc eltwise_y_in_desc(
174
        patterns::PDNodeName("fuse_conv_bn", conv_type() + "_eltwise_y_in"));
175 176 177
    eltwise_y_in_desc.SetShape(framework::vectorize(bn_bias_tensor->dims()));
    eltwise_y_in_desc.SetDataType(bn_bias_tensor->type());
    eltwise_y_in_desc.SetLoDLevel(bn_bias->Var()->GetLoDLevel());
W
Wojciech Uss 已提交
178
    eltwise_y_in_desc.SetPersistable(true);
S
Sylwester Fraczek 已提交
179 180 181 182 183 184 185 186 187 188
    auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
    auto* eltwise_y_in_tensor =
        scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();

    // Initialize eltwise_y
    eltwise_y_in_tensor->Resize(bn_bias_tensor->dims());
    std::fill_n(eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
                eltwise_y_in_tensor->numel(), 0.0f);

    // update weights and biases
189 190
    float epsilon =
        BOOST_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon"));
S
Sylwester Fraczek 已提交
191
    recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
192
                               *bn_mean, *bn_variance, eltwise_y_in_tensor,
193
                               epsilon, conv_type());
S
Sylwester Fraczek 已提交
194

W
Wojciech Uss 已提交
195 196 197 198 199 200 201 202 203
    // with MKL-DNN fuse conv+bn into conv with bias
    // without MKL-DNN fuse conv+bn into conv+elementwise_add
    if (fuse_option == FUSE_MKLDNN) {
      auto input_names = conv->Op()->InputNames();
      bool has_bias = std::find(input_names.begin(), input_names.end(),
                                "Bias") != input_names.end();
      if (has_bias && conv->Op()->Input("Bias").size() > 0) {
        // reuse existing conv bias node
        auto conv_bias_names = conv->Op()->Input("Bias");
204 205 206
        PADDLE_ENFORCE_EQ(
            conv_bias_names.size(), 1UL,
            platform::errors::InvalidArgument("Find input var Bais error."));
W
Wojciech Uss 已提交
207 208
        auto* conv_bias_var = scope->FindVar(conv_bias_names[0]);
        auto* conv_bias_tensor = conv_bias_var->GetMutable<LoDTensor>();
209 210 211 212 213 214 215
        PADDLE_ENFORCE_EQ(
            conv_bias_tensor->dims(), eltwise_y_in_tensor->dims(),
            platform::errors::InvalidArgument(
                "Tensor convolution bias(%d) and elementwise y(%d) "
                "must have same dims.",
                conv_bias_tensor->dims().size(),
                eltwise_y_in_tensor->dims().size()));
W
Wojciech Uss 已提交
216 217 218 219 220 221 222 223 224 225 226 227

        auto eigen_conv_bias = EigenVector<float>::From(*conv_bias_tensor);
        eigen_conv_bias += EigenVector<float>::From(*eltwise_y_in_tensor);
      } else {
        // add new conv_bias node
        conv->Op()->SetInput(
            "Bias", std::vector<std::string>({eltwise_y_in_node->Name()}));
        IR_NODE_LINK_TO(eltwise_y_in_node, conv);
      }
      conv->Op()->SetOutput("Output",
                            std::vector<std::string>({bn_out->Name()}));
      GraphSafeRemoveNodes(
228
          graph,
W
Wojciech Uss 已提交
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
          {conv_out, bn_scale, bn_bias, bn_mean, bn_variance, batch_norm,
           bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance});

      IR_NODE_LINK_TO(conv, bn_out);
      found_conv_bn_count++;
    } else {  // fuse_option == FUSE_NATIVE
      // create an elementwise add node.
      OpDesc desc;
      desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
      desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
      desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
      desc.SetType("elementwise_add");
      desc.SetAttr("axis", 1);
      auto eltwise_op = g->CreateOpNode(&desc);  // OpDesc will be copied.

244 245 246
      GraphSafeRemoveNodes(graph, {bn_scale, bn_bias, bn_mean, bn_variance,
                                   batch_norm, bn_mean_out, bn_variance_out,
                                   bn_saved_mean, bn_saved_variance});
W
Wojciech Uss 已提交
247 248 249 250 251 252

      IR_NODE_LINK_TO(conv_out, eltwise_op);
      IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
      IR_NODE_LINK_TO(eltwise_op, bn_out);
      found_conv_bn_count++;
    }
S
Sylwester Fraczek 已提交
253 254
  };

255
  gpd(graph, handler);
S
Sylwester Fraczek 已提交
256 257 258 259

  AddStatis(found_conv_bn_count);
}

260
void ConvEltwiseAddBNFusePass::ApplyImpl(ir::Graph* graph) const {
261 262
  PADDLE_ENFORCE_NOT_NULL(
      graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
263
  FusePassBase::Init(name_scope_, graph);
S
Sylwester Fraczek 已提交
264 265

  auto* scope = param_scope();
266 267
  PADDLE_ENFORCE_NOT_NULL(
      scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
S
Sylwester Fraczek 已提交
268 269 270 271 272 273

  GraphPatternDetector gpd;
  auto* conv_input =
      gpd.mutable_pattern()
          ->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
          ->AsInput()
274
          ->assert_is_op_input(conv_type(), "Input");
S
Sylwester Fraczek 已提交
275
  patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
276
  conv_bn_pattern(conv_input, conv_type(), true /*with_eltwise_add*/);
S
Sylwester Fraczek 已提交
277 278 279 280

  int found_conv_bn_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
281
    VLOG(4) << "handle " + conv_type() + "BN fuse";
S
Sylwester Fraczek 已提交
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303

    // conv, batch_norm,
    // conv_weight, conv_out,
    // bn_scale, bn_bias, bn_mean, bn_variance,
    // bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,bn_saved_variance
    GET_CONV_BN_NODES(conv_bn_pattern);
    // OPERATORS
    GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bn_pattern);
    // BIAS inputs
    GET_IR_NODE_FROM_SUBGRAPH(eltwise_y_in, eltwise_y_in, conv_bn_pattern);
    // BIAS outputs
    GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bn_pattern);

    // Get eltwise_y (conv bias) variable
    auto* eltwise_y_in_tensor =
        scope->FindVar(eltwise_y_in->Name())->GetMutable<LoDTensor>();

    // Get batch norm bias
    auto* bn_bias_tensor =
        scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>();

    // update weights and biases
304 305
    float epsilon =
        BOOST_GET_CONST(float, batch_norm->Op()->GetAttr("epsilon"));
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347

    // if bias is an input to other ops as well then we cannot overwrite it
    // so we create separate elementwise Y in nodes
    if (eltwise_y_in->outputs.size() > 1) {
      // Make a copy of eltwise Y input tensor
      // Create eltwise_y (conv bias) variable
      VarDesc eltwise_y_in_desc(patterns::PDNodeName(
          name_scope_, "eltwise_y_in" + std::to_string(found_conv_bn_count)));
      eltwise_y_in_desc.SetShape(
          framework::vectorize(eltwise_y_in_tensor->dims()));
      eltwise_y_in_desc.SetDataType(eltwise_y_in_tensor->type());
      eltwise_y_in_desc.SetLoDLevel(eltwise_y_in->Var()->GetLoDLevel());
      eltwise_y_in_desc.SetPersistable(true);
      auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
      auto* eltwise_y_in_tensor_ex =
          scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();

      // Initialize eltwise_y
      TensorCopy(*eltwise_y_in_tensor, platform::CPUPlace(),
                 eltwise_y_in_tensor_ex);

      recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
                                 *bn_mean, *bn_variance, eltwise_y_in_tensor_ex,
                                 epsilon, conv_type());
      // Set new var
      eltwise->Op()->RenameInput(eltwise_y_in->Name(),
                                 eltwise_y_in_node->Name());
      // Link new bias node to eltwise
      IR_NODE_LINK_TO(eltwise_y_in_node, eltwise);
      // unlink original bias from eltwise_op
      eltwise_y_in->outputs.erase(
          std::remove_if(eltwise_y_in->outputs.begin(),
                         eltwise_y_in->outputs.end(),
                         [&](Node*& n) {
                           return n->id() == eltwise->id() ? true : false;
                         }),
          eltwise_y_in->outputs.end());
    } else {
      recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
                                 *bn_mean, *bn_variance, eltwise_y_in_tensor,
                                 epsilon, conv_type());
    }
S
Sylwester Fraczek 已提交
348 349 350 351 352 353

    // Update the elementwise_add node
    eltwise->Op()->SetAttr("axis", 1);
    eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));

    GraphSafeRemoveNodes(
354
        graph,
S
Sylwester Fraczek 已提交
355 356 357 358 359 360 361 362
        {bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
         bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out});

    IR_NODE_LINK_TO(eltwise, bn_out);

    found_conv_bn_count++;
  };

363
  gpd(graph, handler);
S
Sylwester Fraczek 已提交
364 365 366 367 368 369 370 371 372 373 374

  AddStatis(found_conv_bn_count);
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(conv_bn_fuse_pass, paddle::framework::ir::ConvBNFusePass);
REGISTER_PASS(conv_eltwiseadd_bn_fuse_pass,
              paddle::framework::ir::ConvEltwiseAddBNFusePass);
375 376 377 378
REGISTER_PASS(conv_transpose_bn_fuse_pass,
              paddle::framework::ir::ConvTransposeBNFusePass);
REGISTER_PASS(conv_transpose_eltwiseadd_bn_fuse_pass,
              paddle::framework::ir::ConvTransposeEltwiseAddBNFusePass);
379 380 381 382
REGISTER_PASS(depthwise_conv_bn_fuse_pass,
              paddle::framework::ir::DepthwiseConvBNFusePass);
REGISTER_PASS(depthwise_conv_eltwiseadd_bn_fuse_pass,
              paddle::framework::ir::DepthwiseConvEltwiseAddBNFusePass);
P
Pei Yang 已提交
383 384 385
REGISTER_PASS_CAPABILITY(conv_bn_fuse_pass)
    .AddCombination(
        paddle::framework::compatible::OpVersionComparatorCombination()
386
            .LE("conv2d", 1)
P
Pei Yang 已提交
387 388 389 390
            .EQ("batch_norm", 0));
REGISTER_PASS_CAPABILITY(conv_eltwiseadd_bn_fuse_pass)
    .AddCombination(
        paddle::framework::compatible::OpVersionComparatorCombination()
391
            .LE("conv2d", 1)
P
Pei Yang 已提交
392 393
            .EQ("elementwise_add", 0)
            .EQ("batch_norm", 0));