dense_tensor.cc 3.7 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/infrt/dialect/dense_tensor.h"

#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
20 21
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
Y
Yan Chunwei 已提交
22 23 24 25 26 27 28 29 30 31 32
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h>

#include <tuple>

#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/tensor_shape.h"

33 34
namespace infrt {
namespace dt {
Y
Yan Chunwei 已提交
35 36 37 38 39 40
void DTDialect::initialize() {
  addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/dense_tensor.cpp.inc"
      >();
}
41 42 43
static mlir::Type getTensorType(mlir::MLIRContext *context) {
  auto t_dialect = mlir::Identifier::get("t", context);
  return mlir::OpaqueType::get(t_dialect, "tensor");
Y
Yan Chunwei 已提交
44 45
}

46 47 48
static mlir::ParseResult parseCreateUninitTensorOp(
    mlir::OpAsmParser &parser,       // NOLINT
    mlir::OperationState &result) {  // NOLINT
Y
Yan Chunwei 已提交
49
  auto loc = parser.getCurrentLocation();
50 51
  mlir::Type outputRawTypes[1];
  ::llvm::ArrayRef<mlir::Type> outputTypes(outputRawTypes);
Y
Yan Chunwei 已提交
52 53 54 55 56 57

  mlir::ArrayAttr shapeAttr;
  if (parser.parseAttribute(shapeAttr,
                            parser.getBuilder().getI64Type(),
                            "shape",
                            result.attributes))
58 59
    return mlir::failure();
  if (parser.parseOptionalAttrDict(result.attributes)) return mlir::failure();
Y
Yan Chunwei 已提交
60

61 62
  if (parser.parseArrow()) return mlir::failure();
  if (parser.parseType(outputRawTypes[0])) return mlir::failure();
63
  if (!outputRawTypes[0].isa<DenseTensorType>())
Y
Yan Chunwei 已提交
64 65
    return parser.emitError(loc, "invalid kind of type specified");
  result.addTypes(outputTypes);
66
  return mlir::success();
Y
Yan Chunwei 已提交
67 68 69
}

template <typename CreateUninitTensorOp>
70
static void printCreateUninitTensorOp(mlir::OpAsmPrinter &p,  // NOLINT
Y
Yan Chunwei 已提交
71 72 73 74
                                      CreateUninitTensorOp op) {
  p << CreateUninitTensorOp::getOperationName();
  p << " ";
  p.printAttributeWithoutType(op.shapeAttr());
75
  p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
Y
Yan Chunwei 已提交
76 77 78 79
  p << " -> ";
  p << op.getOperation()->getResultTypes();
}

80 81 82 83 84
static mlir::ParseResult parseSetTensorOp(
    mlir::OpAsmParser &parser,       // NOLINT
    mlir::OperationState &result) {  // NOLINT
  llvm::SmallVector<mlir::OpAsmParser::OperandType, 1> operands;
  if (parser.parseOperandList(operands, 1)) return mlir::failure();
Y
Yan Chunwei 已提交
85 86 87

  auto tensor_type = getTensorType(result.getContext());

88 89
  mlir::Attribute value_attr;
  return mlir::failure(
Y
Yan Chunwei 已提交
90 91 92 93 94
      parser.resolveOperand(operands[0], tensor_type, result.operands) ||
      parser.parseAttribute(value_attr, "values", result.attributes));
}

template <typename SetTensorOp>
95
static void printSetTensorOp(mlir::OpAsmPrinter &p, SetTensorOp op) {  // NOLINT
Y
Yan Chunwei 已提交
96 97
  p << SetTensorOp::getOperationName() << " ";
  p.printOperand(op.getOperand());
98
  p << " " << op->getAttr("values");
Y
Yan Chunwei 已提交
99
}
100 101
}  // namespace dt
}  // namespace infrt
Y
Yan Chunwei 已提交
102 103 104

#define GET_OP_CLASSES
#include "paddle/infrt/dialect/dense_tensor.cpp.inc"  // NOLINT
105
#include "paddle/infrt/dialect/dense_tensor_dialect.cpp.inc"