提交 7bec37a3 编写于 作者: L liweibin

add gatne and metapath2vec model

上级 01412765
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file loads and preprocesses the dataset for GATNE model.
"""
import sys
import os
import tqdm
import numpy as np
import logging
import random
from pgl.contrib import heter_graph
import pickle as pkl
class Dataset(object):
"""Implementation of Dataset class
This is a simple implementation of loading and processing dataset for GATNE model.
Args:
config: dict, some configure parameters.
"""
def __init__(self, config):
self.train_edges_file = config['data_path'] + 'train.txt'
self.valid_edges_file = config['data_path'] + 'valid.txt'
self.test_edges_file = config['data_path'] + 'test.txt'
self.nodes_file = config['data_path'] + 'nodes.txt'
self.config = config
self.word2index = self.load_word2index()
self.build_graph()
self.valid_data = self.load_test_data(self.valid_edges_file)
self.test_data = self.load_test_data(self.test_edges_file)
def build_graph(self):
"""Build pgl heterogeneous graph.
"""
edge_data_by_type, all_edges, all_nodes = self.load_training_data(
self.train_edges_file,
slf_loop=self.config['slf_loop'],
symmetry_edge=self.config['symmetry_edge'])
num_nodes = len(all_nodes)
node_features = {
'index': np.array(
[i for i in range(num_nodes)], dtype=np.int64).reshape(-1, 1)
}
self.graph = heter_graph.HeterGraph(
num_nodes=num_nodes,
edges=edge_data_by_type,
node_types=None,
node_feat=node_features)
self.edge_types = sorted(self.graph.edge_types_info())
logging.info('total %d nodes are loaded' % (self.graph.num_nodes))
def load_training_data(self, file_, slf_loop=True, symmetry_edge=True):
"""Load train data from file and preprocess them.
Args:
file_: str, file name for loading data
slf_loop: bool, if true, add self loop edge for every node
symmetry_edge: bool, if true, add symmetry edge for every edge
"""
logging.info('loading data from %s' % file_)
edge_data_by_type = dict()
all_edges = list()
all_nodes = list()
with open(file_, 'r') as reader:
for line in reader:
words = line.strip().split(' ')
if words[0] not in edge_data_by_type:
edge_data_by_type[words[0]] = []
src, dst = words[1], words[2]
edge_data_by_type[words[0]].append((src, dst))
all_edges.append((src, dst))
all_nodes.append(src)
all_nodes.append(dst)
if symmetry_edge:
edge_data_by_type[words[0]].append((dst, src))
all_edges.append((dst, src))
all_nodes = list(set(all_nodes))
all_edges = list(set(all_edges))
# edge_data_by_type['Base'] = all_edges
if slf_loop:
for e_type in edge_data_by_type.keys():
for n in all_nodes:
edge_data_by_type[e_type].append((n, n))
# remapping to index
edges_by_type = {}
for edge_type, edges in edge_data_by_type.items():
res_edges = []
for edge in edges:
res_edges.append(
(self.word2index[edge[0]], self.word2index[edge[1]]))
edges_by_type[edge_type] = res_edges
return edges_by_type, all_edges, all_nodes
def load_test_data(self, file_):
"""Load testing data from file.
"""
logging.info('loading data from %s' % file_)
true_edge_data_by_type = {}
fake_edge_data_by_type = {}
with open(file_, 'r') as reader:
for line in reader:
words = line.strip().split(' ')
src, dst = self.word2index[words[1]], self.word2index[words[2]]
e_type = words[0]
if int(words[3]) == 1: # true edges
if e_type not in true_edge_data_by_type:
true_edge_data_by_type[e_type] = list()
true_edge_data_by_type[e_type].append((src, dst))
else: # fake edges
if e_type not in fake_edge_data_by_type:
fake_edge_data_by_type[e_type] = list()
fake_edge_data_by_type[e_type].append((src, dst))
return (true_edge_data_by_type, fake_edge_data_by_type)
def load_word2index(self):
"""Load words(nodes) from file and map to index.
"""
word2index = {}
with open(self.nodes_file, 'r') as reader:
for index, line in enumerate(reader):
node = line.strip()
word2index[node] = index
return word2index
def generate_walks(self):
"""Generate random walks for every edge type.
"""
all_walks = {}
for e_type in self.edge_types:
layer_walks = self.simulate_walks(
edge_type=e_type,
num_walks=self.config['num_walks'],
walk_length=self.config['walk_length'])
all_walks[e_type] = layer_walks
return all_walks
def simulate_walks(self, edge_type, num_walks, walk_length, schema=None):
"""Generate random walks in specified edge type.
"""
walks = []
nodes = list(range(0, self.graph[edge_type].num_nodes))
for walk_iter in tqdm.tqdm(range(num_walks)):
random.shuffle(nodes)
for node in nodes:
walk = self.graph[edge_type].random_walk(
[node], max_depth=walk_length - 1)
for i in range(len(walk)):
walks.append(walk[i])
return walks
def generate_pairs(self, all_walks):
"""Generate word pairs for training.
"""
logging.info(['edge_types before generate pairs', self.edge_types])
pairs = []
skip_window = self.config['win_size'] // 2
for layer_id, e_type in enumerate(self.edge_types):
walks = all_walks[e_type]
for walk in tqdm.tqdm(walks):
for i in range(len(walk)):
for j in range(1, skip_window + 1):
if i - j >= 0 and walk[i] != walk[i - j]:
neg_nodes = self.graph[e_type].sample_nodes(
self.config['neg_num'])
pairs.append(
(walk[i], walk[i - j], *neg_nodes, layer_id))
if i + j < len(walk) and walk[i] != walk[i + j]:
neg_nodes = self.graph[e_type].sample_nodes(
self.config['neg_num'])
pairs.append(
(walk[i], walk[i + j], *neg_nodes, layer_id))
return pairs
def fetch_batch(self, pairs, batch_size, for_test=False):
"""Produce batch pairs data for training.
"""
np.random.shuffle(pairs)
n_batches = (len(pairs) + (batch_size - 1)) // batch_size
neg_num = len(pairs[0]) - 3
result = []
for i in range(1, n_batches):
batch_pairs = np.array(
pairs[batch_size * (i - 1):batch_size * i], dtype=np.int64)
x = batch_pairs[:, 0].reshape(-1, ).astype(np.int64)
y = batch_pairs[:, 1].reshape(-1, 1, 1).astype(np.int64)
neg = batch_pairs[:, 2:2 + neg_num].reshape(-1, neg_num,
1).astype(np.int64)
t = batch_pairs[:, -1].reshape(-1, 1).astype(np.int64)
result.append((x, y, neg, t))
return result
if __name__ == "__main__":
config = {
'data_path': './data/youtube/',
'train_pairs_file': 'train_pairs.pkl',
'slf_loop': True,
'symmetry_edge': True,
'num_walks': 20,
'walk_length': 10,
'win_size': 5,
'neg_num': 5,
}
log_format = '%(asctime)s-%(levelname)s-%(name)s: %(message)s'
logging.basicConfig(level='INFO', format=log_format)
dataset = Dataset(config)
logging.info('generating walks')
all_walks = dataset.generate_walks()
logging.info('finishing generate walks')
logging.info(['length of all walks: ', all_walks.keys()])
train_pairs = dataset.generate_pairs(all_walks)
pkl.dump(train_pairs,
open(config['data_path'] + config['train_pairs_file'], 'wb'))
logging.info('finishing generate train_pairs')
# PGL Examples for GATNE
[GATNE](https://arxiv.org/pdf/1905.01669.pdf) is a algorithms framework for embedding large-scale Attributed Multiplex Heterogeneous Networks(AMHN). Given a heterogeneous graph, which consists of nodes and edges of multiple types, it can learn continuous feature representations for every node. Based on PGL, we reproduce GATNE algorithm.
## Datasets
YouTube dataset contains 2000 nodes, 1310617 edges and 5 edge types. And we use YouTube dataset for example.
You can dowload YouTube datasets from [here](https://github.com/THUDM/GATNE/tree/master/data)
After downloading the data, put them, let's say, in ./data/ . Note that the current directory is the root directory of GATNE model. Then in ./data/youtube/ directory, there are three files:
* train.txt
* valid.txt
* test.txt
Then you can run the below command to preprocess the data.
```sh
python data_process.py --input_file ./data/youtube/train.txt --output_file ./data/youtube/nodes.txt
```
## Dependencies
- paddlepaddle>=1.6
- pgl>=1.0.0
## Hyperparameters
All the hyper parameters are saved in config.yaml file. So before training GATNE model, you can open the config.yaml to modify the hyper parameters as you like.
for example, you can change the \"use_cuda\" to \"True \" in order to use GPU for training or modify \"data_path\" to use different dataset.
Some important hyper parameters in config.yaml:
- use_cuda: use GPU to train model
- data_path: the directory of dataset
- lr: learning rate
- neg_num: number of negatie samples.
- num_walks: number of walks started from each node
- walk_length: walk length
## How to run
Then run the below command:
```sh
python main.py -c config.yaml
```
### Experiment results
| | PGL result | Reported result |
|:---:|------------|-----------------|
| AUC | 84.83 | 84.61 |
| PR | 82.77 | 81.93 |
| F1 | 76.98 | 76.83 |
task_name: train.gatne
use_cuda: True
log_level: info
seed: 1667
optimizer:
type:
args:
lr: 0.005
trainer:
type: trainer
args:
epochs: 2
log_dir: logs/
save_dir: checkpoints/
output_dir: outputs/
data_loader:
type: Dataset
args:
data_path: ./data/youtube/
train_pairs_file: train_pairs.pkl
batch_size: 256
num_walks: 20
walk_length: 10
win_size: 5
neg_num: 5
slf_loop: True
symmetry_edge: True
model:
type: GATNE
args:
dimensions: 200
edge_dim: 32
att_dim: 32
att_head: 1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file preprocess the data before training.
"""
import sys
import argparse
def gen_nodes_file(file_, result_file):
"""calculate the total number of nodes and save them for latter processing.
"""
nodes = []
with open(file_, 'r') as reader:
for line in reader:
tokens = line.strip().split(' ')
nodes.append(tokens[1])
nodes.append(tokens[2])
nodes = list(set(nodes))
nodes.sort(key=int)
print('total number of nodes: %d' % len(nodes))
print('saving nodes file in %s' % (result_file))
with open(result_file, 'w') as writer:
for n in nodes:
writer.write(n + '\n')
print('finished')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='GATNE')
parser.add_argument(
'--input_file',
default='./data/youtube/train.txt',
type=str,
help='input file')
parser.add_argument(
'--output_file',
default='./data/youtube/nodes.txt',
type=str,
help='output file')
args = parser.parse_args()
print('generating nodes file')
gen_nodes_file(args.input_file, args.output_file)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implement the training process of GATNE model.
"""
import os
import argparse
import time
import numpy as np
import logging
import pickle as pkl
import pgl
from pgl.utils import paddle_helper
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as fl
from utils import *
import Dataset
import model as Model
from sklearn.metrics import (auc, f1_score, precision_recall_curve,
roc_auc_score)
def set_seed(seed):
"""Set random seed.
"""
random.seed(seed)
np.random.seed(seed)
def produce_model(exe, program, dataset, model, feed_dict):
"""Output the learned model parameters for testing.
"""
edge_types = dataset.edge_types
num_nodes = dataset.graph[edge_types[0]].num_nodes
edge_types_count = len(edge_types)
neg_num = dataset.config['neg_num']
final_model = {}
feed_dict['train_inputs'] = np.array(
[n for n in range(num_nodes)], dtype=np.int64).reshape(-1, )
feed_dict['train_labels'] = np.array(
[n for n in range(num_nodes)], dtype=np.int64).reshape(-1, 1, 1)
feed_dict['train_negs'] = np.tile(feed_dict['train_labels'],
(1, neg_num)).reshape(-1, neg_num, 1)
for i in range(edge_types_count):
feed_dict['train_types'] = np.array(
[i for _ in range(num_nodes)], dtype=np.int64).reshape(-1, 1)
edge_node_embed = exe.run(program,
feed=feed_dict,
fetch_list=[model.last_node_embed],
return_numpy=True)[0]
final_model[edge_types[i]] = edge_node_embed
return final_model
def evaluate(final_model, edge_types, data):
"""Calculate the AUC score, F1 score and PR score of the final model
"""
edge_types_count = len(edge_types)
AUC, F1, PR = [], [], []
true_edge_data_by_type = data[0]
fake_edge_data_by_type = data[1]
for i in range(edge_types_count):
try:
local_model = final_model[edge_types[i]]
true_edges = true_edge_data_by_type[edge_types[i]]
fake_edges = fake_edge_data_by_type[edge_types[i]]
except Exception as e:
logging.warn('edge type not exists. %s' % str(e))
continue
tmp_auc, tmp_f1, tmp_pr = calculate_score(local_model, true_edges,
fake_edges)
AUC.append(tmp_auc)
F1.append(tmp_f1)
PR.append(tmp_pr)
return {'AUC': np.mean(AUC), 'F1': np.mean(F1), 'PR': np.mean(PR)}
def calculate_score(model, true_edges, fake_edges):
"""Calculate the AUC score, F1 score and PR score of specified edge type
"""
true_list = list()
prediction_list = list()
true_num = 0
for edge in true_edges:
tmp_score = get_score(model, edge)
if tmp_score is not None:
true_list.append(1)
prediction_list.append(tmp_score)
true_num += 1
for edge in fake_edges:
tmp_score = get_score(model, edge)
if tmp_score is not None:
true_list.append(0)
prediction_list.append(tmp_score)
sorted_pred = prediction_list[:]
sorted_pred.sort()
threshold = sorted_pred[-true_num]
y_pred = np.zeros(len(prediction_list), dtype=np.int32)
for i in range(len(prediction_list)):
if prediction_list[i] >= threshold:
y_pred[i] = 1
y_true = np.array(true_list)
y_scores = np.array(prediction_list)
ps, rs, _ = precision_recall_curve(y_true, y_scores)
return roc_auc_score(y_true, y_scores), f1_score(y_true, y_pred), auc(rs,
ps)
def get_score(local_model, edge):
"""Calculate the cosine similarity score between two nodes.
"""
try:
vector1 = local_model[edge[0]]
vector2 = local_model[edge[1]]
return np.dot(vector1, vector2) / (np.linalg.norm(vector1) *
np.linalg.norm(vector2))
except Exception as e:
logging.warn('get_score warning: %s' % str(e))
return None
pass
def run_epoch(epoch,
config,
dataset,
data,
train_prog,
test_prog,
model,
feed_dict,
exe,
for_test=False):
"""Run training process of every epoch.
"""
total_loss = []
for idx, batch_data in enumerate(data):
feed_dict['train_inputs'] = batch_data[0]
feed_dict['train_labels'] = batch_data[1]
feed_dict['train_negs'] = batch_data[2]
feed_dict['train_types'] = batch_data[3]
loss, lr = exe.run(train_prog,
feed=feed_dict,
fetch_list=[model.loss, model.lr],
return_numpy=True)
total_loss.append(loss[0])
if (idx + 1) % 500 == 0:
avg_loss = np.mean(total_loss)
logging.info("epoch %d | step %d | lr %.4f | train_loss %f " %
(epoch, idx + 1, lr, avg_loss))
total_loss = []
return avg_loss
def save_model(program, exe, dataset, model, feed_dict, filename):
"""Save model.
"""
final_model = produce_model(exe, program, dataset, model, feed_dict)
logging.info('saving model in %s' % (filename))
pkl.dump(final_model, open(filename, 'wb'))
def test(program, exe, dataset, model, feed_dict):
"""Testing and validating.
"""
final_model = produce_model(exe, program, dataset, model, feed_dict)
valid_result = evaluate(final_model, dataset.edge_types,
dataset.valid_data)
test_result = evaluate(final_model, dataset.edge_types, dataset.test_data)
logging.info("valid_AUC %.4f | valid_PR %.4f | valid_F1 %.4f" %
(valid_result['AUC'], valid_result['PR'], valid_result['F1']))
logging.info("test_AUC %.4f | test_PR %.4f | test_F1 %.4f" %
(test_result['AUC'], test_result['PR'], test_result['F1']))
return test_result
def main(config):
"""main function for training GATNE model.
"""
logging.info(config)
set_seed(config['seed'])
dataset = getattr(
Dataset, config['data_loader']['type'])(config['data_loader']['args'])
edge_types = dataset.graph.edge_types_info()
logging.info(['total edge types: ', edge_types])
# train_pairs is a list of tuple: [(src1, dst1, neg, e1), (src2, dst2, neg, e2)]
# e(int), edge num count, for select which edge embedding
train_pairs_file = config['data_loader']['args']['data_path'] + \
config['data_loader']['args']['train_pairs_file']
if os.path.exists(train_pairs_file):
logging.info('loading train pairs from pkl file %s' % train_pairs_file)
train_pairs = pkl.load(open(train_pairs_file, 'rb'))
else:
logging.info('generating walks')
all_walks = dataset.generate_walks()
logging.info('generating train pairs')
train_pairs = dataset.generate_pairs(all_walks)
logging.info('dumping train pairs to %s' % (train_pairs_file))
pkl.dump(train_pairs, open(train_pairs_file, 'wb'))
logging.info('total train pairs: %d' % (len(train_pairs)))
data = dataset.fetch_batch(train_pairs,
config['data_loader']['args']['batch_size'])
place = fluid.CUDAPlace(0) if config['use_cuda'] else fluid.CPUPlace()
train_program = fluid.Program()
startup_program = fluid.Program()
test_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
model = getattr(Model, config['model']['type'])(
config['model']['args'], dataset, place)
test_program = train_program.clone(for_test=True)
with fluid.program_guard(train_program, startup_program):
global_steps = len(data) * config['trainer']['args']['epochs']
model.backward(global_steps, config['optimizer']['args'])
# train
exe = fluid.Executor(place)
exe.run(startup_program)
feed_dict = model.gw.to_feed(dataset.graph)
logging.info('test before training...')
test(test_program, exe, dataset, model, feed_dict)
logging.info('training...')
for epoch in range(1, 1 + config['trainer']['args']['epochs']):
train_result = run_epoch(epoch, config['trainer']['args'], dataset,
data, train_program, test_program, model,
feed_dict, exe)
logging.info('validating and testing...')
test_result = test(test_program, exe, dataset, model, feed_dict)
filename = os.path.join(config['trainer']['args']['save_dir'],
'dict_embed_model_epoch_%d.pkl' % (epoch))
save_model(test_program, exe, dataset, model, feed_dict, filename)
logging.info(
"final_test_AUC %.4f | final_test_PR %.4f | fianl_test_F1 %.4f" % (
test_result['AUC'], test_result['PR'], test_result['F1']))
logging.info('training finished')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='GATNE')
parser.add_argument(
'-c',
'--config',
default=None,
type=str,
help='config file path (default: None)')
parser.add_argument(
'-n',
'--taskname',
default=None,
type=str,
help='task name(default: None)')
args = parser.parse_args()
if args.config:
# load config file
config = Config(args.config, isCreate=True, isSave=True)
config = config()
else:
raise AssertionError(
"Configuration file need to be specified. Add '-c config.yaml', for example."
)
log_format = '%(asctime)s-%(levelname)s-%(name)s: %(message)s'
logging.basicConfig(
level=getattr(logging, config['log_level'].upper()), format=log_format)
main(config)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implement the GATNE model.
"""
import numpy as np
import math
import logging
import paddle.fluid as fluid
import paddle.fluid.layers as fl
from pgl.contrib import heter_graph_wrapper
class GATNE(object):
"""Implementation of GATNE model.
Args:
config: dict, some configure parameters.
dataset: instance of Dataset class
place: GPU or CPU place
"""
def __init__(self, config, dataset, place):
logging.info(['model is: ', self.__class__.__name__])
self.config = config
self.graph = dataset.graph
self.placce = place
self.edge_types = sorted(self.graph.edge_types_info())
logging.info('edge_types in model: %s' % str(self.edge_types))
neg_num = dataset.config['neg_num']
# hyper parameters
self.num_nodes = self.graph.num_nodes
self.embedding_size = self.config['dimensions']
self.embedding_u_size = self.config['edge_dim']
self.dim_a = self.config['att_dim']
self.att_head = self.config['att_head']
self.edge_type_count = len(self.edge_types)
self.u_num = self.edge_type_count
self.gw = heter_graph_wrapper.HeterGraphWrapper(
name="heter_graph",
place=place,
edge_types=self.graph.edge_types_info(),
node_feat=self.graph.node_feat_info(),
edge_feat=self.graph.edge_feat_info())
self.train_inputs = fl.data(
'train_inputs', shape=[None], dtype='int64')
self.train_labels = fl.data(
'train_labels', shape=[None, 1, 1], dtype='int64')
self.train_types = fl.data(
'train_types', shape=[None, 1], dtype='int64')
self.train_negs = fl.data(
'train_negs', shape=[None, neg_num, 1], dtype='int64')
self.forward()
def forward(self):
"""Build the GATNE net.
"""
param_attr_init = fluid.initializer.Uniform(
low=-1.0, high=1.0, seed=np.random.randint(100))
embed_param_attrs = fluid.ParamAttr(
name='Base_node_embed', initializer=param_attr_init)
# node_embeddings
base_node_embed = fl.embedding(
input=fl.reshape(
self.train_inputs, shape=[-1, 1]),
size=[self.num_nodes, self.embedding_size],
param_attr=embed_param_attrs)
node_features = []
for edge_type in self.edge_types:
param_attr_init = fluid.initializer.Uniform(
low=-1.0, high=1.0, seed=np.random.randint(100))
embed_param_attrs = fluid.ParamAttr(
name='%s_node_embed' % edge_type, initializer=param_attr_init)
features = fl.embedding(
input=self.gw[edge_type].node_feat['index'],
size=[self.num_nodes, self.embedding_u_size],
param_attr=embed_param_attrs)
node_features.append(features)
# mp_output: list of embedding(self.num_nodes, dim)
mp_output = self.message_passing(self.gw, self.edge_types,
node_features)
# U : (num_type[m], num_nodes, dim[s])
node_type_embed = fl.stack(mp_output, axis=0)
# U : (num_nodes, num_type[m], dim[s])
node_type_embed = fl.transpose(node_type_embed, perm=[1, 0, 2])
#gather node_type_embed from train_inputs
node_type_embed = fl.gather(node_type_embed, self.train_inputs)
# M_r
trans_weights = fl.create_parameter(
shape=[
self.edge_type_count, self.embedding_u_size,
self.embedding_size // self.att_head
],
attr=fluid.initializer.TruncatedNormalInitializer(
loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)),
dtype='float32',
name='trans_w')
# W_r
trans_weights_s1 = fl.create_parameter(
shape=[self.edge_type_count, self.embedding_u_size, self.dim_a],
attr=fluid.initializer.TruncatedNormalInitializer(
loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)),
dtype='float32',
name='trans_w_s1')
# w_r
trans_weights_s2 = fl.create_parameter(
shape=[self.edge_type_count, self.dim_a, self.att_head],
attr=fluid.initializer.TruncatedNormalInitializer(
loc=0.0, scale=1.0 / math.sqrt(self.embedding_size)),
dtype='float32',
name='trans_w_s2')
trans_w = fl.gather(trans_weights, self.train_types)
trans_w_s1 = fl.gather(trans_weights_s1, self.train_types)
trans_w_s2 = fl.gather(trans_weights_s2, self.train_types)
attention = self.attention(node_type_embed, trans_w_s1, trans_w_s2)
node_type_embed = fl.matmul(attention, node_type_embed)
node_embed = base_node_embed + fl.reshape(
fl.matmul(node_type_embed, trans_w), [-1, self.embedding_size])
self.last_node_embed = fl.l2_normalize(node_embed, axis=1)
nce_weight_initializer = fluid.initializer.TruncatedNormalInitializer(
loc=0.0, scale=1.0 / math.sqrt(self.embedding_size))
nce_weight_attrs = fluid.ParamAttr(
name='nce_weight', initializer=nce_weight_initializer)
weight_pos = fl.embedding(
input=self.train_labels,
size=[self.num_nodes, self.embedding_size],
param_attr=nce_weight_attrs)
weight_neg = fl.embedding(
input=self.train_negs,
size=[self.num_nodes, self.embedding_size],
param_attr=nce_weight_attrs)
tmp_node_embed = fl.unsqueeze(self.last_node_embed, axes=[1])
pos_logits = fl.matmul(
tmp_node_embed, weight_pos, transpose_y=True) # [B, 1, 1]
neg_logits = fl.matmul(
tmp_node_embed, weight_neg, transpose_y=True) # [B, 1, neg_num]
pos_score = fl.squeeze(pos_logits, axes=[1])
pos_score = fl.clip(pos_score, min=-10, max=10)
pos_score = -1.0 * fl.logsigmoid(pos_score)
neg_score = fl.squeeze(neg_logits, axes=[1])
neg_score = fl.clip(neg_score, min=-10, max=10)
neg_score = -1.0 * fl.logsigmoid(-1.0 * neg_score)
neg_score = fl.reduce_sum(neg_score, dim=1, keep_dim=True)
self.loss = fl.reduce_mean(pos_score + neg_score)
def attention(self, node_type_embed, trans_w_s1, trans_w_s2):
"""Calculate attention weights.
"""
attention = fl.tanh(fl.matmul(node_type_embed, trans_w_s1))
attention = fl.matmul(attention, trans_w_s2)
attention = fl.reshape(attention, [-1, self.u_num])
attention = fl.softmax(attention)
attention = fl.reshape(attention, [-1, self.att_head, self.u_num])
return attention
def message_passing(self, gw, edge_types, features, name=''):
"""Message passing from source nodes to dstination nodes
"""
def __message(src_feat, dst_feat, edge_feat):
"""send function
"""
return src_feat['h']
def __reduce(feat):
"""recv function
"""
return fluid.layers.sequence_pool(feat, pool_type='average')
if not isinstance(edge_types, list):
edge_types = [edge_types]
if not isinstance(features, list):
features = [features]
assert len(edge_types) == len(features)
output = []
for i in range(len(edge_types)):
msg = gw[edge_types[i]].send(
__message, nfeat_list=[('h', features[i])])
neigh_feat = gw[edge_types[i]].recv(msg, __reduce)
neigh_feat = fl.fc(neigh_feat,
size=neigh_feat.shape[-1],
name='neigh_fc_%d' % (i),
act='sigmoid')
slf_feat = fl.fc(features[i],
size=neigh_feat.shape[-1],
name='slf_fc_%d' % (i),
act='sigmoid')
out = fluid.layers.concat([slf_feat, neigh_feat], axis=1)
out = fl.fc(out, size=neigh_feat.shape[-1], name='fc', act=None)
out = fluid.layers.l2_normalize(out, axis=1)
output.append(out)
# list of matrix
return output
def backward(self, global_steps, opt_config):
"""Build the optimizer.
"""
self.lr = fl.polynomial_decay(opt_config['lr'], global_steps, 0.001)
adam = fluid.optimizer.Adam(learning_rate=self.lr)
adam.minimize(self.loss)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implement a class for model configure.
"""
import datetime
import os
import yaml
import random
import shutil
class Config(object):
"""Implementation of Config class for model configure.
Args:
config_file(str): configure filename, which is a yaml file.
isCreate(bool): if true, create some neccessary directories to save models, log file and other outputs.
isSave(bool): if true, save config_file in order to record the configure message.
"""
def __init__(self, config_file, isCreate=False, isSave=False):
self.config_file = config_file
self.config = self.get_config_from_yaml(config_file)
if isCreate:
self.create_necessary_dirs()
if isSave:
self.save_config_file()
def get_config_from_yaml(self, yaml_file):
"""Get the configure hyperparameters from yaml file.
"""
try:
with open(yaml_file, 'r') as f:
config = yaml.load(f)
except Exception:
raise IOError("Error in parsing config file '%s'" % yaml_file)
return config
def create_necessary_dirs(self):
"""Create some necessary directories to save some important files.
"""
time_stamp = datetime.datetime.now().strftime('%m%d_%H%M')
self.config['trainer']['args']['log_dir'] = ''.join(
(self.config['trainer']['args']['log_dir'],
self.config['task_name'], '/')) # , '.%s/' % (time_stamp)))
self.config['trainer']['args']['save_dir'] = ''.join(
(self.config['trainer']['args']['save_dir'],
self.config['task_name'], '/')) # , '.%s/' % (time_stamp)))
self.config['trainer']['args']['output_dir'] = ''.join(
(self.config['trainer']['args']['output_dir'],
self.config['task_name'], '/')) # , '.%s/' % (time_stamp)))
# if os.path.exists(self.config['trainer']['args']['save_dir']):
# input('save_dir is existed, do you really want to continue?')
self.make_dir(self.config['trainer']['args']['log_dir'])
self.make_dir(self.config['trainer']['args']['save_dir'])
self.make_dir(self.config['trainer']['args']['output_dir'])
def save_config_file(self):
"""Save config file so that we can know the config when we look back
"""
filename = self.config_file.split('/')[-1]
targetpath = self.config['trainer']['args']['save_dir']
shutil.copyfile(self.config_file, targetpath + filename)
def make_dir(self, path):
"""Build directory"""
if not os.path.exists(path):
os.makedirs(path)
def __getitem__(self, key):
"""Return the configure dict"""
return self.config[key]
def __call__(self):
"""__call__"""
return self.config
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file loads and preprocesses the dataset for metapath2vec model.
"""
import sys
import os
import glob
import numpy as np
import tqdm
import time
import logging
import random
from pgl.contrib import heter_graph
import pickle as pkl
class Dataset(object):
"""Implementation of Dataset class
This is a simple implementation of loading and processing dataset for metapath2vec model.
Args:
config: dict, some configure parameters.
"""
NEGATIVE_TABLE_SIZE = 1e8
def __init__(self, config):
self.config = config
self.walk_files = config['input_path'] + config['walk_path']
self.word2id_file = config['input_path'] + config['word2id_file']
self.word2freq = {}
self.word2id = {}
self.id2word = {}
self.sentences_count = 0
self.token_count = 0
self.negatives = []
self.discards = []
logging.info('reading sentences')
self.read_words()
logging.info('initializing discards')
self.initDiscards()
logging.info('initializing negatives')
self.initNegatives()
def read_words(self):
"""Read words(nodes) from walk files which are produced by sampler.
"""
word_freq = dict()
for walk_file in glob.glob(self.walk_files):
with open(walk_file, 'r') as reader:
for walk in reader:
walk = walk.strip().split(' ')
if len(walk) > 1:
self.sentences_count += 1
for word in walk:
self.token_count += 1
word_freq[word] = word_freq.get(word, 0) + 1
wid = 0
logging.info('Read %d sentences.' % self.sentences_count)
logging.info('Read %d words.' % self.token_count)
logging.info('%d words have been sampled.' % len(word_freq))
for w, c in word_freq.items():
if c < self.config['min_count']:
continue
self.word2id[w] = wid
self.id2word[wid] = w
self.word2freq[wid] = c
wid += 1
self.word_count = len(self.word2id)
logging.info(
'%d words displayed less than %d(min_count) have been discarded.' %
(len(word_freq) - len(self.word2id), self.config['min_count']))
pkl.dump(self.word2id, open(self.word2id_file, 'wb'))
def initDiscards(self):
"""Get a frequency table for sub-sampling.
"""
t = 0.0001
f = np.array(list(self.word2freq.values())) / self.token_count
self.discards = np.sqrt(t / f) + (t / f)
def initNegatives(self):
"""Get a table for negative sampling
"""
pow_freq = np.array(list(self.word2freq.values()))**0.75
words_pow = sum(pow_freq)
ratio = pow_freq / words_pow
count = np.round(ratio * Dataset.NEGATIVE_TABLE_SIZE)
for wid, c in enumerate(count):
self.negatives += [wid] * int(c)
self.negatives = np.array(self.negatives)
np.random.shuffle(self.negatives)
self.sampling_prob = ratio
def getNegatives(self, size):
"""Get negative samples from negative samling table.
"""
return np.random.choice(self.negatives, size)
def walk_from_files(self, walkpath_files):
"""Generate walks from files.
"""
bucket = []
for filename in walkpath_files:
with open(filename) as reader:
for line in reader:
words = line.strip().split(' ')
if len(words) > 1:
word_ids = [
self.word2id[w] for w in words if w in self.word2id
]
bucket.append(word_ids)
if len(bucket) == self.config['batch_size']:
yield bucket
bucket = []
if len(bucket):
yield bucket
def pairs_generator(self, walkpath_files):
"""Generate train pairs(src, pos, negs) for training model.
"""
def wrapper():
"""wrapper for multiprocess calling.
"""
for walks in self.walk_from_files(walkpath_files):
res = self.gen_pairs(walks)
yield res
return wrapper
def gen_pairs(self, walks):
"""Generate train pairs data for training model.
"""
src = []
pos = []
negs = []
skip_window = self.config['win_size'] // 2
for walk in walks:
for i in range(len(walk)):
for j in range(1, skip_window + 1):
if i - j >= 0:
src.append(walk[i])
pos.append(walk[i - j])
negs.append(
self.getNegatives(size=self.config['neg_num']))
if i + j < len(walk):
src.append(walk[i])
pos.append(walk[i + j])
negs.append(
self.getNegatives(size=self.config['neg_num']))
src = np.array(src, dtype=np.int64).reshape(-1, 1, 1)
pos = np.array(pos, dtype=np.int64).reshape(-1, 1, 1)
negs = np.expand_dims(np.array(negs, dtype=np.int64), -1)
return {"src": src, "pos": pos, "negs": negs}
if __name__ == "__main__":
config = {
'input_path': './data/out_aminer_CPAPC/',
'walk_path': 'aminer_walks_CPAPC_500num_100len/*',
'author_label_file': 'author_label.txt',
'venue_label_file': 'venue_label.txt',
'remapping_author_label_file': 'multi_class_author_label.txt',
'remapping_venue_label_file': 'multi_class_venue_label.txt',
'word2id_file': 'word2id.pkl',
'win_size': 7,
'neg_num': 5,
'min_count': 2,
'batch_size': 1,
}
log_format = '%(asctime)s-%(levelname)s-%(name)s: %(message)s'
logging.basicConfig(level=getattr(logging, 'INFO'), format=log_format)
dataset = Dataset(config)
# PGL examples for metapath2vec
[metapath2vec](https://ericdongyx.github.io/papers/KDD17-dong-chawla-swami-metapath2vec.pdf) is a algorithm framework for representation learning in heterogeneous networks which contains multiple types of nodes and links. Given a heterogeneous graph, metapath2vec algorithm first generates meta-path-based random walks and then use skipgram model to train a language model. Based on PGL, we reproduce metapath2vec algorithm.
## Datasets
You can dowload datasets from [here](https://ericdongyx.github.io/metapath2vec/m2v.html)
We use the "aminer" data for example. After downloading the aminer data, put them, let's say, in ./data/net_aminer/ . We also need to put "label/" directory in ./data/.
## Dependencies
- paddlepaddle>=1.6
- pgl>=1.0.0
## Hyperparameters
All the hyper parameters are saved in config.yaml file. So before training, you can open the config.yaml to modify the hyper parameters as you like.
for example, you can change the \"use_cuda\" to \"True \" in order to use GPU for training or modify \"data_path\" to specify the data you want.
Some important hyper parameters in config.yaml:
- **use_cuda**: use GPU to train model
- **data_path**: the directory of dataset that you want to load
- **lr**: learning rate
- **neg_num**: number of negatie samples.
- **num_walks**: number of walks started from each node
- **walk_length**: walk length
- **metapath**: meta path scheme
## Metapath randomwalk sampling
Before training, we should generate some metapath random walks to train skipgram model. we can run the below command to produce metapath randomwalk data.
```sh
python sample.py -c config.yaml
```
## Training and Testing
After finishing metapath randomwalk sampling, you can run the below command to train and test the model.
```sh
python main.py -c config.yaml
python multi_class.py --dataset ./data/out_aminer_CPAPC/author_label.txt --word2id ./checkpoints/train.metapath2vec/word2id.pkl --ckpt_path ./checkpoints/train.metapath2vec/model_epoch5/
```
## Experiment results
| train_percent | Metric | PGL Result | Reported Result |
|---------------|----------|------------|-----------------|
| 50% | macro-F1 | 0.9249 | 0.9314 |
| 50% | micro-F1 | 0.9283 | 0.9365 |
task_name: train.metapath2vec
use_cuda: True
log_level: info
seed: 1667
sampler:
type:
args:
data_path: ./data/net_aminer/
author_label_file: ./data/label/googlescholar.8area.author.label.txt
venue_label_file: ./data/label/googlescholar.8area.venue.label.txt
output_path: ./data/out_aminer_CPAPC/
new_author_label_file: author_label.txt
new_venue_label_file: venue_label.txt
walk_saved_path: walks/
num_walks: 1000
walk_length: 100
metapath: conf-paper-author-paper-conf
optimizer:
type: Adam
args:
lr: 0.005
end_lr: 0.0001
trainer:
type: trainer
args:
epochs: 5
log_dir: logs/
save_dir: checkpoints/
output_dir: outputs/
num_sample_workers: 8
data_loader:
type: Dataset
args:
input_path: ./data/out_aminer_CPAPC/ # same path as output_path in sampler
walk_path: walks/*
word2id_file: word2id.pkl
batch_size: 32
win_size: 7 # default: 7
neg_num: 5
min_count: 10
model:
type: SkipgramModel
args:
embed_dim: 128
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implement the training process of metapath2vec model.
"""
import os
import sys
import argparse
import time
import numpy as np
import logging
import pickle as pkl
import shutil
import glob
import pgl
from pgl.utils import paddle_helper
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as fl
from utils import *
import Dataset
import model as Models
from pgl.utils import mp_reader
from sklearn.metrics import (auc, f1_score, precision_recall_curve,
roc_auc_score)
def set_seed(seed):
"""Set global random seed."""
random.seed(seed)
np.random.seed(seed)
def save_param(dirname, var_name_list):
"""save_param"""
if not os.path.exists(dirname):
os.makedirs(dirname)
for var_name in var_name_list:
var = fluid.global_scope().find_var(var_name)
var_tensor = var.get_tensor()
np.save(os.path.join(dirname, var_name + '.npy'), np.array(var_tensor))
def multiprocess_data_generator(config, dataset):
"""Using multiprocess to generate training data.
"""
num_sample_workers = config['trainer']['args']['num_sample_workers']
walkpath_files = [[] for i in range(num_sample_workers)]
for idx, f in enumerate(glob.glob(dataset.walk_files)):
walkpath_files[idx % num_sample_workers].append(f)
gen_data_pool = [
dataset.pairs_generator(files) for files in walkpath_files
]
if num_sample_workers == 1:
gen_data_func = gen_data_pool[0]
else:
gen_data_func = mp_reader.multiprocess_reader(
gen_data_pool, use_pipe=True, queue_size=100)
return gen_data_func
def run_epoch(epoch,
config,
data_generator,
train_prog,
model,
feed_dict,
exe,
for_test=False):
"""Run training process of every epoch.
"""
total_loss = []
for idx, batch_data in enumerate(data_generator()):
feed_dict['train_inputs'] = batch_data['src']
feed_dict['train_labels'] = batch_data['pos']
feed_dict['train_negs'] = batch_data['negs']
loss, lr = exe.run(train_prog,
feed=feed_dict,
fetch_list=[model.loss, model.lr],
return_numpy=True)
total_loss.append(loss[0])
if (idx + 1) % 500 == 0:
avg_loss = np.mean(total_loss)
logging.info("epoch %d | step %d | lr %.4f | train_loss %f " %
(epoch, idx + 1, lr, avg_loss))
total_loss = []
def main(config):
"""main function for training metapath2vec model.
"""
logging.info(config)
set_seed(config['seed'])
dataset = getattr(
Dataset, config['data_loader']['type'])(config['data_loader']['args'])
data_generator = multiprocess_data_generator(config, dataset)
# move word2id file to checkpoints directory
src_word2id_file = dataset.word2id_file
dst_wor2id_file = config['trainer']['args']['save_dir'] + config[
'data_loader']['args']['word2id_file']
logging.info('backup word2id file to %s' % dst_wor2id_file)
shutil.move(src_word2id_file, dst_wor2id_file)
place = fluid.CUDAPlace(0) if config['use_cuda'] else fluid.CPUPlace()
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
model = getattr(Models, config['model']['type'])(
dataset=dataset, config=config['model']['args'], place=place)
with fluid.program_guard(train_program, startup_program):
global_steps = int(dataset.sentences_count *
config['trainer']['args']['epochs'] /
config['data_loader']['args']['batch_size'])
model.backward(global_steps, config['optimizer']['args'])
# train
exe = fluid.Executor(place)
exe.run(startup_program)
feed_dict = {}
logging.info('training...')
for epoch in range(1, 1 + config['trainer']['args']['epochs']):
run_epoch(epoch, config['trainer']['args'], data_generator,
train_program, model, feed_dict, exe)
logging.info('saving model...')
cur_save_path = os.path.join(config['trainer']['args']['save_dir'],
"model_epoch%d" % (epoch))
save_param(cur_save_path, ['content'])
logging.info('finishing training')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='metapath2vec')
parser.add_argument(
'-c',
'--config',
default=None,
type=str,
help='config file path (default: None)')
parser.add_argument(
'-n',
'--taskname',
default=None,
type=str,
help='task name(default: None)')
args = parser.parse_args()
if args.config:
# load config file
config = Config(args.config, isCreate=True, isSave=True)
config = config()
else:
raise AssertionError(
"Configuration file need to be specified. Add '-c config.yaml', for example."
)
log_format = '%(asctime)s-%(levelname)s-%(name)s: %(message)s'
logging.basicConfig(
level=getattr(logging, config['log_level'].upper()), format=log_format)
main(config)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implement the skipgram model for training metapath2vec.
"""
import argparse
import time
import math
import os
import io
from multiprocessing import Pool
import logging
import numpy as np
import glob
import pgl
from pgl import data_loader
from pgl.utils import op
from pgl.utils.logger import log
import paddle.fluid as fluid
import paddle.fluid.layers as fl
class SkipgramModel(object):
"""Implemetation of skipgram model.
Args:
config: dict, some configure parameters.
dataset: instance of Dataset class
place: GPU or CPU place
"""
def __init__(self, config, dataset, place):
self.config = config
self.dataset = dataset
self.place = place
self.neg_num = self.dataset.config['neg_num']
self.num_nodes = len(dataset.word2id)
self.train_inputs = fl.data(
'train_inputs', shape=[None, 1, 1], dtype='int64')
self.train_labels = fl.data(
'train_labels', shape=[None, 1, 1], dtype='int64')
self.train_negs = fl.data(
'train_negs', shape=[None, self.neg_num, 1], dtype='int64')
self.forward()
def backward(self, global_steps, opt_config):
"""Build the optimizer.
"""
self.lr = fl.polynomial_decay(opt_config['lr'], global_steps,
opt_config['end_lr'])
adam = fluid.optimizer.Adam(learning_rate=self.lr)
adam.minimize(self.loss)
def forward(self):
"""Build the skipgram model.
"""
initrange = 1.0 / self.config['embed_dim']
embed_init = fluid.initializer.UniformInitializer(
low=-initrange, high=initrange)
weight_init = fluid.initializer.TruncatedNormal(
scale=1.0 / math.sqrt(self.config['embed_dim']))
embed_src = fl.embedding(
input=self.train_inputs,
size=[self.num_nodes, self.config['embed_dim']],
param_attr=fluid.ParamAttr(
name='content', initializer=embed_init))
weight_pos = fl.embedding(
input=self.train_labels,
size=[self.num_nodes, self.config['embed_dim']],
param_attr=fluid.ParamAttr(
name='weight', initializer=weight_init))
weight_negs = fl.embedding(
input=self.train_negs,
size=[self.num_nodes, self.config['embed_dim']],
param_attr=fluid.ParamAttr(
name='weight', initializer=weight_init))
pos_logits = fl.matmul(
embed_src, weight_pos, transpose_y=True) # [batch_size, 1, 1]
pos_score = fl.squeeze(pos_logits, axes=[1])
pos_score = fl.clip(pos_score, min=-10, max=10)
pos_score = -1.0 * fl.logsigmoid(pos_score)
neg_logits = fl.matmul(
embed_src, weight_negs,
transpose_y=True) # [batch_size, 1, neg_num]
neg_score = fl.squeeze(neg_logits, axes=[1])
neg_score = fl.clip(neg_score, min=-10, max=10)
neg_score = -1.0 * fl.logsigmoid(-1.0 * neg_score)
neg_score = fl.reduce_sum(neg_score, dim=1, keep_dim=True)
self.loss = fl.reduce_mean(pos_score + neg_score)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file provides the multi class task for testing the embedding learned by metapath2vec model.
"""
import argparse
import sys
import os
import tqdm
import time
import math
import logging
import random
import pickle as pkl
import numpy as np
import sklearn.metrics
from sklearn.metrics import f1_score
import pgl
import paddle.fluid as fluid
import paddle.fluid.layers as fl
import Dataset
from utils import *
def load_param(dirname, var_name_list):
"""load_param"""
for var_name in var_name_list:
var = fluid.global_scope().find_var(var_name)
var_tensor = var.get_tensor()
var_tmp = np.load(os.path.join(dirname, var_name + '.npy'))
var_tensor.set(var_tmp, fluid.CPUPlace())
def load_data(file_, word2id):
"""Load data for node classification.
"""
words_label = []
line_count = 0
with open(file_, 'r') as reader:
for line in reader:
line_count += 1
tokens = line.strip().split(' ')
word, label = tokens[0], int(tokens[1]) - 1
if word in word2id:
words_label.append((word2id[word], label))
words_label = np.array(words_label, dtype=np.int64)
np.random.shuffle(words_label)
logging.info('%d/%d word_label pairs have been loaded' %
(len(words_label), line_count))
return words_label
def node_classify_model(word2id, num_labels, embed_dim=16):
"""Build node classify model.
Args:
word2id(dict): map word(node) to its corresponding index
num_labels: The number of labels.
embed_dim: The dimension of embedding.
"""
nodes = fl.data('nodes', shape=[None, 1], dtype='int64')
labels = fl.data('labels', shape=[None, 1], dtype='int64')
embed_nodes = fl.embedding(
input=nodes,
size=[len(word2id), embed_dim],
param_attr=fluid.ParamAttr(name='content'))
embed_nodes.stop_gradient = True
probs = fl.fc(input=embed_nodes, size=num_labels, act='softmax')
predict = fl.argmax(probs, axis=-1)
loss = fl.cross_entropy(input=probs, label=labels)
loss = fl.reduce_mean(loss)
return {
'loss': loss,
'probs': probs,
'predict': predict,
'labels': labels,
}
def run_epoch(exe, prog, model, feed_dict, lr):
"""Run training process of every epoch.
"""
if lr is None:
loss, predict = exe.run(prog,
feed=feed_dict,
fetch_list=[model['loss'], model['predict']],
return_numpy=True)
lr_ = 0
else:
loss, predict, lr_ = exe.run(
prog,
feed=feed_dict,
fetch_list=[model['loss'], model['predict'], lr],
return_numpy=True)
macro_f1 = f1_score(feed_dict['labels'], predict, average="macro")
micro_f1 = f1_score(feed_dict['labels'], predict, average="micro")
return {
'loss': loss,
'pred': predict,
'lr': lr_,
'macro_f1': macro_f1,
'micro_f1': micro_f1
}
def main(args):
"""main function for training node classification task.
"""
word2id = pkl.load(open(args.word2id, 'rb'))
words_label = load_data(args.dataset, word2id)
# split data for training and testing
split_position = int(words_label.shape[0] * args.train_percent)
train_words_label = words_label[0:split_position, :]
test_words_label = words_label[split_position:, :]
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
train_prog = fluid.Program()
test_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
model = node_classify_model(
word2id, args.num_labels, embed_dim=args.embed_dim)
test_prog = train_prog.clone(for_test=True)
with fluid.program_guard(train_prog, startup_prog):
lr = fl.polynomial_decay(args.lr, 1000, 0.001)
adam = fluid.optimizer.Adam(lr)
adam.minimize(model['loss'])
exe = fluid.Executor(place)
exe.run(startup_prog)
load_param(args.ckpt_path, ['content'])
feed_dict = {}
X = train_words_label[:, 0].reshape(-1, 1)
labels = train_words_label[:, 1].reshape(-1, 1)
logging.info('%d/%d data to train' %
(labels.shape[0], words_label.shape[0]))
test_feed_dict = {}
test_X = test_words_label[:, 0].reshape(-1, 1)
test_labels = test_words_label[:, 1].reshape(-1, 1)
logging.info('%d/%d data to test' %
(test_labels.shape[0], words_label.shape[0]))
for epoch in range(args.epochs):
feed_dict['nodes'] = X
feed_dict['labels'] = labels
train_result = run_epoch(exe, train_prog, model, feed_dict, lr)
test_feed_dict['nodes'] = test_X
test_feed_dict['labels'] = test_labels
test_result = run_epoch(exe, test_prog, model, test_feed_dict, lr=None)
logging.info(
'epoch %d | lr %.4f | train_loss %.5f | train_macro_F1 %.4f | train_micro_F1 %.4f | test_loss %.5f | test_macro_F1 %.4f | test_micro_F1 %.4f'
% (epoch, train_result['lr'], train_result['loss'],
train_result['macro_f1'], train_result['micro_f1'],
test_result['loss'], test_result['macro_f1'],
test_result['micro_f1']))
logging.info(
'final_test_macro_f1 score: %.4f | final_test_micro_f1 score: %.4f' %
(test_result['macro_f1'], test_result['micro_f1']))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='multi_class')
parser.add_argument(
'--dataset',
default=None,
type=str,
help='training and testing data file(default: None)')
parser.add_argument(
'--word2id',
default=None,
type=str,
help='word2id file (default: None)')
parser.add_argument(
'--ckpt_path', default=None, type=str, help='task name(default: None)')
parser.add_argument("--use_cuda", action='store_true', help="use_cuda")
parser.add_argument(
'--train_percent',
default=0.5,
type=float,
help='train_percent(default: 0.5)')
parser.add_argument(
'--num_labels',
default=8,
type=int,
help='number of labels(default: 8)')
parser.add_argument(
'--epochs',
default=100,
type=int,
help='number of epochs for training(default: 10)')
parser.add_argument(
'--lr',
default=0.025,
type=float,
help='learning rate(default: 0.025)')
parser.add_argument(
'--embed_dim',
default=128,
type=int,
help='dimension of embedding(default: 128)')
args = parser.parse_args()
log_format = '%(asctime)s-%(levelname)s-%(name)s: %(message)s'
logging.basicConfig(level='INFO', format=log_format)
main(args)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implement the sampler to sample metapath random walk sequence for
training metapath2vec model.
"""
import multiprocessing
from multiprocessing import Pool
import argparse
import sys
import os
import numpy as np
import pickle as pkl
import tqdm
import time
import logging
import random
from pgl.contrib import heter_graph
from pgl.sample import metapath_randomwalk
from utils import *
class Sampler(object):
"""Implemetation of sampler in order to sample metapath random walk.
Args:
config: dict, some configure parameters.
"""
def __init__(self, config):
self.config = config
self.build_graph()
def build_graph(self):
"""Build pgl heterogeneous graph.
"""
self.conf_id2index, self.conf_name2index, conf_node_type = self.remapping_id(
self.config['data_path'] + 'id_conf.txt',
start_index=0,
node_type='conf')
logging.info('%d venues have been loaded.' % (len(self.conf_id2index)))
self.author_id2index, self.author_name2index, author_node_type = self.remapping_id(
self.config['data_path'] + 'id_author.txt',
start_index=len(self.conf_id2index),
node_type='author')
logging.info('%d authors have been loaded.' %
(len(self.author_id2index)))
self.paper_id2index, self.paper_name2index, paper_node_type = self.remapping_id(
self.config['data_path'] + 'paper.txt',
start_index=(len(self.conf_id2index) + len(self.author_id2index)),
node_type='paper',
separator='\t')
logging.info('%d papers have been loaded.' %
(len(self.paper_id2index)))
node_types = conf_node_type + author_node_type + paper_node_type
num_nodes = len(node_types)
edges_by_types = {}
paper_author_edges = self.load_edges(
self.config['data_path'] + 'paper_author.txt', self.paper_id2index,
self.author_id2index)
paper_conf_edges = self.load_edges(
self.config['data_path'] + 'paper_conf.txt', self.paper_id2index,
self.conf_id2index)
edges_by_types['edge'] = paper_author_edges + paper_conf_edges
logging.info('%d edges have been loaded.' %
(len(edges_by_types['edge'])))
node_features = {
'index': np.array([i for i in range(num_nodes)]).reshape(
-1, 1).astype(np.int64)
}
self.graph = heter_graph.HeterGraph(
num_nodes=num_nodes,
edges=edges_by_types,
node_types=node_types,
node_feat=node_features)
def remapping_id(self, file_, start_index, node_type, separator='\t'):
"""Mapp the ID and name of nodes to index.
"""
node_types = []
id2index = {}
name2index = {}
index = start_index
with open(file_, encoding="ISO-8859-1") as reader:
for line in reader:
tokens = line.strip().split(separator)
id2index[tokens[0]] = index
if len(tokens) == 2:
name2index[tokens[1]] = index
node_types.append((index, node_type))
index += 1
return id2index, name2index, node_types
def load_edges(self, file_, src2index, dst2index, symmetry=True):
"""Load edges from file.
"""
edges = []
with open(file_, 'r') as reader:
for line in reader:
items = line.strip().split()
src, dst = src2index[items[0]], dst2index[items[1]]
edges.append((src, dst))
if symmetry:
edges.append((dst, src))
edges = list(set(edges))
return edges
def generate_multi_class_data(self, name_label_file):
"""Mapp the data that will be used in multi class task to index.
"""
if 'author' in name_label_file:
name2index = self.author_name2index
else:
name2index = self.conf_name2index
index_label_list = []
with open(name_label_file, encoding="ISO-8859-1") as reader:
for line in reader:
tokens = line.strip().split(' ')
name, label = tokens[0], int(tokens[1])
index = name2index[name]
index_label_list.append((index, label))
return index_label_list
def generate_walks(args):
"""Generate metapath random walk and save to file.
"""
g, meta_path, filename, walk_length = args
walks = []
node_types = g._node_types
first_type = meta_path.split('-')[0]
nodes = np.where(node_types == first_type)[0]
if len(nodes) > 4000:
nodes = np.random.choice(nodes, 4000, replace=False)
logging.info('%d number of start nodes' % (len(nodes)))
logging.info('save walks in file: %s' % (filename))
with open(filename, 'w') as writer:
for start_node in nodes:
walk = metapath_randomwalk(g, start_node, meta_path, walk_length)
walk = [str(walk[i]) for i in range(0, len(walk), 2)] # skip paper
writer.write(' '.join(walk) + '\n')
def multiprocess_generate_walks(sampler, edge_type, meta_path, num_walks,
walk_length, saved_path):
"""Use multiprocess to generate metapath random walk.
"""
args = []
for i in range(num_walks):
filename = saved_path + '%04d' % (i)
args.append(
(sampler.graph[edge_type], meta_path, filename, walk_length))
pool = Pool(16)
pool.map(generate_walks, args)
pool.close()
pool.join()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='metapath2vec')
parser.add_argument(
'-c',
'--config',
default=None,
type=str,
help='config file path (default: None)')
args = parser.parse_args()
if args.config:
# load config file
config = Config(args.config, isCreate=False, isSave=False)
config = config()
config = config['sampler']['args']
else:
raise AssertionError(
"Configuration file need to be specified. Add '-c config.yaml', for example."
)
log_format = '%(asctime)s-%(levelname)s-%(name)s: %(message)s'
logging.basicConfig(level="INFO", format=log_format)
logging.info(config)
log_format = '%(asctime)s-%(levelname)s-%(name)s: %(message)s'
logging.basicConfig(level=getattr(logging, 'INFO'), format=log_format)
if not os.path.exists(config['output_path']):
os.makedirs(config['output_path'])
config['walk_saved_path'] = config['output_path'] + config[
'walk_saved_path']
if not os.path.exists(config['walk_saved_path']):
os.makedirs(config['walk_saved_path'])
sampler = Sampler(config)
begin = time.time()
logging.info('multi process sampling')
multiprocess_generate_walks(
sampler=sampler,
edge_type='edge',
meta_path=config['metapath'],
num_walks=config['num_walks'],
walk_length=config['walk_length'],
saved_path=config['walk_saved_path'])
logging.info('total time: %.4f' % (time.time() - begin))
logging.info('generating multi class data')
word_label_list = sampler.generate_multi_class_data(config[
'author_label_file'])
with open(config['output_path'] + config['new_author_label_file'],
'w') as writer:
for line in word_label_list:
line = [str(i) for i in line]
writer.write(' '.join(line) + '\n')
word_label_list = sampler.generate_multi_class_data(config[
'venue_label_file'])
with open(config['output_path'] + config['new_venue_label_file'],
'w') as writer:
for line in word_label_list:
line = [str(i) for i in line]
writer.write(' '.join(line) + '\n')
logging.info('finished')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file implement a class for model configure.
"""
import datetime
import os
import yaml
import random
import shutil
class Config(object):
"""Implementation of Config class for model configure.
Args:
config_file(str): configure filename, which is a yaml file.
isCreate(bool): if true, create some neccessary directories to save models, log file and other outputs.
isSave(bool): if true, save config_file in order to record the configure message.
"""
def __init__(self, config_file, isCreate=False, isSave=False):
self.config_file = config_file
self.config = self.get_config_from_yaml(config_file)
if isCreate:
self.create_necessary_dirs()
if isSave:
self.save_config_file()
def get_config_from_yaml(self, yaml_file):
"""Get the configure hyperparameters from yaml file.
"""
try:
with open(yaml_file, 'r') as f:
config = yaml.load(f)
except Exception:
raise IOError("Error in parsing config file '%s'" % yaml_file)
return config
def create_necessary_dirs(self):
"""Create some necessary directories to save some important files.
"""
time_stamp = datetime.datetime.now().strftime('%m%d_%H%M')
self.config['trainer']['args']['log_dir'] = ''.join(
(self.config['trainer']['args']['log_dir'],
self.config['task_name'], '/')) # , '.%s/' % (time_stamp)))
self.config['trainer']['args']['save_dir'] = ''.join(
(self.config['trainer']['args']['save_dir'],
self.config['task_name'], '/')) # , '.%s/' % (time_stamp)))
self.config['trainer']['args']['output_dir'] = ''.join(
(self.config['trainer']['args']['output_dir'],
self.config['task_name'], '/')) # , '.%s/' % (time_stamp)))
# if os.path.exists(self.config['trainer']['args']['save_dir']):
# input('save_dir is existed, do you really want to continue?')
self.make_dir(self.config['trainer']['args']['log_dir'])
self.make_dir(self.config['trainer']['args']['save_dir'])
self.make_dir(self.config['trainer']['args']['output_dir'])
def save_config_file(self):
"""Save config file so that we can know the config when we look back
"""
filename = self.config_file.split('/')[-1]
targetpath = self.config['trainer']['args']['save_dir']
shutil.copyfile(self.config_file, targetpath + filename)
def make_dir(self, path):
"""Build directory"""
if not os.path.exists(path):
os.makedirs(path)
def __getitem__(self, key):
"""Return the configure dict"""
return self.config[key]
def __call__(self):
"""__call__"""
return self.config
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册