提交 a10bb833 编写于 作者: W Webbley

add distribute metapath2vec

上级 8b7534c2
# Distributed metapath2vec in PGL
[metapath2vec](https://ericdongyx.github.io/papers/KDD17-dong-chawla-swami-metapath2vec.pdf) is a algorithm framework for representation learning in heterogeneous networks which contains multiple types of nodes and links. Given a heterogeneous graph, metapath2vec algorithm first generates meta-path-based random walks and then use skipgram model to train a language model. Based on PGL, we reproduce metapath2vec algorithm in distributed mode.
## Datasets
DBLP: The dataset contains 14376 papers (P), 20 conferences (C), 14475 authors (A), and 8920 terms (T). There are 33791 nodes in this dataset.
You can dowload datasets from [here](https://github.com/librahu/HIN-Datasets-for-Recommendation-and-Network-Embedding)
We use the ```DBLP``` dataset for example. After downloading the dataset, put them, let's say, in ```./data/DBLP/``` .
## Dependencies
- paddlepaddle>=1.6
- pgl>=1.0.0
## How to run
Before training, run the below command to do data preprocessing.
```sh
python data_process.py --data_path ./data/DBLP --output_path ./data/data_processed
```
We adopt [PaddlePaddle Fleet](https://github.com/PaddlePaddle/Fleet) as our distributed training frameworks. ```config.yaml``` is a configure file for metapath2vec hyperparameters and ```local_config``` is a configure file for parameter servers of PaddlePaddle. By default, we have 2 pservers and 2 trainers. One can use ```cloud_run.sh``` to help startup the parameter servers and model trainers.
For examples, train metapath2vec in distributed mode on DBLP dataset.
```sh
# train metapath2vec in distributed mode.
sh cloud_run.sh
# multiclass task example
python multi_class.py --dataset ./data/data_processed/author_label.txt --ckpt_path ./checkpoints/2000 --num_nodes 33791
```
## Hyperparameters
All the hyper parameters are saved in ```config.yaml``` file. So before training, you can open the config.yaml to modify the hyper parameters as you like.
Some important hyper parameters in config.yaml:
- **edge_path**: the directory of graph data that you want to load
- **lr**: learning rate
- **neg_num**: number of negative samples.
- **num_walks**: number of walks started from each node
- **walk_len**: walk length
- **meta_path**: meta path scheme
#!/bin/bash
set -x
mode=${1}
source ./utils.sh
unset http_proxy https_proxy
source ./local_config
if [ ! -d ${log_dir} ]; then
mkdir ${log_dir}
fi
for((i=0;i<${PADDLE_PSERVERS_NUM};i++))
do
echo "start ps server: ${i}"
echo $log_dir
TRAINING_ROLE="PSERVER" PADDLE_TRAINER_ID=${i} sh job.sh &> $log_dir/pserver.$i.log &
done
sleep 10s
for((j=0;j<${PADDLE_TRAINERS_NUM};j++))
do
echo "start ps work: ${j}"
TRAINING_ROLE="TRAINER" PADDLE_TRAINER_ID=${j} sh job.sh &> $log_dir/worker.$j.log &
done
tail -f $log_dir/worker.0.log
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import time
import os
import math
import numpy as np
import paddle.fluid as F
import paddle.fluid.layers as L
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from pgl.utils.logger import log
from model import Metapath2vecModel
from graph import m2vGraph
from utils import load_config
from walker import multiprocess_data_generator
def init_role():
# reset the place according to role of parameter server
training_role = os.getenv("TRAINING_ROLE", "TRAINER")
paddle_role = role_maker.Role.WORKER
place = F.CPUPlace()
if training_role == "PSERVER":
paddle_role = role_maker.Role.SERVER
# set the fleet runtime environment according to configure
ports = os.getenv("PADDLE_PORT", "6174").split(",")
pserver_ips = os.getenv("PADDLE_PSERVERS").split(",") # ip,ip...
eplist = []
if len(ports) > 1:
# local debug mode, multi port
for port in ports:
eplist.append(':'.join([pserver_ips[0], port]))
else:
# distributed mode, multi ip
for ip in pserver_ips:
eplist.append(':'.join([ip, ports[0]]))
pserver_endpoints = eplist # ip:port,ip:port...
worker_num = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
role = role_maker.UserDefinedRoleMaker(
current_id=trainer_id,
role=paddle_role,
worker_num=worker_num,
server_endpoints=pserver_endpoints)
fleet.init(role)
def optimization(base_lr, loss, train_steps, optimizer='sgd'):
decayed_lr = L.learning_rate_scheduler.polynomial_decay(
learning_rate=base_lr,
decay_steps=train_steps,
end_learning_rate=0.0001 * base_lr,
power=1.0,
cycle=False)
if optimizer == 'sgd':
optimizer = F.optimizer.SGD(decayed_lr)
elif optimizer == 'adam':
optimizer = F.optimizer.Adam(decayed_lr, lazy_mode=True)
else:
raise ValueError
log.info('learning rate:%f' % (base_lr))
#create the DistributeTranspiler configure
config = DistributeTranspilerConfig()
config.sync_mode = False
#config.runtime_split_send_recv = False
config.slice_var_up = False
#create the distributed optimizer
optimizer = fleet.distributed_optimizer(optimizer, config)
optimizer.minimize(loss)
def build_complied_prog(train_program, model_loss):
num_threads = int(os.getenv("CPU_NUM", 10))
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
exec_strategy = F.ExecutionStrategy()
exec_strategy.num_threads = num_threads
#exec_strategy.use_experimental_executor = True
build_strategy = F.BuildStrategy()
build_strategy.enable_inplace = True
#build_strategy.memory_optimize = True
build_strategy.memory_optimize = False
build_strategy.remove_unnecessary_lock = False
if num_threads > 1:
build_strategy.reduce_strategy = F.BuildStrategy.ReduceStrategy.Reduce
compiled_prog = F.compiler.CompiledProgram(
train_program).with_data_parallel(loss_name=model_loss.name)
return compiled_prog
def train_prog(exe, program, loss, node2vec_pyreader, args, train_steps):
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
step = 0
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
while True:
try:
begin_time = time.time()
loss_val, = exe.run(program, fetch_list=[loss])
log.info("step %s: loss %.5f speed: %.5f s/step" %
(step, np.mean(loss_val), time.time() - begin_time))
step += 1
except F.core.EOFException:
node2vec_pyreader.reset()
if step % args.steps_per_save == 0 or step == train_steps:
save_path = args.save_path
if trainer_id == 0:
model_path = os.path.join(save_path, "%s" % step)
fleet.save_persistables(exe, model_path)
if step == train_steps:
break
def main(args):
log.info("start")
worker_num = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
num_devices = int(os.getenv("CPU_NUM", 10))
model = Metapath2vecModel(config=args)
pyreader = model.pyreader
loss = model.forward()
# init fleet
init_role()
train_steps = math.ceil(args.num_nodes * args.epochs / args.batch_size /
num_devices / worker_num)
log.info("Train step: %s" % train_steps)
real_batch_size = args.batch_size * args.walk_len * args.win_size
if args.optimizer == "sgd":
args.lr *= real_batch_size
optimization(args.lr, loss, train_steps, args.optimizer)
# init and run server or worker
if fleet.is_server():
fleet.init_server(args.warm_start_from_dir)
fleet.run_server()
if fleet.is_worker():
log.info("start init worker done")
fleet.init_worker()
#just the worker, load the sample
log.info("init worker done")
exe = F.Executor(F.CPUPlace())
exe.run(fleet.startup_program)
log.info("Startup done")
dataset = m2vGraph(args)
log.info("Build graph done.")
data_generator = multiprocess_data_generator(args, dataset)
cur_time = time.time()
for idx, _ in enumerate(data_generator()):
log.info("iter %s: %s s" % (idx, time.time() - cur_time))
cur_time = time.time()
if idx == 100:
break
pyreader.decorate_tensor_provider(data_generator)
pyreader.start()
compiled_prog = build_complied_prog(fleet.main_program, loss)
train_prog(exe, compiled_prog, loss, pyreader, args, train_steps)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='metapath2vec')
parser.add_argument("-c", "--config", type=str, default="./config.yaml")
args = parser.parse_args()
config = load_config(args.config)
log.info(config)
main(config)
# graph data config
edge_path: "./data/data_processed"
edge_files: "p2a:paper_author.txt,p2c:paper_conference.txt,p2t:paper_type.txt"
node_types_file: "node_types.txt"
num_nodes: 37791
symmetry: True
# skipgram pair data config
win_size: 5
neg_num: 5
# average; m2v_plus
neg_sample_type: "average"
# random walk config
# m2v; multi_m2v;
walk_mode: "m2v"
meta_path: "c2p-p2a-a2p-p2c"
first_node_type: "c"
walk_len: 24
batch_size: 4
node_shuffle: True
node_files: null
num_sample_workers: 2
# model config
embed_dim: 64
is_sparse: True
# only use when num_nodes > 100,000,000, slower than noraml embedding
is_distributed: False
# trainging config
epochs: 10
optimizer: "sgd"
lr: 1.0
warm_start_from_dir: null
walkpath_files: "None"
train_files: "None"
steps_per_save: 1000
save_path: "./checkpoints"
log_dir: "./logs"
CPU_NUM: 16
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data preprocessing for DBLP dataset"""
import sys
import os
import argparse
import numpy as np
from collections import OrderedDict
AUTHOR = 14475
PAPER = 14376
CONF = 20
TYPE = 8920
LABEL = 4
def build_node_types(meta_node, outfile):
"""build_node_types"""
nt_ori2new = {}
with open(outfile, 'w') as writer:
offset = 0
for node_type, num_nodes in meta_node.items():
ori_id2new_id = {}
for i in range(num_nodes):
writer.write("%d\t%s\n" % (offset + i, node_type))
ori_id2new_id[i + 1] = offset + i
nt_ori2new[node_type] = ori_id2new_id
offset += num_nodes
return nt_ori2new
def remapping_index(args, src_dict, dst_dict, ori_file, new_file):
"""remapping_index"""
ori_file = os.path.join(args.data_path, ori_file)
new_file = os.path.join(args.output_path, new_file)
with open(ori_file, 'r') as reader, open(new_file, 'w') as writer:
for line in reader:
slots = line.strip().split()
s = int(slots[0])
d = int(slots[1])
new_s = src_dict[s]
new_d = dst_dict[d]
writer.write("%d\t%d\n" % (new_s, new_d))
def author_label(args, ori_id2pgl_id, ori_file, real_file, new_file):
"""author_label"""
ori_file = os.path.join(args.data_path, ori_file)
real_file = os.path.join(args.data_path, real_file)
new_file = os.path.join(args.output_path, new_file)
real_id2pgl_id = {}
with open(ori_file, 'r') as reader:
for line in reader:
slots = line.strip().split()
ori_id = int(slots[0])
real_id = int(slots[1])
pgl_id = ori_id2pgl_id[ori_id]
real_id2pgl_id[real_id] = pgl_id
with open(real_file, 'r') as reader, open(new_file, 'w') as writer:
for line in reader:
slots = line.strip().split()
real_id = int(slots[0])
label = int(slots[1])
pgl_id = real_id2pgl_id[real_id]
writer.write("%d\t%d\n" % (pgl_id, label))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='DBLP data preprocessing')
parser.add_argument(
'--data_path',
default=None,
type=str,
help='original data path(default: None)')
parser.add_argument(
'--output_path',
default=None,
type=str,
help='output path(default: None)')
args = parser.parse_args()
meta_node = OrderedDict()
meta_node['a'] = AUTHOR
meta_node['p'] = PAPER
meta_node['c'] = CONF
meta_node['t'] = TYPE
if not os.path.exists(args.output_path):
os.makedirs(args.output_path)
node_types_file = os.path.join(args.output_path, "node_types.txt")
nt_ori2new = build_node_types(meta_node, node_types_file)
remapping_index(args, nt_ori2new['p'], nt_ori2new['a'], 'paper_author.dat',
'paper_author.txt')
remapping_index(args, nt_ori2new['p'], nt_ori2new['c'],
'paper_conference.dat', 'paper_conference.txt')
remapping_index(args, nt_ori2new['p'], nt_ori2new['t'], 'paper_type.dat',
'paper_type.txt')
author_label(args, nt_ori2new['a'], 'author_map_id.dat',
'author_label.dat', 'author_label.txt')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import sys
import os
import numpy as np
import pickle as pkl
import tqdm
import time
import random
from pgl.utils.logger import log
from pgl import heter_graph
class m2vGraph(object):
"""Implemetation of graph in order to sample metapath random walk.
"""
def __init__(self, config):
self.edge_path = config.edge_path
self.num_nodes = config.num_nodes
self.symmetry = config.symmetry
edge_files = config.edge_files
node_types_file = config.node_types_file
self.edge_file_list = []
for pair in edge_files.split(','):
e_type, filename = pair.split(':')
filename = os.path.join(self.edge_path, filename)
self.edge_file_list.append((e_type, filename))
self.node_types_file = os.path.join(self.edge_path, node_types_file)
self.build_graph()
def build_graph(self):
"""Build pgl heterogeneous graph.
"""
edges_by_types = {}
npy = self.edge_file_list[0][1] + ".npy"
if os.path.exists(npy):
log.info("load data from numpy file")
for pair in self.edge_file_list:
edges_by_types[pair[0]] = np.load(pair[1] + ".npy")
else:
log.info("load data from txt file")
for pair in self.edge_file_list:
edges_by_types[pair[0]] = self.load_edges(pair[1])
# np.save(pair[1] + ".npy", edges_by_types[pair[0]])
for e_type, edges in edges_by_types.items():
log.info(["number of %s edges: " % e_type, len(edges)])
if self.symmetry:
tmp = {}
for key, edges in edges_by_types.items():
n_list = key.split('2')
re_key = n_list[1] + '2' + n_list[0]
tmp[re_key] = edges_by_types[key][:, [1, 0]]
edges_by_types.update(tmp)
log.info(["finished loadding symmetry edges."])
node_types = self.load_node_types(self.node_types_file)
assert len(node_types) == self.num_nodes, \
"num_nodes should be equal to the length of node_types"
log.info(["number of nodes: ", len(node_types)])
node_features = {
'index': np.array([i for i in range(self.num_nodes)]).reshape(
-1, 1).astype(np.int64)
}
self.graph = heter_graph.HeterGraph(
num_nodes=self.num_nodes,
edges=edges_by_types,
node_types=node_types,
node_feat=node_features)
def load_edges(self, file_, symmetry=False):
"""Load edges from file.
"""
edges = []
with open(file_, 'r') as reader:
for line in reader:
items = line.strip().split()
src, dst = int(items[0]), int(items[1])
edges.append((src, dst))
if symmetry:
edges.append((dst, src))
edges = np.array(list(set(edges)), dtype=np.int64)
# edges = list(set(edges))
return edges
def load_node_types(self, file_):
"""Load node types
"""
node_types = []
log.info("node_types_file name: %s" % file_)
with open(file_, 'r') as reader:
for line in reader:
items = line.strip().split()
node_id = int(items[0])
n_type = items[1]
node_types.append((node_id, n_type))
return node_types
#!/bin/bash
set -x
source ./utils.sh
export CPU_NUM=$CPU_NUM
export FLAGS_rpc_deadline=3000000
export FLAGS_communicator_send_queue_size=1
export FLAGS_communicator_min_send_grad_num_before_recv=0
export FLAGS_communicator_max_merge_var_num=1
export FLAGS_communicator_merge_sparse_grad=0
python -u cluster_train.py -c config.yaml
#!/bin/bash
export PADDLE_TRAINERS_NUM=2
export PADDLE_PSERVERS_NUM=2
export PADDLE_PORT=6184,6185
export PADDLE_PSERVERS="127.0.0.1"
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
metapath2vec model.
"""
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import math
import paddle.fluid.layers as L
import paddle.fluid as F
def distributed_embedding(input,
dict_size,
hidden_size,
initializer,
name,
num_part=16,
is_sparse=False,
learning_rate=1.0):
_part_size = hidden_size // num_part
if hidden_size % num_part != 0:
_part_size += 1
output_embedding = []
p_num = 0
while hidden_size > 0:
_part_size = min(_part_size, hidden_size)
hidden_size -= _part_size
print("part", p_num, "size=", (dict_size, _part_size))
part_embedding = L.embedding(
input=input,
size=(dict_size, int(_part_size)),
is_sparse=is_sparse,
is_distributed=False,
param_attr=F.ParamAttr(
name=name + '_part%s' % p_num,
initializer=initializer,
learning_rate=learning_rate))
p_num += 1
output_embedding.append(part_embedding)
return L.concat(output_embedding, -1)
class Metapath2vecModel(object):
def __init__(self, config, embedding_lr=1.0):
self.config = config
self.neg_num = self.config.neg_num
self.num_nodes = self.config.num_nodes
self.embed_dim = self.config.embed_dim
self.is_sparse = self.config.is_sparse
self.is_distributed = self.config.is_distributed
self.embedding_lr = embedding_lr
self.pyreader = L.py_reader(
capacity=70,
shapes=[[-1, 1, 1], [-1, self.neg_num + 1, 1]],
dtypes=['int64', 'int64'],
lod_levels=[0, 0],
name='train',
use_double_buffer=True)
bound = 1. / math.sqrt(self.embed_dim)
self.embed_init = F.initializer.Uniform(low=-bound, high=bound)
self.loss = None
max_hidden_size = int(math.pow(2, 31) / 4 / self.num_nodes)
self.num_part = int(math.ceil(1. * self.embed_dim / max_hidden_size))
def forward(self):
src, dsts = L.read_file(self.pyreader)
if self.is_sparse:
src = L.reshape(src, [-1, 1])
dsts = L.reshape(dsts, [-1, 1])
if self.num_part is not None and self.num_part != 1 and not self.is_distributed:
src_embed = distributed_embedding(
src,
self.num_nodes,
self.embed_dim,
self.embed_init,
"weight",
self.num_part,
self.is_sparse,
learning_rate=self.embedding_lr)
dsts_embed = distributed_embedding(
dsts,
self.num_nodes,
self.embed_dim,
self.embed_init,
"weight",
self.num_part,
self.is_sparse,
learning_rate=self.embedding_lr)
else:
src_embed = L.embedding(
src, (self.num_nodes, self.embed_dim),
self.is_sparse,
self.is_distributed,
param_attr=F.ParamAttr(
name="weight",
learning_rate=self.embedding_lr,
initializer=self.embed_init))
dsts_embed = L.embedding(
dsts, (self.num_nodes, self.embed_dim),
self.is_sparse,
self.is_distributed,
param_attr=F.ParamAttr(
name="weight",
learning_rate=self.embedding_lr,
initializer=self.embed_init))
if self.is_sparse:
src_embed = L.reshape(src_embed, [-1, 1, self.embed_dim])
dsts_embed = L.reshape(dsts_embed,
[-1, self.neg_num + 1, self.embed_dim])
logits = L.matmul(
src_embed, dsts_embed,
transpose_y=True) # [batch_size, 1, neg_num+1]
pos_label = L.fill_constant_batch_size_like(logits, [-1, 1, 1],
"float32", 1)
neg_label = L.fill_constant_batch_size_like(
logits, [-1, 1, self.neg_num], "float32", 0)
label = L.concat([pos_label, neg_label], -1)
pos_weight = L.fill_constant_batch_size_like(logits, [-1, 1, 1],
"float32", self.neg_num)
neg_weight = L.fill_constant_batch_size_like(
logits, [-1, 1, self.neg_num], "float32", 1)
weight = L.concat([pos_weight, neg_weight], -1)
weight.stop_gradient = True
label.stop_gradient = True
loss = L.sigmoid_cross_entropy_with_logits(logits, label)
loss = loss * weight
loss = L.reduce_mean(loss)
loss = loss * ((self.neg_num + 1) / 2 / self.neg_num)
loss.persistable = True
self.loss = loss
return loss
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimized Multiprocessing Reader for PaddlePaddle
"""
import multiprocessing
import numpy as np
import time
import paddle.fluid as fluid
import pyarrow
def _serialize_serializable(obj):
"""Serialize Feed Dict
"""
return {"type": type(obj), "data": obj.__dict__}
def _deserialize_serializable(obj):
"""Deserialize Feed Dict
"""
val = obj["type"].__new__(obj["type"])
val.__dict__.update(obj["data"])
return val
context = pyarrow.default_serialization_context()
context.register_type(
object,
"object",
custom_serializer=_serialize_serializable,
custom_deserializer=_deserialize_serializable)
def serialize_data(data):
"""serialize_data"""
return pyarrow.serialize(data, context=context).to_buffer().to_pybytes()
def deserialize_data(data):
"""deserialize_data"""
return pyarrow.deserialize(data, context=context)
def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
"""
multiprocess_reader use python multi process to read data from readers
and then use multiprocess.Queue or multiprocess.Pipe to merge all
data. The process number is equal to the number of input readers, each
process call one reader.
Multiprocess.Queue require the rw access right to /dev/shm, some
platform does not support.
you need to create multiple readers first, these readers should be independent
to each other so that each process can work independently.
An example:
.. code-block:: python
reader0 = reader(["file01", "file02"])
reader1 = reader(["file11", "file12"])
reader1 = reader(["file21", "file22"])
reader = multiprocess_reader([reader0, reader1, reader2],
queue_size=100, use_pipe=False)
"""
assert type(readers) is list and len(readers) > 0
def _read_into_queue(reader, queue):
"""read_into_queue"""
for sample in reader():
if sample is None:
raise ValueError("sample has None")
queue.put(serialize_data(sample))
queue.put(serialize_data(None))
def queue_reader():
"""queue_reader"""
queue = multiprocessing.Queue(queue_size)
for reader in readers:
p = multiprocessing.Process(
target=_read_into_queue, args=(reader, queue))
p.start()
reader_num = len(readers)
finish_num = 0
while finish_num < reader_num:
sample = deserialize_data(queue.get())
if sample is None:
finish_num += 1
else:
yield sample
def _read_into_pipe(reader, conn):
"""read_into_pipe"""
for sample in reader():
if sample is None:
raise ValueError("sample has None!")
conn.send(serialize_data(sample))
conn.send(serialize_data(None))
conn.close()
def pipe_reader():
"""pipe_reader"""
conns = []
for reader in readers:
parent_conn, child_conn = multiprocessing.Pipe()
conns.append(parent_conn)
p = multiprocessing.Process(
target=_read_into_pipe, args=(reader, child_conn))
p.start()
reader_num = len(readers)
finish_num = 0
conn_to_remove = []
finish_flag = np.zeros(len(conns), dtype="int32")
while finish_num < reader_num:
for conn_id, conn in enumerate(conns):
if finish_flag[conn_id] > 0:
continue
buff = conn.recv()
now = time.time()
sample = deserialize_data(buff)
out = time.time() - now
if sample is None:
finish_num += 1
conn.close()
finish_flag[conn_id] = 1
else:
yield sample
if use_pipe:
return pipe_reader
else:
return queue_reader
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file provides the multi class task for testing the embedding learned by metapath2vec model.
"""
import argparse
import sys
import os
import tqdm
import time
import math
import logging
import random
import pickle as pkl
import numpy as np
import sklearn.metrics
from sklearn.metrics import f1_score
import pgl
import paddle.fluid as fluid
import paddle.fluid.layers as fl
def load_data(file_):
"""Load data for node classification.
"""
words_label = []
line_count = 0
with open(file_, 'r') as reader:
for line in reader:
line_count += 1
tokens = line.strip().split()
word, label = int(tokens[0]), int(tokens[1]) - 1
words_label.append((word, label))
words_label = np.array(words_label, dtype=np.int64)
np.random.shuffle(words_label)
logging.info('%d/%d word_label pairs have been loaded' %
(len(words_label), line_count))
return words_label
def node_classify_model(config):
"""Build node classify model.
"""
nodes = fl.data('nodes', shape=[None, 1], dtype='int64')
labels = fl.data('labels', shape=[None, 1], dtype='int64')
embed_nodes = fl.embedding(
input=nodes,
size=[config.num_nodes, config.embed_dim],
param_attr=fluid.ParamAttr(name='weight'))
embed_nodes.stop_gradient = True
probs = fl.fc(input=embed_nodes, size=config.num_labels, act='softmax')
predict = fl.argmax(probs, axis=-1)
loss = fl.cross_entropy(input=probs, label=labels)
loss = fl.reduce_mean(loss)
return {
'loss': loss,
'probs': probs,
'predict': predict,
'labels': labels,
}
def run_epoch(exe, prog, model, feed_dict, lr):
"""Run training process of every epoch.
"""
if lr is None:
loss, predict = exe.run(prog,
feed=feed_dict,
fetch_list=[model['loss'], model['predict']],
return_numpy=True)
lr_ = 0
else:
loss, predict, lr_ = exe.run(
prog,
feed=feed_dict,
fetch_list=[model['loss'], model['predict'], lr],
return_numpy=True)
macro_f1 = f1_score(feed_dict['labels'], predict, average="macro")
micro_f1 = f1_score(feed_dict['labels'], predict, average="micro")
return {
'loss': loss,
'pred': predict,
'lr': lr_,
'macro_f1': macro_f1,
'micro_f1': micro_f1
}
def main(args):
"""main function for training node classification task.
"""
words_label = load_data(args.dataset)
# split data for training and testing
split_position = int(words_label.shape[0] * args.train_percent)
train_words_label = words_label[0:split_position, :]
test_words_label = words_label[split_position:, :]
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
train_prog = fluid.Program()
test_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
model = node_classify_model(args)
test_prog = train_prog.clone(for_test=True)
with fluid.program_guard(train_prog, startup_prog):
lr = fl.polynomial_decay(args.lr, 1000, 0.001)
adam = fluid.optimizer.Adam(lr)
adam.minimize(model['loss'])
exe = fluid.Executor(place)
exe.run(startup_prog)
def existed_params(var):
if not isinstance(var, fluid.framework.Parameter):
return False
return os.path.exists(os.path.join(args.ckpt_path, var.name))
fluid.io.load_vars(
exe, args.ckpt_path, main_program=train_prog, predicate=existed_params)
# load_param(args.ckpt_path, ['content'])
feed_dict = {}
X = train_words_label[:, 0].reshape(-1, 1)
labels = train_words_label[:, 1].reshape(-1, 1)
logging.info('%d/%d data to train' %
(labels.shape[0], words_label.shape[0]))
test_feed_dict = {}
test_X = test_words_label[:, 0].reshape(-1, 1)
test_labels = test_words_label[:, 1].reshape(-1, 1)
logging.info('%d/%d data to test' %
(test_labels.shape[0], words_label.shape[0]))
for epoch in range(args.epochs):
feed_dict['nodes'] = X
feed_dict['labels'] = labels
train_result = run_epoch(exe, train_prog, model, feed_dict, lr)
test_feed_dict['nodes'] = test_X
test_feed_dict['labels'] = test_labels
test_result = run_epoch(exe, test_prog, model, test_feed_dict, lr=None)
logging.info(
'epoch %d | lr %.4f | train_loss %.5f | train_macro_F1 %.4f | train_micro_F1 %.4f | test_loss %.5f | test_macro_F1 %.4f | test_micro_F1 %.4f'
% (epoch, train_result['lr'], train_result['loss'],
train_result['macro_f1'], train_result['micro_f1'],
test_result['loss'], test_result['macro_f1'],
test_result['micro_f1']))
logging.info(
'final_test_macro_f1 score: %.4f | final_test_micro_f1 score: %.4f' %
(test_result['macro_f1'], test_result['micro_f1']))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='multi_class')
parser.add_argument(
'--dataset',
default=None,
type=str,
help='training and testing data file(default: None)')
parser.add_argument(
'--ckpt_path', default=None, type=str, help='task name(default: None)')
parser.add_argument("--use_cuda", action='store_true', help="use_cuda")
parser.add_argument(
'--train_percent',
default=0.5,
type=float,
help='train_percent(default: 0.5)')
parser.add_argument(
'--num_labels',
default=4,
type=int,
help='number of labels(default: 4)')
parser.add_argument(
'--epochs',
default=100,
type=int,
help='number of epochs for training(default: 100)')
parser.add_argument(
'--lr',
default=0.025,
type=float,
help='learning rate(default: 0.025)')
parser.add_argument(
'--num_nodes', default=0, type=int, help='number of nodes')
parser.add_argument(
'--embed_dim',
default=64,
type=int,
help='dimension of embedding(default: 64)')
args = parser.parse_args()
log_format = '%(asctime)s-%(levelname)s-%(name)s: %(message)s'
logging.basicConfig(level='INFO', format=log_format)
main(args)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of some helper functions"""
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals
import os
import time
import yaml
import numpy as np
from pgl.utils.logger import log
class AttrDict(dict):
"""Attr dict
"""
def __init__(self, d):
self.dict = d
def __getattr__(self, attr):
value = self.dict[attr]
if isinstance(value, dict):
return AttrDict(value)
else:
return value
def __str__(self):
return str(self.dict)
def load_config(config_file):
"""Load config file"""
with open(config_file) as f:
if hasattr(yaml, 'FullLoader'):
config = yaml.load(f, Loader=yaml.FullLoader)
else:
config = yaml.load(f)
return AttrDict(config)
# parse yaml file
function parse_yaml {
local prefix=$2
local s='[[:space:]]*' w='[a-zA-Z0-9_]*' fs=$(echo @|tr @ '\034')
sed -ne "s|^\($s\):|\1|" \
-e "s|^\($s\)\($w\)$s:$s[\"']\(.*\)[\"']$s\$|\1$fs\2$fs\3|p" \
-e "s|^\($s\)\($w\)$s:$s\(.*\)$s\$|\1$fs\2$fs\3|p" $1 |
awk -F$fs '{
indent = length($1)/2;
vname[indent] = $2;
for (i in vname) {if (i > indent) {delete vname[i]}}
if (length($3) > 0) {
vn=""; for (i=0; i<indent; i++) {vn=(vn)(vname[i])("_")}
printf("%s%s%s=\"%s\"\n", "'$prefix'",vn, $2, $3);
}
}'
}
eval $(parse_yaml "$(dirname "${BASH_SOURCE}")"/config.yaml)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""doc
"""
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
import time
import io
import os
import numpy as np
import random
from pgl.utils.logger import log
from pgl.sample import metapath_randomwalk
from pgl.graph_kernel import skip_gram_gen_pair
from pgl.graph_kernel import alias_sample_build_table
from utils import load_config
from graph import m2vGraph
import mp_reader
class NodeGenerator(object):
"""Node generator"""
def __init__(self, config, graph):
self.config = config
self.graph = graph
self.batch_size = self.config.batch_size
self.shuffle = self.config.node_shuffle
self.node_files = self.config.node_files
self.first_node_type = self.config.first_node_type
self.walk_mode = self.config.walk_mode
def __call__(self):
if self.walk_mode == "m2v":
generator = self.m2v_node_generate
log.info("node gen mode is : %s" % (self.walk_mode))
elif self.walk_mode == "multi_m2v":
generator = self.multi_m2v_node_generate
log.info("node gen mode is : %s" % (self.walk_mode))
elif self.walk_mode == "files":
generator = self.files_node_generate
log.info("node gen mode is : %s" % (self.walk_mode))
else:
generator = self.m2v_node_generate
log.info("node gen mode is : %s" % (self.walk_mode))
while True:
for nodes in generator():
yield nodes
def m2v_node_generate(self):
"""m2v_node_generate"""
for nodes in self.graph.node_batch_iter(
batch_size=self.batch_size,
n_type=self.first_node_type,
shuffle=self.shuffle):
yield nodes
def multi_m2v_node_generate(self):
"""multi_m2v_node_generate"""
n_type_list = self.first_node_type.split(';')
num_n_type = len(n_type_list)
node_types = np.unique(self.graph.node_types).tolist()
node_generators = {}
for n_type in node_types:
node_generators[n_type] = \
self.graph.node_batch_iter(self.batch_size, n_type=n_type)
cc = 0
while True:
idx = cc % num_n_type
n_type = n_type_list[idx]
try:
nodes = node_generators[n_type].next()
except StopIteration as e:
log.info("exception when iteration")
break
yield (nodes, idx)
cc += 1
def files_node_generate(self):
"""files_node_generate"""
nodes = []
for filename in self.node_files:
with io.open(filename) as inf:
for line in inf:
node = int(line.strip('\n\t'))
nodes.append(node)
if len(nodes) == self.batch_size:
yield nodes
nodes = []
if len(nodes):
yield nodes
class WalkGenerator(object):
"""Walk generator"""
def __init__(self, config, dataset):
self.config = config
self.dataset = dataset
self.graph = self.dataset.graph
self.walk_mode = self.config.walk_mode
self.node_generator = NodeGenerator(self.config, self.graph)
if self.walk_mode == "multi_m2v":
num_path = len(self.config.meta_path.split(';'))
num_first_node_type = len(self.config.first_node_type.split(';'))
assert num_first_node_type == num_path, \
"In [multi_m2v] walk_mode, the number of metapath should be the same \
as the number of first_node_type"
assert num_path > 1, "In [multi_m2v] walk_mode, the number of metapath\
should be greater than 1"
def __call__(self):
np.random.seed(os.getpid())
if self.walk_mode == "m2v":
walk_generator = self.m2v_walk
log.info("walk mode is : %s" % (self.walk_mode))
elif self.walk_mode == "multi_m2v":
walk_generator = self.multi_m2v_walk
log.info("walk mode is : %s" % (self.walk_mode))
else:
raise ValueError("walk_mode [%s] is not matched" % self.walk_mode)
for walks in walk_generator():
yield walks
def m2v_walk(self):
"""Metapath2vec walker"""
for nodes in self.node_generator():
walks = metapath_randomwalk(
self.graph, nodes, self.config.meta_path, self.config.walk_len)
yield walks
def multi_m2v_walk(self):
"""Multi metapath2vec walker"""
meta_paths = self.config.meta_path.split(';')
for nodes, idx in self.node_generator():
walks = metapath_randomwalk(self.graph, nodes, meta_paths[idx],
self.config.walk_len)
yield walks
class DataGenerator(object):
def __init__(self, config, dataset):
self.config = config
self.dataset = dataset
self.graph = self.dataset.graph
self.walk_generator = WalkGenerator(self.config, self.dataset)
def __call__(self):
generator = self.pair_generate
for src, pos, negs in generator():
dst = np.concatenate([pos, negs], 1)
yield src, dst
def pair_generate(self):
for walks in self.walk_generator():
try:
src_list, pos_list = [], []
for walk in walks:
s, p = skip_gram_gen_pair(walk, self.config.win_size)
src_list.append(s), pos_list.append(p)
src = [s for x in src_list for s in x]
pos = [s for x in pos_list for s in x]
if len(src) == 0:
continue
negs = self.negative_sample(
src,
pos,
neg_num=self.config.neg_num,
neg_sample_type=self.config.neg_sample_type)
src = np.array(src, dtype=np.int64).reshape(-1, 1, 1)
pos = np.array(pos, dtype=np.int64).reshape(-1, 1, 1)
yield src, pos, negs
except Exception as e:
log.exception(e)
def negative_sample(self, src, pos, neg_num, neg_sample_type):
if neg_sample_type == "average":
neg_sample_size = [len(pos), neg_num, 1]
negs = np.random.randint(
low=0, high=self.graph.num_nodes, size=neg_sample_size)
elif neg_sample_type == "m2v_plus":
negs = []
for s in src:
neg = self.graph.sample_nodes(
sample_num=neg_num, n_type=self.graph.node_types[s])
negs.append(neg)
negs = np.vstack(negs).reshape(-1, neg_num, 1)
else: # equal to "average"
neg_sample_size = [len(pos), neg_num, 1]
negs = np.random.randint(
low=0, high=self.graph.num_nodes, size=neg_sample_size)
negs = negs.astype(np.int64)
return negs
def multiprocess_data_generator(config, dataset):
"""Multiprocess data generator.
"""
if config.num_sample_workers == 1:
data_generator = DataGenerator(config, dataset)
else:
pool = [
DataGenerator(config, dataset)
for i in range(config.num_sample_workers)
]
data_generator = mp_reader.multiprocess_reader(
pool, use_pipe=True, queue_size=100)
return data_generator
if __name__ == "__main__":
config_file = "./config.yaml"
config = load_config(config_file)
dataset = m2vGraph(config)
data_generator = multiprocess_data_generator(config, dataset)
start = time.time()
cc = 0
for src, dst in data_generator():
log.info(src.shape)
log.info("time: %.6f" % (time.time() - start))
start = time.time()
cc += 1
if cc == 100:
break
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册