提交 4e5c920a 编写于 作者: S suweiyue

1. dataset with neg_type, 2. never ignore edges

上级 08da20a6
...@@ -24,7 +24,7 @@ from pgl.sample import edge_hash ...@@ -24,7 +24,7 @@ from pgl.sample import edge_hash
class GraphGenerator(BaseDataGenerator): class GraphGenerator(BaseDataGenerator):
def __init__(self, graph_wrappers, data, batch_size, samples, def __init__(self, graph_wrappers, data, batch_size, samples,
num_workers, feed_name_list, use_pyreader, num_workers, feed_name_list, use_pyreader,
phase, graph_data_path, shuffle=True, buf_size=1000): phase, graph_data_path, shuffle=True, buf_size=1000, neg_type="batch_neg"):
super(GraphGenerator, self).__init__( super(GraphGenerator, self).__init__(
buf_size=buf_size, buf_size=buf_size,
...@@ -40,6 +40,7 @@ class GraphGenerator(BaseDataGenerator): ...@@ -40,6 +40,7 @@ class GraphGenerator(BaseDataGenerator):
self.phase = phase self.phase = phase
self.load_graph(graph_data_path) self.load_graph(graph_data_path)
self.num_layers = len(graph_wrappers) self.num_layers = len(graph_wrappers)
self.neg_type= neg_type
def load_graph(self, graph_data_path): def load_graph(self, graph_data_path):
self.graph = pgl.graph.MemmapGraph(graph_data_path) self.graph = pgl.graph.MemmapGraph(graph_data_path)
...@@ -72,7 +73,11 @@ class GraphGenerator(BaseDataGenerator): ...@@ -72,7 +73,11 @@ class GraphGenerator(BaseDataGenerator):
batch_src = np.array(batch_src, dtype="int64") batch_src = np.array(batch_src, dtype="int64")
batch_dst = np.array(batch_dst, dtype="int64") batch_dst = np.array(batch_dst, dtype="int64")
sampled_batch_neg = alias_sample(batch_dst.shape, self.alias, self.events) if neg_type == "batch_neg":
neg_shape = [1]
else:
neg_shape = batch_dst.shape
sampled_batch_neg = alias_sample(neg_shape, self.alias, self.events)
if len(batch_neg) > 0: if len(batch_neg) > 0:
batch_neg = np.concatenate([batch_neg, sampled_batch_neg], 0) batch_neg = np.concatenate([batch_neg, sampled_batch_neg], 0)
...@@ -80,7 +85,8 @@ class GraphGenerator(BaseDataGenerator): ...@@ -80,7 +85,8 @@ class GraphGenerator(BaseDataGenerator):
batch_neg = sampled_batch_neg batch_neg = sampled_batch_neg
if self.phase == "train": if self.phase == "train":
ignore_edges = np.concatenate([np.stack([batch_src, batch_dst], 1), np.stack([batch_dst, batch_src], 1)], 0) #ignore_edges = np.concatenate([np.stack([batch_src, batch_dst], 1), np.stack([batch_dst, batch_src], 1)], 0)
ignore_edges = set()
else: else:
ignore_edges = set() ignore_edges = set()
......
...@@ -32,8 +32,9 @@ class TrainData(object): ...@@ -32,8 +32,9 @@ class TrainData(object):
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
log.info("trainer_id: %s, trainer_count: %s." % (trainer_id, trainer_count)) log.info("trainer_id: %s, trainer_count: %s." % (trainer_id, trainer_count))
edges = np.load(os.path.join(graph_path, "edges.npy"), allow_pickle=True) bidirectional_edges = np.load(os.path.join(graph_path, "edges.npy"), allow_pickle=True)
# edges is bidirectional. # edges is bidirectional.
edges = bidirectional_edges[0::2]
train_usr = edges[trainer_id::trainer_count, 0] train_usr = edges[trainer_id::trainer_count, 0]
train_ad = edges[trainer_id::trainer_count, 1] train_ad = edges[trainer_id::trainer_count, 1]
returns = { returns = {
...@@ -73,7 +74,8 @@ def main(config): ...@@ -73,7 +74,8 @@ def main(config):
use_pyreader=config.use_pyreader, use_pyreader=config.use_pyreader,
phase="train", phase="train",
graph_data_path=config.graph_path, graph_data_path=config.graph_path,
shuffle=True) shuffle=True,
neg_type=config.neg_type)
log.info("build graph reader done.") log.info("build graph reader done.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册