提交 333be4e6 编写于 作者: J JiabinYang

merge hs

上级 d27d28f8
...@@ -50,11 +50,12 @@ def skip_gram_word2vec(dict_size, ...@@ -50,11 +50,12 @@ def skip_gram_word2vec(dict_size,
sample_weight=sample_weight, sample_weight=sample_weight,
param_attr=fluid.ParamAttr(name=w_param_name), param_attr=fluid.ParamAttr(name=w_param_name),
bias_attr=fluid.ParamAttr(name=b_param_name), bias_attr=fluid.ParamAttr(name=b_param_name),
num_neg_samples=num_neg_samples, is_sparse=is_sparse) num_neg_samples=num_neg_samples,
is_sparse=is_sparse)
return cost return cost
def hsigmoid_layer(input, label, ptable, pcode, non_leaf_num): def hsigmoid_layer(input, label, ptable, pcode, non_leaf_num, is_sparse):
if non_leaf_num is None: if non_leaf_num is None:
non_leaf_num = dict_size non_leaf_num = dict_size
...@@ -64,14 +65,16 @@ def skip_gram_word2vec(dict_size, ...@@ -64,14 +65,16 @@ def skip_gram_word2vec(dict_size,
non_leaf_num=non_leaf_num, non_leaf_num=non_leaf_num,
ptable=ptable, ptable=ptable,
pcode=pcode, pcode=pcode,
is_costum=True) is_costum=True,
is_sparse=is_sparse)
return cost return cost
datas = [] datas = []
input_word = fluid.layers.data(name="input_word", shape=[1], dtype='int64') input_word = fluid.layers.data(name="input_word", shape=[1], dtype='int64')
predict_word = fluid.layers.data(name='predict_word', shape=[1], dtype='int64') predict_word = fluid.layers.data(
name='predict_word', shape=[1], dtype='int64')
datas.append(input_word) datas.append(input_word)
datas.append(predict_word) datas.append(predict_word)
...@@ -87,10 +90,8 @@ def skip_gram_word2vec(dict_size, ...@@ -87,10 +90,8 @@ def skip_gram_word2vec(dict_size,
datas.append(ptable) datas.append(ptable)
datas.append(pcode) datas.append(pcode)
py_reader = fluid.layers.create_py_reader_by_data(capacity=64, py_reader = fluid.layers.create_py_reader_by_data(
feed_list=datas, capacity=64, feed_list=datas, name='py_reader', use_double_buffer=True)
name='py_reader',
use_double_buffer=True)
words = fluid.layers.read_file(py_reader) words = fluid.layers.read_file(py_reader)
...@@ -107,7 +108,8 @@ def skip_gram_word2vec(dict_size, ...@@ -107,7 +108,8 @@ def skip_gram_word2vec(dict_size,
cost = nce_layer(emb, words[1], embedding_size, dict_size, 5, "uniform", cost = nce_layer(emb, words[1], embedding_size, dict_size, 5, "uniform",
word_frequencys, None) word_frequencys, None)
if with_hsigmoid: if with_hsigmoid:
cost = hsigmoid_layer(emb, words[1], words[2], words[3], dict_size) cost = hsigmoid_layer(emb, words[1], words[2], words[3], dict_size,
is_sparse)
avg_cost = fluid.layers.reduce_mean(cost) avg_cost = fluid.layers.reduce_mean(cost)
......
...@@ -39,7 +39,8 @@ logger.setLevel(logging.INFO) ...@@ -39,7 +39,8 @@ logger.setLevel(logging.INFO)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle Word2vec example") parser = argparse.ArgumentParser(
description="PaddlePaddle Word2vec example")
parser.add_argument( parser.add_argument(
'--train_data_path', '--train_data_path',
type=str, type=str,
...@@ -87,7 +88,7 @@ def parse_args(): ...@@ -87,7 +88,7 @@ def parse_args():
'--with_nce', '--with_nce',
action='store_true', action='store_true',
required=False, required=False,
default=True, default=False,
help='using negtive sampling, (default: True)') help='using negtive sampling, (default: True)')
parser.add_argument( parser.add_argument(
...@@ -165,24 +166,28 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id): ...@@ -165,24 +166,28 @@ def train_loop(args, train_program, reader, py_reader, loss, trainer_id):
if batch_id % 10 == 0: if batch_id % 10 == 0:
logger.info("TRAIN --> pass: {} batch: {} loss: {}".format( logger.info("TRAIN --> pass: {} batch: {} loss: {}".format(
pass_id, batch_id, loss_val.mean() / args.batch_size)) pass_id, batch_id, loss_val.mean() / args.batch_size))
if batch_id % 1000 == 0 and batch_id != 0: if batch_id % 100 == 0 and batch_id != 0:
elapsed = (time.clock() - start) elapsed = (time.clock() - start)
logger.info("Time used: {}".format(elapsed)) logger.info("Time used: {}".format(elapsed))
if batch_id % 1000 == 0 and batch_id != 0: if batch_id % 1000 == 0 and batch_id != 0:
model_dir = args.model_output_dir + '/batch-' + str(batch_id) model_dir = args.model_output_dir + '/batch-' + str(
batch_id)
if trainer_id == 0: if trainer_id == 0:
fluid.io.save_inference_model(model_dir, data_name_list, [loss], exe) fluid.io.save_inference_model(model_dir, data_name_list,
[loss], exe)
batch_id += 1 batch_id += 1
except fluid.core.EOFException: except fluid.core.EOFException:
py_reader.reset() py_reader.reset()
epoch_end = time.time() epoch_end = time.time()
print("Epoch: {0}, Train total expend: {1} ".format(pass_id, epoch_end - epoch_start)) print("Epoch: {0}, Train total expend: {1} ".format(
pass_id, epoch_end - epoch_start))
model_dir = args.model_output_dir + '/pass-' + str(pass_id) model_dir = args.model_output_dir + '/pass-' + str(pass_id)
if trainer_id == 0: if trainer_id == 0:
fluid.io.save_inference_model(model_dir, data_name_list, [loss], exe) fluid.io.save_inference_model(model_dir, data_name_list,
[loss], exe)
def train(): def train():
...@@ -197,11 +202,15 @@ def train(): ...@@ -197,11 +202,15 @@ def train():
logger.info("dict_size: {}".format(word2vec_reader.dict_size)) logger.info("dict_size: {}".format(word2vec_reader.dict_size))
loss, py_reader = skip_gram_word2vec( loss, py_reader = skip_gram_word2vec(
word2vec_reader.dict_size, word2vec_reader.word_frequencys, word2vec_reader.dict_size,
args.embedding_size, args.max_code_length, word2vec_reader.word_frequencys,
args.with_hs, args.with_nce, is_sparse=args.is_sparse) args.embedding_size,
args.max_code_length,
optimizer = fluid.optimizer.Adam(learning_rate=1e-3) args.with_hs,
args.with_nce,
is_sparse=args.is_sparse)
optimizer = fluid.optimizer.SGD(learning_rate=1e-3)
optimizer.minimize(loss) optimizer.minimize(loss)
if os.getenv("PADDLE_IS_LOCAL", "1") == "1": if os.getenv("PADDLE_IS_LOCAL", "1") == "1":
...@@ -228,14 +237,20 @@ def train(): ...@@ -228,14 +237,20 @@ def train():
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.slice_var_up = False config.slice_var_up = False
t = fluid.DistributeTranspiler(config=config) t = fluid.DistributeTranspiler(config=config)
t.transpile(trainer_id, pservers=pserver_endpoints, trainers=trainers, sync_mode=True) t.transpile(
trainer_id,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=True)
if training_role == "PSERVER": if training_role == "PSERVER":
logger.info("run pserver") logger.info("run pserver")
prog = t.get_pserver_program(current_endpoint) prog = t.get_pserver_program(current_endpoint)
startup = t.get_startup_program(current_endpoint, pserver_program=prog) startup = t.get_startup_program(
current_endpoint, pserver_program=prog)
with open("pserver.main.proto.{}".format(os.getenv("CUR_PORT")), "w") as f: with open("pserver.main.proto.{}".format(os.getenv("CUR_PORT")),
"w") as f:
f.write(str(prog)) f.write(str(prog))
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
...@@ -248,8 +263,9 @@ def train(): ...@@ -248,8 +263,9 @@ def train():
with open("trainer.main.proto.{}".format(trainer_id), "w") as f: with open("trainer.main.proto.{}".format(trainer_id), "w") as f:
f.write(str(train_prog)) f.write(str(train_prog))
train_loop(args, train_prog, word2vec_reader, py_reader, loss, trainer_id) train_loop(args, train_prog, word2vec_reader, py_reader, loss,
trainer_id)
if __name__ == '__main__': if __name__ == '__main__':
train() train()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册