未验证 提交 a3858036 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #11162 from tensor-tang/infer/api

enable infer api with multi-threads
...@@ -65,7 +65,10 @@ void Main(bool use_gpu) { ...@@ -65,7 +65,10 @@ void Main(bool use_gpu) {
} }
TEST(demo, word2vec_cpu) { Main(false /*use_gpu*/); } TEST(demo, word2vec_cpu) { Main(false /*use_gpu*/); }
#ifdef PADDLE_WITH_CUDA
TEST(demo, word2vec_gpu) { Main(true /*use_gpu*/); } TEST(demo, word2vec_gpu) { Main(true /*use_gpu*/); }
#endif
} // namespace demo } // namespace demo
} // namespace paddle } // namespace paddle
...@@ -63,6 +63,7 @@ class PaddlePredictor { ...@@ -63,6 +63,7 @@ class PaddlePredictor {
struct Config; struct Config;
PaddlePredictor() = default; PaddlePredictor() = default;
PaddlePredictor(const PaddlePredictor&) = delete; PaddlePredictor(const PaddlePredictor&) = delete;
PaddlePredictor& operator=(const PaddlePredictor&) = delete;
// Predict an record. // Predict an record.
// The caller should be responsible for allocating and releasing the memory of // The caller should be responsible for allocating and releasing the memory of
...@@ -76,7 +77,7 @@ class PaddlePredictor { ...@@ -76,7 +77,7 @@ class PaddlePredictor {
virtual std::unique_ptr<PaddlePredictor> Clone() = 0; virtual std::unique_ptr<PaddlePredictor> Clone() = 0;
// Destroy the Predictor. // Destroy the Predictor.
virtual ~PaddlePredictor() {} virtual ~PaddlePredictor() = default;
// The common configs for all the predictors. // The common configs for all the predictors.
struct Config { struct Config {
......
...@@ -54,7 +54,8 @@ std::string num2str(T a) { ...@@ -54,7 +54,8 @@ std::string num2str(T a) {
} }
} // namespace } // namespace
bool NativePaddlePredictor::Init() { bool NativePaddlePredictor::Init(
std::shared_ptr<framework::Scope> parent_scope) {
VLOG(3) << "Predictor::init()"; VLOG(3) << "Predictor::init()";
if (config_.use_gpu) { if (config_.use_gpu) {
...@@ -62,9 +63,15 @@ bool NativePaddlePredictor::Init() { ...@@ -62,9 +63,15 @@ bool NativePaddlePredictor::Init() {
} else { } else {
place_ = paddle::platform::CPUPlace(); place_ = paddle::platform::CPUPlace();
} }
paddle::framework::InitDevices(false); if (parent_scope) {
scope_ = parent_scope;
sub_scope_ = &(parent_scope->NewScope());
} else {
paddle::framework::InitDevices(false);
scope_.reset(new paddle::framework::Scope());
}
executor_.reset(new paddle::framework::Executor(place_)); executor_.reset(new paddle::framework::Executor(place_));
scope_.reset(new paddle::framework::Scope());
// Initialize the inference program // Initialize the inference program
if (!config_.model_dir.empty()) { if (!config_.model_dir.empty()) {
...@@ -83,13 +90,8 @@ bool NativePaddlePredictor::Init() { ...@@ -83,13 +90,8 @@ bool NativePaddlePredictor::Init() {
return false; return false;
} }
ctx_ = executor_->Prepare(*inference_program_, 0); ctx_ = executor_->Prepare(*inference_program_, 0);
executor_->CreateVariables(
// Create temporary variables first, so that the first batch do not need to *inference_program_, sub_scope_ ? sub_scope_ : scope_.get(), 0);
// create variables in the runtime. This is the logics of the old inference
// API.
// TODO(Superjomn) this should be modified when `Clone` is valid for
// multi-thread application.
executor_->CreateVariables(*inference_program_, scope_.get(), 0);
// Get the feed_target_names and fetch_target_names // Get the feed_target_names and fetch_target_names
feed_target_names_ = inference_program_->GetFeedTargetNames(); feed_target_names_ = inference_program_->GetFeedTargetNames();
...@@ -97,6 +99,13 @@ bool NativePaddlePredictor::Init() { ...@@ -97,6 +99,13 @@ bool NativePaddlePredictor::Init() {
return true; return true;
} }
NativePaddlePredictor::~NativePaddlePredictor() {
if (sub_scope_) {
PADDLE_ENFORCE_NOT_NULL(scope_, "Should have parent scope!");
scope_->DeleteScope(sub_scope_);
}
};
bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs, bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data) { std::vector<PaddleTensor> *output_data) {
VLOG(3) << "Predictor::predict"; VLOG(3) << "Predictor::predict";
...@@ -121,11 +130,12 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs, ...@@ -121,11 +130,12 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
} }
// Run the inference program // Run the inference program
// if share variables, we need not create variables // if share variables, we need not create variables
executor_->RunPreparedContext(ctx_.get(), executor_->RunPreparedContext(
scope_.get(), ctx_.get(),
&feed_targets, sub_scope_ != nullptr ? sub_scope_ : scope_.get(),
&fetch_targets, &feed_targets,
false /* don't create variable eatch time */); &fetch_targets,
false /* don't create variable eatch time */);
if (!GetFetch(fetchs, output_data)) { if (!GetFetch(fetchs, output_data)) {
LOG(ERROR) << "fail to get fetchs"; LOG(ERROR) << "fail to get fetchs";
return false; return false;
...@@ -138,7 +148,7 @@ std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() { ...@@ -138,7 +148,7 @@ std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() {
VLOG(3) << "Predictor::clone"; VLOG(3) << "Predictor::clone";
std::unique_ptr<PaddlePredictor> cls(new NativePaddlePredictor(config_)); std::unique_ptr<PaddlePredictor> cls(new NativePaddlePredictor(config_));
if (!dynamic_cast<NativePaddlePredictor *>(cls.get())->Init()) { if (!dynamic_cast<NativePaddlePredictor *>(cls.get())->Init(scope_)) {
LOG(ERROR) << "fail to call Init"; LOG(ERROR) << "fail to call Init";
return nullptr; return nullptr;
} }
...@@ -266,7 +276,7 @@ CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>( ...@@ -266,7 +276,7 @@ CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(
} }
std::unique_ptr<PaddlePredictor> predictor(new NativePaddlePredictor(config)); std::unique_ptr<PaddlePredictor> predictor(new NativePaddlePredictor(config));
if (!dynamic_cast<NativePaddlePredictor *>(predictor.get())->Init()) { if (!dynamic_cast<NativePaddlePredictor *>(predictor.get())->Init(nullptr)) {
return nullptr; return nullptr;
} }
return std::move(predictor); return std::move(predictor);
......
...@@ -34,14 +34,15 @@ class NativePaddlePredictor : public PaddlePredictor { ...@@ -34,14 +34,15 @@ class NativePaddlePredictor : public PaddlePredictor {
explicit NativePaddlePredictor(const NativeConfig &config) explicit NativePaddlePredictor(const NativeConfig &config)
: config_(config) {} : config_(config) {}
bool Init(); // will only create sub scope if have global scope
bool Init(std::shared_ptr<framework::Scope> parent_scope);
bool Run(const std::vector<PaddleTensor> &inputs, bool Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data) override; std::vector<PaddleTensor> *output_data) override;
std::unique_ptr<PaddlePredictor> Clone() override; std::unique_ptr<PaddlePredictor> Clone() override;
~NativePaddlePredictor() override{}; ~NativePaddlePredictor() override;
private: private:
bool SetFeed(const std::vector<PaddleTensor> &input_datas, bool SetFeed(const std::vector<PaddleTensor> &input_datas,
...@@ -52,11 +53,13 @@ class NativePaddlePredictor : public PaddlePredictor { ...@@ -52,11 +53,13 @@ class NativePaddlePredictor : public PaddlePredictor {
NativeConfig config_; NativeConfig config_;
platform::Place place_; platform::Place place_;
std::unique_ptr<framework::Executor> executor_; std::unique_ptr<framework::Executor> executor_;
std::unique_ptr<framework::Scope> scope_; std::shared_ptr<framework::Scope> scope_;
std::unique_ptr<framework::ExecutorPrepareContext> ctx_; std::unique_ptr<framework::ExecutorPrepareContext> ctx_;
std::unique_ptr<framework::ProgramDesc> inference_program_; std::unique_ptr<framework::ProgramDesc> inference_program_;
std::vector<std::string> feed_target_names_; std::vector<std::string> feed_target_names_;
std::vector<std::string> fetch_target_names_; std::vector<std::string> fetch_target_names_;
// Do not use unique_ptr, use parent scope to delete
framework::Scope *sub_scope_{nullptr};
}; };
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册