From ad7a7ffb503eec1c61c25d0205da68275862bdae Mon Sep 17 00:00:00 2001 From: frankwhzhang Date: Tue, 27 Nov 2018 19:05:35 +0800 Subject: [PATCH] fix cluster_train.py style --- fluid/PaddleRec/gru4rec/cluster_train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/fluid/PaddleRec/gru4rec/cluster_train.py b/fluid/PaddleRec/gru4rec/cluster_train.py index 18b6e0df..b9b0820d 100644 --- a/fluid/PaddleRec/gru4rec/cluster_train.py +++ b/fluid/PaddleRec/gru4rec/cluster_train.py @@ -13,7 +13,6 @@ import net SEED = 102 - def parse_args(): parser = argparse.ArgumentParser("gru4rec benchmark.") parser.add_argument( @@ -74,7 +73,6 @@ def train(): sgd_optimizer = fluid.optimizer.SGD(learning_rate=args.base_lr) sgd_optimizer.minimize(avg_cost) - def train_loop(main_program): """ train network """ pass_num = args.pass_num @@ -122,7 +120,6 @@ def train(): if args.is_local: print("run local training") train_loop(fluid.default_main_program()) - else: print("run distribute training") t = fluid.DistributeTranspiler() @@ -138,5 +135,6 @@ def train(): elif args.role == "trainer": print("run trainer") train_loop(t.get_trainer_program()) + if __name__ == "__main__": train() -- GitLab