未验证 提交 4c925242 编写于 作者: J JingZhuangzhuang 提交者: GitHub

add _get_phi_kernel_name interface (#47033)

上级 c74bf018
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/paddle_inference_api.h"
#include "paddle/fluid/inference/api/paddle_pass_builder.h" #include "paddle/fluid/inference/api/paddle_pass_builder.h"
#include "paddle/fluid/inference/utils/io_utils.h" #include "paddle/fluid/inference/utils/io_utils.h"
#include "paddle/phi/core/compat/convert_utils.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/phi/core/cuda_stream.h" #include "paddle/phi/core/cuda_stream.h"
...@@ -401,6 +402,12 @@ void BindInferenceApi(py::module *m) { ...@@ -401,6 +402,12 @@ void BindInferenceApi(py::module *m) {
new paddle_infer::Predictor(config)); new paddle_infer::Predictor(config));
return pred; return pred;
}); });
m->def(
"_get_phi_kernel_name",
[](const std::string &fluid_op_name) {
return phi::TransToPhiKernelName(fluid_op_name);
},
py::return_value_policy::reference);
m->def("copy_tensor", &CopyPaddleInferTensor); m->def("copy_tensor", &CopyPaddleInferTensor);
m->def("paddle_dtype_size", &paddle::PaddleDtypeSize); m->def("paddle_dtype_size", &paddle::PaddleDtypeSize);
m->def("paddle_tensor_to_bytes", &SerializePDTensorToBytes); m->def("paddle_tensor_to_bytes", &SerializePDTensorToBytes);
......
...@@ -282,6 +282,7 @@ try: ...@@ -282,6 +282,7 @@ try:
from .libpaddle import _get_current_stream from .libpaddle import _get_current_stream
from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent
from .libpaddle import _set_current_stream from .libpaddle import _set_current_stream
from .libpaddle import _get_phi_kernel_name
if sys.platform != 'win32': if sys.platform != 'win32':
from .libpaddle import _set_process_pids from .libpaddle import _set_process_pids
from .libpaddle import _erase_process_pids from .libpaddle import _erase_process_pids
......
...@@ -15,4 +15,4 @@ ...@@ -15,4 +15,4 @@
from .wrapper import Config, DataType, PlaceType, PrecisionType, Tensor, Predictor from .wrapper import Config, DataType, PlaceType, PrecisionType, Tensor, Predictor
from .wrapper import convert_to_mixed_precision from .wrapper import convert_to_mixed_precision
from ..core import create_predictor, get_version, get_num_bytes_of_data_type, PredictorPool, get_trt_compile_version, get_trt_runtime_version from ..core import create_predictor, get_version, _get_phi_kernel_name, get_num_bytes_of_data_type, PredictorPool, get_trt_compile_version, get_trt_runtime_version
...@@ -19,6 +19,7 @@ from ..fluid.inference import PrecisionType # noqa: F401 ...@@ -19,6 +19,7 @@ from ..fluid.inference import PrecisionType # noqa: F401
from ..fluid.inference import Tensor # noqa: F401 from ..fluid.inference import Tensor # noqa: F401
from ..fluid.inference import Predictor # noqa: F401 from ..fluid.inference import Predictor # noqa: F401
from ..fluid.inference import create_predictor # noqa: F401 from ..fluid.inference import create_predictor # noqa: F401
from ..fluid.inference import _get_phi_kernel_name
from ..fluid.inference import get_version # noqa: F401 from ..fluid.inference import get_version # noqa: F401
from ..fluid.inference import get_trt_compile_version # noqa: F401 from ..fluid.inference import get_trt_compile_version # noqa: F401
from ..fluid.inference import get_trt_runtime_version # noqa: F401 from ..fluid.inference import get_trt_runtime_version # noqa: F401
...@@ -28,7 +29,7 @@ from ..fluid.inference import PredictorPool # noqa: F401 ...@@ -28,7 +29,7 @@ from ..fluid.inference import PredictorPool # noqa: F401
__all__ = [ # noqa __all__ = [ # noqa
'Config', 'DataType', 'PlaceType', 'PrecisionType', 'Tensor', 'Predictor', 'Config', 'DataType', 'PlaceType', 'PrecisionType', 'Tensor', 'Predictor',
'create_predictor', 'get_version', 'get_trt_compile_version', 'create_predictor', 'get_version', '_get_phi_kernel_name',
'convert_to_mixed_precision', 'get_trt_runtime_version', 'get_trt_compile_version', 'convert_to_mixed_precision',
'get_num_bytes_of_data_type', 'PredictorPool' 'get_trt_runtime_version', 'get_num_bytes_of_data_type', 'PredictorPool'
] ]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册