From 1ff69d401aa3329b3ffac7ae8dc8086cab0ab049 Mon Sep 17 00:00:00 2001 From: Yelrose <270018958@qq.com> Date: Wed, 12 Aug 2020 13:26:42 +0800 Subject: [PATCH] fixed dim inference --- pgl/layers/graph_op.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pgl/layers/graph_op.py b/pgl/layers/graph_op.py index 1528bfe..de63059 100644 --- a/pgl/layers/graph_op.py +++ b/pgl/layers/graph_op.py @@ -86,13 +86,14 @@ def graph_gather(gw, feature, index): A tensor with shape (num_graph, k1, k2, k3, ..., kn, hidden_size) """ shape = L.shape(index) + output_dim = feature.shape[-1] index = index + gw.graph_lod[:-1] index = L.reshape(index, [-1]) feature = L.gather(feature, index, overwrite=False) new_shape = [] for i in range(shape.shape[0]): new_shape.append(shape[i]) - new_shape.append(-1) + new_shape.append(output_dim) feature = L.reshape(feature, new_shape) return feature -- GitLab