diff --git a/examples/SAGPool/layers.py b/examples/SAGPool/layers.py index f1f45fb7d0eaf1831e4ae87aca0e286579a25e18..3dfa0822ece9e564adfbc9b15e0a62e1e1f4b08d 100644 --- a/examples/SAGPool/layers.py +++ b/examples/SAGPool/layers.py @@ -53,7 +53,7 @@ def topk_pool(gw, score, graph_id, ratio): index = L.arange(0, gw.num_nodes, dtype="int64") offset = L.gather(graph_lod, graph_id, overwrite=False) - index = (index - temp) + (graph_id * max_num_nodes) + index = (index - offset) + (graph_id * max_num_nodes) index.stop_gradient = True # padding