From f89d9893b696507c3478fe4fbb381010d03dd4bd Mon Sep 17 00:00:00 2001 From: zhangwenhui03 Date: Thu, 10 Jan 2019 17:18:06 +0800 Subject: [PATCH] fix style 2 --- fluid/PaddleRec/gru4rec/train.py | 40 ++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/fluid/PaddleRec/gru4rec/train.py b/fluid/PaddleRec/gru4rec/train.py index 2b889a44..c593ad69 100644 --- a/fluid/PaddleRec/gru4rec/train.py +++ b/fluid/PaddleRec/gru4rec/train.py @@ -17,21 +17,24 @@ SEED = 102 def parse_args(): parser = argparse.ArgumentParser("gru4rec benchmark.") parser.add_argument( - '--train_dir', type=str, default='train_data', help='train file address') + '--train_dir', + type=str, + default='train_data', + help='train file address') parser.add_argument( - '--vocab_path', type=str, default='vocab.txt', help='vocab file address') - parser.add_argument( - '--is_local', type=int, default=1, help='whether local') - parser.add_argument( - '--hid_size', type=int, default=100, help='hid size') + '--vocab_path', + type=str, + default='vocab.txt', + help='vocab file address') + parser.add_argument('--is_local', type=int, default=1, help='whether local') + parser.add_argument('--hid_size', type=int, default=100, help='hid size') parser.add_argument( '--model_dir', type=str, default='model_recall20', help='model dir') parser.add_argument( '--batch_size', type=int, default=5, help='num of batch size') parser.add_argument( '--print_batch', type=int, default=10, help='num of print batch') - parser.add_argument( - '--pass_num', type=int, default=10, help='num of epoch') + parser.add_argument('--pass_num', type=int, default=10, help='num of epoch') parser.add_argument( '--use_cuda', type=int, default=0, help='whether use gpu') parser.add_argument( @@ -43,9 +46,11 @@ def parse_args(): args = parser.parse_args() return args + def get_cards(args): return args.num_devices + def train(): """ do training """ args = parse_args() @@ -61,23 +66,23 @@ def train(): buffer_size=1000, word_freq_threshold=0, is_train=True) # Train program - src_wordseq, dst_wordseq, avg_cost, acc = net.network(vocab_size=vocab_size, hid_size=hid_size) + src_wordseq, dst_wordseq, avg_cost, acc = net.network( + vocab_size=vocab_size, hid_size=hid_size) # Optimization to minimize lost sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=args.base_lr) sgd_optimizer.minimize(avg_cost) - + # Initialize executor place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) if parallel: train_exe = fluid.ParallelExecutor( - use_cuda=use_cuda, - loss_name=avg_cost.name) + use_cuda=use_cuda, loss_name=avg_cost.name) else: train_exe = exe - + pass_num = args.pass_num model_dir = args.model_dir fetch_list = [avg_cost.name] @@ -96,10 +101,11 @@ def train(): place) lod_dst_wordseq = utils.to_lodtensor([dat[1] for dat in data], place) - ret_avg_cost = train_exe.run(feed={ - "src_wordseq": lod_src_wordseq, - "dst_wordseq": lod_dst_wordseq}, - fetch_list=fetch_list) + ret_avg_cost = train_exe.run(feed={ + "src_wordseq": lod_src_wordseq, + "dst_wordseq": lod_dst_wordseq + }, + fetch_list=fetch_list) avg_ppl = np.exp(ret_avg_cost[0]) newest_ppl = np.mean(avg_ppl) if i % args.print_batch == 0: -- GitLab