dense_tensor.cc 4.0 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/infrt/dialect/dense_tensor.h"

#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
20 21
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
Y
Yan Chunwei 已提交
22 23 24 25 26 27 28 29 30 31 32
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h>

#include <tuple>

#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/tensor_shape.h"

33 34
namespace infrt {
namespace dt {
Y
Yan Chunwei 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
void DTDialect::initialize() {
  addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/dense_tensor.cpp.inc"
      >();
}

TensorMapType TensorMapType::get() {
  return Base::get(::infrt::Global::getMLIRContext());
}

TensorMapType TensorMapType::get(mlir::MLIRContext *context) {
  return Base::get(context);
}

StringType StringType::get() {
  return Base::get(::infrt::Global::getMLIRContext());
}

StringType StringType::get(mlir::MLIRContext *context) {
  return Base::get(context);
}

58 59 60
static mlir::Type getTensorType(mlir::MLIRContext *context) {
  auto t_dialect = mlir::Identifier::get("t", context);
  return mlir::OpaqueType::get(t_dialect, "tensor");
Y
Yan Chunwei 已提交
61 62
}

63 64 65
static mlir::ParseResult parseCreateUninitTensorOp(
    mlir::OpAsmParser &parser,       // NOLINT
    mlir::OperationState &result) {  // NOLINT
Y
Yan Chunwei 已提交
66
  auto loc = parser.getCurrentLocation();
67 68
  mlir::Type outputRawTypes[1];
  ::llvm::ArrayRef<mlir::Type> outputTypes(outputRawTypes);
Y
Yan Chunwei 已提交
69 70 71 72 73 74

  mlir::ArrayAttr shapeAttr;
  if (parser.parseAttribute(shapeAttr,
                            parser.getBuilder().getI64Type(),
                            "shape",
                            result.attributes))
75 76
    return mlir::failure();
  if (parser.parseOptionalAttrDict(result.attributes)) return mlir::failure();
Y
Yan Chunwei 已提交
77

78 79
  if (parser.parseArrow()) return mlir::failure();
  if (parser.parseType(outputRawTypes[0])) return mlir::failure();
80
  if (!outputRawTypes[0].isa<DenseTensorType>())
Y
Yan Chunwei 已提交
81 82
    return parser.emitError(loc, "invalid kind of type specified");
  result.addTypes(outputTypes);
83
  return mlir::success();
Y
Yan Chunwei 已提交
84 85 86
}

template <typename CreateUninitTensorOp>
87
static void printCreateUninitTensorOp(mlir::OpAsmPrinter &p,  // NOLINT
Y
Yan Chunwei 已提交
88 89 90 91
                                      CreateUninitTensorOp op) {
  p << CreateUninitTensorOp::getOperationName();
  p << " ";
  p.printAttributeWithoutType(op.shapeAttr());
92
  p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
Y
Yan Chunwei 已提交
93 94 95 96
  p << " -> ";
  p << op.getOperation()->getResultTypes();
}

97 98 99 100 101
static mlir::ParseResult parseSetTensorOp(
    mlir::OpAsmParser &parser,       // NOLINT
    mlir::OperationState &result) {  // NOLINT
  llvm::SmallVector<mlir::OpAsmParser::OperandType, 1> operands;
  if (parser.parseOperandList(operands, 1)) return mlir::failure();
Y
Yan Chunwei 已提交
102 103 104

  auto tensor_type = getTensorType(result.getContext());

105 106
  mlir::Attribute value_attr;
  return mlir::failure(
Y
Yan Chunwei 已提交
107 108 109 110 111
      parser.resolveOperand(operands[0], tensor_type, result.operands) ||
      parser.parseAttribute(value_attr, "values", result.attributes));
}

template <typename SetTensorOp>
112
static void printSetTensorOp(mlir::OpAsmPrinter &p, SetTensorOp op) {  // NOLINT
Y
Yan Chunwei 已提交
113 114
  p << SetTensorOp::getOperationName() << " ";
  p.printOperand(op.getOperand());
115
  p << " " << op->getAttr("values");
Y
Yan Chunwei 已提交
116
}
117 118
}  // namespace dt
}  // namespace infrt
Y
Yan Chunwei 已提交
119 120 121

#define GET_OP_CLASSES
#include "paddle/infrt/dialect/dense_tensor.cpp.inc"  // NOLINT
122
#include "paddle/infrt/dialect/dense_tensor_dialect.cpp.inc"