/** * \file src/jit/impl/mlir/ir/dialect.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #include "megbrain_build_config.h" #if MGB_JIT && MGB_JIT_MLIR #include "megbrain/jit/mlir/ir/dialect.h" #include #include #include #include using namespace mgb; using namespace jit; MgbDialect::MgbDialect(mlir::MLIRContext *ctx) : mlir::Dialect("mgb", ctx) { addOperations< #define GET_OP_LIST #include "megbrain/jit/mlir/ir/ops.cpp.inc" >(); } static mlir::ParseResult parseBinaryOp(mlir::OpAsmParser &parser, mlir::OperationState &result) { SmallVector operands; llvm::SMLoc operandsLoc = parser.getCurrentLocation(); Type type; if (parser.parseOperandList(operands, /*requiredOperandCount=*/2) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type)) return mlir::failure(); // If the type is a function type, it contains the input and result types of // this operation. if (FunctionType funcType = type.dyn_cast()) { if (parser.resolveOperands(operands, funcType.getInputs(), operandsLoc, result.operands)) return mlir::failure(); result.addTypes(funcType.getResults()); return mlir::success(); } // Otherwise, the parsed type is the type of both operands and results. if (parser.resolveOperands(operands, type, result.operands)) return mlir::failure(); result.addTypes(type); return mlir::success(); } static void printBinaryOp(mlir::OpAsmPrinter &printer, mlir::Operation *op) { printer << op->getName() << " " << op->getOperands(); printer.printOptionalAttrDict(op->getAttrs()); printer << " : "; // If all of the types are the same, print the type directly. Type resultType = *op->result_type_begin(); if (llvm::all_of(op->getOperandTypes(), [=](Type type) { return type == resultType; })) { printer << resultType; return; } // Otherwise, print a functional type. printer.printFunctionalType(op->getOperandTypes(), op->getResultTypes()); } ///////////////////////// ElemwiseOp ///////////////////////////////////////////// void AddOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Value lhs, mlir::Value rhs) { state.addTypes(lhs.getType()); state.addOperands({lhs, rhs}); } void AddOp::infer_shapes() { getResult().setType(getOperand(0).getType()); } #define GET_OP_CLASSES #include "megbrain/jit/mlir/ir/ops.cpp.inc" #endif // MGB_JIT && MGB_JIT_MLIR // vim: syntax=cpp.doxygen