From bb6581d21bc18a1fbc7d5cc379f89818e826be11 Mon Sep 17 00:00:00 2001 From: Tingquan Gao Date: Wed, 19 Jan 2022 06:26:01 +0000 Subject: [PATCH] refactor: raise warning when gpu numbers is not 4 --- ppcls/engine/engine.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index f86b092e..fc01de94 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -231,11 +231,13 @@ class Engine(object): save_dtype='float32') # for distributed - self.config["Global"][ - "distributed"] = paddle.distributed.get_world_size() != 1 + world_size = dist.get_world_size() + self.config["Global"]["distributed"] = world_size != 1 + if world_size != 4 and self.mode == "train": + msg = f"The training strategy in config files provided by PaddleClas is based on 4 gpus. But the number of gpus is {world_size} in current training. Please modify the stategy (learning rate, batch size and so on) if use config files in PaddleClas to train." + logger.warning(msg) if self.config["Global"]["distributed"]: dist.init_parallel_env() - if self.config["Global"]["distributed"]: self.model = paddle.DataParallel(self.model) # build postprocess for infer @@ -346,8 +348,8 @@ class Engine(object): @paddle.no_grad() def infer(self): assert self.mode == "infer" and self.eval_mode == "classification" - total_trainer = paddle.distributed.get_world_size() - local_rank = paddle.distributed.get_rank() + total_trainer = dist.get_world_size() + local_rank = dist.get_rank() image_list = get_image_list(self.config["Infer"]["infer_imgs"]) # data split image_list = image_list[local_rank::total_trainer] -- GitLab