提交 71ab42bc 编写于 作者: D dongshuilong

add function of label print for class cpp infer

上级 524190a7
...@@ -35,55 +35,56 @@ using namespace paddle_infer; ...@@ -35,55 +35,56 @@ using namespace paddle_infer;
namespace PaddleClas { namespace PaddleClas {
class Classifier { class Classifier {
public: public:
explicit Classifier(const ClsConfig &config) { explicit Classifier(const ClsConfig &config) {
this->use_gpu_ = config.use_gpu; this->use_gpu_ = config.use_gpu;
this->gpu_id_ = config.gpu_id; this->gpu_id_ = config.gpu_id;
this->gpu_mem_ = config.gpu_mem; this->gpu_mem_ = config.gpu_mem;
this->cpu_math_library_num_threads_ = config.cpu_threads; this->cpu_math_library_num_threads_ = config.cpu_threads;
this->use_fp16_ = config.use_fp16; this->use_fp16_ = config.use_fp16;
this->use_mkldnn_ = config.use_mkldnn; this->use_mkldnn_ = config.use_mkldnn;
this->use_tensorrt_ = config.use_tensorrt; this->use_tensorrt_ = config.use_tensorrt;
this->mean_ = config.mean; this->mean_ = config.mean;
this->std_ = config.std; this->std_ = config.std;
this->resize_short_size_ = config.resize_short_size; this->resize_short_size_ = config.resize_short_size;
this->scale_ = config.scale; this->scale_ = config.scale;
this->crop_size_ = config.crop_size; this->crop_size_ = config.crop_size;
this->ir_optim_ = config.ir_optim; this->ir_optim_ = config.ir_optim;
LoadModel(config.cls_model_path, config.cls_params_path); LoadModel(config.cls_model_path, config.cls_params_path);
} }
// Load Paddle inference model // Load Paddle inference model
void LoadModel(const std::string &model_path, const std::string &params_path); void LoadModel(const std::string &model_path, const std::string &params_path);
// Run predictor // Run predictor
double Run(cv::Mat &img, std::vector<double> *times); void Run(cv::Mat &img, std::vector<float> &out_data, std::vector<int> &idx,
std::vector<double> &times);
private: private:
std::shared_ptr <Predictor> predictor_; std::shared_ptr<Predictor> predictor_;
bool use_gpu_ = false; bool use_gpu_ = false;
int gpu_id_ = 0; int gpu_id_ = 0;
int gpu_mem_ = 4000; int gpu_mem_ = 4000;
int cpu_math_library_num_threads_ = 4; int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false; bool use_mkldnn_ = false;
bool use_tensorrt_ = false; bool use_tensorrt_ = false;
bool use_fp16_ = false; bool use_fp16_ = false;
bool ir_optim_ = true; bool ir_optim_ = true;
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f}; std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
std::vector<float> std_ = {0.229f, 0.224f, 0.225f}; std::vector<float> std_ = {0.229f, 0.224f, 0.225f};
float scale_ = 0.00392157; float scale_ = 0.00392157;
int resize_short_size_ = 256; int resize_short_size_ = 256;
int crop_size_ = 224; int crop_size_ = 224;
// pre-process // pre-process
ResizeImg resize_op_; ResizeImg resize_op_;
Normalize normalize_op_; Normalize normalize_op_;
Permute permute_op_; Permute permute_op_;
CenterCropImg crop_op_; CenterCropImg crop_op_;
}; };
} // namespace PaddleClas } // namespace PaddleClas
...@@ -31,83 +31,101 @@ ...@@ -31,83 +31,101 @@
namespace PaddleClas { namespace PaddleClas {
class ClsConfig { class ClsConfig {
public: public:
explicit ClsConfig(const std::string &path) { explicit ClsConfig(const std::string &path) {
ReadYamlConfig(path); ReadYamlConfig(path);
this->infer_imgs = this->infer_imgs =
this->config_file["Global"]["infer_imgs"].as<std::string>(); this->config_file["Global"]["infer_imgs"].as<std::string>();
this->batch_size = this->config_file["Global"]["batch_size"].as<int>(); this->batch_size = this->config_file["Global"]["batch_size"].as<int>();
this->use_gpu = this->config_file["Global"]["use_gpu"].as<bool>(); this->use_gpu = this->config_file["Global"]["use_gpu"].as<bool>();
if (this->config_file["Global"]["gpu_id"].IsDefined()) if (this->config_file["Global"]["gpu_id"].IsDefined())
this->gpu_id = this->config_file["Global"]["gpu_id"].as<int>(); this->gpu_id = this->config_file["Global"]["gpu_id"].as<int>();
else else
this->gpu_id = 0; this->gpu_id = 0;
this->gpu_mem = this->config_file["Global"]["gpu_mem"].as<int>(); this->gpu_mem = this->config_file["Global"]["gpu_mem"].as<int>();
this->cpu_threads = this->cpu_threads =
this->config_file["Global"]["cpu_num_threads"].as<int>(); this->config_file["Global"]["cpu_num_threads"].as<int>();
this->use_mkldnn = this->config_file["Global"]["enable_mkldnn"].as<bool>(); this->use_mkldnn = this->config_file["Global"]["enable_mkldnn"].as<bool>();
this->use_tensorrt = this->config_file["Global"]["use_tensorrt"].as<bool>(); this->use_tensorrt = this->config_file["Global"]["use_tensorrt"].as<bool>();
this->use_fp16 = this->config_file["Global"]["use_fp16"].as<bool>(); this->use_fp16 = this->config_file["Global"]["use_fp16"].as<bool>();
this->enable_benchmark = this->enable_benchmark =
this->config_file["Global"]["enable_benchmark"].as<bool>(); this->config_file["Global"]["enable_benchmark"].as<bool>();
this->ir_optim = this->config_file["Global"]["ir_optim"].as<bool>(); this->ir_optim = this->config_file["Global"]["ir_optim"].as<bool>();
this->enable_profile = this->enable_profile =
this->config_file["Global"]["enable_profile"].as<bool>(); this->config_file["Global"]["enable_profile"].as<bool>();
this->cls_model_path = this->cls_model_path =
this->config_file["Global"]["inference_model_dir"].as<std::string>() + this->config_file["Global"]["inference_model_dir"].as<std::string>() +
OS_PATH_SEP + "inference.pdmodel"; OS_PATH_SEP + "inference.pdmodel";
this->cls_params_path = this->cls_params_path =
this->config_file["Global"]["inference_model_dir"].as<std::string>() + this->config_file["Global"]["inference_model_dir"].as<std::string>() +
OS_PATH_SEP + "inference.pdiparams"; OS_PATH_SEP + "inference.pdiparams";
this->resize_short_size = this->resize_short_size =
this->config_file["PreProcess"]["transform_ops"][0]["ResizeImage"] this->config_file["PreProcess"]["transform_ops"][0]["ResizeImage"]
["resize_short"] ["resize_short"]
.as<int>(); .as<int>();
this->crop_size = this->crop_size =
this->config_file["PreProcess"]["transform_ops"][1]["CropImage"]["size"] this->config_file["PreProcess"]["transform_ops"][1]["CropImage"]["size"]
.as<int>(); .as<int>();
this->scale = this->config_file["PreProcess"]["transform_ops"][2] this->scale = this->config_file["PreProcess"]["transform_ops"][2]
["NormalizeImage"]["scale"] ["NormalizeImage"]["scale"]
.as<float>(); .as<float>();
this->mean = this->config_file["PreProcess"]["transform_ops"][2] this->mean = this->config_file["PreProcess"]["transform_ops"][2]
["NormalizeImage"]["mean"] ["NormalizeImage"]["mean"]
.as < std::vector < float >> (); .as<std::vector<float>>();
this->std = this->config_file["PreProcess"]["transform_ops"][2] this->std = this->config_file["PreProcess"]["transform_ops"][2]
["NormalizeImage"]["std"] ["NormalizeImage"]["std"]
.as < std::vector < float >> (); .as<std::vector<float>>();
if (this->config_file["Global"]["benchmark"].IsDefined()) if (this->config_file["Global"]["benchmark"].IsDefined())
this->benchmark = this->config_file["Global"]["benchmark"].as<bool>(); this->benchmark = this->config_file["Global"]["benchmark"].as<bool>();
else else
this->benchmark = false; this->benchmark = false;
}
YAML::Node config_file; if (this->config_file["PostProcess"]["Topk"]["topk"].IsDefined())
bool use_gpu = false; this->topk = this->config_file["PostProcess"]["Topk"]["topk"].as<int>();
int gpu_id = 0; if (this->config_file["PostProcess"]["Topk"]["class_id_map_file"]
int gpu_mem = 4000; .IsDefined())
int cpu_threads = 1; this->class_id_map_path =
bool use_mkldnn = false; this->config_file["PostProcess"]["Topk"]["class_id_map_file"]
bool use_tensorrt = false; .as<std::string>();
bool use_fp16 = false; if (this->config_file["PostProcess"]["SavePreLabel"]["save_dir"]
bool benchmark = false; .IsDefined())
int batch_size = 1; this->label_save_dir =
bool enable_benchmark = false; this->config_file["PostProcess"]["SavePreLabel"]["save_dir"]
bool ir_optim = true; .as<std::string>();
bool enable_profile = false; ReadLabelMap();
}
std::string cls_model_path; YAML::Node config_file;
std::string cls_params_path; bool use_gpu = false;
std::string infer_imgs; int gpu_id = 0;
int gpu_mem = 4000;
int cpu_threads = 1;
bool use_mkldnn = false;
bool use_tensorrt = false;
bool use_fp16 = false;
bool benchmark = false;
int batch_size = 1;
bool enable_benchmark = false;
bool ir_optim = true;
bool enable_profile = false;
int resize_short_size = 256; std::string cls_model_path;
int crop_size = 224; std::string cls_params_path;
float scale = 0.00392157; std::string infer_imgs;
std::vector<float> mean = {0.485, 0.456, 0.406};
std::vector<float> std = {0.229, 0.224, 0.225};
void PrintConfigInfo(); int resize_short_size = 256;
int crop_size = 224;
float scale = 0.00392157;
std::vector<float> mean = {0.485, 0.456, 0.406};
std::vector<float> std = {0.229, 0.224, 0.225};
int topk = 5;
std::string class_id_map_path;
std::map<int, std::string> id_map;
std::string label_save_dir;
void ReadYamlConfig(const std::string &path); void PrintConfigInfo();
}; void ReadLabelMap();
void ReadYamlConfig(const std::string &path);
};
} // namespace PaddleClas } // namespace PaddleClas
...@@ -12,101 +12,105 @@ ...@@ -12,101 +12,105 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <algorithm>
#include <include/cls.h> #include <include/cls.h>
#include <numeric>
namespace PaddleClas { namespace PaddleClas {
void Classifier::LoadModel(const std::string &model_path, void Classifier::LoadModel(const std::string &model_path,
const std::string &params_path) { const std::string &params_path) {
paddle_infer::Config config; paddle_infer::Config config;
config.SetModel(model_path, params_path); config.SetModel(model_path, params_path);
if (this->use_gpu_) { if (this->use_gpu_) {
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
if (this->use_tensorrt_) { if (this->use_tensorrt_) {
config.EnableTensorRtEngine( config.EnableTensorRtEngine(
1 << 20, 1, 3, 1 << 20, 1, 3,
this->use_fp16_ ? paddle_infer::Config::Precision::kHalf this->use_fp16_ ? paddle_infer::Config::Precision::kHalf
: paddle_infer::Config::Precision::kFloat32, : paddle_infer::Config::Precision::kFloat32,
false, false); false, false);
}
} else {
config.DisableGpu();
if (this->use_mkldnn_) {
config.EnableMKLDNN();
// cache 10 different shapes for mkldnn to avoid memory leak
config.SetMkldnnCacheCapacity(10);
}
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}
config.SwitchUseFeedFetchOps(false);
// true for multiple input
config.SwitchSpecifyInputNames(true);
config.SwitchIrOptim(this->ir_optim_);
config.EnableMemoryOptim();
config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config);
} }
} else {
double Classifier::Run(cv::Mat &img, std::vector<double> *times) { config.DisableGpu();
cv::Mat srcimg; if (this->use_mkldnn_) {
cv::Mat resize_img; config.EnableMKLDNN();
img.copyTo(srcimg); // cache 10 different shapes for mkldnn to avoid memory leak
config.SetMkldnnCacheCapacity(10);
auto preprocess_start = std::chrono::system_clock::now();
this->resize_op_.Run(img, resize_img, this->resize_short_size_);
this->crop_op_.Run(resize_img, this->crop_size_);
this->normalize_op_.Run(&resize_img, this->mean_, this->std_, this->scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data());
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
auto preprocess_end = std::chrono::system_clock::now();
auto infer_start = std::chrono::system_clock::now();
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
std::vector<float> out_data;
auto output_names = this->predictor_->GetOutputNames();
auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
out_data.resize(out_num);
output_t->CopyToCpu(out_data.data());
auto infer_end = std::chrono::system_clock::now();
auto postprocess_start = std::chrono::system_clock::now();
int maxPosition =
max_element(out_data.begin(), out_data.end()) - out_data.begin();
auto postprocess_end = std::chrono::system_clock::now();
std::chrono::duration<float> preprocess_diff =
preprocess_end - preprocess_start;
times->push_back(double(preprocess_diff.count() * 1000));
std::chrono::duration<float> inference_diff = infer_end - infer_start;
double inference_cost_time = double(inference_diff.count() * 1000);
times->push_back(inference_cost_time);
std::chrono::duration<float> postprocess_diff =
postprocess_end - postprocess_start;
times->push_back(double(postprocess_diff.count() * 1000));
std::cout << "result: " << std::endl;
std::cout << "\tclass id: " << maxPosition << std::endl;
std::cout << std::fixed << std::setprecision(10)
<< "\tscore: " << double(out_data[maxPosition]) << std::endl;
return inference_cost_time;
} }
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}
config.SwitchUseFeedFetchOps(false);
// true for multiple input
config.SwitchSpecifyInputNames(true);
config.SwitchIrOptim(this->ir_optim_);
config.EnableMemoryOptim();
config.DisableGlogInfo();
this->predictor_ = CreatePredictor(config);
}
void Classifier::Run(cv::Mat &img, std::vector<float> &out_data,
std::vector<int> &idx, std::vector<double> &times) {
cv::Mat srcimg;
cv::Mat resize_img;
img.copyTo(srcimg);
auto preprocess_start = std::chrono::system_clock::now();
this->resize_op_.Run(img, resize_img, this->resize_short_size_);
this->crop_op_.Run(resize_img, this->crop_size_);
this->normalize_op_.Run(&resize_img, this->mean_, this->std_, this->scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data());
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
auto preprocess_end = std::chrono::system_clock::now();
auto infer_start = std::chrono::system_clock::now();
input_t->CopyFromCpu(input.data());
this->predictor_->Run();
auto output_names = this->predictor_->GetOutputNames();
auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
out_data.resize(out_num);
idx.resize(out_num);
output_t->CopyToCpu(out_data.data());
auto infer_end = std::chrono::system_clock::now();
auto postprocess_start = std::chrono::system_clock::now();
// int maxPosition =
// max_element(out_data.begin(), out_data.end()) - out_data.begin();
iota(idx.begin(), idx.end(), 0);
stable_sort(idx.begin(), idx.end(), [&out_data](int i1, int i2) {
return out_data[i1] > out_data[i2];
});
auto postprocess_end = std::chrono::system_clock::now();
std::chrono::duration<float> preprocess_diff =
preprocess_end - preprocess_start;
times[0] = double(preprocess_diff.count() * 1000);
std::chrono::duration<float> inference_diff = infer_end - infer_start;
double inference_cost_time = double(inference_diff.count() * 1000);
times[1] = inference_cost_time;
std::chrono::duration<float> postprocess_diff =
postprocess_end - postprocess_start;
times[2] = double(postprocess_diff.count() * 1000);
/* std::cout << "result: " << std::endl; */
/* std::cout << "\tclass id: " << maxPosition << std::endl; */
/* std::cout << std::fixed << std::setprecision(10) */
/* << "\tscore: " << double(out_data[maxPosition]) << std::endl; */
}
} // namespace PaddleClas } // namespace PaddleClas
...@@ -13,23 +13,40 @@ ...@@ -13,23 +13,40 @@
// limitations under the License. // limitations under the License.
#include <include/cls_config.h> #include <include/cls_config.h>
#include <ostream>
namespace PaddleClas { namespace PaddleClas {
void ClsConfig::PrintConfigInfo() { void ClsConfig::PrintConfigInfo() {
std::cout << "=======Paddle Class inference config======" << std::endl; std::cout << "=======Paddle Class inference config======" << std::endl;
std::cout << this->config_file << std::endl; std::cout << this->config_file << std::endl;
std::cout << "=======End of Paddle Class inference config======" << std::endl; std::cout << "=======End of Paddle Class inference config======" << std::endl;
} }
void ClsConfig::ReadYamlConfig(const std::string &path) {
void ClsConfig::ReadYamlConfig(const std::string &path) { try {
this->config_file = YAML::LoadFile(path);
} catch (YAML::BadFile &e) {
std::cout << "Something wrong in yaml file, please check yaml file"
<< std::endl;
exit(1);
}
}
try { void ClsConfig::ReadLabelMap() {
this->config_file = YAML::LoadFile(path); if (this->class_id_map_path.empty()) {
} catch (YAML::BadFile &e) { std::cout << "The Class Label file dose not input" << std::endl;
std::cout << "Something wrong in yaml file, please check yaml file" return;
<< std::endl; }
exit(1); std::ifstream in(this->class_id_map_path);
} std::string line;
if (in) {
while (getline(in, line)) {
int split_flag = line.find_first_of(" ");
this->id_map[std::stoi(line.substr(0, split_flag))] =
line.substr(split_flag + 1, line.size());
} }
}
}
}; // namespace PaddleClas }; // namespace PaddleClas
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <ostream> #include <ostream>
#include <vector> #include <vector>
#include <algorithm>
#include <cstring> #include <cstring>
#include <fstream> #include <fstream>
#include <numeric> #include <numeric>
...@@ -35,81 +36,99 @@ using namespace std; ...@@ -35,81 +36,99 @@ using namespace std;
using namespace cv; using namespace cv;
using namespace PaddleClas; using namespace PaddleClas;
DEFINE_string(config, DEFINE_string(config, "", "Path of yaml file");
"", "Path of yaml file"); DEFINE_string(c, "", "Path of yaml file");
DEFINE_string(c,
"", "Path of yaml file");
int main(int argc, char **argv) { int main(int argc, char **argv) {
google::ParseCommandLineFlags(&argc, &argv, true); google::ParseCommandLineFlags(&argc, &argv, true);
std::string yaml_path = ""; std::string yaml_path = "";
if (FLAGS_config == "" && FLAGS_c == "") { if (FLAGS_config == "" && FLAGS_c == "") {
std::cerr << "[ERROR] usage: " << std::endl std::cerr << "[ERROR] usage: " << std::endl
<< argv[0] << " -c $yaml_path" << std::endl << argv[0] << " -c $yaml_path" << std::endl
<< "or:" << std::endl << "or:" << std::endl
<< argv[0] << " -config $yaml_path" << std::endl; << argv[0] << " -config $yaml_path" << std::endl;
exit(1); exit(1);
} else if (FLAGS_config != "") { } else if (FLAGS_config != "") {
yaml_path = FLAGS_config; yaml_path = FLAGS_config;
} else { } else {
yaml_path = FLAGS_c; yaml_path = FLAGS_c;
} }
ClsConfig config(yaml_path); ClsConfig config(yaml_path);
config.PrintConfigInfo(); config.PrintConfigInfo();
std::string path(config.infer_imgs); std::string path(config.infer_imgs);
std::vector <std::string> img_files_list; std::vector<std::string> img_files_list;
if (cv::utils::fs::isDirectory(path)) { if (cv::utils::fs::isDirectory(path)) {
std::vector <cv::String> filenames; std::vector<cv::String> filenames;
cv::glob(path, filenames); cv::glob(path, filenames);
for (auto f : filenames) { for (auto f : filenames) {
img_files_list.push_back(f); img_files_list.push_back(f);
}
} else {
img_files_list.push_back(path);
} }
} else {
img_files_list.push_back(path);
}
std::cout << "img_file_list length: " << img_files_list.size() << std::endl; std::cout << "img_file_list length: " << img_files_list.size() << std::endl;
Classifier classifier(config); Classifier classifier(config);
double elapsed_time = 0.0; std::vector<double> cls_times = {0, 0, 0};
std::vector<double> cls_times; std::vector<double> cls_times_total = {0, 0, 0};
int warmup_iter = img_files_list.size() > 5 ? 5 : 0; double infer_time;
for (int idx = 0; idx < img_files_list.size(); ++idx) { std::vector<float> out_data;
std::string img_path = img_files_list[idx]; std::vector<int> result_index;
cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR); int warmup_iter = 5;
if (!srcimg.data) { bool label_output_equal_flag = true;
std::cerr << "[ERROR] image read failed! image path: " << img_path for (int idx = 0; idx < img_files_list.size(); ++idx) {
<< "\n"; std::string img_path = img_files_list[idx];
exit(-1); cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
} if (!srcimg.data) {
std::cerr << "[ERROR] image read failed! image path: " << img_path
<< "\n";
exit(-1);
}
cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB); cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB);
classifier.Run(srcimg, out_data, result_index, cls_times);
if (label_output_equal_flag and out_data.size() != config.id_map.size()) {
std::cout << "Warning: the label size is not equal to output size!"
<< std::endl;
label_output_equal_flag = false;
}
double run_time = classifier.Run(srcimg, &cls_times); int max_len = std::min(config.topk, int(out_data.size()));
if (idx >= warmup_iter) { std::cout << "Current image path: " << img_path << std::endl;
elapsed_time += run_time; infer_time = cls_times[0] + cls_times[1] + cls_times[2];
std::cout << "Current image path: " << img_path << std::endl; std::cout << "Current total inferen time cost: " << infer_time << " ms."
std::cout << "Current time cost: " << run_time << " s, " << std::endl;
<< "average time cost in all: " for (int i = 0; i < max_len; ++i) {
<< elapsed_time / (idx + 1 - warmup_iter) << " s." << std::endl; printf("\tTop%d: score: %.4f, ", i + 1, out_data[result_index[i]]);
} else { if (label_output_equal_flag)
std::cout << "Current time cost: " << run_time << " s." << std::endl; printf("label: %s\n", config.id_map[result_index[i]].c_str());
}
} }
if (idx >= warmup_iter) {
for (int i = 0; i < cls_times.size(); ++i)
cls_times_total[i] += cls_times[i];
}
}
if (img_files_list.size() > warmup_iter) {
std::string presion = "fp32"; infer_time = cls_times_total[0] + cls_times_total[1] + cls_times_total[2];
std::cout << "average time cost in all: "
<< infer_time / (img_files_list.size() - warmup_iter) << " ms."
<< std::endl;
}
if (config.use_fp16) std::string presion = "fp32";
presion = "fp16"; if (config.use_fp16)
if (config.benchmark) { presion = "fp16";
AutoLogger autolog("Classification", config.use_gpu, config.use_tensorrt, if (config.benchmark) {
config.use_mkldnn, config.cpu_threads, 1, AutoLogger autolog("Classification", config.use_gpu, config.use_tensorrt,
"1, 3, 224, 224", presion, cls_times, config.use_mkldnn, config.cpu_threads, 1,
img_files_list.size()); "1, 3, 224, 224", presion, cls_times_total,
autolog.report(); img_files_list.size());
} autolog.report();
return 0; }
return 0;
} }
OPENCV_DIR=/work/project/project/cpp_infer/opencv-3.4.7/opencv3 OPENCV_DIR=/work/project/project/test/opencv-3.4.7/opencv3
LIB_DIR=/work/project/project/cpp_infer/paddle_inference/ LIB_DIR=/work/project/project/test/paddle_inference/
CUDA_LIB_DIR=/usr/local/cuda/lib64 CUDA_LIB_DIR=/usr/local/cuda/lib64
CUDNN_LIB_DIR=/usr/lib/x86_64-linux-gnu/ CUDNN_LIB_DIR=/usr/lib/x86_64-linux-gnu/
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册