提交 bb6581d2 编写于 作者: T Tingquan Gao 提交者: Tingquan Gao

refactor: raise warning when gpu numbers is not 4

上级 8f0bd5b5
......@@ -231,11 +231,13 @@ class Engine(object):
save_dtype='float32')
# for distributed
self.config["Global"][
"distributed"] = paddle.distributed.get_world_size() != 1
world_size = dist.get_world_size()
self.config["Global"]["distributed"] = world_size != 1
if world_size != 4 and self.mode == "train":
msg = f"The training strategy in config files provided by PaddleClas is based on 4 gpus. But the number of gpus is {world_size} in current training. Please modify the stategy (learning rate, batch size and so on) if use config files in PaddleClas to train."
logger.warning(msg)
if self.config["Global"]["distributed"]:
dist.init_parallel_env()
if self.config["Global"]["distributed"]:
self.model = paddle.DataParallel(self.model)
# build postprocess for infer
......@@ -346,8 +348,8 @@ class Engine(object):
@paddle.no_grad()
def infer(self):
assert self.mode == "infer" and self.eval_mode == "classification"
total_trainer = paddle.distributed.get_world_size()
local_rank = paddle.distributed.get_rank()
total_trainer = dist.get_world_size()
local_rank = dist.get_rank()
image_list = get_image_list(self.config["Infer"]["infer_imgs"])
# data split
image_list = image_list[local_rank::total_trainer]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册