From c6115f950f26c8fe482481f135723b8d1d59451a Mon Sep 17 00:00:00 2001 From: githubutilities Date: Thu, 14 May 2020 10:37:10 +0800 Subject: [PATCH] Add redis+distributed graphsage --- examples/distribute_graphsage/README.md | 44 +-- examples/distribute_graphsage/cloud_run.sh | 25 ++ .../distribute_graphsage/cluster_train.py | 191 ++++++++++++ examples/distribute_graphsage/config.yaml | 19 ++ examples/distribute_graphsage/job.sh | 14 + examples/distribute_graphsage/local_config | 6 + examples/distribute_graphsage/model.py | 98 ++++++- examples/distribute_graphsage/reader.py | 47 +++ .../redis_setup/before_hook.sh | 31 ++ .../redis_setup/redis_graph.cfg | 6 + .../redis_setup/src/build_graph.py | 275 ++++++++++++++++++ .../redis_setup/src/dump_data.sh | 63 ++++ .../redis_setup/src/gen_redis_conf.py | 72 +++++ .../redis_setup/src/preprocess.py | 35 +++ .../redis_setup/src/requirements.txt | 6 + .../redis_setup/src/run_server.sh | 14 + .../redis_setup/src/start_cluster.py | 37 +++ .../redis_setup/test/test.sh | 7 + .../redis_setup/test/test_redis_graph.py | 40 +++ .../distribute_graphsage/requirements.txt | 4 + examples/distribute_graphsage/train.py | 263 ----------------- examples/distribute_graphsage/utils.py | 55 ++++ examples/distribute_graphsage/utils.sh | 20 ++ 23 files changed, 1075 insertions(+), 297 deletions(-) create mode 100755 examples/distribute_graphsage/cloud_run.sh create mode 100644 examples/distribute_graphsage/cluster_train.py create mode 100644 examples/distribute_graphsage/config.yaml create mode 100644 examples/distribute_graphsage/job.sh create mode 100644 examples/distribute_graphsage/local_config create mode 100644 examples/distribute_graphsage/redis_setup/before_hook.sh create mode 100644 examples/distribute_graphsage/redis_setup/redis_graph.cfg create mode 100644 examples/distribute_graphsage/redis_setup/src/build_graph.py create mode 100644 examples/distribute_graphsage/redis_setup/src/dump_data.sh create mode 100644 examples/distribute_graphsage/redis_setup/src/gen_redis_conf.py create mode 100644 examples/distribute_graphsage/redis_setup/src/preprocess.py create mode 100644 examples/distribute_graphsage/redis_setup/src/requirements.txt create mode 100644 examples/distribute_graphsage/redis_setup/src/run_server.sh create mode 100644 examples/distribute_graphsage/redis_setup/src/start_cluster.py create mode 100644 examples/distribute_graphsage/redis_setup/test/test.sh create mode 100644 examples/distribute_graphsage/redis_setup/test/test_redis_graph.py delete mode 100644 examples/distribute_graphsage/train.py create mode 100644 examples/distribute_graphsage/utils.py create mode 100644 examples/distribute_graphsage/utils.sh diff --git a/examples/distribute_graphsage/README.md b/examples/distribute_graphsage/README.md index 0ce196f..fcfe50a 100644 --- a/examples/distribute_graphsage/README.md +++ b/examples/distribute_graphsage/README.md @@ -6,54 +6,32 @@ information (e.g., text attributes) to efficiently generate node embeddings for For purpose of high scalability, we use redis as distribute graph storage solution and training graphsage against redis server. ### Datasets(Quickstart) -The reddit dataset should be downloaded from [reddit_adj.npz](https://drive.google.com/open?id=174vb0Ws7Vxk_QTUtxqTgDHSQ4El4qDHt) and [reddit.npz](https://drive.google.com/open?id=19SphVl_Oe8SJ1r87Hr5a6znx3nJu1F2Jthe). The details for Reddit Dataset can be found [here](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf). +The reddit dataset should be downloaded from [reddit_adj.npz](https://drive.google.com/open?id=174vb0Ws7Vxk_QTUtxqTgDHSQ4El4qDHt) and [reddit.npz](https://drive.google.com/open?id=19SphVl_Oe8SJ1r87Hr5a6znx3nJu1F2J). The details for Reddit Dataset can be found [here](https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf). -Alternatively, reddit dataset has been preprocessed and packed into docker image, which can be instantly pulled using following commands. +- reddit.npz: https://drive.google.com/open?id=19SphVl_Oe8SJ1r87Hr5a6znx3nJu1F2J +- reddit_adj.npz: https://drive.google.com/open?id=174vb0Ws7Vxk_QTUtxqTgDHSQ4El4qDHt -```sh -docker pull githubutilities/reddit_redis_demo:v0.1 -``` +Download `reddit.npz` and `reddit_adj.npz` into `data` directory for further preprocessing. ### Dependencies -```txt -- paddlepaddle>=1.6 -- pgl -- scipy -- redis==2.10.6 -- redis-py-cluster==1.3.6 +```sh +pip install -r requirements.txt ``` ### How to run -#### 1. Start reddit data service +#### 1. Preprocessing and start reddit data service ```sh -docker run \ - --net=host \ - -d --rm \ - --name reddit_demo \ - -it githubutilities/reddit_redis_demo:v0.1 \ - /bin/bash -c "/bin/bash ./before_hook.sh && /bin/bash" -docker logs -f `docker ps -aqf "name=reddit_demo"` +pushd ./redis_setup + /bin/bash ./before_hook.sh +popd ``` #### 2. training GraphSAGE model ```sh -python train.py --use_cuda --epoch 10 --graphsage_type graphsage_mean --sample_workers 10 +sh ./cloud_run.sh ``` -#### Hyperparameters - -- epoch: Number of epochs default (10) -- use_cuda: Use gpu if assign use_cuda. -- graphsage_type: We support 4 aggregator types including "graphsage_mean", "graphsage_maxpool", "graphsage_meanpool" and "graphsage_lstm". -- sample_workers: The number of workers for multiprocessing subgraph sample. -- lr: Learning rate. -- batch_size: Batch size. -- samples_1: The max neighbors for the first hop neighbor sampling. (default: 25) -- samples_2: The max neighbors for the second hop neighbor sampling. (default: 10) -- hidden_size: The hidden size of the GraphSAGE models. - - diff --git a/examples/distribute_graphsage/cloud_run.sh b/examples/distribute_graphsage/cloud_run.sh new file mode 100755 index 0000000..c5b5e45 --- /dev/null +++ b/examples/distribute_graphsage/cloud_run.sh @@ -0,0 +1,25 @@ +#!/bin/bash +set -x +mode=${1} + +source ./utils.sh +unset http_proxy https_proxy + +source ./local_config +if [ ! -d ${log_dir} ]; then + mkdir ${log_dir} +fi + +for((i=0;i<${PADDLE_PSERVERS_NUM};i++)) +do + echo "start ps server: ${i}" + echo $log_dir + TRAINING_ROLE="PSERVER" PADDLE_TRAINER_ID=${i} sh job.sh &> $log_dir/pserver.$i.log & +done +sleep 10s +for((j=0;j<${PADDLE_TRAINERS_NUM};j++)) +do + echo "start ps work: ${j}" + TRAINING_ROLE="TRAINER" PADDLE_TRAINER_ID=${j} sh job.sh &> $log_dir/worker.$j.log & +done +tail -f $log_dir/worker.0.log diff --git a/examples/distribute_graphsage/cluster_train.py b/examples/distribute_graphsage/cluster_train.py new file mode 100644 index 0000000..1ff2695 --- /dev/null +++ b/examples/distribute_graphsage/cluster_train.py @@ -0,0 +1,191 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import time +import os +import math +import numpy as np + +import paddle.fluid as F +import paddle.fluid.layers as L +from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet +from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig +import paddle.fluid.incubate.fleet.base.role_maker as role_maker +from pgl.utils.logger import log + +from model import GraphsageModel +from utils import load_config +import reader + + +def init_role(): + # reset the place according to role of parameter server + training_role = os.getenv("TRAINING_ROLE", "TRAINER") + paddle_role = role_maker.Role.WORKER + place = F.CPUPlace() + if training_role == "PSERVER": + paddle_role = role_maker.Role.SERVER + + # set the fleet runtime environment according to configure + ports = os.getenv("PADDLE_PORT", "6174").split(",") + pserver_ips = os.getenv("PADDLE_PSERVERS").split(",") # ip,ip... + eplist = [] + if len(ports) > 1: + # local debug mode, multi port + for port in ports: + eplist.append(':'.join([pserver_ips[0], port])) + else: + # distributed mode, multi ip + for ip in pserver_ips: + eplist.append(':'.join([ip, ports[0]])) + + pserver_endpoints = eplist # ip:port,ip:port... + worker_num = int(os.getenv("PADDLE_TRAINERS_NUM", "0")) + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + role = role_maker.UserDefinedRoleMaker( + current_id=trainer_id, + role=paddle_role, + worker_num=worker_num, + server_endpoints=pserver_endpoints) + fleet.init(role) + + +def optimization(base_lr, loss, optimizer='adam'): + if optimizer == 'sgd': + optimizer = F.optimizer.SGD(base_lr) + elif optimizer == 'adam': + optimizer = F.optimizer.Adam(base_lr, lazy_mode=True) + else: + raise ValueError + + log.info('learning rate:%f' % (base_lr)) + #create the DistributeTranspiler configure + config = DistributeTranspilerConfig() + config.sync_mode = False + #config.runtime_split_send_recv = False + + config.slice_var_up = False + #create the distributed optimizer + optimizer = fleet.distributed_optimizer(optimizer, config) + optimizer.minimize(loss) + + +def build_complied_prog(train_program, model_loss): + num_threads = int(os.getenv("CPU_NUM", 10)) + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0)) + exec_strategy = F.ExecutionStrategy() + exec_strategy.num_threads = num_threads + #exec_strategy.use_experimental_executor = True + build_strategy = F.BuildStrategy() + build_strategy.enable_inplace = True + #build_strategy.memory_optimize = True + build_strategy.memory_optimize = False + build_strategy.remove_unnecessary_lock = False + if num_threads > 1: + build_strategy.reduce_strategy = F.BuildStrategy.ReduceStrategy.Reduce + + compiled_prog = F.compiler.CompiledProgram( + train_program).with_data_parallel(loss_name=model_loss.name) + return compiled_prog + + +def fake_py_reader(data_iter, num): + def fake_iter(): + queue = [] + for idx, data in enumerate(data_iter()): + queue.append(data) + if len(queue) == num: + yield queue + queue = [] + if len(queue) > 0: + while len(queue) < num: + queue.append(queue[-1]) + yield queue + return fake_iter + +def train_prog(exe, program, model, pyreader, args): + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + start = time.time() + batch = 0 + total_loss = 0. + total_acc = 0. + total_sample = 0 + for epoch_idx in range(args.num_epoch): + for step, batch_feed_dict in enumerate(pyreader()): + try: + cpu_time = time.time() + batch += 1 + batch_loss, batch_acc = exe.run( + program, + feed=batch_feed_dict, + fetch_list=[model.loss, model.acc]) + + end = time.time() + if batch % args.log_per_step == 0: + log.info( + "Batch %s Loss %s Acc %s \t Speed(per batch) %.5lf/%.5lf sec" + % (batch, np.mean(batch_loss), np.mean(batch_acc), (end - start) /batch, (end - cpu_time))) + + if step % args.steps_per_save == 0: + save_path = args.save_path + if trainer_id == 0: + model_path = os.path.join(save_path, "%s" % step) + fleet.save_persistables(exe, model_path) + except Exception as e: + log.info("Pyreader train error") + log.exception(e) + +def main(args): + log.info("start") + + worker_num = int(os.getenv("PADDLE_TRAINERS_NUM", "0")) + num_devices = int(os.getenv("CPU_NUM", 10)) + + model = GraphsageModel(args) + loss = model.forward() + train_iter = reader.get_iter(args, model.graph_wrapper, 'train') + pyreader = fake_py_reader(train_iter, num_devices) + + # init fleet + init_role() + + optimization(args.lr, loss, args.optimizer) + + # init and run server or worker + if fleet.is_server(): + fleet.init_server(args.warm_start_from_dir) + fleet.run_server() + + if fleet.is_worker(): + log.info("start init worker done") + fleet.init_worker() + #just the worker, load the sample + log.info("init worker done") + + exe = F.Executor(F.CPUPlace()) + exe.run(fleet.startup_program) + log.info("Startup done") + + compiled_prog = build_complied_prog(fleet.main_program, loss) + train_prog(exe, compiled_prog, model, pyreader, args) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='metapath2vec') + parser.add_argument("-c", "--config", type=str, default="./config.yaml") + args = parser.parse_args() + config = load_config(args.config) + log.info(config) + main(config) + diff --git a/examples/distribute_graphsage/config.yaml b/examples/distribute_graphsage/config.yaml new file mode 100644 index 0000000..915ef18 --- /dev/null +++ b/examples/distribute_graphsage/config.yaml @@ -0,0 +1,19 @@ +# model config +hidden_size: 128 +num_class: 41 +samples: [25, 10] +graphsage_type: "graphsage_mean" + +# trainging config +num_epoch: 10 +batch_size: 128 +num_sample_workers: 10 +optimizer: "adam" +lr: 0.01 +warm_start_from_dir: null +steps_per_save: 1000 +log_per_step: 1 +save_path: "./checkpoints" +log_dir: "./logs" +CPU_NUM: 1 + diff --git a/examples/distribute_graphsage/job.sh b/examples/distribute_graphsage/job.sh new file mode 100644 index 0000000..8b9ee4d --- /dev/null +++ b/examples/distribute_graphsage/job.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +set -x +source ./utils.sh + +export CPU_NUM=$CPU_NUM +export FLAGS_rpc_deadline=3000000 + +export FLAGS_communicator_send_queue_size=1 +export FLAGS_communicator_min_send_grad_num_before_recv=0 +export FLAGS_communicator_max_merge_var_num=1 +export FLAGS_communicator_merge_sparse_grad=0 + +python -u cluster_train.py -c config.yaml diff --git a/examples/distribute_graphsage/local_config b/examples/distribute_graphsage/local_config new file mode 100644 index 0000000..0e8ca14 --- /dev/null +++ b/examples/distribute_graphsage/local_config @@ -0,0 +1,6 @@ +#!/bin/bash +export PADDLE_TRAINERS_NUM=2 +export PADDLE_PSERVERS_NUM=2 +export PADDLE_PORT=6184,6185 +export PADDLE_PSERVERS="127.0.0.1" + diff --git a/examples/distribute_graphsage/model.py b/examples/distribute_graphsage/model.py index 145a979..7f5eb99 100644 --- a/examples/distribute_graphsage/model.py +++ b/examples/distribute_graphsage/model.py @@ -11,10 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" + graphsage model. +""" +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +from __future__ import unicode_literals +import math + +import pgl +import numpy as np import paddle +import paddle.fluid.layers as L +import paddle.fluid as F import paddle.fluid as fluid - def copy_send(src_feat, dst_feat, edge_feat): return src_feat["h"] @@ -128,3 +140,87 @@ def graphsage_lstm(gw, feature, hidden_size, act, name): output = fluid.layers.concat([self_feature, neigh_feature], axis=1) output = fluid.layers.l2_normalize(output, axis=1) return output + + +def build_graph_model(graph_wrapper, num_class, k_hop, graphsage_type, + hidden_size): + node_index = fluid.layers.data( + "node_index", shape=[None], dtype="int64", append_batch_size=False) + + node_label = fluid.layers.data( + "node_label", shape=[None, 1], dtype="int64", append_batch_size=False) + + #feature = fluid.layers.gather(feature, graph_wrapper.node_feat['feats']) + feature = graph_wrapper.node_feat['feats'] + feature.stop_gradient = True + + for i in range(k_hop): + if graphsage_type == 'graphsage_mean': + feature = graphsage_mean( + graph_wrapper, + feature, + hidden_size, + act="relu", + name="graphsage_mean_%s" % i) + elif graphsage_type == 'graphsage_meanpool': + feature = graphsage_meanpool( + graph_wrapper, + feature, + hidden_size, + act="relu", + name="graphsage_meanpool_%s" % i) + elif graphsage_type == 'graphsage_maxpool': + feature = graphsage_maxpool( + graph_wrapper, + feature, + hidden_size, + act="relu", + name="graphsage_maxpool_%s" % i) + elif graphsage_type == 'graphsage_lstm': + feature = graphsage_lstm( + graph_wrapper, + feature, + hidden_size, + act="relu", + name="graphsage_maxpool_%s" % i) + else: + raise ValueError("graphsage type %s is not" + " implemented" % graphsage_type) + + feature = fluid.layers.gather(feature, node_index) + logits = fluid.layers.fc(feature, + num_class, + act=None, + name='classification_layer') + proba = fluid.layers.softmax(logits) + + loss = fluid.layers.softmax_with_cross_entropy( + logits=logits, label=node_label) + loss = fluid.layers.mean(loss) + acc = fluid.layers.accuracy(input=proba, label=node_label, k=1) + return loss, acc + + +class GraphsageModel(object): + def __init__(self, args): + self.args = args + + def forward(self): + args = self.args + + graph_wrapper = pgl.graph_wrapper.GraphWrapper( + "sub_graph", node_feat=[('feats', [None, 602], np.dtype('float32'))]) + loss, acc = build_graph_model( + graph_wrapper, + num_class=args.num_class, + hidden_size=args.hidden_size, + graphsage_type=args.graphsage_type, + k_hop=len(args.samples)) + + loss.persistable = True + + self.graph_wrapper = graph_wrapper + self.loss = loss + self.acc = acc + return loss + diff --git a/examples/distribute_graphsage/reader.py b/examples/distribute_graphsage/reader.py index 6617b6b..88556d3 100644 --- a/examples/distribute_graphsage/reader.py +++ b/examples/distribute_graphsage/reader.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys import numpy as np import pickle as pkl import paddle @@ -147,3 +149,48 @@ def multiprocess_graph_reader( return reader() + +def load_data(): + """ + data from https://github.com/matenure/FastGCN/issues/8 + reddit.npz: https://drive.google.com/open?id=19SphVl_Oe8SJ1r87Hr5a6znx3nJu1F2J + reddit_index_label is preprocess from reddit.npz without feats key. + """ + data_dir = os.path.dirname(os.path.abspath(__file__)) + data = np.load(os.path.join(data_dir, "data/reddit_index_label.npz")) + + num_class = 41 + + train_label = data['y_train'] + val_label = data['y_val'] + test_label = data['y_test'] + + train_index = data['train_index'] + val_index = data['val_index'] + test_index = data['test_index'] + + return { + "train_index": train_index, + "train_label": train_label, + "val_label": val_label, + "val_index": val_index, + "test_index": test_index, + "test_label": test_label, + "num_class": 41 + } + +def get_iter(args, graph_wrapper, mode): + data = load_data() + train_iter = multiprocess_graph_reader( + graph_wrapper, + samples=args.samples, + num_workers=args.num_sample_workers, + batch_size=args.batch_size, + node_index=data['train_index'], + node_label=data["train_label"]) + return train_iter + +if __name__ == '__main__': + for e in train_iter(): + print(e) + diff --git a/examples/distribute_graphsage/redis_setup/before_hook.sh b/examples/distribute_graphsage/redis_setup/before_hook.sh new file mode 100644 index 0000000..1b5c101 --- /dev/null +++ b/examples/distribute_graphsage/redis_setup/before_hook.sh @@ -0,0 +1,31 @@ +#!/bin/bash +set -x + +srcdir=./src + +# Data preprocessing +python ./src/preprocess.py + +# Download and compile redis +export PATH=$PWD/redis-5.0.5/src:$PATH +if [ ! -f ./redis.tar.gz ]; then + curl https://codeload.github.com/antirez/redis/tar.gz/5.0.5 -o ./redis.tar.gz +fi +tar -xzf ./redis.tar.gz +cd ./redis-5.0.5/ +make +cd - + +# Install python deps +python -m pip install -U pip +pip install -r ./src/requirements.txt -U + +# Run redis server +sh ./src/run_server.sh + +# Dumping data into redis +source ./redis_graph.cfg +sh ./src/dump_data.sh $edge_path $server_list $num_nodes $node_feat_path + +exit 0 + diff --git a/examples/distribute_graphsage/redis_setup/redis_graph.cfg b/examples/distribute_graphsage/redis_setup/redis_graph.cfg new file mode 100644 index 0000000..382ca17 --- /dev/null +++ b/examples/distribute_graphsage/redis_setup/redis_graph.cfg @@ -0,0 +1,6 @@ +# dump config +edge_path=../data/edge.txt +node_feat_path=../data/feats.npz +num_nodes=232965 +server_list=./server.list + diff --git a/examples/distribute_graphsage/redis_setup/src/build_graph.py b/examples/distribute_graphsage/redis_setup/src/build_graph.py new file mode 100644 index 0000000..9fb1c65 --- /dev/null +++ b/examples/distribute_graphsage/redis_setup/src/build_graph.py @@ -0,0 +1,275 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import json +import logging +from collections import defaultdict +import tqdm +import redis +from redis._compat import b, unicode, bytes, long, basestring +from rediscluster.nodemanager import NodeManager +from rediscluster.crc import crc16 +import argparse +import time +import pickle +import numpy as np +import scipy.sparse as sp + +log = logging.getLogger(__name__) +root = logging.getLogger() +root.setLevel(logging.DEBUG) + +handler = logging.StreamHandler(sys.stdout) +handler.setLevel(logging.DEBUG) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +handler.setFormatter(formatter) +root.addHandler(handler) + + +def encode(value): + """ + Return a bytestring representation of the value. + This method is copied from Redis' connection.py:Connection.encode + """ + if isinstance(value, bytes): + return value + elif isinstance(value, (int, long)): + value = b(str(value)) + elif isinstance(value, float): + value = b(repr(value)) + elif not isinstance(value, basestring): + value = unicode(value) + if isinstance(value, unicode): + value = value.encode('utf-8') + return value + + +def crc16_hash(data): + return crc16(encode(data)) + + +def get_redis(startup_host, startup_port): + startup_nodes = [{"host": startup_host, "port": startup_port}, ] + nodemanager = NodeManager(startup_nodes=startup_nodes) + nodemanager.initialize() + rs = {} + for node, config in nodemanager.nodes.items(): + rs[node] = redis.Redis( + host=config["host"], port=config["port"], decode_responses=False) + return rs, nodemanager + + +def load_data(edge_path): + src, dst = [], [] + with open(edge_path, "r") as f: + for i in tqdm.tqdm(f): + s, d, _ = i.split() + s = int(s) + d = int(d) + src.append(s) + dst.append(d) + dst.append(s) + src.append(d) + src = np.array(src, dtype="int64") + dst = np.array(dst, dtype="int64") + return src, dst + + +def build_edge_index(edge_path, num_nodes, startup_host, startup_port, + num_bucket): + #src, dst = load_data(edge_path) + rs, nodemanager = get_redis(startup_host, startup_port) + + dst_mp, edge_mp = defaultdict(list), defaultdict(list) + with open(edge_path) as f: + for l in tqdm.tqdm(f): + a, b, idx = l.rstrip().split('\t') + a, b, idx = int(a), int(b), int(idx) + dst_mp[a].append(b) + edge_mp[a].append(idx) + part_dst_dicts = {} + for i in tqdm.tqdm(range(num_nodes)): + #if len(edge_index.v[i]) == 0: + # continue + #v = edge_index.v[i].astype("int64").reshape([-1, 1]) + #e = edge_index.eid[i].astype("int64").reshape([-1, 1]) + if i not in dst_mp: + continue + v = np.array(dst_mp[i]).astype('int64').reshape([-1, 1]) + e = np.array(edge_mp[i]).astype('int64').reshape([-1, 1]) + o = np.hstack([v, e]) + key = "d:%s" % i + part = crc16_hash(key) % num_bucket + if part not in part_dst_dicts: + part_dst_dicts[part] = {} + dst_dicts = part_dst_dicts[part] + dst_dicts["d:%s" % i] = o.tobytes() + if len(dst_dicts) > 10000: + slot = nodemanager.keyslot("part-%s" % part) + node = nodemanager.slots[slot][0]['name'] + while True: + res = rs[node].hmset("part-%s" % part, dst_dicts) + if res: + break + log.info("HMSET FAILED RETRY connected %s" % node) + time.sleep(1) + part_dst_dicts[part] = {} + + for part, dst_dicts in part_dst_dicts.items(): + if len(dst_dicts) > 0: + slot = nodemanager.keyslot("part-%s" % part) + node = nodemanager.slots[slot][0]['name'] + while True: + res = rs[node].hmset("part-%s" % part, dst_dicts) + if res: + break + log.info("HMSET FAILED RETRY connected %s" % node) + time.sleep(1) + part_dst_dicts[part] = {} + log.info("dst_dict Done") + + +def build_edge_id(edge_path, num_nodes, startup_host, startup_port, + num_bucket): + src, dst = load_data(edge_path) + rs, nodemanager = get_redis(startup_host, startup_port) + part_edge_dict = {} + for i in tqdm.tqdm(range(len(src))): + key = "e:%s" % i + part = crc16_hash(key) % num_bucket + if part not in part_edge_dict: + part_edge_dict[part] = {} + edge_dict = part_edge_dict[part] + edge_dict["e:%s" % i] = int(src[i]) * num_nodes + int(dst[i]) + if len(edge_dict) > 10000: + slot = nodemanager.keyslot("part-%s" % part) + node = nodemanager.slots[slot][0]['name'] + while True: + res = rs[node].hmset("part-%s" % part, edge_dict) + if res: + break + log.info("HMSET FAILED RETRY connected %s" % node) + time.sleep(1) + + part_edge_dict[part] = {} + + for part, edge_dict in part_edge_dict.items(): + if len(edge_dict) > 0: + slot = nodemanager.keyslot("part-%s" % part) + node = nodemanager.slots[slot][0]['name'] + while True: + res = rs[node].hmset("part-%s" % part, edge_dict) + if res: + break + log.info("HMSET FAILED RETRY connected %s" % node) + time.sleep(1) + part_edge_dict[part] = {} + + +def build_infos(edge_path, num_nodes, startup_host, startup_port, num_bucket): + src, dst = load_data(edge_path) + rs, nodemanager = get_redis(startup_host, startup_port) + slot = nodemanager.keyslot("num_nodes") + node = nodemanager.slots[slot][0]['name'] + res = rs[node].set("num_nodes", num_nodes) + + slot = nodemanager.keyslot("num_edges") + node = nodemanager.slots[slot][0]['name'] + rs[node].set("num_edges", len(src)) + + slot = nodemanager.keyslot("nf:infos") + node = nodemanager.slots[slot][0]['name'] + rs[node].set("nf:infos", json.dumps([['feats', [-1, 602], 'float32'], ])) + + slot = nodemanager.keyslot("ef:infos") + node = nodemanager.slots[slot][0]['name'] + rs[node].set("ef:infos", json.dumps([])) + + +def build_node_feat(node_feat_path, num_nodes, startup_host, startup_port, num_bucket): + assert node_feat_path != "", "node_feat_path empty!" + feat_dict = np.load(node_feat_path) + for k in feat_dict.keys(): + feat = feat_dict[k] + assert feat.shape[0] == num_nodes, "num_nodes invalid" + + rs, nodemanager = get_redis(startup_host, startup_port) + part_feat_dict = {} + for k in feat_dict.keys(): + feat = feat_dict[k] + for i in tqdm.tqdm(range(num_nodes)): + key = "nf:%s:%i" % (k, i) + value = feat[i].tobytes() + part = crc16_hash(key) % num_bucket + if part not in part_feat_dict: + part_feat_dict[part] = {} + part_feat = part_feat_dict[part] + part_feat[key] = value + if len(part_feat) > 100: + slot = nodemanager.keyslot("part-%s" % part) + node = nodemanager.slots[slot][0]['name'] + while True: + res = rs[node].hmset("part-%s" % part, part_feat) + if res: + break + log.info("HMSET FAILED RETRY connected %s" % node) + time.sleep(1) + + part_feat_dict[part] = {} + + for part, part_feat in part_feat_dict.items(): + if len(part_feat) > 0: + slot = nodemanager.keyslot("part-%s" % part) + node = nodemanager.slots[slot][0]['name'] + while True: + res = rs[node].hmset("part-%s" % part, part_feat) + if res: + break + log.info("HMSET FAILED RETRY connected %s" % node) + time.sleep(1) + part_feat_dict[part] = {} + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='gen_redis_conf') + parser.add_argument('--startup_port', type=int, required=True) + parser.add_argument('--startup_host', type=str, required=True) + parser.add_argument('--edge_path', type=str, default="") + parser.add_argument('--node_feat_path', type=str, default="") + parser.add_argument('--num_nodes', type=int, default=0) + parser.add_argument('--num_bucket', type=int, default=64) + parser.add_argument( + '--mode', + type=str, + required=True, + help="choose one of the following modes (clear, edge_index, edge_id, graph_attr)" + ) + args = parser.parse_args() + log.info("Mode: {}".format(args.mode)) + if args.mode == 'edge_index': + build_edge_index(args.edge_path, args.num_nodes, args.startup_host, + args.startup_port, args.num_bucket) + elif args.mode == 'edge_id': + build_edge_id(args.edge_path, args.num_nodes, args.startup_host, + args.startup_port, args.num_bucket) + elif args.mode == 'graph_attr': + build_infos(args.edge_path, args.num_nodes, args.startup_host, + args.startup_port, args.num_bucket) + elif args.mode == 'node_feat': + build_node_feat(args.node_feat_path, args.num_nodes, args.startup_host, + args.startup_port, args.num_bucket) + else: + raise ValueError("%s mode not found" % args.mode) + diff --git a/examples/distribute_graphsage/redis_setup/src/dump_data.sh b/examples/distribute_graphsage/redis_setup/src/dump_data.sh new file mode 100644 index 0000000..052064b --- /dev/null +++ b/examples/distribute_graphsage/redis_setup/src/dump_data.sh @@ -0,0 +1,63 @@ +filter(){ + lines=`cat $1` + rm $1 + for line in $lines; do + remote_host=`echo $line | cut -d":" -f1` + remote_port=`echo $line | cut -d":" -f2` + nc -z $remote_host $remote_port + if [[ $? == 0 ]]; then + echo $line >> $1 + fi + done +} + +dump_data(){ + filter $server_list + + python ./src/start_cluster.py --server_list $server_list --replicas 0 + + address=`head -n 1 $server_list` + + ip=`echo $address | cut -d":" -f1` + port=`echo $address | cut -d":" -f2` + + python ./src/build_graph.py --startup_host $ip \ + --startup_port $port \ + --mode node_feat \ + --node_feat_path $feat_fn \ + --num_nodes $num_nodes + + # build edge index + python ./src/build_graph.py --startup_host $ip \ + --startup_port $port \ + --mode edge_index \ + --edge_path $edge_path \ + --num_nodes $num_nodes + + # build edge id + #python ./src/build_graph.py --startup_host $ip \ + # --startup_port $port \ + # --mode edge_id \ + # --edge_path $edge_path \ + # --num_nodes $num_nodes + + # build graph attr + python ./src/build_graph.py --startup_host $ip \ + --startup_port $port \ + --mode graph_attr \ + --edge_path $edge_path \ + --num_nodes $num_nodes + +} + +if [ $# -ne 4 ]; then + echo 'sh edge_path server_list num_nodes feat_fn' + exit +fi +num_nodes=$3 +server_list=$2 +edge_path=$1 +feat_fn=$4 + +dump_data + diff --git a/examples/distribute_graphsage/redis_setup/src/gen_redis_conf.py b/examples/distribute_graphsage/redis_setup/src/gen_redis_conf.py new file mode 100644 index 0000000..5ded1f3 --- /dev/null +++ b/examples/distribute_graphsage/redis_setup/src/gen_redis_conf.py @@ -0,0 +1,72 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import socket +import argparse +import os +temp = """port %s +bind %s +daemonize yes +pidfile /var/run/redis_%s.pid +cluster-enabled yes +cluster-config-file nodes.conf +cluster-node-timeout 50000 +logfile "redis.log" +appendonly yes""" + + +def gen_config(ports): + if len(ports) == 0: + raise ValueError("No ports") + ip = socket.gethostbyname(socket.gethostname()) + print("Generate redis conf") + for port in ports: + try: + os.mkdir("%s" % port) + except: + print("port %s directory already exists" % port) + pass + with open("%s/redis.conf" % port, 'w') as f: + f.write(temp % (port, ip, port)) + + print("Generate Start Server Scripts") + with open("start_server.sh", "w") as f: + f.write("set -x\n") + for ind, port in enumerate(ports): + f.write("# %s %s start\n" % (ip, port)) + if ind > 0: + f.write("cd ..\n") + f.write("cd %s\n" % port) + f.write("redis-server redis.conf\n") + f.write("\n") + + print("Generate Stop Server Scripts") + with open("stop_server.sh", "w") as f: + f.write("set -x\n") + for ind, port in enumerate(ports): + f.write("# %s %s shutdown\n" % (ip, port)) + f.write("redis-cli -h %s -p %s shutdown\n" % (ip, port)) + f.write("\n") + + with open("server.list", "w") as f: + for ind, port in enumerate(ports): + f.write("%s:%s\n" % (ip, port)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='gen_redis_conf') + parser.add_argument('--ports', nargs='+', type=int, default=[]) + args = parser.parse_args() + gen_config(args.ports) diff --git a/examples/distribute_graphsage/redis_setup/src/preprocess.py b/examples/distribute_graphsage/redis_setup/src/preprocess.py new file mode 100644 index 0000000..42c641c --- /dev/null +++ b/examples/distribute_graphsage/redis_setup/src/preprocess.py @@ -0,0 +1,35 @@ +import os +import sys + +import numpy as np +import scipy.sparse as sp + +def _load_config(fn): + ret = {} + with open(fn) as f: + for l in f: + if l.strip() == '' or l.startswith('#'): + continue + k, v = l.strip().split('=') + ret[k] = v + return ret + +def _prepro(config): + data = np.load("../data/reddit.npz") + adj = sp.load_npz("../data/reddit_adj.npz") + adj = adj.tocoo() + src = adj.row + dst = adj.col + + with open(config['edge_path'], 'w') as f: + for idx, e in enumerate(zip(src, dst)): + s, d = e + l = "{}\t{}\t{}\n".format(s, d, idx) + f.write(l) + feats = data['feats'].astype(np.float32) + np.savez(config['node_feat_path'], feats=feats) + +if __name__ == '__main__': + config = _load_config('./redis_graph.cfg') + _prepro(config) + diff --git a/examples/distribute_graphsage/redis_setup/src/requirements.txt b/examples/distribute_graphsage/redis_setup/src/requirements.txt new file mode 100644 index 0000000..2f955a8 --- /dev/null +++ b/examples/distribute_graphsage/redis_setup/src/requirements.txt @@ -0,0 +1,6 @@ +numpy +scipy +tqdm +redis==2.10.6 +redis-py-cluster==1.3.6 + diff --git a/examples/distribute_graphsage/redis_setup/src/run_server.sh b/examples/distribute_graphsage/redis_setup/src/run_server.sh new file mode 100644 index 0000000..e1268ad --- /dev/null +++ b/examples/distribute_graphsage/redis_setup/src/run_server.sh @@ -0,0 +1,14 @@ +start_server(){ + ports="" + for i in {7430..7439}; do + nc -z localhost $i + if [[ $? != 0 ]]; then + ports="$ports $i" + fi + done + python ./src/gen_redis_conf.py --ports $ports + bash ./start_server.sh #启动服务器 +} + +start_server + diff --git a/examples/distribute_graphsage/redis_setup/src/start_cluster.py b/examples/distribute_graphsage/redis_setup/src/start_cluster.py new file mode 100644 index 0000000..570765d --- /dev/null +++ b/examples/distribute_graphsage/redis_setup/src/start_cluster.py @@ -0,0 +1,37 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse + + +def build_clusters(server_list, replicas): + servers = [] + with open(server_list) as f: + for line in f: + servers.append(line.strip()) + cmd = "echo yes | redis-cli --cluster create" + for server in servers: + cmd += ' %s ' % server + cmd += '--cluster-replicas %s' % replicas + print(cmd) + os.system(cmd) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='start_cluster') + parser.add_argument('--server_list', type=str, required=True) + parser.add_argument('--replicas', type=int, default=0) + args = parser.parse_args() + build_clusters(args.server_list, args.replicas) diff --git a/examples/distribute_graphsage/redis_setup/test/test.sh b/examples/distribute_graphsage/redis_setup/test/test.sh new file mode 100644 index 0000000..ec8695b --- /dev/null +++ b/examples/distribute_graphsage/redis_setup/test/test.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +source ./redis_graph.cfg + +url=`head -n1 $server_list` +shuf $edge_path | head -n 1000 | python ./test/test_redis_graph.py $url + diff --git a/examples/distribute_graphsage/redis_setup/test/test_redis_graph.py b/examples/distribute_graphsage/redis_setup/test/test_redis_graph.py new file mode 100644 index 0000000..f4a3a25 --- /dev/null +++ b/examples/distribute_graphsage/redis_setup/test/test_redis_graph.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +######################################################################## +# +# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved +# +# File: test_redis_graph.py +# Author: suweiyue(suweiyue@baidu.com) +# Date: 2019/08/19 16:28:18 +# +######################################################################## +""" + Comment. +""" +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +from __future__ import unicode_literals + +import sys + +import numpy as np +import tqdm +from pgl.redis_graph import RedisGraph + +if __name__ == '__main__': + host, port = sys.argv[1].split(':') + port = int(port) + redis_configs = [{"host": host, "port": port}, ] + graph = RedisGraph("reddit-graph", redis_configs, num_parts=64) + #nodes = np.arange(0, 100) + #for i in range(0, 100): + for l in tqdm.tqdm(sys.stdin): + l_sp = l.rstrip().split('\t') + if len(l_sp) != 2: + continue + i, j = int(l_sp[0]), int(l_sp[1]) + nodes = graph.sample_predecessor(np.array([i]), 10000) + assert j in nodes + diff --git a/examples/distribute_graphsage/requirements.txt b/examples/distribute_graphsage/requirements.txt index 7bda67a..e0d28f3 100644 --- a/examples/distribute_graphsage/requirements.txt +++ b/examples/distribute_graphsage/requirements.txt @@ -1,3 +1,7 @@ +pgl==1.1.0 +pyyaml +paddlepaddle==1.6.1 + scipy redis==2.10.6 redis-py-cluster==1.3.6 diff --git a/examples/distribute_graphsage/train.py b/examples/distribute_graphsage/train.py deleted file mode 100644 index 4faafdd..0000000 --- a/examples/distribute_graphsage/train.py +++ /dev/null @@ -1,263 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import argparse -import time - -import numpy as np -import scipy.sparse as sp -from sklearn.preprocessing import StandardScaler - -import pgl -from pgl.utils.logger import log -from pgl.utils import paddle_helper -import paddle -import paddle.fluid as fluid -import reader -from model import graphsage_mean, graphsage_meanpool,\ - graphsage_maxpool, graphsage_lstm - - -def load_data(): - """ - data from https://github.com/matenure/FastGCN/issues/8 - reddit.npz: https://drive.google.com/open?id=19SphVl_Oe8SJ1r87Hr5a6znx3nJu1F2J - reddit_index_label is preprocess from reddit.npz without feats key. - """ - data_dir = os.path.dirname(os.path.abspath(__file__)) - data = np.load(os.path.join(data_dir, "data/reddit_index_label.npz")) - - num_class = 41 - - train_label = data['y_train'] - val_label = data['y_val'] - test_label = data['y_test'] - - train_index = data['train_index'] - val_index = data['val_index'] - test_index = data['test_index'] - - return { - "train_index": train_index, - "train_label": train_label, - "val_label": val_label, - "val_index": val_index, - "test_index": test_index, - "test_label": test_label, - "num_class": 41 - } - - -def build_graph_model(graph_wrapper, num_class, k_hop, graphsage_type, - hidden_size): - node_index = fluid.layers.data( - "node_index", shape=[None], dtype="int64", append_batch_size=False) - - node_label = fluid.layers.data( - "node_label", shape=[None, 1], dtype="int64", append_batch_size=False) - - #feature = fluid.layers.gather(feature, graph_wrapper.node_feat['feats']) - feature = graph_wrapper.node_feat['feats'] - feature.stop_gradient = True - - for i in range(k_hop): - if graphsage_type == 'graphsage_mean': - feature = graphsage_mean( - graph_wrapper, - feature, - hidden_size, - act="relu", - name="graphsage_mean_%s" % i) - elif graphsage_type == 'graphsage_meanpool': - feature = graphsage_meanpool( - graph_wrapper, - feature, - hidden_size, - act="relu", - name="graphsage_meanpool_%s" % i) - elif graphsage_type == 'graphsage_maxpool': - feature = graphsage_maxpool( - graph_wrapper, - feature, - hidden_size, - act="relu", - name="graphsage_maxpool_%s" % i) - elif graphsage_type == 'graphsage_lstm': - feature = graphsage_lstm( - graph_wrapper, - feature, - hidden_size, - act="relu", - name="graphsage_maxpool_%s" % i) - else: - raise ValueError("graphsage type %s is not" - " implemented" % graphsage_type) - - feature = fluid.layers.gather(feature, node_index) - logits = fluid.layers.fc(feature, - num_class, - act=None, - name='classification_layer') - proba = fluid.layers.softmax(logits) - - loss = fluid.layers.softmax_with_cross_entropy( - logits=logits, label=node_label) - loss = fluid.layers.mean(loss) - acc = fluid.layers.accuracy(input=proba, label=node_label, k=1) - return loss, acc - - -def run_epoch(batch_iter, - exe, - program, - prefix, - model_loss, - model_acc, - epoch, - log_per_step=100): - batch = 0 - total_loss = 0. - total_acc = 0. - total_sample = 0 - start = time.time() - for batch_feed_dict in batch_iter(): - batch += 1 - batch_loss, batch_acc = exe.run(program, - fetch_list=[model_loss, model_acc], - feed=batch_feed_dict) - - if batch % log_per_step == 0: - log.info("Batch %s %s-Loss %s %s-Acc %s" % - (batch, prefix, batch_loss, prefix, batch_acc)) - - num_samples = len(batch_feed_dict["node_index"]) - total_loss += batch_loss * num_samples - total_acc += batch_acc * num_samples - total_sample += num_samples - end = time.time() - - log.info("%s Epoch %s Loss %.5lf Acc %.5lf Speed(per batch) %.5lf sec" % - (prefix, epoch, total_loss / total_sample, - total_acc / total_sample, (end - start) / batch)) - - -def main(args): - data = load_data() - log.info("preprocess finish") - log.info("Train Examples: %s" % len(data["train_index"])) - log.info("Val Examples: %s" % len(data["val_index"])) - log.info("Test Examples: %s" % len(data["test_index"])) - - place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace() - train_program = fluid.Program() - startup_program = fluid.Program() - samples = [] - if args.samples_1 > 0: - samples.append(args.samples_1) - if args.samples_2 > 0: - samples.append(args.samples_2) - - with fluid.program_guard(train_program, startup_program): - graph_wrapper = pgl.graph_wrapper.GraphWrapper( - "sub_graph", node_feat=[('feats', [None, 602], np.dtype('float32'))]) - model_loss, model_acc = build_graph_model( - graph_wrapper, - num_class=data["num_class"], - hidden_size=args.hidden_size, - graphsage_type=args.graphsage_type, - k_hop=len(samples)) - - test_program = train_program.clone(for_test=True) - - with fluid.program_guard(train_program, startup_program): - adam = fluid.optimizer.Adam(learning_rate=args.lr) - adam.minimize(model_loss) - - exe = fluid.Executor(place) - exe.run(startup_program) - - train_iter = reader.multiprocess_graph_reader( - graph_wrapper, - samples=samples, - num_workers=args.sample_workers, - batch_size=args.batch_size, - node_index=data['train_index'], - node_label=data["train_label"]) - - val_iter = reader.multiprocess_graph_reader( - graph_wrapper, - samples=samples, - num_workers=args.sample_workers, - batch_size=args.batch_size, - node_index=data['val_index'], - node_label=data["val_label"]) - - test_iter = reader.multiprocess_graph_reader( - graph_wrapper, - samples=samples, - num_workers=args.sample_workers, - batch_size=args.batch_size, - node_index=data['test_index'], - node_label=data["test_label"]) - - for epoch in range(args.epoch): - run_epoch( - train_iter, - program=train_program, - exe=exe, - prefix="train", - model_loss=model_loss, - model_acc=model_acc, - log_per_step=1, - epoch=epoch) - - run_epoch( - val_iter, - program=test_program, - exe=exe, - prefix="val", - model_loss=model_loss, - model_acc=model_acc, - log_per_step=10000, - epoch=epoch) - - run_epoch( - test_iter, - program=test_program, - prefix="test", - exe=exe, - model_loss=model_loss, - model_acc=model_acc, - log_per_step=10000, - epoch=epoch) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='graphsage') - parser.add_argument("--use_cuda", action='store_true', help="use_cuda") - parser.add_argument( - "--normalize", action='store_true', help="normalize features") - parser.add_argument( - "--symmetry", action='store_true', help="undirect graph") - parser.add_argument("--graphsage_type", type=str, default="graphsage_mean") - parser.add_argument("--sample_workers", type=int, default=10) - parser.add_argument("--epoch", type=int, default=10) - parser.add_argument("--hidden_size", type=int, default=128) - parser.add_argument("--batch_size", type=int, default=128) - parser.add_argument("--lr", type=float, default=0.01) - parser.add_argument("--samples_1", type=int, default=25) - parser.add_argument("--samples_2", type=int, default=10) - args = parser.parse_args() - log.info(args) - main(args) diff --git a/examples/distribute_graphsage/utils.py b/examples/distribute_graphsage/utils.py new file mode 100644 index 0000000..f5810f7 --- /dev/null +++ b/examples/distribute_graphsage/utils.py @@ -0,0 +1,55 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of some helper functions""" + +from __future__ import division +from __future__ import absolute_import +from __future__ import print_function +from __future__ import unicode_literals + +import os +import time +import yaml +import numpy as np + +from pgl.utils.logger import log + + +class AttrDict(dict): + """Attr dict + """ + + def __init__(self, d): + self.dict = d + + def __getattr__(self, attr): + value = self.dict[attr] + if isinstance(value, dict): + return AttrDict(value) + else: + return value + + def __str__(self): + return str(self.dict) + + +def load_config(config_file): + """Load config file""" + with open(config_file) as f: + if hasattr(yaml, 'FullLoader'): + config = yaml.load(f, Loader=yaml.FullLoader) + else: + config = yaml.load(f) + + return AttrDict(config) diff --git a/examples/distribute_graphsage/utils.sh b/examples/distribute_graphsage/utils.sh new file mode 100644 index 0000000..6f6daa8 --- /dev/null +++ b/examples/distribute_graphsage/utils.sh @@ -0,0 +1,20 @@ + +# parse yaml file +function parse_yaml { + local prefix=$2 + local s='[[:space:]]*' w='[a-zA-Z0-9_]*' fs=$(echo @|tr @ '\034') + sed -ne "s|^\($s\):|\1|" \ + -e "s|^\($s\)\($w\)$s:$s[\"']\(.*\)[\"']$s\$|\1$fs\2$fs\3|p" \ + -e "s|^\($s\)\($w\)$s:$s\(.*\)$s\$|\1$fs\2$fs\3|p" $1 | + awk -F$fs '{ + indent = length($1)/2; + vname[indent] = $2; + for (i in vname) {if (i > indent) {delete vname[i]}} + if (length($3) > 0) { + vn=""; for (i=0; i