From 9029fde7487f6c168e4ee22c886a3e7f2a255537 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Fri, 13 May 2022 14:50:31 +0800 Subject: [PATCH] [IPU] fix ipu and add python infer api, test=develop (#42724) * [IPU] fix ipu and add python infer api, test=develop * [IPU] add paddlepaddle-ipu package name, test=develop --- paddle/fluid/framework/ir/CMakeLists.txt | 1 - paddle/fluid/pybind/inference_api.cc | 8 ++++++++ python/CMakeLists.txt | 2 ++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index b430a409e9..bfefb89ade 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -164,7 +164,6 @@ if(WITH_IPU) pass_library(infer_shape_pass base DIR ipu) pass_library(delete_scale_op_pass base DIR ipu) pass_library(avg_shard_pass base DIR ipu) - pass_library(transfer_cast_op_pass base DIR ipu) endif() cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector ) diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 1bbe6808b2..9447814840 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -601,6 +601,14 @@ void BindAnalysisConfig(py::module *m) { .def("set_xpu_device_id", &AnalysisConfig::SetXpuDeviceId, py::arg("device_id") = 0) .def("enable_npu", &AnalysisConfig::EnableNpu, py::arg("device_id") = 0) + .def("enable_ipu", &AnalysisConfig::EnableIpu, + py::arg("ipu_device_num") = 1, py::arg("ipu_micro_batch_size") = 1, + py::arg("ipu_enable_pipelining") = false, + py::arg("ipu_batches_per_step") = 1) + .def("set_ipu_config", &AnalysisConfig::SetIpuConfig, + py::arg("ipu_enable_fp16") = false, py::arg("ipu_replica_num") = 1, + py::arg("ipu_available_memory_proportion") = 1.0, + py::arg("ipu_enable_half_partial") = false) .def("disable_gpu", &AnalysisConfig::DisableGpu) .def("enable_onnxruntime", &AnalysisConfig::EnableONNXRuntime) .def("disable_onnxruntime", &AnalysisConfig::DisableONNXRuntime) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index fe5f2c25ca..fdcd560658 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -14,6 +14,8 @@ elseif(WITH_ASCEND_CL) SET(PACKAGE_NAME "paddlepaddle-npu") elseif(WITH_XPU) SET(PACKAGE_NAME "paddlepaddle-xpu") +elseif(WITH_IPU) + SET(PACKAGE_NAME "paddlepaddle-ipu") else() SET(PACKAGE_NAME "paddlepaddle") endif() -- GitLab