mkldnn_inplace_pass.cc 7.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/mkldnn_inplace_pass.h"
#include <algorithm>
#include <memory>
#include <string>
19
#include <unordered_map>
20 21 22 23 24 25 26 27 28 29 30 31 32 33
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
namespace ir {

void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
  PADDLE_ENFORCE_NOT_NULL(graph,
                          platform::errors::InvalidArgument(
                              "Pointer to graph argument should not be NULL."));
34
  std::unordered_map<std::string, std::string> original_output_names;
35 36 37 38 39 40 41 42 43 44
  GraphPatternDetector gpd;
  patterns::MKLDNNInPlace mkldnn_inplace{gpd.mutable_pattern(),
                                         "mkldnn_inplace"};
  mkldnn_inplace();

  int found_inplace_count = 0;
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
                     Graph* g) {
    VLOG(3) << "Start to handle MKL-DNN In-Place pass";

45 46
    GET_IR_NODE_FROM_SUBGRAPH(current_op, inplace_to_be_op, mkldnn_inplace);
    GET_IR_NODE_FROM_SUBGRAPH(current_op_in, inplace_to_be_op_in,
47
                              mkldnn_inplace);
48
    GET_IR_NODE_FROM_SUBGRAPH(current_op_out, inplace_to_be_op_out,
49 50
                              mkldnn_inplace);
    GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, mkldnn_inplace);
51
    GET_IR_NODE_FROM_SUBGRAPH(next_op_out, next_op_out, mkldnn_inplace);
52

53 54
    if ((current_op->Op()->HasAttr("use_mkldnn") == false) ||
        (boost::get<bool>(current_op->Op()->GetAttr("use_mkldnn")) == false)) {
55 56 57 58 59
      VLOG(3) << "do not perform mkl-dnn inplace: use_mkldnn missing or set to "
                 "false";
      return;
    }

60 61
    auto& infer_inplace =
        OpInfoMap::Instance().Get(current_op->Op()->Type()).infer_inplace_;
62 63 64 65 66
    if (!infer_inplace) {
      VLOG(3) << "do not perform mkl-dnn inplace: missing InplaceInferer";
      return;
    }

67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    VLOG(3) << "DNNL Inplace op(" << current_op->id() << ") "
            << "Curr Node In: " << current_op_in->Name()
            << " Curr Node out: " << current_op_out->Name();

    VLOG(3) << "DNNL Inplace next op(" << next_op->id() << ") "
            << " next Node out: " << next_op_out->Name();

    auto inputs = current_op->Op()->Inputs();
    auto outputs = current_op->Op()->Outputs();
    auto in_to_outs = infer_inplace(false);  // strictly no CUDA for MKL-DNN
    VLOG(3) << "DNNL InplaceInferer op(" << current_op->id() << ") "
            << in_to_outs.begin()->first << ": "
            << inputs[in_to_outs.begin()->first][0] << " "
            << in_to_outs.begin()->second << ": "
            << outputs[in_to_outs.begin()->second][0];
    // If InferInplace pattern does not contain input node then skip
    auto inplace_input_vec = inputs[in_to_outs.begin()->first];
    if (std::find(inplace_input_vec.begin(), inplace_input_vec.end(),
                  current_op_in->Name()) == inplace_input_vec.end()) {
      VLOG(3) << "DNNL in-place pass SKIP pattern ";
87 88 89
      return;
    }

90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
    // Checking if this particular node (to be inplaced, overwritten)
    // is used anywhere else apart from inplaced op
    auto input_consumers = current_op_in->outputs;
    if (input_consumers.size() > 1) {
      VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot "
                 "be an input to multiple operators";
      return;
    }

    // If this op was alrady inplaced in previous pass placements
    // then we need to update input of next op
    // but original name to be changed is gone, so we need to remember it
    // on first time given op is to be inplaced
    if (current_op_in->Name() != current_op_out->Name()) {
      original_output_names[current_op->Name() + current_op_in->Name()] =
          current_op_out->Name();
    } else {
      VLOG(3) << "DNNL Inplace: Current op already inplaced! ";
    }

    // It may be that next op is reusing some of vars, we need to
    // make sure that unwanted inplace is not created
    for (auto& n : current_op_out->outputs) {
      auto& n_op_infer_inplace =
          OpInfoMap::Instance().Get(n->Op()->Type()).infer_inplace_;
      if ((n_op_infer_inplace == nullptr)) {
        for (auto& m : n->outputs) {
          if (m->Name() == current_op_in->Name()) {
            VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot "
                       "be an output to non-inplaced next op";
            return;
121 122 123 124 125
          }
        }
      }
    }

126 127 128
    auto original_name =
        original_output_names[current_op->Name() + current_op_in->Name()];
    current_op_out->RenameVar(current_op_in->Name());
129 130 131

    // Get mapping of input to output
    auto out_name = in_to_outs.begin()->second;
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
    current_op->Op()->SetOutput(
        out_name, std::vector<std::string>({current_op_out->Name()}));

    // If next op in a line is doing inplace
    // then we need to update its output as well

    // Get inferer of next op
    // If no inferer then we are done
    auto& next_op_infer_inplace =
        OpInfoMap::Instance().Get(next_op->Op()->Type()).infer_inplace_;
    if (next_op_infer_inplace) {
      auto in_to_outs = next_op_infer_inplace(false);
      auto out_name = in_to_outs.begin()->second;
      auto* op = next_op->Op();
      auto inputs = op->Inputs();
      auto outputs = op->Outputs();
      // Check if in-place happened
      // for variable we changed (original name)
      // TODO(jczaja): make recursive propagation of inplace
      auto next_op_inplace_inputs = inputs[in_to_outs.begin()->first];
      if ((next_op_inplace_inputs == outputs[in_to_outs.begin()->second]) &&
          (std::find(next_op_inplace_inputs.begin(),
                     next_op_inplace_inputs.end(),
                     original_name) != next_op_inplace_inputs.end())) {
        VLOG(3) << "DNNL InPlace: Next Op is in-placed , updating its "
                   "input "
                   "and output var!";
        next_op->Op()->SetOutput(
            out_name, std::vector<std::string>({current_op_out->Name()}));
        next_op_out->RenameVar(current_op_in->Name());
        // Get ops that next_op_out is linked to and update their input
        auto next_op_out_consumers = next_op_out->outputs;  // Has to be ops
        for (auto& c : next_op_out_consumers) {
          c->Op()->RenameInput(original_name, current_op_out->Name());
        }
      }
    }

    next_op->Op()->RenameInput(original_name, current_op_out->Name());

172
    found_inplace_count++;
173
    VLOG(3) << "DNNL InPlace applied!";
174 175 176 177 178 179 180 181 182 183
  };

  gpd(graph, handler);
}

}  // namespace ir
}  // namespace framework
}  // namespace paddle

REGISTER_PASS(mkldnn_inplace_pass, paddle::framework::ir::MKLDNNInPlacePass);