未验证 提交 eeb267da 编写于 作者: W Weiyue Su 提交者: GitHub

Merge pull request #79 from WeiyueSu/erniesage

Erniesage
......@@ -31,7 +31,7 @@ final_fc: true
final_l2_norm: true
loss_type: "hinge"
margin: 0.3
neg_type: "random_neg"
neg_type: "batch_neg"
# infer config ------
infer_model: "./output/last"
......
......@@ -31,7 +31,7 @@ final_fc: true
final_l2_norm: true
loss_type: "hinge"
margin: 0.3
neg_type: "random_neg"
neg_type: "batch_neg"
# infer config ------
infer_model: "./output/last"
......
......@@ -24,7 +24,7 @@ from pgl.sample import edge_hash
class GraphGenerator(BaseDataGenerator):
def __init__(self, graph_wrappers, data, batch_size, samples,
num_workers, feed_name_list, use_pyreader,
phase, graph_data_path, shuffle=True, buf_size=1000):
phase, graph_data_path, shuffle=True, buf_size=1000, neg_type="batch_neg"):
super(GraphGenerator, self).__init__(
buf_size=buf_size,
......@@ -40,6 +40,7 @@ class GraphGenerator(BaseDataGenerator):
self.phase = phase
self.load_graph(graph_data_path)
self.num_layers = len(graph_wrappers)
self.neg_type= neg_type
def load_graph(self, graph_data_path):
self.graph = pgl.graph.MemmapGraph(graph_data_path)
......@@ -72,7 +73,11 @@ class GraphGenerator(BaseDataGenerator):
batch_src = np.array(batch_src, dtype="int64")
batch_dst = np.array(batch_dst, dtype="int64")
sampled_batch_neg = alias_sample(batch_dst.shape, self.alias, self.events)
if self.neg_type == "batch_neg":
neg_shape = [1]
else:
neg_shape = batch_dst.shape
sampled_batch_neg = alias_sample(neg_shape, self.alias, self.events)
if len(batch_neg) > 0:
batch_neg = np.concatenate([batch_neg, sampled_batch_neg], 0)
......@@ -80,7 +85,8 @@ class GraphGenerator(BaseDataGenerator):
batch_neg = sampled_batch_neg
if self.phase == "train":
ignore_edges = np.concatenate([np.stack([batch_src, batch_dst], 1), np.stack([batch_dst, batch_src], 1)], 0)
#ignore_edges = np.concatenate([np.stack([batch_src, batch_dst], 1), np.stack([batch_dst, batch_src], 1)], 0)
ignore_edges = set()
else:
ignore_edges = set()
......
......@@ -191,12 +191,12 @@ def all_gather(X):
for i in range(trainer_num):
copy_X = X * 1
copy_X = L.collective._broadcast(copy_X, i, True)
copy_X.stop_gradients=True
copy_X.stop_gradient=True
Xs.append(copy_X)
if len(Xs) > 1:
Xs=L.concat(Xs, 0)
Xs.stop_gradients=True
Xs.stop_gradient=True
else:
Xs = Xs[0]
return Xs
......
......@@ -27,7 +27,7 @@ class ErnieSageV2(BaseNet):
src_position_ids = L.expand(src_position_ids, [src_batch, 1, 1]) # [B, slot_seqlen * num_b, 1]
zero = L.fill_constant([1], dtype='int64', value=0)
input_mask = L.cast(L.equal(src_ids, zero), "int32") # assume pad id == 0 [B, slot_seqlen, 1]
src_pad_len = L.reduce_sum(input_mask, 1) # [B, 1, 1]
src_pad_len = L.reduce_sum(input_mask, 1, keep_dim=True) # [B, 1, 1]
dst_position_ids = L.reshape(
L.range(
......
......@@ -32,8 +32,9 @@ class TrainData(object):
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
log.info("trainer_id: %s, trainer_count: %s." % (trainer_id, trainer_count))
edges = np.load(os.path.join(graph_path, "edges.npy"), allow_pickle=True)
bidirectional_edges = np.load(os.path.join(graph_path, "edges.npy"), allow_pickle=True)
# edges is bidirectional.
edges = bidirectional_edges[0::2]
train_usr = edges[trainer_id::trainer_count, 0]
train_ad = edges[trainer_id::trainer_count, 1]
returns = {
......@@ -73,7 +74,8 @@ def main(config):
use_pyreader=config.use_pyreader,
phase="train",
graph_data_path=config.graph_path,
shuffle=True)
shuffle=True,
neg_type=config.neg_type)
log.info("build graph reader done.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册