未验证 提交 eeb267da 编写于 作者: W Weiyue Su 提交者: GitHub

Merge pull request #79 from WeiyueSu/erniesage

Erniesage
...@@ -31,7 +31,7 @@ final_fc: true ...@@ -31,7 +31,7 @@ final_fc: true
final_l2_norm: true final_l2_norm: true
loss_type: "hinge" loss_type: "hinge"
margin: 0.3 margin: 0.3
neg_type: "random_neg" neg_type: "batch_neg"
# infer config ------ # infer config ------
infer_model: "./output/last" infer_model: "./output/last"
......
...@@ -31,7 +31,7 @@ final_fc: true ...@@ -31,7 +31,7 @@ final_fc: true
final_l2_norm: true final_l2_norm: true
loss_type: "hinge" loss_type: "hinge"
margin: 0.3 margin: 0.3
neg_type: "random_neg" neg_type: "batch_neg"
# infer config ------ # infer config ------
infer_model: "./output/last" infer_model: "./output/last"
......
...@@ -24,7 +24,7 @@ from pgl.sample import edge_hash ...@@ -24,7 +24,7 @@ from pgl.sample import edge_hash
class GraphGenerator(BaseDataGenerator): class GraphGenerator(BaseDataGenerator):
def __init__(self, graph_wrappers, data, batch_size, samples, def __init__(self, graph_wrappers, data, batch_size, samples,
num_workers, feed_name_list, use_pyreader, num_workers, feed_name_list, use_pyreader,
phase, graph_data_path, shuffle=True, buf_size=1000): phase, graph_data_path, shuffle=True, buf_size=1000, neg_type="batch_neg"):
super(GraphGenerator, self).__init__( super(GraphGenerator, self).__init__(
buf_size=buf_size, buf_size=buf_size,
...@@ -40,6 +40,7 @@ class GraphGenerator(BaseDataGenerator): ...@@ -40,6 +40,7 @@ class GraphGenerator(BaseDataGenerator):
self.phase = phase self.phase = phase
self.load_graph(graph_data_path) self.load_graph(graph_data_path)
self.num_layers = len(graph_wrappers) self.num_layers = len(graph_wrappers)
self.neg_type= neg_type
def load_graph(self, graph_data_path): def load_graph(self, graph_data_path):
self.graph = pgl.graph.MemmapGraph(graph_data_path) self.graph = pgl.graph.MemmapGraph(graph_data_path)
...@@ -72,7 +73,11 @@ class GraphGenerator(BaseDataGenerator): ...@@ -72,7 +73,11 @@ class GraphGenerator(BaseDataGenerator):
batch_src = np.array(batch_src, dtype="int64") batch_src = np.array(batch_src, dtype="int64")
batch_dst = np.array(batch_dst, dtype="int64") batch_dst = np.array(batch_dst, dtype="int64")
sampled_batch_neg = alias_sample(batch_dst.shape, self.alias, self.events) if self.neg_type == "batch_neg":
neg_shape = [1]
else:
neg_shape = batch_dst.shape
sampled_batch_neg = alias_sample(neg_shape, self.alias, self.events)
if len(batch_neg) > 0: if len(batch_neg) > 0:
batch_neg = np.concatenate([batch_neg, sampled_batch_neg], 0) batch_neg = np.concatenate([batch_neg, sampled_batch_neg], 0)
...@@ -80,7 +85,8 @@ class GraphGenerator(BaseDataGenerator): ...@@ -80,7 +85,8 @@ class GraphGenerator(BaseDataGenerator):
batch_neg = sampled_batch_neg batch_neg = sampled_batch_neg
if self.phase == "train": if self.phase == "train":
ignore_edges = np.concatenate([np.stack([batch_src, batch_dst], 1), np.stack([batch_dst, batch_src], 1)], 0) #ignore_edges = np.concatenate([np.stack([batch_src, batch_dst], 1), np.stack([batch_dst, batch_src], 1)], 0)
ignore_edges = set()
else: else:
ignore_edges = set() ignore_edges = set()
......
...@@ -191,12 +191,12 @@ def all_gather(X): ...@@ -191,12 +191,12 @@ def all_gather(X):
for i in range(trainer_num): for i in range(trainer_num):
copy_X = X * 1 copy_X = X * 1
copy_X = L.collective._broadcast(copy_X, i, True) copy_X = L.collective._broadcast(copy_X, i, True)
copy_X.stop_gradients=True copy_X.stop_gradient=True
Xs.append(copy_X) Xs.append(copy_X)
if len(Xs) > 1: if len(Xs) > 1:
Xs=L.concat(Xs, 0) Xs=L.concat(Xs, 0)
Xs.stop_gradients=True Xs.stop_gradient=True
else: else:
Xs = Xs[0] Xs = Xs[0]
return Xs return Xs
......
...@@ -27,7 +27,7 @@ class ErnieSageV2(BaseNet): ...@@ -27,7 +27,7 @@ class ErnieSageV2(BaseNet):
src_position_ids = L.expand(src_position_ids, [src_batch, 1, 1]) # [B, slot_seqlen * num_b, 1] src_position_ids = L.expand(src_position_ids, [src_batch, 1, 1]) # [B, slot_seqlen * num_b, 1]
zero = L.fill_constant([1], dtype='int64', value=0) zero = L.fill_constant([1], dtype='int64', value=0)
input_mask = L.cast(L.equal(src_ids, zero), "int32") # assume pad id == 0 [B, slot_seqlen, 1] input_mask = L.cast(L.equal(src_ids, zero), "int32") # assume pad id == 0 [B, slot_seqlen, 1]
src_pad_len = L.reduce_sum(input_mask, 1) # [B, 1, 1] src_pad_len = L.reduce_sum(input_mask, 1, keep_dim=True) # [B, 1, 1]
dst_position_ids = L.reshape( dst_position_ids = L.reshape(
L.range( L.range(
......
...@@ -32,8 +32,9 @@ class TrainData(object): ...@@ -32,8 +32,9 @@ class TrainData(object):
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
log.info("trainer_id: %s, trainer_count: %s." % (trainer_id, trainer_count)) log.info("trainer_id: %s, trainer_count: %s." % (trainer_id, trainer_count))
edges = np.load(os.path.join(graph_path, "edges.npy"), allow_pickle=True) bidirectional_edges = np.load(os.path.join(graph_path, "edges.npy"), allow_pickle=True)
# edges is bidirectional. # edges is bidirectional.
edges = bidirectional_edges[0::2]
train_usr = edges[trainer_id::trainer_count, 0] train_usr = edges[trainer_id::trainer_count, 0]
train_ad = edges[trainer_id::trainer_count, 1] train_ad = edges[trainer_id::trainer_count, 1]
returns = { returns = {
...@@ -73,7 +74,8 @@ def main(config): ...@@ -73,7 +74,8 @@ def main(config):
use_pyreader=config.use_pyreader, use_pyreader=config.use_pyreader,
phase="train", phase="train",
graph_data_path=config.graph_path, graph_data_path=config.graph_path,
shuffle=True) shuffle=True,
neg_type=config.neg_type)
log.info("build graph reader done.") log.info("build graph reader done.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册