提交 1ff69d40 编写于 作者: Y Yelrose

fixed dim inference

上级 2dfbbe82
......@@ -86,13 +86,14 @@ def graph_gather(gw, feature, index):
A tensor with shape (num_graph, k1, k2, k3, ..., kn, hidden_size)
"""
shape = L.shape(index)
output_dim = feature.shape[-1]
index = index + gw.graph_lod[:-1]
index = L.reshape(index, [-1])
feature = L.gather(feature, index, overwrite=False)
new_shape = []
for i in range(shape.shape[0]):
new_shape.append(shape[i])
new_shape.append(-1)
new_shape.append(output_dim)
feature = L.reshape(feature, new_shape)
return feature
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册