未验证 提交 dbd67cb9 编写于 作者: Z zhang wenhui 提交者: GitHub

fix window numpy int64 bug in ssr/word2vec (#2671)

* fix word2vec Readme

* fix windows numpy int64 bug
上级 d89d8144
...@@ -95,7 +95,7 @@ def infer(args, vocab_size, test_reader): ...@@ -95,7 +95,7 @@ def infer(args, vocab_size, test_reader):
user_data, pos_label = utils.infer_data(data, place) user_data, pos_label = utils.infer_data(data, place)
all_item_numpy = np.tile( all_item_numpy = np.tile(
np.arange(vocab_size), len(pos_label)).reshape( np.arange(vocab_size), len(pos_label)).reshape(
len(pos_label), vocab_size, 1) len(pos_label), vocab_size, 1).astype("int64")
para = exe.run(copy_program, para = exe.run(copy_program,
feed={ feed={
"user": user_data, "user": user_data,
......
...@@ -88,17 +88,17 @@ def infer_epoch(args, vocab_size, test_reader, use_cuda, i2w): ...@@ -88,17 +88,17 @@ def infer_epoch(args, vocab_size, test_reader, use_cuda, i2w):
label = [dat[3] for dat in data] label = [dat[3] for dat in data]
input_word = [dat[4] for dat in data] input_word = [dat[4] for dat in data]
para = exe.run( para = exe.run(copy_program,
copy_program, feed={
feed={ "analogy_a": wa,
"analogy_a": wa, "analogy_b": wb,
"analogy_b": wb, "analogy_c": wc,
"analogy_c": wc, "all_label":
"all_label": np.arange(vocab_size).reshape(
np.arange(vocab_size).reshape(vocab_size, 1), vocab_size, 1).astype("int64"),
}, },
fetch_list=[pred.name, values], fetch_list=[pred.name, values],
return_numpy=False) return_numpy=False)
pre = np.array(para[0]) pre = np.array(para[0])
val = np.array(para[1]) val = np.array(para[1])
for ii in range(len(label)): for ii in range(len(label)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册