Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
40e51b25
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
40e51b25
编写于
4月 26, 2021
作者:
石
石晓伟
提交者:
GitHub
4月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
python inference supports custom operators, test=develop (#32533)
上级
8e66046b
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
55 addition
and
3 deletion
+55
-3
paddle/fluid/framework/custom_operator.h
paddle/fluid/framework/custom_operator.h
+3
-0
paddle/fluid/inference/api/CMakeLists.txt
paddle/fluid/inference/api/CMakeLists.txt
+2
-2
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+1
-1
paddle/fluid/inference/api/helper.cc
paddle/fluid/inference/api/helper.cc
+18
-0
paddle/fluid/inference/api/helper.h
paddle/fluid/inference/api/helper.h
+2
-0
python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py
...paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py
+29
-0
未找到文件。
paddle/fluid/framework/custom_operator.h
浏览文件 @
40e51b25
...
...
@@ -28,5 +28,8 @@ void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name);
void
RegisterOperatorWithMetaInfoMap
(
const
paddle
::
OpMetaInfoMap
&
op_meta_info_map
);
// Interface for selective register custom op.
void
RegisterOperatorWithMetaInfo
(
const
std
::
vector
<
OpMetaInfo
>&
op_meta_infos
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/api/CMakeLists.txt
浏览文件 @
40e51b25
...
...
@@ -32,10 +32,10 @@ cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc)
if
(
WITH_CRYPTO
)
cc_library
(
paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor scope reset_tensor_array
analysis_config zero_copy_tensor trainer_desc_proto paddle_crypto
)
analysis_config zero_copy_tensor trainer_desc_proto paddle_crypto
custom_operator
)
else
()
cc_library
(
paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor scope reset_tensor_array
analysis_config zero_copy_tensor trainer_desc_proto
)
analysis_config zero_copy_tensor trainer_desc_proto
custom_operator
)
endif
()
if
(
WIN32
)
...
...
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
40e51b25
...
...
@@ -628,7 +628,7 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
// This function can only be executed once per process.
static
std
::
once_flag
custom_operators_registered
;
std
::
call_once
(
custom_operators_registered
,
[]()
{
paddl
e
::
RegisterAllCustomOperator
();
});
[]()
{
inferenc
e
::
RegisterAllCustomOperator
();
});
if
(
config
.
use_gpu
())
{
static
std
::
once_flag
gflags_initialized
;
...
...
paddle/fluid/inference/api/helper.cc
浏览文件 @
40e51b25
...
...
@@ -13,6 +13,9 @@
// limitations under the License.
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/extension/include/ext_op_meta_info.h"
#include "paddle/fluid/framework/custom_operator.h"
#include "paddle/fluid/framework/operator.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -40,5 +43,20 @@ std::string to_string<std::vector<std::vector<float>>>(
return
ss
.
str
();
}
void
RegisterAllCustomOperator
()
{
auto
&
op_meta_info_map
=
OpMetaInfoMap
::
Instance
();
const
auto
&
meta_info_map
=
op_meta_info_map
.
GetMap
();
for
(
auto
&
pair
:
meta_info_map
)
{
const
auto
&
all_op_kernels
{
framework
::
OperatorWithKernel
::
AllOpKernels
()};
if
(
all_op_kernels
.
find
(
pair
.
first
)
==
all_op_kernels
.
end
())
{
framework
::
RegisterOperatorWithMetaInfo
(
pair
.
second
);
}
else
{
LOG
(
INFO
)
<<
"The operator `"
<<
pair
.
first
<<
"` has been registered. "
"Therefore, we will not repeat the registration here."
;
}
}
}
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/api/helper.h
浏览文件 @
40e51b25
...
...
@@ -398,5 +398,7 @@ static bool IsFileExists(const std::string &path) {
return
exists
;
}
void
RegisterAllCustomOperator
();
}
// namespace inference
}
// namespace paddle
python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py
浏览文件 @
40e51b25
...
...
@@ -255,6 +255,35 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
format
(
predict
,
predict_infer
))
paddle
.
disable_static
()
def
test_static_save_and_run_inference_predictor
(
self
):
paddle
.
enable_static
()
np_data
=
np
.
random
.
random
((
1
,
1
,
28
,
28
)).
astype
(
"float32"
)
np_label
=
np
.
random
.
random
((
1
,
1
)).
astype
(
"int64"
)
path_prefix
=
"custom_op_inference/custom_relu"
from
paddle.inference
import
Config
from
paddle.inference
import
create_predictor
for
device
in
self
.
devices
:
predict
=
custom_relu_static_inference
(
self
.
custom_ops
[
0
],
device
,
np_data
,
np_label
,
path_prefix
)
# load inference model
config
=
Config
(
path_prefix
+
".pdmodel"
,
path_prefix
+
".pdiparams"
)
predictor
=
create_predictor
(
config
)
input_tensor
=
predictor
.
get_input_handle
(
predictor
.
get_input_names
(
)[
0
])
input_tensor
.
reshape
(
np_data
.
shape
)
input_tensor
.
copy_from_cpu
(
np_data
.
copy
())
predictor
.
run
()
output_tensor
=
predictor
.
get_output_handle
(
predictor
.
get_output_names
()[
0
])
predict_infer
=
output_tensor
.
copy_to_cpu
()
self
.
assertTrue
(
np
.
isclose
(
predict
,
predict_infer
,
rtol
=
5e-5
).
any
(),
"custom op predict: {},
\n
custom op infer predict: {}"
.
format
(
predict
,
predict_infer
))
paddle
.
disable_static
()
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录