提交 36594017 编写于 作者: T tensor-tang

make infer init explicit

上级 211b707b
...@@ -54,7 +54,8 @@ std::string num2str(T a) { ...@@ -54,7 +54,8 @@ std::string num2str(T a) {
} }
} // namespace } // namespace
bool NativePaddlePredictor::Init(std::shared_ptr<framework::Scope> scope) { bool NativePaddlePredictor::Init(
std::shared_ptr<framework::Scope> parent_scope) {
VLOG(3) << "Predictor::init()"; VLOG(3) << "Predictor::init()";
if (config_.use_gpu) { if (config_.use_gpu) {
...@@ -62,9 +63,9 @@ bool NativePaddlePredictor::Init(std::shared_ptr<framework::Scope> scope) { ...@@ -62,9 +63,9 @@ bool NativePaddlePredictor::Init(std::shared_ptr<framework::Scope> scope) {
} else { } else {
place_ = paddle::platform::CPUPlace(); place_ = paddle::platform::CPUPlace();
} }
if (scope) { if (parent_scope) {
scope_ = scope; scope_ = parent_scope;
sub_scope_ = &(scope->NewScope()); sub_scope_ = &(parent_scope->NewScope());
} else { } else {
paddle::framework::InitDevices(false); paddle::framework::InitDevices(false);
scope_.reset(new paddle::framework::Scope()); scope_.reset(new paddle::framework::Scope());
...@@ -275,7 +276,7 @@ CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>( ...@@ -275,7 +276,7 @@ CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kNative>(
} }
std::unique_ptr<PaddlePredictor> predictor(new NativePaddlePredictor(config)); std::unique_ptr<PaddlePredictor> predictor(new NativePaddlePredictor(config));
if (!dynamic_cast<NativePaddlePredictor *>(predictor.get())->Init()) { if (!dynamic_cast<NativePaddlePredictor *>(predictor.get())->Init(nullptr)) {
return nullptr; return nullptr;
} }
return std::move(predictor); return std::move(predictor);
......
...@@ -35,7 +35,7 @@ class NativePaddlePredictor : public PaddlePredictor { ...@@ -35,7 +35,7 @@ class NativePaddlePredictor : public PaddlePredictor {
: config_(config) {} : config_(config) {}
// will only create sub scope if have global scope // will only create sub scope if have global scope
bool Init(std::shared_ptr<framework::Scope> scope = nullptr); bool Init(std::shared_ptr<framework::Scope> parent_scope);
bool Run(const std::vector<PaddleTensor> &inputs, bool Run(const std::vector<PaddleTensor> &inputs,
std::vector<PaddleTensor> *output_data) override; std::vector<PaddleTensor> *output_data) override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册