未验证 提交 96cf0846 编写于 作者: Z zhang wenhui 提交者: GitHub

Merge pull request #1478 from frankwhzhang/gru4rec_mc

fix cluster_train.py style
......@@ -13,7 +13,6 @@ import net
SEED = 102
def parse_args():
parser = argparse.ArgumentParser("gru4rec benchmark.")
parser.add_argument(
......@@ -74,7 +73,6 @@ def train():
sgd_optimizer = fluid.optimizer.SGD(learning_rate=args.base_lr)
sgd_optimizer.minimize(avg_cost)
def train_loop(main_program):
""" train network """
pass_num = args.pass_num
......@@ -122,7 +120,6 @@ def train():
if args.is_local:
print("run local training")
train_loop(fluid.default_main_program())
else:
print("run distribute training")
t = fluid.DistributeTranspiler()
......@@ -138,5 +135,6 @@ def train():
elif args.role == "trainer":
print("run trainer")
train_loop(t.get_trainer_program())
if __name__ == "__main__":
train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册