未验证 提交 c2d45238 编写于 作者: Z zhang wenhui 提交者: GitHub

0515 fix ssr (#4630)

* update api 1.8

* fix ssr infer
上级 7616d3bc
......@@ -40,13 +40,13 @@ def model(vocab_size, emb_size, hidden_size):
user_data = fluid.data(
name="user", shape=[None, 1], dtype="int64", lod_level=1)
all_item_data = fluid.data(
name="all_item", shape=[None, vocab_size, 1], dtype="int64")
name="all_item", shape=[None, vocab_size], dtype="int64")
user_emb = fluid.embedding(
input=user_data, size=[vocab_size, emb_size], param_attr="emb.item")
all_item_emb = fluid.embedding(
input=all_item_data, size=[vocab_size, emb_size], param_attr="emb.item")
all_item_emb_re = all_item_emb
all_item_emb_re = fluid.layers.reshape(x=all_item_emb, shape=[-1, emb_size])
user_encoder = net.GrnnEncoder(hidden_size=hidden_size)
user_enc = user_encoder.forward(user_emb)
......@@ -94,7 +94,7 @@ def infer(args, vocab_size, test_reader):
user_data, pos_label = utils.infer_data(data, place)
all_item_numpy = np.tile(
np.arange(vocab_size), len(pos_label)).reshape(
len(pos_label), vocab_size, 1).astype("int64")
len(pos_label), vocab_size).astype("int64")
para = exe.run(copy_program,
feed={
"user": user_data,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册