未验证 提交 ee0caea5 编写于 作者: H Huang Zhengjie 提交者: GitHub

Merge pull request #118 from Yelrose/master

fixed subgraph edge_feat inheritance
...@@ -589,25 +589,25 @@ class Graph(object): ...@@ -589,25 +589,25 @@ class Graph(object):
if eid is None and edges is None: if eid is None and edges is None:
raise ValueError("Eid and edges can't be None at the same time.") raise ValueError("Eid and edges can't be None at the same time.")
sub_edge_feat = {}
if edges is None: if edges is None:
edges = self._edges[eid] edges = self._edges[eid]
else: else:
edges = np.array(edges, dtype="int64") edges = np.array(edges, dtype="int64")
sub_edges = graph_kernel.map_edges(
np.arange(
len(edges), dtype="int64"), edges, reindex)
sub_edge_feat = {}
if edges is None:
if with_edge_feat: if with_edge_feat:
for key, value in self._edge_feat.items(): for key, value in self._edge_feat.items():
if eid is None: if eid is None:
raise ValueError( raise ValueError(
"Eid can not be None with edge features.") "Eid can not be None with edge features.")
sub_edge_feat[key] = value[eid] sub_edge_feat[key] = value[eid]
else:
sub_edge_feat = edge_feats if edge_feats is not None:
sub_edge_feat.update(edge_feats)
sub_edges = graph_kernel.map_edges(
np.arange(
len(edges), dtype="int64"), edges, reindex)
sub_node_feat = {} sub_node_feat = {}
if with_node_feat: if with_node_feat:
......
...@@ -109,8 +109,6 @@ def graphsage_sample(graph, nodes, samples, ignore_edges=[]): ...@@ -109,8 +109,6 @@ def graphsage_sample(graph, nodes, samples, ignore_edges=[]):
start = time.time() start = time.time()
# Find new nodes # Find new nodes
feed_dict = {}
subgraphs = [] subgraphs = []
for i in range(num_layers): for i in range(num_layers):
subgraphs.append( subgraphs.append(
...@@ -471,7 +469,8 @@ def pinsage_sample(graph, ...@@ -471,7 +469,8 @@ def pinsage_sample(graph,
graph.subgraph( graph.subgraph(
nodes=layer_nodes[0], nodes=layer_nodes[0],
edges=layer_edges[i], edges=layer_edges[i],
edge_feats=edge_feat_dict)) edge_feats=edge_feat_dict,
with_edge_feat=False))
subgraphs[i].node_feat["index"] = np.array( subgraphs[i].node_feat["index"] = np.array(
layer_nodes[0], dtype="int64") layer_nodes[0], dtype="int64")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册