未验证 提交 e3c3af79 编写于 作者: Z zhang wenhui 提交者: GitHub

Merge pull request #1795 from frankwhzhang/fix_bug

fix vocab path & cluster train network
......@@ -13,22 +13,26 @@ import net
SEED = 102
def parse_args():
parser = argparse.ArgumentParser("gru4rec benchmark.")
parser.add_argument(
'--train_dir', type=str, default='train_data', help='train file address')
parser.add_argument(
'--vocab_path', type=str, default='vocab.txt', help='vocab file address')
parser.add_argument(
'--is_local', type=int, default=1, help='whether local')
'--train_dir',
type=str,
default='train_data',
help='train file address')
parser.add_argument(
'--hid_size', type=int, default=100, help='hid size')
'--vocab_path',
type=str,
default='vocab.txt',
help='vocab file address')
parser.add_argument('--is_local', type=int, default=1, help='whether local')
parser.add_argument('--hid_size', type=int, default=100, help='hid size')
parser.add_argument(
'--model_dir', type=str, default='model_recall20', help='model dir')
parser.add_argument(
'--batch_size', type=int, default=5, help='num of batch size')
parser.add_argument(
'--pass_num', type=int, default=10, help='num of epoch')
parser.add_argument('--pass_num', type=int, default=10, help='num of epoch')
parser.add_argument(
'--print_batch', type=int, default=10, help='num of print batch')
parser.add_argument(
......@@ -40,19 +44,33 @@ def parse_args():
parser.add_argument(
'--role', type=str, default='pserver', help='trainer or pserver')
parser.add_argument(
'--endpoints', type=str, default='127.0.0.1:6000', help='The pserver endpoints, like: 127.0.0.1:6000, 127.0.0.1:6001')
parser.add_argument(
'--current_endpoint', type=str, default='127.0.0.1:6000', help='The current_endpoint')
parser.add_argument(
'--trainer_id', type=int, default=0, help='trainer id ,only trainer_id=0 save model')
parser.add_argument(
'--trainers', type=int, default=1, help='The num of trianers, (default: 1)')
'--endpoints',
type=str,
default='127.0.0.1:6000',
help='The pserver endpoints, like: 127.0.0.1:6000, 127.0.0.1:6001')
parser.add_argument(
'--current_endpoint',
type=str,
default='127.0.0.1:6000',
help='The current_endpoint')
parser.add_argument(
'--trainer_id',
type=int,
default=0,
help='trainer id ,only trainer_id=0 save model')
parser.add_argument(
'--trainers',
type=int,
default=1,
help='The num of trianers, (default: 1)')
args = parser.parse_args()
return args
def get_cards(args):
return args.num_devices
def train():
""" do training """
args = parse_args()
......@@ -67,12 +85,13 @@ def train():
buffer_size=1000, word_freq_threshold=0, is_train=True)
# Train program
src_wordseq, dst_wordseq, avg_cost, acc = net.network(vocab_size=vocab_size, hid_size=hid_size)
src_wordseq, dst_wordseq, avg_cost, acc = net.all_vocab_network(
vocab_size=vocab_size, hid_size=hid_size)
# Optimization to minimize lost
sgd_optimizer = fluid.optimizer.SGD(learning_rate=args.base_lr)
sgd_optimizer.minimize(avg_cost)
def train_loop(main_program):
""" train network """
pass_num = args.pass_num
......@@ -97,9 +116,11 @@ def train():
lod_dst_wordseq = utils.to_lodtensor([dat[1] for dat in data],
place)
ret_avg_cost = exe.run(main_program,
feed={ "src_wordseq": lod_src_wordseq,
"dst_wordseq": lod_dst_wordseq},
fetch_list=fetch_list)
feed={
"src_wordseq": lod_src_wordseq,
"dst_wordseq": lod_dst_wordseq
},
fetch_list=fetch_list)
avg_ppl = np.exp(ret_avg_cost[0])
newest_ppl = np.mean(avg_ppl)
if i % args.print_batch == 0:
......@@ -113,7 +134,8 @@ def train():
feed_var_names = ["src_wordseq", "dst_wordseq"]
fetch_vars = [avg_cost, acc]
if args.trainer_id == 0:
fluid.io.save_inference_model(save_dir, feed_var_names, fetch_vars, exe)
fluid.io.save_inference_model(save_dir, feed_var_names,
fetch_vars, exe)
print("model saved in %s" % save_dir)
print("finish training")
......@@ -123,7 +145,8 @@ def train():
else:
print("run distribute training")
t = fluid.DistributeTranspiler()
t.transpile(args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
t.transpile(
args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
if args.role == "pserver":
print("run psever")
pserver_prog = t.get_pserver_program(args.current_endpoint)
......@@ -136,5 +159,6 @@ def train():
print("run trainer")
train_loop(t.get_trainer_program())
if __name__ == "__main__":
train()
......@@ -11,23 +11,27 @@ import paddle
import utils
def parse_args():
parser = argparse.ArgumentParser("gru4rec benchmark.")
parser.add_argument(
'--test_dir', type=str, default='test_data', help='test file address')
parser.add_argument(
'--start_index', type=int, default='1', help='start index')
'--start_index', type=int, default='1', help='start index')
parser.add_argument(
'--last_index', type=int, default='10', help='end index')
parser.add_argument(
'--last_index', type=int, default='10', help='end index')
'--model_dir', type=str, default='model_recall20', help='model dir')
parser.add_argument(
'--model_dir', type=str, default='model_recall20', help='model dir')
'--use_cuda', type=int, default='0', help='whether use cuda')
parser.add_argument(
'--use_cuda', type=int, default='1', help='whether use cuda')
'--batch_size', type=int, default='5', help='batch_size')
parser.add_argument(
'--batch_size', type=int, default='5', help='batch_size')
'--vocab_path', type=str, default='vocab.txt', help='vocab file')
args = parser.parse_args()
return args
def infer(test_reader, use_cuda, model_path):
""" inference function """
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
......@@ -72,11 +76,16 @@ if __name__ == "__main__":
test_dir = args.test_dir
model_dir = args.model_dir
batch_size = args.batch_size
vocab_path = args.vocab_path
use_cuda = True if args.use_cuda else False
print("start index: ", start_index, " last_index:" ,last_index)
print("start index: ", start_index, " last_index:", last_index)
vocab_size, test_reader = utils.prepare_data(
test_dir, "", batch_size=batch_size,
buffer_size=1000, word_freq_threshold=0, is_train=False)
test_dir,
vocab_path,
batch_size=batch_size,
buffer_size=1000,
word_freq_threshold=0,
is_train=False)
for epoch in range(start_index, last_index + 1):
epoch_path = model_dir + "/epoch_" + str(epoch)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册