dense_tensor.cc 6.2 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/infrt/dialect/dense_tensor.h"

#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
20 21
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
Y
Yan Chunwei 已提交
22 23 24 25 26 27 28 29 30 31 32
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h>

#include <tuple>

#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/tensor_shape.h"

33 34
namespace infrt {
namespace dt {
Y
Yan Chunwei 已提交
35 36 37 38 39 40 41 42
void DTDialect::initialize() {
  addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/dense_tensor.cpp.inc"
      >();
}

llvm::Optional<TargetType> GetTargetType(mlir::StringRef key) {
43
  if (key.equals_insensitive("x86"))
Y
Yan Chunwei 已提交
44
    return TargetType::X86;
45
  else if (key.equals_insensitive("cuda"))
Y
Yan Chunwei 已提交
46 47 48 49 50 51
    return TargetType::CUDA;
  else
    return llvm::None;
}

llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key) {
52
  if (key.equals_insensitive("nchw"))
Y
Yan Chunwei 已提交
53
    return LayoutType::NCHW;
54
  else if (key.equals_insensitive("nhwc"))
Y
Yan Chunwei 已提交
55 56 57 58 59 60
    return LayoutType::NHWC;
  else
    return llvm::None;
}

llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key) {
61
  if (key.equals_insensitive("i32"))
Y
Yan Chunwei 已提交
62
    return PrecisionType::I32;
63
  else if (key.equals_insensitive("f32"))
Y
Yan Chunwei 已提交
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
    return PrecisionType::F32;
  else
    return llvm::None;
}

TensorType TensorType::get(TargetType target,
                           LayoutType layout,
                           PrecisionType precision) {
  return Base::get(
      ::infrt::Global::getMLIRContext(), target, layout, precision);
}

TargetType TensorType::target() { return getImpl()->target_; }

LayoutType TensorType::layout() { return getImpl()->layout_; }

PrecisionType TensorType::precision() { return getImpl()->precision_; }

82
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TensorType tensorType) {
Y
Yan Chunwei 已提交
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
  os << "TensorType<" << tensorType.target() << ", " << tensorType.layout()
     << ", " << tensorType.precision() << ">";
  return os;
}

TensorMapType TensorMapType::get() {
  return Base::get(::infrt::Global::getMLIRContext());
}

TensorMapType TensorMapType::get(mlir::MLIRContext *context) {
  return Base::get(context);
}

StringType StringType::get() {
  return Base::get(::infrt::Global::getMLIRContext());
}

StringType StringType::get(mlir::MLIRContext *context) {
  return Base::get(context);
}

104
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TargetType type) {
Y
Yan Chunwei 已提交
105 106 107 108 109 110 111 112 113 114 115 116 117
  switch (type) {
    case (TargetType::X86):
      os << "X86";
      break;
    case (TargetType::CUDA):
      os << "CUDA";
      break;
    default:
      os << "Unsupported";
  }
  return os;
}

118
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, LayoutType type) {
Y
Yan Chunwei 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131
  switch (type) {
    case (LayoutType::NCHW):
      os << "NCHW";
      break;
    case (LayoutType::NHWC):
      os << "NHWC";
      break;
    default:
      os << "Unsupported";
  }
  return os;
}

132
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, PrecisionType type) {
Y
Yan Chunwei 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145
  switch (type) {
    case (PrecisionType::I32):
      os << "I32";
      break;
    case (PrecisionType::F32):
      os << "F32";
      break;
    default:
      os << "Unsupported";
  }
  return os;
}

146 147 148
static mlir::Type getTensorType(mlir::MLIRContext *context) {
  auto t_dialect = mlir::Identifier::get("t", context);
  return mlir::OpaqueType::get(t_dialect, "tensor");
Y
Yan Chunwei 已提交
149 150
}

151 152 153
static mlir::ParseResult parseCreateUninitTensorOp(
    mlir::OpAsmParser &parser,       // NOLINT
    mlir::OperationState &result) {  // NOLINT
Y
Yan Chunwei 已提交
154
  auto loc = parser.getCurrentLocation();
155 156
  mlir::Type outputRawTypes[1];
  ::llvm::ArrayRef<mlir::Type> outputTypes(outputRawTypes);
Y
Yan Chunwei 已提交
157 158 159 160 161 162

  mlir::ArrayAttr shapeAttr;
  if (parser.parseAttribute(shapeAttr,
                            parser.getBuilder().getI64Type(),
                            "shape",
                            result.attributes))
163 164
    return mlir::failure();
  if (parser.parseOptionalAttrDict(result.attributes)) return mlir::failure();
Y
Yan Chunwei 已提交
165

166 167
  if (parser.parseArrow()) return mlir::failure();
  if (parser.parseType(outputRawTypes[0])) return mlir::failure();
Y
Yan Chunwei 已提交
168 169 170
  if (!outputRawTypes[0].isa<TensorType>())
    return parser.emitError(loc, "invalid kind of type specified");
  result.addTypes(outputTypes);
171
  return mlir::success();
Y
Yan Chunwei 已提交
172 173 174
}

template <typename CreateUninitTensorOp>
175
static void printCreateUninitTensorOp(mlir::OpAsmPrinter &p,  // NOLINT
Y
Yan Chunwei 已提交
176 177 178 179
                                      CreateUninitTensorOp op) {
  p << CreateUninitTensorOp::getOperationName();
  p << " ";
  p.printAttributeWithoutType(op.shapeAttr());
180
  p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
Y
Yan Chunwei 已提交
181 182 183 184
  p << " -> ";
  p << op.getOperation()->getResultTypes();
}

185 186 187 188 189
static mlir::ParseResult parseSetTensorOp(
    mlir::OpAsmParser &parser,       // NOLINT
    mlir::OperationState &result) {  // NOLINT
  llvm::SmallVector<mlir::OpAsmParser::OperandType, 1> operands;
  if (parser.parseOperandList(operands, 1)) return mlir::failure();
Y
Yan Chunwei 已提交
190 191 192

  auto tensor_type = getTensorType(result.getContext());

193 194
  mlir::Attribute value_attr;
  return mlir::failure(
Y
Yan Chunwei 已提交
195 196 197 198 199
      parser.resolveOperand(operands[0], tensor_type, result.operands) ||
      parser.parseAttribute(value_attr, "values", result.attributes));
}

template <typename SetTensorOp>
200
static void printSetTensorOp(mlir::OpAsmPrinter &p, SetTensorOp op) {  // NOLINT
Y
Yan Chunwei 已提交
201 202
  p << SetTensorOp::getOperationName() << " ";
  p.printOperand(op.getOperand());
203
  p << " " << op->getAttr("values");
Y
Yan Chunwei 已提交
204
}
205 206
}  // namespace dt
}  // namespace infrt
Y
Yan Chunwei 已提交
207 208 209 210

#define GET_OP_CLASSES
#include "paddle/infrt/dialect/dense_tensor.cpp.inc"  // NOLINT

211
#include "paddle/infrt/dialect/dense_tensor_dialect.cpp.inc"