未验证 提交 517b1a7c 编写于 作者: 王明冬 提交者: GitHub

[infrt] add parse for infrt.dense_tensor_type. test=develop (#40592)

上级 84e17a31
...@@ -134,6 +134,10 @@ mlir::Type InfrtDialect::parseType(::mlir::DialectAsmParser &parser) const { ...@@ -134,6 +134,10 @@ mlir::Type InfrtDialect::parseType(::mlir::DialectAsmParser &parser) const {
return DenseTensorType::get( return DenseTensorType::get(
parser.getContext(), *targetType, *precisionType, *layoutType); parser.getContext(), *targetType, *precisionType, *layoutType);
} }
if (keyword == "dense_tensor_map") {
return DenseTensorMapType::get(parser.getContext());
}
// Todo: parse other type // Todo: parse other type
return mlir::Type(); return mlir::Type();
} }
...@@ -156,7 +160,7 @@ void InfrtDialect::printType(::mlir::Type type, ...@@ -156,7 +160,7 @@ void InfrtDialect::printType(::mlir::Type type,
} }
// print DenseTensorType, for example: !infrt.dense_tensor<CPU, FP32, NCHW> // print DenseTensorType, for example: !infrt.dense_tensor<CPU, FP32, NCHW>
if (type.isa<infrt::DenseTensorType>()) { if (type.isa<DenseTensorType>()) {
auto dense_tensor_type = type.cast<infrt::DenseTensorType>(); auto dense_tensor_type = type.cast<infrt::DenseTensorType>();
os << "dense_tensor<" << dense_tensor_type.getTarget() << ", " os << "dense_tensor<" << dense_tensor_type.getTarget() << ", "
<< dense_tensor_type.getPrecision() << ", " << dense_tensor_type.getPrecision() << ", "
...@@ -164,6 +168,12 @@ void InfrtDialect::printType(::mlir::Type type, ...@@ -164,6 +168,12 @@ void InfrtDialect::printType(::mlir::Type type,
return; return;
} }
// print DenseTensorType, for example: !infrt.dense_tensor<CPU, FP32, NCHW>
if (type.isa<DenseTensorMapType>()) {
os << "dense_tensor_map";
return;
}
llvm_unreachable("unknown infrt type."); llvm_unreachable("unknown infrt type.");
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "paddle/infrt/common/global.h" #include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.h" #include "paddle/infrt/dialect/infrt/pass/infrt_op_fuse_pass.h"
#include "paddle/infrt/dialect/mlir_loader.h" #include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/phi/pass/phi_op_cvt_pass.h" #include "paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h"
int main(int argc, char** argv) { int main(int argc, char** argv) {
static llvm::cl::opt<std::string> input_file( static llvm::cl::opt<std::string> input_file(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册