未验证 提交 40e51b25 编写于 作者: 石晓伟 提交者: GitHub

python inference supports custom operators, test=develop (#32533)

上级 8e66046b
......@@ -28,5 +28,8 @@ void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name);
void RegisterOperatorWithMetaInfoMap(
const paddle::OpMetaInfoMap& op_meta_info_map);
// Interface for selective register custom op.
void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos);
} // namespace framework
} // namespace paddle
......@@ -32,10 +32,10 @@ cc_library(paddle_pass_builder SRCS paddle_pass_builder.cc)
if(WITH_CRYPTO)
cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor scope reset_tensor_array
analysis_config zero_copy_tensor trainer_desc_proto paddle_crypto)
analysis_config zero_copy_tensor trainer_desc_proto paddle_crypto custom_operator)
else()
cc_library(paddle_inference_api SRCS api.cc api_impl.cc helper.cc DEPS lod_tensor scope reset_tensor_array
analysis_config zero_copy_tensor trainer_desc_proto)
analysis_config zero_copy_tensor trainer_desc_proto custom_operator)
endif()
if(WIN32)
......
......@@ -628,7 +628,7 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
// This function can only be executed once per process.
static std::once_flag custom_operators_registered;
std::call_once(custom_operators_registered,
[]() { paddle::RegisterAllCustomOperator(); });
[]() { inference::RegisterAllCustomOperator(); });
if (config.use_gpu()) {
static std::once_flag gflags_initialized;
......
......@@ -13,6 +13,9 @@
// limitations under the License.
#include "paddle/fluid/inference/api/helper.h"
#include "paddle/fluid/extension/include/ext_op_meta_info.h"
#include "paddle/fluid/framework/custom_operator.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace inference {
......@@ -40,5 +43,20 @@ std::string to_string<std::vector<std::vector<float>>>(
return ss.str();
}
void RegisterAllCustomOperator() {
auto &op_meta_info_map = OpMetaInfoMap::Instance();
const auto &meta_info_map = op_meta_info_map.GetMap();
for (auto &pair : meta_info_map) {
const auto &all_op_kernels{framework::OperatorWithKernel::AllOpKernels()};
if (all_op_kernels.find(pair.first) == all_op_kernels.end()) {
framework::RegisterOperatorWithMetaInfo(pair.second);
} else {
LOG(INFO) << "The operator `" << pair.first
<< "` has been registered. "
"Therefore, we will not repeat the registration here.";
}
}
}
} // namespace inference
} // namespace paddle
......@@ -398,5 +398,7 @@ static bool IsFileExists(const std::string &path) {
return exists;
}
void RegisterAllCustomOperator();
} // namespace inference
} // namespace paddle
......@@ -255,6 +255,35 @@ class TestNewCustomOpSetUpInstall(unittest.TestCase):
format(predict, predict_infer))
paddle.disable_static()
def test_static_save_and_run_inference_predictor(self):
paddle.enable_static()
np_data = np.random.random((1, 1, 28, 28)).astype("float32")
np_label = np.random.random((1, 1)).astype("int64")
path_prefix = "custom_op_inference/custom_relu"
from paddle.inference import Config
from paddle.inference import create_predictor
for device in self.devices:
predict = custom_relu_static_inference(
self.custom_ops[0], device, np_data, np_label, path_prefix)
# load inference model
config = Config(path_prefix + ".pdmodel",
path_prefix + ".pdiparams")
predictor = create_predictor(config)
input_tensor = predictor.get_input_handle(predictor.get_input_names(
)[0])
input_tensor.reshape(np_data.shape)
input_tensor.copy_from_cpu(np_data.copy())
predictor.run()
output_tensor = predictor.get_output_handle(
predictor.get_output_names()[0])
predict_infer = output_tensor.copy_to_cpu()
self.assertTrue(
np.isclose(
predict, predict_infer, rtol=5e-5).any(),
"custom op predict: {},\n custom op infer predict: {}".format(
predict, predict_infer))
paddle.disable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册