未验证 提交 e84234b5 编写于 作者: Y Yan Chunwei 提交者: GitHub

make clone thread safe (#15363)

上级 adba4384
...@@ -561,6 +561,7 @@ AnalysisPredictor::~AnalysisPredictor() { ...@@ -561,6 +561,7 @@ AnalysisPredictor::~AnalysisPredictor() {
} }
std::unique_ptr<PaddlePredictor> AnalysisPredictor::Clone() { std::unique_ptr<PaddlePredictor> AnalysisPredictor::Clone() {
std::lock_guard<std::mutex> lk(clone_mutex_);
auto *x = new AnalysisPredictor(config_); auto *x = new AnalysisPredictor(config_);
x->Init(scope_, inference_program_); x->Init(scope_, inference_program_);
return std::unique_ptr<PaddlePredictor>(x); return std::unique_ptr<PaddlePredictor>(x);
......
...@@ -115,6 +115,8 @@ class AnalysisPredictor : public PaddlePredictor { ...@@ -115,6 +115,8 @@ class AnalysisPredictor : public PaddlePredictor {
// concurrency problems, wrong results and memory leak, so cache them. // concurrency problems, wrong results and memory leak, so cache them.
std::vector<framework::LoDTensor> feed_tensors_; std::vector<framework::LoDTensor> feed_tensors_;
details::TensorArrayBatchCleaner tensor_array_batch_cleaner_; details::TensorArrayBatchCleaner tensor_array_batch_cleaner_;
// A mutex help to make Clone thread safe.
std::mutex clone_mutex_;
private: private:
// Some status here that help to determine the status inside the predictor. // Some status here that help to determine the status inside the predictor.
......
...@@ -179,8 +179,9 @@ TEST(AnalysisPredictor, Clone) { ...@@ -179,8 +179,9 @@ TEST(AnalysisPredictor, Clone) {
threads.emplace_back([&predictors, &inputs, i] { threads.emplace_back([&predictors, &inputs, i] {
LOG(INFO) << "thread #" << i << " running"; LOG(INFO) << "thread #" << i << " running";
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
auto predictor = predictors.front()->Clone();
for (int j = 0; j < 10; j++) { for (int j = 0; j < 10; j++) {
ASSERT_TRUE(predictors[i]->Run(inputs, &outputs)); ASSERT_TRUE(predictor->Run(inputs, &outputs));
} }
}); });
} }
......
...@@ -161,6 +161,8 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs, ...@@ -161,6 +161,8 @@ bool NativePaddlePredictor::Run(const std::vector<PaddleTensor> &inputs,
} }
std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() { std::unique_ptr<PaddlePredictor> NativePaddlePredictor::Clone() {
std::lock_guard<std::mutex> lk(clone_mutex_);
VLOG(3) << "Predictor::clone";
std::unique_ptr<PaddlePredictor> cls(new NativePaddlePredictor(config_)); std::unique_ptr<PaddlePredictor> cls(new NativePaddlePredictor(config_));
// Hot fix the bug that result diff in multi-thread. // Hot fix the bug that result diff in multi-thread.
// TODO(Superjomn) re-implement a real clone here. // TODO(Superjomn) re-implement a real clone here.
......
...@@ -74,6 +74,8 @@ class NativePaddlePredictor : public PaddlePredictor { ...@@ -74,6 +74,8 @@ class NativePaddlePredictor : public PaddlePredictor {
// Do not use unique_ptr, use parent scope to delete // Do not use unique_ptr, use parent scope to delete
framework::Scope *sub_scope_{nullptr}; framework::Scope *sub_scope_{nullptr};
details::TensorArrayBatchCleaner tensor_array_batch_cleaner_; details::TensorArrayBatchCleaner tensor_array_batch_cleaner_;
// A mutex to make Clone thread safe.
std::mutex clone_mutex_;
}; };
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册