diff --git a/fluid/PaddleCV/image_classification/dist_train/dist_train.py b/fluid/PaddleCV/image_classification/dist_train/dist_train.py index eb314085ea890ca8c6650a2c71d3d08f195a4def..cd1a84be8bc41630646289edcfcf9745a051d7e5 100644 --- a/fluid/PaddleCV/image_classification/dist_train/dist_train.py +++ b/fluid/PaddleCV/image_classification/dist_train/dist_train.py @@ -65,6 +65,7 @@ def parse_args(): add_arg('split_var', bool, True, "Split params on pserver.") add_arg('async_mode', bool, False, "Async distributed training, only for pserver mode.") add_arg('reduce_strategy', str, "allreduce", "Choose from reduce or allreduce.") + add_arg('skip_unbalanced_data', bool, False, "Skip data not if data not balanced on nodes.") # yapf: enable args = parser.parse_args() return args @@ -99,6 +100,7 @@ def build_program(is_train, main_prog, startup_prog, args): image_shape = [int(m) for m in args.image_shape.split(",")] trainer_count = args.dist_env["num_trainers"] + device_num_per_worker = get_device_num() with fluid.program_guard(main_prog, startup_prog): pyreader = fluid.layers.py_reader( capacity=16, @@ -124,8 +126,12 @@ def build_program(is_train, main_prog, startup_prog, args): optimizer = None if is_train: start_lr = args.lr - # n * worker * repeat end_lr = args.lr * trainer_count * args.multi_batch_repeat + if os.getenv("FLAGS_selected_gpus"): + # in multi process mode, "trainer_count" will be total devices + # in the whole cluster, and we need to scale num_of nodes. + end_lr *= device_num_per_worker + total_images = args.total_images / trainer_count step = int(total_images / (args.batch_size * args.multi_batch_repeat) + 1) warmup_steps = step * 5 # warmup 5 passes @@ -251,6 +257,7 @@ def train_parallel(args): over_all_start = time.time() fetch_list = [train_cost.name, train_acc1.name, train_acc5.name] + steps_per_pass = args.total_images / args.batch_size / args.dist_env["num_trainers"] for pass_id in range(args.num_epochs): num_samples = 0 start_time = time.time() @@ -273,6 +280,8 @@ def train_parallel(args): break num_samples += args.batch_size batch_id += 1 + if args.skip_unbalanced_data and batch_id >= steps_per_pass: + break print_train_time(start_time, time.time(), num_samples) train_pyreader.reset()