phi_op_convert_pass.cc 9.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

王明冬 已提交
15
#include "paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h"
16 17 18 19 20

#include <glog/logging.h>
#include <llvm/ADT/SetVector.h>
#include <mlir/Analysis/SliceAnalysis.h>
#include <mlir/IR/Builders.h>
21 22
#include <mlir/IR/Operation.h>
#include <mlir/IR/OperationSupport.h>
23 24 25 26
#include <list>
#include <unordered_set>
#include <vector>

王明冬 已提交
27
#include "paddle/infrt/common/string.h"
28
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
29
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
王明冬 已提交
30 31
#include "paddle/infrt/dialect/phi/ir/phi_base.h"
#include "paddle/infrt/dialect/phi/ir/phi_kernels.h"
32 33 34 35
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include "paddle/infrt/dialect/phi/pass/proto_arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/ops/compat/signatures.h"
36 37

namespace {
王明冬 已提交
38 39
class PhiOpConvertPass
    : public mlir::PassWrapper<PhiOpConvertPass, mlir::FunctionPass> {
40
 public:
王明冬 已提交
41
  ::llvm::StringRef getName() const override { return "PhiOpConvertPass"; }
42
  void runOnFunction() override;
王明冬 已提交
43 44
  PhiOpConvertPass();
  explicit PhiOpConvertPass(const std::vector<infrt::Place> &valid_places)
45 46
      : valid_places_(valid_places) {}

王明冬 已提交
47 48 49 50 51 52 53
  PhiOpConvertPass(const PhiOpConvertPass &other)
      : mlir::PassWrapper<PhiOpConvertPass, mlir::FunctionPass>(*this),
        valid_places_(other.valid_places_) {}

  ::llvm::StringRef getArgument() const override { return "phi-op-convert"; }
  void getDependentDialects(mlir::DialectRegistry &registry) const override;

54 55
 private:
  void convertStage();
王明冬 已提交
56 57 58 59 60 61 62 63
  void dispatchStage();

  // Force a specified data format for all layout sensitive operations.
  Option<std::string> valid_places_options_{
      *this,
      "valid-targets",
      llvm::cl::desc("Set the valid target, [CPU-FP32-NCHW]")};

64 65
  std::vector<infrt::Place> valid_places_;
};
王明冬 已提交
66 67
// Implementation of the PhiOpConvertPass.
void PhiOpConvertPass::runOnFunction() {
68
  convertStage();
王明冬 已提交
69
  dispatchStage();
70
}
王明冬 已提交
71 72

void PhiOpConvertPass::convertStage() {
73 74 75 76 77 78 79 80 81
  mlir::Block &body = getFunction().front();
  std::vector<mlir::Operation *> worklist;
  for (auto &op : body.without_terminator()) {
    worklist.push_back(&op);
  }
  mlir::OpBuilder builder(&body, body.begin());
  while (!worklist.empty()) {
    auto *op = worklist.back();
    worklist.pop_back();
王明冬 已提交
82
    if (!op) continue;
83

王明冬 已提交
84
    auto op_name = op->getName().getIdentifier().str();
85 86 87 88 89 90 91 92

    // only convert op in pd dialect.
    if (op_name.substr(0, 3) != "pd.") continue;
    op_name = op_name.substr(3);
    if (pd_dialect_inputs_info_map_.find(op_name) ==
            pd_dialect_inputs_info_map_.end() ||
        pd_dialect_outputs_info_map_.find(op_name) ==
            pd_dialect_outputs_info_map_.end()) {
王明冬 已提交
93
      LOG(WARNING) << "No op info found for " << op_name;
94 95 96 97
      // Todo: print log
      continue;
    }

98 99
    ::phi::KernelSignature kernel_sign =
        ::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)(
100
            infrt::ProtoArgumentMappingContext(op));
101 102 103 104 105
    // resort input&output according to kernel_sign
    ::llvm::SmallVector<mlir::Value, 4> inputs, ori_output;
    ::llvm::SmallVector<mlir::Type, 4> output_types;
    for (const std::string &str : std::get<0>(kernel_sign.args)) {
      if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) {
王明冬 已提交
106 107
        LOG(ERROR) << "No input info for Op " << op_name << " and argument "
                   << str;
108 109 110 111 112 113 114 115
        return;
      }
      uint8_t index = pd_dialect_inputs_info_map_.at(op_name).at(str);
      inputs.push_back(op->getOperands()[index]);
    }

    for (const std::string &str : std::get<2>(kernel_sign.args)) {
      if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) {
王明冬 已提交
116 117
        LOG(ERROR) << "No output info for Op " << op_name << " and argument "
                   << str;
118 119 120 121 122 123 124 125 126 127 128 129 130 131
        return;
      }
      uint8_t index = pd_dialect_outputs_info_map_.at(op_name).at(str);
      output_types.push_back(op->getResultTypes()[index]);
      ori_output.push_back(op->getResult(index));
    }

    auto loc = getFunction().getLoc();
    builder.setInsertionPoint(op);
    auto kernel_op = builder.create<infrt::KernelOp>(
        loc, output_types, inputs, kernel_sign.name, op->getAttrDictionary());
    for (size_t index = 0; index < ori_output.size(); ++index) {
      ori_output[index].replaceAllUsesWith(kernel_op.getResult(index));
    }
王明冬 已提交
132 133

    CHECK(op->use_empty());
134 135 136
    op->erase();
  }
}
王明冬 已提交
137 138

void PhiOpConvertPass::dispatchStage() {
139 140 141 142 143 144
  std::vector<infrt::KernelOp> worklist;
  mlir::Block &block = getFunction().front();
  for (auto &op : block) {
    infrt::KernelOp kernel_op = ::llvm::dyn_cast_or_null<infrt::KernelOp>(&op);
    if (nullptr != kernel_op) worklist.push_back(kernel_op);
  }
145 146

  mlir::OpBuilder builder(&block, block.begin());
147
  std::map<infrt::TargetType, mlir::Value> phi_context;
148 149
  for (infrt::KernelOp kernel_op : worklist) {
    std::string kernel_name = kernel_op.name().str();
150
    std::vector<infrt::PhiKernelDesc> candidates =
王明冬 已提交
151
        GetCandidateKernels(kernel_name, valid_places_);
152 153 154 155 156 157 158
    if (candidates.empty()) {
      LOG(FATAL) << "No candidate kernels for op:" << kernel_name;
      continue;
    }
    builder.setInsertionPoint(kernel_op);

    // Todo: Implimentation the concrete pass pick strategy
159
    const infrt::PhiKernelDesc &phi_kernel_desc = candidates.front();
160

161
    kernel_name =
王明冬 已提交
162
        infrt::getPhiTargetPrefix(phi_kernel_desc.kernel_type.target) +
163
        kernel_name +
王明冬 已提交
164 165
        infrt::getPhiPrecisionSuffix(phi_kernel_desc.kernel_type.precision) +
        infrt::getPhiLayoutSuffix(phi_kernel_desc.kernel_type.layout);
166 167 168 169

    mlir::OperationName operation_name(kernel_name, kernel_op.getContext());
    mlir::OperationState operation_state(kernel_op.getLoc(), operation_name);

王明冬 已提交
170
    if (phi_context.find(phi_kernel_desc.kernel_type.target) ==
171
        phi_context.end()) {
王明冬 已提交
172
      switch (phi_kernel_desc.kernel_type.target) {
173
        case infrt::TargetType::CPU: {
174 175
          auto context_value =
              builder
176
                  .create<infrt::phi::CreateCPUContextOp>(
177
                      kernel_op.getLoc(),
178 179
                      infrt::phi::ContextType::get(kernel_op.getContext(),
                                                   infrt::TargetType::CPU))
180
                  .output();
181
          phi_context[infrt::TargetType::CPU] = context_value;
182
        } break;
183 184
        case infrt::TargetType::GPU:
        case infrt::TargetType::UNK:
185 186 187 188 189 190
        default:
          LOG(FATAL) << "Unsupported TargetType";
          break;
      }
    }
    operation_state.addOperands(
王明冬 已提交
191 192 193 194
        phi_context.at(phi_kernel_desc.kernel_type.target));

    for (size_t index = 0; index < phi_kernel_desc.input_types.size();
         ++index) {
195
      mlir::Value input = kernel_op.getOperand(index);
王明冬 已提交
196
      auto cvt_tensor_type_op = builder.create<infrt::TensorCastOp>(
197
          kernel_op.getLoc(),
198 199
          infrt::DenseTensorType::get(
              kernel_op.getContext(),
王明冬 已提交
200 201 202
              phi_kernel_desc.input_types[index].target,
              phi_kernel_desc.input_types[index].precision,
              phi_kernel_desc.input_types[index].layout),
203 204 205
          input);
      operation_state.addOperands(cvt_tensor_type_op.output());
    }
王明冬 已提交
206 207

    for (size_t index = 0; index < phi_kernel_desc.output_types.size();
208
         ++index) {
209 210
      operation_state.addTypes(infrt::DenseTensorType::get(
          kernel_op.getContext(),
王明冬 已提交
211 212 213
          phi_kernel_desc.output_types[index].target,
          phi_kernel_desc.output_types[index].precision,
          phi_kernel_desc.output_types[index].layout));
214 215 216
    }
    operation_state.addAttributes(kernel_op.attrsAttr().getValue());
    mlir::Operation *phi_operation = builder.createOperation(operation_state);
王明冬 已提交
217
    for (size_t index = 0; index < phi_kernel_desc.output_types.size();
218 219
         ++index) {
      mlir::Value input = phi_operation->getResult(index);
王明冬 已提交
220
      auto cvt_tensor_type_op = builder.create<infrt::TensorCastOp>(
221 222 223 224 225
          kernel_op.getLoc(), kernel_op.getResultTypes()[index], input);
      kernel_op.getResult(index).replaceAllUsesWith(
          cvt_tensor_type_op.output());
    }
    kernel_op.erase();
226 227
  }
}
228

王明冬 已提交
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
PhiOpConvertPass::PhiOpConvertPass() {
  if (!valid_places_options_.hasValue()) {
    valid_places_.emplace_back(infrt::TargetType::CPU,
                               infrt::PrecisionType::FLOAT32,
                               infrt::LayoutType::NCHW);
    return;
  }

  LOG(FATAL) << "To be done for specifying places in command line";
}

void PhiOpConvertPass::getDependentDialects(
    mlir::DialectRegistry &registry) const {
  registry.insert<infrt::InfrtDialect>();
  registry.insert<infrt::phi::PHIDialect>();
  registry.insert<infrt::phi::PHIDenseTensorDialect>();
  registry.insert<infrt::phi::PHICPUKernelDialect>();
  registry.insert<infrt::phi::PHIGPUKernelDialect>();
}

249 250
}  // namespace

王明冬 已提交
251 252
mlir::PassRegistration<PhiOpConvertPass> phi_op_convert;

253 254
std::unique_ptr<mlir::Pass> infrt::createPhiOpCvtPass(
    std::vector<Place> valid_places) {
王明冬 已提交
255 256 257 258 259
  return std::make_unique<PhiOpConvertPass>(valid_places);
}

std::unique_ptr<mlir::Pass> infrt::createPhiOpCvtPass() {
  return std::make_unique<PhiOpConvertPass>();
260
}