未验证 提交 d60751fb 编写于 作者: F flame 提交者: GitHub

add python inference api (#15248)

add python inference api
上级 59ab98c9
...@@ -45,6 +45,7 @@ paddle.fluid.AsyncExecutor.save_model ArgSpec(args=['self', 'save_path'], vararg ...@@ -45,6 +45,7 @@ paddle.fluid.AsyncExecutor.save_model ArgSpec(args=['self', 'save_path'], vararg
paddle.fluid.AsyncExecutor.stop ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) paddle.fluid.AsyncExecutor.stop ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.fluid.CompiledProgram.__init__ ArgSpec(args=['self', 'program'], varargs=None, keywords=None, defaults=None) paddle.fluid.CompiledProgram.__init__ ArgSpec(args=['self', 'program'], varargs=None, keywords=None, defaults=None)
paddle.fluid.CompiledProgram.with_data_parallel ArgSpec(args=['self', 'loss_name', 'build_strategy', 'exec_strategy', 'share_vars_from'], varargs=None, keywords=None, defaults=(None, None, None, None)) paddle.fluid.CompiledProgram.with_data_parallel ArgSpec(args=['self', 'loss_name', 'build_strategy', 'exec_strategy', 'share_vars_from'], varargs=None, keywords=None, defaults=(None, None, None, None))
paddle.fluid.CompiledProgram.with_inference_optimize ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=None)
paddle.fluid.ExecutionStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.ExecutionStrategy) -> None paddle.fluid.ExecutionStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.ExecutionStrategy) -> None
paddle.fluid.BuildStrategy.GradientScaleStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.BuildStrategy.GradientScaleStrategy, arg0: int) -> None paddle.fluid.BuildStrategy.GradientScaleStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.BuildStrategy.GradientScaleStrategy, arg0: int) -> None
paddle.fluid.BuildStrategy.ReduceStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.BuildStrategy.ReduceStrategy, arg0: int) -> None paddle.fluid.BuildStrategy.ReduceStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.BuildStrategy.ReduceStrategy, arg0: int) -> None
......
...@@ -45,6 +45,7 @@ using contrib::AnalysisConfig; ...@@ -45,6 +45,7 @@ using contrib::AnalysisConfig;
class AnalysisPredictor : public PaddlePredictor { class AnalysisPredictor : public PaddlePredictor {
public: public:
explicit AnalysisPredictor(const AnalysisConfig &config) : config_(config) {} explicit AnalysisPredictor(const AnalysisConfig &config) : config_(config) {}
~AnalysisPredictor();
bool Init(const std::shared_ptr<framework::Scope> &parent_scope, bool Init(const std::shared_ptr<framework::Scope> &parent_scope,
const std::shared_ptr<framework::ProgramDesc> &program = nullptr); const std::shared_ptr<framework::ProgramDesc> &program = nullptr);
...@@ -95,7 +96,6 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -95,7 +96,6 @@ class AnalysisPredictor : public PaddlePredictor {
template <typename T> template <typename T>
void GetFetchOne(const framework::LoDTensor &fetchs, void GetFetchOne(const framework::LoDTensor &fetchs,
PaddleTensor *output_data); PaddleTensor *output_data);
~AnalysisPredictor();
// Some more detailed tests, they are made the friends of the predictor, so that // Some more detailed tests, they are made the friends of the predictor, so that
// the all the details can be tested. // the all the details can be tested.
......
set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune
feed_fetch_method pass_builder parallel_executor profiler layer scope_pool feed_fetch_method pass_builder parallel_executor profiler layer scope_pool
tracer) tracer analysis_predictor)
if(WITH_PYTHON) if(WITH_PYTHON)
list(APPEND PYBIND_DEPS py_func_op) list(APPEND PYBIND_DEPS py_func_op)
endif() endif()
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc ir.cc) set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc ir.cc inference_api.cc)
if(WITH_PYTHON) if(WITH_PYTHON)
if(WITH_AMD_GPU) if(WITH_AMD_GPU)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/pybind/inference_api.h"
#include <pybind11/stl.h>
#include <cstring>
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
using paddle::PaddleDType;
using paddle::PaddleBuf;
using paddle::PaddleTensor;
using paddle::PaddlePlace;
using paddle::PaddlePredictor;
using paddle::NativeConfig;
using paddle::NativePaddlePredictor;
using paddle::AnalysisPredictor;
using paddle::contrib::AnalysisConfig;
static void BindPaddleDType(py::module *m);
static void BindPaddleBuf(py::module *m);
static void BindPaddleTensor(py::module *m);
static void BindPaddlePlace(py::module *m);
static void BindPaddlePredictor(py::module *m);
static void BindNativeConfig(py::module *m);
static void BindNativePredictor(py::module *m);
static void BindAnalysisConfig(py::module *m);
static void BindAnalysisPredictor(py::module *m);
void BindInferenceApi(py::module *m) {
BindPaddleDType(m);
BindPaddleBuf(m);
BindPaddleTensor(m);
BindPaddlePlace(m);
BindPaddlePredictor(m);
BindNativeConfig(m);
BindNativePredictor(m);
BindAnalysisConfig(m);
BindAnalysisPredictor(m);
m->def("create_paddle_predictor",
&paddle::CreatePaddlePredictor<AnalysisConfig>);
m->def("create_paddle_predictor",
&paddle::CreatePaddlePredictor<NativeConfig>);
m->def("paddle_dtype_size", &paddle::PaddleDtypeSize);
}
void BindPaddleDType(py::module *m) {
py::enum_<PaddleDType>(*m, "PaddleDType")
.value("FLOAT32", PaddleDType::FLOAT32)
.value("INT64", PaddleDType::INT64);
}
void BindPaddleBuf(py::module *m) {
py::class_<PaddleBuf>(*m, "PaddleBuf")
.def(py::init<size_t>())
.def(py::init([](std::vector<float> &data) {
auto buf = PaddleBuf(data.size() * sizeof(float));
std::memcpy(buf.data(), static_cast<void *>(data.data()), buf.length());
return std::move(buf);
}))
.def(py::init([](std::vector<int64_t> &data) {
auto buf = PaddleBuf(data.size() * sizeof(int64_t));
std::memcpy(buf.data(), static_cast<void *>(data.data()), buf.length());
return std::move(buf);
}))
.def("resize", &PaddleBuf::Resize)
.def("reset",
[](PaddleBuf &self, std::vector<float> &data) {
self.Resize(data.size() * sizeof(float));
std::memcpy(self.data(), data.data(), self.length());
})
.def("reset",
[](PaddleBuf &self, std::vector<int64_t> &data) {
self.Resize(data.size() * sizeof(int64_t));
std::memcpy(self.data(), data.data(), self.length());
})
.def("empty", &PaddleBuf::empty)
.def("float_data",
[](PaddleBuf &self) -> std::vector<float> {
auto *data = static_cast<float *>(self.data());
return {data, data + self.length() / sizeof(*data)};
})
.def("int64_data",
[](PaddleBuf &self) -> std::vector<int64_t> {
int64_t *data = static_cast<int64_t *>(self.data());
return {data, data + self.length() / sizeof(*data)};
})
.def("length", &PaddleBuf::length);
}
void BindPaddleTensor(py::module *m) {
py::class_<PaddleTensor>(*m, "PaddleTensor")
.def(py::init<>())
.def_readwrite("name", &PaddleTensor::name)
.def_readwrite("shape", &PaddleTensor::shape)
.def_readwrite("data", &PaddleTensor::data)
.def_readwrite("dtype", &PaddleTensor::dtype)
.def_readwrite("lod", &PaddleTensor::lod);
}
void BindPaddlePlace(py::module *m) {
py::enum_<PaddlePlace>(*m, "PaddlePlace")
.value("UNK", PaddlePlace::kUNK)
.value("CPU", PaddlePlace::kCPU)
.value("GPU", PaddlePlace::kGPU);
}
void BindPaddlePredictor(py::module *m) {
auto paddle_predictor = py::class_<PaddlePredictor>(*m, "PaddlePredictor");
paddle_predictor
.def("run",
[](PaddlePredictor &self, const std::vector<PaddleTensor> &inputs) {
std::vector<PaddleTensor> outputs;
self.Run(inputs, &outputs);
return outputs;
})
.def("get_input_tensor", &PaddlePredictor::GetInputTensor)
.def("get_output_tensor", &PaddlePredictor::GetOutputTensor)
.def("zero_copy_run", &PaddlePredictor::ZeroCopyRun)
.def("clone", &PaddlePredictor::Clone);
auto config = py::class_<PaddlePredictor::Config>(paddle_predictor, "Config");
config.def(py::init<>())
.def_readwrite("model_dir", &PaddlePredictor::Config::model_dir);
}
void BindNativeConfig(py::module *m) {
py::class_<NativeConfig, PaddlePredictor::Config>(*m, "NativeConfig")
.def(py::init<>())
.def_readwrite("use_gpu", &NativeConfig::use_gpu)
.def_readwrite("device", &NativeConfig::device)
.def_readwrite("fraction_of_gpu_memory",
&NativeConfig::fraction_of_gpu_memory)
.def_readwrite("prog_file", &NativeConfig::prog_file)
.def_readwrite("param_file", &NativeConfig::param_file)
.def_readwrite("specify_input_name", &NativeConfig::specify_input_name)
.def("set_cpu_math_library_num_threads",
&NativeConfig::SetCpuMathLibraryNumThreads)
.def("cpu_math_library_num_threads",
&NativeConfig::cpu_math_library_num_threads);
}
void BindNativePredictor(py::module *m) {
py::class_<NativePaddlePredictor, PaddlePredictor>(*m,
"NativePaddlePredictor")
.def(py::init<const NativeConfig &>())
.def("init", &NativePaddlePredictor::Init)
.def("run",
[](NativePaddlePredictor &self,
const std::vector<PaddleTensor> &inputs) {
std::vector<PaddleTensor> outputs;
self.Run(inputs, &outputs);
return outputs;
})
.def("get_input_tensor", &NativePaddlePredictor::GetInputTensor)
.def("get_output_tensor", &NativePaddlePredictor::GetOutputTensor)
.def("zero_copy_run", &NativePaddlePredictor::ZeroCopyRun)
.def("clone", &NativePaddlePredictor::Clone)
.def("scope", &NativePaddlePredictor::scope,
py::return_value_policy::reference);
}
void BindAnalysisConfig(py::module *m) {
py::class_<AnalysisConfig>(*m, "AnalysisConfig")
.def(py::init<const AnalysisConfig &>())
.def(py::init<const std::string &>())
.def(py::init<const std::string &, const std::string &>())
.def("set_model", (void (AnalysisConfig::*)(const std::string &)) &
AnalysisConfig::SetModel)
.def("set_model", (void (AnalysisConfig::*)(const std::string &,
const std::string &)) &
AnalysisConfig::SetModel)
.def("set_prog_file", &AnalysisConfig::SetProgFile)
.def("set_params_file", &AnalysisConfig::SetParamsFile)
.def("model_dir", &AnalysisConfig::model_dir)
.def("prog_file", &AnalysisConfig::prog_file)
.def("params_file", &AnalysisConfig::params_file)
.def("enable_use_gpu", &AnalysisConfig::EnableUseGpu,
py::arg("memory_pool_init_size_mb"), py::arg("device_id") = 0)
.def("disable_gpu", &AnalysisConfig::DisableGpu)
.def("use_gpu", &AnalysisConfig::use_gpu)
.def("gpu_device_id", &AnalysisConfig::gpu_device_id)
.def("memory_pool_init_size_mb",
&AnalysisConfig::memory_pool_init_size_mb)
.def("fraction_of_gpu_memory_for_pool",
&AnalysisConfig::fraction_of_gpu_memory_for_pool)
.def("switch_ir_optim", &AnalysisConfig::SwitchIrOptim,
py::arg("x") = true)
.def("ir_optim", &AnalysisConfig::ir_optim)
.def("switch_use_feed_fetch_ops", &AnalysisConfig::SwitchUseFeedFetchOps,
py::arg("x") = true)
.def("use_feed_fetch_ops_enabled",
&AnalysisConfig::use_feed_fetch_ops_enabled)
.def("switch_specify_input_names",
&AnalysisConfig::SwitchSpecifyInputNames, py::arg("x") = true)
.def("specify_input_name", &AnalysisConfig::specify_input_name)
.def("enable_tensorrt_engine", &AnalysisConfig::EnableTensorRtEngine,
py::arg("workspace_size") = 1 << 20, py::arg("max_batch_size") = 1,
py::arg("min_subgraph_size") = 3)
.def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled)
.def("switch_ir_debug", &AnalysisConfig::SwitchIrDebug,
py::arg("x") = true)
.def("enable_mkldnn", &AnalysisConfig::EnableMKLDNN)
.def("mkldnn_enabled", &AnalysisConfig::mkldnn_enabled)
.def("set_cpu_math_library_num_threads",
&AnalysisConfig::SetCpuMathLibraryNumThreads)
.def("cpu_math_library_num_threads",
&AnalysisConfig::cpu_math_library_num_threads)
.def("to_native_config", &AnalysisConfig::ToNativeConfig)
.def("set_mkldnn_op", &AnalysisConfig::SetMKLDNNOp)
.def("set_model_buffer", &AnalysisConfig::SetModelBuffer)
.def("model_from_memory", &AnalysisConfig::model_from_memory)
.def("pass_builder", &AnalysisConfig::pass_builder,
py::return_value_policy::reference);
}
void BindAnalysisPredictor(py::module *m) {
py::class_<AnalysisPredictor, PaddlePredictor>(*m, "AnalysisPredictor")
.def(py::init<const AnalysisConfig &>())
.def("init", &AnalysisPredictor::Init)
.def(
"run",
[](AnalysisPredictor &self, const std::vector<PaddleTensor> &inputs) {
std::vector<PaddleTensor> outputs;
self.Run(inputs, &outputs);
return outputs;
})
.def("get_input_tensor", &AnalysisPredictor::GetInputTensor)
.def("get_output_tensor", &AnalysisPredictor::GetOutputTensor)
.def("zero_copy_run", &AnalysisPredictor::ZeroCopyRun)
.def("clone", &AnalysisPredictor::Clone)
.def("scope", &AnalysisPredictor::scope,
py::return_value_policy::reference);
}
} // namespace pybind
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <pybind11/pybind11.h>
namespace paddle {
namespace pybind {
void BindInferenceApi(pybind11::module *m);
} // namespace pybind
} // namespace paddle
...@@ -49,6 +49,7 @@ limitations under the License. */ ...@@ -49,6 +49,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/imperative.h"
#include "paddle/fluid/pybind/inference_api.h"
#include "paddle/fluid/pybind/ir.h" #include "paddle/fluid/pybind/ir.h"
#include "paddle/fluid/pybind/protobuf.h" #include "paddle/fluid/pybind/protobuf.h"
#include "paddle/fluid/pybind/pybind.h" // NOLINT #include "paddle/fluid/pybind/pybind.h" // NOLINT
...@@ -1083,9 +1084,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1083,9 +1084,9 @@ All parameter, weight, gradient are variables in Paddle.
BindRecordIOWriter(&m); BindRecordIOWriter(&m);
BindAsyncExecutor(&m); BindAsyncExecutor(&m);
BindGraph(&m); BindGraph(&m);
BindNode(&m); BindNode(&m);
BindInferenceApi(&m);
} }
} // namespace pybind } // namespace pybind
} // namespace paddle } // namespace paddle
...@@ -24,6 +24,8 @@ __all__ = ['CompiledProgram', 'ExecutionStrategy', 'BuildStrategy'] ...@@ -24,6 +24,8 @@ __all__ = ['CompiledProgram', 'ExecutionStrategy', 'BuildStrategy']
ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy
BuildStrategy = core.ParallelExecutor.BuildStrategy BuildStrategy = core.ParallelExecutor.BuildStrategy
InferNativeConfig = core.NativeConfig
InferAnalysisConfig = core.AnalysisConfig
def _place_obj(place): def _place_obj(place):
...@@ -70,6 +72,7 @@ class CompiledProgram(object): ...@@ -70,6 +72,7 @@ class CompiledProgram(object):
self._executor = None self._executor = None
self._compiled = False self._compiled = False
self._is_data_parallel = False self._is_data_parallel = False
self._is_inference = False
def with_data_parallel(self, def with_data_parallel(self,
loss_name=None, loss_name=None,
...@@ -109,10 +112,24 @@ class CompiledProgram(object): ...@@ -109,10 +112,24 @@ class CompiledProgram(object):
self._build_strategy = BuildStrategy() self._build_strategy = BuildStrategy()
return self return self
def _with_distributed(self): def with_inference_optimize(self, config):
raise NotImplementedError() """ Add inference optimize
Args:
config: instance of `NativeConfig` or `AnalysisConfig` to create predictor
Returns:
self
"""
assert any([
isinstance(config, InferNativeConfig),
isinstance(config, InferAnalysisConfig)
])
self._is_data_parallel = False
self._is_inference = True
self._infer_config = config
return self
def _with_inference_optimize(self): def _with_distributed(self):
raise NotImplementedError() raise NotImplementedError()
def _compile_data_parallel(self): def _compile_data_parallel(self):
...@@ -177,6 +194,10 @@ class CompiledProgram(object): ...@@ -177,6 +194,10 @@ class CompiledProgram(object):
if self._loss_name else six.u(''), self._scope, self._local_scopes, if self._loss_name else six.u(''), self._scope, self._local_scopes,
self._exec_strategy, self._build_strategy) self._exec_strategy, self._build_strategy)
def _compile_inference(self):
assert self._is_data_parallel is False
return core.create_paddle_predictor(self._infer_config)
def _compile(self, scope, place): def _compile(self, scope, place):
"""Compile the program based on the configs. """Compile the program based on the configs.
...@@ -200,6 +221,8 @@ class CompiledProgram(object): ...@@ -200,6 +221,8 @@ class CompiledProgram(object):
self._place = place self._place = place
if self._is_data_parallel: if self._is_data_parallel:
self._executor = self._compile_data_parallel() self._executor = self._compile_data_parallel()
elif self._is_inference:
self._executor = self._compile_inference()
else: else:
p = _place_obj(self._place) p = _place_obj(self._place)
self._executor = core.Executor(p) self._executor = core.Executor(p)
......
...@@ -27,6 +27,8 @@ from .. import compat as cpt ...@@ -27,6 +27,8 @@ from .. import compat as cpt
__all__ = ['Executor', 'global_scope', 'scope_guard'] __all__ = ['Executor', 'global_scope', 'scope_guard']
g_scope = core.Scope() g_scope = core.Scope()
InferNativeConfig = core.NativeConfig
InferAnalysisConfig = core.AnalysisConfig
def global_scope(): def global_scope():
...@@ -533,6 +535,8 @@ class Executor(object): ...@@ -533,6 +535,8 @@ class Executor(object):
fetch_list=fetch_list, fetch_list=fetch_list,
fetch_var_name=fetch_var_name, fetch_var_name=fetch_var_name,
return_numpy=return_numpy) return_numpy=return_numpy)
elif program._is_inference:
return self._run_inference(program, feed)
else: else:
# TODO(panyx0718): Can compile program to optimize executor # TODO(panyx0718): Can compile program to optimize executor
# performance. # performance.
...@@ -590,3 +594,6 @@ class Executor(object): ...@@ -590,3 +594,6 @@ class Executor(object):
if return_numpy: if return_numpy:
outs = as_numpy(outs) outs = as_numpy(outs)
return outs return outs
def _run_inference(self, program, feed):
return self.executor.run(feed)
...@@ -195,9 +195,34 @@ def infer(use_cuda, save_dirname=None): ...@@ -195,9 +195,34 @@ def infer(use_cuda, save_dirname=None):
}, },
fetch_list=fetch_targets, fetch_list=fetch_targets,
return_numpy=False) return_numpy=False)
print(results[0].recursive_sequence_lengths())
def to_infer_tensor(lod_tensor):
infer_tensor = fluid.core.PaddleTensor()
infer_tensor.lod = lod_tensor.lod()
infer_tensor.data = fluid.core.PaddleBuf(np.array(lod_tensor))
infer_tensor.shape = lod_tensor.shape()
infer_tensor.dtype = fluid.core.PaddleDType.INT64
return infer_tensor
infer_inputs = [first_word, second_word, third_word, fourth_word]
infer_inputs = [to_infer_tensor(t) for t in infer_inputs]
infer_config = fluid.core.NativeConfig()
infer_config.model_dir = 'word2vec.inference.model'
infer_config.use_gpu = use_cuda
if use_cuda:
infer_config.device = 0
infer_config.fraction_of_gpu_memory = 0.15
compiled_program = fluid.compiler.CompiledProgram(inference_program)
compiled_program.with_inference_optimize(infer_config)
assert compiled_program._is_inference is True
infer_outputs = exe.run(compiled_program, feed=infer_inputs)
np_data = np.array(results[0]) np_data = np.array(results[0])
print("Inference Shape: ", np_data.shape) infer_out = infer_outputs[0].data.float_data()
for a, b in zip(np_data[0], infer_out):
g_a = float("{:.6g}".format(a))
g_b = float("{:.6g}".format(b))
assert g_a == g_b
def main(use_cuda, is_sparse, is_parallel): def main(use_cuda, is_sparse, is_parallel):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册