phi_op_convert_pass.cc 13.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

王明冬 已提交
15
#include "paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h"
16 17 18 19 20

#include <glog/logging.h>
#include <llvm/ADT/SetVector.h>
#include <mlir/Analysis/SliceAnalysis.h>
#include <mlir/IR/Builders.h>
21 22
#include <mlir/IR/Operation.h>
#include <mlir/IR/OperationSupport.h>
23

24 25 26 27
#include <list>
#include <unordered_set>
#include <vector>

王明冬 已提交
28
#include "paddle/infrt/common/string.h"
29
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
30
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
王明冬 已提交
31 32
#include "paddle/infrt/dialect/phi/ir/phi_base.h"
#include "paddle/infrt/dialect/phi/ir/phi_kernels.h"
33 34 35
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include "paddle/infrt/dialect/phi/pass/proto_arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h"
36
#include "paddle/phi/core/kernel_factory.h"
37
#include "paddle/phi/ops/compat/signatures.h"
38 39

namespace {
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

infrt::Place ParsePlaceFromStr(const std::string &key) {
  size_t first_index = key.find_first_of('-');
  size_t second_index = key.find_last_of('-');
  if (first_index != second_index) {
    llvm::Optional<infrt::TargetType> tar =
        infrt::GetTargetType(key.substr(0, first_index));
    llvm::Optional<infrt::PrecisionType> pre = infrt::GetPrecisionType(
        key.substr(first_index + 1, second_index - first_index - 1));
    llvm::Optional<infrt::LayoutType> lay =
        infrt::GetLayoutType(key.substr(second_index + 1));
    if (tar && pre && lay) {
      return infrt::Place(tar.getValue(), pre.getValue(), lay.getValue());
    }
  }
  LOG(FATAL) << "Can't parse infrt::Place from string:" << key;
  return infrt::Place();
}

王明冬 已提交
59 60
class PhiOpConvertPass
    : public mlir::PassWrapper<PhiOpConvertPass, mlir::FunctionPass> {
61
 public:
王明冬 已提交
62
  ::llvm::StringRef getName() const override { return "PhiOpConvertPass"; }
63
  void runOnFunction() override;
64 65 66 67 68

  /// Initialize the valid_places_ by the valid_places_options_ while
  /// valid_places_options_ has values.
  mlir::LogicalResult initialize(mlir::MLIRContext *context) override;
  PhiOpConvertPass() {}
王明冬 已提交
69
  explicit PhiOpConvertPass(const std::vector<infrt::Place> &valid_places)
70 71
      : valid_places_(valid_places) {}

王明冬 已提交
72 73 74 75 76 77 78
  PhiOpConvertPass(const PhiOpConvertPass &other)
      : mlir::PassWrapper<PhiOpConvertPass, mlir::FunctionPass>(*this),
        valid_places_(other.valid_places_) {}

  ::llvm::StringRef getArgument() const override { return "phi-op-convert"; }
  void getDependentDialects(mlir::DialectRegistry &registry) const override;

79
 private:
80
  void updateInputsAndResults(infrt::TargetType target);
81
  void convertStage();
王明冬 已提交
82 83
  void dispatchStage();

84
  ListOption<std::string> valid_places_options_{
王明冬 已提交
85 86
      *this,
      "valid-targets",
87 88 89
      llvm::cl::desc(
          "Set the valids target, such as: CPU-FP32-NCHW,GPU-FP32-NCHW"),
      llvm::cl::MiscFlags::CommaSeparated};
王明冬 已提交
90

91 92
  std::vector<infrt::Place> valid_places_;
};
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112

/// Initialize the canonicalizer by building the set of patterns used during
/// execution.
mlir::LogicalResult PhiOpConvertPass::initialize(mlir::MLIRContext *context) {
  if (valid_places_options_.hasValue()) {
    VLOG(4) << "Start parse valid_places from commond line:";
    if (!valid_places_.empty()) {
      LOG(WARNING) << "Find valid place from commandline, current value will "
                      "be overwrittern.";
      valid_places_.clear();
    }
    for (auto &val : *valid_places_options_) {
      VLOG(4) << "place string:" << val;
      valid_places_.emplace_back(ParsePlaceFromStr(val));
    }
    VLOG(4) << "End parse valid_places from commond line:";
  }
  return mlir::success();
}

王明冬 已提交
113 114
// Implementation of the PhiOpConvertPass.
void PhiOpConvertPass::runOnFunction() {
115
  updateInputsAndResults(valid_places_[0].target);
116
  convertStage();
王明冬 已提交
117
  dispatchStage();
118
}
王明冬 已提交
119

120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
void PhiOpConvertPass::updateInputsAndResults(infrt::TargetType target) {
  mlir::Block &body = getFunction().front();
  auto loc = getFunction().getLoc();
  mlir::Operation &operation = body.front();
  mlir::MLIRContext *context = operation.getContext();
  size_t num_input = body.getNumArguments();

  // step1. update input cpu tensors into gpu tensors
  for (size_t index = 0; index < num_input; index++) {
    auto argument = body.getArgument(index);
    if (auto t = argument.getType().dyn_cast<::infrt::DenseTensorType>()) {
      mlir::Type replace_type = infrt::DenseTensorType::get(
          context, target, t.getPrecision(), infrt::LayoutType::NCHW);
      getFunction().insertArgument(index, replace_type, {}, loc);
      argument.replaceAllUsesWith(getFunction().getArgument(index));
      getFunction().eraseArgument(index + 1);
    }
  }
  // update output tensors
  unsigned int num_result = getFunction().getNumResults();
  for (unsigned int index = 0; index < num_result; index++) {
    mlir::Type replace_type =
        infrt::DenseTensorType::get(context,
                                    target,
                                    infrt::PrecisionType::FLOAT32,
                                    infrt::LayoutType::NCHW);
    getFunction().eraseResult(index);
    getFunction().insertResult(index, replace_type, {});
  }
  // update dense_tensor_map
  mlir::Type replace_type = infrt::DenseTensorType::get(
      context, target, infrt::PrecisionType::FLOAT32, infrt::LayoutType::NCHW);

  for (auto &op : body.without_terminator()) {
    if (op.getName().getIdentifier().str() == "phi_dt.tensor_map_get_tensor")
      op.getResult(0).setType(replace_type);
  }
}

王明冬 已提交
159
void PhiOpConvertPass::convertStage() {
160 161 162 163 164 165 166 167 168
  mlir::Block &body = getFunction().front();
  std::vector<mlir::Operation *> worklist;
  for (auto &op : body.without_terminator()) {
    worklist.push_back(&op);
  }
  mlir::OpBuilder builder(&body, body.begin());
  while (!worklist.empty()) {
    auto *op = worklist.back();
    worklist.pop_back();
王明冬 已提交
169
    if (!op) continue;
170

王明冬 已提交
171
    auto op_name = op->getName().getIdentifier().str();
172 173 174 175 176 177 178 179

    // only convert op in pd dialect.
    if (op_name.substr(0, 3) != "pd.") continue;
    op_name = op_name.substr(3);
    if (pd_dialect_inputs_info_map_.find(op_name) ==
            pd_dialect_inputs_info_map_.end() ||
        pd_dialect_outputs_info_map_.find(op_name) ==
            pd_dialect_outputs_info_map_.end()) {
王明冬 已提交
180
      LOG(WARNING) << "No op info found for " << op_name;
181 182 183
      // Todo: print log
      continue;
    }
184 185
    auto loc = getFunction().getLoc();
    builder.setInsertionPoint(op);
186 187 188

    if (!::phi::OpUtilsMap::Instance().HasArgumentMappingFn(op_name)) {
      op_name = phi::TransToPhiKernelName(op_name);
189 190 191
      auto kernel_op = builder.create<infrt::KernelOp>(loc,
                                                       op->getResultTypes(),
                                                       op->getOperands(),
192
                                                       op_name,
193 194 195 196
                                                       op->getAttrDictionary());
      op->replaceAllUsesWith(kernel_op.getResults());
    } else {
      ::phi::KernelSignature kernel_sign =
197
          (*::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name))(
198
              infrt::ProtoArgumentMappingContext(op));
199 200
      VLOG(3) << "IncompatiblePhiKernel: op(" << op_name << "), kernel("
              << kernel_sign.name << ")";
201 202 203
      // resort input&output according to kernel_sign
      ::llvm::SmallVector<mlir::Value, 4> inputs, ori_output;
      ::llvm::SmallVector<mlir::Type, 4> output_types;
204
      for (const std::string &str : kernel_sign.input_names) {
205 206 207 208 209 210 211
        if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) {
          LOG(ERROR) << "No input info for Op " << op_name << " and argument "
                     << str;
          return;
        }
        uint8_t index = pd_dialect_inputs_info_map_.at(op_name).at(str);
        inputs.push_back(op->getOperands()[index]);
212 213
      }

214
      for (const std::string &str : kernel_sign.output_names) {
215 216 217 218 219 220 221 222 223 224 225 226 227
        if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) {
          LOG(ERROR) << "No output info for Op " << op_name << " and argument "
                     << str;
          return;
        }
        uint8_t index = pd_dialect_outputs_info_map_.at(op_name).at(str);
        output_types.push_back(op->getResultTypes()[index]);
        ori_output.push_back(op->getResult(index));
      }
      auto kernel_op = builder.create<infrt::KernelOp>(
          loc, output_types, inputs, kernel_sign.name, op->getAttrDictionary());
      for (size_t index = 0; index < ori_output.size(); ++index) {
        ori_output[index].replaceAllUsesWith(kernel_op.getResult(index));
228 229
      }
    }
王明冬 已提交
230
    CHECK(op->use_empty());
231 232 233
    op->erase();
  }
}
王明冬 已提交
234 235

void PhiOpConvertPass::dispatchStage() {
236 237 238 239 240 241
  std::vector<infrt::KernelOp> worklist;
  mlir::Block &block = getFunction().front();
  for (auto &op : block) {
    infrt::KernelOp kernel_op = ::llvm::dyn_cast_or_null<infrt::KernelOp>(&op);
    if (nullptr != kernel_op) worklist.push_back(kernel_op);
  }
242 243

  mlir::OpBuilder builder(&block, block.begin());
244
  std::map<infrt::TargetType, mlir::Value> phi_context;
245

246 247
  for (infrt::KernelOp kernel_op : worklist) {
    std::string kernel_name = kernel_op.name().str();
248
    std::vector<infrt::PhiKernelDesc> candidates =
王明冬 已提交
249
        GetCandidateKernels(kernel_name, valid_places_);
250 251 252 253 254 255 256
    if (candidates.empty()) {
      LOG(FATAL) << "No candidate kernels for op:" << kernel_name;
      continue;
    }
    builder.setInsertionPoint(kernel_op);

    // Todo: Implimentation the concrete pass pick strategy
257
    const infrt::PhiKernelDesc &phi_kernel_desc = candidates.front();
258

259
    kernel_name =
王明冬 已提交
260
        infrt::getPhiTargetPrefix(phi_kernel_desc.kernel_type.target) +
261
        kernel_name +
王明冬 已提交
262 263
        infrt::getPhiPrecisionSuffix(phi_kernel_desc.kernel_type.precision) +
        infrt::getPhiLayoutSuffix(phi_kernel_desc.kernel_type.layout);
264 265 266 267

    mlir::OperationName operation_name(kernel_name, kernel_op.getContext());
    mlir::OperationState operation_state(kernel_op.getLoc(), operation_name);

王明冬 已提交
268
    if (phi_context.find(phi_kernel_desc.kernel_type.target) ==
269
        phi_context.end()) {
王明冬 已提交
270
      switch (phi_kernel_desc.kernel_type.target) {
271
        case infrt::TargetType::CPU: {
272 273
          auto context_value =
              builder
274
                  .create<infrt::phi::CreateCPUContextOp>(
275
                      kernel_op.getLoc(),
276 277
                      infrt::phi::ContextType::get(kernel_op.getContext(),
                                                   infrt::TargetType::CPU))
278
                  .output();
279
          phi_context[infrt::TargetType::CPU] = context_value;
280
        } break;
281 282 283 284 285 286 287 288 289 290
        case infrt::TargetType::GPU: {
          auto context_value =
              builder
                  .create<infrt::phi::CreateGPUContextOp>(
                      kernel_op.getLoc(),
                      infrt::phi::ContextType::get(kernel_op.getContext(),
                                                   infrt::TargetType::GPU))
                  .output();
          phi_context[infrt::TargetType::GPU] = context_value;
        } break;
291
        case infrt::TargetType::UNK:
292 293 294 295 296 297
        default:
          LOG(FATAL) << "Unsupported TargetType";
          break;
      }
    }
    operation_state.addOperands(
王明冬 已提交
298 299 300 301
        phi_context.at(phi_kernel_desc.kernel_type.target));

    for (size_t index = 0; index < phi_kernel_desc.input_types.size();
         ++index) {
302
      mlir::Value input = kernel_op.getOperand(index);
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
      if (input.getType().dyn_cast<::infrt::DenseTensorType>().getTarget() ==
              ::infrt::TargetType::CPU &&
          phi_kernel_desc.input_types[index].target ==
              ::infrt::TargetType::GPU) {
        auto cvt_tensor_type_op = builder.create<infrt::phi::GpuMemCopyOp>(
            kernel_op.getLoc(),
            infrt::DenseTensorType::get(
                kernel_op.getContext(),
                phi_kernel_desc.input_types[index].target,
                phi_kernel_desc.input_types[index].precision,
                phi_kernel_desc.input_types[index].layout),
            input,
            phi_context[infrt::TargetType::GPU],
            mlir::BoolAttr::get(kernel_op.getContext(), /*d2h*/ false));

        operation_state.addOperands(cvt_tensor_type_op.output());
      } else {
        operation_state.addOperands(input);
      }
322
    }
王明冬 已提交
323 324

    for (size_t index = 0; index < phi_kernel_desc.output_types.size();
325
         ++index) {
326 327
      operation_state.addTypes(infrt::DenseTensorType::get(
          kernel_op.getContext(),
王明冬 已提交
328 329 330
          phi_kernel_desc.output_types[index].target,
          phi_kernel_desc.output_types[index].precision,
          phi_kernel_desc.output_types[index].layout));
331 332 333
    }
    operation_state.addAttributes(kernel_op.attrsAttr().getValue());
    mlir::Operation *phi_operation = builder.createOperation(operation_state);
王明冬 已提交
334
    for (size_t index = 0; index < phi_kernel_desc.output_types.size();
335 336
         ++index) {
      kernel_op.getResult(index).replaceAllUsesWith(
337
          phi_operation->getResult(index));
338 339
    }
    kernel_op.erase();
340 341
  }
}
342

王明冬 已提交
343 344 345 346 347 348 349 350 351
void PhiOpConvertPass::getDependentDialects(
    mlir::DialectRegistry &registry) const {
  registry.insert<infrt::InfrtDialect>();
  registry.insert<infrt::phi::PHIDialect>();
  registry.insert<infrt::phi::PHIDenseTensorDialect>();
  registry.insert<infrt::phi::PHICPUKernelDialect>();
  registry.insert<infrt::phi::PHIGPUKernelDialect>();
}

352 353
}  // namespace

王明冬 已提交
354 355
mlir::PassRegistration<PhiOpConvertPass> phi_op_convert;

356
std::unique_ptr<mlir::Pass> infrt::CreatePhiOpCvtPass(
357
    std::vector<Place> valid_places) {
王明冬 已提交
358 359
  return std::make_unique<PhiOpConvertPass>(valid_places);
}