提交 d0b02001 编写于 作者: T typhoonzero

update

上级 934f1f67
...@@ -65,6 +65,7 @@ def parse_args(): ...@@ -65,6 +65,7 @@ def parse_args():
add_arg('split_var', bool, True, "Split params on pserver.") add_arg('split_var', bool, True, "Split params on pserver.")
add_arg('async_mode', bool, False, "Async distributed training, only for pserver mode.") add_arg('async_mode', bool, False, "Async distributed training, only for pserver mode.")
add_arg('reduce_strategy', str, "allreduce", "Choose from reduce or allreduce.") add_arg('reduce_strategy', str, "allreduce", "Choose from reduce or allreduce.")
add_arg('skip_unbalanced_data', bool, False, "Skip data not if data not balanced on nodes.")
# yapf: enable # yapf: enable
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -99,6 +100,7 @@ def build_program(is_train, main_prog, startup_prog, args): ...@@ -99,6 +100,7 @@ def build_program(is_train, main_prog, startup_prog, args):
image_shape = [int(m) for m in args.image_shape.split(",")] image_shape = [int(m) for m in args.image_shape.split(",")]
trainer_count = args.dist_env["num_trainers"] trainer_count = args.dist_env["num_trainers"]
device_num_per_worker = get_device_num()
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
pyreader = fluid.layers.py_reader( pyreader = fluid.layers.py_reader(
capacity=16, capacity=16,
...@@ -124,8 +126,12 @@ def build_program(is_train, main_prog, startup_prog, args): ...@@ -124,8 +126,12 @@ def build_program(is_train, main_prog, startup_prog, args):
optimizer = None optimizer = None
if is_train: if is_train:
start_lr = args.lr start_lr = args.lr
# n * worker * repeat
end_lr = args.lr * trainer_count * args.multi_batch_repeat end_lr = args.lr * trainer_count * args.multi_batch_repeat
if os.getenv("FLAGS_selected_gpus"):
# in multi process mode, "trainer_count" will be total devices
# in the whole cluster, and we need to scale num_of nodes.
end_lr *= device_num_per_worker
total_images = args.total_images / trainer_count total_images = args.total_images / trainer_count
step = int(total_images / (args.batch_size * args.multi_batch_repeat) + 1) step = int(total_images / (args.batch_size * args.multi_batch_repeat) + 1)
warmup_steps = step * 5 # warmup 5 passes warmup_steps = step * 5 # warmup 5 passes
...@@ -251,6 +257,7 @@ def train_parallel(args): ...@@ -251,6 +257,7 @@ def train_parallel(args):
over_all_start = time.time() over_all_start = time.time()
fetch_list = [train_cost.name, train_acc1.name, train_acc5.name] fetch_list = [train_cost.name, train_acc1.name, train_acc5.name]
steps_per_pass = args.total_images / args.batch_size / args.dist_env["num_trainers"]
for pass_id in range(args.num_epochs): for pass_id in range(args.num_epochs):
num_samples = 0 num_samples = 0
start_time = time.time() start_time = time.time()
...@@ -273,6 +280,8 @@ def train_parallel(args): ...@@ -273,6 +280,8 @@ def train_parallel(args):
break break
num_samples += args.batch_size num_samples += args.batch_size
batch_id += 1 batch_id += 1
if args.skip_unbalanced_data and batch_id >= steps_per_pass:
break
print_train_time(start_time, time.time(), num_samples) print_train_time(start_time, time.time(), num_samples)
train_pyreader.reset() train_pyreader.reset()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册