提交 f89d9893 编写于 作者: Z zhangwenhui03

fix style 2

上级 9677fd6c
...@@ -17,21 +17,24 @@ SEED = 102 ...@@ -17,21 +17,24 @@ SEED = 102
def parse_args(): def parse_args():
parser = argparse.ArgumentParser("gru4rec benchmark.") parser = argparse.ArgumentParser("gru4rec benchmark.")
parser.add_argument( parser.add_argument(
'--train_dir', type=str, default='train_data', help='train file address') '--train_dir',
type=str,
default='train_data',
help='train file address')
parser.add_argument( parser.add_argument(
'--vocab_path', type=str, default='vocab.txt', help='vocab file address') '--vocab_path',
parser.add_argument( type=str,
'--is_local', type=int, default=1, help='whether local') default='vocab.txt',
parser.add_argument( help='vocab file address')
'--hid_size', type=int, default=100, help='hid size') parser.add_argument('--is_local', type=int, default=1, help='whether local')
parser.add_argument('--hid_size', type=int, default=100, help='hid size')
parser.add_argument( parser.add_argument(
'--model_dir', type=str, default='model_recall20', help='model dir') '--model_dir', type=str, default='model_recall20', help='model dir')
parser.add_argument( parser.add_argument(
'--batch_size', type=int, default=5, help='num of batch size') '--batch_size', type=int, default=5, help='num of batch size')
parser.add_argument( parser.add_argument(
'--print_batch', type=int, default=10, help='num of print batch') '--print_batch', type=int, default=10, help='num of print batch')
parser.add_argument( parser.add_argument('--pass_num', type=int, default=10, help='num of epoch')
'--pass_num', type=int, default=10, help='num of epoch')
parser.add_argument( parser.add_argument(
'--use_cuda', type=int, default=0, help='whether use gpu') '--use_cuda', type=int, default=0, help='whether use gpu')
parser.add_argument( parser.add_argument(
...@@ -43,9 +46,11 @@ def parse_args(): ...@@ -43,9 +46,11 @@ def parse_args():
args = parser.parse_args() args = parser.parse_args()
return args return args
def get_cards(args): def get_cards(args):
return args.num_devices return args.num_devices
def train(): def train():
""" do training """ """ do training """
args = parse_args() args = parse_args()
...@@ -61,23 +66,23 @@ def train(): ...@@ -61,23 +66,23 @@ def train():
buffer_size=1000, word_freq_threshold=0, is_train=True) buffer_size=1000, word_freq_threshold=0, is_train=True)
# Train program # Train program
src_wordseq, dst_wordseq, avg_cost, acc = net.network(vocab_size=vocab_size, hid_size=hid_size) src_wordseq, dst_wordseq, avg_cost, acc = net.network(
vocab_size=vocab_size, hid_size=hid_size)
# Optimization to minimize lost # Optimization to minimize lost
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=args.base_lr) sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=args.base_lr)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
# Initialize executor # Initialize executor
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
if parallel: if parallel:
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=use_cuda, use_cuda=use_cuda, loss_name=avg_cost.name)
loss_name=avg_cost.name)
else: else:
train_exe = exe train_exe = exe
pass_num = args.pass_num pass_num = args.pass_num
model_dir = args.model_dir model_dir = args.model_dir
fetch_list = [avg_cost.name] fetch_list = [avg_cost.name]
...@@ -96,10 +101,11 @@ def train(): ...@@ -96,10 +101,11 @@ def train():
place) place)
lod_dst_wordseq = utils.to_lodtensor([dat[1] for dat in data], lod_dst_wordseq = utils.to_lodtensor([dat[1] for dat in data],
place) place)
ret_avg_cost = train_exe.run(feed={ ret_avg_cost = train_exe.run(feed={
"src_wordseq": lod_src_wordseq, "src_wordseq": lod_src_wordseq,
"dst_wordseq": lod_dst_wordseq}, "dst_wordseq": lod_dst_wordseq
fetch_list=fetch_list) },
fetch_list=fetch_list)
avg_ppl = np.exp(ret_avg_cost[0]) avg_ppl = np.exp(ret_avg_cost[0])
newest_ppl = np.mean(avg_ppl) newest_ppl = np.mean(avg_ppl)
if i % args.print_batch == 0: if i % args.print_batch == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册