未验证 提交 b128cb50 编写于 作者: Z Zhong Hui 提交者: GitHub

fix erniesage for pgl 2.0 alpha

fix erniesage for pgl 2.0 alpha 
上级 8668f6c0
...@@ -20,9 +20,9 @@ ...@@ -20,9 +20,9 @@
<img src="https://raw.githubusercontent.com/PaddlePaddle/PGL/main/examples/erniesage/docs/source/_static/ERNIESage_v1_4.png" alt="ERNIESage_v1_4" width="800"> <img src="https://raw.githubusercontent.com/PaddlePaddle/PGL/main/examples/erniesage/docs/source/_static/ERNIESage_v1_4.png" alt="ERNIESage_v1_4" width="800">
## 环境依赖 ## 环境依赖
- paddlepaddle >= 2.0rc - paddlepaddle >= 2.0.0rc1
- pgl >= 2.0 - pgl >= 2.0.0a0
- paddlenlp >= 2.0-beta - paddlenlp >= 2.0.0b0
## 数据准备 ## 数据准备
示例数据```data.txt```中使用了NLPCC2016-DBQA的部分数据,格式为每行"query \t answer"。 示例数据```data.txt```中使用了NLPCC2016-DBQA的部分数据,格式为每行"query \t answer"。
...@@ -36,6 +36,8 @@ NLPCC2016-DBQA 是由国际自然语言处理和中文计算会议 NLPCC 于 201 ...@@ -36,6 +36,8 @@ NLPCC2016-DBQA 是由国际自然语言处理和中文计算会议 NLPCC 于 201
```sh ```sh
# 数据预处理,建图
python ./preprocessing/dump_graph.py --conf ./config/erniesage_link_prediction.yaml
# GPU多卡或单卡模式ErnieSage # GPU多卡或单卡模式ErnieSage
python link_prediction.py --conf ./config/erniesage_link_prediction.yaml python link_prediction.py --conf ./config/erniesage_link_prediction.yaml
# 对图节点的的embeding进行预测 # 对图节点的的embeding进行预测
...@@ -51,7 +53,7 @@ python link_prediction.py --conf ./config/erniesage_link_prediction.yaml --do_pr ...@@ -51,7 +53,7 @@ python link_prediction.py --conf ./config/erniesage_link_prediction.yaml --do_pr
- samples: 采样邻居数 - samples: 采样邻居数
- model_type: 模型类型,包括ErnieSageV2。 - model_type: 模型类型,包括ErnieSageV2。
- ernie_name: 热启模型类型,支持“ernie”和"ernie_tiny",后者速度更快,指定该参数后会自动从服务器下载预训练模型文件。 - ernie_name: 热启模型类型,支持“ernie”和"ernie_tiny",后者速度更快,指定该参数后会自动从服务器下载预训练模型文件。
- num_layers: 图神经网络层数。 - num_layers: 图神经网络层数。
- hidden_size: 隐藏层大小。 - hidden_size: 隐藏层大小。
- batch_size: 训练时的batchsize。 - batch_size: 训练时的batchsize。
- infer_batch_size: 预测时batchsize。 - infer_batch_size: 预测时batchsize。
# Global Enviroment Settings # Global Enviroment Settings
# trainer config ------ # trainer config ------
n_gpu: 2 n_gpu: 2 # delete it, if use cpu to train
seed: 2020 seed: 2020
# ernie_tiny or ernie avaiable
model_name_or_path: "ernie_tiny" task: "link_prediction"
task: "link_predict" model_name_or_path: "ernie-tiny" # ernie-tiny or ernie-1.0 avaiable
learner_type: "gpu" sample_workers: 1
optimizer_type: "adam" optimizer_type: "adam"
lr: 0.00005 lr: 0.00005
batch_size: 32 batch_size: 32
...@@ -17,11 +17,11 @@ save_per_step: 200 ...@@ -17,11 +17,11 @@ save_per_step: 200
output_path: "./output" output_path: "./output"
# data config ------ # data config ------
train_data: "./example_data/link_prediction/graph_data.txt" train_data: "./example_data/graph_data.txt"
graph_data: "./example_data/link_prediction/train_data.txt" graph_data: "./example_data/train_data.txt"
graph_work_path: "./graph" graph_work_path: "./graph_workdir"
sample_workers: 1
input_type: "text" input_type: "text"
encoding: "utf8"
# model config ------ # model config ------
samples: [10] samples: [10]
...@@ -38,4 +38,3 @@ neg_type: "batch_neg" ...@@ -38,4 +38,3 @@ neg_type: "batch_neg"
# infer config ------ # infer config ------
infer_model: "./output/last" infer_model: "./output/last"
infer_batch_size: 128 infer_batch_size: 128
encoding: "utf8"
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from data.dataset import *
from data.graph_reader import *
__all__ = []
__all__ += dataset.__all__
__all__ += graph_reader.__all__
...@@ -13,23 +13,24 @@ ...@@ -13,23 +13,24 @@
# limitations under the License. # limitations under the License.
import os import os
import numpy as np
import numpy as np
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.io import IterableDataset from paddle.io import Dataset, IterableDataset
from paddlenlp.utils.log import logger
import pgl import pgl
from pgl.utils.logger import log from pgl import Graph
from pgl.sample import alias_sample, graphsage_sample from pgl.nn.sampling import graphsage_sample
__all__ = [ __all__ = [
"TrainData", "TrainData",
"PredictData", "PredictData",
"GraphDataset", "batch_fn",
] ]
class TrainData(object): class TrainData(Dataset):
def __init__(self, graph_work_path): def __init__(self, graph_work_path):
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
...@@ -60,7 +61,7 @@ class TrainData(object): ...@@ -60,7 +61,7 @@ class TrainData(object):
return len(self.data["train_data"][0]) return len(self.data["train_data"][0])
class PredictData(object): class PredictData(Dataset):
def __init__(self, num_nodes): def __init__(self, num_nodes):
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
...@@ -73,129 +74,43 @@ class PredictData(object): ...@@ -73,129 +74,43 @@ class PredictData(object):
return len(self.data) return len(self.data)
class GraphDataset(IterableDataset): def batch_fn(batch_ex, samples, base_graph, term_ids):
"""load graph, sample, feed as numpy list. batch_src = []
""" batch_dst = []
batch_neg = []
def __init__(self, for batch in batch_ex:
graphs, batch_src.append(batch[0])
data, batch_dst.append(batch[1])
batch_size, if len(batch) == 3: # default neg samples
samples, batch_neg.append(batch[2])
mode,
graph_data_path, batch_src = np.array(batch_src, dtype="int64")
shuffle=True, batch_dst = np.array(batch_dst, dtype="int64")
neg_type="batch_neg"): if len(batch_neg) > 0:
"""[summary] batch_neg = np.unique(np.concatenate(batch_neg))
else:
Args: batch_neg = batch_dst
graphs (GraphTensor List): GraphTensor of each layers.
data (List): train/test source list. Can be edges or nodes. nodes = np.unique(np.concatenate([batch_src, batch_dst, batch_neg], 0))
batch_size (int): the batch size means the edges num to be sampled. subgraphs = graphsage_sample(base_graph, nodes, samples)
samples (List): List of sample number for each layer.
mode (str): train, eval, test subgraph, sample_index, node_index = subgraphs[0]
graph_data_path (str): the real graph object. from_reindex = {int(x): i for i, x in enumerate(sample_index)}
shuffle (bool, optional): shuffle data. Defaults to True.
neg_type (str, optional): negative sample methods. Defaults to "batch_neg". term_ids = term_ids[sample_index].astype(np.int64)
"""
sub_src_idx = pgl.graph_kernel.map_nodes(batch_src, from_reindex)
super(GraphDataset, self).__init__() sub_dst_idx = pgl.graph_kernel.map_nodes(batch_dst, from_reindex)
self.line_examples = data sub_neg_idx = pgl.graph_kernel.map_nodes(batch_neg, from_reindex)
self.graphs = graphs
self.samples = samples user_index = np.array(sub_src_idx, dtype="int64")
self.mode = mode pos_item_index = np.array(sub_dst_idx, dtype="int64")
self.load_graph(graph_data_path) neg_item_index = np.array(sub_neg_idx, dtype="int64")
self.num_layers = len(graphs)
self.neg_type = neg_type user_real_index = np.array(batch_src, dtype="int64")
self.batch_size = batch_size pos_item_real_index = np.array(batch_dst, dtype="int64")
self.shuffle = shuffle
self.num_workers = 1 return np.array([subgraph.num_nodes], dtype="int32"), \
subgraph.edges.astype("int32"), \
def load_graph(self, graph_data_path): term_ids, user_index, pos_item_index, neg_item_index, \
self.graph = pgl.graph.MemmapGraph(graph_data_path) user_real_index, pos_item_real_index
self.alias = np.load(
os.path.join(graph_data_path, "alias.npy"), mmap_mode="r")
self.events = np.load(
os.path.join(graph_data_path, "events.npy"), mmap_mode="r")
self.term_ids = np.load(
os.path.join(graph_data_path, "term_ids.npy"), mmap_mode="r")
def batch_fn(self, batch_ex):
# batch_ex = [
# (src, dst, neg),
# (src, dst, neg),
# (src, dst, neg),
# ]
batch_src = []
batch_dst = []
batch_neg = []
for batch in batch_ex:
batch_src.append(batch[0])
batch_dst.append(batch[1])
if len(batch) == 3: # default neg samples
batch_neg.append(batch[2])
if len(batch_src) != self.batch_size:
if self.mode == "train":
return None #Skip
if len(batch_neg) > 0:
batch_neg = np.unique(np.concatenate(batch_neg))
batch_src = np.array(batch_src, dtype="int64")
batch_dst = np.array(batch_dst, dtype="int64")
if self.neg_type == "batch_neg":
batch_neg = batch_dst
else:
# TODO user define shape of neg_sample
neg_shape = batch_dst.shape
sampled_batch_neg = alias_sample(neg_shape, self.alias, self.events)
batch_neg = np.concatenate([batch_neg, sampled_batch_neg], 0)
nodes = np.unique(np.concatenate([batch_src, batch_dst, batch_neg], 0))
subgraphs = graphsage_sample(self.graph, nodes, self.samples)
subgraphs[0].node_feat["index"] = subgraphs[0].reindex_to_parrent_nodes(
subgraphs[0].nodes).astype(np.int64)
subgraphs[0].node_feat["term_ids"] = self.term_ids[subgraphs[
0].node_feat["index"]].astype(np.int64)
feed_dict = {}
for i in range(self.num_layers):
numpy_list = self.graphs[i].to_numpy(subgraphs[i])
for j in range(len(numpy_list)):
attr = "{}_{}".format(i, self.graphs[i]._graph_attr_holder[j])
feed_dict[attr] = numpy_list[j]
# only reindex from first subgraph
sub_src_idx = subgraphs[0].reindex_from_parrent_nodes(batch_src)
sub_dst_idx = subgraphs[0].reindex_from_parrent_nodes(batch_dst)
sub_neg_idx = subgraphs[0].reindex_from_parrent_nodes(batch_neg)
feed_dict["user_index"] = np.array(sub_src_idx, dtype="int64")
feed_dict["pos_item_index"] = np.array(sub_dst_idx, dtype="int64")
feed_dict["neg_item_index"] = np.array(sub_neg_idx, dtype="int64")
feed_dict["user_real_index"] = np.array(batch_src, dtype="int64")
feed_dict["pos_item_real_index"] = np.array(batch_dst, dtype="int64")
return list(feed_dict.values())
def to_batch(self):
perm = np.arange(0, len(self.line_examples))
if self.shuffle:
np.random.shuffle(perm)
batch = []
for idx in perm:
line_example = self.line_examples[idx]
batch.append(line_example)
if len(batch) == self.batch_size:
yield batch
batch = []
def __iter__(self):
try:
for batch in self.to_batch():
if batch is None:
continue
yield self.batch_fn(batch)
except Exception as e:
log.exception(e)
...@@ -14,15 +14,27 @@ ...@@ -14,15 +14,27 @@
import paddle import paddle
import copy import copy
import pgl
from paddle.io import DataLoader from paddle.io import DataLoader
__all__ = ["GraphDataLoader"] __all__ = ["GraphDataLoader"]
class GraphDataLoader(object): class GraphDataLoader(object):
def __init__(self, dataset): def __init__(self,
self.loader = DataLoader(dataset) dataset,
self.graphs = dataset.graphs batch_size=1,
shuffle=True,
num_workers=1,
collate_fn=None,
**kwargs):
self.loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=collate_fn,
**kwargs)
def __iter__(self): def __iter__(self):
func = self.__callback__() func = self.__callback__()
...@@ -40,18 +52,16 @@ class GraphDataLoader(object): ...@@ -40,18 +52,16 @@ class GraphDataLoader(object):
""" tensor list to ([graph_tensor, graph_tensor, ...], """ tensor list to ([graph_tensor, graph_tensor, ...],
other tensor) other tensor)
""" """
graph_num = 1
start_len = 0 start_len = 0
datas = [] datas = []
graph_list = [] graph_list = []
for i in range(len(tensors)): for graph in range(graph_num):
tensors[i] = paddle.squeeze(tensors[i], axis=0) graph_list.append(
pgl.Graph(
for graph in self.graphs: num_nodes=tensors[start_len],
new_graph = copy.deepcopy(graph) edges=tensors[start_len + 1]))
length = len(new_graph._graph_attr_holder) start_len += 2
graph_tensor_list = tensors[start_len:start_len + length]
start_len += length
graph_list.append(new_graph.from_tensor(graph_tensor_list))
for i in range(start_len, len(tensors)): for i in range(start_len, len(tensors)):
datas.append(tensors[i]) datas.append(tensors[i])
......
...@@ -12,22 +12,22 @@ ...@@ -12,22 +12,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse
import logging
import yaml
import os import os
import io import io
import random import random
import time import time
import numpy as np import argparse
from easydict import EasyDict as edict from functools import partial
import numpy as np
import yaml
import paddle import paddle
from pgl.contrib.imperative.graph_tensor import GraphTensor import pgl
from easydict import EasyDict as edict
from paddlenlp.utils.log import logger
from models import ErnieSageForLinkPrediction from models import ErnieSageForLinkPrediction
from data import GraphDataset, TrainData, PredictData, GraphDataLoader from data import TrainData, PredictData, GraphDataLoader, batch_fn
from paddlenlp.utils.log import logger
def set_seed(config): def set_seed(config):
...@@ -36,22 +36,38 @@ def set_seed(config): ...@@ -36,22 +36,38 @@ def set_seed(config):
paddle.seed(config.seed) paddle.seed(config.seed)
def load_data(graph_data_path):
base_graph = pgl.Graph.load(graph_data_path)
term_ids = np.load(
os.path.join(graph_data_path, "term_ids.npy"), mmap_mode="r")
return base_graph, term_ids
def do_train(config): def do_train(config):
paddle.set_device("gpu" if config.n_gpu else "cpu") paddle.set_device("gpu" if config.n_gpu else "cpu")
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
set_seed(config) set_seed(config)
graphs = [GraphTensor() for x in range(len(config.samples))] base_graph, term_ids = load_data(config.graph_work_path)
collate_fn = partial(
batch_fn,
samples=config.samples,
base_graph=base_graph,
term_ids=term_ids)
mode = 'train' mode = 'train'
data = TrainData(config.graph_work_path) train_ds = TrainData(config.graph_work_path)
model = ErnieSageForLinkPrediction.from_pretrained( model = ErnieSageForLinkPrediction.from_pretrained(
config.model_name_or_path, config=config) config.model_name_or_path, config=config)
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
train_dataset = GraphDataset(graphs, data, config.batch_size, train_loader = GraphDataLoader(
config.samples, mode, config.graph_work_path) train_ds,
graph_loader = GraphDataLoader(train_dataset) batch_size=config.batch_size,
shuffle=True,
num_workers=config.sample_workers,
collate_fn=collate_fn)
optimizer = paddle.optimizer.Adam( optimizer = paddle.optimizer.Adam(
learning_rate=config.lr, parameters=model.parameters()) learning_rate=config.lr, parameters=model.parameters())
...@@ -59,7 +75,7 @@ def do_train(config): ...@@ -59,7 +75,7 @@ def do_train(config):
global_step = 0 global_step = 0
tic_train = time.time() tic_train = time.time()
for epoch in range(config.epoch): for epoch in range(config.epoch):
for step, (graphs, datas) in enumerate(graph_loader()): for step, (graphs, datas) in enumerate(train_loader):
global_step += 1 global_step += 1
loss, outputs = model(graphs, datas) loss, outputs = model(graphs, datas)
if global_step % config.log_per_step == 0: if global_step % config.log_per_step == 0:
...@@ -78,36 +94,47 @@ def do_train(config): ...@@ -78,36 +94,47 @@ def do_train(config):
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
model._layers.save_pretrained(output_dir) model._layers.save_pretrained(output_dir)
if (not config.n_gpu > 1) or paddle.distributed.get_rank() == 0:
output_dir = os.path.join(config.output_path, "last")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model._layers.save_pretrained(output_dir)
def tostr(data_array): def tostr(data_array):
return " ".join(["%.5lf" % d for d in data_array]) return " ".join(["%.5lf" % d for d in data_array])
@paddle.no_grad()
def do_predict(config): def do_predict(config):
paddle.set_device("gpu" if config.n_gpu else "cpu") paddle.set_device("gpu" if config.n_gpu else "cpu")
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
set_seed(config) set_seed(config)
graphs = [GraphTensor() for x in range(len(config.samples))]
mode = 'predict' mode = 'predict'
num_nodes = int( num_nodes = int(
np.load(os.path.join(config.graph_work_path, "num_nodes.npy"))) np.load(os.path.join(config.graph_work_path, "num_nodes.npy")))
data = PredictData(num_nodes)
base_graph, term_ids = load_data(config.graph_work_path)
collate_fn = partial(
batch_fn,
samples=config.samples,
base_graph=base_graph,
term_ids=term_ids)
model = ErnieSageForLinkPrediction.from_pretrained( model = ErnieSageForLinkPrediction.from_pretrained(
config.model_name_or_path, config=config) config.infer_model, config=config)
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
predict_ds = PredictData(num_nodes)
train_dataset = GraphDataset( predict_loader = GraphDataLoader(
graphs, predict_ds,
data, batch_size=config.infer_batch_size,
config.batch_size, shuffle=True,
config.samples, num_workers=config.sample_workers,
mode, collate_fn=collate_fn)
config.graph_work_path,
shuffle=False)
graph_loader = GraphDataLoader(train_dataset)
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
id2str = io.open( id2str = io.open(
...@@ -121,12 +148,12 @@ def do_predict(config): ...@@ -121,12 +148,12 @@ def do_predict(config):
global_step = 0 global_step = 0
epoch = 0 epoch = 0
tic_train = time.time() tic_train = time.time()
for step, (graphs, datas) in enumerate(graph_loader()): model.eval()
for step, (graphs, datas) in enumerate(predict_loader):
global_step += 1 global_step += 1
loss, outputs = model(graphs, datas) loss, outputs = model(graphs, datas)
for user_feat, user_real_index in zip(outputs[0].numpy(), for user_feat, user_real_index in zip(outputs[0].numpy(),
outputs[3].numpy()): outputs[3].numpy()):
# user_feat, user_real_index =
sri = id2str[int(user_real_index)].strip("\n") sri = id2str[int(user_real_index)].strip("\n")
line = "{}\t{}\n".format(sri, tostr(user_feat)) line = "{}\t{}\n".format(sri, tostr(user_feat))
fout.write(line) fout.write(line)
......
...@@ -15,29 +15,88 @@ ...@@ -15,29 +15,88 @@
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from pgl.contrib.imperative.message_passing import SageConv
class ErnieSageV2Conv(SageConv): class GraphSageConv(nn.Layer):
""" GraphSAGE is a general inductive framework that leverages node feature
information (e.g., text attributes) to efficiently generate node embeddings
for previously unseen data.
Paper reference:
Hamilton, Will, Zhitao Ying, and Jure Leskovec.
"Inductive representation learning on large graphs."
Advances in neural information processing systems. 2017.
"""
def __init__(self, input_size, hidden_size, learning_rate, aggr_func="sum"):
super(GraphSageConv, self).__init__()
assert aggr_func in ["sum", "mean", "max", "min"], \
"Only support 'sum', 'mean', 'max', 'min' built-in receive function."
self.aggr_func = "reduce_%s" % aggr_func
self.self_linear = nn.Linear(
input_size,
hidden_size,
weight_attr=paddle.ParamAttr(learning_rate=learning_rate))
self.neigh_linear = nn.Linear(
input_size,
hidden_size,
weight_attr=paddle.ParamAttr(learning_rate=learning_rate))
def forward(self, graph, feature, act=None):
def _send_func(src_feat, dst_feat, edge_feat):
return {"msg": src_feat["h"]}
def _recv_func(message):
return getattr(message, self.aggr_func)(message["msg"])
msg = graph.send(_send_func, src_feat={"h": feature})
neigh_feature = graph.recv(reduce_func=_recv_func, msg=msg)
self_feature = self.self_linear(feature)
neigh_feature = self.neigh_linear(neigh_feature)
output = self_feature + neigh_feature
if act is not None:
output = getattr(F, act)(output)
output = F.normalize(output, axis=1)
return output
class ErnieSageV2Conv(nn.Layer):
""" ErnieSage (abbreviation of ERNIE SAmple aggreGatE), a model proposed by the PGL team. """ ErnieSage (abbreviation of ERNIE SAmple aggreGatE), a model proposed by the PGL team.
ErnieSageV2: Ernie is applied to the EDGE of the text graph. ErnieSageV2: Ernie is applied to the EDGE of the text graph.
""" """
def __init__(self, ernie, input_size, hidden_size, initializer, def __init__(self,
learning_rate, agg, name): ernie,
input_size,
hidden_size,
learning_rate,
aggr_func='sum'):
"""ErnieSageV2: Ernie is applied to the EDGE of the text graph. """ErnieSageV2: Ernie is applied to the EDGE of the text graph.
Args: Args:
ernie (nn.Layer): the ernie model. ernie (nn.Layer): the ernie model.
input_size (int): input size of feature tensor. input_size (int): input size of feature tensor.
hidden_size (int): hidden size of the Conv layers. hidden_size (int): hidden size of the Conv layers.
initializer (initializer): parameters initializer.
learning_rate (float): learning rate. learning_rate (float): learning rate.
agg (str): aggregate function. 'sum', 'mean', 'max' avaliable. aggr_func (str): aggregate function. 'sum', 'mean', 'max' avaliable.
name (str): layer name.
""" """
super(ErnieSageV2Conv, self).__init__( super(ErnieSageV2Conv, self).__init__()
input_size, hidden_size, initializer, learning_rate, "sum", name) assert aggr_func in ["sum", "mean", "max", "min"], \
"Only support 'sum', 'mean', 'max', 'min' built-in receive function."
self.aggr_func = "reduce_%s" % aggr_func
self.self_linear = nn.Linear(
input_size,
hidden_size,
weight_attr=paddle.ParamAttr(learning_rate=learning_rate))
self.neigh_linear = nn.Linear(
input_size,
hidden_size,
weight_attr=paddle.ParamAttr(learning_rate=learning_rate))
self.ernie = ernie self.ernie = ernie
def ernie_send(self, src_feat, dst_feat, edge_feat): def ernie_send(self, src_feat, dst_feat, edge_feat):
...@@ -72,20 +131,23 @@ class ErnieSageV2Conv(SageConv): ...@@ -72,20 +131,23 @@ class ErnieSageV2Conv(SageConv):
feature = outputs[1] feature = outputs[1]
return {"msg": feature} return {"msg": feature}
def send_recv(self, graph, feature): def send_recv(self, graph, term_ids):
"""Message Passing of erniesage v2. """Message Passing of erniesage v2.
Args: Args:
graph (GraphTensor): the GraphTensor object. graph (Graph): the Graph object.
feature (Tensor): the node feature tensor. feature (Tensor): the node feature tensor.
Returns: Returns:
Tensor: the self and neighbor feature tensors. Tensor: the self and neighbor feature tensors.
""" """
msg = graph.send(self.ernie_send, nfeat_list=[("term_ids", feature)])
neigh_feature = graph.recv(msg, self.agg_func)
term_ids = feature def _recv_func(message):
return getattr(message, self.aggr_func)(message["msg"])
msg = graph.send(self.ernie_send, node_feat={"term_ids": term_ids})
neigh_feature = graph.recv(reduce_func=_recv_func, msg=msg)
cls = paddle.full( cls = paddle.full(
shape=[term_ids.shape[0], 1], dtype="int64", fill_value=1) shape=[term_ids.shape[0], 1], dtype="int64", fill_value=1)
term_ids = paddle.concat([cls, term_ids], 1) term_ids = paddle.concat([cls, term_ids], 1)
...@@ -94,3 +156,24 @@ class ErnieSageV2Conv(SageConv): ...@@ -94,3 +156,24 @@ class ErnieSageV2Conv(SageConv):
self_feature = outputs[1] self_feature = outputs[1]
return self_feature, neigh_feature return self_feature, neigh_feature
def forward(self, graph, term_ids, act='relu'):
"""Forward funciton of Conv layer.
Args:
graph (Graph): Graph object.
feature (Tensor): node feture.
act (str, optional): activation function. Defaults to 'relu'.
Returns:
Tensor: feature after conv.
"""
self_feature, neigh_feature = self.send_recv(graph, term_ids)
self_feature = self.self_linear(self_feature)
neigh_feature = self.neigh_linear(neigh_feature)
output = self_feature + neigh_feature
if act is not None:
output = getattr(F, act)(output)
output = F.normalize(output, axis=1)
return output
...@@ -12,14 +12,12 @@ ...@@ -12,14 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import pgl
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
import numpy as np import numpy as np
from pgl.contrib.imperative.message_passing import GraphSageConv
from models.conv import ErnieSageV2Conv from models.conv import GraphSageConv, ErnieSageV2Conv
class Encoder(nn.Layer): class Encoder(nn.Layer):
...@@ -80,27 +78,22 @@ class ErnieSageV2Encoder(Encoder): ...@@ -80,27 +78,22 @@ class ErnieSageV2Encoder(Encoder):
ernie, ernie,
ernie.config["hidden_size"], ernie.config["hidden_size"],
self.config.hidden_size, self.config.hidden_size,
initializer,
learning_rate=fc_lr, learning_rate=fc_lr,
agg="sum", aggr_func="sum")
name="ErnieSageV2Conv_0")
self.convs.append(erniesage_conv) self.convs.append(erniesage_conv)
for i in range(1, self.config.num_layers): for i in range(1, self.config.num_layers):
layer = GraphSageConv( layer = GraphSageConv(
self.config.hidden_size * 2,
self.config.hidden_size, self.config.hidden_size,
initializer, self.config.hidden_size,
learning_rate=fc_lr, learning_rate=fc_lr,
agg="sum", aggr_func="sum")
name="%s_%s" % ("GraphSageConv_", i))
self.convs.append(layer) self.convs.append(layer)
if self.config.final_fc: if self.config.final_fc:
self.linear = nn.Linear( self.linear = nn.Linear(
self.config.hidden_size * 2,
self.config.hidden_size, self.config.hidden_size,
weight_attr=paddle.ParamAttr(name="final_fc" + '_w'), self.config.hidden_size,
bias_attr=paddle.ParamAttr(name="final_fc" + '_b')) weight_attr=paddle.ParamAttr(learning_rate=fc_lr))
def take_final_feature(self, feature, index): def take_final_feature(self, feature, index):
"""Gather the final feature. """Gather the final feature.
...@@ -119,17 +112,20 @@ class ErnieSageV2Encoder(Encoder): ...@@ -119,17 +112,20 @@ class ErnieSageV2Encoder(Encoder):
feat = F.normalize(feat, axis=1) feat = F.normalize(feat, axis=1)
return feat return feat
def forward(self, graphs, inputs): def forward(self, graphs, term_ids, inputs):
""" forward train function of the model. """ forward train function of the model.
Args: Args:
graphs (GraphTensor List): list of graph tensors. graphs (Graph List): list of graph tensors.
inputs (Tensor List): list of input tensors. inputs (Tensor List): list of input tensors.
Returns: Returns:
Tensor List: list of final feature tensors. Tensor List: list of final feature tensors.
""" """
feature = graphs[0].node_feat["term_ids"] # term_ids for ErnieSageConv is the raw feature.
feature = term_ids
for i in range(len(graphs), self.config.num_layers):
graphs.append(graphs[0])
for i in range(0, self.config.num_layers): for i in range(0, self.config.num_layers):
if i == self.config.num_layers - 1 and i != 0: if i == self.config.num_layers - 1 and i != 0:
act = None act = None
......
...@@ -12,12 +12,10 @@ ...@@ -12,12 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
import pgl import pgl
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from pgl.contrib.imperative.graph_tensor import GraphTensor import numpy as np
from paddlenlp.transformers import ErniePretrainedModel from paddlenlp.transformers import ErniePretrainedModel
from models.encoder import Encoder from models.encoder import Encoder
...@@ -52,15 +50,15 @@ class ErnieSageForLinkPrediction(ErniePretrainedModel): ...@@ -52,15 +50,15 @@ class ErnieSageForLinkPrediction(ErniePretrainedModel):
"""Forward function of link prediction task. """Forward function of link prediction task.
Args: Args:
graphs (GraphTensor List): the GraphTensor list. graphs (Graph List): the Graph list.
datas (Tensor List): other input of the model. datas (Tensor List): other input of the model.
Returns: Returns:
Tensor: loss and output tensors. Tensor: loss and output tensors.
""" """
user_index, pos_item_index, neg_item_index, user_real_index, pos_item_real_index = datas term_ids, user_index, pos_item_index, neg_item_index, user_real_index, pos_item_real_index = datas
# encoder model # encoder model
outputs = self.encoder(graphs, outputs = self.encoder(graphs, term_ids,
[user_index, pos_item_index, neg_item_index]) [user_index, pos_item_index, neg_item_index])
user_feat, pos_item_feat, neg_item_feat = outputs user_feat, pos_item_feat, neg_item_feat = outputs
......
...@@ -22,11 +22,12 @@ from functools import partial ...@@ -22,11 +22,12 @@ from functools import partial
from io import open from io import open
import numpy as np import numpy as np
import yaml
import tqdm import tqdm
from easydict import EasyDict as edict
import pgl import pgl
from pgl.graph_kernel import alias_sample_build_table from pgl.graph_kernel import alias_sample_build_table
from pgl.utils.logger import log from pgl.utils.logger import log
from paddlenlp.transformers import ErnieTinyTokenizer from paddlenlp.transformers import ErnieTinyTokenizer
...@@ -35,20 +36,20 @@ def term2id(string, tokenizer, max_seqlen): ...@@ -35,20 +36,20 @@ def term2id(string, tokenizer, max_seqlen):
tokens = tokenizer(string) tokens = tokenizer(string)
ids = tokenizer.convert_tokens_to_ids(tokens) ids = tokenizer.convert_tokens_to_ids(tokens)
ids = ids[:max_seqlen - 1] ids = ids[:max_seqlen - 1]
ids = ids + [2] # ids + [sep] ids = ids + [tokenizer.sep_token_id]
ids = ids + [0] * (max_seqlen - len(ids)) ids = ids + [tokenizer.pad_token_id] * (max_seqlen - len(ids))
return ids return ids
def load_graph(args, str2id, term_file, terms, item_distribution): def load_graph(config, str2id, term_file, terms, item_distribution):
edges = [] edges = []
with io.open(args.graphpath, encoding=args.encoding) as f: with io.open(config.graph_data, encoding=config.encoding) as f:
for idx, line in enumerate(f): for idx, line in enumerate(f):
if idx % 100000 == 0: if idx % 100000 == 0:
log.info("%s readed %s lines" % (args.graphpath, idx)) log.info("%s readed %s lines" % (config.graph_data, idx))
slots = [] slots = []
for col_idx, col in enumerate(line.strip("\n").split("\t")): for col_idx, col in enumerate(line.strip("\n").split("\t")):
s = col[:args.max_seqlen] s = col[:config.max_seqlen]
if s not in str2id: if s not in str2id:
str2id[s] = len(str2id) str2id[s] = len(str2id)
term_file.write(str(col_idx) + "\t" + col + "\n") term_file.write(str(col_idx) + "\t" + col + "\n")
...@@ -64,17 +65,17 @@ def load_graph(args, str2id, term_file, terms, item_distribution): ...@@ -64,17 +65,17 @@ def load_graph(args, str2id, term_file, terms, item_distribution):
return edges return edges
def load_link_prediction_train_data(args, str2id, term_file, terms, def load_link_prediction_train_data(config, str2id, term_file, terms,
item_distribution): item_distribution):
train_data = [] train_data = []
neg_samples = [] neg_samples = []
with io.open(args.inpath, encoding=args.encoding) as f: with io.open(config.train_data, encoding=config.encoding) as f:
for idx, line in enumerate(f): for idx, line in enumerate(f):
if idx % 100000 == 0: if idx % 100000 == 0:
log.info("%s readed %s lines" % (args.inpath, idx)) log.info("%s readed %s lines" % (config.train_data, idx))
slots = [] slots = []
for col_idx, col in enumerate(line.strip("\n").split("\t")): for col_idx, col in enumerate(line.strip("\n").split("\t")):
s = col[:args.max_seqlen] s = col[:config.max_seqlen]
if s not in str2id: if s not in str2id:
str2id[s] = len(str2id) str2id[s] = len(str2id)
term_file.write(str(col_idx) + "\t" + col + "\n") term_file.write(str(col_idx) + "\t" + col + "\n")
...@@ -86,25 +87,27 @@ def load_link_prediction_train_data(args, str2id, term_file, terms, ...@@ -86,25 +87,27 @@ def load_link_prediction_train_data(args, str2id, term_file, terms,
neg_samples.append(slots[2:]) neg_samples.append(slots[2:])
train_data.append((src, dst)) train_data.append((src, dst))
train_data = np.array(train_data, dtype="int64") train_data = np.array(train_data, dtype="int64")
np.save(os.path.join(args.outpath, "train_data.npy"), train_data) np.save(os.path.join(config.graph_work_path, "train_data.npy"), train_data)
if len(neg_samples) != 0: if len(neg_samples) != 0:
np.save( np.save(
os.path.join(args.outpath, "neg_samples.npy"), os.path.join(config.graph_work_path, "neg_samples.npy"),
np.array(neg_samples)) np.array(neg_samples))
def dump_graph(args): def dump_graph(config):
if not os.path.exists(args.outpath): if not os.path.exists(config.graph_work_path):
os.makedirs(args.outpath) os.makedirs(config.graph_work_path)
str2id = dict() str2id = dict()
term_file = io.open( term_file = io.open(
os.path.join(args.outpath, "terms.txt"), "w", encoding=args.encoding) os.path.join(config.graph_work_path, "terms.txt"),
"w",
encoding=config.encoding)
terms = [] terms = []
item_distribution = [] item_distribution = []
edges = load_graph(args, str2id, term_file, terms, item_distribution) edges = load_graph(config, str2id, term_file, terms, item_distribution)
if args.task == "link_prediction": if config.task == "link_prediction":
load_link_prediction_train_data(args, str2id, term_file, terms, load_link_prediction_train_data(config, str2id, term_file, terms,
item_distribution) item_distribution)
else: else:
raise ValueError raise ValueError
...@@ -118,51 +121,42 @@ def dump_graph(args): ...@@ -118,51 +121,42 @@ def dump_graph(args):
indegree = graph.indegree() indegree = graph.indegree()
graph.indegree() graph.indegree()
graph.outdegree() graph.outdegree()
graph.dump(args.outpath) graph.dump(config.graph_work_path)
# dump alias sample table # dump alias sample table
item_distribution = np.array(item_distribution) item_distribution = np.array(item_distribution)
item_distribution = np.sqrt(item_distribution) item_distribution = np.sqrt(item_distribution)
distribution = 1. * item_distribution / item_distribution.sum() distribution = 1. * item_distribution / item_distribution.sum()
alias, events = alias_sample_build_table(distribution) alias, events = alias_sample_build_table(distribution)
np.save(os.path.join(args.outpath, "alias.npy"), alias) np.save(os.path.join(config.graph_work_path, "alias.npy"), alias)
np.save(os.path.join(args.outpath, "events.npy"), events) np.save(os.path.join(config.graph_work_path, "events.npy"), events)
log.info("End Build Graph") log.info("End Build Graph")
def dump_node_feat(args): def dump_node_feat(config):
log.info("Dump node feat starting...") log.info("Dump node feat starting...")
id2str = [ id2str = [
line.strip("\n").split("\t")[-1] line.strip("\n").split("\t")[-1]
for line in io.open( for line in io.open(
os.path.join(args.outpath, "terms.txt"), encoding=args.encoding) os.path.join(config.graph_work_path, "terms.txt"),
encoding=config.encoding)
] ]
# pool = multiprocessing.Pool() # pool = multiprocessing.Pool()
tokenizer = ErnieTinyTokenizer.from_pretrained(args.model_name_or_path) tokenizer = ErnieTinyTokenizer.from_pretrained(config.model_name_or_path)
fn = partial(term2id, tokenizer=tokenizer, max_seqlen=args.max_seqlen) fn = partial(term2id, tokenizer=tokenizer, max_seqlen=config.max_seqlen)
term_ids = [fn(x) for x in id2str] term_ids = [fn(x) for x in id2str]
np.save( np.save(
os.path.join(args.outpath, "term_ids.npy"), os.path.join(config.graph_work_path, "term_ids.npy"),
np.array(term_ids, np.uint16)) np.array(term_ids, np.uint16))
log.info("Dump node feat done.") log.info("Dump node feat done.")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='main') parser = argparse.ArgumentParser(description='main')
parser.add_argument("-i", "--inpath", type=str, default=None) parser.add_argument("--conf", type=str, default="./config.yaml")
parser.add_argument("-g", "--graphpath", type=str, default=None)
parser.add_argument("-l", "--max_seqlen", type=int, default=30)
# parser.add_argument("--vocab_file", type=str, default="./vocab.txt")
parser.add_argument("--model_name_or_path", type=str, default="ernie_tiny")
parser.add_argument("--encoding", type=str, default="utf8")
parser.add_argument(
"--task",
type=str,
default="link_prediction",
choices=["link_prediction", "node_classification"])
parser.add_argument("-o", "--outpath", type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
dump_graph(args) config = edict(yaml.load(open(args.conf), Loader=yaml.FullLoader))
dump_node_feat(args) dump_graph(config)
dump_node_feat(config)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册