phi_op_convert_pass.cc 10.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

王明冬 已提交
15
#include "paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h"
16 17 18 19 20

#include <glog/logging.h>
#include <llvm/ADT/SetVector.h>
#include <mlir/Analysis/SliceAnalysis.h>
#include <mlir/IR/Builders.h>
21 22
#include <mlir/IR/Operation.h>
#include <mlir/IR/OperationSupport.h>
23 24 25 26
#include <list>
#include <unordered_set>
#include <vector>

王明冬 已提交
27
#include "paddle/infrt/common/string.h"
28
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
29
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
王明冬 已提交
30 31
#include "paddle/infrt/dialect/phi/ir/phi_base.h"
#include "paddle/infrt/dialect/phi/ir/phi_kernels.h"
32 33 34
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include "paddle/infrt/dialect/phi/pass/proto_arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h"
35
#include "paddle/phi/core/kernel_factory.h"
36
#include "paddle/phi/ops/compat/signatures.h"
37 38

namespace {
王明冬 已提交
39 40
class PhiOpConvertPass
    : public mlir::PassWrapper<PhiOpConvertPass, mlir::FunctionPass> {
41
 public:
王明冬 已提交
42
  ::llvm::StringRef getName() const override { return "PhiOpConvertPass"; }
43
  void runOnFunction() override;
王明冬 已提交
44 45
  PhiOpConvertPass();
  explicit PhiOpConvertPass(const std::vector<infrt::Place> &valid_places)
46 47
      : valid_places_(valid_places) {}

王明冬 已提交
48 49 50 51 52 53 54
  PhiOpConvertPass(const PhiOpConvertPass &other)
      : mlir::PassWrapper<PhiOpConvertPass, mlir::FunctionPass>(*this),
        valid_places_(other.valid_places_) {}

  ::llvm::StringRef getArgument() const override { return "phi-op-convert"; }
  void getDependentDialects(mlir::DialectRegistry &registry) const override;

55 56
 private:
  void convertStage();
王明冬 已提交
57 58 59 60 61 62 63 64
  void dispatchStage();

  // Force a specified data format for all layout sensitive operations.
  Option<std::string> valid_places_options_{
      *this,
      "valid-targets",
      llvm::cl::desc("Set the valid target, [CPU-FP32-NCHW]")};

65 66
  std::vector<infrt::Place> valid_places_;
};
王明冬 已提交
67 68
// Implementation of the PhiOpConvertPass.
void PhiOpConvertPass::runOnFunction() {
69
  convertStage();
王明冬 已提交
70
  dispatchStage();
71
}
王明冬 已提交
72 73

void PhiOpConvertPass::convertStage() {
74 75 76 77 78 79 80 81 82
  mlir::Block &body = getFunction().front();
  std::vector<mlir::Operation *> worklist;
  for (auto &op : body.without_terminator()) {
    worklist.push_back(&op);
  }
  mlir::OpBuilder builder(&body, body.begin());
  while (!worklist.empty()) {
    auto *op = worklist.back();
    worklist.pop_back();
王明冬 已提交
83
    if (!op) continue;
84

王明冬 已提交
85
    auto op_name = op->getName().getIdentifier().str();
86 87 88 89 90 91 92 93

    // only convert op in pd dialect.
    if (op_name.substr(0, 3) != "pd.") continue;
    op_name = op_name.substr(3);
    if (pd_dialect_inputs_info_map_.find(op_name) ==
            pd_dialect_inputs_info_map_.end() ||
        pd_dialect_outputs_info_map_.find(op_name) ==
            pd_dialect_outputs_info_map_.end()) {
王明冬 已提交
94
      LOG(WARNING) << "No op info found for " << op_name;
95 96 97
      // Todo: print log
      continue;
    }
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
    auto loc = getFunction().getLoc();
    builder.setInsertionPoint(op);
    if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(op_name)) {
      std::string kernel_name = phi::TransToPhiKernelName(op_name);
      auto kernel_op = builder.create<infrt::KernelOp>(loc,
                                                       op->getResultTypes(),
                                                       op->getOperands(),
                                                       kernel_name,
                                                       op->getAttrDictionary());
      op->replaceAllUsesWith(kernel_op.getResults());
    } else {
      ::phi::KernelSignature kernel_sign =
          ::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)(
              infrt::ProtoArgumentMappingContext(op));
      // resort input&output according to kernel_sign
      ::llvm::SmallVector<mlir::Value, 4> inputs, ori_output;
      ::llvm::SmallVector<mlir::Type, 4> output_types;
      for (const std::string &str : std::get<0>(kernel_sign.args)) {
        if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) {
          LOG(ERROR) << "No input info for Op " << op_name << " and argument "
                     << str;
          return;
        }
        uint8_t index = pd_dialect_inputs_info_map_.at(op_name).at(str);
        inputs.push_back(op->getOperands()[index]);
123 124
      }

125 126 127 128 129 130 131 132 133 134 135 136 137 138
      for (const std::string &str : std::get<2>(kernel_sign.args)) {
        if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) {
          LOG(ERROR) << "No output info for Op " << op_name << " and argument "
                     << str;
          return;
        }
        uint8_t index = pd_dialect_outputs_info_map_.at(op_name).at(str);
        output_types.push_back(op->getResultTypes()[index]);
        ori_output.push_back(op->getResult(index));
      }
      auto kernel_op = builder.create<infrt::KernelOp>(
          loc, output_types, inputs, kernel_sign.name, op->getAttrDictionary());
      for (size_t index = 0; index < ori_output.size(); ++index) {
        ori_output[index].replaceAllUsesWith(kernel_op.getResult(index));
139 140
      }
    }
王明冬 已提交
141
    CHECK(op->use_empty());
142 143 144
    op->erase();
  }
}
王明冬 已提交
145 146

void PhiOpConvertPass::dispatchStage() {
147 148 149 150 151 152
  std::vector<infrt::KernelOp> worklist;
  mlir::Block &block = getFunction().front();
  for (auto &op : block) {
    infrt::KernelOp kernel_op = ::llvm::dyn_cast_or_null<infrt::KernelOp>(&op);
    if (nullptr != kernel_op) worklist.push_back(kernel_op);
  }
153 154

  mlir::OpBuilder builder(&block, block.begin());
155
  std::map<infrt::TargetType, mlir::Value> phi_context;
156 157
  for (infrt::KernelOp kernel_op : worklist) {
    std::string kernel_name = kernel_op.name().str();
158
    std::vector<infrt::PhiKernelDesc> candidates =
王明冬 已提交
159
        GetCandidateKernels(kernel_name, valid_places_);
160 161 162 163 164 165 166
    if (candidates.empty()) {
      LOG(FATAL) << "No candidate kernels for op:" << kernel_name;
      continue;
    }
    builder.setInsertionPoint(kernel_op);

    // Todo: Implimentation the concrete pass pick strategy
167
    const infrt::PhiKernelDesc &phi_kernel_desc = candidates.front();
168

169
    kernel_name =
王明冬 已提交
170
        infrt::getPhiTargetPrefix(phi_kernel_desc.kernel_type.target) +
171
        kernel_name +
王明冬 已提交
172 173
        infrt::getPhiPrecisionSuffix(phi_kernel_desc.kernel_type.precision) +
        infrt::getPhiLayoutSuffix(phi_kernel_desc.kernel_type.layout);
174 175 176 177

    mlir::OperationName operation_name(kernel_name, kernel_op.getContext());
    mlir::OperationState operation_state(kernel_op.getLoc(), operation_name);

王明冬 已提交
178
    if (phi_context.find(phi_kernel_desc.kernel_type.target) ==
179
        phi_context.end()) {
王明冬 已提交
180
      switch (phi_kernel_desc.kernel_type.target) {
181
        case infrt::TargetType::CPU: {
182 183
          auto context_value =
              builder
184
                  .create<infrt::phi::CreateCPUContextOp>(
185
                      kernel_op.getLoc(),
186 187
                      infrt::phi::ContextType::get(kernel_op.getContext(),
                                                   infrt::TargetType::CPU))
188
                  .output();
189
          phi_context[infrt::TargetType::CPU] = context_value;
190
        } break;
191 192
        case infrt::TargetType::GPU:
        case infrt::TargetType::UNK:
193 194 195 196 197 198
        default:
          LOG(FATAL) << "Unsupported TargetType";
          break;
      }
    }
    operation_state.addOperands(
王明冬 已提交
199 200 201 202
        phi_context.at(phi_kernel_desc.kernel_type.target));

    for (size_t index = 0; index < phi_kernel_desc.input_types.size();
         ++index) {
203
      mlir::Value input = kernel_op.getOperand(index);
王明冬 已提交
204
      auto cvt_tensor_type_op = builder.create<infrt::TensorCastOp>(
205
          kernel_op.getLoc(),
206 207
          infrt::DenseTensorType::get(
              kernel_op.getContext(),
王明冬 已提交
208 209 210
              phi_kernel_desc.input_types[index].target,
              phi_kernel_desc.input_types[index].precision,
              phi_kernel_desc.input_types[index].layout),
211 212 213
          input);
      operation_state.addOperands(cvt_tensor_type_op.output());
    }
王明冬 已提交
214 215

    for (size_t index = 0; index < phi_kernel_desc.output_types.size();
216
         ++index) {
217 218
      operation_state.addTypes(infrt::DenseTensorType::get(
          kernel_op.getContext(),
王明冬 已提交
219 220 221
          phi_kernel_desc.output_types[index].target,
          phi_kernel_desc.output_types[index].precision,
          phi_kernel_desc.output_types[index].layout));
222 223 224
    }
    operation_state.addAttributes(kernel_op.attrsAttr().getValue());
    mlir::Operation *phi_operation = builder.createOperation(operation_state);
王明冬 已提交
225
    for (size_t index = 0; index < phi_kernel_desc.output_types.size();
226 227
         ++index) {
      mlir::Value input = phi_operation->getResult(index);
王明冬 已提交
228
      auto cvt_tensor_type_op = builder.create<infrt::TensorCastOp>(
229 230 231 232 233
          kernel_op.getLoc(), kernel_op.getResultTypes()[index], input);
      kernel_op.getResult(index).replaceAllUsesWith(
          cvt_tensor_type_op.output());
    }
    kernel_op.erase();
234 235
  }
}
236

王明冬 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
PhiOpConvertPass::PhiOpConvertPass() {
  if (!valid_places_options_.hasValue()) {
    valid_places_.emplace_back(infrt::TargetType::CPU,
                               infrt::PrecisionType::FLOAT32,
                               infrt::LayoutType::NCHW);
    return;
  }

  LOG(FATAL) << "To be done for specifying places in command line";
}

void PhiOpConvertPass::getDependentDialects(
    mlir::DialectRegistry &registry) const {
  registry.insert<infrt::InfrtDialect>();
  registry.insert<infrt::phi::PHIDialect>();
  registry.insert<infrt::phi::PHIDenseTensorDialect>();
  registry.insert<infrt::phi::PHICPUKernelDialect>();
  registry.insert<infrt::phi::PHIGPUKernelDialect>();
}

257 258
}  // namespace

王明冬 已提交
259 260
mlir::PassRegistration<PhiOpConvertPass> phi_op_convert;

261 262
std::unique_ptr<mlir::Pass> infrt::createPhiOpCvtPass(
    std::vector<Place> valid_places) {
王明冬 已提交
263 264 265 266 267
  return std::make_unique<PhiOpConvertPass>(valid_places);
}

std::unique_ptr<mlir::Pass> infrt::createPhiOpCvtPass() {
  return std::make_unique<PhiOpConvertPass>();
268
}