未验证 提交 0de8a805 编写于 作者: 王明冬 提交者: GitHub

[infrt] update the version of llvm. test=develop (#38843)

上级 4c77a908
include(FetchContent)
set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/CINN/llvm11.tar.gz)
set(LLVM_MD5 39d32b6be466781dddf5869318dcba53)
set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/infrt/llvm_b5149f4e66a49a98b67e8e2de4e24a4af8e2781b.tar.gz)
set(LLVM_MD5 022819bb5760817013cf4b8a37e97d5e)
set(FETCHCONTENT_BASE_DIR ${THIRD_PARTY_PATH}/llvm)
set(FETCHCONTENT_QUIET OFF)
......@@ -51,7 +51,7 @@ message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
# To build with MLIR, the LLVM is build from source code using the following flags:
#[==[
cmake -G Ninja ../llvm \
cmake ../llvm -G "Unix Makefiles" \
-DLLVM_ENABLE_PROJECTS="mlir;clang" \
-DLLVM_BUILD_EXAMPLES=OFF \
-DLLVM_TARGETS_TO_BUILD="X86" \
......@@ -59,8 +59,10 @@ cmake -G Ninja ../llvm \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_ZLIB=OFF \
-DLLVM_ENABLE_RTTI=ON \
-DLLVM_INSTALL_UTILS=ON \
-DCMAKE_INSTALL_PREFIX=./install
#]==]
# The matched llvm-project version is f9dc2b7079350d0fed3bb3775f496b90483c9e42 (currently a temporary commit)
# The matched llvm-project version is b5149f4e66a49a98b67e8e2de4e24a4af8e2781b (currently a temporary commit)
add_definitions(${LLVM_DEFINITIONS})
......@@ -75,7 +77,7 @@ add_definitions(${LLVM_DEFINITIONS})
# The minimum needed libraries for MLIR IR parse and transform.
set(MLIR_IR_LIBS MLIRAnalysis MLIRStandardOps MLIRPass MLIRParser MLIRDialect MLIRIR MLIROptLib)
set(MLIR_IR_LIBS MLIRAnalysis MLIRPass MLIRParser MLIRDialect MLIRIR MLIROptLib)
# tb_base is the name of a xxx.td file (without the .td suffix)
......@@ -89,6 +91,7 @@ function(mlir_tablegen_on td_base)
mlir_tablegen(${td_base}.cpp.inc -gen-op-defs)
if (mlir_tablegen_on_DIALECT)
mlir_tablegen(${td_base}_dialect.hpp.inc --gen-dialect-decls -dialect=${mlir_tablegen_on_DIALECT})
mlir_tablegen(${td_base}_dialect.cpp.inc --gen-dialect-defs -dialect=${mlir_tablegen_on_DIALECT})
endif()
add_public_tablegen_target(${td_base}_IncGen)
add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen)
......
......@@ -77,7 +77,6 @@ add_subdirectory(paddle)
# MLIR td file generations
set(infrt_mlir_incs
ops_inc
basic_kernels_inc
test_kernels_inc
infrt_base_inc
......
......@@ -14,7 +14,7 @@
#pragma once
#include "mlir/IR/MLIRContext.h"
#include <mlir/IR/MLIRContext.h>
#include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt {
......
......@@ -2,7 +2,6 @@ core_gather_headers()
gather_srcs(infrt_src SRCS
dialect.cc
types.cc
basic_kernels.cc
test_kernels.cc
infrt_base.cc
......@@ -14,8 +13,6 @@ gather_srcs(infrt_src SRCS
pd_types.cc
pd_ops.cc
)
mlir_tablegen_on(ops)
mlir_tablegen_on(basic_kernels)
mlir_tablegen_on(test_kernels)
mlir_tablegen_on(infrt_base DIALECT infrt)
......@@ -27,8 +24,7 @@ mlir_add_rewriter(rewrite)
# TODO(Superjomn) add a cmake function cc_executable to ecapsulate the following code
add_executable(infrtopt opt.cc)
target_link_libraries(infrtopt infrt ${mlir_libs})
add_dependencies(infrtopt infrt)
target_link_libraries(infrtopt infrt)
add_executable(print-ir print_ir.cc)
target_link_libraries(print-ir infrt ${mlir_libs})
......
......@@ -17,17 +17,17 @@
#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h>
#include "paddle/infrt/dialect/dense_tensor.h"
namespace infrt::dialect {
namespace infrt {
namespace dialect {
using namespace mlir; // NOLINT
static ParseResult parseCallOp(OpAsmParser &parser, // NOLINT
......@@ -71,12 +71,12 @@ static ParseResult parseConstantF64Op(OpAsmParser &parser, // NOLINT
static ParseResult parseConstantI32Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
return parseConstantOp(
IntegerType::get(32, result.getContext()), parser, result);
IntegerType::get(result.getContext(), 32), parser, result);
}
static ParseResult parseConstantI64Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
return parseConstantOp(
IntegerType::get(64, result.getContext()), parser, result);
IntegerType::get(result.getContext(), 64), parser, result);
}
static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT
......@@ -90,10 +90,10 @@ static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT
}
static void print(OpAsmPrinter &p, CallOp op) { // NOLINT
p << "infrt.call " << op.getAttr("callee") << "(";
p << "infrt.call " << op->getAttr("callee") << "(";
p.printOperands(op.getOperands());
p << ")";
p.printOptionalAttrDict(op.getAttrs(), {"callee"});
p.printOptionalAttrDict(op->getAttrs(), {"callee"});
p << " : ";
}
......@@ -145,7 +145,7 @@ static LogicalResult verify(ConstantF64Op op) { return success(); }
static LogicalResult verify(ConstantI64Op op) { return success(); }
static LogicalResult verify(ReturnOp op) {
auto function = dyn_cast<FuncOp>(op.getParentOp());
auto function = dyn_cast<FuncOp>(op->getParentOp());
if (!function) return success();
......@@ -157,8 +157,8 @@ static LogicalResult verify(ReturnOp op) {
return success();
}
} // namespace dialect
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/basic_kernels.cpp.inc"
} // namespace infrt::dialect
......@@ -13,12 +13,9 @@
// limitations under the License.
#pragma once
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
using namespace mlir; // NOLINT
namespace infrt::dialect {
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/basic_kernels.hpp.inc"
} // namespace infrt::dialect
......@@ -27,7 +27,7 @@ def CallOp : INFRT_Op<"call"> {
let results = (outs Variadic<AnyType>);
let extraClassDeclaration = [{
StringRef getCallee() { return callee(); }
mlir::StringRef getCallee() { return callee(); }
mlir::FunctionType getCalleeType();
}];
}
......@@ -57,9 +57,8 @@ def ReturnOp : INFRT_Op<"return", [Terminator]> {
let arguments = (ins Variadic<AnyType>:$operands);
let builders = [OpBuilder<
"OpBuilder &b, OperationState &result",
[{ build(b, result, llvm::None); }]>];
let builders = [OpBuilder<(ins),
[{ build($_builder, $_state, llvm::None); }]>];
}
class AddOp<string suffix, Type type> : INFRT_Op<"add." # suffix, [NoSideEffect]> {
......
......@@ -17,12 +17,11 @@
#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h>
......@@ -31,68 +30,37 @@
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/tensor_shape.h"
namespace infrt::dt {
namespace infrt {
namespace dt {
void DTDialect::initialize() {
allowUnknownTypes();
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/dense_tensor.cpp.inc"
>();
}
namespace detail {
struct TensorTypeStorage : public mlir::TypeStorage {
TensorTypeStorage(TargetType target,
LayoutType layout,
PrecisionType precision)
: target_(target), layout_(layout), precision_(precision) {}
using KeyTy = std::tuple<TargetType, LayoutType, PrecisionType>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(target_, layout_, precision_);
}
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}
static TensorTypeStorage *construct(
mlir::TypeStorageAllocator &allocator, // NOLINT
const KeyTy &key) {
return new (allocator.allocate<TensorTypeStorage>())
TensorTypeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key));
}
TargetType target_;
LayoutType layout_;
PrecisionType precision_;
};
} // namespace detail
llvm::Optional<TargetType> GetTargetType(mlir::StringRef key) {
if (key.equals_lower("x86"))
if (key.equals_insensitive("x86"))
return TargetType::X86;
else if (key.equals_lower("cuda"))
else if (key.equals_insensitive("cuda"))
return TargetType::CUDA;
else
return llvm::None;
}
llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key) {
if (key.equals_lower("nchw"))
if (key.equals_insensitive("nchw"))
return LayoutType::NCHW;
else if (key.equals_lower("nhwc"))
else if (key.equals_insensitive("nhwc"))
return LayoutType::NHWC;
else
return llvm::None;
}
llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key) {
if (key.equals_lower("i32"))
if (key.equals_insensitive("i32"))
return PrecisionType::I32;
else if (key.equals_lower("f32"))
else if (key.equals_insensitive("f32"))
return PrecisionType::F32;
else
return llvm::None;
......@@ -111,7 +79,7 @@ LayoutType TensorType::layout() { return getImpl()->layout_; }
PrecisionType TensorType::precision() { return getImpl()->precision_; }
raw_ostream &operator<<(raw_ostream &os, TensorType tensorType) {
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TensorType tensorType) {
os << "TensorType<" << tensorType.target() << ", " << tensorType.layout()
<< ", " << tensorType.precision() << ">";
return os;
......@@ -133,7 +101,7 @@ StringType StringType::get(mlir::MLIRContext *context) {
return Base::get(context);
}
raw_ostream &operator<<(raw_ostream &os, TargetType type) {
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TargetType type) {
switch (type) {
case (TargetType::X86):
os << "X86";
......@@ -147,7 +115,7 @@ raw_ostream &operator<<(raw_ostream &os, TargetType type) {
return os;
}
raw_ostream &operator<<(raw_ostream &os, LayoutType type) {
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, LayoutType type) {
switch (type) {
case (LayoutType::NCHW):
os << "NCHW";
......@@ -161,7 +129,7 @@ raw_ostream &operator<<(raw_ostream &os, LayoutType type) {
return os;
}
raw_ostream &operator<<(raw_ostream &os, PrecisionType type) {
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, PrecisionType type) {
switch (type) {
case (PrecisionType::I32):
os << "I32";
......@@ -175,103 +143,69 @@ raw_ostream &operator<<(raw_ostream &os, PrecisionType type) {
return os;
}
static Type getTensorType(mlir::MLIRContext *context) {
auto t_dialect = Identifier::get("t", context);
return OpaqueType::get(t_dialect, "tensor", context);
static mlir::Type getTensorType(mlir::MLIRContext *context) {
auto t_dialect = mlir::Identifier::get("t", context);
return mlir::OpaqueType::get(t_dialect, "tensor");
}
static ParseResult parseCreateUninitTensorOp(
OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
static mlir::ParseResult parseCreateUninitTensorOp(
mlir::OpAsmParser &parser, // NOLINT
mlir::OperationState &result) { // NOLINT
auto loc = parser.getCurrentLocation();
::mlir::Type outputRawTypes[1];
::llvm::ArrayRef<::mlir::Type> outputTypes(outputRawTypes);
mlir::Type outputRawTypes[1];
::llvm::ArrayRef<mlir::Type> outputTypes(outputRawTypes);
mlir::ArrayAttr shapeAttr;
if (parser.parseAttribute(shapeAttr,
parser.getBuilder().getI64Type(),
"shape",
result.attributes))
return failure();
if (parser.parseOptionalAttrDict(result.attributes)) return failure();
return mlir::failure();
if (parser.parseOptionalAttrDict(result.attributes)) return mlir::failure();
if (parser.parseArrow()) return failure();
if (parser.parseType(outputRawTypes[0])) return failure();
if (parser.parseArrow()) return mlir::failure();
if (parser.parseType(outputRawTypes[0])) return mlir::failure();
if (!outputRawTypes[0].isa<TensorType>())
return parser.emitError(loc, "invalid kind of type specified");
result.addTypes(outputTypes);
return success();
return mlir::success();
}
template <typename CreateUninitTensorOp>
static void printCreateUninitTensorOp(OpAsmPrinter &p, // NOLINT
static void printCreateUninitTensorOp(mlir::OpAsmPrinter &p, // NOLINT
CreateUninitTensorOp op) {
p << CreateUninitTensorOp::getOperationName();
p << " ";
p.printAttributeWithoutType(op.shapeAttr());
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"shape"});
p << " -> ";
p << op.getOperation()->getResultTypes();
}
// TODO(shibo): can be removed?
// static ParseResult parseFillTensorWithConstantOp(OpAsmParser& parser,
// OperationState& result) {
// auto loc = parser.getCurrentLocation();
// ::mlir::OpAsmParser::OperandType inputRawOperands[1];
// ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType>
// inputOperands(inputRawOperands);
// ::mlir::Type inputRawTypes[1];
// ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes);
//
// if (parser.parseOperand(inputRawOperands[0])) return failure();
//
// if (parser.parseColon()) return failure();
// if (parser.parseType(inputRawTypes[0])) return failure();
// if (!inputRawTypes[0].isa<TensorType>())
// return parser.emitError(loc, "invalid kind of type specified");
//
// Attribute value_attr;
// if (parser.resolveOperands(inputOperands, inputTypes, loc, result.operands))
// return failure();
// if (parser.parseAttribute(value_attr, "value", result.attributes)) return
// failure();
// return success();
//}
// TODO(shibo): can be removed?
// template <typename FillTensorOp>
// static void printFillTensorWithConstantOp(OpAsmPrinter& p, FillTensorOp op) {
// p << FillTensorOp::getOperationName();
// p << " ";
// p.printOperand(op.getOperand());
// p << " : ";
// p << op.getOperation()->getOperandTypes();
// p << " ";
// p << op.getAttr("value");
//}
static ParseResult parseSetTensorOp(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
SmallVector<OpAsmParser::OperandType, 1> operands;
if (parser.parseOperandList(operands, 1)) return failure();
static mlir::ParseResult parseSetTensorOp(
mlir::OpAsmParser &parser, // NOLINT
mlir::OperationState &result) { // NOLINT
llvm::SmallVector<mlir::OpAsmParser::OperandType, 1> operands;
if (parser.parseOperandList(operands, 1)) return mlir::failure();
auto tensor_type = getTensorType(result.getContext());
Attribute value_attr;
return failure(
mlir::Attribute value_attr;
return mlir::failure(
parser.resolveOperand(operands[0], tensor_type, result.operands) ||
parser.parseAttribute(value_attr, "values", result.attributes));
}
template <typename SetTensorOp>
static void printSetTensorOp(OpAsmPrinter &p, SetTensorOp op) { // NOLINT
static void printSetTensorOp(mlir::OpAsmPrinter &p, SetTensorOp op) { // NOLINT
p << SetTensorOp::getOperationName() << " ";
p.printOperand(op.getOperand());
p << " " << op.getAttr("values");
p << " " << op->getAttr("values");
}
} // namespace dt
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/dense_tensor.cpp.inc" // NOLINT
} // namespace infrt::dt
#include "paddle/infrt/dialect/dense_tensor_dialect.cpp.inc"
......@@ -19,13 +19,8 @@
#include <string>
using namespace mlir; // NOLINT
namespace infrt::dt {
namespace detail {
struct TensorTypeStorage;
} // namespace detail
namespace infrt {
namespace dt {
enum class TargetType : uint8_t { X86, CUDA };
enum class LayoutType : uint8_t { NCHW, NHWC };
enum class PrecisionType : uint8_t { I32, F32 };
......@@ -34,9 +29,39 @@ llvm::Optional<TargetType> GetTargetType(mlir::StringRef key);
llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key);
llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key);
raw_ostream &operator<<(raw_ostream &os, TargetType type);
raw_ostream &operator<<(raw_ostream &os, LayoutType type);
raw_ostream &operator<<(raw_ostream &os, PrecisionType type);
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TargetType type);
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, LayoutType type);
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, PrecisionType type);
namespace detail {
struct TensorTypeStorage : public mlir::TypeStorage {
TensorTypeStorage(TargetType target,
LayoutType layout,
PrecisionType precision)
: target_(target), layout_(layout), precision_(precision) {}
using KeyTy = std::tuple<TargetType, LayoutType, PrecisionType>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(target_, layout_, precision_);
}
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}
static TensorTypeStorage *construct(
mlir::TypeStorageAllocator &allocator, // NOLINT
const KeyTy &key) {
return new (allocator.allocate<TensorTypeStorage>())
TensorTypeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key));
}
TargetType target_;
LayoutType layout_;
PrecisionType precision_;
};
} // namespace detail
class TensorType : public mlir::Type::TypeBase<TensorType,
mlir::Type,
......@@ -52,7 +77,7 @@ class TensorType : public mlir::Type::TypeBase<TensorType,
PrecisionType precision();
};
raw_ostream &operator<<(raw_ostream &os, TensorType tensorType);
mlir::raw_ostream &operator<<(mlir::raw_ostream &os, TensorType tensorType);
class TensorMapType : public mlir::Type::TypeBase<TensorMapType,
mlir::Type,
......@@ -70,10 +95,10 @@ class StringType
static StringType get();
static StringType get(mlir::MLIRContext *context);
};
} // namespace dt
} // namespace infrt
#include "paddle/infrt/dialect/dense_tensor_dialect.hpp.inc"
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/dense_tensor.hpp.inc"
} // namespace infrt::dt
......@@ -14,9 +14,11 @@
#include "paddle/infrt/dialect/diagnostic_utils.h"
#include <llvm/Support/raw_ostream.h>
#include <string>
namespace infrt::dialect {
namespace infrt {
namespace dialect {
struct MyScopedDiagnosicHandler::Impl {
Impl() : diag_stream_(diag_str_) {}
......@@ -49,4 +51,5 @@ mlir::LogicalResult MyScopedDiagnosicHandler::handler(mlir::Diagnostic *diag) {
return mlir::failure(true);
}
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -18,7 +18,8 @@
#include <memory>
namespace infrt::dialect {
namespace infrt {
namespace dialect {
/**
* A scoped diagnostic handler to help debug MLIR process.
......@@ -36,4 +37,5 @@ class MyScopedDiagnosicHandler : public mlir::SourceMgrDiagnosticHandler {
std::unique_ptr<Impl> impl_;
};
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -13,24 +13,26 @@
// limitations under the License.
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Support/LogicalResult.h>
namespace infrt::hlir::dialect {
namespace infrt {
namespace hlir {
namespace dialect {
class CinnDialect : public ::mlir::Dialect {
class CinnDialect : public mlir::Dialect {
public:
explicit CinnDialect(::mlir::MLIRContext* ctx);
explicit CinnDialect(mlir::MLIRContext* ctx);
//! We should register this function in dialect
static llvm::StringRef getDialectNamespace() {
return "infrt::hlir::dialect";
}
};
} // namespace infrt::hlir::dialect
} // namespace dialect
} // namespace hlir
} // namespace infrt
......@@ -18,7 +18,8 @@
#include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/test_kernels.h"
namespace infrt::dialect {
namespace infrt {
namespace dialect {
// ----INFRTDialect definition begin----
void INFRTDialect::initialize() {
......@@ -124,4 +125,5 @@ void INFRTDialect::printType(mlir::Type type,
// ----INFRTDialect definition end----
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -18,19 +18,17 @@
#include <mlir/IR/Dialect.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/IR/Types.h>
#include "paddle/infrt/dialect/infrt_base.hpp.inc"
namespace infrt::dialect {
class INFRTDialect : public ::mlir::Dialect {
explicit INFRTDialect(::mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(),
context,
::mlir::TypeID::get<INFRTDialect>()) {
namespace infrt {
namespace dialect {
class INFRTDialect : public mlir::Dialect {
explicit INFRTDialect(mlir::MLIRContext *context)
: mlir::Dialect(
getDialectNamespace(), context, mlir::TypeID::get<INFRTDialect>()) {
initialize();
}
......@@ -41,15 +39,12 @@ class INFRTDialect : public ::mlir::Dialect {
mlir::DialectAsmPrinter &printer) const override;
void initialize();
friend class ::mlir::MLIRContext;
friend class mlir::MLIRContext;
public:
static ::llvm::StringRef getDialectNamespace() { return "infrt"; }
};
} // namespace infrt::dialect
namespace mlir {
} // namespace dialect
template <typename T>
static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT
......@@ -58,17 +53,16 @@ static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT
return b.getIntegerAttr(b.getI32Type(), constant);
}
static mlir::SmallVector<::mlir::Value, 4> cvtValueToValueRange(
static mlir::SmallVector<mlir::Value, 4> cvtValueToValueRange(
const mlir::Value &operand) {
return mlir::SmallVector<::mlir::Value, 4>(1, operand);
return mlir::SmallVector<mlir::Value, 4>(1, operand);
}
static mlir::SmallVector<::mlir::Value, 4> concatTwoValueRange(
static mlir::SmallVector<mlir::Value, 4> concatTwoValueRange(
mlir::ValueRange operand_0, mlir::ValueRange operand_1) {
mlir::SmallVector<::mlir::Value, 4> operands;
mlir::SmallVector<mlir::Value, 4> operands;
operands.append(operand_0.begin(), operand_0.end());
operands.append(operand_1.begin(), operand_1.end());
return operands;
}
} // namespace mlir
} // namespace infrt
......@@ -28,11 +28,11 @@ def TensorMapType :
def BufferType : OpaqueType<"b", "buffer", "buffer">;
class INFRT_createI32Attr<string value> : NativeCodeCall<
"mlir::createI32Attr($_builder, $_loc, " # value # ")">;
"infrt::createI32Attr($_builder, $_loc, " # value # ")">;
def INFRT_cvtValueToValueRange : NativeCodeCall<
"mlir::cvtValueToValueRange($0)">;
"infrt::cvtValueToValueRange($0)">;
def INFRT_concatTwoValueRange : NativeCodeCall<
"mlir::concatTwoValueRange($0, $1)">;
"infrt::concatTwoValueRange($0, $1)">;
#endif // INFRT_BASE
......@@ -23,12 +23,10 @@
#include "paddle/infrt/dialect/tensor_shape.h"
namespace infrt {
void RegisterCinnDialects(mlir::DialectRegistry& registry) { // NOLINT
registry.insert<ts::TensorShapeDialect>();
registry.insert<dialect::INFRTDialect>();
registry.insert<dt::DTDialect>();
registry.insert<mlir::pd::PaddleDialect>();
void registerCinnDialects(mlir::DialectRegistry &registry) { // NOLINT
registry.insert<ts::TensorShapeDialect,
dialect::INFRTDialect,
dt::DTDialect,
mlir::pd::PaddleDialect>();
}
} // namespace infrt
......@@ -14,10 +14,8 @@
#pragma once
#include "mlir/IR/Dialect.h"
#include <mlir/IR/Dialect.h>
#include <mlir/IR/MLIRContext.h>
namespace infrt {
void RegisterCinnDialects(mlir::DialectRegistry& registry); // NOLINT
void registerCinnDialects(mlir::DialectRegistry &registry); // NOLINT
} // namespace infrt
......@@ -16,8 +16,8 @@
#include <llvm/Support/SourceMgr.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/Parser.h>
#include <unordered_map>
......@@ -30,12 +30,15 @@
#include "paddle/infrt/dialect/diagnostic_utils.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h"
namespace infrt::dialect {
namespace infrt {
namespace dialect {
mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
const std::string& mlir_source) {
// context->allowUnregisteredDialects();
RegisterCinnDialects(context->getDialectRegistry());
mlir::DialectRegistry registry;
registerCinnDialects(registry);
context->appendDialectRegistry(registry);
// Currenetly, We only used the CinnDialect and mlir::BuiltinDialect is
// enough。Don't need StandardOpsDialect.
// context->getDialectRegistry().insert<mlir::StandardOpsDialect>();
......@@ -57,9 +60,9 @@ mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
mlir::MLIRContext* context) {
// context->allowUnregisteredDialects();
RegisterCinnDialects(context->getDialectRegistry());
context->getDialectRegistry().insert<mlir::StandardOpsDialect>();
mlir::DialectRegistry registry;
registerCinnDialects(registry);
context->appendDialectRegistry(registry);
mlir::ScopedDiagnosticHandler scope_handler(
context, [](mlir::Diagnostic& diag) {
if (diag.getSeverity() != mlir::DiagnosticSeverity::Error)
......@@ -71,4 +74,5 @@ mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
return mlir::parseSourceFile(std::string(file_name), context);
}
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -15,16 +15,17 @@
#pragma once
#include <glog/logging.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/BuiltinOps.h>
#include <string>
#include <memory>
namespace infrt::dialect {
namespace infrt {
namespace dialect {
mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
const std::string& mlir_source);
mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
mlir::MLIRContext* context);
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -17,14 +17,15 @@
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <llvm/Support/SourceMgr.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/Parser.h>
#include <string>
#include "paddle/infrt/dialect/init_infrt_dialects.h"
namespace infrt::dialect {
namespace infrt {
namespace dialect {
TEST(MlirLoader, basic) {
mlir::MLIRContext context;
......@@ -42,8 +43,7 @@ func @main() -> f32 {
)ROC";
auto module = LoadMlirSource(&context, source);
module->verify();
EXPECT_TRUE(mlir::succeeded(module->verify()));
LOG(INFO) << "module name: " << module->getOperationName().data();
for (auto func : module->getOps<mlir::FuncOp>()) {
LOG(INFO) << "get func " << func.getName().str();
......@@ -54,4 +54,5 @@ func @main() -> f32 {
}
}
} // namespace infrt::dialect
} // namespace dialect
} // namespace infrt
......@@ -20,5 +20,5 @@ func @main() -> tensor<?xf32> {
%c2 = "pd.matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32>
infrt.return %e2 : tensor<?xf32>
"pd.fetch"(%e2) {name="output"} :(tensor<?xf32>)->()
}
\ No newline at end of file
......@@ -11,5 +11,5 @@ func @main() -> tensor<?xf32> {
%c = "pd.conv2d"(%a, %filter, %bias) {} : (tensor<?x3x256x256xf32>, tensor<3x64x3x3xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32>
%d = "pd.batch_norm"(%c, %scale, %bias2, %mean, %var) {} : (tensor<?x3x256x256xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32>
infrt.return %d : tensor<?x3x256x256xf32>
"pd.fetch"(%d) {name="output"} :(tensor<?x3x256x256xf32>)->()
}
\ No newline at end of file
......@@ -18,5 +18,5 @@ func @main() -> tensor<?xf32> {
%d2 = "pd.elementwise_add"(%c2, %bias2) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e2 = "pd.relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32>
"pd.fetch"(%e2) :(tensor<?xf32>)->()
"pd.fetch"(%e2) {name="output"} :(tensor<?xf32>)->()
}
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/infrt_base.td"
class INFRT_Op<string mnemonic, list<OpTrait> traits = []> :
Op<INFRT_Dialect, mnemonic, traits>;
......@@ -12,34 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <llvm/Support/CommandLine.h>
#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/IR/AsmState.h>
#include <mlir/IR/Dialect.h>
#include <mlir/InitAllDialects.h>
#include <mlir/InitAllPasses.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
#include <mlir/Support/MlirOptMain.h>
#include <mlir/Transforms/Passes.h>
#include <iostream>
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h"
#include "paddle/infrt/dialect/mlir_loader.h"
int main(int argc, char **argv) {
mlir::MLIRContext *context = infrt::Global::getMLIRContext();
auto &registry = context->getDialectRegistry();
infrt::RegisterCinnDialects(registry);
mlir::DialectRegistry registry;
infrt::registerCinnDialects(registry);
mlir::registerCanonicalizerPass();
return mlir::failed(
mlir::MlirOptMain(argc, argv, "INFRT mlir pass driver", registry));
mlir::MlirOptMain(argc, argv, "infrt mlir pass driver", registry));
}
......@@ -16,7 +16,7 @@ def PD_Dialect : Dialect {
This dialect contains the PaddlePaddle operators.
}];
let cppNamespace = "::mlir::pd";
let cppNamespace = "mlir::pd";
}
class PD_Op<string mnemonic, list<OpTrait> traits = []> :
......
......@@ -14,10 +14,15 @@
#include "paddle/infrt/dialect/pd_ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include <mlir/IR/Matchers.h>
#include <mlir/IR/PatternMatch.h>
#include "paddle/infrt/dialect/infrt_base.h"
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT
#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT
namespace mlir {
namespace pd {
PaddleDialect::PaddleDialect(MLIRContext *context)
......@@ -36,12 +41,6 @@ mlir::Operation *PaddleDialect::materializeConstant(mlir::OpBuilder &builder,
return builder.create<ConstantOp>(loc, value);
}
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT
#undef GET_OP_CLASSES
#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT
void ConstantOp::build(OpBuilder &builder,
OperationState &state,
Attribute value) {
......@@ -66,8 +65,8 @@ LogicalResult ConstantOp::inferReturnTypes(
inferredReturnTypes.push_back(attributes.get("value").getType());
return success();
}
::mlir::OpFoldResult ConstantOp::fold(
::llvm::ArrayRef<::mlir::Attribute> operands) {
mlir::OpFoldResult ConstantOp::fold(
::llvm::ArrayRef<mlir::Attribute> operands) {
return value();
}
......@@ -82,11 +81,11 @@ LogicalResult ElementwiseAdd::inferReturnTypes(
return success();
}
void ElementwiseAdd::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseMulAdd>(context);
}
::mlir::OpFoldResult ElementwiseAdd::fold(
mlir::OpFoldResult ElementwiseAdd::fold(
llvm::ArrayRef<mlir::Attribute> operands) {
if (getElementTypeOrSelf(getType()).isa<FloatType>()) {
if (!operands[0] || !operands[1]) return {};
......@@ -154,17 +153,17 @@ LogicalResult MulOp::inferReturnTypes(
}
void ReluOp::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseFCRelu>(context);
}
void FusedRepeatedFCRelu::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseRepeatedFCRelu2>(context);
}
void BatchNormOp::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
mlir::OwningRewritePatternList &results, mlir::MLIRContext *context) {
results.insert<FuseBatchNormWithConvPattern>(context);
}
......
......@@ -14,21 +14,20 @@
#pragma once
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <mlir/Dialect/Traits.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/Matchers.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/DerivedAttributeOpInterface.h>
#include <mlir/Interfaces/InferTypeOpInterface.h>
#include <mlir/Interfaces/LoopLikeInterface.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
namespace mlir {
namespace pd {
......@@ -53,9 +52,8 @@ class PaddleDialect : public Dialect {
}
};
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.hpp.inc"
#undef GET_OP_CLASSES
} // namespace pd
} // namespace mlir
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.hpp.inc"
......@@ -24,6 +24,16 @@ def PD_FeedOp : PD_Op<"feed"> {
def PD_FetchOp : PD_Op<"fetch", [Terminator]> {
let summary = "fetch Op";
let description = [{
Return the output tensor from the subgraph.
}];
let arguments = (ins PD_Tensor :$inputs, StrAttr:$name);
}
def PD_ReturnOp : PD_Op<"return", [Terminator]> {
let summary = "return Op";
let description = [{
Fetch tensor from the graph.
}];
......@@ -31,7 +41,7 @@ def PD_FetchOp : PD_Op<"fetch", [Terminator]> {
let arguments = (ins Variadic<PD_Tensor>:$inputs);
}
def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"FetchOp">]> {
def PD_GraphOp : PD_Op<"graph", [SingleBlockImplicitTerminator<"ReturnOp">]> {
let summary = "paddle graph Op";
let description = [{
Describe a paddle graph or subgraph.
......@@ -50,7 +60,7 @@ def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInte
let hasFolder = 1;
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Attribute value">,
OpBuilder<(ins "Attribute":$value)>,
];
}
......
......@@ -18,12 +18,11 @@
#pragma once
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/Operation.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/IR/Types.h>
namespace mlir {
namespace PD {
......
......@@ -11,26 +11,25 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <llvm/ADT/Optional.h>
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/ScopedPrinter.h>
#include <mlir/IR/BuiltinOps.h>
#include <llvm/Support/raw_os_ostream.hv
#include <llvm/Support/raw_ostream.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/AsmState.h>
#include <mlir/IR/Block.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/Operation.h>
#include <mlir/IR/Region.h>
#include <mlir/IR/Verifier.h>
#include <mlir/Parser.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Transforms/Passes.h>
#include <iostream>
#include "llvm/ADT/Optional.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/raw_os_ostream.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h"
......@@ -114,17 +113,15 @@ int main(int argc, char **argv) {
mlir::registerPassManagerCLOptions();
cl::ParseCommandLineOptions(argc, argv, "mlir demo");
mlir::MLIRContext *context = infrt::Global::getMLIRContext();
// context->allowUnregisteredDialects();
auto &registry = context->getDialectRegistry();
infrt::RegisterCinnDialects(registry);
mlir::DialectRegistry registry;
infrt::registerCinnDialects(registry);
mlir::MLIRContext context(registry);
// mlir will verify module automatically after parsing.
// https://github.com/llvm/llvm-project/blob/38d18d93534d290d045bbbfa86337e70f1139dc2/mlir/lib/Parser/Parser.cpp#L2051
// mlir::OwningModuleRef module_ref = mlir::parseSourceString(mlir_source,
// context);
mlir::OwningModuleRef module_ref =
mlir::parseSourceFile(inputFilename, context);
mlir::parseSourceFile(inputFilename, &context);
std::cout << "----------print IR Structure begin----------" << std::endl;
printOperation(module_ref->getOperation(), 0);
std::cout << "----------print IR Structure end----------" << std::endl;
......
......@@ -17,16 +17,16 @@
#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h>
namespace infrt::ts {
namespace infrt {
namespace ts {
using namespace mlir; // NOLINT
void TensorShapeDialect::initialize() {
......@@ -48,8 +48,8 @@ Type TensorShapeDialect::parseType(DialectAsmParser &parser) const {
return Type();
}
void TensorShapeDialect::printType(::mlir::Type type,
::mlir::DialectAsmPrinter &os) const {
void TensorShapeDialect::printType(mlir::Type type,
mlir::DialectAsmPrinter &os) const {
if (type.isa<ShapeType>()) {
os << "shape";
return;
......@@ -61,8 +61,10 @@ void TensorShapeDialect::printType(::mlir::Type type,
}
llvm_unreachable("unexpected 'shape' type kind");
}
} // namespace ts
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensor_shape.cpp.inc" // NOLINT
} // namespace infrt::ts
#include "paddle/infrt/dialect/tensor_shape_dialect.cpp.inc"
......@@ -17,7 +17,8 @@
#include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt::ts {
namespace infrt {
namespace ts {
class ShapeType
: public mlir::Type::TypeBase<ShapeType, mlir::Type, mlir::TypeStorage> {
......@@ -31,10 +32,9 @@ class PartialShapeType : public mlir::Type::TypeBase<PartialShapeType,
public:
using Base::Base;
};
} // namespace ts
} // namespace infrt
using namespace mlir; // NOLINT
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensor_shape.hpp.inc"
#include "paddle/infrt/dialect/tensor_shape_dialect.hpp.inc"
} // namespace infrt::ts
......@@ -19,7 +19,7 @@ def TensorShapeDialect : Dialect {
def TS_Shape : DialectType<TensorShapeDialect,
CPred<"$_self.isa<::infrt::ts::ShapeType>()">, "!ts.shape type">,
BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> {
let typeDescription = [{
let description = [{
`!ts.shape type` represents a static tensor shape.
}];
}
......@@ -27,7 +27,7 @@ BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> {
def TS_PartialShape : DialectType<TensorShapeDialect,
CPred<"$_self.isa<::infrt::ts::PartialShapeType>()">, "!ts.partial_shape type">,
BuildableType<"$_builder.getType<::infrt::ts::PartialShapeType>()"> {
let typeDescription = [{
let description = [{
`!ts.partial_shape type` represents either a static tensor shape, unranked
tensor shape or a ranked tensor shape with unknown dimension sizes.
}];
......
......@@ -11,10 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <llvm/Support/CommandLine.h>
#include <mlir/Pass/PassManager.h>
#include <iostream>
#include <string>
#include "llvm/Support/CommandLine.h"
#include "mlir/Pass/PassManager.h"
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
......
......@@ -14,14 +14,13 @@
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
#include <llvm/ADT/SetVector.h>
#include <mlir/Analysis/SliceAnalysis.h>
#include <mlir/IR/Builders.h>
#include <paddle/infrt/dialect/pd_ops.h>
#include <list>
#include <unordered_set>
#include <vector>
#include "llvm/ADT/SetVector.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/Builders.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
......@@ -32,9 +31,9 @@ namespace {
// Reference the function nameed "FlexibleDFS" but defined in:
// paddle/fluid/framework/ir/subgraph_detector.cc.
bool reverseDfs(std::vector<::mlir::Operation *> source,
const std::function<bool(const ::mlir::Operation *)> &func) {
std::unordered_set<const ::mlir::Operation *> visited;
bool reverseDfs(std::vector<mlir::Operation *> source,
const std::function<bool(const mlir::Operation *)> &func) {
std::unordered_set<const mlir::Operation *> visited;
while (!source.empty()) {
auto node = source.back();
source.pop_back();
......@@ -44,7 +43,7 @@ bool reverseDfs(std::vector<::mlir::Operation *> source,
auto values = node->getOperands();
for (auto value : values) {
// if the value is a block argument, the node is nullptr.
::mlir::Operation *node = value.getDefiningOp();
mlir::Operation *node = value.getDefiningOp();
if (node != nullptr && !visited.count(node)) {
source.emplace_back(node);
}
......@@ -54,19 +53,19 @@ bool reverseDfs(std::vector<::mlir::Operation *> source,
}
// merge the first&second graph op to a new graph op.
void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
::mlir::pd::GraphOp first,
::mlir::pd::GraphOp second) {
void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT
mlir::pd::GraphOp first,
mlir::pd::GraphOp second) {
// comput inputs and outputs
::llvm::SmallVector<::mlir::Value, 4> inputs(first.getOperands()), outputs;
for (::mlir::Value input : second.getOperands()) {
::llvm::SmallVector<mlir::Value, 4> inputs(first.getOperands()), outputs;
for (mlir::Value input : second.getOperands()) {
if (input.getDefiningOp() != first) {
inputs.push_back(input);
}
}
::llvm::DenseMap<::mlir::Value, unsigned int> op_output_mapping;
for (::mlir::Value output : first.getResults()) {
for (::mlir::Operation *user : output.getUsers()) {
::llvm::DenseMap<mlir::Value, unsigned int> op_output_mapping;
for (mlir::Value output : first.getResults()) {
for (mlir::Operation *user : output.getUsers()) {
if (user != second && user->getParentOp() != second) {
op_output_mapping[output] = outputs.size();
outputs.push_back(output);
......@@ -74,19 +73,19 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
}
}
}
auto fetch_op = second.getBody()->getTerminator();
outputs.append(fetch_op->getOperands().begin(),
fetch_op->getOperands().end());
::llvm::SmallVector<::mlir::Type, 4> fetch_types;
auto return_op = second.getBody()->getTerminator();
outputs.append(return_op->getOperands().begin(),
return_op->getOperands().end());
::llvm::SmallVector<mlir::Type, 4> return_types;
for (auto value : outputs) {
fetch_types.push_back(value.getType());
return_types.push_back(value.getType());
}
// create the new graph op
builder.setInsertionPoint(first);
auto loc = first.getLoc();
auto graph_op = builder.create<::mlir::pd::GraphOp>(loc, fetch_types, inputs);
::mlir::Block *block = new ::mlir::Block;
auto graph_op = builder.create<mlir::pd::GraphOp>(loc, return_types, inputs);
mlir::Block *block = new mlir::Block;
auto copy_range = second.getBody()->without_terminator();
block->getOperations().splice(block->begin(),
second.getBody()->getOperations(),
......@@ -98,18 +97,18 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
copy_range.begin(),
copy_range.end());
builder.setInsertionPointToEnd(block);
builder.create<mlir::pd::FetchOp>(loc, outputs);
builder.create<mlir::pd::ReturnOp>(loc, outputs);
graph_op.body().push_back(block);
// mapping the output
unsigned int num_result = first.getNumResults();
fetch_op = first.getBody()->getTerminator();
return_op = first.getBody()->getTerminator();
for (unsigned int index = 0; index < num_result; ++index) {
auto origin_value = first.getResult(index);
if (op_output_mapping.find(origin_value) == op_output_mapping.end()) {
origin_value.replaceAllUsesWith(fetch_op->getOperand(index));
origin_value.replaceAllUsesWith(return_op->getOperand(index));
} else {
auto inner_value = fetch_op->getOperand(index);
auto inner_value = return_op->getOperand(index);
auto outer_value = graph_op.getResult(op_output_mapping[origin_value]);
while (!origin_value.use_empty()) {
auto replace_value =
......@@ -128,13 +127,13 @@ void mergeTwoAdjacentGraphOp(::mlir::OpBuilder &builder, // NOLINT
// Topological sort the function op.
void topoSortBlock(mlir::Block &body) { // NOLINT
llvm::SetVector<Operation *> toSort;
llvm::SetVector<mlir::Operation *> toSort;
if (body.empty()) return;
for (auto it = body.rbegin(); it != body.rend(); ++it) {
toSort.insert(&*it);
}
llvm::SetVector<Operation *> result =
::mlir::topologicalSort(std::move(toSort));
llvm::SetVector<mlir::Operation *> result =
mlir::topologicalSort(std::move(toSort));
for (auto *op : result) {
op->moveBefore(body.getTerminator());
}
......@@ -145,21 +144,21 @@ void topoSortBlock(mlir::Block &body) { // NOLINT
// Implementation of the trtGraphFusePass.
void trtGraphFusePass::runOnFunction() {
mlir::Block &body = getFunction().front();
::mlir::OpBuilder builder(&body, body.begin());
mlir::OpBuilder builder(&body, body.begin());
bool changed = false;
do {
changed = false;
for (auto &op : body) {
::mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op);
mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
if (nullptr == graph_op) continue;
for (auto user_op : op.getUsers()) {
::mlir::pd::GraphOp user_graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(user_op);
mlir::pd::GraphOp user_graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(user_op);
if (nullptr == user_graph_op) continue;
// get all dst input nodes except src.
std::vector<::mlir::Operation *> source_nodes;
std::vector<mlir::Operation *> source_nodes;
for (auto operand : user_op->getOperands()) {
auto input = operand.getDefiningOp();
if (input != &op && input != nullptr) {
......@@ -167,9 +166,8 @@ void trtGraphFusePass::runOnFunction() {
}
}
// Reverse DFS from the source_nodes.
if (!reverseDfs(source_nodes, [&op](const ::mlir::Operation *n) {
return n == &op;
})) {
if (!reverseDfs(source_nodes,
[&op](const mlir::Operation *n) { return n == &op; })) {
mergeTwoAdjacentGraphOp(builder, graph_op, user_graph_op);
changed = true;
break;
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#pragma once
#include "mlir/Pass/Pass.h"
#include <mlir/Pass/Pass.h>
namespace infrt {
namespace trt {
......@@ -28,15 +28,15 @@ namespace trt {
* %a = "pd.feed"()...
* %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* "pd.return" %m
* } ...
* %d = "pd.graph"(%c) {
* %m = "pd.conv3d"(%c)...
* "pd.fetch" %m
* "pd.return" %m
* } ...
* %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* "pd.return" %m
* } ...
* "pd.fetch" %d, %f
*
......@@ -47,13 +47,13 @@ namespace trt {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
* "pd.fetch" %n, %s
* "pd.return" %n, %s
* } ...
* "pd.fetch" %d, %f
* }
*/
class trtGraphFusePass
: public ::mlir::PassWrapper<trtGraphFusePass, ::mlir::FunctionPass> {
: public mlir::PassWrapper<trtGraphFusePass, mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "trtGraphFusePass"; }
void runOnFunction() override;
......
......@@ -14,7 +14,7 @@
#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h"
#include "mlir/IR/Builders.h"
#include <mlir/IR/Builders.h>
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
......@@ -22,24 +22,24 @@ namespace infrt {
namespace trt {
// Implementation of the trtGraphSplitPass。
void trtGraphSplitPass::runOnFunction() {
std::vector<::mlir::pd::GraphOp> worklist;
::mlir::Block& block = getFunction().front();
std::vector<mlir::pd::GraphOp> worklist;
mlir::Block& block = getFunction().front();
for (auto& op : block) {
::mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(&op);
mlir::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op);
if (nullptr != graph_op &&
graph_op.getBody()->getOperations().size() <= min_subgraph_size_) {
worklist.push_back(graph_op);
}
}
while (!worklist.empty()) {
::mlir::pd::GraphOp graph_op = worklist.back();
mlir::pd::GraphOp graph_op = worklist.back();
worklist.pop_back();
::mlir::Block* body = graph_op.getBody();
auto fetch_op = body->getTerminator();
graph_op.replaceAllUsesWith(fetch_op->getOperands());
mlir::Block* body = graph_op.getBody();
auto return_op = body->getTerminator();
graph_op.replaceAllUsesWith(return_op->getOperands());
auto copy_range = body->without_terminator();
block.getOperations().splice(::mlir::Block::iterator(graph_op),
block.getOperations().splice(mlir::Block::iterator(graph_op),
body->getOperations(),
copy_range.begin(),
copy_range.end());
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#pragma once
#include "mlir/Pass/Pass.h"
#include <mlir/Pass/Pass.h>
namespace infrt {
namespace trt {
......@@ -31,9 +31,9 @@ namespace trt {
* %m = "pd.conv2d"(%a)...
* %n = "pd.conv3d"(%m)...
* %s = "pd.conv2d"(%a)...
* "pd.fetch" %n, %s
* "pd.return" (%n, %s)
* } ...
* "pd.fetch" %d, %f
* "pd.fetch" (%d, %f)
* }
*
* destination func:
......@@ -42,11 +42,11 @@ namespace trt {
* %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ...
* "pd.fetch" %d, %f
* "pd.fetch" (%d, %f)
* }
*/
class trtGraphSplitPass
: public ::mlir::PassWrapper<trtGraphSplitPass, ::mlir::FunctionPass> {
: public mlir::PassWrapper<trtGraphSplitPass, mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "trtGraphSplitPass"; }
void runOnFunction() override;
......
......@@ -14,49 +14,48 @@
#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h"
#include "mlir/IR/Builders.h"
#include <mlir/IR/Builders.h>
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
namespace infrt {
namespace trt {
// Implementation of the trtOpTellerPass。
void trtOpTellerPass::runOnFunction() {
::mlir::Block &body = getFunction().front();
std::vector<::mlir::Operation *> worklist;
mlir::Block &body = getFunction().front();
std::vector<mlir::Operation *> worklist;
worklist.reserve(body.getOperations().size());
for (auto &op : body) {
worklist.push_back(&op);
}
// Build GraphOp.
::mlir::OpBuilder builder(&body, body.begin());
mlir::OpBuilder builder(&body, body.begin());
while (!worklist.empty()) {
auto *op = worklist.back();
worklist.pop_back();
if (op == nullptr) continue;
auto op1 = ::llvm::dyn_cast_or_null<::mlir::pd::FeedOp>(op);
auto op1 = ::llvm::dyn_cast_or_null<mlir::pd::FeedOp>(op);
if (op1) continue;
auto op2 = ::llvm::dyn_cast_or_null<::mlir::pd::FetchOp>(op);
auto op2 = ::llvm::dyn_cast_or_null<mlir::pd::FetchOp>(op);
if (op2) continue;
auto op3 = ::llvm::dyn_cast_or_null<::mlir::pd::GraphOp>(op);
auto op3 = ::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(op);
if (op3) continue;
builder.setInsertionPoint(op);
auto loc = getFunction().getLoc();
auto graph_op = builder.create<::mlir::pd::GraphOp>(
auto graph_op = builder.create<mlir::pd::GraphOp>(
loc, op->getResultTypes(), op->getOperands());
::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;
::llvm::SmallVector<mlir::Value, 4> tblgen_repl_values;
for (auto v :
::llvm::SmallVector<::mlir::Value, 4>{graph_op.getODSResults(0)}) {
::llvm::SmallVector<mlir::Value, 4>{graph_op.getODSResults(0)}) {
tblgen_repl_values.push_back(v);
}
op->replaceAllUsesWith(tblgen_repl_values);
// Build graph op.
::mlir::Block *block = new ::mlir::Block;
mlir::Block *block = new mlir::Block;
graph_op.body().push_back(block);
op->moveBefore(block, block->begin());
builder.setInsertionPointToEnd(block);
builder.create<mlir::pd::FetchOp>(loc, op->getResults());
builder.create<mlir::pd::ReturnOp>(loc, op->getResults());
}
}
} // namespace trt
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#pragma once
#include "mlir/Pass/Pass.h"
#include <mlir/Pass/Pass.h>
namespace infrt {
namespace trt {
......@@ -29,7 +29,7 @@ namespace trt {
* %c = "pd.conv2d"(%a) ...
* %d = "pd.conv3d"(%c) ...
* %f = "pd.conv2d"(%a) ...
* "pd.fetch" %d, %f
* "pd.fetch" (%d, %f)
* }
*
* destination func:
......@@ -37,23 +37,23 @@ namespace trt {
* %a = "pd.feed"()...
* %c = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* "pd.return" (%m)
* } ...
* %d = "pd.graph"(%c) {
* %m = "pd.conv3d"(%c)...
* "pd.fetch" %m
* "pd.return" (%m)
* } ...
* %f = "pd.graph"(%a) {
* %m = "pd.conv2d"(%a)...
* "pd.fetch" %m
* "pd.return" (%m)
* } ...
* "pd.fetch" %d, %f
* "pd.fetch" (%d, %f)
* }
* TODO(winter-wang): Supplementary how to judge the operators can be supported
* by tensorrt.
*/
class trtOpTellerPass
: public ::mlir::PassWrapper<trtOpTellerPass, ::mlir::FunctionPass> {
: public mlir::PassWrapper<trtOpTellerPass, mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "trtOpTellerPass"; }
void runOnFunction() override;
......
......@@ -13,27 +13,25 @@
// limitations under the License.
#include "paddle/infrt/dialect/tensorrt/trt_ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <mlir/IR/Matchers.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt {
namespace trt {
TensorRTDialect::TensorRTDialect(::mlir::MLIRContext *context)
: ::mlir::Dialect("trt", context, ::mlir::TypeID::get<TensorRTDialect>()) {
TensorRTDialect::TensorRTDialect(mlir::MLIRContext *context)
: mlir::Dialect("trt", context, mlir::TypeID::get<TensorRTDialect>()) {
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
>();
#undef GET_OP_LIST
}
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
#undef GET_OP_CLASSES
} // namespace trt
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensorrt/trt_ops.cpp.inc" // NOLINT
......@@ -14,37 +14,32 @@
#pragma once
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <mlir/Dialect/Traits.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/Matchers.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Interfaces/CallInterfaces.h>
#include <mlir/Interfaces/DerivedAttributeOpInterface.h>
#include <mlir/Interfaces/InferTypeOpInterface.h>
#include <mlir/Interfaces/LoopLikeInterface.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt {
namespace trt {
class TensorRTDialect : public ::mlir::Dialect {
class TensorRTDialect : public mlir::Dialect {
public:
explicit TensorRTDialect(::mlir::MLIRContext* context);
explicit TensorRTDialect(mlir::MLIRContext* context);
static llvm::StringRef getDialectNamespace() { return "trt"; }
};
// mlir bug。 can be removed safety when update mlir to llvm11.
using namespace mlir; // NOLINT
} // namespace trt
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensorrt/trt_ops.hpp.inc"
#undef GET_OP_CLASSES
} // namespace trt
} // namespace infrt
......@@ -14,14 +14,13 @@
#include "paddle/infrt/dialect/test_kernels.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
namespace infrt::dialect {
#include <mlir/IR/Builders.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/TypeUtilities.h>
namespace infrt {
namespace dialect {
//===----------------------------------------------------------------------===//
// BenchmarkOp
//===----------------------------------------------------------------------===//
......@@ -32,65 +31,67 @@ namespace infrt::dialect {
// ...
// }
static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
StringAttr nameAttr;
static mlir::ParseResult parseBenchmarkOp(
mlir::OpAsmParser &parser, // NOLINT
mlir::OperationState &result) { // NOLINT
mlir::StringAttr nameAttr;
if (parser.parseAttribute(nameAttr, "name", result.attributes))
return failure();
return mlir::failure();
// Parse the operands, e.g. (%c : i32, %d : f32)
if (parser.parseLParen()) return failure();
if (parser.parseLParen()) return mlir::failure();
SmallVector<OpAsmParser::OperandType, 4> operands;
SmallVector<Type, 4> types;
llvm::SmallVector<mlir::OpAsmParser::OperandType, 4> operands;
llvm::SmallVector<mlir::Type, 4> types;
llvm::SMLoc type_loc = parser.getCurrentLocation();
if (parser.parseOptionalRParen()) {
// Parse non-empty operands
do {
// Parse %c : i32,
OpAsmParser::OperandType operand;
Type type;
mlir::OpAsmParser::OperandType operand;
mlir::Type type;
if (parser.parseOperand(operand) || parser.parseColonType(type))
return failure();
return mlir::failure();
operands.push_back(operand);
types.push_back(type);
} while (succeeded(parser.parseOptionalComma()));
if (parser.parseRParen()) return failure();
if (parser.parseRParen()) return mlir::failure();
}
if (parser.resolveOperands(operands, types, type_loc, result.operands))
return failure();
return mlir::failure();
// Parse the keyword attribute, e.g. max_count = 100, duration_secs = 1
do {
StringRef attr;
Attribute resultAttr;
mlir::StringRef attr;
mlir::Attribute resultAttr;
if (parser.parseKeyword(&attr) || parser.parseEqual() ||
parser.parseAttribute(resultAttr,
parser.getBuilder().getIntegerType(32),
attr,
result.attributes))
return failure();
} while (succeeded(parser.parseOptionalComma()));
return mlir::failure();
} while (mlir::succeeded(parser.parseOptionalComma()));
// Set the default attribute num_warmup_runs to 1 if unset
auto setDefaultAttrIfUnset = [&](const char *attr_name, int value) {
bool found = llvm::any_of(result.attributes,
[attr_name](const NamedAttribute &attr) {
return attr.first == attr_name;
[attr_name](const mlir::NamedAttribute &attr) {
return attr.getName() == attr_name;
});
if (!found) {
IntegerAttr default_val = parser.getBuilder().getI32IntegerAttr(value);
mlir::IntegerAttr default_val =
parser.getBuilder().getI32IntegerAttr(value);
result.addAttribute(attr_name, default_val);
}
};
setDefaultAttrIfUnset("num_warmup_runs", 1);
Region *target = result.addRegion();
mlir::Region *target = result.addRegion();
return parser.parseRegion(*target,
operands,
types,
......@@ -102,11 +103,11 @@ static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT
// max_count = 100, duration_secs = 1 {
// ...
// }
static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
static void print(mlir::OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
p << "infrt.benchmark ";
// Print the name attribute, e.g "add.i32"
auto name_attr = op.getAttr("name");
auto name_attr = op->getAttr("name");
p << name_attr;
// Print the operands and types, e.g. (%c : i32, %d : f32)
......@@ -120,13 +121,13 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
bool need_comma = false;
// Print the attributes, e.g. max_count = 100, duration_secs = 1
for (auto &name_attr : op.getAttrs()) {
auto id = name_attr.first;
for (auto &name_attr : op->getAttrs()) {
auto id = name_attr.getName();
if (id == "name") continue;
if (need_comma) p << ", ";
auto attr = name_attr.second;
auto attr = name_attr.getValue();
p << id << " = ";
if (auto int_attr = attr.dyn_cast<IntegerAttr>()) {
if (auto int_attr = attr.dyn_cast<mlir::IntegerAttr>()) {
int_attr.getValue().print(p.getStream(), /*isSigned=*/false);
} else {
op.emitOpError("Unexpected attribute");
......@@ -142,7 +143,7 @@ static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
}
static LogicalResult verify(BenchmarkOp op) {
static mlir::LogicalResult verify(BenchmarkOp op) {
// Verify that the target benchmark region has exactly one return value.
auto &region = op.region();
auto &last_op = region.front().back();
......@@ -154,10 +155,10 @@ static LogicalResult verify(BenchmarkOp op) {
"incorrect number of return values. One return value is expected");
}
return success();
return mlir::success();
}
} // namespace dialect
} // namespace infrt
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/test_kernels.cpp.inc"
} // namespace infrt::dialect
......@@ -13,11 +13,8 @@
// limitations under the License.
#pragma once
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt::dialect {
using namespace mlir; // NOLINT
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/test_kernels.hpp.inc"
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/types.h"
namespace infrt::hlir::mlir {} // namespace infrt::hlir::mlir
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/IR/StandardTypes.h>
......@@ -23,7 +23,8 @@
#include "paddle/infrt/host_context/op_executable.h"
#include "paddle/infrt/host_context/symbol_table.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
struct CoreRuntime::Impl {
KernelRegistry* kernel_registry{};
......@@ -90,4 +91,5 @@ llvm::SmallVector<ValueRef, 4> CoreRuntime::GetResults(
CoreRuntime::~CoreRuntime() {}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -22,7 +22,8 @@
#include "paddle/infrt/host_context/value.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
class KernelRegistry;
class OpExecutable;
......@@ -83,4 +84,5 @@ class CoreRuntimeBuilder : public CoreRuntime {
OpExecutableBuilder* NewOpExecutable(const std::string& op_name);
};
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -21,7 +21,8 @@
#include "llvm/ADT/SmallVector.h"
#include "paddle/infrt/host_context/value.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
/**
* KernelFrame captures the states(input arguments, attributes, results)
......@@ -163,4 +164,5 @@ class KernelFrameBuilder : public KernelFrame {
}
};
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -18,7 +18,8 @@
#include "paddle/infrt/host_context/kernel_utils.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
int add_i32(int a, int b) { return a + b; }
......@@ -44,4 +45,5 @@ TEST(KernelRegistry, basic) {
ASSERT_EQ(results[0]->get<int>(), 3);
}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -16,7 +16,8 @@
#include <gtest/gtest.h>
namespace infrt::host_context {
namespace infrt {
namespace host_context {
int add_i32(int a, int b) { return a + b; }
float add_f32(float a, float b) { return a + b; }
......@@ -66,4 +67,5 @@ TEST(KernelImpl, pair) {
ASSERT_EQ(results[1]->get<float>(), 3.f);
}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -15,6 +15,7 @@
#include "paddle/infrt/host_context/mlir_function_executable.h"
#include <glog/logging.h>
#include <mlir/IR/BuiltinOps.h>
#include <string> // NOLINT
......
......@@ -13,7 +13,8 @@
// limitations under the License.
#pragma once
#include <mlir/IR/Function.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Region.h>
#include <string>
#include <unordered_map>
......
......@@ -15,9 +15,9 @@
#pragma once
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OperationSupport.h>
#include <unordered_map>
......
......@@ -16,8 +16,9 @@
#include <llvm/Support/SourceMgr.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/Parser.h>
......@@ -40,7 +41,8 @@
#include "paddle/infrt/host_context/value.h"
#include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
template <typename T>
std::string DumpToString(T& op) { // NOLINT
......@@ -113,10 +115,10 @@ bool MlirToRuntimeTranslator::EmitConstantOp(mlir::Operation* op) {
template <>
boost::optional<int32_t> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::IntegerAttr>()) return boost::none;
if (attr->isa<mlir::IntegerAttr>()) {
auto val = attr->cast<mlir::IntegerAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::IntegerAttr>()) return boost::none;
if (attr.isa<mlir::IntegerAttr>()) {
auto val = attr.cast<mlir::IntegerAttr>();
if (val.getType().isInteger(32)) {
return val.getInt();
}
......@@ -125,10 +127,10 @@ boost::optional<int32_t> MlirToRuntimeTranslator::EmitAttribute(
}
template <>
boost::optional<int64_t> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::IntegerAttr>()) return boost::none;
if (attr->isa<mlir::IntegerAttr>()) {
auto val = attr->cast<mlir::IntegerAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::IntegerAttr>()) return boost::none;
if (attr.isa<mlir::IntegerAttr>()) {
auto val = attr.cast<mlir::IntegerAttr>();
if (val.getType().isInteger(64)) {
return val.getInt();
}
......@@ -139,10 +141,10 @@ boost::optional<int64_t> MlirToRuntimeTranslator::EmitAttribute(
// TODO(Superjomn) Make double and float parsing share some thing.
template <>
boost::optional<float> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::FloatAttr>()) return boost::none;
if (attr->isa<mlir::FloatAttr>()) {
auto val = attr->cast<mlir::FloatAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::FloatAttr>()) return boost::none;
if (attr.isa<mlir::FloatAttr>()) {
auto val = attr.cast<mlir::FloatAttr>();
if (val.getType().isF32()) return val.getValueAsDouble();
}
return boost::none;
......@@ -150,10 +152,10 @@ boost::optional<float> MlirToRuntimeTranslator::EmitAttribute(
template <>
boost::optional<double> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::FloatAttr>()) return boost::none;
if (attr->isa<mlir::FloatAttr>()) {
auto val = attr->cast<mlir::FloatAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::FloatAttr>()) return boost::none;
if (attr.isa<mlir::FloatAttr>()) {
auto val = attr.cast<mlir::FloatAttr>();
if (val.getType().isF64()) return val.getValueAsDouble();
}
return boost::none;
......@@ -161,17 +163,17 @@ boost::optional<double> MlirToRuntimeTranslator::EmitAttribute(
template <>
boost::optional<std::string> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::StringAttr>()) return boost::none;
return attr->cast<mlir::StringAttr>().getValue().str();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::StringAttr>()) return boost::none;
return attr.cast<mlir::StringAttr>().getValue().str();
}
#define PROCESS_ARRAY_INT(type__, bits__) \
template <> \
boost::optional<std::vector<type__>> MlirToRuntimeTranslator::EmitAttribute( \
const mlir::Attribute* attr) { \
if (!attr->isa<mlir::ArrayAttr>()) return boost::none; \
auto array = attr->cast<mlir::ArrayAttr>(); \
const mlir::Attribute& attr) { \
if (!attr.isa<mlir::ArrayAttr>()) return boost::none; \
auto array = attr.cast<mlir::ArrayAttr>(); \
CHECK(!array.empty()); \
\
if (!array[0].getType().isInteger(bits__)) { \
......@@ -191,9 +193,9 @@ PROCESS_ARRAY_INT(int64_t, 64);
template <>
boost::optional<std::vector<float>> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr->cast<mlir::ArrayAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr.cast<mlir::ArrayAttr>();
CHECK(!array.empty());
if (!array[0].getType().isF32()) return boost::none;
......@@ -207,9 +209,9 @@ boost::optional<std::vector<float>> MlirToRuntimeTranslator::EmitAttribute(
template <>
boost::optional<std::vector<double>> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr->cast<mlir::ArrayAttr>();
const mlir::Attribute& attr) {
if (!attr.isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr.cast<mlir::ArrayAttr>();
CHECK(!array.empty());
if (!array[0].getType().isF64()) return boost::none;
......@@ -236,7 +238,8 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
for (int i = 0, e = op->getNumOperands(); i < e; i++) {
// function argument as value
auto operand = op->getOperand(i);
if (operand.getKind() == mlir::Value::Kind::BlockArgument) {
/// if (operand.getKind() == mlir::Value::Kind::BlockArgument) {
if (operand.isa<mlir::BlockArgument>()) {
mlir::BlockArgument arg = operand.dyn_cast<mlir::BlockArgument>();
Value* arg_value = GetValue(arg);
impl_->cur_op->AppendArgument(arg_value);
......@@ -283,25 +286,25 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
for (size_t i = 0; i < attrs.size(); i++) {
auto& attr = attrs[i];
if (auto v = EmitAttribute<int32_t>(&attr.second)) {
if (auto v = EmitAttribute<int32_t>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<int64_t>(&attr.second)) {
} else if (auto v = EmitAttribute<int64_t>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<float>(&attr.second)) {
} else if (auto v = EmitAttribute<float>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<double>(&attr.second)) {
} else if (auto v = EmitAttribute<double>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<std::string>(&attr.second)) {
} else if (auto v = EmitAttribute<std::string>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int16_t>>(&attr.second)) {
} else if (auto v = EmitAttribute<std::vector<int16_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int32_t>>(&attr.second)) {
} else if (auto v = EmitAttribute<std::vector<int32_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int64_t>>(&attr.second)) {
} else if (auto v = EmitAttribute<std::vector<int64_t>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<float>>(&attr.second)) {
} else if (auto v = EmitAttribute<std::vector<float>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<double>>(&attr.second)) {
} else if (auto v = EmitAttribute<std::vector<double>>(attr.getValue())) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else {
LOG(FATAL) << "Not supported attribute type";
......@@ -330,7 +333,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
llvm::SmallVector<mlir::Type, 0> results;
auto func_type =
mlir::FunctionType::get(inputs, results, region.getContext());
mlir::FunctionType::get(region.getContext(), inputs, results);
auto* function = impl_->cur_op->CreateFunctionExecutable(
&region, func_type, &impl_->func_defs);
impl_->cur_op->AppendAttribute(new Value(function));
......@@ -555,4 +558,5 @@ void TestMlir(mlir::ModuleOp module, KernelRegistry* registry) {
execute.Run();
}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -29,7 +29,8 @@ class Attribute;
class Value;
} // namespace mlir
namespace infrt::host_context {
namespace infrt {
namespace host_context {
class CoreRuntimeBuilder;
class Value;
......@@ -73,7 +74,7 @@ class MlirToRuntimeTranslator {
bool EmitCallOp(mlir::Operation* op, function_defs_t* function_table);
template <typename T>
boost::optional<T> EmitAttribute(const mlir::Attribute* attr);
boost::optional<T> EmitAttribute(const mlir::Attribute& attr);
Value* GetOpResult(mlir::Operation* op);
......@@ -104,4 +105,5 @@ void MlirToRuntimeTranslate(mlir::ModuleOp module, CoreRuntimeBuilder* runtime);
*/
void TestMlir(mlir::ModuleOp module, KernelRegistry* registry);
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -29,7 +29,8 @@
#include "paddle/infrt/kernel/tensor_shape_kernels.h"
#include "paddle/infrt/kernel/test_kernels.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
TEST(MlirToRuntimeTranslate, basic) {
mlir::MLIRContext context;
......@@ -48,7 +49,7 @@ func @main() -> () {
)ROC";
auto module = dialect::LoadMlirSource(&context, source);
module->verify();
EXPECT_TRUE(mlir::succeeded(module->verify()));
KernelRegistry registry;
kernel::RegisterFloatBasicKernels(&registry);
......@@ -74,7 +75,7 @@ func @main() -> () {
)ROC";
auto module = dialect::LoadMlirSource(&context, source);
module->verify();
EXPECT_TRUE(mlir::succeeded(module->verify()));
KernelRegistry registry;
kernel::RegisterFloatBasicKernels(&registry);
......@@ -115,7 +116,7 @@ infrt.return %a0, %b0: !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F
// LOG(INFO) << "content: " << content << std::endl;
auto module = dialect::LoadMlirSource(context, content);
module->verify();
EXPECT_TRUE(mlir::succeeded(module->verify()));
host_context::KernelRegistry registry;
......@@ -157,4 +158,5 @@ infrt.return %a0, %b0: !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F
}
}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -14,6 +14,7 @@
#include "paddle/infrt/host_context/op_executable.h"
#include <mlir/IR/BuiltinOps.h>
#include <string>
#include "paddle/infrt/host_context/kernel_frame.h"
......@@ -21,7 +22,8 @@
#include "paddle/infrt/host_context/mlir_function_executable.h"
#include "paddle/infrt/host_context/symbol_table.h"
namespace infrt::host_context {
namespace infrt {
namespace host_context {
struct OpExecutable::Impl {
Impl(const std::string& op_name,
......@@ -148,4 +150,5 @@ void OpExecutable::Execute() {
OpExecutable::~OpExecutable() {}
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -14,19 +14,18 @@
#pragma once
#include <llvm/ADT/ArrayRef.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/Region.h>
#include <memory>
#include <string>
#include <unordered_map>
#include "mlir/IR/Function.h"
#include "mlir/IR/Region.h"
namespace mlir {
class FuncOp;
} // namespace mlir
namespace infrt::host_context {
namespace infrt {
namespace host_context {
class SymbolTable;
class KernelRegistry;
......@@ -89,4 +88,5 @@ class OpExecutableBuilder : public OpExecutable {
function_defs_t* function_defs);
};
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
......@@ -23,7 +23,8 @@
using infrt::host_context::Attribute;
namespace infrt::kernel {
namespace infrt {
namespace kernel {
template <typename T>
T add(T a, T b) {
......@@ -82,4 +83,5 @@ void RegisterFloatBasicKernels(host_context::KernelRegistry *registry) {
registry->AddKernel("infrt.print.f32", INFRT_KERNEL(print<float>));
}
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -15,13 +15,16 @@
#pragma once
#include <string>
namespace infrt::host_context {
namespace infrt {
namespace host_context {
struct KernelRegistry;
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
namespace infrt::kernel {
namespace infrt {
namespace kernel {
/**
* Register all the basic kernels to \p registry.
......@@ -31,4 +34,5 @@ void RegisterBasicKernels(host_context::KernelRegistry* registry);
void RegisterIntBasicKernels(host_context::KernelRegistry* registry);
void RegisterFloatBasicKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -25,7 +25,8 @@
#include "paddle/infrt/tensor/tensor_map.h"
#include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::kernel {
namespace infrt {
namespace kernel {
using namespace host_context; // NOLINT
using namespace tensor; // NOLINT
......@@ -76,4 +77,5 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) {
INFRT_KERNEL(ShallowCopyTensor));
}
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -14,12 +14,16 @@
#pragma once
namespace infrt::host_context {
namespace infrt {
namespace host_context {
struct KernelRegistry;
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
namespace infrt::kernel {
namespace infrt {
namespace kernel {
void RegisterTensorKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -24,7 +24,8 @@
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::kernel {
namespace infrt {
namespace kernel {
void PrintShape(const tensor::TensorShape& shape) {
llvm::raw_os_ostream oos(std::cout);
......@@ -35,4 +36,5 @@ void RegisterTensorShapeKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("ts.print_shape", INFRT_KERNEL(PrintShape));
}
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -14,14 +14,18 @@
#pragma once
namespace infrt::host_context {
namespace infrt {
namespace host_context {
class KernelRegistry;
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
namespace infrt::kernel {
namespace infrt {
namespace kernel {
void RegisterTensorShapeKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -33,7 +33,8 @@ using infrt::host_context::Attribute;
using infrt::host_context::MlirFunctionExecutable;
using infrt::host_context::RemainingArguments;
namespace infrt::kernel {
namespace infrt {
namespace kernel {
namespace {
class BenchmarkStats {
public:
......@@ -197,4 +198,5 @@ void RegisterTestKernels(host_context::KernelRegistry *registry) {
INFRT_KERNEL(ShadowCopyTensor));
}
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -15,17 +15,21 @@
#pragma once
#include <string>
namespace infrt::host_context {
namespace infrt {
namespace host_context {
struct KernelRegistry;
} // namespace infrt::host_context
} // namespace host_context
} // namespace infrt
namespace infrt::kernel {
namespace infrt {
namespace kernel {
/**
* Register all the test kernels to registry.
*/
void RegisterTestKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel
} // namespace kernel
} // namespace infrt
......@@ -18,7 +18,9 @@
#include <string>
#include <vector>
namespace infrt::paddle::cpp {
namespace infrt {
namespace paddle {
namespace cpp {
/*
* Compatible interfaces for all the different kinds of XXXDesc. All the XXXDesc
......@@ -226,4 +228,6 @@ class ProgramDescAPI {
virtual void SetVersion(int64_t version) = 0;
};
} // namespace infrt::paddle::cpp
} // namespace cpp
} // namespace paddle
} // namespace infrt
......@@ -22,7 +22,8 @@
#include "paddle/infrt/common/target.h"
#include "paddle/infrt/common/type.h"
namespace infrt::paddle {
namespace infrt {
namespace paddle {
int SizeOfType(framework_proto::VarType::Type type) {
using Type = framework_proto::VarType::Type;
......@@ -169,4 +170,5 @@ void LoadParam(const std::string &path, _Variable *out, const Target &target) {
LoadLoDTensor(fin, out, target);
}
} // namespace infrt::paddle
} // namespace paddle
} // namespace infrt
......@@ -25,7 +25,8 @@
#include "paddle/infrt/paddle/scope.h"
#include "paddle/infrt/paddle/tensor.h"
namespace infrt::paddle {
namespace infrt {
namespace paddle {
namespace framework_proto = ::paddle::framework::proto;
// Read a __model__ file.
......@@ -52,4 +53,5 @@ void TensorFromStream(
const common::Target& target = common::DefaultHostTarget());
void ReadBinaryFile(const std::string& filename, std::string* contents);
} // namespace infrt::paddle
} // namespace paddle
} // namespace infrt
......@@ -14,7 +14,9 @@
#include "paddle/infrt/paddle/pb/block_desc.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
template <>
framework_proto::VarDesc* BlockDesc::GetVar<framework_proto::VarDesc>(
......@@ -40,4 +42,6 @@ framework_proto::OpDesc* BlockDesc::AddOp<framework_proto::OpDesc>() {
return desc_->add_ops();
}
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -18,7 +18,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto;
......@@ -74,4 +76,6 @@ class BlockDesc : public cpp::BlockDescAPI {
framework_proto::BlockDesc* desc_; // not_own
};
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -14,7 +14,9 @@
#include "paddle/infrt/paddle/pb/op_desc.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
google::protobuf::internal::RepeatedPtrIterator<framework_proto::OpDesc_Attr>
FindAttr(framework_proto::OpDesc *desc, const std::string &name) {
......@@ -136,4 +138,6 @@ GET_ATTRS_IMPL(std::vector<std::string>, strings);
GET_ATTR_IMPL(std::string, s);
GET_ATTRS_IMPL(std::vector<int64_t>, longs);
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -19,7 +19,9 @@
#include "paddle/infrt/paddle/framework.pb.h"
#include "paddle/infrt/support/variant.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto;
......@@ -195,4 +197,6 @@ template <>
void OpDesc::SetAttr<std::vector<int>>(const std::string &name,
const std::vector<int> &v);
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -17,7 +17,9 @@
#include <algorithm>
#include <limits>
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
template <>
framework_proto::BlockDesc* ProgramDesc::GetBlock<framework_proto::BlockDesc>(
......@@ -32,4 +34,6 @@ ProgramDesc::AddBlock<framework_proto::BlockDesc>() {
return desc_->add_blocks();
}
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -21,7 +21,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto;
class ProgramDesc : public cpp::ProgramDescAPI {
......@@ -58,4 +60,6 @@ class ProgramDesc : public cpp::ProgramDescAPI {
framework_proto::ProgramDesc *desc_; // not_own
};
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -19,7 +19,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
cpp::VarDescAPI::Type VarDesc::GetType() const {
auto type = desc_->type().type();
......@@ -364,4 +366,6 @@ VarDesc::mutable_tensor_descs() {
return std::vector<framework_proto::VarType::TensorDesc *>();
}
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
......@@ -23,7 +23,9 @@
#include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb {
namespace infrt {
namespace paddle {
namespace pb {
namespace framework_proto = ::paddle::framework::proto;
// convert between std::vector and protobuf repeated.
......@@ -121,4 +123,6 @@ class VarDesc : public cpp::VarDescAPI {
framework_proto::VarDesc *desc_;
};
} // namespace infrt::paddle::pb
} // namespace pb
} // namespace paddle
} // namespace infrt
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册