提交 66f866f1 编写于 作者: Q qiaolongfei

add save/load dict_and_embedding for word2vector

上级 b3754c77
import math, os
import numpy
import paddle.v2 as paddle
......@@ -18,6 +19,31 @@ def wordemb(inlayer):
return wordemb
# save and load word dict and embedding table
def save_dict_and_embedding(word_dict, embeddings):
with open("word_dict", "w") as f:
for key in word_dict:
f.write(key + " " + str(word_dict[key]) + "\n")
with open("embedding_table", "w") as f:
for line in embeddings:
f.write(",".join([str(x) for x in line]) + "\n")
def load_dict_and_embedding():
word_dict = dict()
embeddings = []
with open("word_dict", "r") as f:
for line in f:
key, value = line.strip().split(" ")
word_dict[key] = value
with open("embedding_table", "r") as f:
for line in f:
embeddings.append(
numpy.array([float(x) for x in line.strip().split(',')]))
return word_dict, embeddings
def main():
paddle.init(use_gpu=with_gpu, trainer_count=3)
word_dict = paddle.dataset.imikolov.build_dict()
......@@ -76,9 +102,13 @@ def main():
trainer = paddle.trainer.SGD(cost, parameters, adagrad)
trainer.train(
paddle.batch(paddle.dataset.imikolov.train(word_dict, N), 32),
num_passes=100,
num_passes=1,
event_handler=event_handler)
# save word dict and embedding table
embeddings = parameters.get("_proj").reshape(len(word_dict), embsize)
save_dict_and_embedding(word_dict, embeddings)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册