未验证 提交 e3c3af79 编写于 作者: Z zhang wenhui 提交者: GitHub

Merge pull request #1795 from frankwhzhang/fix_bug

fix vocab path & cluster train network
...@@ -13,22 +13,26 @@ import net ...@@ -13,22 +13,26 @@ import net
SEED = 102 SEED = 102
def parse_args(): def parse_args():
parser = argparse.ArgumentParser("gru4rec benchmark.") parser = argparse.ArgumentParser("gru4rec benchmark.")
parser.add_argument( parser.add_argument(
'--train_dir', type=str, default='train_data', help='train file address') '--train_dir',
parser.add_argument( type=str,
'--vocab_path', type=str, default='vocab.txt', help='vocab file address') default='train_data',
parser.add_argument( help='train file address')
'--is_local', type=int, default=1, help='whether local')
parser.add_argument( parser.add_argument(
'--hid_size', type=int, default=100, help='hid size') '--vocab_path',
type=str,
default='vocab.txt',
help='vocab file address')
parser.add_argument('--is_local', type=int, default=1, help='whether local')
parser.add_argument('--hid_size', type=int, default=100, help='hid size')
parser.add_argument( parser.add_argument(
'--model_dir', type=str, default='model_recall20', help='model dir') '--model_dir', type=str, default='model_recall20', help='model dir')
parser.add_argument( parser.add_argument(
'--batch_size', type=int, default=5, help='num of batch size') '--batch_size', type=int, default=5, help='num of batch size')
parser.add_argument( parser.add_argument('--pass_num', type=int, default=10, help='num of epoch')
'--pass_num', type=int, default=10, help='num of epoch')
parser.add_argument( parser.add_argument(
'--print_batch', type=int, default=10, help='num of print batch') '--print_batch', type=int, default=10, help='num of print batch')
parser.add_argument( parser.add_argument(
...@@ -40,19 +44,33 @@ def parse_args(): ...@@ -40,19 +44,33 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--role', type=str, default='pserver', help='trainer or pserver') '--role', type=str, default='pserver', help='trainer or pserver')
parser.add_argument( parser.add_argument(
'--endpoints', type=str, default='127.0.0.1:6000', help='The pserver endpoints, like: 127.0.0.1:6000, 127.0.0.1:6001') '--endpoints',
parser.add_argument( type=str,
'--current_endpoint', type=str, default='127.0.0.1:6000', help='The current_endpoint') default='127.0.0.1:6000',
parser.add_argument( help='The pserver endpoints, like: 127.0.0.1:6000, 127.0.0.1:6001')
'--trainer_id', type=int, default=0, help='trainer id ,only trainer_id=0 save model') parser.add_argument(
parser.add_argument( '--current_endpoint',
'--trainers', type=int, default=1, help='The num of trianers, (default: 1)') type=str,
default='127.0.0.1:6000',
help='The current_endpoint')
parser.add_argument(
'--trainer_id',
type=int,
default=0,
help='trainer id ,only trainer_id=0 save model')
parser.add_argument(
'--trainers',
type=int,
default=1,
help='The num of trianers, (default: 1)')
args = parser.parse_args() args = parser.parse_args()
return args return args
def get_cards(args): def get_cards(args):
return args.num_devices return args.num_devices
def train(): def train():
""" do training """ """ do training """
args = parse_args() args = parse_args()
...@@ -67,7 +85,8 @@ def train(): ...@@ -67,7 +85,8 @@ def train():
buffer_size=1000, word_freq_threshold=0, is_train=True) buffer_size=1000, word_freq_threshold=0, is_train=True)
# Train program # Train program
src_wordseq, dst_wordseq, avg_cost, acc = net.network(vocab_size=vocab_size, hid_size=hid_size) src_wordseq, dst_wordseq, avg_cost, acc = net.all_vocab_network(
vocab_size=vocab_size, hid_size=hid_size)
# Optimization to minimize lost # Optimization to minimize lost
sgd_optimizer = fluid.optimizer.SGD(learning_rate=args.base_lr) sgd_optimizer = fluid.optimizer.SGD(learning_rate=args.base_lr)
...@@ -97,8 +116,10 @@ def train(): ...@@ -97,8 +116,10 @@ def train():
lod_dst_wordseq = utils.to_lodtensor([dat[1] for dat in data], lod_dst_wordseq = utils.to_lodtensor([dat[1] for dat in data],
place) place)
ret_avg_cost = exe.run(main_program, ret_avg_cost = exe.run(main_program,
feed={ "src_wordseq": lod_src_wordseq, feed={
"dst_wordseq": lod_dst_wordseq}, "src_wordseq": lod_src_wordseq,
"dst_wordseq": lod_dst_wordseq
},
fetch_list=fetch_list) fetch_list=fetch_list)
avg_ppl = np.exp(ret_avg_cost[0]) avg_ppl = np.exp(ret_avg_cost[0])
newest_ppl = np.mean(avg_ppl) newest_ppl = np.mean(avg_ppl)
...@@ -113,7 +134,8 @@ def train(): ...@@ -113,7 +134,8 @@ def train():
feed_var_names = ["src_wordseq", "dst_wordseq"] feed_var_names = ["src_wordseq", "dst_wordseq"]
fetch_vars = [avg_cost, acc] fetch_vars = [avg_cost, acc]
if args.trainer_id == 0: if args.trainer_id == 0:
fluid.io.save_inference_model(save_dir, feed_var_names, fetch_vars, exe) fluid.io.save_inference_model(save_dir, feed_var_names,
fetch_vars, exe)
print("model saved in %s" % save_dir) print("model saved in %s" % save_dir)
print("finish training") print("finish training")
...@@ -123,7 +145,8 @@ def train(): ...@@ -123,7 +145,8 @@ def train():
else: else:
print("run distribute training") print("run distribute training")
t = fluid.DistributeTranspiler() t = fluid.DistributeTranspiler()
t.transpile(args.trainer_id, pservers=args.endpoints, trainers=args.trainers) t.transpile(
args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
if args.role == "pserver": if args.role == "pserver":
print("run psever") print("run psever")
pserver_prog = t.get_pserver_program(args.current_endpoint) pserver_prog = t.get_pserver_program(args.current_endpoint)
...@@ -136,5 +159,6 @@ def train(): ...@@ -136,5 +159,6 @@ def train():
print("run trainer") print("run trainer")
train_loop(t.get_trainer_program()) train_loop(t.get_trainer_program())
if __name__ == "__main__": if __name__ == "__main__":
train() train()
...@@ -11,6 +11,7 @@ import paddle ...@@ -11,6 +11,7 @@ import paddle
import utils import utils
def parse_args(): def parse_args():
parser = argparse.ArgumentParser("gru4rec benchmark.") parser = argparse.ArgumentParser("gru4rec benchmark.")
parser.add_argument( parser.add_argument(
...@@ -22,12 +23,15 @@ def parse_args(): ...@@ -22,12 +23,15 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--model_dir', type=str, default='model_recall20', help='model dir') '--model_dir', type=str, default='model_recall20', help='model dir')
parser.add_argument( parser.add_argument(
'--use_cuda', type=int, default='1', help='whether use cuda') '--use_cuda', type=int, default='0', help='whether use cuda')
parser.add_argument( parser.add_argument(
'--batch_size', type=int, default='5', help='batch_size') '--batch_size', type=int, default='5', help='batch_size')
parser.add_argument(
'--vocab_path', type=str, default='vocab.txt', help='vocab file')
args = parser.parse_args() args = parser.parse_args()
return args return args
def infer(test_reader, use_cuda, model_path): def infer(test_reader, use_cuda, model_path):
""" inference function """ """ inference function """
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
...@@ -72,11 +76,16 @@ if __name__ == "__main__": ...@@ -72,11 +76,16 @@ if __name__ == "__main__":
test_dir = args.test_dir test_dir = args.test_dir
model_dir = args.model_dir model_dir = args.model_dir
batch_size = args.batch_size batch_size = args.batch_size
vocab_path = args.vocab_path
use_cuda = True if args.use_cuda else False use_cuda = True if args.use_cuda else False
print("start index: ", start_index, " last_index:" ,last_index) print("start index: ", start_index, " last_index:", last_index)
vocab_size, test_reader = utils.prepare_data( vocab_size, test_reader = utils.prepare_data(
test_dir, "", batch_size=batch_size, test_dir,
buffer_size=1000, word_freq_threshold=0, is_train=False) vocab_path,
batch_size=batch_size,
buffer_size=1000,
word_freq_threshold=0,
is_train=False)
for epoch in range(start_index, last_index + 1): for epoch in range(start_index, last_index + 1):
epoch_path = model_dir + "/epoch_" + str(epoch) epoch_path = model_dir + "/epoch_" + str(epoch)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册