phi_op_convert_pass.cc 11.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

王明冬 已提交
15
#include "paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h"
16 17 18 19 20

#include <glog/logging.h>
#include <llvm/ADT/SetVector.h>
#include <mlir/Analysis/SliceAnalysis.h>
#include <mlir/IR/Builders.h>
21 22
#include <mlir/IR/Operation.h>
#include <mlir/IR/OperationSupport.h>
23 24 25 26
#include <list>
#include <unordered_set>
#include <vector>

王明冬 已提交
27
#include "paddle/infrt/common/string.h"
28
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
29
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
王明冬 已提交
30 31
#include "paddle/infrt/dialect/phi/ir/phi_base.h"
#include "paddle/infrt/dialect/phi/ir/phi_kernels.h"
32 33 34
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include "paddle/infrt/dialect/phi/pass/proto_arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h"
35
#include "paddle/phi/core/kernel_factory.h"
36
#include "paddle/phi/ops/compat/signatures.h"
37 38

namespace {
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

infrt::Place ParsePlaceFromStr(const std::string &key) {
  size_t first_index = key.find_first_of('-');
  size_t second_index = key.find_last_of('-');
  if (first_index != second_index) {
    llvm::Optional<infrt::TargetType> tar =
        infrt::GetTargetType(key.substr(0, first_index));
    llvm::Optional<infrt::PrecisionType> pre = infrt::GetPrecisionType(
        key.substr(first_index + 1, second_index - first_index - 1));
    llvm::Optional<infrt::LayoutType> lay =
        infrt::GetLayoutType(key.substr(second_index + 1));
    if (tar && pre && lay) {
      return infrt::Place(tar.getValue(), pre.getValue(), lay.getValue());
    }
  }
  LOG(FATAL) << "Can't parse infrt::Place from string:" << key;
  return infrt::Place();
}

王明冬 已提交
58 59
class PhiOpConvertPass
    : public mlir::PassWrapper<PhiOpConvertPass, mlir::FunctionPass> {
60
 public:
王明冬 已提交
61
  ::llvm::StringRef getName() const override { return "PhiOpConvertPass"; }
62
  void runOnFunction() override;
63 64 65 66 67

  /// Initialize the valid_places_ by the valid_places_options_ while
  /// valid_places_options_ has values.
  mlir::LogicalResult initialize(mlir::MLIRContext *context) override;
  PhiOpConvertPass() {}
王明冬 已提交
68
  explicit PhiOpConvertPass(const std::vector<infrt::Place> &valid_places)
69 70
      : valid_places_(valid_places) {}

王明冬 已提交
71 72 73 74 75 76 77
  PhiOpConvertPass(const PhiOpConvertPass &other)
      : mlir::PassWrapper<PhiOpConvertPass, mlir::FunctionPass>(*this),
        valid_places_(other.valid_places_) {}

  ::llvm::StringRef getArgument() const override { return "phi-op-convert"; }
  void getDependentDialects(mlir::DialectRegistry &registry) const override;

78 79
 private:
  void convertStage();
王明冬 已提交
80 81
  void dispatchStage();

82
  ListOption<std::string> valid_places_options_{
王明冬 已提交
83 84
      *this,
      "valid-targets",
85 86 87
      llvm::cl::desc(
          "Set the valids target, such as: CPU-FP32-NCHW,GPU-FP32-NCHW"),
      llvm::cl::MiscFlags::CommaSeparated};
王明冬 已提交
88

89 90
  std::vector<infrt::Place> valid_places_;
};
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110

/// Initialize the canonicalizer by building the set of patterns used during
/// execution.
mlir::LogicalResult PhiOpConvertPass::initialize(mlir::MLIRContext *context) {
  if (valid_places_options_.hasValue()) {
    VLOG(4) << "Start parse valid_places from commond line:";
    if (!valid_places_.empty()) {
      LOG(WARNING) << "Find valid place from commandline, current value will "
                      "be overwrittern.";
      valid_places_.clear();
    }
    for (auto &val : *valid_places_options_) {
      VLOG(4) << "place string:" << val;
      valid_places_.emplace_back(ParsePlaceFromStr(val));
    }
    VLOG(4) << "End parse valid_places from commond line:";
  }
  return mlir::success();
}

王明冬 已提交
111 112
// Implementation of the PhiOpConvertPass.
void PhiOpConvertPass::runOnFunction() {
113
  convertStage();
王明冬 已提交
114
  dispatchStage();
115
}
王明冬 已提交
116 117

void PhiOpConvertPass::convertStage() {
118 119 120 121 122 123 124 125 126
  mlir::Block &body = getFunction().front();
  std::vector<mlir::Operation *> worklist;
  for (auto &op : body.without_terminator()) {
    worklist.push_back(&op);
  }
  mlir::OpBuilder builder(&body, body.begin());
  while (!worklist.empty()) {
    auto *op = worklist.back();
    worklist.pop_back();
王明冬 已提交
127
    if (!op) continue;
128

王明冬 已提交
129
    auto op_name = op->getName().getIdentifier().str();
130 131 132 133 134 135 136 137

    // only convert op in pd dialect.
    if (op_name.substr(0, 3) != "pd.") continue;
    op_name = op_name.substr(3);
    if (pd_dialect_inputs_info_map_.find(op_name) ==
            pd_dialect_inputs_info_map_.end() ||
        pd_dialect_outputs_info_map_.find(op_name) ==
            pd_dialect_outputs_info_map_.end()) {
王明冬 已提交
138
      LOG(WARNING) << "No op info found for " << op_name;
139 140 141
      // Todo: print log
      continue;
    }
142 143
    auto loc = getFunction().getLoc();
    builder.setInsertionPoint(op);
144 145 146

    if (!::phi::OpUtilsMap::Instance().HasArgumentMappingFn(op_name)) {
      op_name = phi::TransToPhiKernelName(op_name);
147 148 149
      auto kernel_op = builder.create<infrt::KernelOp>(loc,
                                                       op->getResultTypes(),
                                                       op->getOperands(),
150
                                                       op_name,
151 152 153 154 155 156
                                                       op->getAttrDictionary());
      op->replaceAllUsesWith(kernel_op.getResults());
    } else {
      ::phi::KernelSignature kernel_sign =
          ::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)(
              infrt::ProtoArgumentMappingContext(op));
157 158
      VLOG(3) << "IncompatiblePhiKernel: op(" << op_name << "), kernel("
              << kernel_sign.name << ")";
159 160 161 162 163 164 165 166 167 168 169
      // resort input&output according to kernel_sign
      ::llvm::SmallVector<mlir::Value, 4> inputs, ori_output;
      ::llvm::SmallVector<mlir::Type, 4> output_types;
      for (const std::string &str : std::get<0>(kernel_sign.args)) {
        if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) {
          LOG(ERROR) << "No input info for Op " << op_name << " and argument "
                     << str;
          return;
        }
        uint8_t index = pd_dialect_inputs_info_map_.at(op_name).at(str);
        inputs.push_back(op->getOperands()[index]);
170 171
      }

172 173 174 175 176 177 178 179 180 181 182 183 184 185
      for (const std::string &str : std::get<2>(kernel_sign.args)) {
        if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) {
          LOG(ERROR) << "No output info for Op " << op_name << " and argument "
                     << str;
          return;
        }
        uint8_t index = pd_dialect_outputs_info_map_.at(op_name).at(str);
        output_types.push_back(op->getResultTypes()[index]);
        ori_output.push_back(op->getResult(index));
      }
      auto kernel_op = builder.create<infrt::KernelOp>(
          loc, output_types, inputs, kernel_sign.name, op->getAttrDictionary());
      for (size_t index = 0; index < ori_output.size(); ++index) {
        ori_output[index].replaceAllUsesWith(kernel_op.getResult(index));
186 187
      }
    }
王明冬 已提交
188
    CHECK(op->use_empty());
189 190 191
    op->erase();
  }
}
王明冬 已提交
192 193

void PhiOpConvertPass::dispatchStage() {
194 195 196 197 198 199
  std::vector<infrt::KernelOp> worklist;
  mlir::Block &block = getFunction().front();
  for (auto &op : block) {
    infrt::KernelOp kernel_op = ::llvm::dyn_cast_or_null<infrt::KernelOp>(&op);
    if (nullptr != kernel_op) worklist.push_back(kernel_op);
  }
200 201

  mlir::OpBuilder builder(&block, block.begin());
202
  std::map<infrt::TargetType, mlir::Value> phi_context;
203 204
  for (infrt::KernelOp kernel_op : worklist) {
    std::string kernel_name = kernel_op.name().str();
205
    std::vector<infrt::PhiKernelDesc> candidates =
王明冬 已提交
206
        GetCandidateKernels(kernel_name, valid_places_);
207 208 209 210 211 212 213
    if (candidates.empty()) {
      LOG(FATAL) << "No candidate kernels for op:" << kernel_name;
      continue;
    }
    builder.setInsertionPoint(kernel_op);

    // Todo: Implimentation the concrete pass pick strategy
214
    const infrt::PhiKernelDesc &phi_kernel_desc = candidates.front();
215

216
    kernel_name =
王明冬 已提交
217
        infrt::getPhiTargetPrefix(phi_kernel_desc.kernel_type.target) +
218
        kernel_name +
王明冬 已提交
219 220
        infrt::getPhiPrecisionSuffix(phi_kernel_desc.kernel_type.precision) +
        infrt::getPhiLayoutSuffix(phi_kernel_desc.kernel_type.layout);
221 222 223 224

    mlir::OperationName operation_name(kernel_name, kernel_op.getContext());
    mlir::OperationState operation_state(kernel_op.getLoc(), operation_name);

王明冬 已提交
225
    if (phi_context.find(phi_kernel_desc.kernel_type.target) ==
226
        phi_context.end()) {
王明冬 已提交
227
      switch (phi_kernel_desc.kernel_type.target) {
228
        case infrt::TargetType::CPU: {
229 230
          auto context_value =
              builder
231
                  .create<infrt::phi::CreateCPUContextOp>(
232
                      kernel_op.getLoc(),
233 234
                      infrt::phi::ContextType::get(kernel_op.getContext(),
                                                   infrt::TargetType::CPU))
235
                  .output();
236
          phi_context[infrt::TargetType::CPU] = context_value;
237
        } break;
238 239 240 241 242 243 244 245 246 247
        case infrt::TargetType::GPU: {
          auto context_value =
              builder
                  .create<infrt::phi::CreateGPUContextOp>(
                      kernel_op.getLoc(),
                      infrt::phi::ContextType::get(kernel_op.getContext(),
                                                   infrt::TargetType::GPU))
                  .output();
          phi_context[infrt::TargetType::GPU] = context_value;
        } break;
248
        case infrt::TargetType::UNK:
249 250 251 252 253 254
        default:
          LOG(FATAL) << "Unsupported TargetType";
          break;
      }
    }
    operation_state.addOperands(
王明冬 已提交
255 256 257 258
        phi_context.at(phi_kernel_desc.kernel_type.target));

    for (size_t index = 0; index < phi_kernel_desc.input_types.size();
         ++index) {
259
      mlir::Value input = kernel_op.getOperand(index);
王明冬 已提交
260
      auto cvt_tensor_type_op = builder.create<infrt::TensorCastOp>(
261
          kernel_op.getLoc(),
262 263
          infrt::DenseTensorType::get(
              kernel_op.getContext(),
王明冬 已提交
264 265 266
              phi_kernel_desc.input_types[index].target,
              phi_kernel_desc.input_types[index].precision,
              phi_kernel_desc.input_types[index].layout),
267 268 269
          input);
      operation_state.addOperands(cvt_tensor_type_op.output());
    }
王明冬 已提交
270 271

    for (size_t index = 0; index < phi_kernel_desc.output_types.size();
272
         ++index) {
273 274
      operation_state.addTypes(infrt::DenseTensorType::get(
          kernel_op.getContext(),
王明冬 已提交
275 276 277
          phi_kernel_desc.output_types[index].target,
          phi_kernel_desc.output_types[index].precision,
          phi_kernel_desc.output_types[index].layout));
278 279 280
    }
    operation_state.addAttributes(kernel_op.attrsAttr().getValue());
    mlir::Operation *phi_operation = builder.createOperation(operation_state);
王明冬 已提交
281
    for (size_t index = 0; index < phi_kernel_desc.output_types.size();
282 283
         ++index) {
      mlir::Value input = phi_operation->getResult(index);
王明冬 已提交
284
      auto cvt_tensor_type_op = builder.create<infrt::TensorCastOp>(
285 286 287 288 289
          kernel_op.getLoc(), kernel_op.getResultTypes()[index], input);
      kernel_op.getResult(index).replaceAllUsesWith(
          cvt_tensor_type_op.output());
    }
    kernel_op.erase();
290 291
  }
}
292

王明冬 已提交
293 294 295 296 297 298 299 300 301
void PhiOpConvertPass::getDependentDialects(
    mlir::DialectRegistry &registry) const {
  registry.insert<infrt::InfrtDialect>();
  registry.insert<infrt::phi::PHIDialect>();
  registry.insert<infrt::phi::PHIDenseTensorDialect>();
  registry.insert<infrt::phi::PHICPUKernelDialect>();
  registry.insert<infrt::phi::PHIGPUKernelDialect>();
}

302 303
}  // namespace

王明冬 已提交
304 305
mlir::PassRegistration<PhiOpConvertPass> phi_op_convert;

306
std::unique_ptr<mlir::Pass> infrt::CreatePhiOpCvtPass(
307
    std::vector<Place> valid_places) {
王明冬 已提交
308 309
  return std::make_unique<PhiOpConvertPass>(valid_places);
}