未验证 提交 967115b8 编写于 作者: D dyning 提交者: GitHub

Merge pull request #592 from littletomatodonkey/fix_predictor_run

replace zero_copy_run to run for memory leak
...@@ -41,6 +41,8 @@ public: ...@@ -41,6 +41,8 @@ public:
this->use_mkldnn = bool(stoi(config_map_["use_mkldnn"])); this->use_mkldnn = bool(stoi(config_map_["use_mkldnn"]));
this->use_zero_copy_run = bool(stoi(config_map_["use_zero_copy_run"]));
this->max_side_len = stoi(config_map_["max_side_len"]); this->max_side_len = stoi(config_map_["max_side_len"]);
this->det_db_thresh = stod(config_map_["det_db_thresh"]); this->det_db_thresh = stod(config_map_["det_db_thresh"]);
...@@ -68,6 +70,8 @@ public: ...@@ -68,6 +70,8 @@ public:
bool use_mkldnn = false; bool use_mkldnn = false;
bool use_zero_copy_run = false;
int max_side_len = 960; int max_side_len = 960;
double det_db_thresh = 0.3; double det_db_thresh = 0.3;
......
...@@ -39,8 +39,8 @@ public: ...@@ -39,8 +39,8 @@ public:
explicit DBDetector(const std::string &model_dir, const bool &use_gpu, explicit DBDetector(const std::string &model_dir, const bool &use_gpu,
const int &gpu_id, const int &gpu_mem, const int &gpu_id, const int &gpu_mem,
const int &cpu_math_library_num_threads, const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const int &max_side_len, const bool &use_mkldnn, const bool &use_zero_copy_run,
const double &det_db_thresh, const int &max_side_len, const double &det_db_thresh,
const double &det_db_box_thresh, const double &det_db_box_thresh,
const double &det_db_unclip_ratio, const double &det_db_unclip_ratio,
const bool &visualize) { const bool &visualize) {
...@@ -49,6 +49,7 @@ public: ...@@ -49,6 +49,7 @@ public:
this->gpu_mem_ = gpu_mem; this->gpu_mem_ = gpu_mem;
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
this->use_mkldnn_ = use_mkldnn; this->use_mkldnn_ = use_mkldnn;
this->use_zero_copy_run_ = use_zero_copy_run;
this->max_side_len_ = max_side_len; this->max_side_len_ = max_side_len;
...@@ -75,6 +76,7 @@ private: ...@@ -75,6 +76,7 @@ private:
int gpu_mem_ = 4000; int gpu_mem_ = 4000;
int cpu_math_library_num_threads_ = 4; int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false; bool use_mkldnn_ = false;
bool use_zero_copy_run_ = false;
int max_side_len_ = 960; int max_side_len_ = 960;
......
...@@ -38,12 +38,14 @@ public: ...@@ -38,12 +38,14 @@ public:
explicit CRNNRecognizer(const std::string &model_dir, const bool &use_gpu, explicit CRNNRecognizer(const std::string &model_dir, const bool &use_gpu,
const int &gpu_id, const int &gpu_mem, const int &gpu_id, const int &gpu_mem,
const int &cpu_math_library_num_threads, const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const string &label_path) { const bool &use_mkldnn, const bool &use_zero_copy_run,
const string &label_path) {
this->use_gpu_ = use_gpu; this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id; this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem; this->gpu_mem_ = gpu_mem;
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
this->use_mkldnn_ = use_mkldnn; this->use_mkldnn_ = use_mkldnn;
this->use_zero_copy_run_ = use_zero_copy_run;
this->label_list_ = Utility::ReadDict(label_path); this->label_list_ = Utility::ReadDict(label_path);
this->label_list_.push_back(" "); this->label_list_.push_back(" ");
...@@ -64,6 +66,7 @@ private: ...@@ -64,6 +66,7 @@ private:
int gpu_mem_ = 4000; int gpu_mem_ = 4000;
int cpu_math_library_num_threads_ = 4; int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false; bool use_mkldnn_ = false;
bool use_zero_copy_run_ = false;
std::vector<std::string> label_list_; std::vector<std::string> label_list_;
......
...@@ -48,14 +48,15 @@ int main(int argc, char **argv) { ...@@ -48,14 +48,15 @@ int main(int argc, char **argv) {
cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR); cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
DBDetector det(config.det_model_dir, config.use_gpu, config.gpu_id, DBDetector det(
config.gpu_mem, config.cpu_math_library_num_threads, config.det_model_dir, config.use_gpu, config.gpu_id, config.gpu_mem,
config.use_mkldnn, config.max_side_len, config.det_db_thresh, config.cpu_math_library_num_threads, config.use_mkldnn,
config.det_db_box_thresh, config.det_db_unclip_ratio, config.use_zero_copy_run, config.max_side_len, config.det_db_thresh,
config.visualize); config.det_db_box_thresh, config.det_db_unclip_ratio, config.visualize);
CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id, CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id,
config.gpu_mem, config.cpu_math_library_num_threads, config.gpu_mem, config.cpu_math_library_num_threads,
config.use_mkldnn, config.char_list_file); config.use_mkldnn, config.use_zero_copy_run,
config.char_list_file);
auto start = std::chrono::system_clock::now(); auto start = std::chrono::system_clock::now();
std::vector<std::vector<std::vector<int>>> boxes; std::vector<std::vector<std::vector<int>>> boxes;
......
...@@ -31,7 +31,8 @@ void DBDetector::LoadModel(const std::string &model_dir) { ...@@ -31,7 +31,8 @@ void DBDetector::LoadModel(const std::string &model_dir) {
} }
// false for zero copy tensor // false for zero copy tensor
config.SwitchUseFeedFetchOps(false); // true for commom tensor
config.SwitchUseFeedFetchOps(!this->use_zero_copy_run_);
// true for multiple input // true for multiple input
config.SwitchSpecifyInputNames(true); config.SwitchSpecifyInputNames(true);
...@@ -59,12 +60,22 @@ void DBDetector::Run(cv::Mat &img, ...@@ -59,12 +60,22 @@ void DBDetector::Run(cv::Mat &img,
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data()); this->permute_op_.Run(&resize_img, input.data());
// Inference.
if (this->use_zero_copy_run_) {
auto input_names = this->predictor_->GetInputNames(); auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputTensor(input_names[0]); auto input_t = this->predictor_->GetInputTensor(input_names[0]);
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
input_t->copy_from_cpu(input.data()); input_t->copy_from_cpu(input.data());
this->predictor_->ZeroCopyRun(); this->predictor_->ZeroCopyRun();
} else {
paddle::PaddleTensor input_t;
input_t.shape = {1, 3, resize_img.rows, resize_img.cols};
input_t.data =
paddle::PaddleBuf(input.data(), input.size() * sizeof(float));
input_t.dtype = PaddleDType::FLOAT32;
std::vector<paddle::PaddleTensor> outputs;
this->predictor_->Run({input_t}, &outputs, 1);
}
std::vector<float> out_data; std::vector<float> out_data;
auto output_names = this->predictor_->GetOutputNames(); auto output_names = this->predictor_->GetOutputNames();
......
...@@ -39,18 +39,29 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes, ...@@ -39,18 +39,29 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
this->permute_op_.Run(&resize_img, input.data()); this->permute_op_.Run(&resize_img, input.data());
// Inference.
if (this->use_zero_copy_run_) {
auto input_names = this->predictor_->GetInputNames(); auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputTensor(input_names[0]); auto input_t = this->predictor_->GetInputTensor(input_names[0]);
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
input_t->copy_from_cpu(input.data()); input_t->copy_from_cpu(input.data());
this->predictor_->ZeroCopyRun(); this->predictor_->ZeroCopyRun();
} else {
paddle::PaddleTensor input_t;
input_t.shape = {1, 3, resize_img.rows, resize_img.cols};
input_t.data =
paddle::PaddleBuf(input.data(), input.size() * sizeof(float));
input_t.dtype = PaddleDType::FLOAT32;
std::vector<paddle::PaddleTensor> outputs;
this->predictor_->Run({input_t}, &outputs, 1);
}
std::vector<int64_t> rec_idx; std::vector<int64_t> rec_idx;
auto output_names = this->predictor_->GetOutputNames(); auto output_names = this->predictor_->GetOutputNames();
auto output_t = this->predictor_->GetOutputTensor(output_names[0]); auto output_t = this->predictor_->GetOutputTensor(output_names[0]);
auto rec_idx_lod = output_t->lod(); auto rec_idx_lod = output_t->lod();
auto shape_out = output_t->shape(); auto shape_out = output_t->shape();
int out_num = std::accumulate(shape_out.begin(), shape_out.end(), 1, int out_num = std::accumulate(shape_out.begin(), shape_out.end(), 1,
std::multiplies<int>()); std::multiplies<int>());
...@@ -120,7 +131,8 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) { ...@@ -120,7 +131,8 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
} }
// false for zero copy tensor // false for zero copy tensor
config.SwitchUseFeedFetchOps(false); // true for commom tensor
config.SwitchUseFeedFetchOps(!this->use_zero_copy_run_);
// true for multiple input // true for multiple input
config.SwitchSpecifyInputNames(true); config.SwitchSpecifyInputNames(true);
......
...@@ -4,6 +4,7 @@ gpu_id 0 ...@@ -4,6 +4,7 @@ gpu_id 0
gpu_mem 4000 gpu_mem 4000
cpu_math_library_num_threads 10 cpu_math_library_num_threads 10
use_mkldnn 0 use_mkldnn 0
use_zero_copy_run 1
# det config # det config
max_side_len 960 max_side_len 960
......
...@@ -17,28 +17,32 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) ...@@ -17,28 +17,32 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import cv2
import copy
import numpy as np
import math
import time
import sys
import paddle.fluid as fluid
import tools.infer.utility as utility import tools.infer.utility as utility
from ppocr.utils.utility import initial_logger from ppocr.utils.utility import initial_logger
logger = initial_logger() logger = initial_logger()
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
import cv2
from ppocr.data.det.sast_process import SASTProcessTest from ppocr.data.det.sast_process import SASTProcessTest
from ppocr.data.det.east_process import EASTProcessTest from ppocr.data.det.east_process import EASTProcessTest
from ppocr.data.det.db_process import DBProcessTest from ppocr.data.det.db_process import DBProcessTest
from ppocr.postprocess.db_postprocess import DBPostProcess from ppocr.postprocess.db_postprocess import DBPostProcess
from ppocr.postprocess.east_postprocess import EASTPostPocess from ppocr.postprocess.east_postprocess import EASTPostPocess
from ppocr.postprocess.sast_postprocess import SASTPostProcess from ppocr.postprocess.sast_postprocess import SASTPostProcess
import copy
import numpy as np
import math
import time
import sys
class TextDetector(object): class TextDetector(object):
def __init__(self, args): def __init__(self, args):
max_side_len = args.det_max_side_len max_side_len = args.det_max_side_len
self.det_algorithm = args.det_algorithm self.det_algorithm = args.det_algorithm
self.use_zero_copy_run = args.use_zero_copy_run
preprocess_params = {'max_side_len': max_side_len} preprocess_params = {'max_side_len': max_side_len}
postprocess_params = {} postprocess_params = {}
if self.det_algorithm == "DB": if self.det_algorithm == "DB":
...@@ -135,8 +139,12 @@ class TextDetector(object): ...@@ -135,8 +139,12 @@ class TextDetector(object):
return None, 0 return None, 0
im = im.copy() im = im.copy()
starttime = time.time() starttime = time.time()
if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(im) self.input_tensor.copy_from_cpu(im)
self.predictor.zero_copy_run() self.predictor.zero_copy_run()
else:
im = fluid.core.PaddleTensor(im)
self.predictor.run([im])
outputs = [] outputs = []
for output_tensor in self.output_tensors: for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
......
...@@ -17,15 +17,18 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) ...@@ -17,15 +17,18 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import tools.infer.utility as utility
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
import cv2 import cv2
import copy import copy
import numpy as np import numpy as np
import math import math
import time import time
import paddle.fluid as fluid
import tools.infer.utility as utility
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.character import CharacterOps from ppocr.utils.character import CharacterOps
...@@ -37,6 +40,7 @@ class TextRecognizer(object): ...@@ -37,6 +40,7 @@ class TextRecognizer(object):
self.character_type = args.rec_char_type self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm self.rec_algorithm = args.rec_algorithm
self.use_zero_copy_run = args.use_zero_copy_run
char_ops_params = { char_ops_params = {
"character_type": args.rec_char_type, "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path, "character_dict_path": args.rec_char_dict_path,
...@@ -102,8 +106,12 @@ class TextRecognizer(object): ...@@ -102,8 +106,12 @@ class TextRecognizer(object):
norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
starttime = time.time() starttime = time.time()
if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run() self.predictor.zero_copy_run()
else:
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch])
if self.loss_type == "ctc": if self.loss_type == "ctc":
rec_idx_batch = self.output_tensors[0].copy_to_cpu() rec_idx_batch = self.output_tensors[0].copy_to_cpu()
......
...@@ -71,6 +71,7 @@ def parse_args(): ...@@ -71,6 +71,7 @@ def parse_args():
default="./ppocr/utils/ppocr_keys_v1.txt") default="./ppocr/utils/ppocr_keys_v1.txt")
parser.add_argument("--use_space_char", type=bool, default=True) parser.add_argument("--use_space_char", type=bool, default=True)
parser.add_argument("--enable_mkldnn", type=bool, default=False) parser.add_argument("--enable_mkldnn", type=bool, default=False)
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
return parser.parse_args() return parser.parse_args()
...@@ -105,9 +106,12 @@ def create_predictor(args, mode): ...@@ -105,9 +106,12 @@ def create_predictor(args, mode):
#config.enable_memory_optim() #config.enable_memory_optim()
config.disable_glog_info() config.disable_glog_info()
# use zero copy if args.use_zero_copy_run:
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False) config.switch_use_feed_fetch_ops(False)
else:
config.switch_use_feed_fetch_ops(True)
predictor = create_paddle_predictor(config) predictor = create_paddle_predictor(config)
input_names = predictor.get_input_names() input_names = predictor.get_input_names()
input_tensor = predictor.get_input_tensor(input_names[0]) input_tensor = predictor.get_input_tensor(input_names[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册