未验证 提交 9afe4f67 编写于 作者: W Wu Yi 提交者: GitHub

Merge pull request #2152 from typhoonzero/fix_mp_mode_step_counts

Fix MP mode step count and add infinite mode for reader
......@@ -66,7 +66,6 @@ def parse_args():
add_arg('split_var', bool, True, "Split params on pserver.")
add_arg('async_mode', bool, False, "Async distributed training, only for pserver mode.")
add_arg('reduce_strategy', str, "allreduce", "Choose from reduce or allreduce.")
add_arg('skip_unbalanced_data', bool, False, "Skip data not if data not balanced on nodes.")
add_arg('enable_sequential_execution', bool, False, "Skip data not if data not balanced on nodes.")
#for dgc
add_arg('enable_dgc', bool, False, "Skip data not if data not balanced on nodes.")
......@@ -85,9 +84,11 @@ def get_device_num():
device_num = subprocess.check_output(['nvidia-smi', '-L']).decode().count('\n')
return device_num
def prepare_reader(is_train, pyreader, args, pass_id=0):
def prepare_reader(is_train, pyreader, args, pass_id=1):
# NOTE: always use infinite reader for dist training
if is_train:
reader = train(data_dir=args.data_dir, pass_id_as_seed=pass_id)
reader = train(data_dir=args.data_dir, pass_id_as_seed=pass_id,
infinite=True)
else:
reader = val(data_dir=args.data_dir)
if is_train:
......@@ -138,7 +139,10 @@ def build_program(is_train, main_prog, startup_prog, args):
end_lr /= device_num_per_worker
total_images = args.total_images / trainer_count
step = int(total_images / (args.batch_size * args.multi_batch_repeat) + 1)
if os.getenv("FLAGS_selected_gpus"):
step = int(total_images / (args.batch_size / device_num_per_worker * args.multi_batch_repeat) + 1)
else:
step = int(total_images / (args.batch_size * args.multi_batch_repeat) + 1)
warmup_steps = step * 5 # warmup 5 passes
epochs = [30, 60, 80]
bd = [step * e for e in epochs]
......@@ -262,9 +266,9 @@ def train_parallel(args):
strategy = fluid.ExecutionStrategy()
strategy.num_threads = args.num_threads
# num_iteration_per_drop_scope indicates how
# many iterations to clean up the temp variables which
# is generated during execution. It may make the execution faster,
# because the temp variable's shape maybe the same between two iterations
# many iterations to clean up the temp variables which
# is generated during execution. It may make the execution faster,
# because the temp variable's shape are the same between two iterations.
strategy.num_iteration_per_drop_scope = 30
build_strategy = fluid.BuildStrategy()
......@@ -317,14 +321,19 @@ def train_parallel(args):
over_all_start = time.time()
fetch_list = [train_cost.name, train_acc1.name, train_acc5.name]
steps_per_pass = args.total_images / args.batch_size / args.dist_env["num_trainers"]
# 1. MP mode, batch size for current process should be args.batch_size / GPUs
# 2. SP/PG mode, batch size for each process should be original args.batch_size
if os.getenv("FLAGS_selected_gpus"):
steps_per_pass = args.total_images / (args.batch_size / get_device_num()) / args.dist_env["num_trainers"]
else:
steps_per_pass = args.total_images / args.batch_size / args.dist_env["num_trainers"]
for pass_id in range(args.num_epochs):
num_samples = 0
start_time = time.time()
batch_id = 1
# use pass_id+1 as per pass global shuffle for distributed training
prepare_reader(True, train_pyreader, args, pass_id + 1)
train_pyreader.start()
if pass_id == 0:
train_pyreader.start()
while True:
try:
if batch_id % 30 == 0:
......@@ -342,11 +351,10 @@ def train_parallel(args):
break
num_samples += args.batch_size
batch_id += 1
if args.skip_unbalanced_data and batch_id >= steps_per_pass:
if batch_id >= steps_per_pass:
break
print_train_time(start_time, time.time(), num_samples)
train_pyreader.reset()
if pass_id >= args.start_test_pass:
if args.multi_batch_repeat > 1:
copyback_repeat_bn_params(train_prog)
......@@ -362,6 +370,7 @@ def train_parallel(args):
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_persistables(startup_exe, model_path, main_program=train_prog)
train_pyreader.reset()
startup_exe.close()
print("total train time: ", time.time() - over_all_start)
......
......@@ -15,5 +15,7 @@ PADDLE_TRAINING_ROLE="TRAINER" \
PADDLE_CURRENT_ENDPOINT="127.0.0.1:716${i}" \
PADDLE_TRAINER_ID="${i}" \
FLAGS_selected_gpus="${i}" \
python dist_train.py --model $MODEL --update_method nccl2 --batch_size 32 --fp16 1 --scale_loss 8 &> logs/tr$i.log &
python -u dist_train.py --model $MODEL --update_method nccl2 \
--batch_size 32 \
--fp16 0 --scale_loss 1 &> logs/tr$i.log &
done
......@@ -12,7 +12,7 @@ np.random.seed(0)
DATA_DIM = 224
THREAD = 8
BUF_SIZE = 102400
BUF_SIZE = 1024
DATA_DIR = 'data/ILSVRC2012'
......@@ -131,38 +131,44 @@ def _reader_creator(file_list,
color_jitter=False,
rotate=False,
data_dir=DATA_DIR,
pass_id_as_seed=0):
pass_id_as_seed=1,
infinite=False):
def reader():
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if shuffle:
if pass_id_as_seed:
np.random.seed(pass_id_as_seed)
np.random.shuffle(full_lines)
if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
per_node_lines = len(full_lines) // trainer_count
lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1)
* per_node_lines]
print(
"read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines, len(lines),
len(full_lines)))
else:
lines = full_lines
for line in lines:
if mode == 'train' or mode == 'val':
img_path, label = line.split()
img_path = os.path.join(data_dir, img_path)
yield img_path, int(label)
elif mode == 'test':
img_path, label = line.split()
img_path = os.path.join(data_dir, img_path)
yield [img_path]
pass_id_as_seed_counter = pass_id_as_seed
while True:
if shuffle:
if pass_id_as_seed_counter:
np.random.seed(pass_id_as_seed_counter)
np.random.shuffle(full_lines)
if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
per_node_lines = len(full_lines) // trainer_count
lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1)
* per_node_lines]
print(
"read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines, len(lines),
len(full_lines)))
else:
lines = full_lines
for line in lines:
if mode == 'train' or mode == 'val':
img_path, label = line.split()
img_path = os.path.join(data_dir, img_path)
yield img_path, int(label)
elif mode == 'test':
img_path, label = line.split()
img_path = os.path.join(data_dir, img_path)
yield [img_path]
if not infinite:
break
pass_id_as_seed_counter += 1
print("passid ++, current: ", pass_id_as_seed_counter)
mapper = functools.partial(
process_image, mode=mode, color_jitter=color_jitter, rotate=rotate)
......@@ -170,7 +176,7 @@ def _reader_creator(file_list,
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def train(data_dir=DATA_DIR, pass_id_as_seed=0):
def train(data_dir=DATA_DIR, pass_id_as_seed=1, infinite=False):
file_list = os.path.join(data_dir, 'train_list.txt')
return _reader_creator(
file_list,
......@@ -179,7 +185,8 @@ def train(data_dir=DATA_DIR, pass_id_as_seed=0):
color_jitter=False,
rotate=False,
data_dir=data_dir,
pass_id_as_seed=pass_id_as_seed)
pass_id_as_seed=pass_id_as_seed,
infinite=infinite)
def val(data_dir=DATA_DIR):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册