提交 65b49ba5 编写于 作者: R root

fix criteo

上级 0961379e
...@@ -44,12 +44,13 @@ for ei in range(10000): ...@@ -44,12 +44,13 @@ for ei in range(10000):
feed_dict = {} feed_dict = {}
feed_dict['dense_input'] = np.array(data[0][0]).astype("float32").reshape( feed_dict['dense_input'] = np.array(data[0][0]).astype("float32").reshape(
1, 13) 1, 13)
feed_dict['dense_input.lod'] = [0, 1]
for i in range(1, 27): for i in range(1, 27):
tmp_data = np.array(data[0][i]).astype(np.int64) tmp_data = np.array(data[0][i]).astype(np.int64)
feed_dict["embedding_{}.tmp_0".format(i - 1)] = tmp_data.reshape( feed_dict["embedding_{}.tmp_0".format(i - 1)] = tmp_data.reshape(
(1, len(data[0][i]))) (1, len(data[0][i])))
print(feed_dict) feed_dict["embedding_{}.tmp_0.lod".format(i - 1)] = [0, 1]
fetch_map = client.predict(feed=feed_dict, fetch=["prob"]) fetch_map = client.predict(feed=feed_dict, fetch=["prob"], batch=True)
prob_list.append(fetch_map['prob'][0][1]) prob_list.append(fetch_map['prob'][0][1])
label_list.append(data[0][-1][0]) label_list.append(data[0][-1][0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册