提交 570bf814 编写于 作者: L liweibin

release hetergraph

上级 38057f1d
......@@ -21,7 +21,7 @@ import tqdm
import numpy as np
import logging
import random
from pgl.contrib import heter_graph
from pgl import heter_graph
import pickle as pkl
......
......@@ -21,7 +21,7 @@ import logging
import paddle.fluid as fluid
import paddle.fluid.layers as fl
from pgl.contrib import heter_graph_wrapper
from pgl import heter_graph_wrapper
class GATNE(object):
......
......@@ -23,7 +23,7 @@ import tqdm
import time
import logging
import random
from pgl.contrib import heter_graph
from pgl import heter_graph
import pickle as pkl
......@@ -71,8 +71,12 @@ class Dataset(object):
if len(walk) > 1:
self.sentences_count += 1
for word in walk:
self.token_count += 1
word_freq[word] = word_freq.get(word, 0) + 1
if int(word) >= self.config[
'paper_start_index']: # remove paper
continue
else:
self.token_count += 1
word_freq[word] = word_freq.get(word, 0) + 1
wid = 0
logging.info('Read %d sentences.' % self.sentences_count)
......@@ -126,6 +130,10 @@ class Dataset(object):
with open(filename) as reader:
for line in reader:
words = line.strip().split()
words = [
w for w in words
if int(w) < self.config['paper_start_index']
]
if len(words) > 1:
word_ids = [
self.word2id[w] for w in words if w in self.word2id
......
......@@ -42,9 +42,10 @@ data_loader:
walk_path: walks/*
word2id_file: word2id.pkl
batch_size: 32
win_size: 7 # default: 7
win_size: 5 # default: 7
neg_num: 5
min_count: 10
paper_start_index: 1697414
model:
type: SkipgramModel
......
......@@ -28,7 +28,7 @@ import tqdm
import time
import logging
import random
from pgl.contrib import heter_graph
from pgl import heter_graph
from pgl.sample import metapath_randomwalk
from utils import *
......
......@@ -18,4 +18,6 @@ from pgl import layers
from pgl import graph_wrapper
from pgl import graph
from pgl import data_loader
from pgl import heter_graph
from pgl import heter_graph_wrapper
from pgl import contrib
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generate Contrib api
"""
from pgl.contrib import heter_graph
from pgl.contrib import heter_graph_wrapper
......@@ -64,8 +64,8 @@ class HeterGraphWrapper(object):
import paddle.fluid as fluid
import numpy as np
from pgl.contrib import heter_graph
from pgl.contrib import heter_graph_wrapper
from pgl import heter_graph
from pgl import heter_graph_wrapper
num_nodes = 4
node_types = [(0, 'user'), (1, 'item'), (2, 'item'), (3, 'user')]
edges = {
......
......@@ -28,7 +28,7 @@ import pgl.graph as pgraph
import pickle as pkl
from pgl.utils.logger import log
import pgl.graph_kernel as graph_kernel
from pgl.contrib import heter_graph
from pgl import heter_graph
import pgl.redis_graph as rg
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test_hetergraph"""
import time
import unittest
import json
import os
import numpy as np
from pgl.sample import metapath_randomwalk
from pgl.graph import Graph
from pgl import heter_graph
class HeterGraphTest(unittest.TestCase):
"""HeterGraph test
"""
@classmethod
def setUpClass(cls):
np.random.seed(1)
edges = {}
# for test no successor
edges['c2p'] = [(1, 4), (0, 5), (1, 9), (1, 8), (2, 8), (2, 5), (3, 6),
(3, 7), (3, 4), (3, 8)]
edges['p2c'] = [(v, u) for u, v in edges['c2p']]
edges['p2a'] = [(4, 10), (4, 11), (4, 12), (4, 14), (4, 13), (6, 12),
(6, 11), (6, 14), (7, 12), (7, 11), (8, 14), (9, 10)]
edges['a2p'] = [(v, u) for u, v in edges['p2a']]
# for test speed
# edges['c2p'] = [(0, 4), (0, 5), (1, 9), (1,8), (2,8), (2,5), (3,6), (3,7), (3,4), (3,8)]
# edges['p2c'] = [(v,u) for u, v in edges['c2p']]
# edges['p2a'] = [(4,10), (4,11), (4,12), (4,14), (5,13), (6,13), (6,11), (6,14), (7,12), (7,11), (8,14), (9,13)]
# edges['a2p'] = [(v,u) for u, v in edges['p2a']]
node_types = ['c' for _ in range(4)] + ['p' for _ in range(6)
] + ['a' for _ in range(5)]
node_types = [(i, t) for i, t in enumerate(node_types)]
cls.graph = heter_graph.HeterGraph(
num_nodes=len(node_types), edges=edges, node_types=node_types)
def test_num_nodes_by_type(self):
print()
n_types = {'c': 4, 'p': 6, 'a': 5}
for nt in n_types:
num_nodes = self.graph.num_nodes_by_type(nt)
self.assertEqual(num_nodes, n_types[nt])
def test_node_batch_iter(self):
print()
batch_size = 2
ground = [[4, 5], [6, 7], [8, 9]]
for idx, nodes in enumerate(
self.graph.node_batch_iter(
batch_size=batch_size, shuffle=False, n_type='p')):
self.assertEqual(len(nodes), batch_size)
self.assertListEqual(list(nodes), ground[idx])
def test_sample_nodes(self):
print()
p_ground = [4, 5, 6, 7, 8, 9]
sample_num = 10
nodes = self.graph.sample_nodes(sample_num=sample_num, n_type='p')
self.assertEqual(len(nodes), sample_num)
for n in nodes:
self.assertIn(n, p_ground)
# test n_type == None
ground = [i for i in range(15)]
nodes = self.graph.sample_nodes(sample_num=sample_num, n_type=None)
self.assertEqual(len(nodes), sample_num)
for n in nodes:
self.assertIn(n, ground)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test_metapath_randomwalk"""
import time
import unittest
import json
import os
import numpy as np
from pgl.sample import metapath_randomwalk
from pgl.graph import Graph
from pgl import heter_graph
np.random.seed(1)
class MetapathRandomwalkTest(unittest.TestCase):
"""metapath_randomwalk test
"""
def setUp(self):
edges = {}
# for test no successor
edges['c2p'] = [(1, 4), (0, 5), (1, 9), (1, 8), (2, 8), (2, 5), (3, 6),
(3, 7), (3, 4), (3, 8)]
edges['p2c'] = [(v, u) for u, v in edges['c2p']]
edges['p2a'] = [(4, 10), (4, 11), (4, 12), (4, 14), (4, 13), (6, 12),
(6, 11), (6, 14), (7, 12), (7, 11), (8, 14), (9, 10)]
edges['a2p'] = [(v, u) for u, v in edges['p2a']]
# for test speed
# edges['c2p'] = [(0, 4), (0, 5), (1, 9), (1,8), (2,8), (2,5), (3,6), (3,7), (3,4), (3,8)]
# edges['p2c'] = [(v,u) for u, v in edges['c2p']]
# edges['p2a'] = [(4,10), (4,11), (4,12), (4,14), (5,13), (6,13), (6,11), (6,14), (7,12), (7,11), (8,14), (9,13)]
# edges['a2p'] = [(v,u) for u, v in edges['p2a']]
self.node_types = ['c' for _ in range(4)] + [
'p' for _ in range(6)
] + ['a' for _ in range(5)]
node_types = [(i, t) for i, t in enumerate(self.node_types)]
self.graph = heter_graph.HeterGraph(
num_nodes=len(node_types), edges=edges, node_types=node_types)
def test_metapath_randomwalk(self):
meta_path = 'c2p-p2a-a2p-p2c'
path = ['c', 'p', 'a', 'p', 'c']
start_nodes = [0, 1, 2, 3]
walk_len = 10
walks = metapath_randomwalk(
graph=self.graph,
start_nodes=start_nodes,
metapath=meta_path,
walk_length=walk_len)
self.assertEqual(len(walks), 4)
for walk in walks:
for i in range(len(walk)):
idx = i % (len(path) - 1)
self.assertEqual(self.node_types[walk[i]], path[idx])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册