Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
a78ca1cf
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a78ca1cf
编写于
4月 10, 2022
作者:
W
Wilber
提交者:
GitHub
4月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
predictor support trt (#41556)
上级
e68da187
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
193 addition
and
33 deletion
+193
-33
paddle/infrt/api/CMakeLists.txt
paddle/infrt/api/CMakeLists.txt
+2
-0
paddle/infrt/api/infrt_api.cc
paddle/infrt/api/infrt_api.cc
+36
-9
paddle/infrt/api/infrt_api.h
paddle/infrt/api/infrt_api.h
+8
-0
paddle/infrt/api/infrt_api_test.cc.in
paddle/infrt/api/infrt_api_test.cc.in
+43
-0
paddle/infrt/backends/tensorrt/trt_utils.h
paddle/infrt/backends/tensorrt/trt_utils.h
+2
-1
paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td
paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td
+2
-1
paddle/infrt/dialect/tensorrt/trt_exec.cc
paddle/infrt/dialect/tensorrt/trt_exec.cc
+1
-1
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc
+5
-0
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h
+3
-0
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc
+5
-0
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h
+3
-0
paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc
paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc
+4
-0
paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h
paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h
+3
-0
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc
+5
-0
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h
+3
-0
paddle/infrt/dialect/tensorrt/trt_type_convert_pass.cc
paddle/infrt/dialect/tensorrt/trt_type_convert_pass.cc
+1
-1
paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h
paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h
+1
-1
paddle/infrt/host_context/paddle_mlir.cc
paddle/infrt/host_context/paddle_mlir.cc
+50
-13
paddle/infrt/host_context/paddle_mlir.h
paddle/infrt/host_context/paddle_mlir.h
+15
-5
paddle/infrt/kernel/phi/registry.cc
paddle/infrt/kernel/phi/registry.cc
+1
-1
未找到文件。
paddle/infrt/api/CMakeLists.txt
浏览文件 @
a78ca1cf
...
...
@@ -7,3 +7,5 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/infrt_api_test.cc.in ${CMAKE_CURRENT_
# Disable temporarily for the external-kernel's mkldnn is outdate
cc_test_tiny
(
test_infrt_api SRCS infrt_api_test.cc DEPS infrt
${
MLIR_IR_LIBS
}
)
# TODO(inference): remove after optimize weight unfold.
set_tests_properties
(
test_infrt_api PROPERTIES TIMEOUT 200
)
paddle/infrt/api/infrt_api.cc
浏览文件 @
a78ca1cf
...
...
@@ -17,12 +17,14 @@
#include <llvm/ADT/SmallVector.h>
#include <llvm/Support/DynamicLibrary.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/Parser.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Transforms/Passes.h>
#include <unordered_map>
#include <vector>
#include "mlir/Pass/PassManager.h"
#include "paddle/infrt/backends/host/phi_allocator.h"
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/dense_tensor.h"
...
...
@@ -48,8 +50,16 @@
#include "paddle/infrt/kernel/test_kernels.h"
#include "paddle/infrt/tensor/tensor_map.h"
#include "paddle/infrt/dialect/infrt/pass/infrt_weights_unfold_pass.h"
#if defined(INFRT_WITH_GPU) && defined(INFRT_WITH_TRT)
#include "paddle/infrt/kernel/tensorrt/registry.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h"
#include "paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h"
#endif
using
namespace
infrt
::
host_context
;
// NOLINT
...
...
@@ -233,17 +243,34 @@ int InfRtPredictor::Init(const InfRtConfig& config) {
#endif // INFRT_WITH_GPU && INFRT_WITH_TRT
#endif
auto
module_op
=
impl_
->
module_gen_
.
ImportPaddleModel
(
config
.
model_dir
(),
config
.
param_dir
());
mlir
::
ModuleOp
module_op
;
if
(
config
.
tensorrt_enabled
())
{
module_op
=
impl_
->
module_gen_
.
ImportPaddleModel
(
config
.
model_dir
(),
config
.
param_dir
(),
false
);
}
else
{
module_op
=
impl_
->
module_gen_
.
ImportPaddleModel
(
config
.
model_dir
(),
config
.
param_dir
());
}
context
->
loadAllAvailableDialects
();
::
mlir
::
PassManager
pm
(
context
);
::
mlir
::
OpPassManager
&
phi_pass_manager
=
pm
.
nest
<::
mlir
::
FuncOp
>
();
std
::
vector
<::
infrt
::
Place
>
valid_places
=
{{
::
infrt
::
TargetType
::
CPU
,
::
infrt
::
PrecisionType
::
FLOAT32
,
::
infrt
::
LayoutType
::
NCHW
}};
phi_pass_manager
.
addPass
(
CreatePhiOpCvtPass
(
valid_places
));
phi_pass_manager
.
addPass
(
CreateInfrtOpFusePass
());
::
mlir
::
OpPassManager
&
pass_manager
=
pm
.
nest
<::
mlir
::
FuncOp
>
();
if
(
config
.
tensorrt_enabled
())
{
pass_manager
.
addPass
(
::
infrt
::
CreateInfrtWeightsUnfoldPass
());
pass_manager
.
addPass
(
::
infrt
::
trt
::
CreateTrtOpTellerPass
());
pass_manager
.
addPass
(
::
infrt
::
trt
::
CreateTrtGraphFusePass
());
pass_manager
.
addPass
(
::
infrt
::
trt
::
CreateTrtGraphSplitPass
(
1
));
pass_manager
.
addPass
(
::
infrt
::
trt
::
CreateTrtOpConverterPass
());
pass_manager
.
addPass
(
::
infrt
::
trt
::
CreateTrtTypeConvertPass
());
pass_manager
.
addPass
(
::
mlir
::
createCanonicalizerPass
());
}
else
{
std
::
vector
<::
infrt
::
Place
>
valid_places
=
{
{
::
infrt
::
TargetType
::
CPU
,
::
infrt
::
PrecisionType
::
FLOAT32
,
::
infrt
::
LayoutType
::
NCHW
}};
pass_manager
.
addPass
(
CreatePhiOpCvtPass
(
valid_places
));
pass_manager
.
addPass
(
CreateInfrtOpFusePass
());
}
if
(
mlir
::
failed
(
pm
.
run
(
module_op
)))
{
std
::
cout
<<
"
\n
pass failed!
\n
"
<<
std
::
endl
;
return
4
;
...
...
paddle/infrt/api/infrt_api.h
浏览文件 @
a78ca1cf
...
...
@@ -26,6 +26,9 @@ class InfRtConfig {
std
::
string
param_dir_
;
std
::
vector
<
std
::
string
>
shared_libs_
;
// TODO(wilber): Design an easy-to-use interface.
bool
tensorrt_enabled_
{
false
};
public:
InfRtConfig
()
=
default
;
void
set_model_dir
(
const
std
::
string
&
model_dir
)
{
model_dir_
=
model_dir
;
}
...
...
@@ -39,6 +42,11 @@ class InfRtConfig {
}
const
std
::
vector
<
std
::
string
>&
shared_libs
()
const
{
return
shared_libs_
;
}
// TODO(wilber): Design an easy-to-use interface.
void
enable_tensorrt
()
{
tensorrt_enabled_
=
true
;
}
void
disable_tensorrt
()
{
tensorrt_enabled_
=
false
;
}
bool
tensorrt_enabled
()
const
{
return
tensorrt_enabled_
;
}
virtual
~
InfRtConfig
()
=
default
;
};
...
...
paddle/infrt/api/infrt_api_test.cc.in
浏览文件 @
a78ca1cf
...
...
@@ -57,4 +57,47 @@ TEST(InfRtPredictor, predictor) {
ASSERT_EQ(output->dims(), ::phi::DDim({16, 10}));
}
#ifdef INFRT_WITH_TRT
TEST(InfRtPredictor, trt_predictor) {
std::vector<std::string> shared_libs;
InfRtConfig config;
config.enable_tensorrt();
config.set_model_dir("@CMAKE_BINARY_DIR@/models/resnet50/model.pdmodel");
config.set_param_dir("@CMAKE_BINARY_DIR@/models/resnet50/model.pdiparams");
std::unique_ptr<InfRtPredictor> predictor = CreateInfRtPredictor(config);
::infrt::backends::CpuPhiAllocator cpu_allocator;
::phi::DenseTensor* input = predictor->GetInput(0);
input->Resize({2, 3, 256, 256});
input->AllocateFrom(&cpu_allocator, ::phi::DataType::FLOAT32);
auto* input_data = reinterpret_cast<float*>(input->data());
for (int i = 0; i < input->numel(); i++) input_data[i] = 1.0;
predictor->Run();
// get and print output tensor
auto* output = predictor->GetOutput(0);
ASSERT_EQ(output->dims(), ::phi::DDim({2, 1000}));
const std::vector<float> true_vals {
-3.319006264209747314e-01, -1.418896913528442383e+00,
-6.934890151023864746e-01, -1.498023152351379395e+00,
3.078042864799499512e-01, -1.340998053550720215e+00,
3.508620023727416992e+00, 2.274388313293457031e+00,
-1.321727275848388672e+00, -8.888689428567886353e-02,
-3.319006264209747314e-01, -1.418896913528442383e+00,
-6.934890151023864746e-01, -1.498023152351379395e+00,
3.078042864799499512e-01, -1.340998053550720215e+00,
3.508620023727416992e+00, 2.274388313293457031e+00,
-1.321727275848388672e+00, -8.888689428567886353e-02
};
for (size_t i = 0; i < true_vals.size(); i+=100) {
CHECK_NEAR(output->data<float>()[i*100], true_vals[i], 1e-5);
}
}
#endif
} // namespace infrt
paddle/infrt/backends/tensorrt/trt_utils.h
浏览文件 @
a78ca1cf
...
...
@@ -50,7 +50,8 @@ inline nvinfer1::Dims VecToDims(const std::vector<int>& vec) {
assert
(
false
);
}
// Pick first nvinfer1::Dims::MAX_DIMS elements
nvinfer1
::
Dims
dims
{
std
::
min
(
static_cast
<
int
>
(
vec
.
size
()),
limit
),
{}};
nvinfer1
::
Dims
dims
;
dims
.
nbDims
=
std
::
min
(
static_cast
<
int
>
(
vec
.
size
()),
limit
);
std
::
copy_n
(
vec
.
begin
(),
dims
.
nbDims
,
std
::
begin
(
dims
.
d
));
return
dims
;
}
...
...
paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td
浏览文件 @
a78ca1cf
...
...
@@ -34,7 +34,8 @@ def CreateHostInitedDenseTensorOp : PDT_Op<"create_host_inited_dense_tensor.f32"
I64ArrayAttr:$dims,
LayoutAttr:$layout,
I64ArrayAttr:$lod,
F32ArrayAttr:$values
F32ArrayAttr:$values,
DefaultValuedAttr<BoolAttr, "true">:$run_once
);
let results = (outs DenseTensor:$output);
}
...
...
paddle/infrt/dialect/tensorrt/trt_exec.cc
浏览文件 @
a78ca1cf
...
...
@@ -81,7 +81,7 @@ int main(int argc, char** argv) {
trt_pass_manager
.
addPass
(
std
::
make_unique
<
infrt
::
trt
::
TRTGraphFusePass
>
());
trt_pass_manager
.
addPass
(
std
::
make_unique
<
infrt
::
trt
::
TRTGraphSplitPass
>
(
1
));
trt_pass_manager
.
addPass
(
std
::
make_unique
<
infrt
::
trt
::
TRTOpConverterPass
>
());
trt_pass_manager
.
addPass
(
infrt
::
trt
::
c
reateTrtTypeConvertPass
());
trt_pass_manager
.
addPass
(
infrt
::
trt
::
C
reateTrtTypeConvertPass
());
trt_pass_manager
.
addPass
(
::
mlir
::
createCanonicalizerPass
());
if
(
mlir
::
failed
(
pm
.
run
(
*
module
)))
{
std
::
cout
<<
"
\n
pass failed!
\n
"
<<
std
::
endl
;
...
...
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.cc
浏览文件 @
a78ca1cf
...
...
@@ -181,5 +181,10 @@ void TRTGraphFusePass::runOnFunction() {
// TODO(wilber): Implement a toposort for efficiency.
// topoSortBlock(body);
}
std
::
unique_ptr
<
mlir
::
Pass
>
CreateTrtGraphFusePass
()
{
return
std
::
make_unique
<
TRTGraphFusePass
>
();
}
}
// namespace trt
}
// namespace infrt
paddle/infrt/dialect/tensorrt/trt_graph_fuse_pass.h
浏览文件 @
a78ca1cf
...
...
@@ -17,6 +17,9 @@
namespace
infrt
{
namespace
trt
{
std
::
unique_ptr
<
mlir
::
Pass
>
CreateTrtGraphFusePass
();
/*
* trtGraphFusePass.
*
...
...
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.cc
浏览文件 @
a78ca1cf
...
...
@@ -44,5 +44,10 @@ void TRTGraphSplitPass::runOnFunction() {
graph_op
.
erase
();
}
}
std
::
unique_ptr
<
mlir
::
Pass
>
CreateTrtGraphSplitPass
(
size_t
min_subgraph_size
)
{
return
std
::
make_unique
<
TRTGraphSplitPass
>
(
min_subgraph_size
);
}
}
// namespace trt
}
// namespace infrt
paddle/infrt/dialect/tensorrt/trt_graph_split_pass.h
浏览文件 @
a78ca1cf
...
...
@@ -17,6 +17,9 @@
namespace
infrt
{
namespace
trt
{
std
::
unique_ptr
<
mlir
::
Pass
>
CreateTrtGraphSplitPass
(
size_t
min_subgraph_size
);
/*
* trtGraphSplitPass.
*
...
...
paddle/infrt/dialect/tensorrt/trt_op_converter_pass.cc
浏览文件 @
a78ca1cf
...
...
@@ -260,5 +260,9 @@ void TRTOpConverterPass::runOnOperation() {
signalPassFailure
();
}
std
::
unique_ptr
<
mlir
::
Pass
>
CreateTrtOpConverterPass
()
{
return
std
::
make_unique
<
TRTOpConverterPass
>
();
}
}
// namespace trt
}
// namespace infrt
paddle/infrt/dialect/tensorrt/trt_op_converter_pass.h
浏览文件 @
a78ca1cf
...
...
@@ -20,6 +20,9 @@
namespace
infrt
{
namespace
trt
{
std
::
unique_ptr
<
mlir
::
Pass
>
CreateTrtOpConverterPass
();
/*
* trtOpConverterPass.
*
...
...
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.cc
浏览文件 @
a78ca1cf
...
...
@@ -58,5 +58,10 @@ void TRTOpTellerPass::runOnFunction() {
builder
.
create
<::
infrt
::
ReturnOp
>
(
loc
,
op
->
getResults
());
}
}
std
::
unique_ptr
<
mlir
::
Pass
>
CreateTrtOpTellerPass
()
{
return
std
::
make_unique
<
TRTOpTellerPass
>
();
}
}
// namespace trt
}
// namespace infrt
paddle/infrt/dialect/tensorrt/trt_op_teller_pass.h
浏览文件 @
a78ca1cf
...
...
@@ -17,6 +17,9 @@
namespace
infrt
{
namespace
trt
{
std
::
unique_ptr
<
mlir
::
Pass
>
CreateTrtOpTellerPass
();
/*
* trtOpTellerPass.
*
...
...
paddle/infrt/dialect/tensorrt/trt_type_convert_pass.cc
浏览文件 @
a78ca1cf
...
...
@@ -175,7 +175,7 @@ void TrtTypeConvertPass::runOnFunction() {
namespace
infrt
{
namespace
trt
{
std
::
unique_ptr
<
mlir
::
Pass
>
c
reateTrtTypeConvertPass
()
{
std
::
unique_ptr
<
mlir
::
Pass
>
C
reateTrtTypeConvertPass
()
{
return
std
::
make_unique
<
TrtTypeConvertPass
>
();
}
...
...
paddle/infrt/dialect/tensorrt/trt_type_convert_pass.h
浏览文件 @
a78ca1cf
...
...
@@ -19,7 +19,7 @@
namespace
infrt
{
namespace
trt
{
std
::
unique_ptr
<
mlir
::
Pass
>
c
reateTrtTypeConvertPass
();
std
::
unique_ptr
<
mlir
::
Pass
>
C
reateTrtTypeConvertPass
();
}
// namespace trt
}
// namespace infrt
paddle/infrt/host_context/paddle_mlir.cc
浏览文件 @
a78ca1cf
...
...
@@ -15,11 +15,13 @@
#include "paddle/infrt/host_context/paddle_mlir.h"
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/Value.h>
#include "paddle/infrt/dialect/infrt/ir/basic_kernels.h"
#include "paddle/infrt/dialect/infrt/ir/infrt_dialect.h"
#include "paddle/infrt/dialect/pd/common/pd_ops_info.h"
#include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h"
#include "paddle/infrt/dialect/phi/ir/phi_base.h"
MLIRModelGenImpl
::
MLIRModelGenImpl
()
:
context_
(
infrt
::
Global
::
getMLIRContext
()),
builder_
(
context_
)
{
...
...
@@ -35,32 +37,40 @@ MLIRModelGenImpl::MLIRModelGenImpl()
infrt
::
paddle
::
framework_proto
::
ProgramDesc
MLIRModelGenImpl
::
ParsePaddleModel
(
const
std
::
string
&
model_file
)
{
model_file_
=
model_file
;
infrt
::
paddle
::
framework_proto
::
ProgramDesc
program_proto
=
*
infrt
::
paddle
::
LoadProgram
(
model_file
);
return
program_proto
;
}
mlir
::
ModuleOp
MLIRModelGenImpl
::
ImportPaddleModel
(
const
std
::
string
&
model_dir
)
{
mlir
::
ModuleOp
MLIRModelGenImpl
::
ImportPaddleModel
(
const
std
::
string
&
model_dir
,
bool
arg_has_map
)
{
model_dir_
=
model_dir
;
infrt
::
paddle
::
framework_proto
::
ProgramDesc
program_proto
=
ParsePaddleModel
(
model_dir
+
"/__model__"
);
return
ImportPaddleModel
(
program_proto
);
return
ImportPaddleModel
(
program_proto
,
arg_has_map
);
}
mlir
::
ModuleOp
MLIRModelGenImpl
::
ImportPaddleModel
(
const
std
::
string
&
model_file
,
const
std
::
string
&
param_file
)
{
const
std
::
string
&
model_file
,
const
std
::
string
&
param_file
,
bool
arg_has_map
)
{
model_file_
=
model_file
;
params_file_
=
param_file
;
infrt
::
paddle
::
framework_proto
::
ProgramDesc
program_proto
=
ParsePaddleModel
(
model_file
);
return
ImportPaddleModel
(
program_proto
);
return
ImportPaddleModel
(
program_proto
,
arg_has_map
);
}
mlir
::
ModuleOp
MLIRModelGenImpl
::
ImportPaddleModel
(
const
infrt
::
paddle
::
framework_proto
::
ProgramDesc
&
program
)
{
const
infrt
::
paddle
::
framework_proto
::
ProgramDesc
&
program
,
bool
arg_has_map
)
{
main_block_
=
program
.
blocks
(
0
);
llvm
::
SmallVector
<
mlir
::
Type
,
4
>
operandTypes
=
GetModelInputsType
(
program
);
llvm
::
SmallVector
<
mlir
::
Type
,
4
>
operandTypes
=
GetModelInputsType
(
program
,
arg_has_map
);
llvm
::
SmallVector
<
mlir
::
Type
,
4
>
resultTypes
=
GetModelOutputsType
(
program
);
mlir
::
FuncOp
mainFunc
=
UpdateModelModule
(
operandTypes
,
resultTypes
);
UpdateModelParams
(
program
,
&
mainFunc
);
UpdateModelParams
(
program
,
&
mainFunc
,
arg_has_map
);
UpdateModelOps
(
program
);
UpdateModelOutputs
(
program
);
return
module_
;
...
...
@@ -83,9 +93,12 @@ mlir::FuncOp MLIRModelGenImpl::UpdateModelModule(
}
llvm
::
SmallVector
<
mlir
::
Type
,
4
>
MLIRModelGenImpl
::
GetModelInputsType
(
const
infrt
::
paddle
::
framework_proto
::
ProgramDesc
&
program
)
{
const
infrt
::
paddle
::
framework_proto
::
ProgramDesc
&
program
,
bool
arg_has_map
)
{
llvm
::
SmallVector
<
mlir
::
Type
,
4
>
operandTypes
;
operandTypes
.
push_back
(
infrt
::
phi
::
DenseTensorMapType
::
get
(
context_
));
if
(
arg_has_map
)
{
operandTypes
.
push_back
(
infrt
::
phi
::
DenseTensorMapType
::
get
(
context_
));
}
for
(
auto
&
op_desc
:
main_block_
.
ops
())
{
if
(
op_desc
.
type
()
!=
"feed"
)
continue
;
for
(
int
var_idx
=
0
;
var_idx
<
op_desc
.
outputs_size
();
++
var_idx
)
{
...
...
@@ -155,9 +168,14 @@ void MLIRModelGenImpl::UpdateModelOps(
void
MLIRModelGenImpl
::
UpdateModelParams
(
const
infrt
::
paddle
::
framework_proto
::
ProgramDesc
&
program
,
mlir
::
FuncOp
*
mainFunc
)
{
mlir
::
FuncOp
*
mainFunc
,
bool
arg_has_map
)
{
// update input vars
int
input_index
=
1
;
int
input_index
;
if
(
arg_has_map
)
input_index
=
1
;
else
input_index
=
0
;
for
(
auto
&
op_desc
:
main_block_
.
ops
())
{
if
(
op_desc
.
type
()
==
"feed"
)
{
for
(
int
var_idx
=
0
;
var_idx
<
op_desc
.
outputs_size
();
++
var_idx
)
{
...
...
@@ -170,9 +188,28 @@ void MLIRModelGenImpl::UpdateModelParams(
}
}
}
::
mlir
::
Value
map
;
if
(
arg_has_map
)
{
map
=
mainFunc
->
getArgument
(
0
);
}
else
{
builder_
.
setInsertionPointToStart
(
&
mainFunc
->
body
().
front
());
if
(
!
model_dir_
.
empty
())
{
auto
load_op
=
builder_
.
create
<::
infrt
::
phi
::
LoadParamsOp
>
(
mlir
::
UnknownLoc
::
get
(
context_
),
::
infrt
::
phi
::
DenseTensorMapType
::
get
(
context_
),
builder_
.
getStringAttr
(
model_dir_
));
map
=
load_op
.
out
();
}
else
if
(
!
model_file_
.
empty
())
{
auto
load_op
=
builder_
.
create
<::
infrt
::
phi
::
LoadCombinedParamsOp
>
(
mlir
::
UnknownLoc
::
get
(
context_
),
::
infrt
::
phi
::
DenseTensorMapType
::
get
(
context_
),
builder_
.
getStringAttr
(
model_file_
),
builder_
.
getStringAttr
(
params_file_
));
map
=
load_op
.
out
();
}
}
// update persistable tensors
::
mlir
::
Value
map
=
mainFunc
->
getArgument
(
0
);
for
(
int
i
=
0
;
i
<
main_block_
.
vars_size
();
i
++
)
{
auto
var_desc
=
main_block_
.
vars
(
i
);
if
(
params_map_
.
find
(
var_desc
.
name
())
!=
params_map_
.
end
())
continue
;
...
...
paddle/infrt/host_context/paddle_mlir.h
浏览文件 @
a78ca1cf
...
...
@@ -37,8 +37,10 @@ class MLIRModelGenImpl {
public:
MLIRModelGenImpl
();
mlir
::
ModuleOp
ImportPaddleModel
(
const
std
::
string
&
model_file
,
const
std
::
string
&
param_file
);
mlir
::
ModuleOp
ImportPaddleModel
(
const
std
::
string
&
model_dir
);
const
std
::
string
&
param_file
,
bool
arg_has_map
=
true
);
mlir
::
ModuleOp
ImportPaddleModel
(
const
std
::
string
&
model_dir
,
bool
arg_has_map
=
true
);
private:
// parse paddle model file
...
...
@@ -47,11 +49,13 @@ class MLIRModelGenImpl {
// convert paddle model proto into paddle dialect module
mlir
::
ModuleOp
ImportPaddleModel
(
const
infrt
::
paddle
::
framework_proto
::
ProgramDesc
&
program
);
const
infrt
::
paddle
::
framework_proto
::
ProgramDesc
&
program
,
bool
arg_has_map
);
// get inputs and outputs info from program_desc
llvm
::
SmallVector
<
mlir
::
Type
,
4
>
GetModelInputsType
(
const
infrt
::
paddle
::
framework_proto
::
ProgramDesc
&
program
);
const
infrt
::
paddle
::
framework_proto
::
ProgramDesc
&
program
,
bool
arg_has_map
);
llvm
::
SmallVector
<
mlir
::
Type
,
4
>
GetModelOutputsType
(
const
infrt
::
paddle
::
framework_proto
::
ProgramDesc
&
program
);
// create main function module
...
...
@@ -63,7 +67,8 @@ class MLIRModelGenImpl {
// convert persistable params and inputs variable into mlir domain
void
UpdateModelParams
(
const
infrt
::
paddle
::
framework_proto
::
ProgramDesc
&
program
,
mlir
::
FuncOp
*
mainFunc
);
mlir
::
FuncOp
*
mainFunc
,
bool
arg_has_map
);
// register model outpus into params_map_
void
UpdateModelOutputs
(
const
infrt
::
paddle
::
framework_proto
::
ProgramDesc
&
program
);
...
...
@@ -80,11 +85,16 @@ class MLIRModelGenImpl {
void
RegisterOpOutputVars
(
const
infrt
::
paddle
::
framework_proto
::
OpDesc
&
op_
,
mlir
::
Operation
*
mlir_op_
);
private:
mlir
::
MLIRContext
*
context_
;
mlir
::
OpBuilder
builder_
;
mlir
::
ModuleOp
module_
;
infrt
::
paddle
::
framework_proto
::
BlockDesc
main_block_
;
std
::
string
model_dir_
{};
std
::
string
model_file_
{};
std
::
string
params_file_
{};
std
::
map
<
std
::
string
,
mlir
::
Value
>
params_map_
;
};
...
...
paddle/infrt/kernel/phi/registry.cc
浏览文件 @
a78ca1cf
...
...
@@ -46,7 +46,7 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) {
registry
->
AddKernel
(
"phi_dt.create_host_inited_dense_tensor.f32"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
CreateHostInitedDenseTensorF32
),
{
"dims"
,
"lod"
,
"layout"
,
"values"
});
{
"dims"
,
"lod"
,
"layout"
,
"values"
,
"run_once"
});
registry
->
AddKernel
(
"phi_dt.fill_dense_tensor.f32"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
FillDenseTensorF32
),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录