提交 79bda2c4 编写于 作者: T tensor-tang

enable infer api with multi-threads

上级 418c41d8
......@@ -63,6 +63,7 @@ class PaddlePredictor {
struct Config;
PaddlePredictor() = default;
PaddlePredictor(const PaddlePredictor&) = delete;
PaddlePredictor& operator=(const PaddlePredictor&) = delete;
// Predict an record.
// The caller should be responsible for allocating and releasing the memory of
......@@ -76,7 +77,7 @@ class PaddlePredictor {
virtual std::unique_ptr<PaddlePredictor> Clone() = 0;
// Destroy the Predictor.
virtual ~PaddlePredictor() {}
virtual ~PaddlePredictor() = default;
// The common configs for all the predictors.
struct Config {
......
......@@ -54,7 +54,7 @@ std::string num2str(T a) {
}
} // namespace
bool NativePaddlePredictor::Init() {
bool NativePaddlePredictor::Init(std::shared_ptr<framework::Scope> scope) {
VLOG(3) << "Predictor::init()";
if (config_.use_gpu) {
......@@ -62,9 +62,15 @@ bool NativePaddlePredictor::Init() {
} else {
place_ = paddle::platform::CPUPlace();
}
paddle::framework::InitDevices(false);
if (scope) {
scope_ = scope;
sub_scope_ = &(scope->NewScope());
} else {
paddle::framework::InitDevices(false);
scope_.reset(new paddle::framework::Scope());
}
executor_.reset(new paddle::framework::Executor(place_));
scope_.reset(new paddle::framework::Scope());
// Initialize the inference program
if (!config_.model_dir.empty()) {
......@@ -83,13 +89,8 @@ bool NativePaddlePredictor::Init() {
return false;
}
ctx_ = executor_->Prepare(*inference_program_, 0);
// Create temporary variables first, so that the first batch do not need to
// create variables in the runtime. This is the logics of the old inference
// API.
// TODO(Superjomn) this should be modified when `Clone` is valid for
// multi-thread application.
executor_->CreateVariables(*inference_program_, scope_.get(), 0);
executor_->CreateVariables(
*inference_program_, sub_scope_ ? sub_scope_ : scope_.get(), 0);
// Get the feed_target_names and fetch_target_names
feed_target_names_ = inference_program_->GetFeedTargetNames();
......@@ -97,6 +98,13 @@ bool NativePaddlePredictor::Init() {
return true;
}
NativePaddlePredictor::~NativePaddlePredictor() {
if (sub_scope_) {
PADDLE_ENFORCE_NOT_NULL(scope_, "Should have parent scope!");
scope_->DeleteScope(sub_scope_);
}
};
bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data) {
VLOG(3) << "Predictor::predict";
......@@ -121,11 +129,12 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
}
// Run the inference program
// if share variables, we need not create variables
executor_->RunPreparedContext(ctx_.get(),
scope_.get(),
&feed_targets,
&fetch_targets,
false /* don't create variable eatch time */);
executor_->RunPreparedContext(
ctx_.get(),
sub_scope_ != nullptr ? sub_scope_ : scope_.get(),
&feed_targets,
&fetch_targets,
false /* don't create variable eatch time */);
if (!GetFetch(fetchs, output_data)) {
LOG(ERROR) << "fail to get fetchs";
return false;
......@@ -138,7 +147,7 @@ std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() {
VLOG(3) << "Predictor::clone";
std::unique_ptr<PaddlePredictor> cls(new NativePaddlePredictor(config_));
if (!dynamic_cast<NativePaddlePredictor *>(cls.get())->Init()) {
if (!dynamic_cast<NativePaddlePredictor *>(cls.get())->Init(scope_)) {
LOG(ERROR) << "fail to call Init";
return nullptr;
}
......
......@@ -34,14 +34,15 @@ class NativePaddlePredictor : public PaddlePredictor {
explicit NativePaddlePredictor(const NativeConfig &config)
: config_(config) {}
bool Init();
// will only create sub scope if have global scope
bool Init(std::shared_ptr<framework::Scope> scope = nullptr);
bool Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data) override;
std::unique_ptr<PaddlePredictor> Clone() override;
~NativePaddlePredictor() override{};
~NativePaddlePredictor() override;
private:
bool SetFeed(const std::vector<PaddleTensor> &input_datas,
......@@ -52,11 +53,13 @@ class NativePaddlePredictor : public PaddlePredictor {
NativeConfig config_;
platform::Place place_;
std::unique_ptr<framework::Executor> executor_;
std::unique_ptr<framework::Scope> scope_;
std::shared_ptr<framework::Scope> scope_;
std::unique_ptr<framework::ExecutorPrepareContext> ctx_;
std::unique_ptr<framework::ProgramDesc> inference_program_;
std::vector<std::string> feed_target_names_;
std::vector<std::string> fetch_target_names_;
// Do not use unique_ptr, use parent scope to delete
framework::Scope *sub_scope_{nullptr};
};
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册