未验证 提交 0de8a805 编写于 作者: 王明冬 提交者: GitHub

[infrt] update the version of llvm. test=develop (#38843)

上级 4c77a908
include(FetchContent) include(FetchContent)
set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/CINN/llvm11.tar.gz) set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/infrt/llvm_b5149f4e66a49a98b67e8e2de4e24a4af8e2781b.tar.gz)
set(LLVM_MD5 39d32b6be466781dddf5869318dcba53) set(LLVM_MD5 022819bb5760817013cf4b8a37e97d5e)
set(FETCHCONTENT_BASE_DIR ${THIRD_PARTY_PATH}/llvm) set(FETCHCONTENT_BASE_DIR ${THIRD_PARTY_PATH}/llvm)
set(FETCHCONTENT_QUIET OFF) set(FETCHCONTENT_QUIET OFF)
...@@ -51,7 +51,7 @@ message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") ...@@ -51,7 +51,7 @@ message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
# To build with MLIR, the LLVM is build from source code using the following flags: # To build with MLIR, the LLVM is build from source code using the following flags:
#[==[ #[==[
cmake -G Ninja ../llvm \ cmake ../llvm -G "Unix Makefiles" \
-DLLVM_ENABLE_PROJECTS="mlir;clang" \ -DLLVM_ENABLE_PROJECTS="mlir;clang" \
-DLLVM_BUILD_EXAMPLES=OFF \ -DLLVM_BUILD_EXAMPLES=OFF \
-DLLVM_TARGETS_TO_BUILD="X86" \ -DLLVM_TARGETS_TO_BUILD="X86" \
...@@ -59,8 +59,10 @@ cmake -G Ninja ../llvm \ ...@@ -59,8 +59,10 @@ cmake -G Ninja ../llvm \
-DLLVM_ENABLE_ASSERTIONS=ON \ -DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_ZLIB=OFF \ -DLLVM_ENABLE_ZLIB=OFF \
-DLLVM_ENABLE_RTTI=ON \ -DLLVM_ENABLE_RTTI=ON \
-DLLVM_INSTALL_UTILS=ON \
-DCMAKE_INSTALL_PREFIX=./install
#]==] #]==]
# The matched llvm-project version is f9dc2b7079350d0fed3bb3775f496b90483c9e42 (currently a temporary commit) # The matched llvm-project version is b5149f4e66a49a98b67e8e2de4e24a4af8e2781b (currently a temporary commit)
add_definitions(${LLVM_DEFINITIONS}) add_definitions(${LLVM_DEFINITIONS})
...@@ -75,7 +77,7 @@ add_definitions(${LLVM_DEFINITIONS}) ...@@ -75,7 +77,7 @@ add_definitions(${LLVM_DEFINITIONS})
# The minimum needed libraries for MLIR IR parse and transform. # The minimum needed libraries for MLIR IR parse and transform.
set(MLIR_IR_LIBS MLIRAnalysis MLIRStandardOps MLIRPass MLIRParser MLIRDialect MLIRIR MLIROptLib) set(MLIR_IR_LIBS MLIRAnalysis MLIRPass MLIRParser MLIRDialect MLIRIR MLIROptLib)
# tb_base is the name of a xxx.td file (without the .td suffix) # tb_base is the name of a xxx.td file (without the .td suffix)
...@@ -89,6 +91,7 @@ function(mlir_tablegen_on td_base) ...@@ -89,6 +91,7 @@ function(mlir_tablegen_on td_base)
mlir_tablegen(${td_base}.cpp.inc -gen-op-defs) mlir_tablegen(${td_base}.cpp.inc -gen-op-defs)
if (mlir_tablegen_on_DIALECT) if (mlir_tablegen_on_DIALECT)
mlir_tablegen(${td_base}_dialect.hpp.inc --gen-dialect-decls -dialect=${mlir_tablegen_on_DIALECT}) mlir_tablegen(${td_base}_dialect.hpp.inc --gen-dialect-decls -dialect=${mlir_tablegen_on_DIALECT})
mlir_tablegen(${td_base}_dialect.cpp.inc --gen-dialect-defs -dialect=${mlir_tablegen_on_DIALECT})
endif() endif()
add_public_tablegen_target(${td_base}_IncGen) add_public_tablegen_target(${td_base}_IncGen)
add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen) add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen)
......
...@@ -77,7 +77,6 @@ add_subdirectory(paddle) ...@@ -77,7 +77,6 @@ add_subdirectory(paddle)
# MLIR td file generations # MLIR td file generations
set(infrt_mlir_incs set(infrt_mlir_incs
ops_inc
basic_kernels_inc basic_kernels_inc
test_kernels_inc test_kernels_inc
infrt_base_inc infrt_base_inc
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#pragma once #pragma once
#include "mlir/IR/MLIRContext.h" #include <mlir/IR/MLIRContext.h>
#include "paddle/infrt/tensor/dense_host_tensor.h" #include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt { namespace infrt {
......
...@@ -2,7 +2,6 @@ core_gather_headers() ...@@ -2,7 +2,6 @@ core_gather_headers()
gather_srcs(infrt_src SRCS gather_srcs(infrt_src SRCS
dialect.cc dialect.cc
types.cc
basic_kernels.cc basic_kernels.cc
test_kernels.cc test_kernels.cc
infrt_base.cc infrt_base.cc
...@@ -14,8 +13,6 @@ gather_srcs(infrt_src SRCS ...@@ -14,8 +13,6 @@ gather_srcs(infrt_src SRCS
pd_types.cc pd_types.cc
pd_ops.cc pd_ops.cc
) )
mlir_tablegen_on(ops)
mlir_tablegen_on(basic_kernels) mlir_tablegen_on(basic_kernels)
mlir_tablegen_on(test_kernels) mlir_tablegen_on(test_kernels)
mlir_tablegen_on(infrt_base DIALECT infrt) mlir_tablegen_on(infrt_base DIALECT infrt)
...@@ -27,8 +24,7 @@ mlir_add_rewriter(rewrite) ...@@ -27,8 +24,7 @@ mlir_add_rewriter(rewrite)
# TODO(Superjomn) add a cmake function cc_executable to ecapsulate the following code # TODO(Superjomn) add a cmake function cc_executable to ecapsulate the following code
add_executable(infrtopt opt.cc) add_executable(infrtopt opt.cc)
target_link_libraries(infrtopt infrt ${mlir_libs}) target_link_libraries(infrtopt infrt)
add_dependencies(infrtopt infrt)
add_executable(print-ir print_ir.cc) add_executable(print-ir print_ir.cc)
target_link_libraries(print-ir infrt ${mlir_libs}) target_link_libraries(print-ir infrt ${mlir_libs})
......
...@@ -17,17 +17,17 @@ ...@@ -17,17 +17,17 @@
#include <llvm/ADT/STLExtras.h> #include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h> #include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include <mlir/IR/Function.h> #include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/Module.h> #include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/OpDefinition.h> #include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h> #include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h> #include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h> #include <mlir/Support/LogicalResult.h>
#include "paddle/infrt/dialect/dense_tensor.h" #include "paddle/infrt/dialect/dense_tensor.h"
namespace infrt::dialect { namespace infrt {
namespace dialect {
using namespace mlir; // NOLINT using namespace mlir; // NOLINT
static ParseResult parseCallOp(OpAsmParser &parser, // NOLINT static ParseResult parseCallOp(OpAsmParser &parser, // NOLINT
...@@ -71,12 +71,12 @@ static ParseResult parseConstantF64Op(OpAsmParser &parser, // NOLINT ...@@ -71,12 +71,12 @@ static ParseResult parseConstantF64Op(OpAsmParser &parser, // NOLINT
static ParseResult parseConstantI32Op(OpAsmParser &parser, // NOLINT static ParseResult parseConstantI32Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT OperationState &result) { // NOLINT
return parseConstantOp( return parseConstantOp(
IntegerType::get(32, result.getContext()), parser, result); IntegerType::get(result.getContext(), 32), parser, result);
} }
static ParseResult parseConstantI64Op(OpAsmParser &parser, // NOLINT static ParseResult parseConstantI64Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT OperationState &result) { // NOLINT
return parseConstantOp( return parseConstantOp(
IntegerType::get(64, result.getContext()), parser, result); IntegerType::get(result.getContext(), 64), parser, result);
} }
static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT
...@@ -90,10 +90,10 @@ static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT ...@@ -90,10 +90,10 @@ static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT
} }
static void print(OpAsmPrinter &p, CallOp op) { // NOLINT static void print(OpAsmPrinter &p, CallOp op) { // NOLINT
p << "infrt.call " << op.getAttr("callee") << "("; p << "infrt.call " << op->getAttr("callee") << "(";
p.printOperands(op.getOperands()); p.printOperands(op.getOperands());
p << ")"; p << ")";
p.printOptionalAttrDict(op.getAttrs(), {"callee"}); p.printOptionalAttrDict(op->getAttrs(), {"callee"});
p << " : "; p << " : ";
} }
...@@ -145,7 +145,7 @@ static LogicalResult verify(ConstantF64Op op) { return success(); } ...@@ -145,7 +145,7 @@ static LogicalResult verify(ConstantF64Op op) { return success(); }
static LogicalResult verify(ConstantI64Op op) { return success(); } static LogicalResult verify(ConstantI64Op op) { return success(); }
static LogicalResult verify(ReturnOp op) { static LogicalResult verify(ReturnOp op) {
auto function = dyn_cast<FuncOp>(op.getParentOp()); auto function = dyn_cast<FuncOp>(op->getParentOp());
if (!function) return success(); if (!function) return success();
...@@ -157,8 +157,8 @@ static LogicalResult verify(ReturnOp op) { ...@@ -157,8 +157,8 @@ static LogicalResult verify(ReturnOp op) {
return success(); return success();
} }
} // namespace dialect
} // namespace infrt
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/basic_kernels.cpp.inc" #include "paddle/infrt/dialect/basic_kernels.cpp.inc"
} // namespace infrt::dialect
...@@ -13,12 +13,9 @@ ...@@ -13,12 +13,9 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/OpDefinition.h> #include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h> #include <mlir/Interfaces/SideEffectInterfaces.h>
using namespace mlir; // NOLINT
namespace infrt::dialect {
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/basic_kernels.hpp.inc" #include "paddle/infrt/dialect/basic_kernels.hpp.inc"
} // namespace infrt::dialect
...@@ -27,7 +27,7 @@ def CallOp : INFRT_Op<"call"> { ...@@ -27,7 +27,7 @@ def CallOp : INFRT_Op<"call"> {
let results = (outs Variadic<AnyType>); let results = (outs Variadic<AnyType>);
let extraClassDeclaration = [{ let extraClassDeclaration = [{
StringRef getCallee() { return callee(); } mlir::StringRef getCallee() { return callee(); }
mlir::FunctionType getCalleeType(); mlir::FunctionType getCalleeType();
}]; }];
} }
...@@ -57,9 +57,8 @@ def ReturnOp : INFRT_Op<"return", [Terminator]> { ...@@ -57,9 +57,8 @@ def ReturnOp : INFRT_Op<"return", [Terminator]> {
let arguments = (ins Variadic<AnyType>:$operands); let arguments = (ins Variadic<AnyType>:$operands);
let builders = [OpBuilder< let builders = [OpBuilder<(ins),
"OpBuilder &b, OperationState &result", [{ build($_builder, $_state, llvm::None); }]>];
[{ build(b, result, llvm::None); }]>];
} }
class AddOp<string suffix, Type type> : INFRT_Op<"add." # suffix, [NoSideEffect]> { class AddOp<string suffix, Type type> : INFRT_Op<"add." # suffix, [NoSideEffect]> {
......
...@@ -17,12 +17,11 @@ ...@@ -17,12 +17,11 @@
#include <llvm/ADT/STLExtras.h> #include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h> #include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/DialectImplementation.h> #include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OpDefinition.h> #include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h> #include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h> #include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h> #include <mlir/Support/LogicalResult.h>
...@@ -31,68 +30,37 @@ ...@@ -31,68 +30,37 @@
#include "paddle/infrt/common/global.h" #include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/tensor_shape.h" #include "paddle/infrt/dialect/tensor_shape.h"
namespace infrt::dt { namespace infrt {
namespace dt {
void DTDialect::initialize() { void DTDialect::initialize() {
allowUnknownTypes();
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "paddle/infrt/dialect/dense_tensor.cpp.inc" #include "paddle/infrt/dialect/dense_tensor.cpp.inc"
>(); >();
} }
namespace detail {
struct TensorTypeStorage : public mlir::TypeStorage {
TensorTypeStorage(TargetType target,
LayoutType layout,
PrecisionType precision)
: target_(target), layout_(layout), precision_(precision) {}
using KeyTy = std::tuple<TargetType, LayoutType, PrecisionType>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(target_, layout_, precision_);
}
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}
static TensorTypeStorage *construct(
mlir::TypeStorageAllocator &allocator, // NOLINT
const KeyTy &key) {
return new (allocator.allocate<TensorTypeStorage>())
TensorTypeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key));
}
TargetType target_;
LayoutType layout_;
PrecisionType precision_;
};
} // namespace detail
llvm::Optional<TargetType> GetTargetType(mlir::StringRef key) { llvm::Optional<TargetType> GetTargetType(mlir::StringRef key) {
if (key.equals_lower("x86")) if (key.equals_insensitive("x86"))
return TargetType::X86; return TargetType::X86;
else if (key.equals_lower("cuda")) else if (key.equals_insensitive("cuda"))
return TargetType::CUDA; return TargetType::CUDA;
else else
return llvm::None; return llvm::None;
} }
llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key) { llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key) {
if (key.equals_lower("nchw")) if (key.equals_insensitive("nchw"))
return LayoutType::NCHW; return LayoutType::NCHW;
else if (key.equals_lower("nhwc")) else if (key.equals_insensitive("nhwc"))
return LayoutType::NHWC; return LayoutType::NHWC;
else else
return llvm::None; return llvm::None;
} }
llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key) { llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key) {
if (key.equals_lower("i32")) if (key.equals_insensitive("i32"))
return PrecisionType::I32; return PrecisionType::I32;
else if (key.equals_lower("f32")) else if (key.equals_insensitive("f32"))
return PrecisionType::F32; return PrecisionType::F32;
else else
return llvm::None; return llvm::None;
...@@ -111,7 +79,7 @@ LayoutType TensorType::layout() { return getImpl()->layout_; } ...@@ -111,7 +79,7 @@ LayoutType TensorType::layout() { return getImpl()->layout_; }
PrecisionType TensorType::precision() { return getImpl()->precision_; } PrecisionType TensorType::precision() { return getImpl()->precision_; }
raw_ostream &operator<<(raw_ostream &os, TensorType tensorType) { mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TensorType tensorType) {
os << "TensorType<" << tensorType.target() << ", " << tensorType.layout() os << "TensorType<" << tensorType.target() << ", " << tensorType.layout()
<< ", " << tensorType.precision() << ">"; << ", " << tensorType.precision() << ">";
return os; return os;
...@@ -133,7 +101,7 @@ StringType StringType::get(mlir::MLIRContext *context) { ...@@ -133,7 +101,7 @@ StringType StringType::get(mlir::MLIRContext *context) {
return Base::get(context); return Base::get(context);
} }
raw_ostream &operator<<(raw_ostream &os, TargetType type) { mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TargetType type) {
switch (type) { switch (type) {
case (TargetType::X86): case (TargetType::X86):
os << "X86"; os << "X86";
...@@ -147,7 +115,7 @@ raw_ostream &operator<<(raw_ostream &os, TargetType type) { ...@@ -147,7 +115,7 @@ raw_ostream &operator<<(raw_ostream &os, TargetType type) {
return os; return os;
} }
raw_ostream &operator<<(raw_ostream &os, LayoutType type) { mlir::raw_ostream &operator<<(mlir::raw_ostream &os, LayoutType type) {
switch (type) { switch (type) {
case (LayoutType::NCHW): case (LayoutType::NCHW):
os << "NCHW"; os << "NCHW";
...@@ -161,7 +129,7 @@ raw_ostream &operator<<(raw_ostream &os, LayoutType type) { ...@@ -161,7 +129,7 @@ raw_ostream &operator<<(raw_ostream &os, LayoutType type) {
return os; return os;
} }
raw_ostream &operator<<(raw_ostream &os, PrecisionType type) { mlir::raw_ostream &operator<<(mlir::raw_ostream &os, PrecisionType type) {
switch (type) { switch (type) {
case (PrecisionType::I32): case (PrecisionType::I32):
os << "I32"; os << "I32";
...@@ -175,103 +143,69 @@ raw_ostream &operator<<(raw_ostream &os, PrecisionType type) { ...@@ -175,103 +143,69 @@ raw_ostream &operator<<(raw_ostream &os, PrecisionType type) {
return os; return os;
} }
static Type getTensorType(mlir::MLIRContext *context) { static mlir::Type getTensorType(mlir::MLIRContext *context) {
auto t_dialect = Identifier::get("t", context); auto t_dialect = mlir::Identifier::get("t", context);
return OpaqueType::get(t_dialect, "tensor", context); return mlir::OpaqueType::get(t_dialect, "tensor");
} }
static ParseResult parseCreateUninitTensorOp( static mlir::ParseResult parseCreateUninitTensorOp(
OpAsmParser &parser, // NOLINT mlir::OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT mlir::OperationState &result) { // NOLINT
auto loc = parser.getCurrentLocation(); auto loc = parser.getCurrentLocation();
::mlir::Type outputRawTypes[1]; mlir::Type outputRawTypes[1];
::llvm::ArrayRef<::mlir::Type> outputTypes(outputRawTypes); ::llvm::ArrayRef<mlir::Type> outputTypes(outputRawTypes);
mlir::ArrayAttr shapeAttr; mlir::ArrayAttr shapeAttr;
if (parser.parseAttribute(shapeAttr, if (parser.parseAttribute(shapeAttr,
parser.getBuilder().getI64Type(), parser.getBuilder().getI64Type(),
"shape", "shape",
result.attributes)) result.attributes))
return failure(); return mlir::failure();
if (parser.parseOptionalAttrDict(result.attributes)) return failure(); if (parser.parseOptionalAttrDict(result.attributes)) return mlir::failure();
if (parser.parseArrow()) return failure(); if (parser.parseArrow()) return mlir::failure();
if (parser.parseType(outputRawTypes[0])) return failure(); if (parser.parseType(outputRawTypes[0])) return mlir::failure();
if (!outputRawTypes[0].isa<TensorType>()) if (!outputRawTypes[0].isa<TensorType>())
return parser.emitError(loc, "invalid kind of type specified"); return parser.emitError(loc, "invalid kind of type specified");
result.addTypes(outputTypes); result.addTypes(outputTypes);
return success(); return mlir::success();
} }
template <typename CreateUninitTensorOp> template <typename CreateUninitTensorOp>
static void printCreateUninitTensorOp(OpAsmPrinter &p, // NOLINT static void printCreateUninitTensorOp(mlir::OpAsmPrinter &p, // NOLINT
CreateUninitTensorOp op) { CreateUninitTensorOp op) {
p << CreateUninitTensorOp::getOperationName(); p << CreateUninitTensorOp::getOperationName();
p << " "; p << " ";
p.printAttributeWithoutType(op.shapeAttr()); p.printAttributeWithoutType(op.shapeAttr());
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
p << " -> "; p << " -> ";
p << op.getOperation()->getResultTypes(); p << op.getOperation()->getResultTypes();
} }
// TODO(shibo): can be removed? static mlir::ParseResult parseSetTensorOp(
// static ParseResult parseFillTensorWithConstantOp(OpAsmParser& parser, mlir::OpAsmParser &parser, // NOLINT
// OperationState& result) { mlir::OperationState &result) { // NOLINT
// auto loc = parser.getCurrentLocation(); llvm::SmallVector<mlir::OpAsmParser::OperandType, 1> operands;
// ::mlir::OpAsmParser::OperandType inputRawOperands[1]; if (parser.parseOperandList(operands, 1)) return mlir::failure();
// ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType>
// inputOperands(inputRawOperands);
// ::mlir::Type inputRawTypes[1];
// ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes);
//
// if (parser.parseOperand(inputRawOperands[0])) return failure();
//
// if (parser.parseColon()) return failure();
// if (parser.parseType(inputRawTypes[0])) return failure();
// if (!inputRawTypes[0].isa<TensorType>())
// return parser.emitError(loc, "invalid kind of type specified");
//
// Attribute value_attr;
// if (parser.resolveOperands(inputOperands, inputTypes, loc, result.operands))
// return failure();
// if (parser.parseAttribute(value_attr, "value", result.attributes)) return
// failure();
// return success();
//}
// TODO(shibo): can be removed?
// template <typename FillTensorOp>
// static void printFillTensorWithConstantOp(OpAsmPrinter& p, FillTensorOp op) {
// p << FillTensorOp::getOperationName();
// p << " ";
// p.printOperand(op.getOperand());
// p << " : ";
// p << op.getOperation()->getOperandTypes();
// p << " ";
// p << op.getAttr("value");
//}
static ParseResult parseSetTensorOp(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
SmallVector<OpAsmParser::OperandType, 1> operands;
if (parser.parseOperandList(operands, 1)) return failure();
auto tensor_type = getTensorType(result.getContext()); auto tensor_type = getTensorType(result.getContext());
Attribute value_attr; mlir::Attribute value_attr;
return failure( return mlir::failure(
parser.resolveOperand(operands[0], tensor_type, result.operands) || parser.resolveOperand(operands[0], tensor_type, result.operands) ||
parser.parseAttribute(value_attr, "values", result.attributes)); parser.parseAttribute(value_attr, "values", result.attributes));
} }
template <typename SetTensorOp> template <typename SetTensorOp>
static void printSetTensorOp(OpAsmPrinter &p, SetTensorOp op) { // NOLINT static void printSetTensorOp(mlir::OpAsmPrinter &p, SetTensorOp op) { // NOLINT
p << SetTensorOp::getOperationName() << " "; p << SetTensorOp::getOperationName() << " ";
p.printOperand(op.getOperand()); p.printOperand(op.getOperand());
p << " " << op.getAttr("values"); p << " " << op->getAttr("values");
} }
} // namespace dt
} // namespace infrt
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/dense_tensor.cpp.inc" // NOLINT #include "paddle/infrt/dialect/dense_tensor.cpp.inc" // NOLINT
} // namespace infrt::dt #include "paddle/infrt/dialect/dense_tensor_dialect.cpp.inc"
...@@ -19,13 +19,8 @@ ...@@ -19,13 +19,8 @@
#include <string> #include <string>
using namespace mlir; // NOLINT namespace infrt {
namespace infrt::dt { namespace dt {
namespace detail {
struct TensorTypeStorage;
} // namespace detail
enum class TargetType : uint8_t { X86, CUDA }; enum class TargetType : uint8_t { X86, CUDA };
enum class LayoutType : uint8_t { NCHW, NHWC }; enum class LayoutType : uint8_t { NCHW, NHWC };
enum class PrecisionType : uint8_t { I32, F32 }; enum class PrecisionType : uint8_t { I32, F32 };
...@@ -34,9 +29,39 @@ llvm::Optional<TargetType> GetTargetType(mlir::StringRef key); ...@@ -34,9 +29,39 @@ llvm::Optional<TargetType> GetTargetType(mlir::StringRef key);
llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key); llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key);
llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key); llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key);
raw_ostream &operator<<(raw_ostream &os, TargetType type); mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TargetType type);
raw_ostream &operator<<(raw_ostream &os, LayoutType type); mlir::raw_ostream &operator<<(mlir::raw_ostream &os, LayoutType type);
raw_ostream &operator<<(raw_ostream &os, PrecisionType type); mlir::raw_ostream &operator<<(mlir::raw_ostream &os, PrecisionType type);
namespace detail {
struct TensorTypeStorage : public mlir::TypeStorage {
TensorTypeStorage(TargetType target,
LayoutType layout,
PrecisionType precision)
: target_(target), layout_(layout), precision_(precision) {}
using KeyTy = std::tuple<TargetType, LayoutType, PrecisionType>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(target_, layout_, precision_);
}
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}
static TensorTypeStorage *construct(
mlir::TypeStorageAllocator &allocator, // NOLINT
const KeyTy &key) {
return new (allocator.allocate<TensorTypeStorage>())
TensorTypeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key));
}
TargetType target_;
LayoutType layout_;
PrecisionType precision_;
};
} // namespace detail
class TensorType : public mlir::Type::TypeBase<TensorType, class TensorType : public mlir::Type::TypeBase<TensorType,
mlir::Type, mlir::Type,
...@@ -52,7 +77,7 @@ class TensorType : public mlir::Type::TypeBase<TensorType, ...@@ -52,7 +77,7 @@ class TensorType : public mlir::Type::TypeBase<TensorType,
PrecisionType precision(); PrecisionType precision();
}; };
raw_ostream &operator<<(raw_ostream &os, TensorType tensorType); mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TensorType tensorType);
class TensorMapType : public mlir::Type::TypeBase<TensorMapType, class TensorMapType : public mlir::Type::TypeBase<TensorMapType,
mlir::Type, mlir::Type,
...@@ -70,10 +95,10 @@ class StringType ...@@ -70,10 +95,10 @@ class StringType
static StringType get(); static StringType get();
static StringType get(mlir::MLIRContext *context); static StringType get(mlir::MLIRContext *context);
}; };
} // namespace dt
} // namespace infrt
#include "paddle/infrt/dialect/dense_tensor_dialect.hpp.inc" #include "paddle/infrt/dialect/dense_tensor_dialect.hpp.inc"
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/dense_tensor.hpp.inc" #include "paddle/infrt/dialect/dense_tensor.hpp.inc"
} // namespace infrt::dt
...@@ -14,9 +14,11 @@ ...@@ -14,9 +14,11 @@
#include "paddle/infrt/dialect/diagnostic_utils.h" #include "paddle/infrt/dialect/diagnostic_utils.h"
#include <llvm/Support/raw_ostream.h>
#include <string> #include <string>
namespace infrt::dialect { namespace infrt {
namespace dialect {
struct MyScopedDiagnosicHandler::Impl { struct MyScopedDiagnosicHandler::Impl {
Impl() : diag_stream_(diag_str_) {} Impl() : diag_stream_(diag_str_) {}
...@@ -49,4 +51,5 @@ mlir::LogicalResult MyScopedDiagnosicHandler::handler(mlir::Diagnostic *diag) { ...@@ -49,4 +51,5 @@ mlir::LogicalResult MyScopedDiagnosicHandler::handler(mlir::Diagnostic *diag) {
return mlir::failure(true); return mlir::failure(true);
} }
} // namespace infrt::dialect } // namespace dialect
} // namespace infrt
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
#include <memory> #include <memory>
namespace infrt::dialect { namespace infrt {
namespace dialect {
/** /**
* A scoped diagnostic handler to help debug MLIR process. * A scoped diagnostic handler to help debug MLIR process.
...@@ -36,4 +37,5 @@ class MyScopedDiagnosicHandler : public mlir::SourceMgrDiagnosticHandler { ...@@ -36,4 +37,5 @@ class MyScopedDiagnosicHandler : public mlir::SourceMgrDiagnosticHandler {
std::unique_ptr<Impl> impl_; std::unique_ptr<Impl> impl_;
}; };
} // namespace infrt::dialect } // namespace dialect
} // namespace infrt
...@@ -13,24 +13,26 @@ ...@@ -13,24 +13,26 @@
// limitations under the License. // limitations under the License.
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Dialect.h> #include <mlir/IR/Dialect.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OpDefinition.h> #include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h> #include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/Interfaces/SideEffectInterfaces.h> #include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Support/LogicalResult.h> #include <mlir/Support/LogicalResult.h>
namespace infrt::hlir::dialect { namespace infrt {
namespace hlir {
namespace dialect {
class CinnDialect : public ::mlir::Dialect { class CinnDialect : public mlir::Dialect {
public: public:
explicit CinnDialect(::mlir::MLIRContext* ctx); explicit CinnDialect(mlir::MLIRContext* ctx);
//! We should register this function in dialect //! We should register this function in dialect
static llvm::StringRef getDialectNamespace() { static llvm::StringRef getDialectNamespace() {
return "infrt::hlir::dialect"; return "infrt::hlir::dialect";
} }
}; };
} // namespace dialect
} // namespace infrt::hlir::dialect } // namespace hlir
} // namespace infrt
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
#include "paddle/infrt/dialect/dense_tensor.h" #include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/test_kernels.h" #include "paddle/infrt/dialect/test_kernels.h"
namespace infrt::dialect { namespace infrt {
namespace dialect {
// ----INFRTDialect definition begin---- // ----INFRTDialect definition begin----
void INFRTDialect::initialize() { void INFRTDialect::initialize() {
...@@ -124,4 +125,5 @@ void INFRTDialect::printType(mlir::Type type, ...@@ -124,4 +125,5 @@ void INFRTDialect::printType(mlir::Type type,
// ----INFRTDialect definition end---- // ----INFRTDialect definition end----
} // namespace infrt::dialect } // namespace dialect
} // namespace infrt
...@@ -18,19 +18,17 @@ ...@@ -18,19 +18,17 @@
#include <mlir/IR/Dialect.h> #include <mlir/IR/Dialect.h>
#include <mlir/IR/DialectImplementation.h> #include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/MLIRContext.h> #include <mlir/IR/MLIRContext.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h> #include <mlir/IR/TypeUtilities.h>
#include <mlir/IR/Types.h> #include <mlir/IR/Types.h>
#include "paddle/infrt/dialect/infrt_base.hpp.inc" #include "paddle/infrt/dialect/infrt_base.hpp.inc"
namespace infrt::dialect { namespace infrt {
namespace dialect {
class INFRTDialect : public ::mlir::Dialect { class INFRTDialect : public mlir::Dialect {
explicit INFRTDialect(::mlir::MLIRContext *context) explicit INFRTDialect(mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(), : mlir::Dialect(
context, getDialectNamespace(), context, mlir::TypeID::get<INFRTDialect>()) {
::mlir::TypeID::get<INFRTDialect>()) {
initialize(); initialize();
} }
...@@ -41,15 +39,12 @@ class INFRTDialect : public ::mlir::Dialect { ...@@ -41,15 +39,12 @@ class INFRTDialect : public ::mlir::Dialect {
mlir::DialectAsmPrinter &printer) const override; mlir::DialectAsmPrinter &printer) const override;
void initialize(); void initialize();
friend class ::mlir::MLIRContext; friend class mlir::MLIRContext;
public: public:
static ::llvm::StringRef getDialectNamespace() { return "infrt"; } static ::llvm::StringRef getDialectNamespace() { return "infrt"; }
}; };
} // namespace dialect
} // namespace infrt::dialect
namespace mlir {
template <typename T> template <typename T>
static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT
...@@ -58,17 +53,16 @@ static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT ...@@ -58,17 +53,16 @@ static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT
return b.getIntegerAttr(b.getI32Type(), constant); return b.getIntegerAttr(b.getI32Type(), constant);
} }
static mlir::SmallVector<::mlir::Value, 4> cvtValueToValueRange( static mlir::SmallVector<mlir::Value, 4> cvtValueToValueRange(
const mlir::Value &operand) { const mlir::Value &operand) {
return mlir::SmallVector<::mlir::Value, 4>(1, operand); return mlir::SmallVector<mlir::Value, 4>(1, operand);
} }
static mlir::SmallVector<::mlir::Value, 4> concatTwoValueRange( static mlir::SmallVector<mlir::Value, 4> concatTwoValueRange(
mlir::ValueRange operand_0, mlir::ValueRange operand_1) { mlir::ValueRange operand_0, mlir::ValueRange operand_1) {
mlir::SmallVector<::mlir::Value, 4> operands; mlir::SmallVector<mlir::Value, 4> operands;
operands.append(operand_0.begin(), operand_0.end()); operands.append(operand_0.begin(), operand_0.end());
operands.append(operand_1.begin(), operand_1.end()); operands.append(operand_1.begin(), operand_1.end());
return operands; return operands;
} }
} // namespace infrt
} // namespace mlir
...@@ -28,11 +28,11 @@ def TensorMapType : ...@@ -28,11 +28,11 @@ def TensorMapType :
def BufferType : OpaqueType<"b", "buffer", "buffer">; def BufferType : OpaqueType<"b", "buffer", "buffer">;
class INFRT_createI32Attr<string value> : NativeCodeCall< class INFRT_createI32Attr<string value> : NativeCodeCall<
"mlir::createI32Attr($_builder, $_loc, " # value # ")">; "infrt::createI32Attr($_builder, $_loc, " # value # ")">;
def INFRT_cvtValueToValueRange : NativeCodeCall< def INFRT_cvtValueToValueRange : NativeCodeCall<
"mlir::cvtValueToValueRange($0)">; "infrt::cvtValueToValueRange($0)">;
def INFRT_concatTwoValueRange : NativeCodeCall< def INFRT_concatTwoValueRange : NativeCodeCall<
"mlir::concatTwoValueRange($0, $1)">; "infrt::concatTwoValueRange($0, $1)">;
#endif // INFRT_BASE #endif // INFRT_BASE
...@@ -23,12 +23,10 @@ ...@@ -23,12 +23,10 @@
#include "paddle/infrt/dialect/tensor_shape.h" #include "paddle/infrt/dialect/tensor_shape.h"
namespace infrt { namespace infrt {
void registerCinnDialects(mlir::DialectRegistry &registry) { // NOLINT
void RegisterCinnDialects(mlir::DialectRegistry& registry) { // NOLINT registry.insert<ts::TensorShapeDialect,
registry.insert<ts::TensorShapeDialect>(); dialect::INFRTDialect,
registry.insert<dialect::INFRTDialect>(); dt::DTDialect,
registry.insert<dt::DTDialect>(); mlir::pd::PaddleDialect>();
registry.insert<mlir::pd::PaddleDialect>();
} }
} // namespace infrt } // namespace infrt
...@@ -14,10 +14,8 @@ ...@@ -14,10 +14,8 @@
#pragma once #pragma once
#include "mlir/IR/Dialect.h" #include <mlir/IR/Dialect.h>
#include <mlir/IR/MLIRContext.h>
namespace infrt { namespace infrt {
void registerCinnDialects(mlir::DialectRegistry &registry); // NOLINT
void RegisterCinnDialects(mlir::DialectRegistry& registry); // NOLINT
} // namespace infrt } // namespace infrt
...@@ -16,8 +16,8 @@ ...@@ -16,8 +16,8 @@
#include <llvm/Support/SourceMgr.h> #include <llvm/Support/SourceMgr.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h> #include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Diagnostics.h> #include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OperationSupport.h> #include <mlir/IR/OperationSupport.h>
#include <mlir/Parser.h> #include <mlir/Parser.h>
#include <unordered_map> #include <unordered_map>
...@@ -30,12 +30,15 @@ ...@@ -30,12 +30,15 @@
#include "paddle/infrt/dialect/diagnostic_utils.h" #include "paddle/infrt/dialect/diagnostic_utils.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h" #include "paddle/infrt/dialect/init_infrt_dialects.h"
namespace infrt::dialect { namespace infrt {
namespace dialect {
mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context, mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
const std::string& mlir_source) { const std::string& mlir_source) {
// context->allowUnregisteredDialects(); // context->allowUnregisteredDialects();
RegisterCinnDialects(context->getDialectRegistry()); mlir::DialectRegistry registry;
registerCinnDialects(registry);
context->appendDialectRegistry(registry);
// Currenetly, We only used the CinnDialect and mlir::BuiltinDialect is // Currenetly, We only used the CinnDialect and mlir::BuiltinDialect is
// enough。Don't need StandardOpsDialect. // enough。Don't need StandardOpsDialect.
// context->getDialectRegistry().insert<mlir::StandardOpsDialect>(); // context->getDialectRegistry().insert<mlir::StandardOpsDialect>();
...@@ -57,9 +60,9 @@ mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context, ...@@ -57,9 +60,9 @@ mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
mlir::OwningModuleRef LoadMlirFile(const std::string& file_name, mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
mlir::MLIRContext* context) { mlir::MLIRContext* context) {
// context->allowUnregisteredDialects(); // context->allowUnregisteredDialects();
RegisterCinnDialects(context->getDialectRegistry()); mlir::DialectRegistry registry;
context->getDialectRegistry().insert<mlir::StandardOpsDialect>(); registerCinnDialects(registry);
context->appendDialectRegistry(registry);
mlir::ScopedDiagnosticHandler scope_handler( mlir::ScopedDiagnosticHandler scope_handler(
context, [](mlir::Diagnostic& diag) { context, [](mlir::Diagnostic& diag) {
if (diag.getSeverity() != mlir::DiagnosticSeverity::Error) if (diag.getSeverity() != mlir::DiagnosticSeverity::Error)
...@@ -71,4 +74,5 @@ mlir::OwningModuleRef LoadMlirFile(const std::string& file_name, ...@@ -71,4 +74,5 @@ mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
return mlir::parseSourceFile(std::string(file_name), context); return mlir::parseSourceFile(std::string(file_name), context);
} }
} // namespace infrt::dialect } // namespace dialect
} // namespace infrt
...@@ -15,16 +15,17 @@ ...@@ -15,16 +15,17 @@
#pragma once #pragma once
#include <glog/logging.h> #include <glog/logging.h>
#include <mlir/IR/Module.h> #include <mlir/IR/BuiltinOps.h>
#include <string> #include <string>
#include <memory> #include <memory>
namespace infrt::dialect { namespace infrt {
namespace dialect {
mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context, mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
const std::string& mlir_source); const std::string& mlir_source);
mlir::OwningModuleRef LoadMlirFile(const std::string& file_name, mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
mlir::MLIRContext* context); mlir::MLIRContext* context);
} // namespace dialect
} // namespace infrt::dialect } // namespace infrt
...@@ -17,14 +17,15 @@ ...@@ -17,14 +17,15 @@
#include <glog/logging.h> #include <glog/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <llvm/Support/SourceMgr.h> #include <llvm/Support/SourceMgr.h>
#include <mlir/IR/Function.h> #include <mlir/IR/BuiltinTypes.h>
#include <mlir/Parser.h> #include <mlir/Parser.h>
#include <string> #include <string>
#include "paddle/infrt/dialect/init_infrt_dialects.h" #include "paddle/infrt/dialect/init_infrt_dialects.h"
namespace infrt::dialect { namespace infrt {
namespace dialect {
TEST(MlirLoader, basic) { TEST(MlirLoader, basic) {
mlir::MLIRContext context; mlir::MLIRContext context;
...@@ -42,8 +43,7 @@ func @main() -> f32 { ...@@ -42,8 +43,7 @@ func @main() -> f32 {
)ROC"; )ROC";
auto module = LoadMlirSource(&context, source); auto module = LoadMlirSource(&context, source);
module->verify(); EXPECT_TRUE(mlir::succeeded(module->verify()));
LOG(INFO) << "module name: " << module->getOperationName().data(); LOG(INFO) << "module name: " << module->getOperationName().data();
for (auto func : module->getOps<mlir::FuncOp>()) { for (auto func : module->getOps<mlir::FuncOp>()) {
LOG(INFO) << "get func " << func.getName().str(); LOG(INFO) << "get func " << func.getName().str();
...@@ -54,4 +54,5 @@ func @main() -> f32 { ...@@ -54,4 +54,5 @@ func @main() -> f32 {
} }
} }
} // namespace infrt::dialect } // namespace dialect
} // namespace infrt
...@@ -20,5 +20,5 @@ func @main() -> tensor<?xf32> { ...@@ -20,5 +20,5 @@ func @main() -> tensor<?xf32> {
%c2 = "pd.matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %c2 = "pd.matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32> %e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32>
infrt.return %e2 : tensor<?xf32> "pd.fetch"(%e2) {name="output"} :(tensor<?xf32>)->()
} }
\ No newline at end of file
...@@ -11,5 +11,5 @@ func @main() -> tensor<?xf32> { ...@@ -11,5 +11,5 @@ func @main() -> tensor<?xf32> {
%c = "pd.conv2d"(%a, %filter, %bias) {} : (tensor<?x3x256x256xf32>, tensor<3x64x3x3xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32> %c = "pd.conv2d"(%a, %filter, %bias) {} : (tensor<?x3x256x256xf32>, tensor<3x64x3x3xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32>
%d = "pd.batch_norm"(%c, %scale, %bias2, %mean, %var) {} : (tensor<?x3x256x256xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32> %d = "pd.batch_norm"(%c, %scale, %bias2, %mean, %var) {} : (tensor<?x3x256x256xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32>
infrt.return %d : tensor<?x3x256x256xf32> "pd.fetch"(%d) {name="output"} :(tensor<?x3x256x256xf32>)->()
} }
\ No newline at end of file
...@@ -18,5 +18,5 @@ func @main() -> tensor<?xf32> { ...@@ -18,5 +18,5 @@ func @main() -> tensor<?xf32> {
%d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32> %e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32>
"pd.fetch"(%e2) :(tensor<?xf32>)->() "pd.fetch"(%e2) {name="output"} :(tensor<?xf32>)->()
} }
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/infrt_base.td"
class INFRT_Op<string mnemonic, list<OpTrait> traits = []> :
Op<INFRT_Dialect, mnemonic, traits>;
...@@ -12,34 +12,14 @@ ...@@ -12,34 +12,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <glog/logging.h>
#include <llvm/Support/CommandLine.h>
#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/IR/AsmState.h>
#include <mlir/IR/Dialect.h>
#include <mlir/InitAllDialects.h>
#include <mlir/InitAllPasses.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
#include <mlir/Support/MlirOptMain.h> #include <mlir/Support/MlirOptMain.h>
#include <mlir/Transforms/Passes.h> #include <mlir/Transforms/Passes.h>
#include <iostream>
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h" #include "paddle/infrt/dialect/init_infrt_dialects.h"
#include "paddle/infrt/dialect/mlir_loader.h"
int main(int argc, char **argv) { int main(int argc, char **argv) {
mlir::MLIRContext *context = infrt::Global::getMLIRContext(); mlir::DialectRegistry registry;
infrt::registerCinnDialects(registry);
auto &registry = context->getDialectRegistry();
infrt::RegisterCinnDialects(registry);
mlir::registerCanonicalizerPass(); mlir::registerCanonicalizerPass();
return mlir::failed( return mlir::failed(
mlir::MlirOptMain(argc, argv, "INFRT mlir pass driver", registry)); mlir::MlirOptMain(argc, argv, "infrt mlir pass driver", registry));
} }
...@@ -16,7 +16,7 @@ def PD_Dialect : Dialect { ...@@ -16,7 +16,7 @@ def PD_Dialect : Dialect {
This dialect contains the PaddlePaddle operators. This dialect contains the PaddlePaddle operators.
}]; }];
let cppNamespace = "::mlir::pd"; let cppNamespace = "mlir::pd";
} }
class PD_Op<string mnemonic, list<OpTrait> traits = []> : class PD_Op<string mnemonic, list<OpTrait> traits = []> :
......
...@@ -14,10 +14,15 @@ ...@@ -14,10 +14,15 @@
#include "paddle/infrt/dialect/pd_ops.h" #include "paddle/infrt/dialect/pd_ops.h"
#include "mlir/IR/Matchers.h" #include <mlir/IR/Matchers.h>
#include "mlir/IR/PatternMatch.h" #include <mlir/IR/PatternMatch.h>
#include "paddle/infrt/dialect/infrt_base.h" #include "paddle/infrt/dialect/infrt_base.h"
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT
#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT
namespace mlir { namespace mlir {
namespace pd { namespace pd {
PaddleDialect::PaddleDialect(MLIRContext *context) PaddleDialect::PaddleDialect(MLIRContext *context)
...@@ -36,12 +41,6 @@ mlir::Operation *PaddleDialect::materializeConstant(mlir::OpBuilder &builder, ...@@ -36,12 +41,6 @@ mlir::Operation *PaddleDialect::materializeConstant(mlir::OpBuilder &builder,
return builder.create<ConstantOp>(loc, value); return builder.create<ConstantOp>(loc, value);
} }
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT
#undef GET_OP_CLASSES
#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT
void ConstantOp::build(OpBuilder &builder, void ConstantOp::build(OpBuilder &builder,
OperationState &state, OperationState &state,
Attribute value) { Attribute value) {
...@@ -66,8 +65,8 @@ LogicalResult ConstantOp::inferReturnTypes( ...@@ -66,8 +65,8 @@ LogicalResult ConstantOp::inferReturnTypes(
inferredReturnTypes.push_back(attributes.get("value").getType()); inferredReturnTypes.push_back(attributes.get("value").getType());
return success(); return success();
} }
::mlir::OpFoldResult ConstantOp::fold( mlir::OpFoldResult ConstantOp::fold(
::llvm::ArrayRef<::mlir::Attribute> operands) { ::llvm::ArrayRef<mlir::Attribute> operands) {
return value(); return value();
} }
...@@ -82,11 +81,11 @@ LogicalResult ElementwiseAdd::inferReturnTypes( ...@@ -82,11 +81,11 @@ LogicalResult ElementwiseAdd::inferReturnTypes(
return success(); return success();
} }
void ElementwiseAdd::getCanonicalizationPatterns( void ElementwiseAdd::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseMulAdd>(context); results.insert<FuseMulAdd>(context);
} }
::mlir::OpFoldResult ElementwiseAdd::fold( mlir::OpFoldResult ElementwiseAdd::fold(
llvm::ArrayRef<mlir::Attribute> operands) { llvm::ArrayRef<mlir::Attribute> operands) {
if (getElementTypeOrSelf(getType()).isa<FloatType>()) { if (getElementTypeOrSelf(getType()).isa<FloatType>()) {
if (!operands[0] || !operands[1]) return {}; if (!operands[0] || !operands[1]) return {};
...@@ -154,17 +153,17 @@ LogicalResult MulOp::inferReturnTypes( ...@@ -154,17 +153,17 @@ LogicalResult MulOp::inferReturnTypes(
} }
void ReluOp::getCanonicalizationPatterns( void ReluOp::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseFCRelu>(context); results.insert<FuseFCRelu>(context);
} }
void FusedRepeatedFCRelu::getCanonicalizationPatterns( void FusedRepeatedFCRelu::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseRepeatedFCRelu2>(context); results.insert<FuseRepeatedFCRelu2>(context);
} }
void BatchNormOp::getCanonicalizationPatterns( void BatchNormOp::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseBatchNormWithConvPattern>(context); results.insert<FuseBatchNormWithConvPattern>(context);
} }
......
...@@ -14,21 +14,20 @@ ...@@ -14,21 +14,20 @@
#pragma once #pragma once
#include "mlir/Dialect/Traits.h" #include <mlir/Dialect/Traits.h>
#include "mlir/IR/Attributes.h" #include <mlir/IR/Attributes.h>
#include "mlir/IR/Builders.h" #include <mlir/IR/Builders.h>
#include "mlir/IR/Dialect.h" #include <mlir/IR/BuiltinOps.h>
#include "mlir/IR/Function.h" #include <mlir/IR/BuiltinTypes.h>
#include "mlir/IR/Matchers.h" #include <mlir/IR/Dialect.h>
#include "mlir/IR/Module.h" #include <mlir/IR/Matchers.h>
#include "mlir/IR/OpImplementation.h" #include <mlir/IR/OpImplementation.h>
#include "mlir/IR/StandardTypes.h" #include <mlir/IR/TypeUtilities.h>
#include "mlir/IR/TypeUtilities.h" #include <mlir/Interfaces/CallInterfaces.h>
#include "mlir/Interfaces/CallInterfaces.h" #include <mlir/Interfaces/DerivedAttributeOpInterface.h>
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" #include <mlir/Interfaces/InferTypeOpInterface.h>
#include "mlir/Interfaces/InferTypeOpInterface.h" #include <mlir/Interfaces/LoopLikeInterface.h>
#include "mlir/Interfaces/LoopLikeInterface.h" #include <mlir/Interfaces/SideEffectInterfaces.h>
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir { namespace mlir {
namespace pd { namespace pd {
...@@ -53,9 +52,8 @@ class PaddleDialect : public Dialect { ...@@ -53,9 +52,8 @@ class PaddleDialect : public Dialect {
} }
}; };
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.hpp.inc"
#undef GET_OP_CLASSES
} // namespace pd } // namespace pd
} // namespace mlir } // namespace mlir
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.hpp.inc"
...@@ -24,6 +24,16 @@ def PD_FeedOp : PD_Op<"feed"> { ...@@ -24,6 +24,16 @@ def PD_FeedOp : PD_Op<"feed"> {
def PD_FetchOp : PD_Op<"fetch", [Terminator]> { def PD_FetchOp : PD_Op<"fetch", [Terminator]> {
let summary = "fetch Op"; let summary = "fetch Op";
let description = [{
Return the output tensor from the subgraph.
}];
let arguments = (ins PD_Tensor :$inputs, StrAttr:$name);
}
def PD_ReturnOp : PD_Op<"return", [Terminator]> {
let summary = "return Op";
let description = [{ let description = [{
Fetch tensor from the graph. Fetch tensor from the graph.
}]; }];
...@@ -31,7 +41,7 @@ def PD_FetchOp : PD_Op<"fetch", [Terminator]> { ...@@ -31,7 +41,7 @@ def PD_FetchOp : PD_Op<"fetch", [Terminator]> {
let arguments = (ins Variadic<PD_Tensor>:$inputs); let arguments = (ins Variadic<PD_Tensor>:$inputs);
} }
def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> { def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"ReturnOp">]> {
let summary = "paddle graph Op"; let summary = "paddle graph Op";
let description = [{ let description = [{
Describe a paddle graph or subgraph. Describe a paddle graph or subgraph.
...@@ -50,7 +60,7 @@ def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInte ...@@ -50,7 +60,7 @@ def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInte
let hasFolder = 1; let hasFolder = 1;
let builders = [ let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Attribute value">, OpBuilder<(ins "Attribute":$value)>,
]; ];
} }
......
...@@ -18,12 +18,11 @@ ...@@ -18,12 +18,11 @@
#pragma once #pragma once
#include "mlir/IR/Diagnostics.h" #include <mlir/IR/Diagnostics.h>
#include "mlir/IR/Location.h" #include <mlir/IR/Location.h>
#include "mlir/IR/Operation.h" #include <mlir/IR/Operation.h>
#include "mlir/IR/StandardTypes.h" #include <mlir/IR/TypeUtilities.h>
#include "mlir/IR/TypeUtilities.h" #include <mlir/IR/Types.h>
#include "mlir/IR/Types.h"
namespace mlir { namespace mlir {
namespace PD { namespace PD {
......
...@@ -11,26 +11,25 @@ ...@@ -11,26 +11,25 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <llvm/ADT/Optional.h>
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/ScopedPrinter.h>
#include <mlir/IR/BuiltinOps.h>
#include <llvm/Support/raw_os_ostream.hv
#include <llvm/Support/raw_ostream.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/AsmState.h>
#include <mlir/IR/Block.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/Operation.h>
#include <mlir/IR/Region.h>
#include <mlir/IR/Verifier.h>
#include <mlir/Parser.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Transforms/Passes.h>
#include <iostream> #include <iostream>
#include "llvm/ADT/Optional.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/raw_os_ostream.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#include "paddle/infrt/common/global.h" #include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h" #include "paddle/infrt/dialect/init_infrt_dialects.h"
...@@ -114,17 +113,15 @@ int main(int argc, char **argv) { ...@@ -114,17 +113,15 @@ int main(int argc, char **argv) {
mlir::registerPassManagerCLOptions(); mlir::registerPassManagerCLOptions();
cl::ParseCommandLineOptions(argc, argv, "mlir demo"); cl::ParseCommandLineOptions(argc, argv, "mlir demo");
mlir::MLIRContext *context = infrt::Global::getMLIRContext(); mlir::DialectRegistry registry;
// context->allowUnregisteredDialects(); infrt::registerCinnDialects(registry);
auto &registry = context->getDialectRegistry(); mlir::MLIRContext context(registry);
infrt::RegisterCinnDialects(registry);
// mlir will verify module automatically after parsing. // mlir will verify module automatically after parsing.
// https://github.com/llvm/llvm-project/blob/38d18d93534d290d045bbbfa86337e70f1139dc2/mlir/lib/Parser/Parser.cpp#L2051 // https://github.com/llvm/llvm-project/blob/38d18d93534d290d045bbbfa86337e70f1139dc2/mlir/lib/Parser/Parser.cpp#L2051
// mlir::OwningModuleRef module_ref = mlir::parseSourceString(mlir_source, // mlir::OwningModuleRef module_ref = mlir::parseSourceString(mlir_source,
// context); // context);
mlir::OwningModuleRef module_ref = mlir::OwningModuleRef module_ref =
mlir::parseSourceFile(inputFilename, context); mlir::parseSourceFile(inputFilename, &context);
std::cout << "----------print IR Structure begin----------" << std::endl; std::cout << "----------print IR Structure begin----------" << std::endl;
printOperation(module_ref->getOperation(), 0); printOperation(module_ref->getOperation(), 0);
std::cout << "----------print IR Structure end----------" << std::endl; std::cout << "----------print IR Structure end----------" << std::endl;
......
...@@ -17,16 +17,16 @@ ...@@ -17,16 +17,16 @@
#include <llvm/ADT/STLExtras.h> #include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h> #include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h> #include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/DialectImplementation.h> #include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OpDefinition.h> #include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h> #include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h> #include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h> #include <mlir/Support/LogicalResult.h>
namespace infrt::ts { namespace infrt {
namespace ts {
using namespace mlir; // NOLINT using namespace mlir; // NOLINT
void TensorShapeDialect::initialize() { void TensorShapeDialect::initialize() {
...@@ -48,8 +48,8 @@ Type TensorShapeDialect::parseType(DialectAsmParser &parser) const { ...@@ -48,8 +48,8 @@ Type TensorShapeDialect::parseType(DialectAsmParser &parser) const {
return Type(); return Type();
} }
void TensorShapeDialect::printType(::mlir::Type type, void TensorShapeDialect::printType(mlir::Type type,
::mlir::DialectAsmPrinter &os) const { mlir::DialectAsmPrinter &os) const {
if (type.isa<ShapeType>()) { if (type.isa<ShapeType>()) {
os << "shape"; os << "shape";
return; return;
...@@ -61,8 +61,10 @@ void TensorShapeDialect::printType(::mlir::Type type, ...@@ -61,8 +61,10 @@ void TensorShapeDialect::printType(::mlir::Type type,
} }
llvm_unreachable("unexpected 'shape' type kind"); llvm_unreachable("unexpected 'shape' type kind");
} }
} // namespace ts
} // namespace infrt
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensor_shape.cpp.inc" // NOLINT #include "paddle/infrt/dialect/tensor_shape.cpp.inc" // NOLINT
} // namespace infrt::ts #include "paddle/infrt/dialect/tensor_shape_dialect.cpp.inc"
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
#include <mlir/IR/OpDefinition.h> #include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h> #include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt::ts { namespace infrt {
namespace ts {
class ShapeType class ShapeType
: public mlir::Type::TypeBase<ShapeType, mlir::Type, mlir::TypeStorage> { : public mlir::Type::TypeBase<ShapeType, mlir::Type, mlir::TypeStorage> {
...@@ -31,10 +32,9 @@ class PartialShapeType : public mlir::Type::TypeBase<PartialShapeType, ...@@ -31,10 +32,9 @@ class PartialShapeType : public mlir::Type::TypeBase<PartialShapeType,
public: public:
using Base::Base; using Base::Base;
}; };
} // namespace ts
} // namespace infrt
using namespace mlir; // NOLINT
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensor_shape.hpp.inc" #include "paddle/infrt/dialect/tensor_shape.hpp.inc"
#include "paddle/infrt/dialect/tensor_shape_dialect.hpp.inc" #include "paddle/infrt/dialect/tensor_shape_dialect.hpp.inc"
} // namespace infrt::ts
...@@ -19,7 +19,7 @@ def TensorShapeDialect : Dialect { ...@@ -19,7 +19,7 @@ def TensorShapeDialect : Dialect {
def TS_Shape : DialectType<TensorShapeDialect, def TS_Shape : DialectType<TensorShapeDialect,
CPred<"$_self.isa<::infrt::ts::ShapeType>()">, "!ts.shape type">, CPred<"$_self.isa<::infrt::ts::ShapeType>()">, "!ts.shape type">,
BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> { BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> {
let typeDescription = [{ let description = [{
`!ts.shape type` represents a static tensor shape. `!ts.shape type` represents a static tensor shape.
}]; }];
} }
...@@ -27,7 +27,7 @@ BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> { ...@@ -27,7 +27,7 @@ BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> {
def TS_PartialShape : DialectType<TensorShapeDialect, def TS_PartialShape : DialectType<TensorShapeDialect,
CPred<"$_self.isa<::infrt::ts::PartialShapeType>()">, "!ts.partial_shape type">, CPred<"$_self.isa<::infrt::ts::PartialShapeType>()">, "!ts.partial_shape type">,
BuildableType<"$_builder.getType<::infrt::ts::PartialShapeType>()"> { BuildableType<"$_builder.getType<::infrt::ts::PartialShapeType>()"> {
let typeDescription = [{ let description = [{
`!ts.partial_shape type` represents either a static tensor shape, unranked `!ts.partial_shape type` represents either a static tensor shape, unranked
tensor shape or a ranked tensor shape with unknown dimension sizes. tensor shape or a ranked tensor shape with unknown dimension sizes.
}]; }];
......
...@@ -11,10 +11,10 @@ ...@@ -11,10 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <llvm/Support/CommandLine.h>
#include <mlir/Pass/PassManager.h>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include "llvm/Support/CommandLine.h"
#include "mlir/Pass/PassManager.h"
#include "paddle/infrt/common/global.h" #include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/mlir_loader.h" #include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
......
...@@ -14,14 +14,13 @@ ...@@ -14,14 +14,13 @@
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
#include <llvm/ADT/SetVector.h>
#include <mlir/Analysis/SliceAnalysis.h>
#include <mlir/IR/Builders.h>
#include <paddle/infrt/dialect/pd_ops.h>
#include <list> #include <list>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "llvm/ADT/SetVector.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Builders.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt { namespace infrt {
namespace trt { namespace trt {
...@@ -32,9 +31,9 @@ namespace { ...@@ -32,9 +31,9 @@ namespace {
// Reference the function nameed "FlexibleDFS" but defined in: // Reference the function nameed "FlexibleDFS" but defined in:
// paddle/fluid/framework/ir/subgraph_detector.cc. // paddle/fluid/framework/ir/subgraph_detector.cc.
bool reverseDfs(std::vector<::mlir::Operation *> source, bool reverseDfs(std::vector<mlir::Operation *> source,
const std::function<bool(const ::mlir::Operation *)> &func) { const std::function<bool(const mlir::Operation *)> &func) {
std::unordered_set<const ::mlir::Operation *> visited; std::unordered_set<const mlir::Operation *> visited;
while (!source.empty()) { while (!source.empty()) {
auto node = source.back(); auto node = source.back();
source.pop_back(); source.pop_back();
...@@ -44,7 +43,7 @@ bool reverseDfs(std::vector<::mlir::Operation *> source, ...@@ -44,7 +43,7 @@ bool reverseDfs(std::vector<::mlir::Operation *> source,
auto values = node->getOperands(); auto values = node->getOperands();
for (auto value : values) { for (auto value : values) {
// if the value is a block argument, the node is nullptr. // if the value is a block argument, the node is nullptr.
::mlir::Operation *node = value.getDefiningOp(); mlir::Operation *node = value.getDefiningOp();
if (node != nullptr && !visited.count(node)) { if (node != nullptr && !visited.count(node)) {
source.emplace_back(node); source.emplace_back(node);
} }
...@@ -54,19 +53,19 @@ bool reverseDfs(std::vector<::mlir::Operation *> source, ...@@ -54,19 +53,19 @@ bool reverseDfs(std::vector<::mlir::Operation *> source,
} }
// merge the first&second graph op to a new graph op. // merge the first&second graph op to a new graph op.
void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT
::mlir::pd::GraphOp first, mlir::pd::GraphOp first,
::mlir::pd::GraphOp second) { mlir::pd::GraphOp second) {
// comput inputs and outputs // comput inputs and outputs
::llvm::SmallVector<::mlir::Value, 4> inputs(first.getOperands()), outputs; ::llvm::SmallVector<mlir::Value, 4> inputs(first.getOperands()), outputs;
for (::mlir::Value input : second.getOperands()) { for (mlir::Value input : second.getOperands()) {
if (input.getDefiningOp() != first) { if (input.getDefiningOp() != first) {
inputs.push_back(input); inputs.push_back(input);
} }
} }
::llvm::DenseMap<::mlir::Value, unsigned int> op_output_mapping; ::llvm::DenseMap<mlir::Value, unsigned int> op_output_mapping;
for (::mlir::Value output : first.getResults()) { for (mlir::Value output : first.getResults()) {
for (::mlir::Operation *user : output.getUsers()) { for (mlir::Operation *user : output.getUsers()) {
if (user != second && user->getParentOp() != second) { if (user != second && user->getParentOp() != second) {
op_output_mapping[output] = outputs.size(); op_output_mapping[output] = outputs.size();
outputs.push_back(output); outputs.push_back(output);
...@@ -74,19 +73,19 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT ...@@ -74,19 +73,19 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
} }
} }
} }
auto fetch_op = second.getBody()->getTerminator(); auto return_op = second.getBody()->getTerminator();
outputs.append(fetch_op->getOperands().begin(), outputs.append(return_op->getOperands().begin(),
fetch_op->getOperands().end()); return_op->getOperands().end());
::llvm::SmallVector<::mlir::Type, 4> fetch_types; ::llvm::SmallVector<mlir::Type, 4> return_types;
for (auto value : outputs) { for (auto value : outputs) {
fetch_types.push_back(value.getType()); return_types.push_back(value.getType());
} }
// create the new graph op // create the new graph op
builder.setInsertionPoint(first); builder.setInsertionPoint(first);
auto loc = first.getLoc(); auto loc = first.getLoc();
auto graph_op = builder.create<::mlir::pd::GraphOp>(loc, fetch_types, inputs); auto graph_op = builder.create<mlir::pd::GraphOp>(loc, return_types, inputs);
::mlir::Block *block = new ::mlir::Block; mlir::Block *block = new mlir::Block;
auto copy_range = second.getBody()->without_terminator(); auto copy_range = second.getBody()->without_terminator();
block->getOperations().splice(block->begin(), block->getOperations().splice(block->begin(),
second.getBody()->getOperations(), second.getBody()->getOperations(),
...@@ -98,18 +97,18 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT ...@@ -98,18 +97,18 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
copy_range.begin(), copy_range.begin(),
copy_range.end()); copy_range.end());
builder.setInsertionPointToEnd(block); builder.setInsertionPointToEnd(block);
builder.create<mlir::pd::FetchOp>(loc, outputs); builder.create<mlir::pd::ReturnOp>(loc, outputs);
graph_op.body().push_back(block); graph_op.body().push_back(block);
// mapping the output // mapping the output
unsigned int num_result = first.getNumResults(); unsigned int num_result = first.getNumResults();
fetch_op = first.getBody()->getTerminator(); return_op = first.getBody()->getTerminator();
for (unsigned int index = 0; index < num_result; ++index) { for (unsigned int index = 0; index < num_result; ++index) {
auto origin_value = first.getResult(index); auto origin_value = first.getResult(index);
if (op_output_mapping.find(origin_value) == op_output_mapping.end()) { if (op_output_mapping.find(origin_value) == op_output_mapping.end()) {
origin_value.replaceAllUsesWith(fetch_op->getOperand(index)); origin_value.replaceAllUsesWith(return_op->getOperand(index));
} else { } else {
auto inner_value = fetch_op->getOperand(index); auto inner_value = return_op->getOperand(index);
auto outer_value = graph_op.getResult(op_output_mapping[origin_value]); auto outer_value = graph_op.getResult(op_output_mapping[origin_value]);
while (!origin_value.use_empty()) { while (!origin_value.use_empty()) {
auto replace_value = auto replace_value =
...@@ -128,13 +127,13 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT ...@@ -128,13 +127,13 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
// Topological sort the function op. // Topological sort the function op.
void topoSortBlock(mlir::Block &body) { // NOLINT void topoSortBlock(mlir::Block &body) { // NOLINT
llvm::SetVector<Operation *> toSort; llvm::SetVector<mlir::Operation *> toSort;
if (body.empty()) return; if (body.empty()) return;
for (auto it = body.rbegin(); it != body.rend(); ++it) { for (auto it = body.rbegin(); it != body.rend(); ++it) {
toSort.insert(&*it); toSort.insert(&*it);
} }
llvm::SetVector<Operation *> result = llvm::SetVector<mlir::Operation *> result =
::mlir::topologicalSort(std::move(toSort)); mlir::topologicalSort(std::move(toSort));
for (auto *op : result) { for (auto *op : result) {
op->moveBefore(body.getTerminator()); op->moveBefore(body.getTerminator());
} }
...@@ -145,21 +144,21 @@ void topoSortBlock(mlir::Block &body) { // NOLINT ...@@ -145,21 +144,21 @@ void topoSortBlock(mlir::Block &body) { // NOLINT
// Implementation of the trtGraphFusePass. // Implementation of the trtGraphFusePass.
void trtGraphFusePass::runOnFunction() { void trtGraphFusePass::runOnFunction() {
mlir::Block &body = getFunction().front(); mlir::Block &body = getFunction().front();
::mlir::OpBuilder builder(&body, body.begin()); mlir::OpBuilder builder(&body, body.begin());
bool changed = false; bool changed = false;
do { do {
changed = false; changed = false;
for (auto &op : body) { for (auto &op : body) {
::mlir::pd::GraphOp graph_op = mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op); ::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
if (nullptr == graph_op) continue; if (nullptr == graph_op) continue;
for (auto user_op : op.getUsers()) { for (auto user_op : op.getUsers()) {
::mlir::pd::GraphOp user_graph_op = mlir::pd::GraphOp user_graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(user_op); ::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(user_op);
if (nullptr == user_graph_op) continue; if (nullptr == user_graph_op) continue;
// get all dst input nodes except src. // get all dst input nodes except src.
std::vector<::mlir::Operation *> source_nodes; std::vector<mlir::Operation *> source_nodes;
for (auto operand : user_op->getOperands()) { for (auto operand : user_op->getOperands()) {
auto input = operand.getDefiningOp(); auto input = operand.getDefiningOp();
if (input != &op && input != nullptr) { if (input != &op && input != nullptr) {
...@@ -167,9 +166,8 @@ void trtGraphFusePass::runOnFunction() { ...@@ -167,9 +166,8 @@ void trtGraphFusePass::runOnFunction() {
} }
} }
// Reverse DFS from the source_nodes. // Reverse DFS from the source_nodes.
if (!reverseDfs(source_nodes, [&op](const ::mlir::Operation *n) { if (!reverseDfs(source_nodes,
return n == &op; [&op](const mlir::Operation *n) { return n == &op; })) {
})) {
mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op); mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op);
changed = true; changed = true;
break; break;
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "mlir/Pass/Pass.h" #include <mlir/Pass/Pass.h>
namespace infrt { namespace infrt {
namespace trt { namespace trt {
...@@ -28,15 +28,15 @@ namespace trt { ...@@ -28,15 +28,15 @@ namespace trt {
* %a = "pd.feed"()... * %a = "pd.feed"()...
* %c = "pd.graph"(%a) { * %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* "pd.fetch" %m * "pd.return" %m
* } ... * } ...
* %d = "pd.graph"(%c) { * %d = "pd.graph"(%c) {
* %m = "pd.conv3d"(%c)... * %m = "pd.conv3d"(%c)...
* "pd.fetch" %m * "pd.return" %m
* } ... * } ...
* %f = "pd.graph"(%a) { * %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* "pd.fetch" %m * "pd.return" %m
* } ... * } ...
* "pd.fetch" %d, %f * "pd.fetch" %d, %f
* *
...@@ -47,13 +47,13 @@ namespace trt { ...@@ -47,13 +47,13 @@ namespace trt {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)... * %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)... * %s = "pd.conv2d"(%a)...
* "pd.fetch" %n, %s * "pd.return" %n, %s
* } ... * } ...
* "pd.fetch" %d, %f * "pd.fetch" %d, %f
* } * }
*/ */
class trtGraphFusePass class trtGraphFusePass
: public ::mlir::PassWrapper<trtGraphFusePass, ::mlir::FunctionPass> { : public mlir::PassWrapper<trtGraphFusePass, mlir::FunctionPass> {
public: public:
::llvm::StringRef getName() const override { return "trtGraphFusePass"; } ::llvm::StringRef getName() const override { return "trtGraphFusePass"; }
void runOnFunction() override; void runOnFunction() override;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h"
#include "mlir/IR/Builders.h" #include <mlir/IR/Builders.h>
#include "paddle/infrt/dialect/pd_ops.h" #include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h"
...@@ -22,24 +22,24 @@ namespace infrt { ...@@ -22,24 +22,24 @@ namespace infrt {
namespace trt { namespace trt {
// Implementation of the trtGraphSplitPass。 // Implementation of the trtGraphSplitPass。
void trtGraphSplitPass::runOnFunction() { void trtGraphSplitPass::runOnFunction() {
std::vector<::mlir::pd::GraphOp> worklist; std::vector<mlir::pd::GraphOp> worklist;
::mlir::Block& block = getFunction().front(); mlir::Block& block = getFunction().front();
for (auto& op : block) { for (auto& op : block) {
::mlir::pd::GraphOp graph_op = mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op); ::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
if (nullptr != graph_op && if (nullptr != graph_op &&
graph_op.getBody()->getOperations().size() <= min_subgraph_size_) { graph_op.getBody()->getOperations().size() <= min_subgraph_size_) {
worklist.push_back(graph_op); worklist.push_back(graph_op);
} }
} }
while (!worklist.empty()) { while (!worklist.empty()) {
::mlir::pd::GraphOp graph_op = worklist.back(); mlir::pd::GraphOp graph_op = worklist.back();
worklist.pop_back(); worklist.pop_back();
::mlir::Block* body = graph_op.getBody(); mlir::Block* body = graph_op.getBody();
auto fetch_op = body->getTerminator(); auto return_op = body->getTerminator();
graph_op.replaceAllUsesWith(fetch_op->getOperands()); graph_op.replaceAllUsesWith(return_op->getOperands());
auto copy_range = body->without_terminator(); auto copy_range = body->without_terminator();
block.getOperations().splice(::mlir::Block::iterator(graph_op), block.getOperations().splice(mlir::Block::iterator(graph_op),
body->getOperations(), body->getOperations(),
copy_range.begin(), copy_range.begin(),
copy_range.end()); copy_range.end());
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "mlir/Pass/Pass.h" #include <mlir/Pass/Pass.h>
namespace infrt { namespace infrt {
namespace trt { namespace trt {
...@@ -31,9 +31,9 @@ namespace trt { ...@@ -31,9 +31,9 @@ namespace trt {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)... * %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)... * %s = "pd.conv2d"(%a)...
* "pd.fetch" %n, %s * "pd.return" (%n, %s)
* } ... * } ...
* "pd.fetch" %d, %f * "pd.fetch" (%d, %f)
* } * }
* *
* destination func: * destination func:
...@@ -42,11 +42,11 @@ namespace trt { ...@@ -42,11 +42,11 @@ namespace trt {
* %c = "pd.conv2d"(%a) ... * %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ... * %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ... * %f = "pd.conv2d"(%a) ...
* "pd.fetch" %d, %f * "pd.fetch" (%d, %f)
* } * }
*/ */
class trtGraphSplitPass class trtGraphSplitPass
: public ::mlir::PassWrapper<trtGraphSplitPass, ::mlir::FunctionPass> { : public mlir::PassWrapper<trtGraphSplitPass, mlir::FunctionPass> {
public: public:
::llvm::StringRef getName() const override { return "trtGraphSplitPass"; } ::llvm::StringRef getName() const override { return "trtGraphSplitPass"; }
void runOnFunction() override; void runOnFunction() override;
......
...@@ -14,49 +14,48 @@ ...@@ -14,49 +14,48 @@
#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h" #include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h"
#include "mlir/IR/Builders.h" #include <mlir/IR/Builders.h>
#include "paddle/infrt/dialect/pd_ops.h" #include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt { namespace infrt {
namespace trt { namespace trt {
// Implementation of the trtOpTellerPass。 // Implementation of the trtOpTellerPass。
void trtOpTellerPass::runOnFunction() { void trtOpTellerPass::runOnFunction() {
::mlir::Block &body = getFunction().front(); mlir::Block &body = getFunction().front();
std::vector<::mlir::Operation *> worklist; std::vector<mlir::Operation *> worklist;
worklist.reserve(body.getOperations().size()); worklist.reserve(body.getOperations().size());
for (auto &op : body) { for (auto &op : body) {
worklist.push_back(&op); worklist.push_back(&op);
} }
// Build GraphOp. // Build GraphOp.
::mlir::OpBuilder builder(&body, body.begin()); mlir::OpBuilder builder(&body, body.begin());
while (!worklist.empty()) { while (!worklist.empty()) {
auto *op = worklist.back(); auto *op = worklist.back();
worklist.pop_back(); worklist.pop_back();
if (op == nullptr) continue; if (op == nullptr) continue;
auto op1 = ::llvm::dyn_cast_or_null<::mlir::pd::FeedOp>(op); auto op1 = ::llvm::dyn_cast_or_null<mlir::pd::FeedOp>(op);
if (op1) continue; if (op1) continue;
auto op2 = ::llvm::dyn_cast_or_null<::mlir::pd::FetchOp>(op); auto op2 = ::llvm::dyn_cast_or_null<mlir::pd::FetchOp>(op);
if (op2) continue; if (op2) continue;
auto op3 = ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(op); auto op3 = ::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(op);
if (op3) continue; if (op3) continue;
builder.setInsertionPoint(op); builder.setInsertionPoint(op);
auto loc = getFunction().getLoc(); auto loc = getFunction().getLoc();
auto graph_op = builder.create<::mlir::pd::GraphOp>( auto graph_op = builder.create<mlir::pd::GraphOp>(
loc, op->getResultTypes(), op->getOperands()); loc, op->getResultTypes(), op->getOperands());
::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values; ::llvm::SmallVector<mlir::Value, 4> tblgen_repl_values;
for (auto v : for (auto v :
::llvm::SmallVector<::mlir::Value, 4>{graph_op.getODSResults(0)}) { ::llvm::SmallVector<mlir::Value, 4>{graph_op.getODSResults(0)}) {
tblgen_repl_values.push_back(v); tblgen_repl_values.push_back(v);
} }
op->replaceAllUsesWith(tblgen_repl_values); op->replaceAllUsesWith(tblgen_repl_values);
// Build graph op. // Build graph op.
::mlir::Block *block = new ::mlir::Block; mlir::Block *block = new mlir::Block;
graph_op.body().push_back(block); graph_op.body().push_back(block);
op->moveBefore(block, block->begin()); op->moveBefore(block, block->begin());
builder.setInsertionPointToEnd(block); builder.setInsertionPointToEnd(block);
builder.create<mlir::pd::FetchOp>(loc, op->getResults()); builder.create<mlir::pd::ReturnOp>(loc, op->getResults());
} }
} }
} // namespace trt } // namespace trt
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "mlir/Pass/Pass.h" #include <mlir/Pass/Pass.h>
namespace infrt { namespace infrt {
namespace trt { namespace trt {
...@@ -29,7 +29,7 @@ namespace trt { ...@@ -29,7 +29,7 @@ namespace trt {
* %c = "pd.conv2d"(%a) ... * %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ... * %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ... * %f = "pd.conv2d"(%a) ...
* "pd.fetch" %d, %f * "pd.fetch" (%d, %f)
* } * }
* *
* destination func: * destination func:
...@@ -37,23 +37,23 @@ namespace trt { ...@@ -37,23 +37,23 @@ namespace trt {
* %a = "pd.feed"()... * %a = "pd.feed"()...
* %c = "pd.graph"(%a) { * %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* "pd.fetch" %m * "pd.return" (%m)
* } ... * } ...
* %d = "pd.graph"(%c) { * %d = "pd.graph"(%c) {
* %m = "pd.conv3d"(%c)... * %m = "pd.conv3d"(%c)...
* "pd.fetch" %m * "pd.return" (%m)
* } ... * } ...
* %f = "pd.graph"(%a) { * %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)... * %m = "pd.conv2d"(%a)...
* "pd.fetch" %m * "pd.return" (%m)
* } ... * } ...
* "pd.fetch" %d, %f * "pd.fetch" (%d, %f)
* } * }
* TODO(winter-wang): Supplementary how to judge the operators can be supported * TODO(winter-wang): Supplementary how to judge the operators can be supported
* by tensorrt. * by tensorrt.
*/ */
class trtOpTellerPass class trtOpTellerPass
: public ::mlir::PassWrapper<trtOpTellerPass, ::mlir::FunctionPass> { : public mlir::PassWrapper<trtOpTellerPass, mlir::FunctionPass> {
public: public:
::llvm::StringRef getName() const override { return "trtOpTellerPass"; } ::llvm::StringRef getName() const override { return "trtOpTellerPass"; }
void runOnFunction() override; void runOnFunction() override;
......
...@@ -13,27 +13,25 @@ ...@@ -13,27 +13,25 @@
// limitations under the License. // limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_ops.h" #include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include "mlir/IR/Matchers.h" #include <mlir/IR/Matchers.h>
#include "mlir/IR/OpImplementation.h" #include <mlir/IR/OpImplementation.h>
#include "mlir/IR/PatternMatch.h" #include <mlir/IR/PatternMatch.h>
#include "mlir/Interfaces/CallInterfaces.h" #include <mlir/Interfaces/CallInterfaces.h>
#include "mlir/Interfaces/SideEffectInterfaces.h" #include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt { namespace infrt {
namespace trt { namespace trt {
TensorRTDialect::TensorRTDialect(::mlir::MLIRContext *context) TensorRTDialect::TensorRTDialect(mlir::MLIRContext *context)
: ::mlir::Dialect("trt", context, ::mlir::TypeID::get<TensorRTDialect>()) { : mlir::Dialect("trt", context, mlir::TypeID::get<TensorRTDialect>()) {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT #include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
>(); >();
#undef GET_OP_LIST
} }
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
#undef GET_OP_CLASSES
} // namespace trt } // namespace trt
} // namespace infrt } // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
...@@ -14,37 +14,32 @@ ...@@ -14,37 +14,32 @@
#pragma once #pragma once
#include "mlir/Dialect/Traits.h" #include <mlir/Dialect/Traits.h>
#include "mlir/IR/Attributes.h" #include <mlir/IR/Attributes.h>
#include "mlir/IR/Builders.h" #include <mlir/IR/Builders.h>
#include "mlir/IR/Dialect.h" #include <mlir/IR/BuiltinOps.h>
#include "mlir/IR/Function.h" #include <mlir/IR/BuiltinTypes.h>
#include "mlir/IR/Matchers.h" #include <mlir/IR/Dialect.h>
#include "mlir/IR/Module.h" #include <mlir/IR/Matchers.h>
#include "mlir/IR/OpImplementation.h" #include <mlir/IR/OpImplementation.h>
#include "mlir/IR/StandardTypes.h" #include <mlir/IR/TypeUtilities.h>
#include "mlir/IR/TypeUtilities.h" #include <mlir/Interfaces/CallInterfaces.h>
#include "mlir/Interfaces/CallInterfaces.h" #include <mlir/Interfaces/DerivedAttributeOpInterface.h>
#include "mlir/Interfaces/DerivedAttributeOpInterface.h" #include <mlir/Interfaces/InferTypeOpInterface.h>
#include "mlir/Interfaces/InferTypeOpInterface.h" #include <mlir/Interfaces/LoopLikeInterface.h>
#include "mlir/Interfaces/LoopLikeInterface.h" #include <mlir/Interfaces/SideEffectInterfaces.h>
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace infrt { namespace infrt {
namespace trt { namespace trt {
class TensorRTDialect : public ::mlir::Dialect { class TensorRTDialect : public mlir::Dialect {
public: public:
explicit TensorRTDialect(::mlir::MLIRContext* context); explicit TensorRTDialect(mlir::MLIRContext* context);
static llvm::StringRef getDialectNamespace() { return "trt"; } static llvm::StringRef getDialectNamespace() { return "trt"; }
}; };
// mlir bug。 can be removed safety when update mlir to llvm11. } // namespace trt
using namespace mlir; // NOLINT } // namespace infrt
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensorrt/trt_ops.hpp.inc" #include "paddle/infrt/dialect/tensorrt/trt_ops.hpp.inc"
#undef GET_OP_CLASSES
} // namespace trt
} // namespace infrt
...@@ -14,14 +14,13 @@ ...@@ -14,14 +14,13 @@
#include "paddle/infrt/dialect/test_kernels.h" #include "paddle/infrt/dialect/test_kernels.h"
#include "mlir/IR/Builders.h" #include <mlir/IR/Builders.h>
#include "mlir/IR/OpDefinition.h" #include <mlir/IR/OpDefinition.h>
#include "mlir/IR/OpImplementation.h" #include <mlir/IR/OpImplementation.h>
#include "mlir/IR/StandardTypes.h" #include <mlir/IR/TypeUtilities.h>
#include "mlir/IR/TypeUtilities.h"
namespace infrt::dialect {
namespace infrt {
namespace dialect {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// BenchmarkOp // BenchmarkOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
...@@ -32,65 +31,67 @@ namespace infrt::dialect { ...@@ -32,65 +31,67 @@ namespace infrt::dialect {
// ... // ...
// } // }
static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT static mlir::ParseResult parseBenchmarkOp(
OperationState &result) { // NOLINT mlir::OpAsmParser &parser, // NOLINT
StringAttr nameAttr; mlir::OperationState &result) { // NOLINT
mlir::StringAttr nameAttr;
if (parser.parseAttribute(nameAttr, "name", result.attributes)) if (parser.parseAttribute(nameAttr, "name", result.attributes))
return failure(); return mlir::failure();
// Parse the operands, e.g. (%c : i32, %d : f32) // Parse the operands, e.g. (%c : i32, %d : f32)
if (parser.parseLParen()) return failure(); if (parser.parseLParen()) return mlir::failure();
SmallVector<OpAsmParser::OperandType, 4> operands; llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands;
SmallVector<Type, 4> types; llvm::SmallVector<mlir::Type, 4> types;
llvm::SMLoc type_loc = parser.getCurrentLocation(); llvm::SMLoc type_loc = parser.getCurrentLocation();
if (parser.parseOptionalRParen()) { if (parser.parseOptionalRParen()) {
// Parse non-empty operands // Parse non-empty operands
do { do {
// Parse %c : i32, // Parse %c : i32,
OpAsmParser::OperandType operand; mlir::OpAsmParser::OperandType operand;
Type type; mlir::Type type;
if (parser.parseOperand(operand) || parser.parseColonType(type)) if (parser.parseOperand(operand) || parser.parseColonType(type))
return failure(); return mlir::failure();
operands.push_back(operand); operands.push_back(operand);
types.push_back(type); types.push_back(type);
} while (succeeded(parser.parseOptionalComma())); } while (succeeded(parser.parseOptionalComma()));
if (parser.parseRParen()) return failure(); if (parser.parseRParen()) return mlir::failure();
} }
if (parser.resolveOperands(operands, types, type_loc, result.operands)) if (parser.resolveOperands(operands, types, type_loc, result.operands))
return failure(); return mlir::failure();
// Parse the keyword attribute, e.g. max_count = 100, duration_secs = 1 // Parse the keyword attribute, e.g. max_count = 100, duration_secs = 1
do { do {
StringRef attr; mlir::StringRef attr;
Attribute resultAttr; mlir::Attribute resultAttr;
if (parser.parseKeyword(&attr) || parser.parseEqual() || if (parser.parseKeyword(&attr) || parser.parseEqual() ||
parser.parseAttribute(resultAttr, parser.parseAttribute(resultAttr,
parser.getBuilder().getIntegerType(32), parser.getBuilder().getIntegerType(32),
attr, attr,
result.attributes)) result.attributes))
return failure(); return mlir::failure();
} while (succeeded(parser.parseOptionalComma())); } while (mlir::succeeded(parser.parseOptionalComma()));
// Set the default attribute num_warmup_runs to 1 if unset // Set the default attribute num_warmup_runs to 1 if unset
auto setDefaultAttrIfUnset = [&](const char *attr_name, int value) { auto setDefaultAttrIfUnset = [&](const char *attr_name, int value) {
bool found = llvm::any_of(result.attributes, bool found = llvm::any_of(result.attributes,
[attr_name](const NamedAttribute &attr) { [attr_name](const mlir::NamedAttribute &attr) {
return attr.first == attr_name; return attr.getName() == attr_name;
}); });
if (!found) { if (!found) {
IntegerAttr default_val = parser.getBuilder().getI32IntegerAttr(value); mlir::IntegerAttr default_val =
parser.getBuilder().getI32IntegerAttr(value);
result.addAttribute(attr_name, default_val); result.addAttribute(attr_name, default_val);
} }
}; };
setDefaultAttrIfUnset("num_warmup_runs", 1); setDefaultAttrIfUnset("num_warmup_runs", 1);
Region *target = result.addRegion(); mlir::Region *target = result.addRegion();
return parser.parseRegion(*target, return parser.parseRegion(*target,
operands, operands,
types, types,
...@@ -102,11 +103,11 @@ static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT ...@@ -102,11 +103,11 @@ static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT
// max_count = 100, duration_secs = 1 { // max_count = 100, duration_secs = 1 {
// ... // ...
// } // }
static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT static void print(mlir::OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
p << "infrt.benchmark "; p << "infrt.benchmark ";
// Print the name attribute, e.g "add.i32" // Print the name attribute, e.g "add.i32"
auto name_attr = op.getAttr("name"); auto name_attr = op->getAttr("name");
p << name_attr; p << name_attr;
// Print the operands and types, e.g. (%c : i32, %d : f32) // Print the operands and types, e.g. (%c : i32, %d : f32)
...@@ -120,13 +121,13 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT ...@@ -120,13 +121,13 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
bool need_comma = false; bool need_comma = false;
// Print the attributes, e.g. max_count = 100, duration_secs = 1 // Print the attributes, e.g. max_count = 100, duration_secs = 1
for (auto &name_attr : op.getAttrs()) { for (auto &name_attr : op->getAttrs()) {
auto id = name_attr.first; auto id = name_attr.getName();
if (id == "name") continue; if (id == "name") continue;
if (need_comma) p << ", "; if (need_comma) p << ", ";
auto attr = name_attr.second; auto attr = name_attr.getValue();
p << id << " = "; p << id << " = ";
if (auto int_attr = attr.dyn_cast<IntegerAttr>()) { if (auto int_attr = attr.dyn_cast<mlir::IntegerAttr>()) {
int_attr.getValue().print(p.getStream(), /*isSigned=*/false); int_attr.getValue().print(p.getStream(), /*isSigned=*/false);
} else { } else {
op.emitOpError("Unexpected attribute"); op.emitOpError("Unexpected attribute");
...@@ -142,7 +143,7 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT ...@@ -142,7 +143,7 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
p.printRegion(op.region(), /*printEntryBlockArgs=*/false); p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
} }
static LogicalResult verify(BenchmarkOp op) { static mlir::LogicalResult verify(BenchmarkOp op) {
// Verify that the target benchmark region has exactly one return value. // Verify that the target benchmark region has exactly one return value.
auto &region = op.region(); auto &region = op.region();
auto &last_op = region.front().back(); auto &last_op = region.front().back();
...@@ -154,10 +155,10 @@ static LogicalResult verify(BenchmarkOp op) { ...@@ -154,10 +155,10 @@ static LogicalResult verify(BenchmarkOp op) {
"incorrect number of return values. One return value is expected"); "incorrect number of return values. One return value is expected");
} }
return success(); return mlir::success();
} }
} // namespace dialect
} // namespace infrt
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/test_kernels.cpp.inc" #include "paddle/infrt/dialect/test_kernels.cpp.inc"
} // namespace infrt::dialect
...@@ -13,11 +13,8 @@ ...@@ -13,11 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "mlir/IR/OpDefinition.h" #include <mlir/IR/OpDefinition.h>
#include "mlir/Interfaces/SideEffectInterfaces.h" #include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt::dialect {
using namespace mlir; // NOLINT
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/test_kernels.hpp.inc" #include "paddle/infrt/dialect/test_kernels.hpp.inc"
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/types.h"
namespace infrt::hlir::mlir {} // namespace infrt::hlir::mlir
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/IR/StandardTypes.h>
...@@ -23,7 +23,8 @@ ...@@ -23,7 +23,8 @@
#include "paddle/infrt/host_context/op_executable.h" #include "paddle/infrt/host_context/op_executable.h"
#include "paddle/infrt/host_context/symbol_table.h" #include "paddle/infrt/host_context/symbol_table.h"
namespace infrt::host_context { namespace infrt {
namespace host_context {
struct CoreRuntime::Impl { struct CoreRuntime::Impl {
KernelRegistry* kernel_registry{}; KernelRegistry* kernel_registry{};
...@@ -90,4 +91,5 @@ llvm::SmallVector<ValueRef, 4> CoreRuntime::GetResults( ...@@ -90,4 +91,5 @@ llvm::SmallVector<ValueRef, 4> CoreRuntime::GetResults(
CoreRuntime::~CoreRuntime() {} CoreRuntime::~CoreRuntime() {}
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
#include "paddle/infrt/host_context/value.h" #include "paddle/infrt/host_context/value.h"
namespace infrt::host_context { namespace infrt {
namespace host_context {
class KernelRegistry; class KernelRegistry;
class OpExecutable; class OpExecutable;
...@@ -83,4 +84,5 @@ class CoreRuntimeBuilder : public CoreRuntime { ...@@ -83,4 +84,5 @@ class CoreRuntimeBuilder : public CoreRuntime {
OpExecutableBuilder* NewOpExecutable(const std::string& op_name); OpExecutableBuilder* NewOpExecutable(const std::string& op_name);
}; };
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -21,7 +21,8 @@ ...@@ -21,7 +21,8 @@
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "paddle/infrt/host_context/value.h" #include "paddle/infrt/host_context/value.h"
namespace infrt::host_context { namespace infrt {
namespace host_context {
/** /**
* KernelFrame captures the states(input arguments, attributes, results) * KernelFrame captures the states(input arguments, attributes, results)
...@@ -163,4 +164,5 @@ class KernelFrameBuilder : public KernelFrame { ...@@ -163,4 +164,5 @@ class KernelFrameBuilder : public KernelFrame {
} }
}; };
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
#include "paddle/infrt/host_context/kernel_utils.h" #include "paddle/infrt/host_context/kernel_utils.h"
namespace infrt::host_context { namespace infrt {
namespace host_context {
int add_i32(int a, int b) { return a + b; } int add_i32(int a, int b) { return a + b; }
...@@ -44,4 +45,5 @@ TEST(KernelRegistry, basic) { ...@@ -44,4 +45,5 @@ TEST(KernelRegistry, basic) {
ASSERT_EQ(results[0]->get<int>(), 3); ASSERT_EQ(results[0]->get<int>(), 3);
} }
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
namespace infrt::host_context { namespace infrt {
namespace host_context {
int add_i32(int a, int b) { return a + b; } int add_i32(int a, int b) { return a + b; }
float add_f32(float a, float b) { return a + b; } float add_f32(float a, float b) { return a + b; }
...@@ -66,4 +67,5 @@ TEST(KernelImpl, pair) { ...@@ -66,4 +67,5 @@ TEST(KernelImpl, pair) {
ASSERT_EQ(results[1]->get<float>(), 3.f); ASSERT_EQ(results[1]->get<float>(), 3.f);
} }
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/infrt/host_context/mlir_function_executable.h" #include "paddle/infrt/host_context/mlir_function_executable.h"
#include <glog/logging.h> #include <glog/logging.h>
#include <mlir/IR/BuiltinOps.h>
#include <string> // NOLINT #include <string> // NOLINT
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <mlir/IR/Function.h> #include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Region.h>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
#pragma once #pragma once
#include <mlir/Dialect/StandardOps/IR/Ops.h> #include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Diagnostics.h> #include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OperationSupport.h> #include <mlir/IR/OperationSupport.h>
#include <unordered_map> #include <unordered_map>
......
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
#include <llvm/Support/SourceMgr.h> #include <llvm/Support/SourceMgr.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h> #include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Diagnostics.h> #include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OperationSupport.h> #include <mlir/IR/OperationSupport.h>
#include <mlir/Parser.h> #include <mlir/Parser.h>
...@@ -40,7 +41,8 @@ ...@@ -40,7 +41,8 @@
#include "paddle/infrt/host_context/value.h" #include "paddle/infrt/host_context/value.h"
#include "paddle/infrt/tensor/tensor_shape.h" #include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::host_context { namespace infrt {
namespace host_context {
template <typename T> template <typename T>
std::string DumpToString(T& op) { // NOLINT std::string DumpToString(T& op) { // NOLINT
...@@ -113,10 +115,10 @@ bool MlirToRuntimeTranslator::EmitConstantOp(mlir::Operation* op) { ...@@ -113,10 +115,10 @@ bool MlirToRuntimeTranslator::EmitConstantOp(mlir::Operation* op) {
template <> template <>
boost::optional<int32_t> MlirToRuntimeTranslator::EmitAttribute( boost::optional<int32_t> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) { const mlir::Attribute& attr) {
if (!attr->isa<mlir::IntegerAttr>()) return boost::none; if (!attr.isa<mlir::IntegerAttr>()) return boost::none;
if (attr->isa<mlir::IntegerAttr>()) { if (attr.isa<mlir::IntegerAttr>()) {
auto val = attr->cast<mlir::IntegerAttr>(); auto val = attr.cast<mlir::IntegerAttr>();
if (val.getType().isInteger(32)) { if (val.getType().isInteger(32)) {
return val.getInt(); return val.getInt();
} }
...@@ -125,10 +127,10 @@ boost::optional<int32_t> MlirToRuntimeTranslator::EmitAttribute( ...@@ -125,10 +127,10 @@ boost::optional<int32_t> MlirToRuntimeTranslator::EmitAttribute(
} }
template <> template <>
boost::optional<int64_t> MlirToRuntimeTranslator::EmitAttribute( boost::optional<int64_t> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) { const mlir::Attribute& attr) {
if (!attr->isa<mlir::IntegerAttr>()) return boost::none; if (!attr.isa<mlir::IntegerAttr>()) return boost::none;
if (attr->isa<mlir::IntegerAttr>()) { if (attr.isa<mlir::IntegerAttr>()) {
auto val = attr->cast<mlir::IntegerAttr>(); auto val = attr.cast<mlir::IntegerAttr>();
if (val.getType().isInteger(64)) { if (val.getType().isInteger(64)) {
return val.getInt(); return val.getInt();
} }
...@@ -139,10 +141,10 @@ boost::optional<int64_t> MlirToRuntimeTranslator::EmitAttribute( ...@@ -139,10 +141,10 @@ boost::optional<int64_t> MlirToRuntimeTranslator::EmitAttribute(
// TODO(Superjomn) Make double and float parsing share some thing. // TODO(Superjomn) Make double and float parsing share some thing.
template <> template <>
boost::optional<float> MlirToRuntimeTranslator::EmitAttribute( boost::optional<float> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) { const mlir::Attribute& attr) {
if (!attr->isa<mlir::FloatAttr>()) return boost::none; if (!attr.isa<mlir::FloatAttr>()) return boost::none;
if (attr->isa<mlir::FloatAttr>()) { if (attr.isa<mlir::FloatAttr>()) {
auto val = attr->cast<mlir::FloatAttr>(); auto val = attr.cast<mlir::FloatAttr>();
if (val.getType().isF32()) return val.getValueAsDouble(); if (val.getType().isF32()) return val.getValueAsDouble();
} }
return boost::none; return boost::none;
...@@ -150,10 +152,10 @@ boost::optional<float> MlirToRuntimeTranslator::EmitAttribute( ...@@ -150,10 +152,10 @@ boost::optional<float> MlirToRuntimeTranslator::EmitAttribute(
template <> template <>
boost::optional<double> MlirToRuntimeTranslator::EmitAttribute( boost::optional<double> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) { const mlir::Attribute& attr) {
if (!attr->isa<mlir::FloatAttr>()) return boost::none; if (!attr.isa<mlir::FloatAttr>()) return boost::none;
if (attr->isa<mlir::FloatAttr>()) { if (attr.isa<mlir::FloatAttr>()) {
auto val = attr->cast<mlir::FloatAttr>(); auto val = attr.cast<mlir::FloatAttr>();
if (val.getType().isF64()) return val.getValueAsDouble(); if (val.getType().isF64()) return val.getValueAsDouble();
} }
return boost::none; return boost::none;
...@@ -161,17 +163,17 @@ boost::optional<double> MlirToRuntimeTranslator::EmitAttribute( ...@@ -161,17 +163,17 @@ boost::optional<double> MlirToRuntimeTranslator::EmitAttribute(
template <> template <>
boost::optional<std::string> MlirToRuntimeTranslator::EmitAttribute( boost::optional<std::string> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) { const mlir::Attribute& attr) {
if (!attr->isa<mlir::StringAttr>()) return boost::none; if (!attr.isa<mlir::StringAttr>()) return boost::none;
return attr->cast<mlir::StringAttr>().getValue().str(); return attr.cast<mlir::StringAttr>().getValue().str();
} }
#define PROCESS_ARRAY_INT(type__, bits__) \ #define PROCESS_ARRAY_INT(type__, bits__) \
template <> \ template <> \
boost::optional<std::vector<type__>> MlirToRuntimeTranslator::EmitAttribute( \ boost::optional<std::vector<type__>> MlirToRuntimeTranslator::EmitAttribute( \
const mlir::Attribute* attr) { \ const mlir::Attribute& attr) { \
if (!attr->isa<mlir::ArrayAttr>()) return boost::none; \ if (!attr.isa<mlir::ArrayAttr>()) return boost::none; \
auto array = attr->cast<mlir::ArrayAttr>(); \ auto array = attr.cast<mlir::ArrayAttr>(); \
CHECK(!array.empty()); \ CHECK(!array.empty()); \
\ \
if (!array[0].getType().isInteger(bits__)) { \ if (!array[0].getType().isInteger(bits__)) { \
...@@ -191,9 +193,9 @@ PROCESS_ARRAY_INT(int64_t, 64); ...@@ -191,9 +193,9 @@ PROCESS_ARRAY_INT(int64_t, 64);
template <> template <>
boost::optional<std::vector<float>> MlirToRuntimeTranslator::EmitAttribute( boost::optional<std::vector<float>> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) { const mlir::Attribute& attr) {
if (!attr->isa<mlir::ArrayAttr>()) return boost::none; if (!attr.isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr->cast<mlir::ArrayAttr>(); auto array = attr.cast<mlir::ArrayAttr>();
CHECK(!array.empty()); CHECK(!array.empty());
if (!array[0].getType().isF32()) return boost::none; if (!array[0].getType().isF32()) return boost::none;
...@@ -207,9 +209,9 @@ boost::optional<std::vector<float>> MlirToRuntimeTranslator::EmitAttribute( ...@@ -207,9 +209,9 @@ boost::optional<std::vector<float>> MlirToRuntimeTranslator::EmitAttribute(
template <> template <>
boost::optional<std::vector<double>> MlirToRuntimeTranslator::EmitAttribute( boost::optional<std::vector<double>> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) { const mlir::Attribute& attr) {
if (!attr->isa<mlir::ArrayAttr>()) return boost::none; if (!attr.isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr->cast<mlir::ArrayAttr>(); auto array = attr.cast<mlir::ArrayAttr>();
CHECK(!array.empty()); CHECK(!array.empty());
if (!array[0].getType().isF64()) return boost::none; if (!array[0].getType().isF64()) return boost::none;
...@@ -236,7 +238,8 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { ...@@ -236,7 +238,8 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
for (int i = 0, e = op->getNumOperands(); i < e; i++) { for (int i = 0, e = op->getNumOperands(); i < e; i++) {
// function argument as value // function argument as value
auto operand = op->getOperand(i); auto operand = op->getOperand(i);
if (operand.getKind() == mlir::Value::Kind::BlockArgument) { /// if (operand.getKind() == mlir::Value::Kind::BlockArgument) {
if (operand.isa<mlir::BlockArgument>()) {
mlir::BlockArgument arg = operand.dyn_cast<mlir::BlockArgument>(); mlir::BlockArgument arg = operand.dyn_cast<mlir::BlockArgument>();
Value* arg_value = GetValue(arg); Value* arg_value = GetValue(arg);
impl_->cur_op->AppendArgument(arg_value); impl_->cur_op->AppendArgument(arg_value);
...@@ -283,25 +286,25 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { ...@@ -283,25 +286,25 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
for (size_t i = 0; i < attrs.size(); i++) { for (size_t i = 0; i < attrs.size(); i++) {
auto& attr = attrs[i]; auto& attr = attrs[i];
if (auto v = EmitAttribute<int32_t>(&attr.second)) { if (auto v = EmitAttribute<int32_t>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v)); impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<int64_t>(&attr.second)) { } else if (auto v = EmitAttribute<int64_t>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v)); impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<float>(&attr.second)) { } else if (auto v = EmitAttribute<float>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v)); impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<double>(&attr.second)) { } else if (auto v = EmitAttribute<double>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v)); impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<std::string>(&attr.second)) { } else if (auto v = EmitAttribute<std::string>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int16_t>>(&attr.second)) { } else if (auto v = EmitAttribute<std::vector<int16_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int32_t>>(&attr.second)) { } else if (auto v = EmitAttribute<std::vector<int32_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int64_t>>(&attr.second)) { } else if (auto v = EmitAttribute<std::vector<int64_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<float>>(&attr.second)) { } else if (auto v = EmitAttribute<std::vector<float>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<double>>(&attr.second)) { } else if (auto v = EmitAttribute<std::vector<double>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v))); impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else { } else {
LOG(FATAL) << "Not supported attribute type"; LOG(FATAL) << "Not supported attribute type";
...@@ -330,7 +333,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { ...@@ -330,7 +333,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
llvm::SmallVector<mlir::Type, 0> results; llvm::SmallVector<mlir::Type, 0> results;
auto func_type = auto func_type =
mlir::FunctionType::get(inputs, results, region.getContext()); mlir::FunctionType::get(region.getContext(), inputs, results);
auto* function = impl_->cur_op->CreateFunctionExecutable( auto* function = impl_->cur_op->CreateFunctionExecutable(
&region, func_type, &impl_->func_defs); &region, func_type, &impl_->func_defs);
impl_->cur_op->AppendAttribute(new Value(function)); impl_->cur_op->AppendAttribute(new Value(function));
...@@ -555,4 +558,5 @@ void TestMlir(mlir::ModuleOp module, KernelRegistry* registry) { ...@@ -555,4 +558,5 @@ void TestMlir(mlir::ModuleOp module, KernelRegistry* registry) {
execute.Run(); execute.Run();
} }
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -29,7 +29,8 @@ class Attribute; ...@@ -29,7 +29,8 @@ class Attribute;
class Value; class Value;
} // namespace mlir } // namespace mlir
namespace infrt::host_context { namespace infrt {
namespace host_context {
class CoreRuntimeBuilder; class CoreRuntimeBuilder;
class Value; class Value;
...@@ -73,7 +74,7 @@ class MlirToRuntimeTranslator { ...@@ -73,7 +74,7 @@ class MlirToRuntimeTranslator {
bool EmitCallOp(mlir::Operation* op, function_defs_t* function_table); bool EmitCallOp(mlir::Operation* op, function_defs_t* function_table);
template <typename T> template <typename T>
boost::optional<T> EmitAttribute(const mlir::Attribute* attr); boost::optional<T> EmitAttribute(const mlir::Attribute& attr);
Value* GetOpResult(mlir::Operation* op); Value* GetOpResult(mlir::Operation* op);
...@@ -104,4 +105,5 @@ void MlirToRuntimeTranslate(mlir::ModuleOp module, CoreRuntimeBuilder* runtime); ...@@ -104,4 +105,5 @@ void MlirToRuntimeTranslate(mlir::ModuleOp module, CoreRuntimeBuilder* runtime);
*/ */
void TestMlir(mlir::ModuleOp module, KernelRegistry* registry); void TestMlir(mlir::ModuleOp module, KernelRegistry* registry);
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -29,7 +29,8 @@ ...@@ -29,7 +29,8 @@
#include "paddle/infrt/kernel/tensor_shape_kernels.h" #include "paddle/infrt/kernel/tensor_shape_kernels.h"
#include "paddle/infrt/kernel/test_kernels.h" #include "paddle/infrt/kernel/test_kernels.h"
namespace infrt::host_context { namespace infrt {
namespace host_context {
TEST(MlirToRuntimeTranslate, basic) { TEST(MlirToRuntimeTranslate, basic) {
mlir::MLIRContext context; mlir::MLIRContext context;
...@@ -48,7 +49,7 @@ func @main() -> () { ...@@ -48,7 +49,7 @@ func @main() -> () {
)ROC"; )ROC";
auto module = dialect::LoadMlirSource(&context, source); auto module = dialect::LoadMlirSource(&context, source);
module->verify(); EXPECT_TRUE(mlir::succeeded(module->verify()));
KernelRegistry registry; KernelRegistry registry;
kernel::RegisterFloatBasicKernels(&registry); kernel::RegisterFloatBasicKernels(&registry);
...@@ -74,7 +75,7 @@ func @main() -> () { ...@@ -74,7 +75,7 @@ func @main() -> () {
)ROC"; )ROC";
auto module = dialect::LoadMlirSource(&context, source); auto module = dialect::LoadMlirSource(&context, source);
module->verify(); EXPECT_TRUE(mlir::succeeded(module->verify()));
KernelRegistry registry; KernelRegistry registry;
kernel::RegisterFloatBasicKernels(&registry); kernel::RegisterFloatBasicKernels(&registry);
...@@ -115,7 +116,7 @@ infrt.return %a0, %b0: !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F ...@@ -115,7 +116,7 @@ infrt.return %a0, %b0: !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F
// LOG(INFO) << "content: " << content << std::endl; // LOG(INFO) << "content: " << content << std::endl;
auto module = dialect::LoadMlirSource(context, content); auto module = dialect::LoadMlirSource(context, content);
module->verify(); EXPECT_TRUE(mlir::succeeded(module->verify()));
host_context::KernelRegistry registry; host_context::KernelRegistry registry;
...@@ -157,4 +158,5 @@ infrt.return %a0, %b0: !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F ...@@ -157,4 +158,5 @@ infrt.return %a0, %b0: !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F
} }
} }
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/infrt/host_context/op_executable.h" #include "paddle/infrt/host_context/op_executable.h"
#include <mlir/IR/BuiltinOps.h>
#include <string> #include <string>
#include "paddle/infrt/host_context/kernel_frame.h" #include "paddle/infrt/host_context/kernel_frame.h"
...@@ -21,7 +22,8 @@ ...@@ -21,7 +22,8 @@
#include "paddle/infrt/host_context/mlir_function_executable.h" #include "paddle/infrt/host_context/mlir_function_executable.h"
#include "paddle/infrt/host_context/symbol_table.h" #include "paddle/infrt/host_context/symbol_table.h"
namespace infrt::host_context { namespace infrt {
namespace host_context {
struct OpExecutable::Impl { struct OpExecutable::Impl {
Impl(const std::string& op_name, Impl(const std::string& op_name,
...@@ -148,4 +150,5 @@ void OpExecutable::Execute() { ...@@ -148,4 +150,5 @@ void OpExecutable::Execute() {
OpExecutable::~OpExecutable() {} OpExecutable::~OpExecutable() {}
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -14,19 +14,18 @@ ...@@ -14,19 +14,18 @@
#pragma once #pragma once
#include <llvm/ADT/ArrayRef.h> #include <llvm/ADT/ArrayRef.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Region.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "mlir/IR/Function.h"
#include "mlir/IR/Region.h"
namespace mlir { namespace mlir {
class FuncOp; class FuncOp;
} // namespace mlir } // namespace mlir
namespace infrt::host_context { namespace infrt {
namespace host_context {
class SymbolTable; class SymbolTable;
class KernelRegistry; class KernelRegistry;
...@@ -89,4 +88,5 @@ class OpExecutableBuilder : public OpExecutable { ...@@ -89,4 +88,5 @@ class OpExecutableBuilder : public OpExecutable {
function_defs_t* function_defs); function_defs_t* function_defs);
}; };
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
...@@ -23,7 +23,8 @@ ...@@ -23,7 +23,8 @@
using infrt::host_context::Attribute; using infrt::host_context::Attribute;
namespace infrt::kernel { namespace infrt {
namespace kernel {
template <typename T> template <typename T>
T add(T a, T b) { T add(T a, T b) {
...@@ -82,4 +83,5 @@ void RegisterFloatBasicKernels(host_context::KernelRegistry *registry) { ...@@ -82,4 +83,5 @@ void RegisterFloatBasicKernels(host_context::KernelRegistry *registry) {
registry->AddKernel("infrt.print.f32", INFRT_KERNEL(print<float>)); registry->AddKernel("infrt.print.f32", INFRT_KERNEL(print<float>));
} }
} // namespace infrt::kernel } // namespace kernel
} // namespace infrt
...@@ -15,13 +15,16 @@ ...@@ -15,13 +15,16 @@
#pragma once #pragma once
#include <string> #include <string>
namespace infrt::host_context { namespace infrt {
namespace host_context {
struct KernelRegistry; struct KernelRegistry;
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
namespace infrt::kernel { namespace infrt {
namespace kernel {
/** /**
* Register all the basic kernels to \p registry. * Register all the basic kernels to \p registry.
...@@ -31,4 +34,5 @@ void RegisterBasicKernels(host_context::KernelRegistry* registry); ...@@ -31,4 +34,5 @@ void RegisterBasicKernels(host_context::KernelRegistry* registry);
void RegisterIntBasicKernels(host_context::KernelRegistry* registry); void RegisterIntBasicKernels(host_context::KernelRegistry* registry);
void RegisterFloatBasicKernels(host_context::KernelRegistry* registry); void RegisterFloatBasicKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel } // namespace kernel
} // namespace infrt
...@@ -25,7 +25,8 @@ ...@@ -25,7 +25,8 @@
#include "paddle/infrt/tensor/tensor_map.h" #include "paddle/infrt/tensor/tensor_map.h"
#include "paddle/infrt/tensor/tensor_shape.h" #include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::kernel { namespace infrt {
namespace kernel {
using namespace host_context; // NOLINT using namespace host_context; // NOLINT
using namespace tensor; // NOLINT using namespace tensor; // NOLINT
...@@ -76,4 +77,5 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) { ...@@ -76,4 +77,5 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) {
INFRT_KERNEL(ShallowCopyTensor)); INFRT_KERNEL(ShallowCopyTensor));
} }
} // namespace infrt::kernel } // namespace kernel
} // namespace infrt
...@@ -14,12 +14,16 @@ ...@@ -14,12 +14,16 @@
#pragma once #pragma once
namespace infrt::host_context { namespace infrt {
namespace host_context {
struct KernelRegistry; struct KernelRegistry;
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
namespace infrt::kernel { namespace infrt {
namespace kernel {
void RegisterTensorKernels(host_context::KernelRegistry* registry); void RegisterTensorKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel } // namespace kernel
} // namespace infrt
...@@ -24,7 +24,8 @@ ...@@ -24,7 +24,8 @@
#include "paddle/infrt/host_context/kernel_utils.h" #include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/tensor/tensor_shape.h" #include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::kernel { namespace infrt {
namespace kernel {
void PrintShape(const tensor::TensorShape& shape) { void PrintShape(const tensor::TensorShape& shape) {
llvm::raw_os_ostream oos(std::cout); llvm::raw_os_ostream oos(std::cout);
...@@ -35,4 +36,5 @@ void RegisterTensorShapeKernels(host_context::KernelRegistry* registry) { ...@@ -35,4 +36,5 @@ void RegisterTensorShapeKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("ts.print_shape", INFRT_KERNEL(PrintShape)); registry->AddKernel("ts.print_shape", INFRT_KERNEL(PrintShape));
} }
} // namespace infrt::kernel } // namespace kernel
} // namespace infrt
...@@ -14,14 +14,18 @@ ...@@ -14,14 +14,18 @@
#pragma once #pragma once
namespace infrt::host_context { namespace infrt {
namespace host_context {
class KernelRegistry; class KernelRegistry;
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
namespace infrt::kernel { namespace infrt {
namespace kernel {
void RegisterTensorShapeKernels(host_context::KernelRegistry* registry); void RegisterTensorShapeKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel } // namespace kernel
} // namespace infrt
...@@ -33,7 +33,8 @@ using infrt::host_context::Attribute; ...@@ -33,7 +33,8 @@ using infrt::host_context::Attribute;
using infrt::host_context::MlirFunctionExecutable; using infrt::host_context::MlirFunctionExecutable;
using infrt::host_context::RemainingArguments; using infrt::host_context::RemainingArguments;
namespace infrt::kernel { namespace infrt {
namespace kernel {
namespace { namespace {
class BenchmarkStats { class BenchmarkStats {
public: public:
...@@ -197,4 +198,5 @@ void RegisterTestKernels(host_context::KernelRegistry *registry) { ...@@ -197,4 +198,5 @@ void RegisterTestKernels(host_context::KernelRegistry *registry) {
INFRT_KERNEL(ShadowCopyTensor)); INFRT_KERNEL(ShadowCopyTensor));
} }
} // namespace infrt::kernel } // namespace kernel
} // namespace infrt
...@@ -15,17 +15,21 @@ ...@@ -15,17 +15,21 @@
#pragma once #pragma once
#include <string> #include <string>
namespace infrt::host_context { namespace infrt {
namespace host_context {
struct KernelRegistry; struct KernelRegistry;
} // namespace infrt::host_context } // namespace host_context
} // namespace infrt
namespace infrt::kernel { namespace infrt {
namespace kernel {
/** /**
* Register all the test kernels to registry. * Register all the test kernels to registry.
*/ */
void RegisterTestKernels(host_context::KernelRegistry* registry); void RegisterTestKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel } // namespace kernel
} // namespace infrt
...@@ -18,7 +18,9 @@ ...@@ -18,7 +18,9 @@
#include <string> #include <string>
#include <vector> #include <vector>
namespace infrt::paddle::cpp { namespace infrt {
namespace paddle {
namespace cpp {
/* /*
* Compatible interfaces for all the different kinds of XXXDesc. All the XXXDesc * Compatible interfaces for all the different kinds of XXXDesc. All the XXXDesc
...@@ -226,4 +228,6 @@ class ProgramDescAPI { ...@@ -226,4 +228,6 @@ class ProgramDescAPI {
virtual void SetVersion(int64_t version) = 0; virtual void SetVersion(int64_t version) = 0;
}; };
} // namespace infrt::paddle::cpp } // namespace cpp
} // namespace paddle
} // namespace infrt
...@@ -22,7 +22,8 @@ ...@@ -22,7 +22,8 @@
#include "paddle/infrt/common/target.h" #include "paddle/infrt/common/target.h"
#include "paddle/infrt/common/type.h" #include "paddle/infrt/common/type.h"
namespace infrt::paddle { namespace infrt {
namespace paddle {
int SizeOfType(framework_proto::VarType::Type type) { int SizeOfType(framework_proto::VarType::Type type) {
using Type = framework_proto::VarType::Type; using Type = framework_proto::VarType::Type;
...@@ -169,4 +170,5 @@ void LoadParam(const std::string &path, _Variable *out, const Target &target) { ...@@ -169,4 +170,5 @@ void LoadParam(const std::string &path, _Variable *out, const Target &target) {
LoadLoDTensor(fin, out, target); LoadLoDTensor(fin, out, target);
} }
} // namespace infrt::paddle } // namespace paddle
} // namespace infrt
...@@ -25,7 +25,8 @@ ...@@ -25,7 +25,8 @@
#include "paddle/infrt/paddle/scope.h" #include "paddle/infrt/paddle/scope.h"
#include "paddle/infrt/paddle/tensor.h" #include "paddle/infrt/paddle/tensor.h"
namespace infrt::paddle { namespace infrt {
namespace paddle {
namespace framework_proto = ::paddle::framework::proto; namespace framework_proto = ::paddle::framework::proto;
// Read a __model__ file. // Read a __model__ file.
...@@ -52,4 +53,5 @@ void TensorFromStream( ...@@ -52,4 +53,5 @@ void TensorFromStream(
const common::Target& target = common::DefaultHostTarget()); const common::Target& target = common::DefaultHostTarget());
void ReadBinaryFile(const std::string& filename, std::string* contents); void ReadBinaryFile(const std::string& filename, std::string* contents);
} // namespace infrt::paddle } // namespace paddle
} // namespace infrt
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#include "paddle/infrt/paddle/pb/block_desc.h" #include "paddle/infrt/paddle/pb/block_desc.h"
namespace infrt::paddle::pb { namespace infrt {
namespace paddle {
namespace pb {
template <> template <>
framework_proto::VarDesc* BlockDesc::GetVar<framework_proto::VarDesc>( framework_proto::VarDesc* BlockDesc::GetVar<framework_proto::VarDesc>(
...@@ -40,4 +42,6 @@ framework_proto::OpDesc* BlockDesc::AddOp<framework_proto::OpDesc>() { ...@@ -40,4 +42,6 @@ framework_proto::OpDesc* BlockDesc::AddOp<framework_proto::OpDesc>() {
return desc_->add_ops(); return desc_->add_ops();
} }
} // namespace infrt::paddle::pb } // namespace pb
} // namespace paddle
} // namespace infrt
...@@ -18,7 +18,9 @@ ...@@ -18,7 +18,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h" #include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h" #include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb { namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto; namespace framework_proto = ::paddle::framework::proto;
...@@ -74,4 +76,6 @@ class BlockDesc : public cpp::BlockDescAPI { ...@@ -74,4 +76,6 @@ class BlockDesc : public cpp::BlockDescAPI {
framework_proto::BlockDesc* desc_; // not_own framework_proto::BlockDesc* desc_; // not_own
}; };
} // namespace infrt::paddle::pb } // namespace pb
} // namespace paddle
} // namespace infrt
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
#include "paddle/infrt/paddle/pb/op_desc.h" #include "paddle/infrt/paddle/pb/op_desc.h"
namespace infrt::paddle::pb { namespace infrt {
namespace paddle {
namespace pb {
google::protobuf::internal::RepeatedPtrIterator<framework_proto::OpDesc_Attr> google::protobuf::internal::RepeatedPtrIterator<framework_proto::OpDesc_Attr>
FindAttr(framework_proto::OpDesc *desc, const std::string &name) { FindAttr(framework_proto::OpDesc *desc, const std::string &name) {
...@@ -136,4 +138,6 @@ GET_ATTRS_IMPL(std::vector<std::string>, strings); ...@@ -136,4 +138,6 @@ GET_ATTRS_IMPL(std::vector<std::string>, strings);
GET_ATTR_IMPL(std::string, s); GET_ATTR_IMPL(std::string, s);
GET_ATTRS_IMPL(std::vector<int64_t>, longs); GET_ATTRS_IMPL(std::vector<int64_t>, longs);
} // namespace infrt::paddle::pb } // namespace pb
} // namespace paddle
} // namespace infrt
...@@ -19,7 +19,9 @@ ...@@ -19,7 +19,9 @@
#include "paddle/infrt/paddle/framework.pb.h" #include "paddle/infrt/paddle/framework.pb.h"
#include "paddle/infrt/support/variant.h" #include "paddle/infrt/support/variant.h"
namespace infrt::paddle::pb { namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto; namespace framework_proto = ::paddle::framework::proto;
...@@ -195,4 +197,6 @@ template <> ...@@ -195,4 +197,6 @@ template <>
void OpDesc::SetAttr<std::vector<int>>(const std::string &name, void OpDesc::SetAttr<std::vector<int>>(const std::string &name,
const std::vector<int> &v); const std::vector<int> &v);
} // namespace infrt::paddle::pb } // namespace pb
} // namespace paddle
} // namespace infrt
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
namespace infrt::paddle::pb { namespace infrt {
namespace paddle {
namespace pb {
template <> template <>
framework_proto::BlockDesc* ProgramDesc::GetBlock<framework_proto::BlockDesc>( framework_proto::BlockDesc* ProgramDesc::GetBlock<framework_proto::BlockDesc>(
...@@ -32,4 +34,6 @@ ProgramDesc::AddBlock<framework_proto::BlockDesc>() { ...@@ -32,4 +34,6 @@ ProgramDesc::AddBlock<framework_proto::BlockDesc>() {
return desc_->add_blocks(); return desc_->add_blocks();
} }
} // namespace infrt::paddle::pb } // namespace pb
} // namespace paddle
} // namespace infrt
...@@ -21,7 +21,9 @@ ...@@ -21,7 +21,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h" #include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h" #include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb { namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto; namespace framework_proto = ::paddle::framework::proto;
class ProgramDesc : public cpp::ProgramDescAPI { class ProgramDesc : public cpp::ProgramDescAPI {
...@@ -58,4 +60,6 @@ class ProgramDesc : public cpp::ProgramDescAPI { ...@@ -58,4 +60,6 @@ class ProgramDesc : public cpp::ProgramDescAPI {
framework_proto::ProgramDesc *desc_; // not_own framework_proto::ProgramDesc *desc_; // not_own
}; };
} // namespace infrt::paddle::pb } // namespace pb
} // namespace paddle
} // namespace infrt
...@@ -19,7 +19,9 @@ ...@@ -19,7 +19,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h" #include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h" #include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb { namespace infrt {
namespace paddle {
namespace pb {
cpp::VarDescAPI::Type VarDesc::GetType() const { cpp::VarDescAPI::Type VarDesc::GetType() const {
auto type = desc_->type().type(); auto type = desc_->type().type();
...@@ -364,4 +366,6 @@ VarDesc::mutable_tensor_descs() { ...@@ -364,4 +366,6 @@ VarDesc::mutable_tensor_descs() {
return std::vector<framework_proto::VarType::TensorDesc *>(); return std::vector<framework_proto::VarType::TensorDesc *>();
} }
} // namespace infrt::paddle::pb } // namespace pb
} // namespace paddle
} // namespace infrt
...@@ -23,7 +23,9 @@ ...@@ -23,7 +23,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h" #include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h" #include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb { namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto; namespace framework_proto = ::paddle::framework::proto;
// convert between std::vector and protobuf repeated. // convert between std::vector and protobuf repeated.
...@@ -121,4 +123,6 @@ class VarDesc : public cpp::VarDescAPI { ...@@ -121,4 +123,6 @@ class VarDesc : public cpp::VarDescAPI {
framework_proto::VarDesc *desc_; framework_proto::VarDesc *desc_;
}; };
} // namespace infrt::paddle::pb } // namespace pb
} // namespace paddle
} // namespace infrt
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册