未验证 提交 a78ca1cf 编写于 作者: W Wilber 提交者: GitHub

predictor support trt (#41556)

上级 e68da187
...@@ -7,3 +7,5 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/infrt_api_test.cc.in ${CMAKE_CURRENT_ ...@@ -7,3 +7,5 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/infrt_api_test.cc.in ${CMAKE_CURRENT_
# Disable temporarily for the external-kernel's mkldnn is outdate # Disable temporarily for the external-kernel's mkldnn is outdate
cc_test_tiny(test_infrt_api SRCS infrt_api_test.cc DEPS infrt ${MLIR_IR_LIBS}) cc_test_tiny(test_infrt_api SRCS infrt_api_test.cc DEPS infrt ${MLIR_IR_LIBS})
# TODO(inference): remove after optimize weight unfold.
set_tests_properties(test_infrt_api PROPERTIES TIMEOUT 200)
...@@ -17,12 +17,14 @@ ...@@ -17,12 +17,14 @@
#include <llvm/ADT/SmallVector.h> #include <llvm/ADT/SmallVector.h>
#include <llvm/Support/DynamicLibrary.h> #include <llvm/Support/DynamicLibrary.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h> #include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/Parser.h> #include <mlir/Parser.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Transforms/Passes.h>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "mlir/Pass/PassManager.h"
#include "paddle/infrt/backends/host/phi_allocator.h" #include "paddle/infrt/backends/host/phi_allocator.h"
#include "paddle/infrt/common/global.h" #include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/dense_tensor.h" #include "paddle/infrt/dialect/dense_tensor.h"
...@@ -48,8 +50,16 @@ ...@@ -48,8 +50,16 @@
#include "paddle/infrt/kernel/test_kernels.h" #include "paddle/infrt/kernel/test_kernels.h"
#include "paddle/infrt/tensor/tensor_map.h" #include "paddle/infrt/tensor/tensor_map.h"
#include "paddle/infrt/dialect/infrt/pass/infrt_weights_unfold_pass.h"
#if defined(INFRT_WITH_GPU) && defined(INFRT_WITH_TRT) #if defined(INFRT_WITH_GPU) && defined(INFRT_WITH_TRT)
#include "paddle/infrt/kernel/tensorrt/registry.h" #include "paddle/infrt/kernel/tensorrt/registry.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h"
#endif #endif
using namespace infrt::host_context; // NOLINT using namespace infrt::host_context; // NOLINT
...@@ -233,17 +243,34 @@ int InfRtPredictor::Init(const InfRtConfig& config) { ...@@ -233,17 +243,34 @@ int InfRtPredictor::Init(const InfRtConfig& config) {
#endif // INFRT_WITH_GPU && INFRT_WITH_TRT #endif // INFRT_WITH_GPU && INFRT_WITH_TRT
#endif #endif
auto module_op = impl_->module_gen_.ImportPaddleModel(config.model_dir(), mlir::ModuleOp module_op;
if (config.tensorrt_enabled()) {
module_op = impl_->module_gen_.ImportPaddleModel(
config.model_dir(), config.param_dir(), false);
} else {
module_op = impl_->module_gen_.ImportPaddleModel(config.model_dir(),
config.param_dir()); config.param_dir());
}
context->loadAllAvailableDialects(); context->loadAllAvailableDialects();
::mlir::PassManager pm(context); ::mlir::PassManager pm(context);
::mlir::OpPassManager& phi_pass_manager = pm.nest<::mlir::FuncOp>(); ::mlir::OpPassManager& pass_manager = pm.nest<::mlir::FuncOp>();
std::vector<::infrt::Place> valid_places = {{::infrt::TargetType::CPU, if (config.tensorrt_enabled()) {
pass_manager.addPass(::infrt::CreateInfrtWeightsUnfoldPass());
pass_manager.addPass(::infrt::trt::CreateTrtOpTellerPass());
pass_manager.addPass(::infrt::trt::CreateTrtGraphFusePass());
pass_manager.addPass(::infrt::trt::CreateTrtGraphSplitPass(1));
pass_manager.addPass(::infrt::trt::CreateTrtOpConverterPass());
pass_manager.addPass(::infrt::trt::CreateTrtTypeConvertPass());
pass_manager.addPass(::mlir::createCanonicalizerPass());
} else {
std::vector<::infrt::Place> valid_places = {
{::infrt::TargetType::CPU,
::infrt::PrecisionType::FLOAT32, ::infrt::PrecisionType::FLOAT32,
::infrt::LayoutType::NCHW}}; ::infrt::LayoutType::NCHW}};
phi_pass_manager.addPass(CreatePhiOpCvtPass(valid_places)); pass_manager.addPass(CreatePhiOpCvtPass(valid_places));
phi_pass_manager.addPass(CreateInfrtOpFusePass()); pass_manager.addPass(CreateInfrtOpFusePass());
}
if (mlir::failed(pm.run(module_op))) { if (mlir::failed(pm.run(module_op))) {
std::cout << "\npass failed!\n" << std::endl; std::cout << "\npass failed!\n" << std::endl;
return 4; return 4;
......
...@@ -26,6 +26,9 @@ class InfRtConfig { ...@@ -26,6 +26,9 @@ class InfRtConfig {
std::string param_dir_; std::string param_dir_;
std::vector<std::string> shared_libs_; std::vector<std::string> shared_libs_;
// TODO(wilber): Design an easy-to-use interface.
bool tensorrt_enabled_{false};
public: public:
InfRtConfig() = default; InfRtConfig() = default;
void set_model_dir(const std::string& model_dir) { model_dir_ = model_dir; } void set_model_dir(const std::string& model_dir) { model_dir_ = model_dir; }
...@@ -39,6 +42,11 @@ class InfRtConfig { ...@@ -39,6 +42,11 @@ class InfRtConfig {
} }
const std::vector<std::string>& shared_libs() const { return shared_libs_; } const std::vector<std::string>& shared_libs() const { return shared_libs_; }
// TODO(wilber): Design an easy-to-use interface.
void enable_tensorrt() { tensorrt_enabled_ = true; }
void disable_tensorrt() { tensorrt_enabled_ = false; }
bool tensorrt_enabled() const { return tensorrt_enabled_; }
virtual ~InfRtConfig() = default; virtual ~InfRtConfig() = default;
}; };
......
...@@ -57,4 +57,47 @@ TEST(InfRtPredictor, predictor) { ...@@ -57,4 +57,47 @@ TEST(InfRtPredictor, predictor) {
ASSERT_EQ(output->dims(), ::phi::DDim({16, 10})); ASSERT_EQ(output->dims(), ::phi::DDim({16, 10}));
} }
#ifdef INFRT_WITH_TRT
TEST(InfRtPredictor, trt_predictor) {
std::vector<std::string> shared_libs;
InfRtConfig config;
config.enable_tensorrt();
config.set_model_dir("@CMAKE_BINARY_DIR@/models/resnet50/model.pdmodel");
config.set_param_dir("@CMAKE_BINARY_DIR@/models/resnet50/model.pdiparams");
std::unique_ptr<InfRtPredictor> predictor = CreateInfRtPredictor(config);
::infrt::backends::CpuPhiAllocator cpu_allocator;
::phi::DenseTensor* input = predictor->GetInput(0);
input->Resize({2, 3, 256, 256});
input->AllocateFrom(&cpu_allocator, ::phi::DataType::FLOAT32);
auto* input_data = reinterpret_cast<float*>(input->data());
for (int i = 0; i < input->numel(); i++) input_data[i] = 1.0;
predictor->Run();
// get and print output tensor
auto* output = predictor->GetOutput(0);
ASSERT_EQ(output->dims(), ::phi::DDim({2, 1000}));
const std::vector<float> true_vals {
-3.319006264209747314e-01, -1.418896913528442383e+00,
-6.934890151023864746e-01, -1.498023152351379395e+00,
3.078042864799499512e-01, -1.340998053550720215e+00,
3.508620023727416992e+00, 2.274388313293457031e+00,
-1.321727275848388672e+00, -8.888689428567886353e-02,
-3.319006264209747314e-01, -1.418896913528442383e+00,
-6.934890151023864746e-01, -1.498023152351379395e+00,
3.078042864799499512e-01, -1.340998053550720215e+00,
3.508620023727416992e+00, 2.274388313293457031e+00,
-1.321727275848388672e+00, -8.888689428567886353e-02
};
for (size_t i = 0; i < true_vals.size(); i+=100) {
CHECK_NEAR(output->data<float>()[i*100], true_vals[i], 1e-5);
}
}
#endif
} // namespace infrt } // namespace infrt
...@@ -50,7 +50,8 @@ inline nvinfer1::Dims VecToDims(const std::vector<int>& vec) { ...@@ -50,7 +50,8 @@ inline nvinfer1::Dims VecToDims(const std::vector<int>& vec) {
assert(false); assert(false);
} }
// Pick first nvinfer1::Dims::MAX_DIMS elements // Pick first nvinfer1::Dims::MAX_DIMS elements
nvinfer1::Dims dims{std::min(static_cast<int>(vec.size()), limit), {}}; nvinfer1::Dims dims;
dims.nbDims = std::min(static_cast<int>(vec.size()), limit);
std::copy_n(vec.begin(), dims.nbDims, std::begin(dims.d)); std::copy_n(vec.begin(), dims.nbDims, std::begin(dims.d));
return dims; return dims;
} }
......
...@@ -34,7 +34,8 @@ def CreateHostInitedDenseTensorOp : PDT_Op<"create_host_inited_dense_tensor.f32" ...@@ -34,7 +34,8 @@ def CreateHostInitedDenseTensorOp : PDT_Op<"create_host_inited_dense_tensor.f32"
I64ArrayAttr:$dims, I64ArrayAttr:$dims,
LayoutAttr:$layout, LayoutAttr:$layout,
I64ArrayAttr:$lod, I64ArrayAttr:$lod,
F32ArrayAttr:$values F32ArrayAttr:$values,
DefaultValuedAttr<BoolAttr, "true">:$run_once
); );
let results = (outs DenseTensor:$output); let results = (outs DenseTensor:$output);
} }
......
...@@ -81,7 +81,7 @@ int main(int argc, char** argv) { ...@@ -81,7 +81,7 @@ int main(int argc, char** argv) {
trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTGraphFusePass>()); trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTGraphFusePass>());
trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTGraphSplitPass>(1)); trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTGraphSplitPass>(1));
trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTOpConverterPass>()); trt_pass_manager.addPass(std::make_unique<infrt::trt::TRTOpConverterPass>());
trt_pass_manager.addPass(infrt::trt::createTrtTypeConvertPass()); trt_pass_manager.addPass(infrt::trt::CreateTrtTypeConvertPass());
trt_pass_manager.addPass(::mlir::createCanonicalizerPass()); trt_pass_manager.addPass(::mlir::createCanonicalizerPass());
if (mlir::failed(pm.run(*module))) { if (mlir::failed(pm.run(*module))) {
std::cout << "\npass failed!\n" << std::endl; std::cout << "\npass failed!\n" << std::endl;
......
...@@ -181,5 +181,10 @@ void TRTGraphFusePass::runOnFunction() { ...@@ -181,5 +181,10 @@ void TRTGraphFusePass::runOnFunction() {
// TODO(wilber): Implement a toposort for efficiency. // TODO(wilber): Implement a toposort for efficiency.
// topoSortBlock(body); // topoSortBlock(body);
} }
std::unique_ptr<mlir::Pass> CreateTrtGraphFusePass() {
return std::make_unique<TRTGraphFusePass>();
}
} // namespace trt } // namespace trt
} // namespace infrt } // namespace infrt
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
namespace infrt { namespace infrt {
namespace trt { namespace trt {
std::unique_ptr<mlir::Pass> CreateTrtGraphFusePass();
/* /*
* trtGraphFusePass. * trtGraphFusePass.
* *
......
...@@ -44,5 +44,10 @@ void TRTGraphSplitPass::runOnFunction() { ...@@ -44,5 +44,10 @@ void TRTGraphSplitPass::runOnFunction() {
graph_op.erase(); graph_op.erase();
} }
} }
std::unique_ptr<mlir::Pass> CreateTrtGraphSplitPass(size_t min_subgraph_size) {
return std::make_unique<TRTGraphSplitPass>(min_subgraph_size);
}
} // namespace trt } // namespace trt
} // namespace infrt } // namespace infrt
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
namespace infrt { namespace infrt {
namespace trt { namespace trt {
std::unique_ptr<mlir::Pass> CreateTrtGraphSplitPass(size_t min_subgraph_size);
/* /*
* trtGraphSplitPass. * trtGraphSplitPass.
* *
......
...@@ -260,5 +260,9 @@ void TRTOpConverterPass::runOnOperation() { ...@@ -260,5 +260,9 @@ void TRTOpConverterPass::runOnOperation() {
signalPassFailure(); signalPassFailure();
} }
std::unique_ptr<mlir::Pass> CreateTrtOpConverterPass() {
return std::make_unique<TRTOpConverterPass>();
}
} // namespace trt } // namespace trt
} // namespace infrt } // namespace infrt
...@@ -20,6 +20,9 @@ ...@@ -20,6 +20,9 @@
namespace infrt { namespace infrt {
namespace trt { namespace trt {
std::unique_ptr<mlir::Pass> CreateTrtOpConverterPass();
/* /*
* trtOpConverterPass. * trtOpConverterPass.
* *
......
...@@ -58,5 +58,10 @@ void TRTOpTellerPass::runOnFunction() { ...@@ -58,5 +58,10 @@ void TRTOpTellerPass::runOnFunction() {
builder.create<::infrt::ReturnOp>(loc, op->getResults()); builder.create<::infrt::ReturnOp>(loc, op->getResults());
} }
} }
std::unique_ptr<mlir::Pass> CreateTrtOpTellerPass() {
return std::make_unique<TRTOpTellerPass>();
}
} // namespace trt } // namespace trt
} // namespace infrt } // namespace infrt
...@@ -17,6 +17,9 @@ ...@@ -17,6 +17,9 @@
namespace infrt { namespace infrt {
namespace trt { namespace trt {
std::unique_ptr<mlir::Pass> CreateTrtOpTellerPass();
/* /*
* trtOpTellerPass. * trtOpTellerPass.
* *
......
...@@ -175,7 +175,7 @@ void TrtTypeConvertPass::runOnFunction() { ...@@ -175,7 +175,7 @@ void TrtTypeConvertPass::runOnFunction() {
namespace infrt { namespace infrt {
namespace trt { namespace trt {
std::unique_ptr<mlir::Pass> createTrtTypeConvertPass() { std::unique_ptr<mlir::Pass> CreateTrtTypeConvertPass() {
return std::make_unique<TrtTypeConvertPass>(); return std::make_unique<TrtTypeConvertPass>();
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace infrt { namespace infrt {
namespace trt { namespace trt {
std::unique_ptr<mlir::Pass> createTrtTypeConvertPass(); std::unique_ptr<mlir::Pass> CreateTrtTypeConvertPass();
} // namespace trt } // namespace trt
} // namespace infrt } // namespace infrt
...@@ -15,11 +15,13 @@ ...@@ -15,11 +15,13 @@
#include "paddle/infrt/host_context/paddle_mlir.h" #include "paddle/infrt/host_context/paddle_mlir.h"
#include <mlir/IR/OpDefinition.h> #include <mlir/IR/OpDefinition.h>
#include <mlir/IR/Value.h>
#include "paddle/infrt/dialect/infrt/ir/basic_kernels.h" #include "paddle/infrt/dialect/infrt/ir/basic_kernels.h"
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h" #include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
#include "paddle/infrt/dialect/pd/common/pd_ops_info.h" #include "paddle/infrt/dialect/pd/common/pd_ops_info.h"
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h" #include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
#include "paddle/infrt/dialect/phi/ir/phi_base.h"
MLIRModelGenImpl::MLIRModelGenImpl() MLIRModelGenImpl::MLIRModelGenImpl()
: context_(infrt::Global::getMLIRContext()), builder_(context_) { : context_(infrt::Global::getMLIRContext()), builder_(context_) {
...@@ -35,32 +37,40 @@ MLIRModelGenImpl::MLIRModelGenImpl() ...@@ -35,32 +37,40 @@ MLIRModelGenImpl::MLIRModelGenImpl()
infrt::paddle::framework_proto::ProgramDesc MLIRModelGenImpl::ParsePaddleModel( infrt::paddle::framework_proto::ProgramDesc MLIRModelGenImpl::ParsePaddleModel(
const std::string &model_file) { const std::string &model_file) {
model_file_ = model_file;
infrt::paddle::framework_proto::ProgramDesc program_proto = infrt::paddle::framework_proto::ProgramDesc program_proto =
*infrt::paddle::LoadProgram(model_file); *infrt::paddle::LoadProgram(model_file);
return program_proto; return program_proto;
} }
mlir::ModuleOp MLIRModelGenImpl::ImportPaddleModel( mlir::ModuleOp MLIRModelGenImpl::ImportPaddleModel(const std::string &model_dir,
const std::string &model_dir) { bool arg_has_map) {
model_dir_ = model_dir;
infrt::paddle::framework_proto::ProgramDesc program_proto = infrt::paddle::framework_proto::ProgramDesc program_proto =
ParsePaddleModel(model_dir + "/__model__"); ParsePaddleModel(model_dir + "/__model__");
return ImportPaddleModel(program_proto); return ImportPaddleModel(program_proto, arg_has_map);
} }
mlir::ModuleOp MLIRModelGenImpl::ImportPaddleModel( mlir::ModuleOp MLIRModelGenImpl::ImportPaddleModel(
const std::string &model_file, const std::string &param_file) { const std::string &model_file,
const std::string &param_file,
bool arg_has_map) {
model_file_ = model_file;
params_file_ = param_file;
infrt::paddle::framework_proto::ProgramDesc program_proto = infrt::paddle::framework_proto::ProgramDesc program_proto =
ParsePaddleModel(model_file); ParsePaddleModel(model_file);
return ImportPaddleModel(program_proto); return ImportPaddleModel(program_proto, arg_has_map);
} }
mlir::ModuleOp MLIRModelGenImpl::ImportPaddleModel( mlir::ModuleOp MLIRModelGenImpl::ImportPaddleModel(
const infrt::paddle::framework_proto::ProgramDesc &program) { const infrt::paddle::framework_proto::ProgramDesc &program,
bool arg_has_map) {
main_block_ = program.blocks(0); main_block_ = program.blocks(0);
llvm::SmallVector<mlir::Type, 4> operandTypes = GetModelInputsType(program); llvm::SmallVector<mlir::Type, 4> operandTypes =
GetModelInputsType(program, arg_has_map);
llvm::SmallVector<mlir::Type, 4> resultTypes = GetModelOutputsType(program); llvm::SmallVector<mlir::Type, 4> resultTypes = GetModelOutputsType(program);
mlir::FuncOp mainFunc = UpdateModelModule(operandTypes, resultTypes); mlir::FuncOp mainFunc = UpdateModelModule(operandTypes, resultTypes);
UpdateModelParams(program, &mainFunc); UpdateModelParams(program, &mainFunc, arg_has_map);
UpdateModelOps(program); UpdateModelOps(program);
UpdateModelOutputs(program); UpdateModelOutputs(program);
return module_; return module_;
...@@ -83,9 +93,12 @@ mlir::FuncOp MLIRModelGenImpl::UpdateModelModule( ...@@ -83,9 +93,12 @@ mlir::FuncOp MLIRModelGenImpl::UpdateModelModule(
} }
llvm::SmallVector<mlir::Type, 4> MLIRModelGenImpl::GetModelInputsType( llvm::SmallVector<mlir::Type, 4> MLIRModelGenImpl::GetModelInputsType(
const infrt::paddle::framework_proto::ProgramDesc &program) { const infrt::paddle::framework_proto::ProgramDesc &program,
bool arg_has_map) {
llvm::SmallVector<mlir::Type, 4> operandTypes; llvm::SmallVector<mlir::Type, 4> operandTypes;
if (arg_has_map) {
operandTypes.push_back(infrt::phi::DenseTensorMapType::get(context_)); operandTypes.push_back(infrt::phi::DenseTensorMapType::get(context_));
}
for (auto &op_desc : main_block_.ops()) { for (auto &op_desc : main_block_.ops()) {
if (op_desc.type() != "feed") continue; if (op_desc.type() != "feed") continue;
for (int var_idx = 0; var_idx < op_desc.outputs_size(); ++var_idx) { for (int var_idx = 0; var_idx < op_desc.outputs_size(); ++var_idx) {
...@@ -155,9 +168,14 @@ void MLIRModelGenImpl::UpdateModelOps( ...@@ -155,9 +168,14 @@ void MLIRModelGenImpl::UpdateModelOps(
void MLIRModelGenImpl::UpdateModelParams( void MLIRModelGenImpl::UpdateModelParams(
const infrt::paddle::framework_proto::ProgramDesc &program, const infrt::paddle::framework_proto::ProgramDesc &program,
mlir::FuncOp *mainFunc) { mlir::FuncOp *mainFunc,
bool arg_has_map) {
// update input vars // update input vars
int input_index = 1; int input_index;
if (arg_has_map)
input_index = 1;
else
input_index = 0;
for (auto &op_desc : main_block_.ops()) { for (auto &op_desc : main_block_.ops()) {
if (op_desc.type() == "feed") { if (op_desc.type() == "feed") {
for (int var_idx = 0; var_idx < op_desc.outputs_size(); ++var_idx) { for (int var_idx = 0; var_idx < op_desc.outputs_size(); ++var_idx) {
...@@ -170,9 +188,28 @@ void MLIRModelGenImpl::UpdateModelParams( ...@@ -170,9 +188,28 @@ void MLIRModelGenImpl::UpdateModelParams(
} }
} }
} }
::mlir::Value map;
if (arg_has_map) {
map = mainFunc->getArgument(0);
} else {
builder_.setInsertionPointToStart(&mainFunc->body().front());
if (!model_dir_.empty()) {
auto load_op = builder_.create<::infrt::phi::LoadParamsOp>(
mlir::UnknownLoc::get(context_),
::infrt::phi::DenseTensorMapType::get(context_),
builder_.getStringAttr(model_dir_));
map = load_op.out();
} else if (!model_file_.empty()) {
auto load_op = builder_.create<::infrt::phi::LoadCombinedParamsOp>(
mlir::UnknownLoc::get(context_),
::infrt::phi::DenseTensorMapType::get(context_),
builder_.getStringAttr(model_file_),
builder_.getStringAttr(params_file_));
map = load_op.out();
}
}
// update persistable tensors // update persistable tensors
::mlir::Value map = mainFunc->getArgument(0);
for (int i = 0; i < main_block_.vars_size(); i++) { for (int i = 0; i < main_block_.vars_size(); i++) {
auto var_desc = main_block_.vars(i); auto var_desc = main_block_.vars(i);
if (params_map_.find(var_desc.name()) != params_map_.end()) continue; if (params_map_.find(var_desc.name()) != params_map_.end()) continue;
......
...@@ -37,8 +37,10 @@ class MLIRModelGenImpl { ...@@ -37,8 +37,10 @@ class MLIRModelGenImpl {
public: public:
MLIRModelGenImpl(); MLIRModelGenImpl();
mlir::ModuleOp ImportPaddleModel(const std::string &model_file, mlir::ModuleOp ImportPaddleModel(const std::string &model_file,
const std::string &param_file); const std::string &param_file,
mlir::ModuleOp ImportPaddleModel(const std::string &model_dir); bool arg_has_map = true);
mlir::ModuleOp ImportPaddleModel(const std::string &model_dir,
bool arg_has_map = true);
private: private:
// parse paddle model file // parse paddle model file
...@@ -47,11 +49,13 @@ class MLIRModelGenImpl { ...@@ -47,11 +49,13 @@ class MLIRModelGenImpl {
// convert paddle model proto into paddle dialect module // convert paddle model proto into paddle dialect module
mlir::ModuleOp ImportPaddleModel( mlir::ModuleOp ImportPaddleModel(
const infrt::paddle::framework_proto::ProgramDesc &program); const infrt::paddle::framework_proto::ProgramDesc &program,
bool arg_has_map);
// get inputs and outputs info from program_desc // get inputs and outputs info from program_desc
llvm::SmallVector<mlir::Type, 4> GetModelInputsType( llvm::SmallVector<mlir::Type, 4> GetModelInputsType(
const infrt::paddle::framework_proto::ProgramDesc &program); const infrt::paddle::framework_proto::ProgramDesc &program,
bool arg_has_map);
llvm::SmallVector<mlir::Type, 4> GetModelOutputsType( llvm::SmallVector<mlir::Type, 4> GetModelOutputsType(
const infrt::paddle::framework_proto::ProgramDesc &program); const infrt::paddle::framework_proto::ProgramDesc &program);
// create main function module // create main function module
...@@ -63,7 +67,8 @@ class MLIRModelGenImpl { ...@@ -63,7 +67,8 @@ class MLIRModelGenImpl {
// convert persistable params and inputs variable into mlir domain // convert persistable params and inputs variable into mlir domain
void UpdateModelParams( void UpdateModelParams(
const infrt::paddle::framework_proto::ProgramDesc &program, const infrt::paddle::framework_proto::ProgramDesc &program,
mlir::FuncOp *mainFunc); mlir::FuncOp *mainFunc,
bool arg_has_map);
// register model outpus into params_map_ // register model outpus into params_map_
void UpdateModelOutputs( void UpdateModelOutputs(
const infrt::paddle::framework_proto::ProgramDesc &program); const infrt::paddle::framework_proto::ProgramDesc &program);
...@@ -80,11 +85,16 @@ class MLIRModelGenImpl { ...@@ -80,11 +85,16 @@ class MLIRModelGenImpl {
void RegisterOpOutputVars(const infrt::paddle::framework_proto::OpDesc &op_, void RegisterOpOutputVars(const infrt::paddle::framework_proto::OpDesc &op_,
mlir::Operation *mlir_op_); mlir::Operation *mlir_op_);
private:
mlir::MLIRContext *context_; mlir::MLIRContext *context_;
mlir::OpBuilder builder_; mlir::OpBuilder builder_;
mlir::ModuleOp module_; mlir::ModuleOp module_;
infrt::paddle::framework_proto::BlockDesc main_block_; infrt::paddle::framework_proto::BlockDesc main_block_;
std::string model_dir_{};
std::string model_file_{};
std::string params_file_{};
std::map<std::string, mlir::Value> params_map_; std::map<std::string, mlir::Value> params_map_;
}; };
......
...@@ -46,7 +46,7 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) { ...@@ -46,7 +46,7 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) {
registry->AddKernel( registry->AddKernel(
"phi_dt.create_host_inited_dense_tensor.f32", "phi_dt.create_host_inited_dense_tensor.f32",
INFRT_KERNEL(infrt::kernel::phi::CreateHostInitedDenseTensorF32), INFRT_KERNEL(infrt::kernel::phi::CreateHostInitedDenseTensorF32),
{"dims", "lod", "layout", "values"}); {"dims", "lod", "layout", "values", "run_once"});
registry->AddKernel("phi_dt.fill_dense_tensor.f32", registry->AddKernel("phi_dt.fill_dense_tensor.f32",
INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32), INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册