mlir_to_runtime_translate.cc 23.9 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/infrt/host_context/mlir_to_runtime_translate.h"

17
#include <glog/logging.h>
Y
Yan Chunwei 已提交
18 19
#include <llvm/Support/SourceMgr.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
W
Wilber 已提交
20
#include <mlir/IR/BuiltinAttributes.h>
21 22
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
Y
Yan Chunwei 已提交
23 24 25 26 27 28 29 30 31 32 33 34
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/Parser.h>

#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "paddle/infrt/common/string.h"
35
#include "paddle/infrt/dialect/dense_tensor.h"
Y
Yan Chunwei 已提交
36 37 38 39 40 41 42 43 44 45
#include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/tensor_shape.h"
#include "paddle/infrt/host_context/core_runtime.h"
#include "paddle/infrt/host_context/kernel_frame.h"
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/mlir_function_executable.h"
#include "paddle/infrt/host_context/op_executable.h"
#include "paddle/infrt/host_context/value.h"
#include "paddle/infrt/tensor/tensor_shape.h"

W
Wilber 已提交
46 47 48 49 50 51 52
#ifdef INFRT_WITH_PHI
#ifdef INFRT_WITH_TRT
#include "paddle/infrt/kernel/tensorrt/trt_kernels.h"
#endif
#include "paddle/phi/core/dense_tensor.h"
#endif

53 54
namespace infrt {
namespace host_context {
Y
Yan Chunwei 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

template <typename T>
std::string DumpToString(T& op) {  // NOLINT
  std::string buffer;
  llvm::raw_string_ostream os(buffer);
  op.print(os);
  os.flush();
  return buffer;
}

struct MlirToRuntimeTranslator::Impl {
  mlir::ModuleOp module;
  // The runtime for a function call.
  CoreRuntimeBuilder* runtime{};
  // The current working op, the translator process the ops one by one, each
  // time it updates `cur_op` here to current op
  // working on.
  OpExecutableBuilder* cur_op{};

  // record the current function name.
  std::string cur_func_name;

  // Name to function definitions.
  std::unordered_map<std::string, mlir::FuncOp> func_defs;

  // Map from an operation to its results.
  std::unordered_map<const mlir::Operation*, std::vector<ValueRef>> op_results;
  llvm::DenseMap<mlir::Value, ValueRef> value_map;
};

bool MlirToRuntimeTranslator::EmitConstantOp(mlir::Operation* op) {
86
  if (!infrt::Startswith(op->getName().getStringRef().str(), "infrt.constant"))
Y
Yan Chunwei 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
    return false;
  VLOG(3) << "Emitting constant op [" << op->getName().getStringRef().str()
          << "]";

  auto attr = op->getAttr("value");
  if (attr.isa<mlir::FloatAttr>()) {
    if (attr.getType().isF32()) {
      impl_->op_results[op] = {ValueRef(
          static_cast<float>(attr.cast<mlir::FloatAttr>().getValueAsDouble()))};
    } else if (attr.getType().isF64()) {
      impl_->op_results[op] = {ValueRef(static_cast<double>(
          attr.cast<mlir::FloatAttr>().getValueAsDouble()))};
    } else {
      LOG(FATAL) << "Not supported attribute type";
    }
    return true;
  }

  if (attr.isa<mlir::IntegerAttr>()) {
    if (attr.getType().isInteger(32)) {
      impl_->op_results[op] = {ValueRef(
          static_cast<int32_t>(attr.cast<mlir::IntegerAttr>().getSInt()))};
    } else if (attr.getType().isInteger(64)) {
      impl_->op_results[op] = {ValueRef(
          static_cast<int64_t>(attr.cast<mlir::IntegerAttr>().getSInt()))};
    } else if (attr.getType().isInteger(1)) {
      impl_->op_results[op] = {
          ValueRef(static_cast<bool>(attr.cast<mlir::IntegerAttr>().getInt()))};
    } else {
      LOG(FATAL) << "Not supported attribute type";
    }
    return true;
  }

  LOG(FATAL) << "Not supported constant attribute type";
  return true;
}

template <>
126
paddle::optional<int32_t> MlirToRuntimeTranslator::EmitAttribute(
127
    const mlir::Attribute& attr) {
128
  if (!attr.isa<mlir::IntegerAttr>()) return paddle::none;
129 130
  if (attr.isa<mlir::IntegerAttr>()) {
    auto val = attr.cast<mlir::IntegerAttr>();
Y
Yan Chunwei 已提交
131
    if (val.getType().isInteger(32)) {
132
      return val.getValue().getSExtValue();
Y
Yan Chunwei 已提交
133 134
    }
  }
135
  return paddle::none;
Y
Yan Chunwei 已提交
136 137
}
template <>
138
paddle::optional<int64_t> MlirToRuntimeTranslator::EmitAttribute(
139
    const mlir::Attribute& attr) {
140
  if (!attr.isa<mlir::IntegerAttr>()) return paddle::none;
141 142
  if (attr.isa<mlir::IntegerAttr>()) {
    auto val = attr.cast<mlir::IntegerAttr>();
Y
Yan Chunwei 已提交
143
    if (val.getType().isInteger(64)) {
144
      return val.getValue().getSExtValue();
Y
Yan Chunwei 已提交
145 146
    }
  }
147
  return paddle::none;
Y
Yan Chunwei 已提交
148 149 150 151
}

// TODO(Superjomn) Make double and float parsing share some thing.
template <>
152
paddle::optional<float> MlirToRuntimeTranslator::EmitAttribute(
153
    const mlir::Attribute& attr) {
154
  if (!attr.isa<mlir::FloatAttr>()) return paddle::none;
155 156
  if (attr.isa<mlir::FloatAttr>()) {
    auto val = attr.cast<mlir::FloatAttr>();
Y
Yan Chunwei 已提交
157 158
    if (val.getType().isF32()) return val.getValueAsDouble();
  }
159
  return paddle::none;
Y
Yan Chunwei 已提交
160 161
}

162
template <>
163
paddle::optional<bool> MlirToRuntimeTranslator::EmitAttribute(
164
    const mlir::Attribute& attr) {
165
  if (!attr.isa<mlir::BoolAttr>()) return paddle::none;
166 167 168 169
  if (attr.isa<mlir::BoolAttr>()) {
    auto val = attr.cast<mlir::BoolAttr>();
    return val.getValue();
  }
170
  return paddle::none;
171 172
}

Y
Yan Chunwei 已提交
173
template <>
174
paddle::optional<double> MlirToRuntimeTranslator::EmitAttribute(
175
    const mlir::Attribute& attr) {
176
  if (!attr.isa<mlir::FloatAttr>()) return paddle::none;
177 178
  if (attr.isa<mlir::FloatAttr>()) {
    auto val = attr.cast<mlir::FloatAttr>();
Y
Yan Chunwei 已提交
179 180
    if (val.getType().isF64()) return val.getValueAsDouble();
  }
181
  return paddle::none;
Y
Yan Chunwei 已提交
182 183
}

184
template <>
185
paddle::optional<::infrt::TargetType> MlirToRuntimeTranslator::EmitAttribute(
186
    const mlir::Attribute& attr) {
187
  if (!attr.isa<::infrt::TargetAttr>()) return paddle::none;
188 189 190
  if (attr.isa<::infrt::TargetAttr>()) {
    return attr.cast<::infrt::TargetAttr>().getTarget();
  }
191
  return paddle::none;
192 193 194
}

template <>
195
paddle::optional<::infrt::LayoutType> MlirToRuntimeTranslator::EmitAttribute(
196
    const mlir::Attribute& attr) {
197
  if (!attr.isa<::infrt::LayoutAttr>()) return paddle::none;
198 199 200
  if (attr.isa<::infrt::LayoutAttr>()) {
    return attr.cast<::infrt::LayoutAttr>().getLayout();
  }
201
  return paddle::none;
202 203 204
}

template <>
205
paddle::optional<::infrt::PrecisionType> MlirToRuntimeTranslator::EmitAttribute(
206
    const mlir::Attribute& attr) {
207
  if (!attr.isa<::infrt::PrecisionAttr>()) return paddle::none;
208 209 210
  if (attr.isa<::infrt::PrecisionAttr>()) {
    return attr.cast<::infrt::PrecisionAttr>().getPrecision();
  }
211
  return paddle::none;
212 213
}

Y
Yan Chunwei 已提交
214
template <>
215
paddle::optional<std::string> MlirToRuntimeTranslator::EmitAttribute(
216
    const mlir::Attribute& attr) {
217
  if (!attr.isa<mlir::StringAttr>()) return paddle::none;
218
  return attr.cast<mlir::StringAttr>().getValue().str();
Y
Yan Chunwei 已提交
219 220
}

221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
#define PROCESS_ARRAY_INT(type__, bits__)                                   \
  template <>                                                               \
  paddle::optional<std::vector<type__>>                                     \
  MlirToRuntimeTranslator::EmitAttribute(const mlir::Attribute& attr) {     \
    if (!attr.isa<mlir::ArrayAttr>()) return paddle::none;                  \
    auto array = attr.cast<mlir::ArrayAttr>();                              \
    CHECK(!array.empty());                                                  \
                                                                            \
    if (!array[0].getType().isInteger(bits__)) {                            \
      return paddle::none;                                                  \
    }                                                                       \
                                                                            \
    std::vector<type__> res;                                                \
    for (auto& v : array) {                                                 \
      res.push_back(v.cast<mlir::IntegerAttr>().getValue().getSExtValue()); \
    }                                                                       \
    return res;                                                             \
Y
Yan Chunwei 已提交
238 239
  }

240
PROCESS_ARRAY_INT(bool, 1);
Y
Yan Chunwei 已提交
241 242 243 244 245
PROCESS_ARRAY_INT(int16_t, 16);
PROCESS_ARRAY_INT(int32_t, 32);
PROCESS_ARRAY_INT(int64_t, 64);

template <>
246
paddle::optional<std::vector<float>> MlirToRuntimeTranslator::EmitAttribute(
247
    const mlir::Attribute& attr) {
248
  if (!attr.isa<mlir::ArrayAttr>()) return paddle::none;
249
  auto array = attr.cast<mlir::ArrayAttr>();
Y
Yan Chunwei 已提交
250 251
  CHECK(!array.empty());

252
  if (!array[0].getType().isF32()) return paddle::none;
Y
Yan Chunwei 已提交
253 254 255 256 257 258 259 260 261

  std::vector<float> res;
  for (auto& v : array) {
    res.push_back(v.cast<mlir::FloatAttr>().getValueAsDouble());
  }
  return res;
}

template <>
262
paddle::optional<std::vector<double>> MlirToRuntimeTranslator::EmitAttribute(
263
    const mlir::Attribute& attr) {
264
  if (!attr.isa<mlir::ArrayAttr>()) return paddle::none;
265
  auto array = attr.cast<mlir::ArrayAttr>();
Y
Yan Chunwei 已提交
266 267
  CHECK(!array.empty());

268
  if (!array[0].getType().isF64()) return paddle::none;
Y
Yan Chunwei 已提交
269 270 271 272 273 274 275 276 277

  std::vector<double> res;
  for (auto& v : array) {
    res.push_back(v.cast<mlir::FloatAttr>().getValueAsDouble());
  }
  return res;
}

static bool IsReturn(mlir::Operation* op) {
278
  return op->getName().getStringRef() == "infrt.return";
Y
Yan Chunwei 已提交
279 280
}

281 282
bool MlirToRuntimeTranslator::EmitGeneralOp(
    mlir::Operation* op, const KernelRegistry& kernel_registry) {
Y
Yan Chunwei 已提交
283 284 285 286 287
  CHECK(impl_->runtime);
  impl_->cur_op =
      impl_->runtime->NewOpExecutable(op->getName().getStringRef().str());

  VLOG(3) << "processing general op : " << op->getName().getStringRef().str();
W
Wilber 已提交
288 289 290 291 292 293 294 295 296 297 298 299
  // TODO(wilber): Find a more appropriate way to handle special cases.
  if (op->getName().getStringRef() == "trt.create_engine") {
#ifdef INFRT_WITH_TRT
    auto* symbols = impl_->runtime->symbol_table();
    ::infrt::kernel::tensorrt::MlirOperationWithInfrtSymbol mlir_operation;
    mlir_operation.operation = op;
    mlir_operation.symbol_table = symbols;
    impl_->cur_op->AppendArgument(new Value(mlir_operation));
    // TODO(wilber): how to pass DenseTensor to create_engine op? temporialiy
    // add a naive implement.
    for (int i = 0, e = op->getNumOperands(); i < e; ++i) {
      auto operand = op->getOperand(i);
W
Wilber 已提交
300
      Value* arg_value{nullptr};
W
Wilber 已提交
301 302
      if (operand.isa<mlir::BlockArgument>()) {
        mlir::BlockArgument arg = operand.dyn_cast<mlir::BlockArgument>();
W
Wilber 已提交
303 304 305 306 307 308
        arg_value = GetValue(arg);
      } else {
        arg_value = GetValue(operand);
        if (!arg_value) {
          auto upstream_op = operand.getDefiningOp();
          arg_value = GetOpResult(upstream_op);
W
Wilber 已提交
309 310
        }
      }
311
      if (arg_value->is_type<::Tensor>()) {
W
Wilber 已提交
312 313 314
        impl_->runtime->FeedInArgs(
            std::make_pair(std::to_string(i), ValueRef(arg_value)));
      }
W
Wilber 已提交
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
    }
#else
    CHECK(false) << "should not reach here";
#endif
  } else {
    // process operands
    for (int i = 0, e = op->getNumOperands(); i < e; i++) {
      // function argument as value
      auto operand = op->getOperand(i);
      /// if (operand.getKind() == mlir::Value::Kind::BlockArgument) {
      if (operand.isa<mlir::BlockArgument>()) {
        mlir::BlockArgument arg = operand.dyn_cast<mlir::BlockArgument>();
        Value* arg_value = GetValue(arg);
        impl_->cur_op->AppendArgument(arg_value);
        VLOG(3) << "* op mlir operand: " << DumpToString(arg) << " "
                << GetValue(arg);
        continue;
      }
Y
Yan Chunwei 已提交
333

W
Wilber 已提交
334 335 336 337 338 339 340 341
      // normal value
      Value* arg_value = GetValue(operand);
      if (!arg_value) {
        auto upstream_op = operand.getDefiningOp();
        arg_value = GetOpResult(upstream_op);
      }
      CHECK(arg_value) << "No-exist argument value found: "
                       << DumpToString(operand);
Y
Yan Chunwei 已提交
342 343
      impl_->cur_op->AppendArgument(arg_value);

W
Wilber 已提交
344 345
      VLOG(3) << "* op mlir operand: " << DumpToString(operand) << " "
              << GetValue(operand) << " vs " << arg_value;
Y
Yan Chunwei 已提交
346 347 348 349 350 351
    }
  }

  // process attributes
  auto attrs = op->getAttrs();

352
  // MLIR's underlying attr storage type is `Builtin_Dictionary`, and its
353 354
  // elements are sorted by name. The following code adapts the order of
  // function signatures of the phi operator library.
355 356 357 358
  llvm::SmallVector<Value*, 4> tmp;
  tmp.resize(attrs.size());
  const std::string& kernel_name = op->getName().getStringRef().str();
  const auto& attr_names = kernel_registry.GetAttrNameList(kernel_name);
359 360 361
  if (attrs.size()) {
    if (attr_names.empty()) {
      LOG(WARNING) << "The kernel `" << kernel_name
362
                   << "` has not been registered with attributes order ";
363 364 365 366 367 368 369
    } else {
      CHECK_EQ(attr_names.size(), attrs.size())
          << "The number of kernel `" << kernel_name
          << "` attributes specified by mlir (" << attrs.size()
          << ") is inconsistent with the registration (" << attr_names.size()
          << ").";
    }
370
  }
371

372 373 374 375 376 377 378 379 380
  auto get_offset = [](const char* attr,
                       const std::vector<const char*>& names,
                       const std::string& kernel_name) -> int {
    for (size_t i = 0; i < names.size(); ++i) {
      if (!std::strcmp(attr, names[i])) {
        return i;
      }
    }
    LOG(WARNING) << "The attribute `" << attr << "` of kernel `" << kernel_name
381
                 << "` is not properly register";
382 383 384
    return -1;
  };

Y
Yan Chunwei 已提交
385 386
  for (size_t i = 0; i < attrs.size(); i++) {
    auto& attr = attrs[i];
387 388 389 390 391 392
    int offset{};
    if (attr_names.size()) {
      offset = get_offset(attr.getName().data(), attr_names, kernel_name);
    } else {
      offset = i;
    }
393
    CHECK_GT(offset, -1);
394
    if (auto v = EmitAttribute<int32_t>(attr.getValue())) {
395
      tmp[offset] = new Value(*v);
396
    } else if (auto v = EmitAttribute<int64_t>(attr.getValue())) {
397
      tmp[offset] = new Value(*v);
398
    } else if (auto v = EmitAttribute<float>(attr.getValue())) {
399
      tmp[offset] = new Value(*v);
400
    } else if (auto v = EmitAttribute<double>(attr.getValue())) {
401
      tmp[offset] = new Value(*v);
402
    } else if (auto v = EmitAttribute<std::string>(attr.getValue())) {
403
      tmp[offset] = new Value(std::move(*v));
404
    } else if (auto v = EmitAttribute<bool>(attr.getValue())) {
405
      tmp[offset] = new Value(*v);
406
    } else if (auto v = EmitAttribute<::infrt::TargetType>(attr.getValue())) {
407
      tmp[offset] = new Value(*v);
408 409
    } else if (auto v =
                   EmitAttribute<::infrt::PrecisionType>(attr.getValue())) {
410
      tmp[offset] = new Value(*v);
411
    } else if (auto v = EmitAttribute<::infrt::LayoutType>(attr.getValue())) {
412
      tmp[offset] = new Value(*v);
413
    } else if (auto v = EmitAttribute<std::vector<int16_t>>(attr.getValue())) {
414
      tmp[offset] = new Value(std::move(*v));
415
    } else if (auto v = EmitAttribute<std::vector<int32_t>>(attr.getValue())) {
416
      tmp[offset] = new Value(std::move(*v));
417
    } else if (auto v = EmitAttribute<std::vector<int64_t>>(attr.getValue())) {
418
      tmp[offset] = new Value(std::move(*v));
419
    } else if (auto v = EmitAttribute<std::vector<float>>(attr.getValue())) {
420
      tmp[offset] = new Value(std::move(*v));
421
    } else if (auto v = EmitAttribute<std::vector<double>>(attr.getValue())) {
422
      tmp[offset] = new Value(std::move(*v));
Y
Yan Chunwei 已提交
423 424 425 426 427
    } else {
      LOG(FATAL) << "Not supported attribute type";
    }
  }

428 429 430 431
  for (size_t i = 0; i < tmp.size(); i++) {
    impl_->cur_op->AppendAttribute(tmp[i]);
  }

Y
Yan Chunwei 已提交
432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453
  // process regions, we treat regions as attribute.
  auto num_regions = op->getNumRegions();
  if (num_regions > 0) {
    CHECK_EQ(num_regions, 1UL)
        << "op with more than one region is not supported yet.";
    auto& region = op->getRegions().front();
    auto num_blocks = region.getBlocks().size();
    CHECK_EQ(num_blocks, 1UL)
        << "region with more than one block is not supported yet.";

    // process arguments
    llvm::SmallVector<mlir::Type, 4> inputs;
    auto& block = region.getBlocks().front();
    for (auto arg : block.getArguments()) inputs.push_back(arg.getType());

    // process results
    // NOTE: if an op contains a region, we simply ignore the region's return
    // values,
    //       or its return values will conflict with op's return values.
    llvm::SmallVector<mlir::Type, 0> results;

    auto func_type =
454
        mlir::FunctionType::get(region.getContext(), inputs, results);
Y
Yan Chunwei 已提交
455 456 457 458 459
    auto* function = impl_->cur_op->CreateFunctionExecutable(
        &region, func_type, &impl_->func_defs);
    impl_->cur_op->AppendAttribute(new Value(function));
  }

W
Wilber 已提交
460 461 462 463 464
  // process results
  llvm::SmallVector<Value*, 4> res_values;
  for (int i = 0, e = op->getNumResults(); i < e; i++) {
    auto res = op->getResult(i);
    if (res.getType().isa<::infrt::DenseTensorType>()) {
465 466
      auto r =
          impl_->value_map.try_emplace(res, ValueRef(new Value{::Tensor()}));
W
Wilber 已提交
467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486
      CHECK(r.second) << "Duplicate add mlir value [" << DumpToString(res)
                      << "]";
      res_values.push_back(r.first->second.get());
    } else {
      res_values.push_back(AddValue(res));
    }

    VLOG(3) << "* op mlir res: " << DumpToString(res) << " " << GetValue(res);
  }
  impl_->cur_op->SetResults(res_values);

#ifdef INFRT_DEBUG
  {
    VLOG(3) << "check result";
    for (int i = 0; i < impl_->cur_op->frame().GetNumResults(); i++) {
      VLOG(3) << "+ res value: " << impl_->cur_op->frame().GetResults()[i];
    }
  }
#endif

Y
Yan Chunwei 已提交
487 488 489 490 491 492
  return true;
}

bool MlirToRuntimeTranslator::EmitReturnOp(
    mlir::Operation* op, llvm::SmallVectorImpl<mlir::Value>* results) {
  CHECK(results);
493
  if (op->getName().getStringRef() == "infrt.return") {
Y
Yan Chunwei 已提交
494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565
    for (size_t i = 0; i < op->getNumOperands(); i++) {
      results->push_back(op->getOperand(i));
    }

    return true;
  }
  return false;
}

bool MlirToRuntimeTranslator::EmitFunctions() {
  for (auto func_op : impl_->module.getOps<mlir::FuncOp>()) {
    EmitFunction(func_op);
  }
  return true;
}

void MlirToRuntimeTranslator::EmitFunction(mlir::FuncOp op) {
  impl_->func_defs[op.getName().str()] = op;
}

Value* MlirToRuntimeTranslator::GetOpResult(mlir::Operation* op) {
  auto it = impl_->op_results.find(op);
  return it == impl_->op_results.end() ? nullptr : it->second.front().get();
}

Value* MlirToRuntimeTranslator::GetValue(mlir::Value value) {
  auto it = impl_->value_map.find(value);
  return it == impl_->value_map.end() ? nullptr : it->second.get();
}

Value* MlirToRuntimeTranslator::AddValue(mlir::Value value) {
  auto res = impl_->value_map.try_emplace(value, ValueRef(new Value));
  CHECK(res.second) << "Duplicate add mlir value [" << DumpToString(value)
                    << "]";
  return res.first->second.get();
}

MlirToRuntimeTranslator::~MlirToRuntimeTranslator() {}

void MlirToRuntimeTranslator::UpdateCurFuncName(const std::string& name) {
  impl_->cur_func_name = std::string(name);
}

MlirToRuntimeTranslator::MlirToRuntimeTranslator(mlir::ModuleOp module,
                                                 CoreRuntimeBuilder* runtime)
    : impl_(new Impl) {
  CHECK(runtime);
  impl_->module = module;
  impl_->runtime = runtime;
}

bool MlirToRuntimeTranslator::EmitBuildShapeOp(mlir::Operation* op) {
  if (op->getName().getStringRef() != "ts.build_shape") return false;

  auto value = op->getAttr("value");

  CHECK(value.isa<mlir::ArrayAttr>());
  auto values = value.cast<mlir::ArrayAttr>().getValue();
  std::vector<int64_t> dims;
  for (auto& attr_v : values) {
    dims.push_back(attr_v.cast<mlir::IntegerAttr>().getInt());
  }
  impl_->op_results[op] = {
      ValueRef(new Value(tensor::TensorShape(llvm::ArrayRef<int64_t>(dims))))};

  return true;
}

bool MlirToRuntimeTranslator::EmitCallOp(mlir::Operation* op,
                                         function_defs_t* function_table) {
  CHECK(op);
  CHECK(function_table);
566
  if (op->getName().getStringRef() != "infrt.call") return false;
Y
Yan Chunwei 已提交
567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592

  impl_->cur_op =
      impl_->runtime->NewOpExecutable(op->getName().getStringRef().str());

  auto callee = op->getAttr("callee");
  auto callee_name = callee.dyn_cast<mlir::FlatSymbolRefAttr>();

  // process arguments
  for (size_t i = 0; i < op->getNumOperands(); i++) {
    auto operand = op->getOperand(i);
    auto* arg_value = GetValue(operand);

    if (!arg_value) {
      auto upstream_op = operand.getDefiningOp();
      arg_value = GetOpResult(upstream_op);
    }
    CHECK(arg_value) << "No-exist argument value found: "
                     << DumpToString(operand);
    impl_->cur_op->AppendArgument(arg_value);
  }

  // process attribute
  auto& table = function_table ? *function_table : impl_->func_defs;
  {
    // lookup the callee function
    auto it = table.find(callee_name.getValue().str());
593 594
    CHECK(it != table.end())
        << "can't find function [" << callee_name.getValue().str() << "]";
Y
Yan Chunwei 已提交
595 596 597 598 599
    auto* function =
        impl_->cur_op->CreateFunctionExecutable(it->second, &impl_->func_defs);
    impl_->cur_op->AppendAttribute(new Value(function));
  }

600 601 602 603 604 605 606 607
  // process results
  llvm::SmallVector<Value*, 4> res_values;
  for (int i = 0, e = op->getNumResults(); i < e; i++) {
    auto res = op->getResult(i);
    res_values.push_back(AddValue(res));
  }
  impl_->cur_op->SetResults(res_values);

Y
Yan Chunwei 已提交
608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685
  VLOG(3) << "Emit call " << callee_name.getValue().str() << " "
          << impl_->cur_op->frame();
  return true;
}

MlirToRuntimeTranslator::MlirToRuntimeTranslator(CoreRuntimeBuilder* runtime)
    : impl_(new Impl) {
  CHECK(runtime);
  impl_->runtime = runtime;
}

Value* MlirToRuntimeTranslator::AddValue(mlir::Value mlir_value, Value* value) {
  auto it = impl_->value_map.try_emplace(mlir_value, ValueRef(value));
  CHECK(it.second) << "duplicate add value " << DumpToString(mlir_value);
  return value;
}

void MlirToRuntimeTranslate(mlir::ModuleOp module,
                            CoreRuntimeBuilder* runtime) {
  MlirToRuntimeTranslator(module, runtime).Run();
}

/**
 * Execute the mlir program in test mode -- print some debug information to
 * stdout.
 */
class MlirProgramTestExecutor : public MlirToRuntimeTranslator {
 public:
  CoreRuntimeBuilder core_runtime;

  MlirProgramTestExecutor(mlir::ModuleOp module, KernelRegistry* registry)
      : MlirToRuntimeTranslator(module, &core_runtime),
        core_runtime(registry),
        registry(registry) {
    CHECK(registry);
  }

  void Run() {
    EmitFunctions();

    CHECK(registry);
    for (auto func_op : impl_->module.getOps<mlir::FuncOp>()) {
      VLOG(3) << "Running function " << func_op.getName().str();
      EmitAndRunFuncWithoutArguments(func_op);
    }
  }

 protected:
  std::unordered_map<std::string, mlir::FuncOp> func_def_table;

  void EmitFunction(mlir::FuncOp op) override {
    CHECK(!impl_->func_defs.count(op.getName().str()))
        << "Duplicate function defition found for function ["
        << op.getName().str();
    impl_->func_defs.emplace(op.getName().str(), op);
  }

 private:
  void EmitAndRunFuncWithoutArguments(mlir::FuncOp func) {
    // print the function name for llvm FileChecker macro, CHECK-LABEL
    std::cout << '@' << func.getName().str() << std::endl;
    if (func.getNumArguments() ==
        0) {  // an entry function, execute it immediately
      VLOG(3) << "executing function " << func.getName().str();
      // Emit and execute each function
      CoreRuntimeBuilder runtime(registry);
      impl_->runtime = &runtime;

      auto& blocks = func.getBlocks();
      CHECK_EQ(blocks.size(), 1UL)
          << "function with more than one block is not supported yet";

      for (auto& op : blocks.front()) {
        if (EmitConstantOp(&op)) continue;
        if (EmitBuildShapeOp(&op)) continue;
        llvm::SmallVector<mlir::Value, 3> results;
        if (EmitReturnOp(&op, &results)) continue;
        if (EmitCallOp(&op, &impl_->func_defs)) continue;
686
        if (EmitGeneralOp(&op, *registry)) continue;
Y
Yan Chunwei 已提交
687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705
        LOG(FATAL) << "Not supported op: " << DumpToString(op);
      }

      runtime.Execute();

    } else {
      VLOG(2) << "get an callable function: " << func.getName().str();
    }
  }

 private:
  KernelRegistry* registry{};
};

void TestMlir(mlir::ModuleOp module, KernelRegistry* registry) {
  MlirProgramTestExecutor execute(module, registry);
  execute.Run();
}

706 707
}  // namespace host_context
}  // namespace infrt