From c3922941de8114e8552afb4fc36a52adb925c404 Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Wed, 4 Nov 2020 22:48:19 +0800 Subject: [PATCH] fix dist training (#363) 1. fix dist training 2. fix cpp infer to support dir inference --- deploy/cpp_infer/src/main.cpp | 49 +++++++++++++++++++++++++---------- tools/program.py | 13 +++------- tools/train.py | 4 ++- 3 files changed, 43 insertions(+), 23 deletions(-) diff --git a/deploy/cpp_infer/src/main.cpp b/deploy/cpp_infer/src/main.cpp index 01397ef9..a5430838 100644 --- a/deploy/cpp_infer/src/main.cpp +++ b/deploy/cpp_infer/src/main.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -43,26 +44,48 @@ int main(int argc, char **argv) { config.PrintConfigInfo(); - std::string img_path(argv[2]); + std::string path(argv[2]); - cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR); - cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB); + std::vector img_files_list; + if (cv::utils::fs::isDirectory(path)) { + std::vector filenames; + cv::glob(path, filenames); + for (auto f : filenames) { + img_files_list.push_back(f); + } + } else { + img_files_list.push_back(path); + } + + std::cout << "img_file_list length: " << img_files_list.size() << std::endl; Classifier classifier(config.cls_model_dir, config.use_gpu, config.gpu_id, config.gpu_mem, config.cpu_math_library_num_threads, config.use_mkldnn, config.use_zero_copy_run, config.resize_short_size, config.crop_size); - auto start = std::chrono::system_clock::now(); - classifier.Run(srcimg); - auto end = std::chrono::system_clock::now(); - auto duration = - std::chrono::duration_cast(end - start); - std::cout << "Cost " - << double(duration.count()) * - std::chrono::microseconds::period::num / - std::chrono::microseconds::period::den - << " s" << std::endl; + double elapsed_time = 0.0; + int warmup_iter = img_files_list.size() > 5 ? 5 : 0; + for (int idx = 0; idx < img_files_list.size(); ++idx) { + std::string img_path = img_files_list[idx]; + cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR); + cv::cvtColor(srcimg, srcimg, cv::COLOR_BGR2RGB); + + auto start = std::chrono::system_clock::now(); + classifier.Run(srcimg); + auto end = std::chrono::system_clock::now(); + auto duration = + std::chrono::duration_cast(end - start); + double curr_time = double(duration.count()) * + std::chrono::microseconds::period::num / + std::chrono::microseconds::period::den; + if (idx >= warmup_iter) { + elapsed_time += curr_time; + } + std::cout << "Current time cost: " << curr_time << " s, " + << "average time cost in all: " + << elapsed_time / (idx + 1 - warmup_iter) << " s." << std::endl; + } return 0; } diff --git a/tools/program.py b/tools/program.py index a55bacef..1154f8e6 100644 --- a/tools/program.py +++ b/tools/program.py @@ -295,16 +295,11 @@ def run(dataloader, feeds = create_feeds(batch, use_mix) fetchs = create_fetchs(feeds, net, config, mode) if mode == 'train': - if config["use_data_parallel"]: - avg_loss = net.scale_loss(fetchs['loss']) - avg_loss.backward() - net.apply_collective_grads() - else: - avg_loss = fetchs['loss'] - avg_loss.backward() + avg_loss = fetchs['loss'] + avg_loss.backward() - optimizer.minimize(avg_loss) - net.clear_gradients() + optimizer.step() + optimizer.clear_grad() metric_list['lr'].update( optimizer._global_learning_rate().numpy()[0], batch_size) diff --git a/tools/train.py b/tools/train.py index 61612b4d..db623581 100644 --- a/tools/train.py +++ b/tools/train.py @@ -63,13 +63,15 @@ def main(args): use_data_parallel = int(os.getenv("PADDLE_TRAINERS_NUM", 1)) != 1 config["use_data_parallel"] = use_data_parallel + if config["use_data_parallel"]: + strategy = paddle.distributed.init_parallel_env() + net = program.create_model(config.ARCHITECTURE, config.classes_num) optimizer, lr_scheduler = program.create_optimizer( config, parameter_list=net.parameters()) if config["use_data_parallel"]: - strategy = paddle.distributed.init_parallel_env() net = paddle.DataParallel(net, strategy) # load model from checkpoint or pretrained model -- GitLab