未验证 提交 ef4ef154 编写于 作者: 王明冬 提交者: GitHub

[infrt] rename pd dialect from mlir to infrt. (#40651)

* [infrt] rename pd dialect from mlir to infrt. test=develop

* [infrt] fix the kernel signature generator bug.
上级 081e4307
...@@ -46,10 +46,19 @@ int main(int argc, char **argv) { ...@@ -46,10 +46,19 @@ int main(int argc, char **argv) {
auto &kernel_factory = phi::KernelFactory::Instance(); auto &kernel_factory = phi::KernelFactory::Instance();
std::string kernel_signature_map_str{"{"}; std::string kernel_signature_map_str{"{"};
for (const auto &op_kernel_pair : kernel_factory.kernels()) { for (const auto &op_kernel_pair : kernel_factory.kernels()) {
if (kernel_signature_map.Has(op_kernel_pair.first)) { std::string op_name = op_kernel_pair.first;
const paddle::flat_hash_map<std::string, std::string> &kernel_name_map =
phi::OpUtilsMap::Instance().base_kernel_name_map();
for (auto &it : kernel_name_map) {
if (it.second == op_name) {
op_name = it.first;
break;
}
}
if (kernel_signature_map.Has(op_name)) {
kernel_signature_map_str = kernel_signature_map_str =
kernel_signature_map_str + "\"" + op_kernel_pair.first + "\":{"; kernel_signature_map_str + "\"" + op_kernel_pair.first + "\":{";
auto &args = kernel_signature_map.Get(op_kernel_pair.first).args; auto &args = kernel_signature_map.Get(op_name).args;
kernel_signature_map_str += "\"inputs\":["; kernel_signature_map_str += "\"inputs\":[";
auto inputs_ = std::get<0>(args); auto inputs_ = std::get<0>(args);
......
...@@ -33,13 +33,14 @@ void registerCinnDialects(mlir::DialectRegistry &registry) { // NOLINT ...@@ -33,13 +33,14 @@ void registerCinnDialects(mlir::DialectRegistry &registry) { // NOLINT
registry.insert<ts::TensorShapeDialect, registry.insert<ts::TensorShapeDialect,
InfrtDialect, InfrtDialect,
dt::DTDialect, dt::DTDialect,
mlir::pd::PaddleDialect, pd::PaddleDialect,
trt::TensorRTDialect
#ifdef INFRT_WITH_PHI #ifdef INFRT_WITH_PHI
,
phi::PHIDenseTensorDialect, phi::PHIDenseTensorDialect,
phi::PHICPUKernelDialect, phi::PHICPUKernelDialect,
phi::PHIGPUKernelDialect, phi::PHIGPUKernelDialect,
phi::PHIDialect, phi::PHIDialect
infrt::trt::TensorRTDialect
#endif #endif
>(); >();
} }
......
...@@ -17,7 +17,7 @@ def Paddle_Dialect : Dialect { ...@@ -17,7 +17,7 @@ def Paddle_Dialect : Dialect {
This dialect contains the PaddlePaddle operators. This dialect contains the PaddlePaddle operators.
}]; }];
let hasConstantMaterializer = 1; let hasConstantMaterializer = 1;
let cppNamespace = "mlir::pd"; let cppNamespace = "infrt::pd";
} }
class PD_Op<string mnemonic, list<OpTrait> traits = []> : class PD_Op<string mnemonic, list<OpTrait> traits = []> :
...@@ -25,7 +25,7 @@ class PD_Op<string mnemonic, list<OpTrait> traits = []> : ...@@ -25,7 +25,7 @@ class PD_Op<string mnemonic, list<OpTrait> traits = []> :
class PD_PaddleAttr <string name, string description> : class PD_PaddleAttr <string name, string description> :
Attr<CPred<"$_self.isa<mlir::pd::" # name # "Attr>()">, Attr<CPred<"$_self.isa<infrt::pd::" # name # "Attr>()">,
"PaddlePaddle " # description # " attribute">; "PaddlePaddle " # description # " attribute">;
...@@ -33,12 +33,12 @@ class PD_PaddleAttr <string name, string description> : ...@@ -33,12 +33,12 @@ class PD_PaddleAttr <string name, string description> :
// PaddlePaddle type definitions // PaddlePaddle type definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def PD_PDDialectType : Type<CPred<"$_self.isa<mlir::pd::PDType>()">, "PaddlePaddle type">; def PD_PDDialectType : Type<CPred<"$_self.isa<infrt::pd::PDType>()">, "PaddlePaddle type">;
class PD_PaddleType <string name, string description> : class PD_PaddleType <string name, string description> :
Type<CPred<"$_self.isa<mlir::pd::" # name #"Type>()">, Type<CPred<"$_self.isa<infrt::pd::" # name #"Type>()">,
"Paddle " # description # " type">, "Paddle " # description # " type">,
BuildableType<"getType<mlir::pd::" # name # "Type>()">; BuildableType<"getType<infrt::pd::" # name # "Type>()">;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Integer types // Integer types
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd/ir/pd_extra_ops.cpp.inc" // NOLINT #include "paddle/infrt/dialect/pd/ir/pd_extra_ops.cpp.inc" // NOLINT
namespace mlir { namespace infrt {
namespace pd { namespace pd {
void PaddleDialect::initialize() { void PaddleDialect::initialize() {
addOperations< addOperations<
...@@ -43,33 +43,34 @@ mlir::Operation *PaddleDialect::materializeConstant(mlir::OpBuilder &builder, ...@@ -43,33 +43,34 @@ mlir::Operation *PaddleDialect::materializeConstant(mlir::OpBuilder &builder,
return builder.create<ConstantOp>(loc, value); return builder.create<ConstantOp>(loc, value);
} }
void ConstantOp::build(OpBuilder &builder, void ConstantOp::build(mlir::OpBuilder &builder,
OperationState &state, mlir::OperationState &state,
Attribute value) { mlir::Attribute value) {
if (auto elem_attr = value.dyn_cast<ElementsAttr>()) { if (auto elem_attr = value.dyn_cast<mlir::ElementsAttr>()) {
return ConstantOp::build(builder, state, elem_attr); return ConstantOp::build(builder, state, elem_attr);
} else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) { } else if (value.isa<mlir::BoolAttr, mlir::FloatAttr, mlir::IntegerAttr>()) {
ShapedType type = RankedTensorType::get(/*shape=*/{}, value.getType()); mlir::ShapedType type =
state.addAttribute("value", DenseElementsAttr::get(type, value)); mlir::RankedTensorType::get(/*shape=*/{}, value.getType());
state.addAttribute("value", mlir::DenseElementsAttr::get(type, value));
state.addTypes(type); state.addTypes(type);
return; return;
} }
llvm_unreachable("unsupported attribute type for building pd.constant"); llvm_unreachable("unsupported attribute type for building pd.constant");
} }
LogicalResult ConstantOp::inferReturnTypes( mlir::LogicalResult ConstantOp::inferReturnTypes(
MLIRContext *context, mlir::MLIRContext *context,
Optional<Location> location, mlir::Optional<mlir::Location> location,
ValueRange operands, mlir::ValueRange operands,
DictionaryAttr attributes, mlir::DictionaryAttr attributes,
RegionRange regions, mlir::RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) { llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(attributes.get("value").getType()); inferredReturnTypes.push_back(attributes.get("value").getType());
return success(); return mlir::success();
} }
mlir::OpFoldResult ConstantOp::fold( mlir::OpFoldResult ConstantOp::fold(
::llvm::ArrayRef<mlir::Attribute> operands) { ::llvm::ArrayRef<mlir::Attribute> operands) {
return value(); return value();
} }
} // namespace pd } // namespace pd
} // namespace mlir } // namespace infrt
...@@ -55,8 +55,8 @@ bool reverseDfs(std::vector<mlir::Operation *> source, ...@@ -55,8 +55,8 @@ bool reverseDfs(std::vector<mlir::Operation *> source,
// merge the first&second graph op to a new graph op. // merge the first&second graph op to a new graph op.
void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT
mlir::pd::GraphOp first, infrt::pd::GraphOp first,
mlir::pd::GraphOp second) { infrt::pd::GraphOp second) {
// comput inputs and outputs // comput inputs and outputs
::llvm::SmallVector<mlir::Value, 4> inputs(first.getOperands()), outputs; ::llvm::SmallVector<mlir::Value, 4> inputs(first.getOperands()), outputs;
for (mlir::Value input : second.getOperands()) { for (mlir::Value input : second.getOperands()) {
...@@ -85,7 +85,7 @@ void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT ...@@ -85,7 +85,7 @@ void mergeTwoAdjacentGraphOp(mlir::OpBuilder &builder, // NOLINT
// create the new graph op // create the new graph op
builder.setInsertionPoint(first); builder.setInsertionPoint(first);
auto loc = first.getLoc(); auto loc = first.getLoc();
auto graph_op = builder.create<mlir::pd::GraphOp>(loc, return_types, inputs); auto graph_op = builder.create<infrt::pd::GraphOp>(loc, return_types, inputs);
mlir::Block *block = new mlir::Block; mlir::Block *block = new mlir::Block;
auto copy_range = second.getBody()->without_terminator(); auto copy_range = second.getBody()->without_terminator();
block->getOperations().splice(block->begin(), block->getOperations().splice(block->begin(),
...@@ -150,13 +150,13 @@ void TRTGraphFusePass::runOnFunction() { ...@@ -150,13 +150,13 @@ void TRTGraphFusePass::runOnFunction() {
do { do {
changed = false; changed = false;
for (auto &op : body) { for (auto &op : body) {
mlir::pd::GraphOp graph_op = infrt::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op); ::llvm::dyn_cast_or_null<infrt::pd::GraphOp>(&op);
if (nullptr == graph_op) continue; if (nullptr == graph_op) continue;
for (auto user_op : op.getUsers()) { for (auto user_op : op.getUsers()) {
mlir::pd::GraphOp user_graph_op = infrt::pd::GraphOp user_graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(user_op); ::llvm::dyn_cast_or_null<infrt::pd::GraphOp>(user_op);
if (nullptr == user_graph_op) continue; if (nullptr == user_graph_op) continue;
// get all dst input nodes except src. // get all dst input nodes except src.
std::vector<mlir::Operation *> source_nodes; std::vector<mlir::Operation *> source_nodes;
......
...@@ -21,18 +21,18 @@ namespace infrt { ...@@ -21,18 +21,18 @@ namespace infrt {
namespace trt { namespace trt {
// Implementation of the trtGraphSplitPass。 // Implementation of the trtGraphSplitPass。
void TRTGraphSplitPass::runOnFunction() { void TRTGraphSplitPass::runOnFunction() {
std::vector<mlir::pd::GraphOp> worklist; std::vector<infrt::pd::GraphOp> worklist;
mlir::Block& block = getFunction().front(); mlir::Block& block = getFunction().front();
for (auto& op : block) { for (auto& op : block) {
mlir::pd::GraphOp graph_op = infrt::pd::GraphOp graph_op =
::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(&op); ::llvm::dyn_cast_or_null<infrt::pd::GraphOp>(&op);
if (nullptr != graph_op && if (nullptr != graph_op &&
graph_op.getBody()->getOperations().size() <= min_subgraph_size_) { graph_op.getBody()->getOperations().size() <= min_subgraph_size_) {
worklist.push_back(graph_op); worklist.push_back(graph_op);
} }
} }
while (!worklist.empty()) { while (!worklist.empty()) {
mlir::pd::GraphOp graph_op = worklist.back(); infrt::pd::GraphOp graph_op = worklist.back();
worklist.pop_back(); worklist.pop_back();
mlir::Block* body = graph_op.getBody(); mlir::Block* body = graph_op.getBody();
auto return_op = body->getTerminator(); auto return_op = body->getTerminator();
......
...@@ -27,7 +27,7 @@ struct PD2TRT_GraphLower : public ::mlir::RewritePattern { ...@@ -27,7 +27,7 @@ struct PD2TRT_GraphLower : public ::mlir::RewritePattern {
: ::mlir::RewritePattern("pd.graph", 1, context, {"trt.create_engine"}) {} : ::mlir::RewritePattern("pd.graph", 1, context, {"trt.create_engine"}) {}
::mlir::LogicalResult matchAndRewrite( ::mlir::LogicalResult matchAndRewrite(
::mlir::Operation *op, ::mlir::PatternRewriter &rewriter) const override { ::mlir::Operation *op, ::mlir::PatternRewriter &rewriter) const override {
auto casted_op = ::llvm::dyn_cast<mlir::pd::GraphOp>(op); auto casted_op = ::llvm::dyn_cast<infrt::pd::GraphOp>(op);
::mlir::Operation::operand_range inputs = casted_op.inputs(); ::mlir::Operation::operand_range inputs = casted_op.inputs();
auto ods_loc = rewriter.getFusedLoc(op->getLoc()); auto ods_loc = rewriter.getFusedLoc(op->getLoc());
CreateEngineOp create_engine_op; CreateEngineOp create_engine_op;
......
...@@ -35,13 +35,13 @@ void TRTOpTellerPass::runOnFunction() { ...@@ -35,13 +35,13 @@ void TRTOpTellerPass::runOnFunction() {
auto *op = worklist.back(); auto *op = worklist.back();
worklist.pop_back(); worklist.pop_back();
if (op == nullptr) continue; if (op == nullptr) continue;
if (::llvm::dyn_cast_or_null<mlir::pd::FeedOp>(op)) continue; if (::llvm::dyn_cast_or_null<infrt::pd::FeedOp>(op)) continue;
if (::llvm::dyn_cast_or_null<mlir::pd::FetchOp>(op)) continue; if (::llvm::dyn_cast_or_null<infrt::pd::FetchOp>(op)) continue;
if (::llvm::dyn_cast_or_null<mlir::pd::GraphOp>(op)) continue; if (::llvm::dyn_cast_or_null<infrt::pd::GraphOp>(op)) continue;
if (::llvm::dyn_cast_or_null<::infrt::ReturnOp>(op)) continue; if (::llvm::dyn_cast_or_null<::infrt::ReturnOp>(op)) continue;
builder.setInsertionPoint(op); builder.setInsertionPoint(op);
auto loc = getFunction().getLoc(); auto loc = getFunction().getLoc();
auto graph_op = builder.create<mlir::pd::GraphOp>( auto graph_op = builder.create<infrt::pd::GraphOp>(
loc, op->getResultTypes(), op->getOperands()); loc, op->getResultTypes(), op->getOperands());
::llvm::SmallVector<mlir::Value, 4> tblgen_repl_values; ::llvm::SmallVector<mlir::Value, 4> tblgen_repl_values;
......
...@@ -22,7 +22,7 @@ MLIRModelGenImpl::MLIRModelGenImpl() ...@@ -22,7 +22,7 @@ MLIRModelGenImpl::MLIRModelGenImpl()
context_->getOrLoadDialect<mlir::StandardOpsDialect>(); context_->getOrLoadDialect<mlir::StandardOpsDialect>();
context_->getOrLoadDialect<infrt::ts::TensorShapeDialect>(); context_->getOrLoadDialect<infrt::ts::TensorShapeDialect>();
context_->getOrLoadDialect<infrt::dt::DTDialect>(); context_->getOrLoadDialect<infrt::dt::DTDialect>();
context_->getOrLoadDialect<mlir::pd::PaddleDialect>(); context_->getOrLoadDialect<infrt::pd::PaddleDialect>();
context_->getOrLoadDialect<::infrt::InfrtDialect>(); context_->getOrLoadDialect<::infrt::InfrtDialect>();
module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(context_)); module_ = mlir::ModuleOp::create(mlir::UnknownLoc::get(context_));
} }
......
...@@ -72,7 +72,7 @@ TEST(ABS_MODEL, convert_and_execute) { ...@@ -72,7 +72,7 @@ TEST(ABS_MODEL, convert_and_execute) {
context->getOrLoadDialect<infrt::ts::TensorShapeDialect>(); context->getOrLoadDialect<infrt::ts::TensorShapeDialect>();
context->getOrLoadDialect<infrt::InfrtDialect>(); context->getOrLoadDialect<infrt::InfrtDialect>();
context->getOrLoadDialect<infrt::dt::DTDialect>(); context->getOrLoadDialect<infrt::dt::DTDialect>();
context->getOrLoadDialect<mlir::pd::PaddleDialect>(); context->getOrLoadDialect<infrt::pd::PaddleDialect>();
context->getOrLoadDialect<infrt::phi::PHIDenseTensorDialect>(); context->getOrLoadDialect<infrt::phi::PHIDenseTensorDialect>();
context->getOrLoadDialect<infrt::phi::PHICPUKernelDialect>(); context->getOrLoadDialect<infrt::phi::PHICPUKernelDialect>();
......
...@@ -42,6 +42,6 @@ def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInte ...@@ -42,6 +42,6 @@ def PD_ConstantOp : PD_Op<"constant", [NoSideEffect, ConstantLike, DeclareOpInte
let hasFolder = 1; let hasFolder = 1;
let builders = [ let builders = [
OpBuilder<(ins "Attribute":$value)>, OpBuilder<(ins "mlir::Attribute":$value)>,
]; ];
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册