未验证 提交 9afe4f67 编写于 作者: W Wu Yi 提交者: GitHub

Merge pull request #2152 from typhoonzero/fix_mp_mode_step_counts

Fix MP mode step count and add infinite mode for reader
...@@ -66,7 +66,6 @@ def parse_args(): ...@@ -66,7 +66,6 @@ def parse_args():
add_arg('split_var', bool, True, "Split params on pserver.") add_arg('split_var', bool, True, "Split params on pserver.")
add_arg('async_mode', bool, False, "Async distributed training, only for pserver mode.") add_arg('async_mode', bool, False, "Async distributed training, only for pserver mode.")
add_arg('reduce_strategy', str, "allreduce", "Choose from reduce or allreduce.") add_arg('reduce_strategy', str, "allreduce", "Choose from reduce or allreduce.")
add_arg('skip_unbalanced_data', bool, False, "Skip data not if data not balanced on nodes.")
add_arg('enable_sequential_execution', bool, False, "Skip data not if data not balanced on nodes.") add_arg('enable_sequential_execution', bool, False, "Skip data not if data not balanced on nodes.")
#for dgc #for dgc
add_arg('enable_dgc', bool, False, "Skip data not if data not balanced on nodes.") add_arg('enable_dgc', bool, False, "Skip data not if data not balanced on nodes.")
...@@ -85,9 +84,11 @@ def get_device_num(): ...@@ -85,9 +84,11 @@ def get_device_num():
device_num = subprocess.check_output(['nvidia-smi', '-L']).decode().count('\n') device_num = subprocess.check_output(['nvidia-smi', '-L']).decode().count('\n')
return device_num return device_num
def prepare_reader(is_train, pyreader, args, pass_id=0): def prepare_reader(is_train, pyreader, args, pass_id=1):
# NOTE: always use infinite reader for dist training
if is_train: if is_train:
reader = train(data_dir=args.data_dir, pass_id_as_seed=pass_id) reader = train(data_dir=args.data_dir, pass_id_as_seed=pass_id,
infinite=True)
else: else:
reader = val(data_dir=args.data_dir) reader = val(data_dir=args.data_dir)
if is_train: if is_train:
...@@ -138,6 +139,9 @@ def build_program(is_train, main_prog, startup_prog, args): ...@@ -138,6 +139,9 @@ def build_program(is_train, main_prog, startup_prog, args):
end_lr /= device_num_per_worker end_lr /= device_num_per_worker
total_images = args.total_images / trainer_count total_images = args.total_images / trainer_count
if os.getenv("FLAGS_selected_gpus"):
step = int(total_images / (args.batch_size / device_num_per_worker * args.multi_batch_repeat) + 1)
else:
step = int(total_images / (args.batch_size * args.multi_batch_repeat) + 1) step = int(total_images / (args.batch_size * args.multi_batch_repeat) + 1)
warmup_steps = step * 5 # warmup 5 passes warmup_steps = step * 5 # warmup 5 passes
epochs = [30, 60, 80] epochs = [30, 60, 80]
...@@ -264,7 +268,7 @@ def train_parallel(args): ...@@ -264,7 +268,7 @@ def train_parallel(args):
# num_iteration_per_drop_scope indicates how # num_iteration_per_drop_scope indicates how
# many iterations to clean up the temp variables which # many iterations to clean up the temp variables which
# is generated during execution. It may make the execution faster, # is generated during execution. It may make the execution faster,
# because the temp variable's shape maybe the same between two iterations # because the temp variable's shape are the same between two iterations.
strategy.num_iteration_per_drop_scope = 30 strategy.num_iteration_per_drop_scope = 30
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
...@@ -317,13 +321,18 @@ def train_parallel(args): ...@@ -317,13 +321,18 @@ def train_parallel(args):
over_all_start = time.time() over_all_start = time.time()
fetch_list = [train_cost.name, train_acc1.name, train_acc5.name] fetch_list = [train_cost.name, train_acc1.name, train_acc5.name]
# 1. MP mode, batch size for current process should be args.batch_size / GPUs
# 2. SP/PG mode, batch size for each process should be original args.batch_size
if os.getenv("FLAGS_selected_gpus"):
steps_per_pass = args.total_images / (args.batch_size / get_device_num()) / args.dist_env["num_trainers"]
else:
steps_per_pass = args.total_images / args.batch_size / args.dist_env["num_trainers"] steps_per_pass = args.total_images / args.batch_size / args.dist_env["num_trainers"]
for pass_id in range(args.num_epochs): for pass_id in range(args.num_epochs):
num_samples = 0 num_samples = 0
start_time = time.time() start_time = time.time()
batch_id = 1 batch_id = 1
# use pass_id+1 as per pass global shuffle for distributed training if pass_id == 0:
prepare_reader(True, train_pyreader, args, pass_id + 1)
train_pyreader.start() train_pyreader.start()
while True: while True:
try: try:
...@@ -342,11 +351,10 @@ def train_parallel(args): ...@@ -342,11 +351,10 @@ def train_parallel(args):
break break
num_samples += args.batch_size num_samples += args.batch_size
batch_id += 1 batch_id += 1
if args.skip_unbalanced_data and batch_id >= steps_per_pass: if batch_id >= steps_per_pass:
break break
print_train_time(start_time, time.time(), num_samples) print_train_time(start_time, time.time(), num_samples)
train_pyreader.reset()
if pass_id >= args.start_test_pass: if pass_id >= args.start_test_pass:
if args.multi_batch_repeat > 1: if args.multi_batch_repeat > 1:
copyback_repeat_bn_params(train_prog) copyback_repeat_bn_params(train_prog)
...@@ -362,6 +370,7 @@ def train_parallel(args): ...@@ -362,6 +370,7 @@ def train_parallel(args):
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
os.makedirs(model_path) os.makedirs(model_path)
fluid.io.save_persistables(startup_exe, model_path, main_program=train_prog) fluid.io.save_persistables(startup_exe, model_path, main_program=train_prog)
train_pyreader.reset()
startup_exe.close() startup_exe.close()
print("total train time: ", time.time() - over_all_start) print("total train time: ", time.time() - over_all_start)
......
...@@ -15,5 +15,7 @@ PADDLE_TRAINING_ROLE="TRAINER" \ ...@@ -15,5 +15,7 @@ PADDLE_TRAINING_ROLE="TRAINER" \
PADDLE_CURRENT_ENDPOINT="127.0.0.1:716${i}" \ PADDLE_CURRENT_ENDPOINT="127.0.0.1:716${i}" \
PADDLE_TRAINER_ID="${i}" \ PADDLE_TRAINER_ID="${i}" \
FLAGS_selected_gpus="${i}" \ FLAGS_selected_gpus="${i}" \
python dist_train.py --model $MODEL --update_method nccl2 --batch_size 32 --fp16 1 --scale_loss 8 &> logs/tr$i.log & python -u dist_train.py --model $MODEL --update_method nccl2 \
--batch_size 32 \
--fp16 0 --scale_loss 1 &> logs/tr$i.log &
done done
...@@ -12,7 +12,7 @@ np.random.seed(0) ...@@ -12,7 +12,7 @@ np.random.seed(0)
DATA_DIM = 224 DATA_DIM = 224
THREAD = 8 THREAD = 8
BUF_SIZE = 102400 BUF_SIZE = 1024
DATA_DIR = 'data/ILSVRC2012' DATA_DIR = 'data/ILSVRC2012'
...@@ -131,13 +131,16 @@ def _reader_creator(file_list, ...@@ -131,13 +131,16 @@ def _reader_creator(file_list,
color_jitter=False, color_jitter=False,
rotate=False, rotate=False,
data_dir=DATA_DIR, data_dir=DATA_DIR,
pass_id_as_seed=0): pass_id_as_seed=1,
infinite=False):
def reader(): def reader():
with open(file_list) as flist: with open(file_list) as flist:
full_lines = [line.strip() for line in flist] full_lines = [line.strip() for line in flist]
pass_id_as_seed_counter = pass_id_as_seed
while True:
if shuffle: if shuffle:
if pass_id_as_seed: if pass_id_as_seed_counter:
np.random.seed(pass_id_as_seed) np.random.seed(pass_id_as_seed_counter)
np.random.shuffle(full_lines) np.random.shuffle(full_lines)
if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'): if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits # distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
...@@ -161,8 +164,11 @@ def _reader_creator(file_list, ...@@ -161,8 +164,11 @@ def _reader_creator(file_list,
elif mode == 'test': elif mode == 'test':
img_path, label = line.split() img_path, label = line.split()
img_path = os.path.join(data_dir, img_path) img_path = os.path.join(data_dir, img_path)
yield [img_path] yield [img_path]
if not infinite:
break
pass_id_as_seed_counter += 1
print("passid ++, current: ", pass_id_as_seed_counter)
mapper = functools.partial( mapper = functools.partial(
process_image, mode=mode, color_jitter=color_jitter, rotate=rotate) process_image, mode=mode, color_jitter=color_jitter, rotate=rotate)
...@@ -170,7 +176,7 @@ def _reader_creator(file_list, ...@@ -170,7 +176,7 @@ def _reader_creator(file_list,
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def train(data_dir=DATA_DIR, pass_id_as_seed=0): def train(data_dir=DATA_DIR, pass_id_as_seed=1, infinite=False):
file_list = os.path.join(data_dir, 'train_list.txt') file_list = os.path.join(data_dir, 'train_list.txt')
return _reader_creator( return _reader_creator(
file_list, file_list,
...@@ -179,7 +185,8 @@ def train(data_dir=DATA_DIR, pass_id_as_seed=0): ...@@ -179,7 +185,8 @@ def train(data_dir=DATA_DIR, pass_id_as_seed=0):
color_jitter=False, color_jitter=False,
rotate=False, rotate=False,
data_dir=data_dir, data_dir=data_dir,
pass_id_as_seed=pass_id_as_seed) pass_id_as_seed=pass_id_as_seed,
infinite=infinite)
def val(data_dir=DATA_DIR): def val(data_dir=DATA_DIR):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册