提交 e6f4a801 编写于 作者: W wuzewu

Automatically adjust the batch size when the batch_size is less than the number of devices

上级 8fcf35b9
......@@ -110,8 +110,17 @@ class BasicTask(object):
# run config
self.config = config if config else RunConfig()
self.place, self.device_count = hub.common.get_running_device_info(
self.config)
self.place = self.places[0]
self.device_count = len(self.places)
if self.config.batch_size < self.device_count:
logger.warning(
"Batch size({}) is less than the count of devices({}), which is not allowed in current Paddle versions"
.format(self.config.batch_size, self.device_count))
logger.warning("Batch size automatically adjusted to {}".format(
self.device_count))
self.config._batch_size = self.device_count
self.exe = fluid.Executor(place=self.place)
self.build_strategy = fluid.BuildStrategy()
if self.config.enable_memory_optim:
......@@ -239,6 +248,12 @@ class BasicTask(object):
self.exe.run(self.env.startup_program)
self._build_env_end_event()
@property
def places(self):
if self.config.use_cuda:
return fluid.framework.cuda_places()
return fluid.framework.cpu_places()
@property
def is_train_phase(self):
return self.phase in ["train"]
......@@ -481,6 +496,9 @@ class BasicTask(object):
period_run_states = []
for run_step, batch in enumerate(self.reader(), start=1):
if self.config.use_data_parallel and len(batch) < self.device_count:
continue
step_run_state = RunState(len(self.fetch_list))
step_run_state.run_step = 1
num_batch_examples = len(batch)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册