diff --git a/cmake/external/llvm.cmake b/cmake/external/llvm.cmake index e080a7359af98276be2d6bfc53e6b5917f83bde9..27210e5260048a57cc442fce4c6cf8657e401568 100644 --- a/cmake/external/llvm.cmake +++ b/cmake/external/llvm.cmake @@ -1,7 +1,7 @@ include(FetchContent) -set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/CINN/llvm11.tar.gz) -set(LLVM_MD5 39d32b6be466781dddf5869318dcba53) +set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/infrt/llvm_b5149f4e66a49a98b67e8e2de4e24a4af8e2781b.tar.gz) +set(LLVM_MD5 022819bb5760817013cf4b8a37e97d5e) set(FETCHCONTENT_BASE_DIR ${THIRD_PARTY_PATH}/llvm) set(FETCHCONTENT_QUIET OFF) @@ -51,7 +51,7 @@ message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") # To build with MLIR, the LLVM is build from source code using the following flags: #[==[ -cmake -G Ninja ../llvm \ +cmake ../llvm -G "Unix Makefiles" \ -DLLVM_ENABLE_PROJECTS="mlir;clang" \ -DLLVM_BUILD_EXAMPLES=OFF \ -DLLVM_TARGETS_TO_BUILD="X86" \ @@ -59,8 +59,10 @@ cmake -G Ninja ../llvm \ -DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_ENABLE_ZLIB=OFF \ -DLLVM_ENABLE_RTTI=ON \ + -DLLVM_INSTALL_UTILS=ON \ + -DCMAKE_INSTALL_PREFIX=./install #]==] -# The matched llvm-project version is f9dc2b7079350d0fed3bb3775f496b90483c9e42 (currently a temporary commit) +# The matched llvm-project version is b5149f4e66a49a98b67e8e2de4e24a4af8e2781b (currently a temporary commit) add_definitions(${LLVM_DEFINITIONS}) @@ -75,7 +77,7 @@ add_definitions(${LLVM_DEFINITIONS}) # The minimum needed libraries for MLIR IR parse and transform. -set(MLIR_IR_LIBS MLIRAnalysis MLIRStandardOps MLIRPass MLIRParser MLIRDialect MLIRIR MLIROptLib) +set(MLIR_IR_LIBS MLIRAnalysis MLIRPass MLIRParser MLIRDialect MLIRIR MLIROptLib) # tb_base is the name of a xxx.td file (without the .td suffix) @@ -89,6 +91,7 @@ function(mlir_tablegen_on td_base) mlir_tablegen(${td_base}.cpp.inc -gen-op-defs) if (mlir_tablegen_on_DIALECT) mlir_tablegen(${td_base}_dialect.hpp.inc --gen-dialect-decls -dialect=${mlir_tablegen_on_DIALECT}) + mlir_tablegen(${td_base}_dialect.cpp.inc --gen-dialect-defs -dialect=${mlir_tablegen_on_DIALECT}) endif() add_public_tablegen_target(${td_base}_IncGen) add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen) diff --git a/paddle/infrt/CMakeLists.txt b/paddle/infrt/CMakeLists.txt index 8f05d286bf0339e52eecdb043731bba41db7504d..8af3012a220ad1e06803b6832dc3c3558af7bb53 100644 --- a/paddle/infrt/CMakeLists.txt +++ b/paddle/infrt/CMakeLists.txt @@ -77,7 +77,6 @@ add_subdirectory(paddle) # MLIR td file generations set(infrt_mlir_incs - ops_inc basic_kernels_inc test_kernels_inc infrt_base_inc diff --git a/paddle/infrt/common/global.h b/paddle/infrt/common/global.h index f89164d03f31dedc81aca779f16fd42f979f3aab..e6586cb3a3c603ed352b360a45c3cce879978657 100644 --- a/paddle/infrt/common/global.h +++ b/paddle/infrt/common/global.h @@ -14,7 +14,7 @@ #pragma once -#include "mlir/IR/MLIRContext.h" +#include #include "paddle/infrt/tensor/dense_host_tensor.h" namespace infrt { diff --git a/paddle/infrt/dialect/CMakeLists.txt b/paddle/infrt/dialect/CMakeLists.txt index d145843684c6366897d4347d66998af71e4250c2..c064b2145266bfb44f05c0c118b03388fa1b8e8b 100644 --- a/paddle/infrt/dialect/CMakeLists.txt +++ b/paddle/infrt/dialect/CMakeLists.txt @@ -2,7 +2,6 @@ core_gather_headers() gather_srcs(infrt_src SRCS dialect.cc - types.cc basic_kernels.cc test_kernels.cc infrt_base.cc @@ -14,8 +13,6 @@ gather_srcs(infrt_src SRCS pd_types.cc pd_ops.cc ) - -mlir_tablegen_on(ops) mlir_tablegen_on(basic_kernels) mlir_tablegen_on(test_kernels) mlir_tablegen_on(infrt_base DIALECT infrt) @@ -27,8 +24,7 @@ mlir_add_rewriter(rewrite) # TODO(Superjomn) add a cmake function cc_executable to ecapsulate the following code add_executable(infrtopt opt.cc) -target_link_libraries(infrtopt infrt ${mlir_libs}) -add_dependencies(infrtopt infrt) +target_link_libraries(infrtopt infrt) add_executable(print-ir print_ir.cc) target_link_libraries(print-ir infrt ${mlir_libs}) diff --git a/paddle/infrt/dialect/basic_kernels.cc b/paddle/infrt/dialect/basic_kernels.cc index b4d2b9182b0c5035f829715c21970a47fb79e9cb..bad7e73ec5ae5c3216a912729637664bba17d3b0 100644 --- a/paddle/infrt/dialect/basic_kernels.cc +++ b/paddle/infrt/dialect/basic_kernels.cc @@ -17,17 +17,17 @@ #include #include #include -#include -#include +#include +#include #include #include -#include #include #include #include "paddle/infrt/dialect/dense_tensor.h" -namespace infrt::dialect { +namespace infrt { +namespace dialect { using namespace mlir; // NOLINT static ParseResult parseCallOp(OpAsmParser &parser, // NOLINT @@ -71,12 +71,12 @@ static ParseResult parseConstantF64Op(OpAsmParser &parser, // NOLINT static ParseResult parseConstantI32Op(OpAsmParser &parser, // NOLINT OperationState &result) { // NOLINT return parseConstantOp( - IntegerType::get(32, result.getContext()), parser, result); + IntegerType::get(result.getContext(), 32), parser, result); } static ParseResult parseConstantI64Op(OpAsmParser &parser, // NOLINT OperationState &result) { // NOLINT return parseConstantOp( - IntegerType::get(64, result.getContext()), parser, result); + IntegerType::get(result.getContext(), 64), parser, result); } static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT @@ -90,10 +90,10 @@ static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT } static void print(OpAsmPrinter &p, CallOp op) { // NOLINT - p << "infrt.call " << op.getAttr("callee") << "("; + p << "infrt.call " << op->getAttr("callee") << "("; p.printOperands(op.getOperands()); p << ")"; - p.printOptionalAttrDict(op.getAttrs(), {"callee"}); + p.printOptionalAttrDict(op->getAttrs(), {"callee"}); p << " : "; } @@ -145,7 +145,7 @@ static LogicalResult verify(ConstantF64Op op) { return success(); } static LogicalResult verify(ConstantI64Op op) { return success(); } static LogicalResult verify(ReturnOp op) { - auto function = dyn_cast(op.getParentOp()); + auto function = dyn_cast(op->getParentOp()); if (!function) return success(); @@ -157,8 +157,8 @@ static LogicalResult verify(ReturnOp op) { return success(); } +} // namespace dialect +} // namespace infrt #define GET_OP_CLASSES #include "paddle/infrt/dialect/basic_kernels.cpp.inc" - -} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/basic_kernels.h b/paddle/infrt/dialect/basic_kernels.h index 65316bc1437c027a03e629e8f5cab868b5470758..b82abcd52d28f45b18824d9ea6f9e12c2ec1c574 100644 --- a/paddle/infrt/dialect/basic_kernels.h +++ b/paddle/infrt/dialect/basic_kernels.h @@ -13,12 +13,9 @@ // limitations under the License. #pragma once +#include #include #include -using namespace mlir; // NOLINT - -namespace infrt::dialect { #define GET_OP_CLASSES #include "paddle/infrt/dialect/basic_kernels.hpp.inc" -} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/basic_kernels.td b/paddle/infrt/dialect/basic_kernels.td index df5e4d8a2c6a1c50bb959ec5ec4a18b6bf451d59..7d8de79fbae2b0cb36ca354b8f6f39fc94851ebe 100644 --- a/paddle/infrt/dialect/basic_kernels.td +++ b/paddle/infrt/dialect/basic_kernels.td @@ -27,7 +27,7 @@ def CallOp : INFRT_Op<"call"> { let results = (outs Variadic); let extraClassDeclaration = [{ - StringRef getCallee() { return callee(); } + mlir::StringRef getCallee() { return callee(); } mlir::FunctionType getCalleeType(); }]; } @@ -57,9 +57,8 @@ def ReturnOp : INFRT_Op<"return", [Terminator]> { let arguments = (ins Variadic:$operands); - let builders = [OpBuilder< - "OpBuilder &b, OperationState &result", - [{ build(b, result, llvm::None); }]>]; + let builders = [OpBuilder<(ins), + [{ build($_builder, $_state, llvm::None); }]>]; } class AddOp : INFRT_Op<"add." # suffix, [NoSideEffect]> { diff --git a/paddle/infrt/dialect/dense_tensor.cc b/paddle/infrt/dialect/dense_tensor.cc index 629a7b16523fcaabe789b7a5f8d2146c6cd7633d..7685cdc65b9ad00492e0ca8a084ac7c686c94d89 100644 --- a/paddle/infrt/dialect/dense_tensor.cc +++ b/paddle/infrt/dialect/dense_tensor.cc @@ -17,12 +17,11 @@ #include #include #include +#include +#include #include -#include -#include #include #include -#include #include #include @@ -31,68 +30,37 @@ #include "paddle/infrt/common/global.h" #include "paddle/infrt/dialect/tensor_shape.h" -namespace infrt::dt { - +namespace infrt { +namespace dt { void DTDialect::initialize() { - allowUnknownTypes(); addOperations< #define GET_OP_LIST #include "paddle/infrt/dialect/dense_tensor.cpp.inc" >(); } -namespace detail { -struct TensorTypeStorage : public mlir::TypeStorage { - TensorTypeStorage(TargetType target, - LayoutType layout, - PrecisionType precision) - : target_(target), layout_(layout), precision_(precision) {} - - using KeyTy = std::tuple; - - bool operator==(const KeyTy &key) const { - return key == KeyTy(target_, layout_, precision_); - } - - static llvm::hash_code hashKey(const KeyTy &key) { - return llvm::hash_value(key); - } - - static TensorTypeStorage *construct( - mlir::TypeStorageAllocator &allocator, // NOLINT - const KeyTy &key) { - return new (allocator.allocate()) - TensorTypeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key)); - } - - TargetType target_; - LayoutType layout_; - PrecisionType precision_; -}; -} // namespace detail - llvm::Optional GetTargetType(mlir::StringRef key) { - if (key.equals_lower("x86")) + if (key.equals_insensitive("x86")) return TargetType::X86; - else if (key.equals_lower("cuda")) + else if (key.equals_insensitive("cuda")) return TargetType::CUDA; else return llvm::None; } llvm::Optional GetLayoutType(mlir::StringRef key) { - if (key.equals_lower("nchw")) + if (key.equals_insensitive("nchw")) return LayoutType::NCHW; - else if (key.equals_lower("nhwc")) + else if (key.equals_insensitive("nhwc")) return LayoutType::NHWC; else return llvm::None; } llvm::Optional GetPrecisionType(mlir::StringRef key) { - if (key.equals_lower("i32")) + if (key.equals_insensitive("i32")) return PrecisionType::I32; - else if (key.equals_lower("f32")) + else if (key.equals_insensitive("f32")) return PrecisionType::F32; else return llvm::None; @@ -111,7 +79,7 @@ LayoutType TensorType::layout() { return getImpl()->layout_; } PrecisionType TensorType::precision() { return getImpl()->precision_; } -raw_ostream &operator<<(raw_ostream &os, TensorType tensorType) { +mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TensorType tensorType) { os << "TensorType<" << tensorType.target() << ", " << tensorType.layout() << ", " << tensorType.precision() << ">"; return os; @@ -133,7 +101,7 @@ StringType StringType::get(mlir::MLIRContext *context) { return Base::get(context); } -raw_ostream &operator<<(raw_ostream &os, TargetType type) { +mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TargetType type) { switch (type) { case (TargetType::X86): os << "X86"; @@ -147,7 +115,7 @@ raw_ostream &operator<<(raw_ostream &os, TargetType type) { return os; } -raw_ostream &operator<<(raw_ostream &os, LayoutType type) { +mlir::raw_ostream &operator<<(mlir::raw_ostream &os, LayoutType type) { switch (type) { case (LayoutType::NCHW): os << "NCHW"; @@ -161,7 +129,7 @@ raw_ostream &operator<<(raw_ostream &os, LayoutType type) { return os; } -raw_ostream &operator<<(raw_ostream &os, PrecisionType type) { +mlir::raw_ostream &operator<<(mlir::raw_ostream &os, PrecisionType type) { switch (type) { case (PrecisionType::I32): os << "I32"; @@ -175,103 +143,69 @@ raw_ostream &operator<<(raw_ostream &os, PrecisionType type) { return os; } -static Type getTensorType(mlir::MLIRContext *context) { - auto t_dialect = Identifier::get("t", context); - return OpaqueType::get(t_dialect, "tensor", context); +static mlir::Type getTensorType(mlir::MLIRContext *context) { + auto t_dialect = mlir::Identifier::get("t", context); + return mlir::OpaqueType::get(t_dialect, "tensor"); } -static ParseResult parseCreateUninitTensorOp( - OpAsmParser &parser, // NOLINT - OperationState &result) { // NOLINT +static mlir::ParseResult parseCreateUninitTensorOp( + mlir::OpAsmParser &parser, // NOLINT + mlir::OperationState &result) { // NOLINT auto loc = parser.getCurrentLocation(); - ::mlir::Type outputRawTypes[1]; - ::llvm::ArrayRef<::mlir::Type> outputTypes(outputRawTypes); + mlir::Type outputRawTypes[1]; + ::llvm::ArrayRef outputTypes(outputRawTypes); mlir::ArrayAttr shapeAttr; if (parser.parseAttribute(shapeAttr, parser.getBuilder().getI64Type(), "shape", result.attributes)) - return failure(); - if (parser.parseOptionalAttrDict(result.attributes)) return failure(); + return mlir::failure(); + if (parser.parseOptionalAttrDict(result.attributes)) return mlir::failure(); - if (parser.parseArrow()) return failure(); - if (parser.parseType(outputRawTypes[0])) return failure(); + if (parser.parseArrow()) return mlir::failure(); + if (parser.parseType(outputRawTypes[0])) return mlir::failure(); if (!outputRawTypes[0].isa()) return parser.emitError(loc, "invalid kind of type specified"); result.addTypes(outputTypes); - return success(); + return mlir::success(); } template -static void printCreateUninitTensorOp(OpAsmPrinter &p, // NOLINT +static void printCreateUninitTensorOp(mlir::OpAsmPrinter &p, // NOLINT CreateUninitTensorOp op) { p << CreateUninitTensorOp::getOperationName(); p << " "; p.printAttributeWithoutType(op.shapeAttr()); - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); + p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"}); p << " -> "; p << op.getOperation()->getResultTypes(); } -// TODO(shibo): can be removed? -// static ParseResult parseFillTensorWithConstantOp(OpAsmParser& parser, -// OperationState& result) { -// auto loc = parser.getCurrentLocation(); -// ::mlir::OpAsmParser::OperandType inputRawOperands[1]; -// ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType> -// inputOperands(inputRawOperands); -// ::mlir::Type inputRawTypes[1]; -// ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); -// -// if (parser.parseOperand(inputRawOperands[0])) return failure(); -// -// if (parser.parseColon()) return failure(); -// if (parser.parseType(inputRawTypes[0])) return failure(); -// if (!inputRawTypes[0].isa()) -// return parser.emitError(loc, "invalid kind of type specified"); -// -// Attribute value_attr; -// if (parser.resolveOperands(inputOperands, inputTypes, loc, result.operands)) -// return failure(); -// if (parser.parseAttribute(value_attr, "value", result.attributes)) return -// failure(); -// return success(); -//} - -// TODO(shibo): can be removed? -// template -// static void printFillTensorWithConstantOp(OpAsmPrinter& p, FillTensorOp op) { -// p << FillTensorOp::getOperationName(); -// p << " "; -// p.printOperand(op.getOperand()); -// p << " : "; -// p << op.getOperation()->getOperandTypes(); -// p << " "; -// p << op.getAttr("value"); -//} - -static ParseResult parseSetTensorOp(OpAsmParser &parser, // NOLINT - OperationState &result) { // NOLINT - SmallVector operands; - if (parser.parseOperandList(operands, 1)) return failure(); +static mlir::ParseResult parseSetTensorOp( + mlir::OpAsmParser &parser, // NOLINT + mlir::OperationState &result) { // NOLINT + llvm::SmallVector operands; + if (parser.parseOperandList(operands, 1)) return mlir::failure(); auto tensor_type = getTensorType(result.getContext()); - Attribute value_attr; - return failure( + mlir::Attribute value_attr; + return mlir::failure( parser.resolveOperand(operands[0], tensor_type, result.operands) || parser.parseAttribute(value_attr, "values", result.attributes)); } template -static void printSetTensorOp(OpAsmPrinter &p, SetTensorOp op) { // NOLINT +static void printSetTensorOp(mlir::OpAsmPrinter &p, SetTensorOp op) { // NOLINT p << SetTensorOp::getOperationName() << " "; p.printOperand(op.getOperand()); - p << " " << op.getAttr("values"); + p << " " << op->getAttr("values"); } +} // namespace dt +} // namespace infrt #define GET_OP_CLASSES #include "paddle/infrt/dialect/dense_tensor.cpp.inc" // NOLINT -} // namespace infrt::dt +#include "paddle/infrt/dialect/dense_tensor_dialect.cpp.inc" diff --git a/paddle/infrt/dialect/dense_tensor.h b/paddle/infrt/dialect/dense_tensor.h index 866c62213ab058037bafb116602cc0d609fd3bec..416925d3382bad640753b77e5516d6e45a425eef 100644 --- a/paddle/infrt/dialect/dense_tensor.h +++ b/paddle/infrt/dialect/dense_tensor.h @@ -19,13 +19,8 @@ #include -using namespace mlir; // NOLINT -namespace infrt::dt { - -namespace detail { -struct TensorTypeStorage; -} // namespace detail - +namespace infrt { +namespace dt { enum class TargetType : uint8_t { X86, CUDA }; enum class LayoutType : uint8_t { NCHW, NHWC }; enum class PrecisionType : uint8_t { I32, F32 }; @@ -34,9 +29,39 @@ llvm::Optional GetTargetType(mlir::StringRef key); llvm::Optional GetLayoutType(mlir::StringRef key); llvm::Optional GetPrecisionType(mlir::StringRef key); -raw_ostream &operator<<(raw_ostream &os, TargetType type); -raw_ostream &operator<<(raw_ostream &os, LayoutType type); -raw_ostream &operator<<(raw_ostream &os, PrecisionType type); +mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TargetType type); +mlir::raw_ostream &operator<<(mlir::raw_ostream &os, LayoutType type); +mlir::raw_ostream &operator<<(mlir::raw_ostream &os, PrecisionType type); + +namespace detail { +struct TensorTypeStorage : public mlir::TypeStorage { + TensorTypeStorage(TargetType target, + LayoutType layout, + PrecisionType precision) + : target_(target), layout_(layout), precision_(precision) {} + + using KeyTy = std::tuple; + + bool operator==(const KeyTy &key) const { + return key == KeyTy(target_, layout_, precision_); + } + + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + static TensorTypeStorage *construct( + mlir::TypeStorageAllocator &allocator, // NOLINT + const KeyTy &key) { + return new (allocator.allocate()) + TensorTypeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key)); + } + + TargetType target_; + LayoutType layout_; + PrecisionType precision_; +}; +} // namespace detail class TensorType : public mlir::Type::TypeBase #include -namespace infrt::dialect { +namespace infrt { +namespace dialect { struct MyScopedDiagnosicHandler::Impl { Impl() : diag_stream_(diag_str_) {} @@ -49,4 +51,5 @@ mlir::LogicalResult MyScopedDiagnosicHandler::handler(mlir::Diagnostic *diag) { return mlir::failure(true); } -} // namespace infrt::dialect +} // namespace dialect +} // namespace infrt diff --git a/paddle/infrt/dialect/diagnostic_utils.h b/paddle/infrt/dialect/diagnostic_utils.h index 3a8098cf751812d35dc3eac1041bed0536055288..746e61c8fe5c3151f3c6ea1da5bd105d1492082e 100644 --- a/paddle/infrt/dialect/diagnostic_utils.h +++ b/paddle/infrt/dialect/diagnostic_utils.h @@ -18,7 +18,8 @@ #include -namespace infrt::dialect { +namespace infrt { +namespace dialect { /** * A scoped diagnostic handler to help debug MLIR process. @@ -36,4 +37,5 @@ class MyScopedDiagnosicHandler : public mlir::SourceMgrDiagnosticHandler { std::unique_ptr impl_; }; -} // namespace infrt::dialect +} // namespace dialect +} // namespace infrt diff --git a/paddle/infrt/dialect/dialect.cc b/paddle/infrt/dialect/dialect.cc index cbcd5d0f0fa785a21c78d0ae25f40e6211a504ee..fe07b91d22ed54ae576d828f208ec766f5b719da 100644 --- a/paddle/infrt/dialect/dialect.cc +++ b/paddle/infrt/dialect/dialect.cc @@ -13,24 +13,26 @@ // limitations under the License. #include +#include #include -#include #include #include -#include #include #include -namespace infrt::hlir::dialect { +namespace infrt { +namespace hlir { +namespace dialect { -class CinnDialect : public ::mlir::Dialect { +class CinnDialect : public mlir::Dialect { public: - explicit CinnDialect(::mlir::MLIRContext* ctx); + explicit CinnDialect(mlir::MLIRContext* ctx); //! We should register this function in dialect static llvm::StringRef getDialectNamespace() { return "infrt::hlir::dialect"; } }; - -} // namespace infrt::hlir::dialect +} // namespace dialect +} // namespace hlir +} // namespace infrt diff --git a/paddle/infrt/dialect/infrt_base.cc b/paddle/infrt/dialect/infrt_base.cc index b28ad5ad4b5a59c898cc08303626df09b2ef70c9..e8005661bbd6527f6c21076fd0f3a362a5541968 100644 --- a/paddle/infrt/dialect/infrt_base.cc +++ b/paddle/infrt/dialect/infrt_base.cc @@ -18,7 +18,8 @@ #include "paddle/infrt/dialect/dense_tensor.h" #include "paddle/infrt/dialect/test_kernels.h" -namespace infrt::dialect { +namespace infrt { +namespace dialect { // ----INFRTDialect definition begin---- void INFRTDialect::initialize() { @@ -124,4 +125,5 @@ void INFRTDialect::printType(mlir::Type type, // ----INFRTDialect definition end---- -} // namespace infrt::dialect +} // namespace dialect +} // namespace infrt diff --git a/paddle/infrt/dialect/infrt_base.h b/paddle/infrt/dialect/infrt_base.h index 58acd7c9a409a5a13f31bf3bd1688f0bb26b3e0b..1a7fbcf395a6e9be70de6021f8e60f94922f32c3 100644 --- a/paddle/infrt/dialect/infrt_base.h +++ b/paddle/infrt/dialect/infrt_base.h @@ -18,19 +18,17 @@ #include #include #include -#include #include #include #include "paddle/infrt/dialect/infrt_base.hpp.inc" -namespace infrt::dialect { - -class INFRTDialect : public ::mlir::Dialect { - explicit INFRTDialect(::mlir::MLIRContext *context) - : ::mlir::Dialect(getDialectNamespace(), - context, - ::mlir::TypeID::get()) { +namespace infrt { +namespace dialect { +class INFRTDialect : public mlir::Dialect { + explicit INFRTDialect(mlir::MLIRContext *context) + : mlir::Dialect( + getDialectNamespace(), context, mlir::TypeID::get()) { initialize(); } @@ -41,15 +39,12 @@ class INFRTDialect : public ::mlir::Dialect { mlir::DialectAsmPrinter &printer) const override; void initialize(); - friend class ::mlir::MLIRContext; + friend class mlir::MLIRContext; public: static ::llvm::StringRef getDialectNamespace() { return "infrt"; } }; - -} // namespace infrt::dialect - -namespace mlir { +} // namespace dialect template static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT @@ -58,17 +53,16 @@ static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT return b.getIntegerAttr(b.getI32Type(), constant); } -static mlir::SmallVector<::mlir::Value, 4> cvtValueToValueRange( +static mlir::SmallVector cvtValueToValueRange( const mlir::Value &operand) { - return mlir::SmallVector<::mlir::Value, 4>(1, operand); + return mlir::SmallVector(1, operand); } -static mlir::SmallVector<::mlir::Value, 4> concatTwoValueRange( +static mlir::SmallVector concatTwoValueRange( mlir::ValueRange operand_0, mlir::ValueRange operand_1) { - mlir::SmallVector<::mlir::Value, 4> operands; + mlir::SmallVector operands; operands.append(operand_0.begin(), operand_0.end()); operands.append(operand_1.begin(), operand_1.end()); return operands; } - -} // namespace mlir +} // namespace infrt diff --git a/paddle/infrt/dialect/infrt_base.td b/paddle/infrt/dialect/infrt_base.td index 7d6fdbbbf2f68f6629c2299f807cbb9fa7605f74..1abd294236d93cfb0aa7ce70db25f2acfb57a06a 100644 --- a/paddle/infrt/dialect/infrt_base.td +++ b/paddle/infrt/dialect/infrt_base.td @@ -28,11 +28,11 @@ def TensorMapType : def BufferType : OpaqueType<"b", "buffer", "buffer">; class INFRT_createI32Attr : NativeCodeCall< - "mlir::createI32Attr($_builder, $_loc, " # value # ")">; + "infrt::createI32Attr($_builder, $_loc, " # value # ")">; def INFRT_cvtValueToValueRange : NativeCodeCall< - "mlir::cvtValueToValueRange($0)">; + "infrt::cvtValueToValueRange($0)">; def INFRT_concatTwoValueRange : NativeCodeCall< - "mlir::concatTwoValueRange($0, $1)">; + "infrt::concatTwoValueRange($0, $1)">; #endif // INFRT_BASE diff --git a/paddle/infrt/dialect/init_infrt_dialects.cc b/paddle/infrt/dialect/init_infrt_dialects.cc index 4bc2bf70942d29723c731f90da446ee0acc257f5..c3769414dbb390566a177cfcec0b62009b53018a 100644 --- a/paddle/infrt/dialect/init_infrt_dialects.cc +++ b/paddle/infrt/dialect/init_infrt_dialects.cc @@ -23,12 +23,10 @@ #include "paddle/infrt/dialect/tensor_shape.h" namespace infrt { - -void RegisterCinnDialects(mlir::DialectRegistry& registry) { // NOLINT - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); +void registerCinnDialects(mlir::DialectRegistry ®istry) { // NOLINT + registry.insert(); } - } // namespace infrt diff --git a/paddle/infrt/dialect/init_infrt_dialects.h b/paddle/infrt/dialect/init_infrt_dialects.h index 50caca018980d05b112459ecf27f81e538cf9e2a..0912e9ef2555b49a7fd2d22c5e3ab6a457cbb05b 100644 --- a/paddle/infrt/dialect/init_infrt_dialects.h +++ b/paddle/infrt/dialect/init_infrt_dialects.h @@ -14,10 +14,8 @@ #pragma once -#include "mlir/IR/Dialect.h" - +#include +#include namespace infrt { - -void RegisterCinnDialects(mlir::DialectRegistry& registry); // NOLINT - +void registerCinnDialects(mlir::DialectRegistry ®istry); // NOLINT } // namespace infrt diff --git a/paddle/infrt/dialect/mlir_loader.cc b/paddle/infrt/dialect/mlir_loader.cc index b318a6a763483141de7c1521614cb82538615bb6..1d0696e77dcda612eeb8c367958e44e2efed5354 100644 --- a/paddle/infrt/dialect/mlir_loader.cc +++ b/paddle/infrt/dialect/mlir_loader.cc @@ -16,8 +16,8 @@ #include #include +#include #include -#include #include #include #include @@ -30,12 +30,15 @@ #include "paddle/infrt/dialect/diagnostic_utils.h" #include "paddle/infrt/dialect/init_infrt_dialects.h" -namespace infrt::dialect { +namespace infrt { +namespace dialect { mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context, const std::string& mlir_source) { // context->allowUnregisteredDialects(); - RegisterCinnDialects(context->getDialectRegistry()); + mlir::DialectRegistry registry; + registerCinnDialects(registry); + context->appendDialectRegistry(registry); // Currenetly, We only used the CinnDialect and mlir::BuiltinDialect is // enough。Don't need StandardOpsDialect. // context->getDialectRegistry().insert(); @@ -57,9 +60,9 @@ mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context, mlir::OwningModuleRef LoadMlirFile(const std::string& file_name, mlir::MLIRContext* context) { // context->allowUnregisteredDialects(); - RegisterCinnDialects(context->getDialectRegistry()); - context->getDialectRegistry().insert(); - + mlir::DialectRegistry registry; + registerCinnDialects(registry); + context->appendDialectRegistry(registry); mlir::ScopedDiagnosticHandler scope_handler( context, [](mlir::Diagnostic& diag) { if (diag.getSeverity() != mlir::DiagnosticSeverity::Error) @@ -71,4 +74,5 @@ mlir::OwningModuleRef LoadMlirFile(const std::string& file_name, return mlir::parseSourceFile(std::string(file_name), context); } -} // namespace infrt::dialect +} // namespace dialect +} // namespace infrt diff --git a/paddle/infrt/dialect/mlir_loader.h b/paddle/infrt/dialect/mlir_loader.h index 092da7d9ce03f64f43a2bfa237c7fa60983959a1..5e50ad9e5a27176a1bea32356b0cf343140bb441 100644 --- a/paddle/infrt/dialect/mlir_loader.h +++ b/paddle/infrt/dialect/mlir_loader.h @@ -15,16 +15,17 @@ #pragma once #include -#include +#include #include #include -namespace infrt::dialect { +namespace infrt { +namespace dialect { mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context, const std::string& mlir_source); mlir::OwningModuleRef LoadMlirFile(const std::string& file_name, mlir::MLIRContext* context); - -} // namespace infrt::dialect +} // namespace dialect +} // namespace infrt diff --git a/paddle/infrt/dialect/mlir_loader_test.cc b/paddle/infrt/dialect/mlir_loader_test.cc index 1b622d585ad8ee556ea8f35eb64560f49fb5710d..11150530730444ed74f547b9bb8abef5473c61b0 100644 --- a/paddle/infrt/dialect/mlir_loader_test.cc +++ b/paddle/infrt/dialect/mlir_loader_test.cc @@ -17,14 +17,15 @@ #include #include #include -#include +#include #include #include #include "paddle/infrt/dialect/init_infrt_dialects.h" -namespace infrt::dialect { +namespace infrt { +namespace dialect { TEST(MlirLoader, basic) { mlir::MLIRContext context; @@ -42,8 +43,7 @@ func @main() -> f32 { )ROC"; auto module = LoadMlirSource(&context, source); - module->verify(); - + EXPECT_TRUE(mlir::succeeded(module->verify())); LOG(INFO) << "module name: " << module->getOperationName().data(); for (auto func : module->getOps()) { LOG(INFO) << "get func " << func.getName().str(); @@ -54,4 +54,5 @@ func @main() -> f32 { } } -} // namespace infrt::dialect +} // namespace dialect +} // namespace infrt diff --git a/paddle/infrt/dialect/mlir_tests/rewrite.mlir b/paddle/infrt/dialect/mlir_tests/rewrite.mlir index bfad9d1f6924d4da7b818968ebb796cf8f346935..5e207634da8e4bb96719254700d7f30e4cdfe52a 100644 --- a/paddle/infrt/dialect/mlir_tests/rewrite.mlir +++ b/paddle/infrt/dialect/mlir_tests/rewrite.mlir @@ -20,5 +20,5 @@ func @main() -> tensor { %c2 = "pd.matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor, tensor) -> tensor %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor, tensor) -> tensor %e2 = "pd.relu"(%d2) {} : (tensor) -> tensor - infrt.return %e2 : tensor + "pd.fetch"(%e2) {name="output"} :(tensor)->() } \ No newline at end of file diff --git a/paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir b/paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir index 9ea1ec0ebca365b42be8d310793dc3c5f7dd4cf4..2889b92b18ef08fb6014eff948e2a5fc3d50c7f3 100644 --- a/paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir +++ b/paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir @@ -11,5 +11,5 @@ func @main() -> tensor { %c = "pd.conv2d"(%a, %filter, %bias) {} : (tensor, tensor<3x64x3x3xf32>, tensor<64xf32>) -> tensor %d = "pd.batch_norm"(%c, %scale, %bias2, %mean, %var) {} : (tensor, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor - infrt.return %d : tensor + "pd.fetch"(%d) {name="output"} :(tensor)->() } \ No newline at end of file diff --git a/paddle/infrt/dialect/mlir_tests/trt_ops.mlir b/paddle/infrt/dialect/mlir_tests/trt_ops.mlir index 009b6d1c19653e52a0ef0174892cdcbeccf18154..d98f107bab41e959d82acfd681d762d7981eab51 100644 --- a/paddle/infrt/dialect/mlir_tests/trt_ops.mlir +++ b/paddle/infrt/dialect/mlir_tests/trt_ops.mlir @@ -18,5 +18,5 @@ func @main() -> tensor { %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor, tensor) -> tensor %e2 = "pd.relu"(%d2) {} : (tensor) -> tensor - "pd.fetch"(%e2) :(tensor)->() + "pd.fetch"(%e2) {name="output"} :(tensor)->() } diff --git a/paddle/infrt/dialect/ops.td b/paddle/infrt/dialect/ops.td deleted file mode 100644 index 264134a447c63f637090e2f9919f2b97cad1ab4f..0000000000000000000000000000000000000000 --- a/paddle/infrt/dialect/ops.td +++ /dev/null @@ -1,6 +0,0 @@ -include "mlir/IR/OpBase.td" -include "paddle/infrt/dialect/infrt_base.td" - - -class INFRT_Op traits = []> : - Op; diff --git a/paddle/infrt/dialect/opt.cc b/paddle/infrt/dialect/opt.cc index d90d25230d0c24fb84ccdcf2cd282ba814b9a665..5bcf5a23f4c532b1056ceaa54c80902b32e4061a 100644 --- a/paddle/infrt/dialect/opt.cc +++ b/paddle/infrt/dialect/opt.cc @@ -12,34 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include #include - -#include - -#include "paddle/infrt/common/global.h" #include "paddle/infrt/dialect/init_infrt_dialects.h" -#include "paddle/infrt/dialect/mlir_loader.h" int main(int argc, char **argv) { - mlir::MLIRContext *context = infrt::Global::getMLIRContext(); - - auto ®istry = context->getDialectRegistry(); - infrt::RegisterCinnDialects(registry); - + mlir::DialectRegistry registry; + infrt::registerCinnDialects(registry); mlir::registerCanonicalizerPass(); - return mlir::failed( - mlir::MlirOptMain(argc, argv, "INFRT mlir pass driver", registry)); + mlir::MlirOptMain(argc, argv, "infrt mlir pass driver", registry)); } diff --git a/paddle/infrt/dialect/pd_op_base.td b/paddle/infrt/dialect/pd_op_base.td index af53df113dfb3e908d5066fed984a8c37942df25..a3e3c4ae592779c36f175ecfc20c154724be0863 100644 --- a/paddle/infrt/dialect/pd_op_base.td +++ b/paddle/infrt/dialect/pd_op_base.td @@ -16,7 +16,7 @@ def PD_Dialect : Dialect { This dialect contains the PaddlePaddle operators. }]; - let cppNamespace = "::mlir::pd"; + let cppNamespace = "mlir::pd"; } class PD_Op traits = []> : diff --git a/paddle/infrt/dialect/pd_ops.cc b/paddle/infrt/dialect/pd_ops.cc index ce10be6d100f82b3a431b45098121fc5011496e6..fe3899688384628b2f1a5cba577f5f46515275e0 100644 --- a/paddle/infrt/dialect/pd_ops.cc +++ b/paddle/infrt/dialect/pd_ops.cc @@ -14,10 +14,15 @@ #include "paddle/infrt/dialect/pd_ops.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/PatternMatch.h" +#include +#include #include "paddle/infrt/dialect/infrt_base.h" +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT + +#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT + namespace mlir { namespace pd { PaddleDialect::PaddleDialect(MLIRContext *context) @@ -36,12 +41,6 @@ mlir::Operation *PaddleDialect::materializeConstant(mlir::OpBuilder &builder, return builder.create(loc, value); } -#define GET_OP_CLASSES -#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT -#undef GET_OP_CLASSES - -#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT - void ConstantOp::build(OpBuilder &builder, OperationState &state, Attribute value) { @@ -66,8 +65,8 @@ LogicalResult ConstantOp::inferReturnTypes( inferredReturnTypes.push_back(attributes.get("value").getType()); return success(); } -::mlir::OpFoldResult ConstantOp::fold( - ::llvm::ArrayRef<::mlir::Attribute> operands) { +mlir::OpFoldResult ConstantOp::fold( + ::llvm::ArrayRef operands) { return value(); } @@ -82,11 +81,11 @@ LogicalResult ElementwiseAdd::inferReturnTypes( return success(); } void ElementwiseAdd::getCanonicalizationPatterns( - ::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { results.insert(context); } -::mlir::OpFoldResult ElementwiseAdd::fold( +mlir::OpFoldResult ElementwiseAdd::fold( llvm::ArrayRef operands) { if (getElementTypeOrSelf(getType()).isa()) { if (!operands[0] || !operands[1]) return {}; @@ -154,17 +153,17 @@ LogicalResult MulOp::inferReturnTypes( } void ReluOp::getCanonicalizationPatterns( - ::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { results.insert(context); } void FusedRepeatedFCRelu::getCanonicalizationPatterns( - ::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { results.insert(context); } void BatchNormOp::getCanonicalizationPatterns( - ::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { + mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) { results.insert(context); } diff --git a/paddle/infrt/dialect/pd_ops.h b/paddle/infrt/dialect/pd_ops.h index 71e0a53988d1ac8dbd9e1031f830360dc4167cc4..7d1d1d6f58451321a7edae50df4c19a043bf6b29 100644 --- a/paddle/infrt/dialect/pd_ops.h +++ b/paddle/infrt/dialect/pd_ops.h @@ -14,21 +14,20 @@ #pragma once -#include "mlir/Dialect/Traits.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/CallInterfaces.h" -#include "mlir/Interfaces/DerivedAttributeOpInterface.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Interfaces/LoopLikeInterface.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace mlir { namespace pd { @@ -53,9 +52,8 @@ class PaddleDialect : public Dialect { } }; -#define GET_OP_CLASSES -#include "paddle/infrt/dialect/pd_ops.hpp.inc" -#undef GET_OP_CLASSES - } // namespace pd } // namespace mlir + +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/pd_ops.hpp.inc" diff --git a/paddle/infrt/dialect/pd_ops.td b/paddle/infrt/dialect/pd_ops.td index b020b7ad5dbc783c1dba192bcb64f02080fbf93c..3addf15082a12c28341da53add36ab1541721b67 100644 --- a/paddle/infrt/dialect/pd_ops.td +++ b/paddle/infrt/dialect/pd_ops.td @@ -24,6 +24,16 @@ def PD_FeedOp : PD_Op<"feed"> { def PD_FetchOp : PD_Op<"fetch", [Terminator]> { let summary = "fetch Op"; + let description = [{ + Return the output tensor from the subgraph. + }]; + + let arguments = (ins PD_Tensor :$inputs, StrAttr:$name); +} + +def PD_ReturnOp : PD_Op<"return", [Terminator]> { + let summary = "return Op"; + let description = [{ Fetch tensor from the graph. }]; @@ -31,7 +41,7 @@ def PD_FetchOp : PD_Op<"fetch", [Terminator]> { let arguments = (ins Variadic:$inputs); } -def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> { +def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"ReturnOp">]> { let summary = "paddle graph Op"; let description = [{ Describe a paddle graph or subgraph. @@ -50,7 +60,7 @@ def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInte let hasFolder = 1; let builders = [ - OpBuilder<"OpBuilder &builder, OperationState &state, Attribute value">, + OpBuilder<(ins "Attribute":$value)>, ]; } diff --git a/paddle/infrt/dialect/pd_types.h b/paddle/infrt/dialect/pd_types.h index 6f9fe56338a9fd7e5a6b1d532d396cc75efe0415..0da888a9c076922fc21d5cce004dc839bd705762 100644 --- a/paddle/infrt/dialect/pd_types.h +++ b/paddle/infrt/dialect/pd_types.h @@ -18,12 +18,11 @@ #pragma once -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Types.h" +#include +#include +#include +#include +#include namespace mlir { namespace PD { diff --git a/paddle/infrt/dialect/print_ir.cc b/paddle/infrt/dialect/print_ir.cc index 43a3577b90f109c638aa08c00de3feb6e8150a7d..5cfd16ee859438c891d6ccf77b97e663620e584c 100644 --- a/paddle/infrt/dialect/print_ir.cc +++ b/paddle/infrt/dialect/print_ir.cc @@ -11,26 +11,25 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include -#include "llvm/ADT/Optional.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/ScopedPrinter.h" -#include "llvm/Support/raw_os_ostream.h" -#include "llvm/Support/raw_ostream.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/AsmState.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Region.h" -#include "mlir/IR/Verifier.h" -#include "mlir/Parser.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/Passes.h" #include "paddle/infrt/common/global.h" #include "paddle/infrt/dialect/init_infrt_dialects.h" @@ -114,17 +113,15 @@ int main(int argc, char **argv) { mlir::registerPassManagerCLOptions(); cl::ParseCommandLineOptions(argc, argv, "mlir demo"); - mlir::MLIRContext *context = infrt::Global::getMLIRContext(); - // context->allowUnregisteredDialects(); - auto ®istry = context->getDialectRegistry(); - infrt::RegisterCinnDialects(registry); - + mlir::DialectRegistry registry; + infrt::registerCinnDialects(registry); + mlir::MLIRContext context(registry); // mlir will verify module automatically after parsing. // https://github.com/llvm/llvm-project/blob/38d18d93534d290d045bbbfa86337e70f1139dc2/mlir/lib/Parser/Parser.cpp#L2051 // mlir::OwningModuleRef module_ref = mlir::parseSourceString(mlir_source, // context); mlir::OwningModuleRef module_ref = - mlir::parseSourceFile(inputFilename, context); + mlir::parseSourceFile(inputFilename, &context); std::cout << "----------print IR Structure begin----------" << std::endl; printOperation(module_ref->getOperation(), 0); std::cout << "----------print IR Structure end----------" << std::endl; diff --git a/paddle/infrt/dialect/tensor_shape.cc b/paddle/infrt/dialect/tensor_shape.cc index ef5a5525cb22f337f6111823283fadde7c6aff22..92c03818264ee7c44626042dd1de53b66bb8c54b 100644 --- a/paddle/infrt/dialect/tensor_shape.cc +++ b/paddle/infrt/dialect/tensor_shape.cc @@ -17,16 +17,16 @@ #include #include #include +#include +#include #include -#include -#include #include #include -#include #include #include -namespace infrt::ts { +namespace infrt { +namespace ts { using namespace mlir; // NOLINT void TensorShapeDialect::initialize() { @@ -48,8 +48,8 @@ Type TensorShapeDialect::parseType(DialectAsmParser &parser) const { return Type(); } -void TensorShapeDialect::printType(::mlir::Type type, - ::mlir::DialectAsmPrinter &os) const { +void TensorShapeDialect::printType(mlir::Type type, + mlir::DialectAsmPrinter &os) const { if (type.isa()) { os << "shape"; return; @@ -61,8 +61,10 @@ void TensorShapeDialect::printType(::mlir::Type type, } llvm_unreachable("unexpected 'shape' type kind"); } +} // namespace ts +} // namespace infrt #define GET_OP_CLASSES #include "paddle/infrt/dialect/tensor_shape.cpp.inc" // NOLINT -} // namespace infrt::ts +#include "paddle/infrt/dialect/tensor_shape_dialect.cpp.inc" diff --git a/paddle/infrt/dialect/tensor_shape.h b/paddle/infrt/dialect/tensor_shape.h index bd3fa8853675af4f1a19d2bdcf413cc0f80809fb..af892af735d2a4e2a8e97ac90e5fb2ba0e9fd1d8 100644 --- a/paddle/infrt/dialect/tensor_shape.h +++ b/paddle/infrt/dialect/tensor_shape.h @@ -17,7 +17,8 @@ #include #include -namespace infrt::ts { +namespace infrt { +namespace ts { class ShapeType : public mlir::Type::TypeBase { @@ -31,10 +32,9 @@ class PartialShapeType : public mlir::Type::TypeBase()">, "!ts.shape type">, BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> { - let typeDescription = [{ + let description = [{ `!ts.shape type` represents a static tensor shape. }]; } @@ -27,7 +27,7 @@ BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> { def TS_PartialShape : DialectType()">, "!ts.partial_shape type">, BuildableType<"$_builder.getType<::infrt::ts::PartialShapeType>()"> { - let typeDescription = [{ + let description = [{ `!ts.partial_shape type` represents either a static tensor shape, unranked tensor shape or a ranked tensor shape with unknown dimension sizes. }]; diff --git a/paddle/infrt/dialect/tensorrt/trt_exec.cc b/paddle/infrt/dialect/tensorrt/trt_exec.cc index dc0f2acb2b733e0f9d35f8153d6ac7f8ab0610cc..1baef7a3f77fdd9d3e363110ea3679aa942e222f 100644 --- a/paddle/infrt/dialect/tensorrt/trt_exec.cc +++ b/paddle/infrt/dialect/tensorrt/trt_exec.cc @@ -11,10 +11,10 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. +#include +#include #include #include -#include "llvm/Support/CommandLine.h" -#include "mlir/Pass/PassManager.h" #include "paddle/infrt/common/global.h" #include "paddle/infrt/dialect/mlir_loader.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc index 181f462962aeefa91ee572716090b86946a4cd42..1da80ef2c3b1000c045327510a03081f8aa954ca 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc @@ -14,14 +14,13 @@ #include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" +#include +#include +#include +#include #include #include #include -#include "llvm/ADT/SetVector.h" -#include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/IR/Builders.h" -#include "paddle/infrt/dialect/pd_ops.h" -#include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { namespace trt { @@ -32,9 +31,9 @@ namespace { // Reference the function nameed "FlexibleDFS" but defined in: // paddle/fluid/framework/ir/subgraph_detector.cc. -bool reverseDfs(std::vector<::mlir::Operation *> source, - const std::function &func) { - std::unordered_set visited; +bool reverseDfs(std::vector source, + const std::function &func) { + std::unordered_set visited; while (!source.empty()) { auto node = source.back(); source.pop_back(); @@ -44,7 +43,7 @@ bool reverseDfs(std::vector<::mlir::Operation *> source, auto values = node->getOperands(); for (auto value : values) { // if the value is a block argument, the node is nullptr. - ::mlir::Operation *node = value.getDefiningOp(); + mlir::Operation *node = value.getDefiningOp(); if (node != nullptr && !visited.count(node)) { source.emplace_back(node); } @@ -54,19 +53,19 @@ bool reverseDfs(std::vector<::mlir::Operation *> source, } // merge the first&second graph op to a new graph op. -void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT - ::mlir::pd::GraphOp first, - ::mlir::pd::GraphOp second) { +void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT + mlir::pd::GraphOp first, + mlir::pd::GraphOp second) { // comput inputs and outputs - ::llvm::SmallVector<::mlir::Value, 4> inputs(first.getOperands()), outputs; - for (::mlir::Value input : second.getOperands()) { + ::llvm::SmallVector inputs(first.getOperands()), outputs; + for (mlir::Value input : second.getOperands()) { if (input.getDefiningOp() != first) { inputs.push_back(input); } } - ::llvm::DenseMap<::mlir::Value, unsigned int> op_output_mapping; - for (::mlir::Value output : first.getResults()) { - for (::mlir::Operation *user : output.getUsers()) { + ::llvm::DenseMap op_output_mapping; + for (mlir::Value output : first.getResults()) { + for (mlir::Operation *user : output.getUsers()) { if (user != second && user->getParentOp() != second) { op_output_mapping[output] = outputs.size(); outputs.push_back(output); @@ -74,19 +73,19 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT } } } - auto fetch_op = second.getBody()->getTerminator(); - outputs.append(fetch_op->getOperands().begin(), - fetch_op->getOperands().end()); - ::llvm::SmallVector<::mlir::Type, 4> fetch_types; + auto return_op = second.getBody()->getTerminator(); + outputs.append(return_op->getOperands().begin(), + return_op->getOperands().end()); + ::llvm::SmallVector return_types; for (auto value : outputs) { - fetch_types.push_back(value.getType()); + return_types.push_back(value.getType()); } // create the new graph op builder.setInsertionPoint(first); auto loc = first.getLoc(); - auto graph_op = builder.create<::mlir::pd::GraphOp>(loc, fetch_types, inputs); - ::mlir::Block *block = new ::mlir::Block; + auto graph_op = builder.create(loc, return_types, inputs); + mlir::Block *block = new mlir::Block; auto copy_range = second.getBody()->without_terminator(); block->getOperations().splice(block->begin(), second.getBody()->getOperations(), @@ -98,18 +97,18 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT copy_range.begin(), copy_range.end()); builder.setInsertionPointToEnd(block); - builder.create(loc, outputs); + builder.create(loc, outputs); graph_op.body().push_back(block); // mapping the output unsigned int num_result = first.getNumResults(); - fetch_op = first.getBody()->getTerminator(); + return_op = first.getBody()->getTerminator(); for (unsigned int index = 0; index < num_result; ++index) { auto origin_value = first.getResult(index); if (op_output_mapping.find(origin_value) == op_output_mapping.end()) { - origin_value.replaceAllUsesWith(fetch_op->getOperand(index)); + origin_value.replaceAllUsesWith(return_op->getOperand(index)); } else { - auto inner_value = fetch_op->getOperand(index); + auto inner_value = return_op->getOperand(index); auto outer_value = graph_op.getResult(op_output_mapping[origin_value]); while (!origin_value.use_empty()) { auto replace_value = @@ -128,13 +127,13 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT // Topological sort the function op. void topoSortBlock(mlir::Block &body) { // NOLINT - llvm::SetVector toSort; + llvm::SetVector toSort; if (body.empty()) return; for (auto it = body.rbegin(); it != body.rend(); ++it) { toSort.insert(&*it); } - llvm::SetVector result = - ::mlir::topologicalSort(std::move(toSort)); + llvm::SetVector result = + mlir::topologicalSort(std::move(toSort)); for (auto *op : result) { op->moveBefore(body.getTerminator()); } @@ -145,21 +144,21 @@ void topoSortBlock(mlir::Block &body) { // NOLINT // Implementation of the trtGraphFusePass. void trtGraphFusePass::runOnFunction() { mlir::Block &body = getFunction().front(); - ::mlir::OpBuilder builder(&body, body.begin()); + mlir::OpBuilder builder(&body, body.begin()); bool changed = false; do { changed = false; for (auto &op : body) { - ::mlir::pd::GraphOp graph_op = - ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op); + mlir::pd::GraphOp graph_op = + ::llvm::dyn_cast_or_null(&op); if (nullptr == graph_op) continue; for (auto user_op : op.getUsers()) { - ::mlir::pd::GraphOp user_graph_op = - ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(user_op); + mlir::pd::GraphOp user_graph_op = + ::llvm::dyn_cast_or_null(user_op); if (nullptr == user_graph_op) continue; // get all dst input nodes except src. - std::vector<::mlir::Operation *> source_nodes; + std::vector source_nodes; for (auto operand : user_op->getOperands()) { auto input = operand.getDefiningOp(); if (input != &op && input != nullptr) { @@ -167,9 +166,8 @@ void trtGraphFusePass::runOnFunction() { } } // Reverse DFS from the source_nodes. - if (!reverseDfs(source_nodes, [&op](const ::mlir::Operation *n) { - return n == &op; - })) { + if (!reverseDfs(source_nodes, + [&op](const mlir::Operation *n) { return n == &op; })) { mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op); changed = true; break; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h index e7134e88f316c916787e7faba7f34432922d36c6..f1e555c6f67ecaadff76fb17f68ebaae1a6528e1 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h @@ -13,7 +13,7 @@ // limitations under the License. #pragma once -#include "mlir/Pass/Pass.h" +#include namespace infrt { namespace trt { @@ -28,15 +28,15 @@ namespace trt { * %a = "pd.feed"()... * %c = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... - * "pd.fetch" %m + * "pd.return" %m * } ... * %d = "pd.graph"(%c) { * %m = "pd.conv3d"(%c)... - * "pd.fetch" %m + * "pd.return" %m * } ... * %f = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... - * "pd.fetch" %m + * "pd.return" %m * } ... * "pd.fetch" %d, %f * @@ -47,13 +47,13 @@ namespace trt { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... - * "pd.fetch" %n, %s + * "pd.return" %n, %s * } ... * "pd.fetch" %d, %f * } */ class trtGraphFusePass - : public ::mlir::PassWrapper { + : public mlir::PassWrapper { public: ::llvm::StringRef getName() const override { return "trtGraphFusePass"; } void runOnFunction() override; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc index 2b45364de2036f0fc1747b42e860bf2a22b80b51..257f2b528542557db33121a4c304eb8e6f657007 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc @@ -14,7 +14,7 @@ #include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h" -#include "mlir/IR/Builders.h" +#include #include "paddle/infrt/dialect/pd_ops.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h" @@ -22,24 +22,24 @@ namespace infrt { namespace trt { // Implementation of the trtGraphSplitPass。 void trtGraphSplitPass::runOnFunction() { - std::vector<::mlir::pd::GraphOp> worklist; - ::mlir::Block& block = getFunction().front(); + std::vector worklist; + mlir::Block& block = getFunction().front(); for (auto& op : block) { - ::mlir::pd::GraphOp graph_op = - ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op); + mlir::pd::GraphOp graph_op = + ::llvm::dyn_cast_or_null(&op); if (nullptr != graph_op && graph_op.getBody()->getOperations().size() <= min_subgraph_size_) { worklist.push_back(graph_op); } } while (!worklist.empty()) { - ::mlir::pd::GraphOp graph_op = worklist.back(); + mlir::pd::GraphOp graph_op = worklist.back(); worklist.pop_back(); - ::mlir::Block* body = graph_op.getBody(); - auto fetch_op = body->getTerminator(); - graph_op.replaceAllUsesWith(fetch_op->getOperands()); + mlir::Block* body = graph_op.getBody(); + auto return_op = body->getTerminator(); + graph_op.replaceAllUsesWith(return_op->getOperands()); auto copy_range = body->without_terminator(); - block.getOperations().splice(::mlir::Block::iterator(graph_op), + block.getOperations().splice(mlir::Block::iterator(graph_op), body->getOperations(), copy_range.begin(), copy_range.end()); diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h index 092df0cf834e5995cf0c3c693a3cb4949856ca58..d30d186647fc32aa4e16047000ee4071effb900d 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h @@ -13,7 +13,7 @@ // limitations under the License. #pragma once -#include "mlir/Pass/Pass.h" +#include namespace infrt { namespace trt { @@ -31,9 +31,9 @@ namespace trt { * %m = "pd.conv2d"(%a)... * %n = "pd.conv3d"(%m)... * %s = "pd.conv2d"(%a)... - * "pd.fetch" %n, %s + * "pd.return" (%n, %s) * } ... - * "pd.fetch" %d, %f + * "pd.fetch" (%d, %f) * } * * destination func: @@ -42,11 +42,11 @@ namespace trt { * %c = "pd.conv2d"(%a) ... * %d = "pd.conv3d"(%c) ... * %f = "pd.conv2d"(%a) ... - * "pd.fetch" %d, %f + * "pd.fetch" (%d, %f) * } */ class trtGraphSplitPass - : public ::mlir::PassWrapper { + : public mlir::PassWrapper { public: ::llvm::StringRef getName() const override { return "trtGraphSplitPass"; } void runOnFunction() override; diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc index 7b7fbb05c1d13b834447932f63c7b394e14b9715..4e8d40b982b2eaf13aeef4f026d783c3f353c14b 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc @@ -14,49 +14,48 @@ #include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h" -#include "mlir/IR/Builders.h" +#include #include "paddle/infrt/dialect/pd_ops.h" -#include "paddle/infrt/dialect/tensorrt/trt_ops.h" namespace infrt { namespace trt { // Implementation of the trtOpTellerPass。 void trtOpTellerPass::runOnFunction() { - ::mlir::Block &body = getFunction().front(); - std::vector<::mlir::Operation *> worklist; + mlir::Block &body = getFunction().front(); + std::vector worklist; worklist.reserve(body.getOperations().size()); for (auto &op : body) { worklist.push_back(&op); } // Build GraphOp. - ::mlir::OpBuilder builder(&body, body.begin()); + mlir::OpBuilder builder(&body, body.begin()); while (!worklist.empty()) { auto *op = worklist.back(); worklist.pop_back(); if (op == nullptr) continue; - auto op1 = ::llvm::dyn_cast_or_null<::mlir::pd::FeedOp>(op); + auto op1 = ::llvm::dyn_cast_or_null(op); if (op1) continue; - auto op2 = ::llvm::dyn_cast_or_null<::mlir::pd::FetchOp>(op); + auto op2 = ::llvm::dyn_cast_or_null(op); if (op2) continue; - auto op3 = ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(op); + auto op3 = ::llvm::dyn_cast_or_null(op); if (op3) continue; builder.setInsertionPoint(op); auto loc = getFunction().getLoc(); - auto graph_op = builder.create<::mlir::pd::GraphOp>( + auto graph_op = builder.create( loc, op->getResultTypes(), op->getOperands()); - ::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; + ::llvm::SmallVector tblgen_repl_values; for (auto v : - ::llvm::SmallVector<::mlir::Value, 4>{graph_op.getODSResults(0)}) { + ::llvm::SmallVector{graph_op.getODSResults(0)}) { tblgen_repl_values.push_back(v); } op->replaceAllUsesWith(tblgen_repl_values); // Build graph op. - ::mlir::Block *block = new ::mlir::Block; + mlir::Block *block = new mlir::Block; graph_op.body().push_back(block); op->moveBefore(block, block->begin()); builder.setInsertionPointToEnd(block); - builder.create(loc, op->getResults()); + builder.create(loc, op->getResults()); } } } // namespace trt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h index b03945b3459c0237343006019fe15e8a2e508492..fb16c974f7fb3f923bdc460d62d8e5b9f628fff9 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h @@ -13,7 +13,7 @@ // limitations under the License. #pragma once -#include "mlir/Pass/Pass.h" +#include namespace infrt { namespace trt { @@ -29,7 +29,7 @@ namespace trt { * %c = "pd.conv2d"(%a) ... * %d = "pd.conv3d"(%c) ... * %f = "pd.conv2d"(%a) ... - * "pd.fetch" %d, %f + * "pd.fetch" (%d, %f) * } * * destination func: @@ -37,23 +37,23 @@ namespace trt { * %a = "pd.feed"()... * %c = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... - * "pd.fetch" %m + * "pd.return" (%m) * } ... * %d = "pd.graph"(%c) { * %m = "pd.conv3d"(%c)... - * "pd.fetch" %m + * "pd.return" (%m) * } ... * %f = "pd.graph"(%a) { * %m = "pd.conv2d"(%a)... - * "pd.fetch" %m + * "pd.return" (%m) * } ... - * "pd.fetch" %d, %f + * "pd.fetch" (%d, %f) * } * TODO(winter-wang): Supplementary how to judge the operators can be supported * by tensorrt. */ class trtOpTellerPass - : public ::mlir::PassWrapper { + : public mlir::PassWrapper { public: ::llvm::StringRef getName() const override { return "trtOpTellerPass"; } void runOnFunction() override; diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.cc b/paddle/infrt/dialect/tensorrt/trt_ops.cc index 4c02238b10e1da770454682d889addbe078b0a54..35b7967892cafcea66c382e5681ee43480b02735 100644 --- a/paddle/infrt/dialect/tensorrt/trt_ops.cc +++ b/paddle/infrt/dialect/tensorrt/trt_ops.cc @@ -13,27 +13,25 @@ // limitations under the License. #include "paddle/infrt/dialect/tensorrt/trt_ops.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Interfaces/CallInterfaces.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" +#include +#include +#include +#include +#include namespace infrt { namespace trt { -TensorRTDialect::TensorRTDialect(::mlir::MLIRContext *context) - : ::mlir::Dialect("trt", context, ::mlir::TypeID::get()) { +TensorRTDialect::TensorRTDialect(mlir::MLIRContext *context) + : mlir::Dialect("trt", context, mlir::TypeID::get()) { addOperations< #define GET_OP_LIST #include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT >(); -#undef GET_OP_LIST } -#define GET_OP_CLASSES -#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT -#undef GET_OP_CLASSES - } // namespace trt } // namespace infrt + +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT diff --git a/paddle/infrt/dialect/tensorrt/trt_ops.h b/paddle/infrt/dialect/tensorrt/trt_ops.h index c9043c2280de0f7970fb323876b10c68c6a63de7..a37491ec1abc7fd423fef23df5170936d2a769c7 100644 --- a/paddle/infrt/dialect/tensorrt/trt_ops.h +++ b/paddle/infrt/dialect/tensorrt/trt_ops.h @@ -14,37 +14,32 @@ #pragma once -#include "mlir/Dialect/Traits.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/Function.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/Module.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/Interfaces/CallInterfaces.h" -#include "mlir/Interfaces/DerivedAttributeOpInterface.h" -#include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Interfaces/LoopLikeInterface.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace infrt { namespace trt { -class TensorRTDialect : public ::mlir::Dialect { +class TensorRTDialect : public mlir::Dialect { public: - explicit TensorRTDialect(::mlir::MLIRContext* context); + explicit TensorRTDialect(mlir::MLIRContext* context); static llvm::StringRef getDialectNamespace() { return "trt"; } }; -// mlir bug。 can be removed safety when update mlir to llvm11. -using namespace mlir; // NOLINT +} // namespace trt +} // namespace infrt #define GET_OP_CLASSES #include "paddle/infrt/dialect/tensorrt/trt_ops.hpp.inc" -#undef GET_OP_CLASSES - -} // namespace trt -} // namespace infrt diff --git a/paddle/infrt/dialect/test_kernels.cc b/paddle/infrt/dialect/test_kernels.cc index 894d96f95ad5cb291ced0c71ecb94ec9ab879423..c4588d7cf8bab748832865fc3aaab1913f33d11b 100644 --- a/paddle/infrt/dialect/test_kernels.cc +++ b/paddle/infrt/dialect/test_kernels.cc @@ -14,14 +14,13 @@ #include "paddle/infrt/dialect/test_kernels.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/TypeUtilities.h" - -namespace infrt::dialect { +#include +#include +#include +#include +namespace infrt { +namespace dialect { //===----------------------------------------------------------------------===// // BenchmarkOp //===----------------------------------------------------------------------===// @@ -32,65 +31,67 @@ namespace infrt::dialect { // ... // } -static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT - OperationState &result) { // NOLINT - StringAttr nameAttr; +static mlir::ParseResult parseBenchmarkOp( + mlir::OpAsmParser &parser, // NOLINT + mlir::OperationState &result) { // NOLINT + mlir::StringAttr nameAttr; if (parser.parseAttribute(nameAttr, "name", result.attributes)) - return failure(); + return mlir::failure(); // Parse the operands, e.g. (%c : i32, %d : f32) - if (parser.parseLParen()) return failure(); + if (parser.parseLParen()) return mlir::failure(); - SmallVector operands; - SmallVector types; + llvm::SmallVector operands; + llvm::SmallVector types; llvm::SMLoc type_loc = parser.getCurrentLocation(); if (parser.parseOptionalRParen()) { // Parse non-empty operands do { // Parse %c : i32, - OpAsmParser::OperandType operand; - Type type; + mlir::OpAsmParser::OperandType operand; + mlir::Type type; if (parser.parseOperand(operand) || parser.parseColonType(type)) - return failure(); + return mlir::failure(); operands.push_back(operand); types.push_back(type); } while (succeeded(parser.parseOptionalComma())); - if (parser.parseRParen()) return failure(); + if (parser.parseRParen()) return mlir::failure(); } if (parser.resolveOperands(operands, types, type_loc, result.operands)) - return failure(); + return mlir::failure(); // Parse the keyword attribute, e.g. max_count = 100, duration_secs = 1 do { - StringRef attr; - Attribute resultAttr; + mlir::StringRef attr; + mlir::Attribute resultAttr; if (parser.parseKeyword(&attr) || parser.parseEqual() || parser.parseAttribute(resultAttr, parser.getBuilder().getIntegerType(32), attr, result.attributes)) - return failure(); - } while (succeeded(parser.parseOptionalComma())); + return mlir::failure(); + } while (mlir::succeeded(parser.parseOptionalComma())); // Set the default attribute num_warmup_runs to 1 if unset auto setDefaultAttrIfUnset = [&](const char *attr_name, int value) { bool found = llvm::any_of(result.attributes, - [attr_name](const NamedAttribute &attr) { - return attr.first == attr_name; + [attr_name](const mlir::NamedAttribute &attr) { + return attr.getName() == attr_name; }); if (!found) { - IntegerAttr default_val = parser.getBuilder().getI32IntegerAttr(value); + mlir::IntegerAttr default_val = + parser.getBuilder().getI32IntegerAttr(value); result.addAttribute(attr_name, default_val); } }; setDefaultAttrIfUnset("num_warmup_runs", 1); - Region *target = result.addRegion(); + mlir::Region *target = result.addRegion(); return parser.parseRegion(*target, operands, types, @@ -102,11 +103,11 @@ static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT // max_count = 100, duration_secs = 1 { // ... // } -static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT +static void print(mlir::OpAsmPrinter &p, BenchmarkOp op) { // NOLINT p << "infrt.benchmark "; // Print the name attribute, e.g "add.i32" - auto name_attr = op.getAttr("name"); + auto name_attr = op->getAttr("name"); p << name_attr; // Print the operands and types, e.g. (%c : i32, %d : f32) @@ -120,13 +121,13 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT bool need_comma = false; // Print the attributes, e.g. max_count = 100, duration_secs = 1 - for (auto &name_attr : op.getAttrs()) { - auto id = name_attr.first; + for (auto &name_attr : op->getAttrs()) { + auto id = name_attr.getName(); if (id == "name") continue; if (need_comma) p << ", "; - auto attr = name_attr.second; + auto attr = name_attr.getValue(); p << id << " = "; - if (auto int_attr = attr.dyn_cast()) { + if (auto int_attr = attr.dyn_cast()) { int_attr.getValue().print(p.getStream(), /*isSigned=*/false); } else { op.emitOpError("Unexpected attribute"); @@ -142,7 +143,7 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT p.printRegion(op.region(), /*printEntryBlockArgs=*/false); } -static LogicalResult verify(BenchmarkOp op) { +static mlir::LogicalResult verify(BenchmarkOp op) { // Verify that the target benchmark region has exactly one return value. auto ®ion = op.region(); auto &last_op = region.front().back(); @@ -154,10 +155,10 @@ static LogicalResult verify(BenchmarkOp op) { "incorrect number of return values. One return value is expected"); } - return success(); + return mlir::success(); } +} // namespace dialect +} // namespace infrt #define GET_OP_CLASSES #include "paddle/infrt/dialect/test_kernels.cpp.inc" - -} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/test_kernels.h b/paddle/infrt/dialect/test_kernels.h index 29d4209cb7280e0d3d9947c1a9d0cfff75ade01b..73c8a6fb387bca6ebc7ae393e4bba32ab94aa951 100644 --- a/paddle/infrt/dialect/test_kernels.h +++ b/paddle/infrt/dialect/test_kernels.h @@ -13,11 +13,8 @@ // limitations under the License. #pragma once -#include "mlir/IR/OpDefinition.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" +#include +#include -namespace infrt::dialect { -using namespace mlir; // NOLINT #define GET_OP_CLASSES #include "paddle/infrt/dialect/test_kernels.hpp.inc" -} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/types.cc b/paddle/infrt/dialect/types.cc deleted file mode 100644 index 6d6f6a20b46c90d0bdbb79e5b732255b4a6e27bf..0000000000000000000000000000000000000000 --- a/paddle/infrt/dialect/types.cc +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/infrt/dialect/types.h" - -namespace infrt::hlir::mlir {} // namespace infrt::hlir::mlir diff --git a/paddle/infrt/dialect/types.h b/paddle/infrt/dialect/types.h deleted file mode 100644 index a9a2b61871cc0911b756deddda8ba60fade4ac94..0000000000000000000000000000000000000000 --- a/paddle/infrt/dialect/types.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include diff --git a/paddle/infrt/host_context/core_runtime.cc b/paddle/infrt/host_context/core_runtime.cc index cdb8cc99ecb2631d1b9cdf1b8adb830fe9e826a5..e3917bd07d24248becb013e2d6ef6546608285f9 100644 --- a/paddle/infrt/host_context/core_runtime.cc +++ b/paddle/infrt/host_context/core_runtime.cc @@ -23,7 +23,8 @@ #include "paddle/infrt/host_context/op_executable.h" #include "paddle/infrt/host_context/symbol_table.h" -namespace infrt::host_context { +namespace infrt { +namespace host_context { struct CoreRuntime::Impl { KernelRegistry* kernel_registry{}; @@ -90,4 +91,5 @@ llvm::SmallVector CoreRuntime::GetResults( CoreRuntime::~CoreRuntime() {} -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/core_runtime.h b/paddle/infrt/host_context/core_runtime.h index 802f8b17bb0105169c269e6dae9f37331655a1de..acb6a66cac630f695afbdcc527d7b397973aa84f 100644 --- a/paddle/infrt/host_context/core_runtime.h +++ b/paddle/infrt/host_context/core_runtime.h @@ -22,7 +22,8 @@ #include "paddle/infrt/host_context/value.h" -namespace infrt::host_context { +namespace infrt { +namespace host_context { class KernelRegistry; class OpExecutable; @@ -83,4 +84,5 @@ class CoreRuntimeBuilder : public CoreRuntime { OpExecutableBuilder* NewOpExecutable(const std::string& op_name); }; -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/kernel_frame.h b/paddle/infrt/host_context/kernel_frame.h index 20cb17dc7fbe241557f70e5d0e2f6cf15dc69b56..5186b88fe2c41a8b4939dd70fde9123549764856 100644 --- a/paddle/infrt/host_context/kernel_frame.h +++ b/paddle/infrt/host_context/kernel_frame.h @@ -21,7 +21,8 @@ #include "llvm/ADT/SmallVector.h" #include "paddle/infrt/host_context/value.h" -namespace infrt::host_context { +namespace infrt { +namespace host_context { /** * KernelFrame captures the states(input arguments, attributes, results) @@ -163,4 +164,5 @@ class KernelFrameBuilder : public KernelFrame { } }; -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/kernel_registry_test.cc b/paddle/infrt/host_context/kernel_registry_test.cc index f36ec2a1cac7ded8bd1fc6c30061ce001bdeda1c..7fca56343041c2827f0dce57ca98fb9158ef66f4 100644 --- a/paddle/infrt/host_context/kernel_registry_test.cc +++ b/paddle/infrt/host_context/kernel_registry_test.cc @@ -18,7 +18,8 @@ #include "paddle/infrt/host_context/kernel_utils.h" -namespace infrt::host_context { +namespace infrt { +namespace host_context { int add_i32(int a, int b) { return a + b; } @@ -44,4 +45,5 @@ TEST(KernelRegistry, basic) { ASSERT_EQ(results[0]->get(), 3); } -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/kernel_utils_test.cc b/paddle/infrt/host_context/kernel_utils_test.cc index 1904eb106a29375f4997ec099151835d409e09b8..bebd8d86e50bbd6a2d80325f9fbd8254718c8d0a 100644 --- a/paddle/infrt/host_context/kernel_utils_test.cc +++ b/paddle/infrt/host_context/kernel_utils_test.cc @@ -16,7 +16,8 @@ #include -namespace infrt::host_context { +namespace infrt { +namespace host_context { int add_i32(int a, int b) { return a + b; } float add_f32(float a, float b) { return a + b; } @@ -66,4 +67,5 @@ TEST(KernelImpl, pair) { ASSERT_EQ(results[1]->get(), 3.f); } -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/mlir_function_executable.cc b/paddle/infrt/host_context/mlir_function_executable.cc index 5f8dacf8e448acc494856fb1c7117d61b3075190..47ec27ebec300f1cedd57b11e0dd1e6b37611141 100644 --- a/paddle/infrt/host_context/mlir_function_executable.cc +++ b/paddle/infrt/host_context/mlir_function_executable.cc @@ -15,6 +15,7 @@ #include "paddle/infrt/host_context/mlir_function_executable.h" #include +#include #include // NOLINT diff --git a/paddle/infrt/host_context/mlir_function_executable.h b/paddle/infrt/host_context/mlir_function_executable.h index ba5fa154d6fcc3183c3a882e1eb1bd05daa66129..a6428df86e6b27061d92856970682bc29499d825 100644 --- a/paddle/infrt/host_context/mlir_function_executable.h +++ b/paddle/infrt/host_context/mlir_function_executable.h @@ -13,7 +13,8 @@ // limitations under the License. #pragma once -#include +#include +#include #include #include diff --git a/paddle/infrt/host_context/mlir_program_executor.h b/paddle/infrt/host_context/mlir_program_executor.h index b2af4d2d79db54aa02a34b0371c63f992c055f58..c2ccb90640b21bcfb675a707d6cb60cf5028ab36 100644 --- a/paddle/infrt/host_context/mlir_program_executor.h +++ b/paddle/infrt/host_context/mlir_program_executor.h @@ -15,9 +15,9 @@ #pragma once #include +#include +#include #include -#include -#include #include #include diff --git a/paddle/infrt/host_context/mlir_to_runtime_translate.cc b/paddle/infrt/host_context/mlir_to_runtime_translate.cc index 25324b1291582b406eb5b33c1241609a9e2ed5d6..3dbc7a702be38d986b6f77b345abe85f939370e6 100644 --- a/paddle/infrt/host_context/mlir_to_runtime_translate.cc +++ b/paddle/infrt/host_context/mlir_to_runtime_translate.cc @@ -16,8 +16,9 @@ #include #include +#include +#include #include -#include #include #include @@ -40,7 +41,8 @@ #include "paddle/infrt/host_context/value.h" #include "paddle/infrt/tensor/tensor_shape.h" -namespace infrt::host_context { +namespace infrt { +namespace host_context { template std::string DumpToString(T& op) { // NOLINT @@ -113,10 +115,10 @@ bool MlirToRuntimeTranslator::EmitConstantOp(mlir::Operation* op) { template <> boost::optional MlirToRuntimeTranslator::EmitAttribute( - const mlir::Attribute* attr) { - if (!attr->isa()) return boost::none; - if (attr->isa()) { - auto val = attr->cast(); + const mlir::Attribute& attr) { + if (!attr.isa()) return boost::none; + if (attr.isa()) { + auto val = attr.cast(); if (val.getType().isInteger(32)) { return val.getInt(); } @@ -125,10 +127,10 @@ boost::optional MlirToRuntimeTranslator::EmitAttribute( } template <> boost::optional MlirToRuntimeTranslator::EmitAttribute( - const mlir::Attribute* attr) { - if (!attr->isa()) return boost::none; - if (attr->isa()) { - auto val = attr->cast(); + const mlir::Attribute& attr) { + if (!attr.isa()) return boost::none; + if (attr.isa()) { + auto val = attr.cast(); if (val.getType().isInteger(64)) { return val.getInt(); } @@ -139,10 +141,10 @@ boost::optional MlirToRuntimeTranslator::EmitAttribute( // TODO(Superjomn) Make double and float parsing share some thing. template <> boost::optional MlirToRuntimeTranslator::EmitAttribute( - const mlir::Attribute* attr) { - if (!attr->isa()) return boost::none; - if (attr->isa()) { - auto val = attr->cast(); + const mlir::Attribute& attr) { + if (!attr.isa()) return boost::none; + if (attr.isa()) { + auto val = attr.cast(); if (val.getType().isF32()) return val.getValueAsDouble(); } return boost::none; @@ -150,10 +152,10 @@ boost::optional MlirToRuntimeTranslator::EmitAttribute( template <> boost::optional MlirToRuntimeTranslator::EmitAttribute( - const mlir::Attribute* attr) { - if (!attr->isa()) return boost::none; - if (attr->isa()) { - auto val = attr->cast(); + const mlir::Attribute& attr) { + if (!attr.isa()) return boost::none; + if (attr.isa()) { + auto val = attr.cast(); if (val.getType().isF64()) return val.getValueAsDouble(); } return boost::none; @@ -161,17 +163,17 @@ boost::optional MlirToRuntimeTranslator::EmitAttribute( template <> boost::optional MlirToRuntimeTranslator::EmitAttribute( - const mlir::Attribute* attr) { - if (!attr->isa()) return boost::none; - return attr->cast().getValue().str(); + const mlir::Attribute& attr) { + if (!attr.isa()) return boost::none; + return attr.cast().getValue().str(); } #define PROCESS_ARRAY_INT(type__, bits__) \ template <> \ boost::optional> MlirToRuntimeTranslator::EmitAttribute( \ - const mlir::Attribute* attr) { \ - if (!attr->isa()) return boost::none; \ - auto array = attr->cast(); \ + const mlir::Attribute& attr) { \ + if (!attr.isa()) return boost::none; \ + auto array = attr.cast(); \ CHECK(!array.empty()); \ \ if (!array[0].getType().isInteger(bits__)) { \ @@ -191,9 +193,9 @@ PROCESS_ARRAY_INT(int64_t, 64); template <> boost::optional> MlirToRuntimeTranslator::EmitAttribute( - const mlir::Attribute* attr) { - if (!attr->isa()) return boost::none; - auto array = attr->cast(); + const mlir::Attribute& attr) { + if (!attr.isa()) return boost::none; + auto array = attr.cast(); CHECK(!array.empty()); if (!array[0].getType().isF32()) return boost::none; @@ -207,9 +209,9 @@ boost::optional> MlirToRuntimeTranslator::EmitAttribute( template <> boost::optional> MlirToRuntimeTranslator::EmitAttribute( - const mlir::Attribute* attr) { - if (!attr->isa()) return boost::none; - auto array = attr->cast(); + const mlir::Attribute& attr) { + if (!attr.isa()) return boost::none; + auto array = attr.cast(); CHECK(!array.empty()); if (!array[0].getType().isF64()) return boost::none; @@ -236,7 +238,8 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { for (int i = 0, e = op->getNumOperands(); i < e; i++) { // function argument as value auto operand = op->getOperand(i); - if (operand.getKind() == mlir::Value::Kind::BlockArgument) { + /// if (operand.getKind() == mlir::Value::Kind::BlockArgument) { + if (operand.isa()) { mlir::BlockArgument arg = operand.dyn_cast(); Value* arg_value = GetValue(arg); impl_->cur_op->AppendArgument(arg_value); @@ -283,25 +286,25 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { for (size_t i = 0; i < attrs.size(); i++) { auto& attr = attrs[i]; - if (auto v = EmitAttribute(&attr.second)) { + if (auto v = EmitAttribute(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(*v)); - } else if (auto v = EmitAttribute(&attr.second)) { + } else if (auto v = EmitAttribute(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(*v)); - } else if (auto v = EmitAttribute(&attr.second)) { + } else if (auto v = EmitAttribute(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(*v)); - } else if (auto v = EmitAttribute(&attr.second)) { + } else if (auto v = EmitAttribute(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(*v)); - } else if (auto v = EmitAttribute(&attr.second)) { + } else if (auto v = EmitAttribute(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(std::move(*v))); - } else if (auto v = EmitAttribute>(&attr.second)) { + } else if (auto v = EmitAttribute>(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(std::move(*v))); - } else if (auto v = EmitAttribute>(&attr.second)) { + } else if (auto v = EmitAttribute>(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(std::move(*v))); - } else if (auto v = EmitAttribute>(&attr.second)) { + } else if (auto v = EmitAttribute>(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(std::move(*v))); - } else if (auto v = EmitAttribute>(&attr.second)) { + } else if (auto v = EmitAttribute>(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(std::move(*v))); - } else if (auto v = EmitAttribute>(&attr.second)) { + } else if (auto v = EmitAttribute>(attr.getValue())) { impl_->cur_op->AppendAttribute(new Value(std::move(*v))); } else { LOG(FATAL) << "Not supported attribute type"; @@ -330,7 +333,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { llvm::SmallVector results; auto func_type = - mlir::FunctionType::get(inputs, results, region.getContext()); + mlir::FunctionType::get(region.getContext(), inputs, results); auto* function = impl_->cur_op->CreateFunctionExecutable( ®ion, func_type, &impl_->func_defs); impl_->cur_op->AppendAttribute(new Value(function)); @@ -555,4 +558,5 @@ void TestMlir(mlir::ModuleOp module, KernelRegistry* registry) { execute.Run(); } -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/mlir_to_runtime_translate.h b/paddle/infrt/host_context/mlir_to_runtime_translate.h index 598e81bfd96d8acbc6d7eeba046df701a955b628..fcd79eaf386eed5a6a8eaa31712e344bab56dbd4 100644 --- a/paddle/infrt/host_context/mlir_to_runtime_translate.h +++ b/paddle/infrt/host_context/mlir_to_runtime_translate.h @@ -29,7 +29,8 @@ class Attribute; class Value; } // namespace mlir -namespace infrt::host_context { +namespace infrt { +namespace host_context { class CoreRuntimeBuilder; class Value; @@ -73,7 +74,7 @@ class MlirToRuntimeTranslator { bool EmitCallOp(mlir::Operation* op, function_defs_t* function_table); template - boost::optional EmitAttribute(const mlir::Attribute* attr); + boost::optional EmitAttribute(const mlir::Attribute& attr); Value* GetOpResult(mlir::Operation* op); @@ -104,4 +105,5 @@ void MlirToRuntimeTranslate(mlir::ModuleOp module, CoreRuntimeBuilder* runtime); */ void TestMlir(mlir::ModuleOp module, KernelRegistry* registry); -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/mlir_to_runtime_translate_test.cc b/paddle/infrt/host_context/mlir_to_runtime_translate_test.cc index 9b85be977ab6c1964b006385dfdc78414f1e482b..375daa4515e17fe1618c71d642825d112a3f788f 100644 --- a/paddle/infrt/host_context/mlir_to_runtime_translate_test.cc +++ b/paddle/infrt/host_context/mlir_to_runtime_translate_test.cc @@ -29,7 +29,8 @@ #include "paddle/infrt/kernel/tensor_shape_kernels.h" #include "paddle/infrt/kernel/test_kernels.h" -namespace infrt::host_context { +namespace infrt { +namespace host_context { TEST(MlirToRuntimeTranslate, basic) { mlir::MLIRContext context; @@ -48,7 +49,7 @@ func @main() -> () { )ROC"; auto module = dialect::LoadMlirSource(&context, source); - module->verify(); + EXPECT_TRUE(mlir::succeeded(module->verify())); KernelRegistry registry; kernel::RegisterFloatBasicKernels(®istry); @@ -74,7 +75,7 @@ func @main() -> () { )ROC"; auto module = dialect::LoadMlirSource(&context, source); - module->verify(); + EXPECT_TRUE(mlir::succeeded(module->verify())); KernelRegistry registry; kernel::RegisterFloatBasicKernels(®istry); @@ -115,7 +116,7 @@ infrt.return %a0, %b0: !infrt.tensor, !infrt.tensorverify(); + EXPECT_TRUE(mlir::succeeded(module->verify())); host_context::KernelRegistry registry; @@ -157,4 +158,5 @@ infrt.return %a0, %b0: !infrt.tensor, !infrt.tensor #include #include "paddle/infrt/host_context/kernel_frame.h" @@ -21,7 +22,8 @@ #include "paddle/infrt/host_context/mlir_function_executable.h" #include "paddle/infrt/host_context/symbol_table.h" -namespace infrt::host_context { +namespace infrt { +namespace host_context { struct OpExecutable::Impl { Impl(const std::string& op_name, @@ -148,4 +150,5 @@ void OpExecutable::Execute() { OpExecutable::~OpExecutable() {} -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/op_executable.h b/paddle/infrt/host_context/op_executable.h index e2248225a5cafa44be27604ad3b5f606c37cf6c7..550f6ab6349ed2f3f503ea7b0b425f7dbc1aea2c 100644 --- a/paddle/infrt/host_context/op_executable.h +++ b/paddle/infrt/host_context/op_executable.h @@ -14,19 +14,18 @@ #pragma once #include - +#include +#include #include #include #include -#include "mlir/IR/Function.h" -#include "mlir/IR/Region.h" - namespace mlir { class FuncOp; } // namespace mlir -namespace infrt::host_context { +namespace infrt { +namespace host_context { class SymbolTable; class KernelRegistry; @@ -89,4 +88,5 @@ class OpExecutableBuilder : public OpExecutable { function_defs_t* function_defs); }; -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/kernel/basic_kernels.cc b/paddle/infrt/kernel/basic_kernels.cc index d7f2c3865157dd973e21f8527a4804e4e3209bf3..b186cfcfd2b355f97711ecc916e497c2916d4060 100644 --- a/paddle/infrt/kernel/basic_kernels.cc +++ b/paddle/infrt/kernel/basic_kernels.cc @@ -23,7 +23,8 @@ using infrt::host_context::Attribute; -namespace infrt::kernel { +namespace infrt { +namespace kernel { template T add(T a, T b) { @@ -82,4 +83,5 @@ void RegisterFloatBasicKernels(host_context::KernelRegistry *registry) { registry->AddKernel("infrt.print.f32", INFRT_KERNEL(print)); } -} // namespace infrt::kernel +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/basic_kernels.h b/paddle/infrt/kernel/basic_kernels.h index 9e98885cf6ebfb8e000424874da70f3a34e2e127..feb66be61f530676cf79a32be1e52d69017d21bc 100644 --- a/paddle/infrt/kernel/basic_kernels.h +++ b/paddle/infrt/kernel/basic_kernels.h @@ -15,13 +15,16 @@ #pragma once #include -namespace infrt::host_context { +namespace infrt { +namespace host_context { struct KernelRegistry; -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt -namespace infrt::kernel { +namespace infrt { +namespace kernel { /** * Register all the basic kernels to \p registry. @@ -31,4 +34,5 @@ void RegisterBasicKernels(host_context::KernelRegistry* registry); void RegisterIntBasicKernels(host_context::KernelRegistry* registry); void RegisterFloatBasicKernels(host_context::KernelRegistry* registry); -} // namespace infrt::kernel +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/tensor_kernels.cc b/paddle/infrt/kernel/tensor_kernels.cc index 2fa477aa4dbda6f7282e65d705c27d433f2839c1..51e000492237435de555bc53bb63d23ce7ecbeb2 100644 --- a/paddle/infrt/kernel/tensor_kernels.cc +++ b/paddle/infrt/kernel/tensor_kernels.cc @@ -25,7 +25,8 @@ #include "paddle/infrt/tensor/tensor_map.h" #include "paddle/infrt/tensor/tensor_shape.h" -namespace infrt::kernel { +namespace infrt { +namespace kernel { using namespace host_context; // NOLINT using namespace tensor; // NOLINT @@ -76,4 +77,5 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) { INFRT_KERNEL(ShallowCopyTensor)); } -} // namespace infrt::kernel +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/tensor_kernels.h b/paddle/infrt/kernel/tensor_kernels.h index 8f2180ba80a4f81c910aafe915d61288da99c930..df8e25c32393c903c3e6801e23095aeff6eca9b4 100644 --- a/paddle/infrt/kernel/tensor_kernels.h +++ b/paddle/infrt/kernel/tensor_kernels.h @@ -14,12 +14,16 @@ #pragma once -namespace infrt::host_context { +namespace infrt { +namespace host_context { struct KernelRegistry; -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt -namespace infrt::kernel { +namespace infrt { +namespace kernel { void RegisterTensorKernels(host_context::KernelRegistry* registry); -} // namespace infrt::kernel +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/tensor_shape_kernels.cc b/paddle/infrt/kernel/tensor_shape_kernels.cc index a04b492819298b3b6673324856d6277cc1ab6bea..4edbecfa108869ee1f8181c8efd42adc91224d6d 100644 --- a/paddle/infrt/kernel/tensor_shape_kernels.cc +++ b/paddle/infrt/kernel/tensor_shape_kernels.cc @@ -24,7 +24,8 @@ #include "paddle/infrt/host_context/kernel_utils.h" #include "paddle/infrt/tensor/tensor_shape.h" -namespace infrt::kernel { +namespace infrt { +namespace kernel { void PrintShape(const tensor::TensorShape& shape) { llvm::raw_os_ostream oos(std::cout); @@ -35,4 +36,5 @@ void RegisterTensorShapeKernels(host_context::KernelRegistry* registry) { registry->AddKernel("ts.print_shape", INFRT_KERNEL(PrintShape)); } -} // namespace infrt::kernel +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/tensor_shape_kernels.h b/paddle/infrt/kernel/tensor_shape_kernels.h index e87c6c37e88a08fa2b2c85d35621786e9a46e65e..e31a37463be43bcc997368bd9693b3d866eff454 100644 --- a/paddle/infrt/kernel/tensor_shape_kernels.h +++ b/paddle/infrt/kernel/tensor_shape_kernels.h @@ -14,14 +14,18 @@ #pragma once -namespace infrt::host_context { +namespace infrt { +namespace host_context { class KernelRegistry; -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt -namespace infrt::kernel { +namespace infrt { +namespace kernel { void RegisterTensorShapeKernels(host_context::KernelRegistry* registry); -} // namespace infrt::kernel +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/test_kernels.cc b/paddle/infrt/kernel/test_kernels.cc index d5f64d09b602fd8696eba966f717d440214043c8..ccfb3356a855f418f14e42ed8a368f31d2fe8b27 100644 --- a/paddle/infrt/kernel/test_kernels.cc +++ b/paddle/infrt/kernel/test_kernels.cc @@ -33,7 +33,8 @@ using infrt::host_context::Attribute; using infrt::host_context::MlirFunctionExecutable; using infrt::host_context::RemainingArguments; -namespace infrt::kernel { +namespace infrt { +namespace kernel { namespace { class BenchmarkStats { public: @@ -197,4 +198,5 @@ void RegisterTestKernels(host_context::KernelRegistry *registry) { INFRT_KERNEL(ShadowCopyTensor)); } -} // namespace infrt::kernel +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/test_kernels.h b/paddle/infrt/kernel/test_kernels.h index f42884dfaf2c9005b31e0ef335d1316625337a6f..f5639ec1afaad769d62530c4ef91eafa35779218 100644 --- a/paddle/infrt/kernel/test_kernels.h +++ b/paddle/infrt/kernel/test_kernels.h @@ -15,17 +15,21 @@ #pragma once #include -namespace infrt::host_context { +namespace infrt { +namespace host_context { struct KernelRegistry; -} // namespace infrt::host_context +} // namespace host_context +} // namespace infrt -namespace infrt::kernel { +namespace infrt { +namespace kernel { /** * Register all the test kernels to registry. */ void RegisterTestKernels(host_context::KernelRegistry* registry); -} // namespace infrt::kernel +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/paddle/cpp/desc_api.h b/paddle/infrt/paddle/cpp/desc_api.h index ccd79c048ab14593838b5173cadcbe979019045a..3b2dcb0018b2fcf733585ce28dac16aadffd7639 100644 --- a/paddle/infrt/paddle/cpp/desc_api.h +++ b/paddle/infrt/paddle/cpp/desc_api.h @@ -18,7 +18,9 @@ #include #include -namespace infrt::paddle::cpp { +namespace infrt { +namespace paddle { +namespace cpp { /* * Compatible interfaces for all the different kinds of XXXDesc. All the XXXDesc @@ -226,4 +228,6 @@ class ProgramDescAPI { virtual void SetVersion(int64_t version) = 0; }; -} // namespace infrt::paddle::cpp +} // namespace cpp +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/model_parser.cc b/paddle/infrt/paddle/model_parser.cc index 285280e69435b046a3c3073faf575de31662a2b5..f3de1a630451cc387765040191be8715768be510 100644 --- a/paddle/infrt/paddle/model_parser.cc +++ b/paddle/infrt/paddle/model_parser.cc @@ -22,7 +22,8 @@ #include "paddle/infrt/common/target.h" #include "paddle/infrt/common/type.h" -namespace infrt::paddle { +namespace infrt { +namespace paddle { int SizeOfType(framework_proto::VarType::Type type) { using Type = framework_proto::VarType::Type; @@ -169,4 +170,5 @@ void LoadParam(const std::string &path, _Variable *out, const Target &target) { LoadLoDTensor(fin, out, target); } -} // namespace infrt::paddle +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/model_parser.h b/paddle/infrt/paddle/model_parser.h index 73125fadedb82b9dd628fc0fb65a3b5c54d24a42..373f77033dcefa1a81cd8756da859b6d232337a0 100644 --- a/paddle/infrt/paddle/model_parser.h +++ b/paddle/infrt/paddle/model_parser.h @@ -25,7 +25,8 @@ #include "paddle/infrt/paddle/scope.h" #include "paddle/infrt/paddle/tensor.h" -namespace infrt::paddle { +namespace infrt { +namespace paddle { namespace framework_proto = ::paddle::framework::proto; // Read a __model__ file. @@ -52,4 +53,5 @@ void TensorFromStream( const common::Target& target = common::DefaultHostTarget()); void ReadBinaryFile(const std::string& filename, std::string* contents); -} // namespace infrt::paddle +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/pb/block_desc.cc b/paddle/infrt/paddle/pb/block_desc.cc index 11186bc68af1640e5a83559d3eda4ca958eab8b4..5b28fa5464c547a9badeefef0ef5888fc10ccaaf 100644 --- a/paddle/infrt/paddle/pb/block_desc.cc +++ b/paddle/infrt/paddle/pb/block_desc.cc @@ -14,7 +14,9 @@ #include "paddle/infrt/paddle/pb/block_desc.h" -namespace infrt::paddle::pb { +namespace infrt { +namespace paddle { +namespace pb { template <> framework_proto::VarDesc* BlockDesc::GetVar( @@ -40,4 +42,6 @@ framework_proto::OpDesc* BlockDesc::AddOp() { return desc_->add_ops(); } -} // namespace infrt::paddle::pb +} // namespace pb +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/pb/block_desc.h b/paddle/infrt/paddle/pb/block_desc.h index 9c1b7f9adf172fa615415f786b9a94e6ee03e22e..c9e325699a4bc4bd18eaf76a5f44cc37aa8c17d9 100644 --- a/paddle/infrt/paddle/pb/block_desc.h +++ b/paddle/infrt/paddle/pb/block_desc.h @@ -18,7 +18,9 @@ #include "paddle/infrt/paddle/cpp/desc_api.h" #include "paddle/infrt/paddle/framework.pb.h" -namespace infrt::paddle::pb { +namespace infrt { +namespace paddle { +namespace pb { namespace framework_proto = ::paddle::framework::proto; @@ -74,4 +76,6 @@ class BlockDesc : public cpp::BlockDescAPI { framework_proto::BlockDesc* desc_; // not_own }; -} // namespace infrt::paddle::pb +} // namespace pb +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/pb/op_desc.cc b/paddle/infrt/paddle/pb/op_desc.cc index c7b1e66f506425db53fd0dfdbbf43c2dc2bc4b2a..32dcefb1ac684a647d978e7d92351ae46a58f9d6 100644 --- a/paddle/infrt/paddle/pb/op_desc.cc +++ b/paddle/infrt/paddle/pb/op_desc.cc @@ -14,7 +14,9 @@ #include "paddle/infrt/paddle/pb/op_desc.h" -namespace infrt::paddle::pb { +namespace infrt { +namespace paddle { +namespace pb { google::protobuf::internal::RepeatedPtrIterator FindAttr(framework_proto::OpDesc *desc, const std::string &name) { @@ -136,4 +138,6 @@ GET_ATTRS_IMPL(std::vector, strings); GET_ATTR_IMPL(std::string, s); GET_ATTRS_IMPL(std::vector, longs); -} // namespace infrt::paddle::pb +} // namespace pb +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/pb/op_desc.h b/paddle/infrt/paddle/pb/op_desc.h index 81d57d9f32252773626db6bf554c388253c99a1f..2829f2aca2e08dd186c4a38c3b26d808cc1e1138 100644 --- a/paddle/infrt/paddle/pb/op_desc.h +++ b/paddle/infrt/paddle/pb/op_desc.h @@ -19,7 +19,9 @@ #include "paddle/infrt/paddle/framework.pb.h" #include "paddle/infrt/support/variant.h" -namespace infrt::paddle::pb { +namespace infrt { +namespace paddle { +namespace pb { namespace framework_proto = ::paddle::framework::proto; @@ -195,4 +197,6 @@ template <> void OpDesc::SetAttr>(const std::string &name, const std::vector &v); -} // namespace infrt::paddle::pb +} // namespace pb +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/pb/program_desc.cc b/paddle/infrt/paddle/pb/program_desc.cc index ed8a7e36e0129c7b8b121989fcb80c363f73fc8d..9d725485a974d3f6800a4bb3cca661d8653333c3 100644 --- a/paddle/infrt/paddle/pb/program_desc.cc +++ b/paddle/infrt/paddle/pb/program_desc.cc @@ -17,7 +17,9 @@ #include #include -namespace infrt::paddle::pb { +namespace infrt { +namespace paddle { +namespace pb { template <> framework_proto::BlockDesc* ProgramDesc::GetBlock( @@ -32,4 +34,6 @@ ProgramDesc::AddBlock() { return desc_->add_blocks(); } -} // namespace infrt::paddle::pb +} // namespace pb +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/pb/program_desc.h b/paddle/infrt/paddle/pb/program_desc.h index 4adad650c974dfc4cffe57ff70ac01965a3e733d..b1e64f8e86611fd8ef4e8be8a2064ceb1cd7a5ae 100644 --- a/paddle/infrt/paddle/pb/program_desc.h +++ b/paddle/infrt/paddle/pb/program_desc.h @@ -21,7 +21,9 @@ #include "paddle/infrt/paddle/cpp/desc_api.h" #include "paddle/infrt/paddle/framework.pb.h" -namespace infrt::paddle::pb { +namespace infrt { +namespace paddle { +namespace pb { namespace framework_proto = ::paddle::framework::proto; class ProgramDesc : public cpp::ProgramDescAPI { @@ -58,4 +60,6 @@ class ProgramDesc : public cpp::ProgramDescAPI { framework_proto::ProgramDesc *desc_; // not_own }; -} // namespace infrt::paddle::pb +} // namespace pb +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/pb/var_desc.cc b/paddle/infrt/paddle/pb/var_desc.cc index cf80df4f1b845b1f89d971a353563d934144b7ca..7ea2e24da3446c22e5f359122eb2d8d1ef5b12b4 100644 --- a/paddle/infrt/paddle/pb/var_desc.cc +++ b/paddle/infrt/paddle/pb/var_desc.cc @@ -19,7 +19,9 @@ #include "paddle/infrt/paddle/cpp/desc_api.h" #include "paddle/infrt/paddle/framework.pb.h" -namespace infrt::paddle::pb { +namespace infrt { +namespace paddle { +namespace pb { cpp::VarDescAPI::Type VarDesc::GetType() const { auto type = desc_->type().type(); @@ -364,4 +366,6 @@ VarDesc::mutable_tensor_descs() { return std::vector(); } -} // namespace infrt::paddle::pb +} // namespace pb +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/pb/var_desc.h b/paddle/infrt/paddle/pb/var_desc.h index 4cff5fdee0375d02e5fd014e287fe74f2c9a0d77..7215ba6bb6aa7b52af69ed76562d3c65422c95a5 100644 --- a/paddle/infrt/paddle/pb/var_desc.h +++ b/paddle/infrt/paddle/pb/var_desc.h @@ -23,7 +23,9 @@ #include "paddle/infrt/paddle/cpp/desc_api.h" #include "paddle/infrt/paddle/framework.pb.h" -namespace infrt::paddle::pb { +namespace infrt { +namespace paddle { +namespace pb { namespace framework_proto = ::paddle::framework::proto; // convert between std::vector and protobuf repeated. @@ -121,4 +123,6 @@ class VarDesc : public cpp::VarDescAPI { framework_proto::VarDesc *desc_; }; -} // namespace infrt::paddle::pb +} // namespace pb +} // namespace paddle +} // namespace infrt