diff --git a/paddle/fluid/framework/custom_operator.h b/paddle/fluid/framework/custom_operator.h index 117841f80cf47ed95251fee1d01f7fd87caa600b..259901c09f3e00729876d7bea062237ad5bad94a 100644 --- a/paddle/fluid/framework/custom_operator.h +++ b/paddle/fluid/framework/custom_operator.h @@ -28,5 +28,8 @@ void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name); void RegisterOperatorWithMetaInfoMap( const paddle::OpMetaInfoMap& op_meta_info_map); +// Interface for selective register custom op. +void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 03f86cc7ba6de608f2f755ad5ff0f76578575575..82c95ba2c95712d2ebe3aa80286689028febf3fe 100755 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -32,10 +32,10 @@ cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc) if(WITH_CRYPTO) cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor scope reset_tensor_array - analysis_config zero_copy_tensor trainer_desc_proto paddle_crypto) + analysis_config zero_copy_tensor trainer_desc_proto paddle_crypto custom_operator) else() cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor scope reset_tensor_array - analysis_config zero_copy_tensor trainer_desc_proto) + analysis_config zero_copy_tensor trainer_desc_proto custom_operator) endif() if(WIN32) diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 95b08318368438943b538774cbeb83e2d92a5103..6a6be14fd5977dcb7a7909b17a7684780391042c 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -628,7 +628,7 @@ std::unique_ptr CreatePaddlePredictor< // This function can only be executed once per process. static std::once_flag custom_operators_registered; std::call_once(custom_operators_registered, - []() { paddle::RegisterAllCustomOperator(); }); + []() { inference::RegisterAllCustomOperator(); }); if (config.use_gpu()) { static std::once_flag gflags_initialized; diff --git a/paddle/fluid/inference/api/helper.cc b/paddle/fluid/inference/api/helper.cc index 9cc491e10d691a206dd903b78c0ea570741da44c..d78560239de50eb224641583d62b55bac75be465 100644 --- a/paddle/fluid/inference/api/helper.cc +++ b/paddle/fluid/inference/api/helper.cc @@ -13,6 +13,9 @@ // limitations under the License. #include "paddle/fluid/inference/api/helper.h" +#include "paddle/fluid/extension/include/ext_op_meta_info.h" +#include "paddle/fluid/framework/custom_operator.h" +#include "paddle/fluid/framework/operator.h" namespace paddle { namespace inference { @@ -40,5 +43,20 @@ std::string to_string>>( return ss.str(); } +void RegisterAllCustomOperator() { + auto &op_meta_info_map = OpMetaInfoMap::Instance(); + const auto &meta_info_map = op_meta_info_map.GetMap(); + for (auto &pair : meta_info_map) { + const auto &all_op_kernels{framework::OperatorWithKernel::AllOpKernels()}; + if (all_op_kernels.find(pair.first) == all_op_kernels.end()) { + framework::RegisterOperatorWithMetaInfo(pair.second); + } else { + LOG(INFO) << "The operator `" << pair.first + << "` has been registered. " + "Therefore, we will not repeat the registration here."; + } + } +} + } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h index 14b968f5834da8618f6af16aa8c25e1d1baaae5e..c6d25137594b76a1ff67d9fb25b2480372c3eefa 100644 --- a/paddle/fluid/inference/api/helper.h +++ b/paddle/fluid/inference/api/helper.h @@ -398,5 +398,7 @@ static bool IsFileExists(const std::string &path) { return exists; } +void RegisterAllCustomOperator(); + } // namespace inference } // namespace paddle diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py index 7f5c76d0aeeae3a20b3310144dfc8f45c6461f84..642e93ebcb85e0eab3aa373243d7b44b42aab443 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_setup.py @@ -255,6 +255,35 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase): format(predict, predict_infer)) paddle.disable_static() + def test_static_save_and_run_inference_predictor(self): + paddle.enable_static() + np_data = np.random.random((1, 1, 28, 28)).astype("float32") + np_label = np.random.random((1, 1)).astype("int64") + path_prefix = "custom_op_inference/custom_relu" + from paddle.inference import Config + from paddle.inference import create_predictor + for device in self.devices: + predict = custom_relu_static_inference( + self.custom_ops[0], device, np_data, np_label, path_prefix) + # load inference model + config = Config(path_prefix + ".pdmodel", + path_prefix + ".pdiparams") + predictor = create_predictor(config) + input_tensor = predictor.get_input_handle(predictor.get_input_names( + )[0]) + input_tensor.reshape(np_data.shape) + input_tensor.copy_from_cpu(np_data.copy()) + predictor.run() + output_tensor = predictor.get_output_handle( + predictor.get_output_names()[0]) + predict_infer = output_tensor.copy_to_cpu() + self.assertTrue( + np.isclose( + predict, predict_infer, rtol=5e-5).any(), + "custom op predict: {},\n custom op infer predict: {}".format( + predict, predict_infer)) + paddle.disable_static() + if __name__ == '__main__': unittest.main()