未验证 提交 628f0856 编写于 作者: 石晓伟 提交者: GitHub

[Cherry-pick] inference modification for custom operator (#31283) (#31300)

上级 227a6775
...@@ -667,10 +667,6 @@ void RegisterOperatorWithMetaInfo( ...@@ -667,10 +667,6 @@ void RegisterOperatorWithMetaInfo(
void RegisterOperatorWithMetaInfoMap( void RegisterOperatorWithMetaInfoMap(
const paddle::OpMetaInfoMap& op_meta_info_map) { const paddle::OpMetaInfoMap& op_meta_info_map) {
auto& meta_info_map = op_meta_info_map.GetMap(); auto& meta_info_map = op_meta_info_map.GetMap();
PADDLE_ENFORCE_EQ(meta_info_map.empty(), false,
platform::errors::PreconditionNotMet(
"No custom operator that needs to be registered."));
VLOG(1) << "Custom Operator: size of op meta info map - " VLOG(1) << "Custom Operator: size of op meta info map - "
<< meta_info_map.size(); << meta_info_map.size();
// pair: {op_type, OpMetaInfo} // pair: {op_type, OpMetaInfo}
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/extension/include/ext_op_meta_info.h"
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
...@@ -612,6 +613,12 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor< ...@@ -612,6 +613,12 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Note: Each config can only be used for one predictor.")); "Note: Each config can only be used for one predictor."));
// Register custom operators compiled by the user.
// This function can only be executed once per process.
static std::once_flag custom_operators_registered;
std::call_once(custom_operators_registered,
[]() { paddle::RegisterAllCustomOperator(); });
if (config.use_gpu()) { if (config.use_gpu()) {
static std::once_flag gflags_initialized; static std::once_flag gflags_initialized;
static bool process_level_allocator_enabled; static bool process_level_allocator_enabled;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册