phi_op_convert_pass.cc 10.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

王明冬 已提交
15
#include "paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h"
16 17 18 19 20

#include <glog/logging.h>
#include <llvm/ADT/SetVector.h>
#include <mlir/Analysis/SliceAnalysis.h>
#include <mlir/IR/Builders.h>
21 22
#include <mlir/IR/Operation.h>
#include <mlir/IR/OperationSupport.h>
23 24 25 26
#include <list>
#include <unordered_set>
#include <vector>

王明冬 已提交
27
#include "paddle/infrt/common/string.h"
28
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
29
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
王明冬 已提交
30 31
#include "paddle/infrt/dialect/phi/ir/phi_base.h"
#include "paddle/infrt/dialect/phi/ir/phi_kernels.h"
32 33 34
#include "paddle/infrt/dialect/phi/pass/kernel_op_desc.h"
#include "paddle/infrt/dialect/phi/pass/proto_arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h"
35
#include "paddle/phi/core/kernel_factory.h"
36
#include "paddle/phi/ops/compat/signatures.h"
37 38

namespace {
王明冬 已提交
39 40
class PhiOpConvertPass
    : public mlir::PassWrapper<PhiOpConvertPass, mlir::FunctionPass> {
41
 public:
王明冬 已提交
42
  ::llvm::StringRef getName() const override { return "PhiOpConvertPass"; }
43
  void runOnFunction() override;
王明冬 已提交
44 45
  PhiOpConvertPass();
  explicit PhiOpConvertPass(const std::vector<infrt::Place> &valid_places)
46 47
      : valid_places_(valid_places) {}

王明冬 已提交
48 49 50 51 52 53 54
  PhiOpConvertPass(const PhiOpConvertPass &other)
      : mlir::PassWrapper<PhiOpConvertPass, mlir::FunctionPass>(*this),
        valid_places_(other.valid_places_) {}

  ::llvm::StringRef getArgument() const override { return "phi-op-convert"; }
  void getDependentDialects(mlir::DialectRegistry &registry) const override;

55 56
 private:
  void convertStage();
王明冬 已提交
57 58 59 60 61 62 63 64
  void dispatchStage();

  // Force a specified data format for all layout sensitive operations.
  Option<std::string> valid_places_options_{
      *this,
      "valid-targets",
      llvm::cl::desc("Set the valid target, [CPU-FP32-NCHW]")};

65 66
  std::vector<infrt::Place> valid_places_;
};
王明冬 已提交
67 68
// Implementation of the PhiOpConvertPass.
void PhiOpConvertPass::runOnFunction() {
69
  convertStage();
王明冬 已提交
70
  dispatchStage();
71
}
王明冬 已提交
72 73

void PhiOpConvertPass::convertStage() {
74 75 76 77 78 79 80 81 82
  mlir::Block &body = getFunction().front();
  std::vector<mlir::Operation *> worklist;
  for (auto &op : body.without_terminator()) {
    worklist.push_back(&op);
  }
  mlir::OpBuilder builder(&body, body.begin());
  while (!worklist.empty()) {
    auto *op = worklist.back();
    worklist.pop_back();
王明冬 已提交
83
    if (!op) continue;
84

王明冬 已提交
85
    auto op_name = op->getName().getIdentifier().str();
86 87 88 89 90 91 92 93

    // only convert op in pd dialect.
    if (op_name.substr(0, 3) != "pd.") continue;
    op_name = op_name.substr(3);
    if (pd_dialect_inputs_info_map_.find(op_name) ==
            pd_dialect_inputs_info_map_.end() ||
        pd_dialect_outputs_info_map_.find(op_name) ==
            pd_dialect_outputs_info_map_.end()) {
王明冬 已提交
94
      LOG(WARNING) << "No op info found for " << op_name;
95 96 97
      // Todo: print log
      continue;
    }
98 99
    auto loc = getFunction().getLoc();
    builder.setInsertionPoint(op);
100 101 102

    if (!::phi::OpUtilsMap::Instance().HasArgumentMappingFn(op_name)) {
      op_name = phi::TransToPhiKernelName(op_name);
103 104 105
      auto kernel_op = builder.create<infrt::KernelOp>(loc,
                                                       op->getResultTypes(),
                                                       op->getOperands(),
106
                                                       op_name,
107 108 109 110 111 112
                                                       op->getAttrDictionary());
      op->replaceAllUsesWith(kernel_op.getResults());
    } else {
      ::phi::KernelSignature kernel_sign =
          ::phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_name)(
              infrt::ProtoArgumentMappingContext(op));
113 114
      VLOG(3) << "IncompatiblePhiKernel: op(" << op_name << "), kernel("
              << kernel_sign.name << ")";
115 116 117 118 119 120 121 122 123 124 125
      // resort input&output according to kernel_sign
      ::llvm::SmallVector<mlir::Value, 4> inputs, ori_output;
      ::llvm::SmallVector<mlir::Type, 4> output_types;
      for (const std::string &str : std::get<0>(kernel_sign.args)) {
        if (pd_dialect_inputs_info_map_.at(op_name).count(str) == 0) {
          LOG(ERROR) << "No input info for Op " << op_name << " and argument "
                     << str;
          return;
        }
        uint8_t index = pd_dialect_inputs_info_map_.at(op_name).at(str);
        inputs.push_back(op->getOperands()[index]);
126 127
      }

128 129 130 131 132 133 134 135 136 137 138 139 140 141
      for (const std::string &str : std::get<2>(kernel_sign.args)) {
        if (pd_dialect_outputs_info_map_.at(op_name).count(str) == 0) {
          LOG(ERROR) << "No output info for Op " << op_name << " and argument "
                     << str;
          return;
        }
        uint8_t index = pd_dialect_outputs_info_map_.at(op_name).at(str);
        output_types.push_back(op->getResultTypes()[index]);
        ori_output.push_back(op->getResult(index));
      }
      auto kernel_op = builder.create<infrt::KernelOp>(
          loc, output_types, inputs, kernel_sign.name, op->getAttrDictionary());
      for (size_t index = 0; index < ori_output.size(); ++index) {
        ori_output[index].replaceAllUsesWith(kernel_op.getResult(index));
142 143
      }
    }
王明冬 已提交
144
    CHECK(op->use_empty());
145 146 147
    op->erase();
  }
}
王明冬 已提交
148 149

void PhiOpConvertPass::dispatchStage() {
150 151 152 153 154 155
  std::vector<infrt::KernelOp> worklist;
  mlir::Block &block = getFunction().front();
  for (auto &op : block) {
    infrt::KernelOp kernel_op = ::llvm::dyn_cast_or_null<infrt::KernelOp>(&op);
    if (nullptr != kernel_op) worklist.push_back(kernel_op);
  }
156 157

  mlir::OpBuilder builder(&block, block.begin());
158
  std::map<infrt::TargetType, mlir::Value> phi_context;
159 160
  for (infrt::KernelOp kernel_op : worklist) {
    std::string kernel_name = kernel_op.name().str();
161
    std::vector<infrt::PhiKernelDesc> candidates =
王明冬 已提交
162
        GetCandidateKernels(kernel_name, valid_places_);
163 164 165 166 167 168 169
    if (candidates.empty()) {
      LOG(FATAL) << "No candidate kernels for op:" << kernel_name;
      continue;
    }
    builder.setInsertionPoint(kernel_op);

    // Todo: Implimentation the concrete pass pick strategy
170
    const infrt::PhiKernelDesc &phi_kernel_desc = candidates.front();
171

172
    kernel_name =
王明冬 已提交
173
        infrt::getPhiTargetPrefix(phi_kernel_desc.kernel_type.target) +
174
        kernel_name +
王明冬 已提交
175 176
        infrt::getPhiPrecisionSuffix(phi_kernel_desc.kernel_type.precision) +
        infrt::getPhiLayoutSuffix(phi_kernel_desc.kernel_type.layout);
177 178 179 180

    mlir::OperationName operation_name(kernel_name, kernel_op.getContext());
    mlir::OperationState operation_state(kernel_op.getLoc(), operation_name);

王明冬 已提交
181
    if (phi_context.find(phi_kernel_desc.kernel_type.target) ==
182
        phi_context.end()) {
王明冬 已提交
183
      switch (phi_kernel_desc.kernel_type.target) {
184
        case infrt::TargetType::CPU: {
185 186
          auto context_value =
              builder
187
                  .create<infrt::phi::CreateCPUContextOp>(
188
                      kernel_op.getLoc(),
189 190
                      infrt::phi::ContextType::get(kernel_op.getContext(),
                                                   infrt::TargetType::CPU))
191
                  .output();
192
          phi_context[infrt::TargetType::CPU] = context_value;
193
        } break;
194 195
        case infrt::TargetType::GPU:
        case infrt::TargetType::UNK:
196 197 198 199 200 201
        default:
          LOG(FATAL) << "Unsupported TargetType";
          break;
      }
    }
    operation_state.addOperands(
王明冬 已提交
202 203 204 205
        phi_context.at(phi_kernel_desc.kernel_type.target));

    for (size_t index = 0; index < phi_kernel_desc.input_types.size();
         ++index) {
206
      mlir::Value input = kernel_op.getOperand(index);
王明冬 已提交
207
      auto cvt_tensor_type_op = builder.create<infrt::TensorCastOp>(
208
          kernel_op.getLoc(),
209 210
          infrt::DenseTensorType::get(
              kernel_op.getContext(),
王明冬 已提交
211 212 213
              phi_kernel_desc.input_types[index].target,
              phi_kernel_desc.input_types[index].precision,
              phi_kernel_desc.input_types[index].layout),
214 215 216
          input);
      operation_state.addOperands(cvt_tensor_type_op.output());
    }
王明冬 已提交
217 218

    for (size_t index = 0; index < phi_kernel_desc.output_types.size();
219
         ++index) {
220 221
      operation_state.addTypes(infrt::DenseTensorType::get(
          kernel_op.getContext(),
王明冬 已提交
222 223 224
          phi_kernel_desc.output_types[index].target,
          phi_kernel_desc.output_types[index].precision,
          phi_kernel_desc.output_types[index].layout));
225 226 227
    }
    operation_state.addAttributes(kernel_op.attrsAttr().getValue());
    mlir::Operation *phi_operation = builder.createOperation(operation_state);
王明冬 已提交
228
    for (size_t index = 0; index < phi_kernel_desc.output_types.size();
229 230
         ++index) {
      mlir::Value input = phi_operation->getResult(index);
王明冬 已提交
231
      auto cvt_tensor_type_op = builder.create<infrt::TensorCastOp>(
232 233 234 235 236
          kernel_op.getLoc(), kernel_op.getResultTypes()[index], input);
      kernel_op.getResult(index).replaceAllUsesWith(
          cvt_tensor_type_op.output());
    }
    kernel_op.erase();
237 238
  }
}
239

王明冬 已提交
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
PhiOpConvertPass::PhiOpConvertPass() {
  if (!valid_places_options_.hasValue()) {
    valid_places_.emplace_back(infrt::TargetType::CPU,
                               infrt::PrecisionType::FLOAT32,
                               infrt::LayoutType::NCHW);
    return;
  }

  LOG(FATAL) << "To be done for specifying places in command line";
}

void PhiOpConvertPass::getDependentDialects(
    mlir::DialectRegistry &registry) const {
  registry.insert<infrt::InfrtDialect>();
  registry.insert<infrt::phi::PHIDialect>();
  registry.insert<infrt::phi::PHIDenseTensorDialect>();
  registry.insert<infrt::phi::PHICPUKernelDialect>();
  registry.insert<infrt::phi::PHIGPUKernelDialect>();
}

260 261
}  // namespace

王明冬 已提交
262 263
mlir::PassRegistration<PhiOpConvertPass> phi_op_convert;

264
std::unique_ptr<mlir::Pass> infrt::CreatePhiOpCvtPass(
265
    std::vector<Place> valid_places) {
王明冬 已提交
266 267 268
  return std::make_unique<PhiOpConvertPass>(valid_places);
}

269
std::unique_ptr<mlir::Pass> infrt::CreatePhiOpCvtPass() {
王明冬 已提交
270
  return std::make_unique<PhiOpConvertPass>();
271
}