提交 4e5c920a 编写于 作者: S suweiyue

1. dataset with neg_type, 2. never ignore edges

上级 08da20a6
......@@ -24,7 +24,7 @@ from pgl.sample import edge_hash
class GraphGenerator(BaseDataGenerator):
def __init__(self, graph_wrappers, data, batch_size, samples,
num_workers, feed_name_list, use_pyreader,
phase, graph_data_path, shuffle=True, buf_size=1000):
phase, graph_data_path, shuffle=True, buf_size=1000, neg_type="batch_neg"):
super(GraphGenerator, self).__init__(
buf_size=buf_size,
......@@ -40,6 +40,7 @@ class GraphGenerator(BaseDataGenerator):
self.phase = phase
self.load_graph(graph_data_path)
self.num_layers = len(graph_wrappers)
self.neg_type= neg_type
def load_graph(self, graph_data_path):
self.graph = pgl.graph.MemmapGraph(graph_data_path)
......@@ -72,7 +73,11 @@ class GraphGenerator(BaseDataGenerator):
batch_src = np.array(batch_src, dtype="int64")
batch_dst = np.array(batch_dst, dtype="int64")
sampled_batch_neg = alias_sample(batch_dst.shape, self.alias, self.events)
if neg_type == "batch_neg":
neg_shape = [1]
else:
neg_shape = batch_dst.shape
sampled_batch_neg = alias_sample(neg_shape, self.alias, self.events)
if len(batch_neg) > 0:
batch_neg = np.concatenate([batch_neg, sampled_batch_neg], 0)
......@@ -80,7 +85,8 @@ class GraphGenerator(BaseDataGenerator):
batch_neg = sampled_batch_neg
if self.phase == "train":
ignore_edges = np.concatenate([np.stack([batch_src, batch_dst], 1), np.stack([batch_dst, batch_src], 1)], 0)
#ignore_edges = np.concatenate([np.stack([batch_src, batch_dst], 1), np.stack([batch_dst, batch_src], 1)], 0)
ignore_edges = set()
else:
ignore_edges = set()
......
......@@ -32,8 +32,9 @@ class TrainData(object):
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
log.info("trainer_id: %s, trainer_count: %s." % (trainer_id, trainer_count))
edges = np.load(os.path.join(graph_path, "edges.npy"), allow_pickle=True)
bidirectional_edges = np.load(os.path.join(graph_path, "edges.npy"), allow_pickle=True)
# edges is bidirectional.
edges = bidirectional_edges[0::2]
train_usr = edges[trainer_id::trainer_count, 0]
train_ad = edges[trainer_id::trainer_count, 1]
returns = {
......@@ -73,7 +74,8 @@ def main(config):
use_pyreader=config.use_pyreader,
phase="train",
graph_data_path=config.graph_path,
shuffle=True)
shuffle=True,
neg_type=config.neg_type)
log.info("build graph reader done.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册