From a78ca1cf87079a04591b536237c0415ba8526679 Mon Sep 17 00:00:00 2001 From: Wilber Date: Sun, 10 Apr 2022 12:40:37 +0800 Subject: [PATCH] predictor support trt (#41556) --- paddle/infrt/api/CMakeLists.txt | 2 + paddle/infrt/api/infrt_api.cc | 45 ++++++++++--- paddle/infrt/api/infrt_api.h | 8 +++ paddle/infrt/api/infrt_api_test.cc.in | 43 +++++++++++++ paddle/infrt/backends/tensorrt/trt_utils.h | 3 +- .../infrt/dialect/phi/ir/infrt_phi_tensor.td | 3 +- paddle/infrt/dialect/tensorrt/trt_exec.cc | 2 +- .../dialect/tensorrt/trt_graph_fuse_pass.cc | 5 ++ .../dialect/tensorrt/trt_graph_fuse_pass.h | 3 + .../dialect/tensorrt/trt_graph_split_pass.cc | 5 ++ .../dialect/tensorrt/trt_graph_split_pass.h | 3 + .../dialect/tensorrt/trt_op_converter_pass.cc | 4 ++ .../dialect/tensorrt/trt_op_converter_pass.h | 3 + .../dialect/tensorrt/trt_op_teller_pass.cc | 5 ++ .../dialect/tensorrt/trt_op_teller_pass.h | 3 + .../dialect/tensorrt/trt_type_convert_pass.cc | 2 +- .../dialect/tensorrt/trt_type_convert_pass.h | 2 +- paddle/infrt/host_context/paddle_mlir.cc | 63 +++++++++++++++---- paddle/infrt/host_context/paddle_mlir.h | 20 ++++-- paddle/infrt/kernel/phi/registry.cc | 2 +- 20 files changed, 193 insertions(+), 33 deletions(-) diff --git a/paddle/infrt/api/CMakeLists.txt b/paddle/infrt/api/CMakeLists.txt index 27d736cfdf7..6d4604edee6 100644 --- a/paddle/infrt/api/CMakeLists.txt +++ b/paddle/infrt/api/CMakeLists.txt @@ -7,3 +7,5 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/infrt_api_test.cc.in ${CMAKE_CURRENT_ # Disable temporarily for the external-kernel's mkldnn is outdate cc_test_tiny(test_infrt_api SRCS infrt_api_test.cc DEPS infrt ${MLIR_IR_LIBS}) +# TODO(inference): remove after optimize weight unfold. +set_tests_properties(test_infrt_api PROPERTIES TIMEOUT 200) diff --git a/paddle/infrt/api/infrt_api.cc b/paddle/infrt/api/infrt_api.cc index 2e8b64f768f..8b4b14a3ca0 100644 --- a/paddle/infrt/api/infrt_api.cc +++ b/paddle/infrt/api/infrt_api.cc @@ -17,12 +17,14 @@ #include #include #include +#include #include +#include +#include #include #include -#include "mlir/Pass/PassManager.h" #include "paddle/infrt/backends/host/phi_allocator.h" #include "paddle/infrt/common/global.h" #include "paddle/infrt/dialect/dense_tensor.h" @@ -48,8 +50,16 @@ #include "paddle/infrt/kernel/test_kernels.h" #include "paddle/infrt/tensor/tensor_map.h" +#include "paddle/infrt/dialect/infrt/pass/infrt_weights_unfold_pass.h" + #if defined(INFRT_WITH_GPU) && defined(INFRT_WITH_TRT) #include "paddle/infrt/kernel/tensorrt/registry.h" + +#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h" +#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h" +#include "paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h" +#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h" +#include "paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h" #endif using namespace infrt::host_context; // NOLINT @@ -233,17 +243,34 @@ int InfRtPredictor::Init(const InfRtConfig& config) { #endif // INFRT_WITH_GPU && INFRT_WITH_TRT #endif - auto module_op = impl_->module_gen_.ImportPaddleModel(config.model_dir(), - config.param_dir()); + mlir::ModuleOp module_op; + if (config.tensorrt_enabled()) { + module_op = impl_->module_gen_.ImportPaddleModel( + config.model_dir(), config.param_dir(), false); + } else { + module_op = impl_->module_gen_.ImportPaddleModel(config.model_dir(), + config.param_dir()); + } context->loadAllAvailableDialects(); ::mlir::PassManager pm(context); - ::mlir::OpPassManager& phi_pass_manager = pm.nest<::mlir::FuncOp>(); - std::vector<::infrt::Place> valid_places = {{::infrt::TargetType::CPU, - ::infrt::PrecisionType::FLOAT32, - ::infrt::LayoutType::NCHW}}; - phi_pass_manager.addPass(CreatePhiOpCvtPass(valid_places)); - phi_pass_manager.addPass(CreateInfrtOpFusePass()); + ::mlir::OpPassManager& pass_manager = pm.nest<::mlir::FuncOp>(); + if (config.tensorrt_enabled()) { + pass_manager.addPass(::infrt::CreateInfrtWeightsUnfoldPass()); + pass_manager.addPass(::infrt::trt::CreateTrtOpTellerPass()); + pass_manager.addPass(::infrt::trt::CreateTrtGraphFusePass()); + pass_manager.addPass(::infrt::trt::CreateTrtGraphSplitPass(1)); + pass_manager.addPass(::infrt::trt::CreateTrtOpConverterPass()); + pass_manager.addPass(::infrt::trt::CreateTrtTypeConvertPass()); + pass_manager.addPass(::mlir::createCanonicalizerPass()); + } else { + std::vector<::infrt::Place> valid_places = { + {::infrt::TargetType::CPU, + ::infrt::PrecisionType::FLOAT32, + ::infrt::LayoutType::NCHW}}; + pass_manager.addPass(CreatePhiOpCvtPass(valid_places)); + pass_manager.addPass(CreateInfrtOpFusePass()); + } if (mlir::failed(pm.run(module_op))) { std::cout << "\npass failed!\n" << std::endl; return 4; diff --git a/paddle/infrt/api/infrt_api.h b/paddle/infrt/api/infrt_api.h index cf14cab3c06..231f496bb89 100644 --- a/paddle/infrt/api/infrt_api.h +++ b/paddle/infrt/api/infrt_api.h @@ -26,6 +26,9 @@ class InfRtConfig { std::string param_dir_; std::vector shared_libs_; + // TODO(wilber): Design an easy-to-use interface. + bool tensorrt_enabled_{false}; + public: InfRtConfig() = default; void set_model_dir(const std::string& model_dir) { model_dir_ = model_dir; } @@ -39,6 +42,11 @@ class InfRtConfig { } const std::vector& shared_libs() const { return shared_libs_; } + // TODO(wilber): Design an easy-to-use interface. + void enable_tensorrt() { tensorrt_enabled_ = true; } + void disable_tensorrt() { tensorrt_enabled_ = false; } + bool tensorrt_enabled() const { return tensorrt_enabled_; } + virtual ~InfRtConfig() = default; }; diff --git a/paddle/infrt/api/infrt_api_test.cc.in b/paddle/infrt/api/infrt_api_test.cc.in index 6323b6a540a..13635ddaaab 100644 --- a/paddle/infrt/api/infrt_api_test.cc.in +++ b/paddle/infrt/api/infrt_api_test.cc.in @@ -57,4 +57,47 @@ TEST(InfRtPredictor, predictor) { ASSERT_EQ(output->dims(), ::phi::DDim({16, 10})); } +#ifdef INFRT_WITH_TRT +TEST(InfRtPredictor, trt_predictor) { + std::vector shared_libs; + + InfRtConfig config; + config.enable_tensorrt(); + + config.set_model_dir("@CMAKE_BINARY_DIR@/models/resnet50/model.pdmodel"); + config.set_param_dir("@CMAKE_BINARY_DIR@/models/resnet50/model.pdiparams"); + + std::unique_ptr predictor = CreateInfRtPredictor(config); + + ::infrt::backends::CpuPhiAllocator cpu_allocator; + ::phi::DenseTensor* input = predictor->GetInput(0); + input->Resize({2, 3, 256, 256}); + input->AllocateFrom(&cpu_allocator, ::phi::DataType::FLOAT32); + auto* input_data = reinterpret_cast(input->data()); + for (int i = 0; i < input->numel(); i++) input_data[i] = 1.0; + predictor->Run(); + + // get and print output tensor + auto* output = predictor->GetOutput(0); + + ASSERT_EQ(output->dims(), ::phi::DDim({2, 1000})); + const std::vector true_vals { + -3.319006264209747314e-01, -1.418896913528442383e+00, + -6.934890151023864746e-01, -1.498023152351379395e+00, + 3.078042864799499512e-01, -1.340998053550720215e+00, + 3.508620023727416992e+00, 2.274388313293457031e+00, + -1.321727275848388672e+00, -8.888689428567886353e-02, + -3.319006264209747314e-01, -1.418896913528442383e+00, + -6.934890151023864746e-01, -1.498023152351379395e+00, + 3.078042864799499512e-01, -1.340998053550720215e+00, + 3.508620023727416992e+00, 2.274388313293457031e+00, + -1.321727275848388672e+00, -8.888689428567886353e-02 + }; + + for (size_t i = 0; i < true_vals.size(); i+=100) { + CHECK_NEAR(output->data()[i*100], true_vals[i], 1e-5); + } +} +#endif + } // namespace infrt diff --git a/paddle/infrt/backends/tensorrt/trt_utils.h b/paddle/infrt/backends/tensorrt/trt_utils.h index c23d4608bb3..b2d5659fd25 100644 --- a/paddle/infrt/backends/tensorrt/trt_utils.h +++ b/paddle/infrt/backends/tensorrt/trt_utils.h @@ -50,7 +50,8 @@ inline nvinfer1::Dims VecToDims(const std::vector& vec) { assert(false); } // Pick first nvinfer1::Dims::MAX_DIMS elements - nvinfer1::Dims dims{std::min(static_cast(vec.size()), limit), {}}; + nvinfer1::Dims dims; + dims.nbDims = std::min(static_cast(vec.size()), limit); std::copy_n(vec.begin(), dims.nbDims, std::begin(dims.d)); return dims; } diff --git a/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td b/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td index c4707c367bc..2078ebb1442 100644 --- a/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td +++ b/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td @@ -34,7 +34,8 @@ def CreateHostInitedDenseTensorOp : PDT_Op<"create_host_inited_dense_tensor.f32" I64ArrayAttr:$dims, LayoutAttr:$layout, I64ArrayAttr:$lod, - F32ArrayAttr:$values + F32ArrayAttr:$values, + DefaultValuedAttr:$run_once ); let results = (outs DenseTensor:$output); } diff --git a/paddle/infrt/dialect/tensorrt/trt_exec.cc b/paddle/infrt/dialect/tensorrt/trt_exec.cc index 837ca209374..2682a744bb0 100644 --- a/paddle/infrt/dialect/tensorrt/trt_exec.cc +++ b/paddle/infrt/dialect/tensorrt/trt_exec.cc @@ -81,7 +81,7 @@ int main(int argc, char** argv) { trt_pass_manager.addPass(std::make_unique()); trt_pass_manager.addPass(std::make_unique(1)); trt_pass_manager.addPass(std::make_unique()); - trt_pass_manager.addPass(infrt::trt::createTrtTypeConvertPass()); + trt_pass_manager.addPass(infrt::trt::CreateTrtTypeConvertPass()); trt_pass_manager.addPass(::mlir::createCanonicalizerPass()); if (mlir::failed(pm.run(*module))) { std::cout << "\npass failed!\n" << std::endl; diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc index 55964b77e21..bbe9a76e87b 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc @@ -181,5 +181,10 @@ void TRTGraphFusePass::runOnFunction() { // TODO(wilber): Implement a toposort for efficiency. // topoSortBlock(body); } + +std::unique_ptr CreateTrtGraphFusePass() { + return std::make_unique(); +} + } // namespace trt } // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h index 4c721476230..515e73df854 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h @@ -17,6 +17,9 @@ namespace infrt { namespace trt { + +std::unique_ptr CreateTrtGraphFusePass(); + /* * trtGraphFusePass. * diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc index 2136f19fd1a..d5ce871edd1 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc @@ -44,5 +44,10 @@ void TRTGraphSplitPass::runOnFunction() { graph_op.erase(); } } + +std::unique_ptr CreateTrtGraphSplitPass(size_t min_subgraph_size) { + return std::make_unique(min_subgraph_size); +} + } // namespace trt } // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h index a71b9cb6536..fa101a8db02 100644 --- a/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h @@ -17,6 +17,9 @@ namespace infrt { namespace trt { + +std::unique_ptr CreateTrtGraphSplitPass(size_t min_subgraph_size); + /* * trtGraphSplitPass. * diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc index e40bbd67c0b..6776f01e36d 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc @@ -260,5 +260,9 @@ void TRTOpConverterPass::runOnOperation() { signalPassFailure(); } +std::unique_ptr CreateTrtOpConverterPass() { + return std::make_unique(); +} + } // namespace trt } // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h index 685686493c9..84bc7194636 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h @@ -20,6 +20,9 @@ namespace infrt { namespace trt { + +std::unique_ptr CreateTrtOpConverterPass(); + /* * trtOpConverterPass. * diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc index 77c22c12854..d7b917385cf 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc @@ -58,5 +58,10 @@ void TRTOpTellerPass::runOnFunction() { builder.create<::infrt::ReturnOp>(loc, op->getResults()); } } + +std::unique_ptr CreateTrtOpTellerPass() { + return std::make_unique(); +} + } // namespace trt } // namespace infrt diff --git a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h index 47375d838a9..566c5a45da0 100644 --- a/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h @@ -17,6 +17,9 @@ namespace infrt { namespace trt { + +std::unique_ptr CreateTrtOpTellerPass(); + /* * trtOpTellerPass. * diff --git a/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.cc b/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.cc index 0ed79c79db6..35c81d02301 100644 --- a/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.cc +++ b/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.cc @@ -175,7 +175,7 @@ void TrtTypeConvertPass::runOnFunction() { namespace infrt { namespace trt { -std::unique_ptr createTrtTypeConvertPass() { +std::unique_ptr CreateTrtTypeConvertPass() { return std::make_unique(); } diff --git a/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h b/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h index fbc30cdbeb7..68a15696b3e 100644 --- a/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h +++ b/paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h @@ -19,7 +19,7 @@ namespace infrt { namespace trt { -std::unique_ptr createTrtTypeConvertPass(); +std::unique_ptr CreateTrtTypeConvertPass(); } // namespace trt } // namespace infrt diff --git a/paddle/infrt/host_context/paddle_mlir.cc b/paddle/infrt/host_context/paddle_mlir.cc index 8b7bbe13260..0264920a600 100644 --- a/paddle/infrt/host_context/paddle_mlir.cc +++ b/paddle/infrt/host_context/paddle_mlir.cc @@ -15,11 +15,13 @@ #include "paddle/infrt/host_context/paddle_mlir.h" #include +#include #include "paddle/infrt/dialect/infrt/ir/basic_kernels.h" #include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h" #include "paddle/infrt/dialect/pd/common/pd_ops_info.h" #include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h" +#include "paddle/infrt/dialect/phi/ir/phi_base.h" MLIRModelGenImpl::MLIRModelGenImpl() : context_(infrt::Global::getMLIRContext()), builder_(context_) { @@ -35,32 +37,40 @@ MLIRModelGenImpl::MLIRModelGenImpl() infrt::paddle::framework_proto::ProgramDesc MLIRModelGenImpl::ParsePaddleModel( const std::string &model_file) { + model_file_ = model_file; infrt::paddle::framework_proto::ProgramDesc program_proto = *infrt::paddle::LoadProgram(model_file); return program_proto; } -mlir::ModuleOp MLIRModelGenImpl::ImportPaddleModel( - const std::string &model_dir) { +mlir::ModuleOp MLIRModelGenImpl::ImportPaddleModel(const std::string &model_dir, + bool arg_has_map) { + model_dir_ = model_dir; infrt::paddle::framework_proto::ProgramDesc program_proto = ParsePaddleModel(model_dir + "/__model__"); - return ImportPaddleModel(program_proto); + return ImportPaddleModel(program_proto, arg_has_map); } mlir::ModuleOp MLIRModelGenImpl::ImportPaddleModel( - const std::string &model_file, const std::string ¶m_file) { + const std::string &model_file, + const std::string ¶m_file, + bool arg_has_map) { + model_file_ = model_file; + params_file_ = param_file; infrt::paddle::framework_proto::ProgramDesc program_proto = ParsePaddleModel(model_file); - return ImportPaddleModel(program_proto); + return ImportPaddleModel(program_proto, arg_has_map); } mlir::ModuleOp MLIRModelGenImpl::ImportPaddleModel( - const infrt::paddle::framework_proto::ProgramDesc &program) { + const infrt::paddle::framework_proto::ProgramDesc &program, + bool arg_has_map) { main_block_ = program.blocks(0); - llvm::SmallVector operandTypes = GetModelInputsType(program); + llvm::SmallVector operandTypes = + GetModelInputsType(program, arg_has_map); llvm::SmallVector resultTypes = GetModelOutputsType(program); mlir::FuncOp mainFunc = UpdateModelModule(operandTypes, resultTypes); - UpdateModelParams(program, &mainFunc); + UpdateModelParams(program, &mainFunc, arg_has_map); UpdateModelOps(program); UpdateModelOutputs(program); return module_; @@ -83,9 +93,12 @@ mlir::FuncOp MLIRModelGenImpl::UpdateModelModule( } llvm::SmallVector MLIRModelGenImpl::GetModelInputsType( - const infrt::paddle::framework_proto::ProgramDesc &program) { + const infrt::paddle::framework_proto::ProgramDesc &program, + bool arg_has_map) { llvm::SmallVector operandTypes; - operandTypes.push_back(infrt::phi::DenseTensorMapType::get(context_)); + if (arg_has_map) { + operandTypes.push_back(infrt::phi::DenseTensorMapType::get(context_)); + } for (auto &op_desc : main_block_.ops()) { if (op_desc.type() != "feed") continue; for (int var_idx = 0; var_idx < op_desc.outputs_size(); ++var_idx) { @@ -155,9 +168,14 @@ void MLIRModelGenImpl::UpdateModelOps( void MLIRModelGenImpl::UpdateModelParams( const infrt::paddle::framework_proto::ProgramDesc &program, - mlir::FuncOp *mainFunc) { + mlir::FuncOp *mainFunc, + bool arg_has_map) { // update input vars - int input_index = 1; + int input_index; + if (arg_has_map) + input_index = 1; + else + input_index = 0; for (auto &op_desc : main_block_.ops()) { if (op_desc.type() == "feed") { for (int var_idx = 0; var_idx < op_desc.outputs_size(); ++var_idx) { @@ -170,9 +188,28 @@ void MLIRModelGenImpl::UpdateModelParams( } } } + ::mlir::Value map; + if (arg_has_map) { + map = mainFunc->getArgument(0); + } else { + builder_.setInsertionPointToStart(&mainFunc->body().front()); + if (!model_dir_.empty()) { + auto load_op = builder_.create<::infrt::phi::LoadParamsOp>( + mlir::UnknownLoc::get(context_), + ::infrt::phi::DenseTensorMapType::get(context_), + builder_.getStringAttr(model_dir_)); + map = load_op.out(); + } else if (!model_file_.empty()) { + auto load_op = builder_.create<::infrt::phi::LoadCombinedParamsOp>( + mlir::UnknownLoc::get(context_), + ::infrt::phi::DenseTensorMapType::get(context_), + builder_.getStringAttr(model_file_), + builder_.getStringAttr(params_file_)); + map = load_op.out(); + } + } // update persistable tensors - ::mlir::Value map = mainFunc->getArgument(0); for (int i = 0; i < main_block_.vars_size(); i++) { auto var_desc = main_block_.vars(i); if (params_map_.find(var_desc.name()) != params_map_.end()) continue; diff --git a/paddle/infrt/host_context/paddle_mlir.h b/paddle/infrt/host_context/paddle_mlir.h index 3d79d608e70..57bdc1b4857 100644 --- a/paddle/infrt/host_context/paddle_mlir.h +++ b/paddle/infrt/host_context/paddle_mlir.h @@ -37,8 +37,10 @@ class MLIRModelGenImpl { public: MLIRModelGenImpl(); mlir::ModuleOp ImportPaddleModel(const std::string &model_file, - const std::string ¶m_file); - mlir::ModuleOp ImportPaddleModel(const std::string &model_dir); + const std::string ¶m_file, + bool arg_has_map = true); + mlir::ModuleOp ImportPaddleModel(const std::string &model_dir, + bool arg_has_map = true); private: // parse paddle model file @@ -47,11 +49,13 @@ class MLIRModelGenImpl { // convert paddle model proto into paddle dialect module mlir::ModuleOp ImportPaddleModel( - const infrt::paddle::framework_proto::ProgramDesc &program); + const infrt::paddle::framework_proto::ProgramDesc &program, + bool arg_has_map); // get inputs and outputs info from program_desc llvm::SmallVector GetModelInputsType( - const infrt::paddle::framework_proto::ProgramDesc &program); + const infrt::paddle::framework_proto::ProgramDesc &program, + bool arg_has_map); llvm::SmallVector GetModelOutputsType( const infrt::paddle::framework_proto::ProgramDesc &program); // create main function module @@ -63,7 +67,8 @@ class MLIRModelGenImpl { // convert persistable params and inputs variable into mlir domain void UpdateModelParams( const infrt::paddle::framework_proto::ProgramDesc &program, - mlir::FuncOp *mainFunc); + mlir::FuncOp *mainFunc, + bool arg_has_map); // register model outpus into params_map_ void UpdateModelOutputs( const infrt::paddle::framework_proto::ProgramDesc &program); @@ -80,11 +85,16 @@ class MLIRModelGenImpl { void RegisterOpOutputVars(const infrt::paddle::framework_proto::OpDesc &op_, mlir::Operation *mlir_op_); + private: mlir::MLIRContext *context_; mlir::OpBuilder builder_; mlir::ModuleOp module_; infrt::paddle::framework_proto::BlockDesc main_block_; + std::string model_dir_{}; + std::string model_file_{}; + std::string params_file_{}; + std::map params_map_; }; diff --git a/paddle/infrt/kernel/phi/registry.cc b/paddle/infrt/kernel/phi/registry.cc index 928209ab182..848ff28faff 100644 --- a/paddle/infrt/kernel/phi/registry.cc +++ b/paddle/infrt/kernel/phi/registry.cc @@ -46,7 +46,7 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) { registry->AddKernel( "phi_dt.create_host_inited_dense_tensor.f32", INFRT_KERNEL(infrt::kernel::phi::CreateHostInitedDenseTensorF32), - {"dims", "lod", "layout", "values"}); + {"dims", "lod", "layout", "values", "run_once"}); registry->AddKernel("phi_dt.fill_dense_tensor.f32", INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32), -- GitLab