test_kernels.cc 5.5 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/infrt/dialect/test_kernels.h"

17 18 19 20
#include <mlir/IR/Builders.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/TypeUtilities.h>
Y
Yan Chunwei 已提交
21

22 23
namespace infrt {
namespace dialect {
Y
Yan Chunwei 已提交
24 25 26 27 28 29 30 31 32 33
//===----------------------------------------------------------------------===//
// BenchmarkOp
//===----------------------------------------------------------------------===//

// Parse the BenchmarkOp in the following format
// infrt.benchmark "add.i32"(%c : i32, %d : f32)
//       max_count = 100, duration_secs = 1 {
// ...
// }

34 35 36 37
static mlir::ParseResult parseBenchmarkOp(
    mlir::OpAsmParser &parser,       // NOLINT
    mlir::OperationState &result) {  // NOLINT
  mlir::StringAttr nameAttr;
Y
Yan Chunwei 已提交
38
  if (parser.parseAttribute(nameAttr, "name", result.attributes))
39
    return mlir::failure();
Y
Yan Chunwei 已提交
40 41

  // Parse the operands, e.g. (%c : i32, %d : f32)
42
  if (parser.parseLParen()) return mlir::failure();
Y
Yan Chunwei 已提交
43

44 45
  llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands;
  llvm::SmallVector<mlir::Type, 4> types;
Y
Yan Chunwei 已提交
46 47 48 49 50 51
  llvm::SMLoc type_loc = parser.getCurrentLocation();

  if (parser.parseOptionalRParen()) {
    // Parse non-empty operands
    do {
      // Parse %c : i32,
52 53
      mlir::OpAsmParser::OperandType operand;
      mlir::Type type;
Y
Yan Chunwei 已提交
54 55

      if (parser.parseOperand(operand) || parser.parseColonType(type))
56
        return mlir::failure();
Y
Yan Chunwei 已提交
57 58 59 60 61

      operands.push_back(operand);
      types.push_back(type);
    } while (succeeded(parser.parseOptionalComma()));

62
    if (parser.parseRParen()) return mlir::failure();
Y
Yan Chunwei 已提交
63 64 65
  }

  if (parser.resolveOperands(operands, types, type_loc, result.operands))
66
    return mlir::failure();
Y
Yan Chunwei 已提交
67 68 69

  // Parse the keyword attribute, e.g. max_count = 100, duration_secs = 1
  do {
70 71
    mlir::StringRef attr;
    mlir::Attribute resultAttr;
Y
Yan Chunwei 已提交
72 73 74 75 76
    if (parser.parseKeyword(&attr) || parser.parseEqual() ||
        parser.parseAttribute(resultAttr,
                              parser.getBuilder().getIntegerType(32),
                              attr,
                              result.attributes))
77 78
      return mlir::failure();
  } while (mlir::succeeded(parser.parseOptionalComma()));
Y
Yan Chunwei 已提交
79 80 81 82

  // Set the default attribute num_warmup_runs to 1 if unset
  auto setDefaultAttrIfUnset = [&](const char *attr_name, int value) {
    bool found = llvm::any_of(result.attributes,
83 84
                              [attr_name](const mlir::NamedAttribute &attr) {
                                return attr.getName() == attr_name;
Y
Yan Chunwei 已提交
85 86
                              });
    if (!found) {
87 88
      mlir::IntegerAttr default_val =
          parser.getBuilder().getI32IntegerAttr(value);
Y
Yan Chunwei 已提交
89 90 91 92 93
      result.addAttribute(attr_name, default_val);
    }
  };
  setDefaultAttrIfUnset("num_warmup_runs", 1);

94
  mlir::Region *target = result.addRegion();
Y
Yan Chunwei 已提交
95 96 97 98 99 100 101 102 103 104 105
  return parser.parseRegion(*target,
                            operands,
                            types,
                            /*enableNameShadowing=*/true);
}

// Print the BenchmarkOp in the following format
// infrt.benchmark "add.i32"(%c : i32, %d : f32)
//       max_count = 100, duration_secs = 1 {
// ...
// }
106
static void print(mlir::OpAsmPrinter &p, BenchmarkOp op) {  // NOLINT
Y
Yan Chunwei 已提交
107 108 109
  p << "infrt.benchmark ";

  // Print the name attribute, e.g "add.i32"
110
  auto name_attr = op->getAttr("name");
Y
Yan Chunwei 已提交
111 112 113 114 115 116 117 118 119 120 121 122 123
  p << name_attr;

  // Print the operands and types, e.g. (%c : i32, %d : f32)
  p << '(';
  llvm::interleaveComma(llvm::zip(op.getOperands(), op.getOperandTypes()),
                        p,
                        [&](const auto &it) {
                          p << std::get<0>(it) << " : " << std::get<1>(it);
                        });
  p << ") ";

  bool need_comma = false;
  // Print the attributes, e.g. max_count = 100, duration_secs = 1
124 125
  for (auto &name_attr : op->getAttrs()) {
    auto id = name_attr.getName();
Y
Yan Chunwei 已提交
126 127
    if (id == "name") continue;
    if (need_comma) p << ", ";
128
    auto attr = name_attr.getValue();
Y
Yan Chunwei 已提交
129
    p << id << " = ";
130
    if (auto int_attr = attr.dyn_cast<mlir::IntegerAttr>()) {
Y
Yan Chunwei 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
      int_attr.getValue().print(p.getStream(), /*isSigned=*/false);
    } else {
      op.emitOpError("Unexpected attribute");
    }
    need_comma = true;
  }
  p << ' ';

  // Print the region
  // Reuse the argument names provided to the op for the bbarg names within
  // the region.
  p.shadowRegionArgs(op.region(), op.getOperands());
  p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
}

146
static mlir::LogicalResult verify(BenchmarkOp op) {
Y
Yan Chunwei 已提交
147 148 149
  // Verify that the target benchmark region has exactly one return value.
  auto &region = op.region();
  auto &last_op = region.front().back();
150
  if (last_op.getName().getStringRef() != "Infrt.return") {
Y
Yan Chunwei 已提交
151 152 153 154 155 156 157
    return op.emitOpError("missing return statement");
  }
  if (last_op.getNumOperands() != 1) {
    return op.emitOpError(
        "incorrect number of return values. One return value is expected");
  }

158
  return mlir::success();
Y
Yan Chunwei 已提交
159
}
160 161
}  // namespace dialect
}  // namespace infrt
Y
Yan Chunwei 已提交
162 163 164

#define GET_OP_CLASSES
#include "paddle/infrt/dialect/test_kernels.cpp.inc"