diff --git a/ogb_examples/linkproppred/main_pgl.py b/ogb_examples/linkproppred/main_pgl.py deleted file mode 100644 index bb81a248c98fe03dcc44037d211e5e2af06a0716..0000000000000000000000000000000000000000 --- a/ogb_examples/linkproppred/main_pgl.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""test ogb -""" -import argparse -import time -import logging -import numpy as np - -import paddle.fluid as fluid - -import pgl -from pgl.contrib.ogb.linkproppred.dataset_pgl import PglLinkPropPredDataset -from pgl.utils import paddle_helper -from ogb.linkproppred import Evaluator - - -def send_func(src_feat, dst_feat, edge_feat): - """send_func""" - return src_feat["h"] - - -def recv_func(feat): - """recv_func""" - return fluid.layers.sequence_pool(feat, pool_type="sum") - - -class GNNModel(object): - """GNNModel""" - - def __init__(self, name, num_nodes, emb_dim, num_layers): - self.num_nodes = num_nodes - self.emb_dim = emb_dim - self.num_layers = num_layers - self.name = name - - self.src_nodes = fluid.layers.data( - name='src_nodes', - shape=[None], - dtype='int64', ) - - self.dst_nodes = fluid.layers.data( - name='dst_nodes', - shape=[None], - dtype='int64', ) - - self.edge_label = fluid.layers.data( - name='edge_label', - shape=[None, 1], - dtype='float32', ) - - def forward(self, graph): - """forward""" - h = fluid.layers.create_parameter( - shape=[self.num_nodes, self.emb_dim], - dtype="float32", - name=self.name + "_embedding") - - for layer in range(self.num_layers): - msg = graph.send( - send_func, - nfeat_list=[("h", h)], ) - h = graph.recv(msg, recv_func) - h = fluid.layers.fc( - h, - size=self.emb_dim, - bias_attr=False, - param_attr=fluid.ParamAttr(name=self.name + '_%s' % layer)) - h = h * graph.node_feat["norm"] - bias = fluid.layers.create_parameter( - shape=[self.emb_dim], - dtype='float32', - is_bias=True, - name=self.name + '_bias_%s' % layer) - h = fluid.layers.elementwise_add(h, bias, act="relu") - - src = fluid.layers.gather(h, self.src_nodes, overwrite=False) - dst = fluid.layers.gather(h, self.dst_nodes, overwrite=False) - edge_embed = src * dst - pred = fluid.layers.fc(input=edge_embed, - size=1, - name=self.name + "_pred_output") - - prob = fluid.layers.sigmoid(pred) - - loss = fluid.layers.sigmoid_cross_entropy_with_logits(pred, - self.edge_label) - loss = fluid.layers.reduce_sum(loss) - - return pred, prob, loss - - -def main(): - """main - """ - # Training settings - parser = argparse.ArgumentParser(description='Graph Dataset') - parser.add_argument( - '--epochs', - type=int, - default=4, - help='number of epochs to train (default: 100)') - parser.add_argument( - '--dataset', - type=str, - default="ogbl-ppa", - help='dataset name (default: protein protein associations)') - parser.add_argument('--use_cuda', action='store_true') - parser.add_argument('--batch_size', type=int, default=5120) - parser.add_argument('--embed_dim', type=int, default=64) - parser.add_argument('--num_layers', type=int, default=2) - parser.add_argument('--lr', type=float, default=0.001) - args = parser.parse_args() - print(args) - - place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace() - - ### automatic dataloading and splitting - print("loadding dataset") - dataset = PglLinkPropPredDataset(name=args.dataset) - splitted_edge = dataset.get_edge_split() - print(splitted_edge['train_edge'].shape) - print(splitted_edge['train_edge_label'].shape) - - print("building evaluator") - ### automatic evaluator. takes dataset name as input - evaluator = Evaluator(args.dataset) - - graph_data = dataset[0] - print("num_nodes: %d" % graph_data.num_nodes) - - train_program = fluid.Program() - startup_program = fluid.Program() - - # degree normalize - indegree = graph_data.indegree() - norm = np.zeros_like(indegree, dtype="float32") - norm[indegree > 0] = np.power(indegree[indegree > 0], -0.5) - graph_data.node_feat["norm"] = np.expand_dims(norm, -1).astype("float32") - # graph_data.node_feat["index"] = np.array([i for i in range(graph_data.num_nodes)], dtype=np.int64).reshape(-1,1) - - with fluid.program_guard(train_program, startup_program): - model = GNNModel( - name="gnn", - num_nodes=graph_data.num_nodes, - emb_dim=args.embed_dim, - num_layers=args.num_layers) - gw = pgl.graph_wrapper.GraphWrapper( - "graph", - place, - node_feat=graph_data.node_feat_info(), - edge_feat=graph_data.edge_feat_info()) - pred, prob, loss = model.forward(gw) - - val_program = train_program.clone(for_test=True) - - with fluid.program_guard(train_program, startup_program): - global_steps = int(splitted_edge['train_edge'].shape[0] / - args.batch_size * 2) - learning_rate = fluid.layers.polynomial_decay(args.lr, global_steps, - 0.00005) - - adam = fluid.optimizer.Adam( - learning_rate=learning_rate, - regularization=fluid.regularizer.L2DecayRegularizer( - regularization_coeff=0.0005)) - adam.minimize(loss) - - exe = fluid.Executor(place) - exe.run(startup_program) - feed = gw.to_feed(graph_data) - - print("evaluate result before training: ") - result = test(exe, val_program, prob, evaluator, feed, splitted_edge) - print(result) - - print("training") - cc = 0 - for epoch in range(1, args.epochs + 1): - for batch_data, batch_label in data_generator( - graph_data, - splitted_edge["train_edge"], - splitted_edge["train_edge_label"], - batch_size=args.batch_size): - feed['src_nodes'] = batch_data[:, 0].reshape(-1, 1) - feed['dst_nodes'] = batch_data[:, 1].reshape(-1, 1) - feed['edge_label'] = batch_label.astype("float32") - - res_loss, y_pred, b_lr = exe.run( - train_program, - feed=feed, - fetch_list=[loss, prob, learning_rate]) - if cc % 1 == 0: - print("epoch %d | step %d | lr %s | Loss %s" % - (epoch, cc, b_lr[0], res_loss[0])) - cc += 1 - - if cc % 20 == 0: - print("Evaluating...") - result = test(exe, val_program, prob, evaluator, feed, - splitted_edge) - print("epoch %d | step %d" % (epoch, cc)) - print(result) - - -def test(exe, val_program, prob, evaluator, feed, splitted_edge): - """Evaluation""" - result = {} - feed['src_nodes'] = splitted_edge["valid_edge"][:, 0].reshape(-1, 1) - feed['dst_nodes'] = splitted_edge["valid_edge"][:, 1].reshape(-1, 1) - feed['edge_label'] = splitted_edge["valid_edge_label"].astype( - "float32").reshape(-1, 1) - y_pred = exe.run(val_program, feed=feed, fetch_list=[prob])[0] - input_dict = { - "y_pred_pos": - y_pred[splitted_edge["valid_edge_label"] == 1].reshape(-1, ), - "y_pred_neg": - y_pred[splitted_edge["valid_edge_label"] == 0].reshape(-1, ) - } - result["valid"] = evaluator.eval(input_dict) - - feed['src_nodes'] = splitted_edge["test_edge"][:, 0].reshape(-1, 1) - feed['dst_nodes'] = splitted_edge["test_edge"][:, 1].reshape(-1, 1) - feed['edge_label'] = splitted_edge["test_edge_label"].astype( - "float32").reshape(-1, 1) - y_pred = exe.run(val_program, feed=feed, fetch_list=[prob])[0] - input_dict = { - "y_pred_pos": - y_pred[splitted_edge["test_edge_label"] == 1].reshape(-1, ), - "y_pred_neg": - y_pred[splitted_edge["test_edge_label"] == 0].reshape(-1, ) - } - result["test"] = evaluator.eval(input_dict) - return result - - -def data_generator(graph, data, label_data, batch_size, shuffle=True): - """Data Generator""" - perm = np.arange(0, len(data)) - if shuffle: - np.random.shuffle(perm) - - offset = 0 - while offset < len(perm): - batch_index = perm[offset:(offset + batch_size)] - offset += batch_size - pos_data = data[batch_index] - pos_label = label_data[batch_index] - - neg_src_node = pos_data[:, 0] - neg_dst_node = np.random.choice( - pos_data.reshape(-1, ), size=len(neg_src_node)) - neg_data = np.hstack( - [neg_src_node.reshape(-1, 1), neg_dst_node.reshape(-1, 1)]) - exists = graph.has_edges_between(neg_src_node, neg_dst_node) - neg_data = neg_data[np.invert(exists)] - neg_label = np.zeros(shape=len(neg_data), dtype=np.int64) - - batch_data = np.vstack([pos_data, neg_data]) - label = np.vstack([pos_label.reshape(-1, 1), neg_label.reshape(-1, 1)]) - yield batch_data, label - - -if __name__ == "__main__": - main() diff --git a/ogb_examples/linkproppred/ogbl-ppa/README.md b/ogb_examples/linkproppred/ogbl-ppa/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f06b3bc2be13dca9548491c5a152841fd4bb034f --- /dev/null +++ b/ogb_examples/linkproppred/ogbl-ppa/README.md @@ -0,0 +1,21 @@ +# Graph Link Prediction for Open Graph Benchmark (OGB) PPA dataset + +[The Open Graph Benchmark (OGB)](https://ogb.stanford.edu/) is a collection of benchmark datasets, data loaders, and evaluators for graph machine learning. Here we complete the Graph Link Prediction task based on PGL. + + +### Requirements + +paddlpaddle >= 1.7.1 + +pgl 1.0.2 + +ogb + + +### How to Run + +``` +CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --use_cuda 1 --num_workers 4 --output_path ./output/model_1 --batch_size 65536 --epoch 1000 --learning_rate 0.005 --hidden_size 256 +``` + +The best record will be saved in ./output/model_1/best.txt. diff --git a/ogb_examples/linkproppred/ogbl-ppa/args.py b/ogb_examples/linkproppred/ogbl-ppa/args.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc51d37f9774fbf50fb7bbb5aa700b9f8aaff7f --- /dev/null +++ b/ogb_examples/linkproppred/ogbl-ppa/args.py @@ -0,0 +1,44 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""finetune args""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +from __future__ import absolute_import + +import os +import time +import argparse + +from utils.args import ArgumentGroup + +# yapf: disable +parser = argparse.ArgumentParser(__doc__) +model_g = ArgumentGroup(parser, "model", "model configuration and paths.") +model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.") +model_g.add_arg("init_pretraining_params", str, None, + "Init pre-training params which preforms fine-tuning from. If the " + "arg 'init_checkpoint' has been set, this argument wouldn't be valid.") + +train_g = ArgumentGroup(parser, "training", "training options.") +train_g.add_arg("epoch", int, 3, "Number of epoches for fine-tuning.") +train_g.add_arg("learning_rate", float, 5e-5, "Learning rate used to train with warmup.") + +run_type_g = ArgumentGroup(parser, "run_type", "running type options.") +run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.") +run_type_g.add_arg("num_workers", int, 1, "use multiprocess to generate graph") +run_type_g.add_arg("output_path", str, None, "path to save model") +run_type_g.add_arg("hidden_size", int, 128, "model hidden-size") +run_type_g.add_arg("batch_size", int, 128, "batch_size") diff --git a/ogb_examples/linkproppred/ogbl-ppa/dataloader/__init__.py b/ogb_examples/linkproppred/ogbl-ppa/dataloader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abf198b97e6e818e1fbe59006f98492640bcee54 --- /dev/null +++ b/ogb_examples/linkproppred/ogbl-ppa/dataloader/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ogb_examples/linkproppred/ogbl-ppa/dataloader/base_dataloader.py b/ogb_examples/linkproppred/ogbl-ppa/dataloader/base_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..d04f9fd521602bf67f950b3e72ba021fd09c298f --- /dev/null +++ b/ogb_examples/linkproppred/ogbl-ppa/dataloader/base_dataloader.py @@ -0,0 +1,148 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Base DataLoader +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +from __future__ import absolute_import + +import os +import sys +import six +from io import open +from collections import namedtuple +import numpy as np +import tqdm +import paddle +from pgl.utils import mp_reader +import collections +import time + +import pgl + +if six.PY3: + import io + sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') + sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') + + +def batch_iter(data, perm, batch_size, fid, num_workers): + """node_batch_iter + """ + size = len(data) + start = 0 + cc = 0 + while start < size: + index = perm[start:start + batch_size] + start += batch_size + cc += 1 + if cc % num_workers != fid: + continue + yield data[index] + + +def scan_batch_iter(data, batch_size, fid, num_workers): + """node_batch_iter + """ + batch = [] + cc = 0 + for line_example in data.scan(): + cc += 1 + if cc % num_workers != fid: + continue + batch.append(line_example) + if len(batch) == batch_size: + yield batch + batch = [] + + if len(batch) > 0: + yield batch + + +class BaseDataGenerator(object): + """Base Data Geneartor""" + + def __init__(self, buf_size, batch_size, num_workers, shuffle=True): + self.num_workers = num_workers + self.batch_size = batch_size + self.line_examples = [] + self.buf_size = buf_size + self.shuffle = shuffle + + def batch_fn(self, batch_examples): + """ batch_fn batch producer""" + raise NotImplementedError("No defined Batch Fn") + + def batch_iter(self, fid, perm): + """ batch iterator""" + if self.shuffle: + for batch in batch_iter(self, perm, self.batch_size, fid, + self.num_workers): + yield batch + else: + for batch in scan_batch_iter(self, self.batch_size, fid, + self.num_workers): + yield batch + + def __len__(self): + return len(self.line_examples) + + def __getitem__(self, idx): + if isinstance(idx, collections.Iterable): + return [self[bidx] for bidx in idx] + else: + return self.line_examples[idx] + + def generator(self): + """batch dict generator""" + + def worker(filter_id, perm): + """ multiprocess worker""" + + def func_run(): + """ func_run """ + pid = os.getpid() + np.random.seed(pid + int(time.time())) + for batch_examples in self.batch_iter(filter_id, perm): + batch_dict = self.batch_fn(batch_examples) + yield batch_dict + + return func_run + + # consume a seed + np.random.rand() + if self.shuffle: + perm = np.arange(0, len(self)) + np.random.shuffle(perm) + else: + perm = None + if self.num_workers == 1: + r = paddle.reader.buffered(worker(0, perm), self.buf_size) + else: + worker_pool = [ + worker(wid, perm) for wid in range(self.num_workers) + ] + worker = mp_reader.multiprocess_reader( + worker_pool, use_pipe=True, queue_size=1000) + r = paddle.reader.buffered(worker, self.buf_size) + + for batch in r(): + yield batch + + def scan(self): + for line_example in self.line_examples: + yield line_example diff --git a/ogb_examples/linkproppred/ogbl-ppa/dataloader/ogbl_ppa_dataloader.py b/ogb_examples/linkproppred/ogbl-ppa/dataloader/ogbl_ppa_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..621db215a6924de338a7dd881ddc54ac82290a33 --- /dev/null +++ b/ogb_examples/linkproppred/ogbl-ppa/dataloader/ogbl_ppa_dataloader.py @@ -0,0 +1,118 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +from __future__ import absolute_import + +from dataloader.base_dataloader import BaseDataGenerator +import ssl +ssl._create_default_https_context = ssl._create_unverified_context + +from ogb.linkproppred import LinkPropPredDataset +from ogb.linkproppred import Evaluator +import tqdm +from collections import namedtuple +import pgl +import numpy as np + + +class PPADataGenerator(BaseDataGenerator): + def __init__(self, + graph_wrapper=None, + buf_size=1000, + batch_size=128, + num_workers=1, + shuffle=True, + phase="train"): + super(PPADataGenerator, self).__init__( + buf_size=buf_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle) + + self.d_name = "ogbl-ppa" + self.graph_wrapper = graph_wrapper + dataset = LinkPropPredDataset(name=self.d_name) + splitted_edge = dataset.get_edge_split() + self.phase = phase + graph = dataset[0] + edges = graph["edge_index"].T + #self.graph = pgl.graph.Graph(num_nodes=graph["num_nodes"], + # edges=edges, + # node_feat={"nfeat": graph["node_feat"], + # "node_id": np.arange(0, graph["num_nodes"], dtype="int64").reshape(-1, 1) }) + + #self.graph.indegree() + self.num_nodes = graph["num_nodes"] + if self.phase == 'train': + edges = splitted_edge["train"]["edge"] + labels = np.ones(len(edges)) + elif self.phase == "valid": + # Compute the embedding for all the nodes + pos_edges = splitted_edge["valid"]["edge"] + neg_edges = splitted_edge["valid"]["edge_neg"] + pos_labels = np.ones(len(pos_edges)) + neg_labels = np.zeros(len(neg_edges)) + edges = np.vstack([pos_edges, neg_edges]) + labels = pos_labels.tolist() + neg_labels.tolist() + elif self.phase == "test": + # Compute the embedding for all the nodes + pos_edges = splitted_edge["test"]["edge"] + neg_edges = splitted_edge["test"]["edge_neg"] + pos_labels = np.ones(len(pos_edges)) + neg_labels = np.zeros(len(neg_edges)) + edges = np.vstack([pos_edges, neg_edges]) + labels = pos_labels.tolist() + neg_labels.tolist() + + self.line_examples = [] + Example = namedtuple('Example', ['src', "dst", "label"]) + for edge, label in zip(edges, labels): + self.line_examples.append( + Example( + src=edge[0], dst=edge[1], label=label)) + print("Phase", self.phase) + print("Len Examples", len(self.line_examples)) + + def batch_fn(self, batch_ex): + batch_src = [] + batch_dst = [] + join_graph = [] + cc = 0 + batch_node_id = [] + batch_labels = [] + for ex in batch_ex: + batch_src.append(ex.src) + batch_dst.append(ex.dst) + batch_labels.append(ex.label) + + if self.phase == "train": + for num in range(1): + rand_src = np.random.randint( + low=0, high=self.num_nodes, size=len(batch_ex)) + rand_dst = np.random.randint( + low=0, high=self.num_nodes, size=len(batch_ex)) + batch_src = batch_src + rand_src.tolist() + batch_dst = batch_dst + rand_dst.tolist() + batch_labels = batch_labels + np.zeros_like( + rand_src, dtype="int64").tolist() + + feed_dict = {} + + feed_dict["batch_src"] = np.array(batch_src, dtype="int64") + feed_dict["batch_dst"] = np.array(batch_dst, dtype="int64") + feed_dict["labels"] = np.array(batch_labels, dtype="int64") + return feed_dict diff --git a/ogb_examples/linkproppred/ogbl-ppa/model.py b/ogb_examples/linkproppred/ogbl-ppa/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9429ea39a900488e1ab65c084e4b133079c56dcb --- /dev/null +++ b/ogb_examples/linkproppred/ogbl-ppa/model.py @@ -0,0 +1,108 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""lbs_model""" +import os +import re +import time +from random import random +from functools import reduce, partial + +import numpy as np +import multiprocessing + +import paddle +import paddle.fluid as F +import paddle.fluid.layers as L +from pgl.graph_wrapper import GraphWrapper +from pgl.layers.conv import gcn, gat + + +class BaseGraph(object): + """Base Graph Model""" + + def __init__(self, args): + node_feature = [('nfeat', [None, 58], "float32"), + ('node_id', [None, 1], "int64")] + self.hidden_size = args.hidden_size + self.num_nodes = args.num_nodes + + self.graph_wrapper = None # GraphWrapper( + #name="graph", place=F.CPUPlace(), node_feat=node_feature) + + self.build_model(args) + + def build_model(self, args): + """ build graph model""" + self.batch_src = L.data(name="batch_src", shape=[-1], dtype="int64") + self.batch_src = L.reshape(self.batch_src, [-1, 1]) + self.batch_dst = L.data(name="batch_dst", shape=[-1], dtype="int64") + self.batch_dst = L.reshape(self.batch_dst, [-1, 1]) + self.labels = L.data(name="labels", shape=[-1], dtype="int64") + self.labels = L.reshape(self.labels, [-1, 1]) + self.labels.stop_gradients = True + self.src_repr = L.embedding( + self.batch_src, + size=(self.num_nodes, self.hidden_size), + param_attr=F.ParamAttr( + name="node_embeddings", + initializer=F.initializer.NormalInitializer( + loc=0.0, scale=1.0))) + + self.dst_repr = L.embedding( + self.batch_dst, + size=(self.num_nodes, self.hidden_size), + param_attr=F.ParamAttr( + name="node_embeddings", + initializer=F.initializer.NormalInitializer( + loc=0.0, scale=1.0))) + + self.link_predictor(self.src_repr, self.dst_repr) + + self.bce_loss() + + def link_predictor(self, x, y): + """ siamese network""" + feat = x * y + + feat = L.fc(feat, size=self.hidden_size, name="link_predictor_1") + feat = L.relu(feat) + + feat = L.fc(feat, size=self.hidden_size, name="link_predictor_2") + feat = L.relu(feat) + + self.logits = L.fc(feat, + size=1, + act="sigmoid", + name="link_predictor_logits") + + def bce_loss(self): + """listwise model""" + mask = L.cast(self.labels > 0.5, dtype="float32") + mask.stop_gradients = True + + self.loss = L.log_loss(self.logits, mask, epsilon=1e-15) + self.loss = L.reduce_mean(self.loss) * 2 + proba = L.sigmoid(self.logits) + proba = L.concat([proba * -1 + 1, proba], axis=1) + auc_out, batch_auc_out, _ = \ + L.auc(input=proba, label=self.labels, curve='ROC', slide_steps=1) + + self.metrics = { + "loss": self.loss, + "auc": batch_auc_out, + } + + def neighbor_aggregator(self, node_repr): + """neighbor aggregation""" + return node_repr diff --git a/ogb_examples/linkproppred/ogbl-ppa/monitor/__init__.py b/ogb_examples/linkproppred/ogbl-ppa/monitor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d814437561c253c97a95e31187e63a554476364f --- /dev/null +++ b/ogb_examples/linkproppred/ogbl-ppa/monitor/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""init""" diff --git a/ogb_examples/linkproppred/ogbl-ppa/monitor/train_monitor.py b/ogb_examples/linkproppred/ogbl-ppa/monitor/train_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..a517b7c2679f51f4247912df8f661a20792720b8 --- /dev/null +++ b/ogb_examples/linkproppred/ogbl-ppa/monitor/train_monitor.py @@ -0,0 +1,184 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""train and evaluate""" +import tqdm +import json +import numpy as np +import sys +import os +import paddle.fluid as F +from tensorboardX import SummaryWriter +from ogb.linkproppred import Evaluator +from ogb.linkproppred import LinkPropPredDataset + + +def multi_device(reader, dev_count): + """multi device""" + if dev_count == 1: + for batch in reader: + yield batch + else: + batches = [] + for batch in reader: + batches.append(batch) + if len(batches) == dev_count: + yield batches + batches = [] + + +class OgbEvaluator(object): + def __init__(self): + d_name = "ogbl-ppa" + dataset = LinkPropPredDataset(name=d_name) + splitted_edge = dataset.get_edge_split() + graph = dataset[0] + self.num_nodes = graph["num_nodes"] + self.ogb_evaluator = Evaluator(name="ogbl-ppa") + + def eval(self, scores, labels, phase): + labels = np.reshape(labels, [-1]) + ret = {} + pos = scores[labels > 0.5].squeeze(-1) + neg = scores[labels < 0.5].squeeze(-1) + for K in [10, 50, 100]: + self.ogb_evaluator.K = K + ret['%s_hits@%s' % (phase, K)] = self.ogb_evaluator.eval({ + 'y_pred_pos': pos, + 'y_pred_neg': neg, + })[f'hits@{K}'] + return ret + + +def evaluate(model, valid_exe, valid_ds, valid_prog, dev_count, evaluator, + phase): + """evaluate """ + cc = 0 + scores = [] + labels = [] + + for feed_dict in tqdm.tqdm( + multi_device(valid_ds.generator(), dev_count), desc='evaluating'): + + if dev_count > 1: + output = valid_exe.run(feed=feed_dict, + fetch_list=[model.logits, model.labels]) + else: + output = valid_exe.run(valid_prog, + feed=feed_dict, + fetch_list=[model.logits, model.labels]) + scores.append(output[0]) + labels.append(output[1]) + + scores = np.vstack(scores) + labels = np.vstack(labels) + ret = evaluator.eval(scores, labels, phase) + return ret + + +def _create_if_not_exist(path): + basedir = os.path.dirname(path) + if not os.path.exists(basedir): + os.makedirs(basedir) + + +def train_and_evaluate(exe, + train_exe, + valid_exe, + train_ds, + valid_ds, + test_ds, + train_prog, + valid_prog, + model, + metric, + epoch=20, + dev_count=1, + train_log_step=5, + eval_step=10000, + evaluator=None, + output_path=None): + """train and evaluate""" + + global_step = 0 + + log_path = os.path.join(output_path, "log") + _create_if_not_exist(log_path) + + writer = SummaryWriter(log_path) + + best_model = 0 + for e in range(epoch): + for feed_dict in tqdm.tqdm( + multi_device(train_ds.generator(), dev_count), + desc='Epoch %s' % e): + if dev_count > 1: + ret = train_exe.run(feed=feed_dict, fetch_list=metric.vars) + ret = [[np.mean(v)] for v in ret] + else: + ret = train_exe.run(train_prog, + feed=feed_dict, + fetch_list=metric.vars) + + ret = metric.parse(ret) + if global_step % train_log_step == 0: + for key, value in ret.items(): + writer.add_scalar( + 'train_' + key, value, global_step=global_step) + + global_step += 1 + if global_step % eval_step == 0: + eval_ret = evaluate(model, exe, valid_ds, valid_prog, 1, + evaluator, "valid") + + test_eval_ret = evaluate(model, exe, test_ds, valid_prog, 1, + evaluator, "test") + + eval_ret.update(test_eval_ret) + + sys.stderr.write(json.dumps(eval_ret, indent=4) + "\n") + + for key, value in eval_ret.items(): + writer.add_scalar(key, value, global_step=global_step) + + if eval_ret["valid_hits@100"] > best_model: + F.io.save_persistables( + exe, + os.path.join(output_path, "checkpoint"), train_prog) + eval_ret["step"] = global_step + with open(os.path.join(output_path, "best.txt"), "w") as f: + f.write(json.dumps(eval_ret, indent=2) + '\n') + best_model = eval_ret["valid_hits@100"] + # Epoch End + eval_ret = evaluate(model, exe, valid_ds, valid_prog, 1, evaluator, + "valid") + + test_eval_ret = evaluate(model, exe, test_ds, valid_prog, 1, evaluator, + "test") + + eval_ret.update(test_eval_ret) + sys.stderr.write(json.dumps(eval_ret, indent=4) + "\n") + + for key, value in eval_ret.items(): + writer.add_scalar(key, value, global_step=global_step) + + if eval_ret["valid_hits@100"] > best_model: + F.io.save_persistables(exe, + os.path.join(output_path, "checkpoint"), + train_prog) + eval_ret["step"] = global_step + with open(os.path.join(output_path, "best.txt"), "w") as f: + f.write(json.dumps(eval_ret, indent=2) + '\n') + best_model = eval_ret["valid_hits@100"] + + writer.close() diff --git a/ogb_examples/linkproppred/ogbl-ppa/train.py b/ogb_examples/linkproppred/ogbl-ppa/train.py new file mode 100644 index 0000000000000000000000000000000000000000..c70fa4f9dd4987e615f6f935b5108c727fe7abee --- /dev/null +++ b/ogb_examples/linkproppred/ogbl-ppa/train.py @@ -0,0 +1,157 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""listwise model +""" + +import torch +import os +import re +import time +import logging +from random import random +from functools import reduce, partial + +# For downloading ogb +import ssl +ssl._create_default_https_context = ssl._create_unverified_context +# SSL + +import numpy as np +import multiprocessing + +import pgl +import paddle +import paddle.fluid as F +import paddle.fluid.layers as L + +from args import parser +from utils.args import print_arguments, check_cuda +from utils.init import init_checkpoint, init_pretraining_params +from model import BaseGraph +from dataloader.ogbl_ppa_dataloader import PPADataGenerator +from monitor.train_monitor import train_and_evaluate, OgbEvaluator + +log = logging.getLogger(__name__) + + +class Metric(object): + """Metric""" + + def __init__(self, **args): + self.args = args + + @property + def vars(self): + """ fetch metric vars""" + values = [self.args[k] for k in self.args.keys()] + return values + + def parse(self, fetch_list): + """parse""" + tup = list(zip(self.args.keys(), [float(v[0]) for v in fetch_list])) + return dict(tup) + + +if __name__ == '__main__': + args = parser.parse_args() + print_arguments(args) + evaluator = OgbEvaluator() + + train_prog = F.Program() + startup_prog = F.Program() + args.num_nodes = evaluator.num_nodes + + if args.use_cuda: + dev_list = F.cuda_places() + place = dev_list[0] + dev_count = len(dev_list) + else: + place = F.CPUPlace() + dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + + with F.program_guard(train_prog, startup_prog): + with F.unique_name.guard(): + graph_model = BaseGraph(args) + test_prog = train_prog.clone(for_test=True) + opt = F.optimizer.Adam(learning_rate=args.learning_rate) + opt.minimize(graph_model.loss) + + #test_prog = F.Program() + #with F.program_guard(test_prog, startup_prog): + # with F.unique_name.guard(): + # _graph_model = BaseGraph(args) + + train_ds = PPADataGenerator( + phase="train", + graph_wrapper=graph_model.graph_wrapper, + num_workers=args.num_workers, + batch_size=args.batch_size) + + valid_ds = PPADataGenerator( + phase="valid", + graph_wrapper=graph_model.graph_wrapper, + num_workers=args.num_workers, + batch_size=args.batch_size) + + test_ds = PPADataGenerator( + phase="test", + graph_wrapper=graph_model.graph_wrapper, + num_workers=args.num_workers, + batch_size=args.batch_size) + + exe = F.Executor(place) + exe.run(startup_prog) + + if args.init_pretraining_params is not None: + init_pretraining_params( + exe, args.init_pretraining_params, main_program=startup_prog) + + metric = Metric(**graph_model.metrics) + + nccl2_num_trainers = 1 + nccl2_trainer_id = 0 + if dev_count > 1: + + exec_strategy = F.ExecutionStrategy() + exec_strategy.num_threads = dev_count + + train_exe = F.ParallelExecutor( + use_cuda=args.use_cuda, + loss_name=graph_model.loss.name, + exec_strategy=exec_strategy, + main_program=train_prog, + num_trainers=nccl2_num_trainers, + trainer_id=nccl2_trainer_id) + + test_exe = exe + else: + train_exe, test_exe = exe, exe + + train_and_evaluate( + exe=exe, + train_exe=train_exe, + valid_exe=test_exe, + train_ds=train_ds, + valid_ds=valid_ds, + test_ds=test_ds, + train_prog=train_prog, + valid_prog=test_prog, + train_log_step=5, + output_path=args.output_path, + dev_count=dev_count, + model=graph_model, + epoch=args.epoch, + eval_step=1000000, + evaluator=evaluator, + metric=metric) diff --git a/ogb_examples/linkproppred/ogbl-ppa/utils/__init__.py b/ogb_examples/linkproppred/ogbl-ppa/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1333621cf62da67fcf10016fc848c503f7c254fa --- /dev/null +++ b/ogb_examples/linkproppred/ogbl-ppa/utils/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""utils""" diff --git a/ogb_examples/linkproppred/ogbl-ppa/utils/args.py b/ogb_examples/linkproppred/ogbl-ppa/utils/args.py new file mode 100644 index 0000000000000000000000000000000000000000..5131f2ceb88775f12e886402ef205735a1ac1d77 --- /dev/null +++ b/ogb_examples/linkproppred/ogbl-ppa/utils/args.py @@ -0,0 +1,97 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Arguments for configuration.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +from __future__ import absolute_import + +import six +import os +import sys +import argparse +import logging + +import paddle.fluid as fluid + +log = logging.getLogger(__name__) + + +def prepare_logger(logger, debug=False, save_to_file=None): + """doc""" + formatter = logging.Formatter( + fmt='[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:\t%(message)s' + ) + #console_hdl = logging.StreamHandler() + #console_hdl.setFormatter(formatter) + #logger.addHandler(console_hdl) + if save_to_file is not None and not os.path.exists(save_to_file): + file_hdl = logging.FileHandler(save_to_file) + file_hdl.setFormatter(formatter) + logger.addHandler(file_hdl) + logger.setLevel(logging.DEBUG) + logger.propagate = False + + +def str2bool(v): + """doc""" + # because argparse does not support to parse "true, False" as python + # boolean directly + return v.lower() in ("true", "t", "1") + + +class ArgumentGroup(object): + """doc""" + + def __init__(self, parser, title, des): + self._group = parser.add_argument_group(title=title, description=des) + + def add_arg(self, + name, + type, + default, + help, + positional_arg=False, + **kwargs): + """doc""" + prefix = "" if positional_arg else "--" + type = str2bool if type == bool else type + self._group.add_argument( + prefix + name, + default=default, + type=type, + help=help + ' Default: %(default)s.', + **kwargs) + + +def print_arguments(args): + """doc""" + log.info('----------- Configuration Arguments -----------') + for arg, value in sorted(six.iteritems(vars(args))): + log.info('%s: %s' % (arg, value)) + log.info('------------------------------------------------') + + +def check_cuda(use_cuda, err= \ + "\nYou can not set use_cuda=True in the model because you are using paddlepaddle-cpu.\n \ + Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_cuda=False to run models on CPU.\n" + ): + """doc""" + try: + if use_cuda == True and fluid.is_compiled_with_cuda() == False: + log.error(err) + sys.exit(1) + except Exception as e: + pass diff --git a/ogb_examples/linkproppred/ogbl-ppa/utils/cards.py b/ogb_examples/linkproppred/ogbl-ppa/utils/cards.py new file mode 100644 index 0000000000000000000000000000000000000000..2b658a4bf6272f00f48ff447caaaa580189afe60 --- /dev/null +++ b/ogb_examples/linkproppred/ogbl-ppa/utils/cards.py @@ -0,0 +1,31 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""cards""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +from __future__ import absolute_import +import os + + +def get_cards(): + """ + get gpu cards number + """ + num = 0 + cards = os.environ.get('CUDA_VISIBLE_DEVICES', '') + if cards != '': + num = len(cards.split(",")) + return num diff --git a/ogb_examples/linkproppred/ogbl-ppa/utils/fp16.py b/ogb_examples/linkproppred/ogbl-ppa/utils/fp16.py new file mode 100644 index 0000000000000000000000000000000000000000..740add267dff2dbf463032bcc47a6741ca9f7c43 --- /dev/null +++ b/ogb_examples/linkproppred/ogbl-ppa/utils/fp16.py @@ -0,0 +1,201 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import paddle +import paddle.fluid as fluid + + +def append_cast_op(i, o, prog): + """ + Append a cast op in a given Program to cast input `i` to data type `o.dtype`. + Args: + i (Variable): The input Variable. + o (Variable): The output Variable. + prog (Program): The Program to append cast op. + """ + prog.global_block().append_op( + type="cast", + inputs={"X": i}, + outputs={"Out": o}, + attrs={"in_dtype": i.dtype, + "out_dtype": o.dtype}) + + +def copy_to_master_param(p, block): + v = block.vars.get(p.name, None) + if v is None: + raise ValueError("no param name %s found!" % p.name) + new_p = fluid.framework.Parameter( + block=block, + shape=v.shape, + dtype=fluid.core.VarDesc.VarType.FP32, + type=v.type, + lod_level=v.lod_level, + stop_gradient=p.stop_gradient, + trainable=p.trainable, + optimize_attr=p.optimize_attr, + regularizer=p.regularizer, + gradient_clip_attr=p.gradient_clip_attr, + error_clip=p.error_clip, + name=v.name + ".master") + return new_p + + +def apply_dynamic_loss_scaling(loss_scaling, master_params_grads, + incr_every_n_steps, decr_every_n_nan_or_inf, + incr_ratio, decr_ratio): + _incr_every_n_steps = fluid.layers.fill_constant( + shape=[1], dtype='int32', value=incr_every_n_steps) + _decr_every_n_nan_or_inf = fluid.layers.fill_constant( + shape=[1], dtype='int32', value=decr_every_n_nan_or_inf) + + _num_good_steps = fluid.layers.create_global_var( + name=fluid.unique_name.generate("num_good_steps"), + shape=[1], + value=0, + dtype='int32', + persistable=True) + _num_bad_steps = fluid.layers.create_global_var( + name=fluid.unique_name.generate("num_bad_steps"), + shape=[1], + value=0, + dtype='int32', + persistable=True) + + grads = [fluid.layers.reduce_sum(g) for [_, g] in master_params_grads] + all_grads = fluid.layers.concat(grads) + all_grads_sum = fluid.layers.reduce_sum(all_grads) + is_overall_finite = fluid.layers.isfinite(all_grads_sum) + + update_loss_scaling(is_overall_finite, loss_scaling, _num_good_steps, + _num_bad_steps, _incr_every_n_steps, + _decr_every_n_nan_or_inf, incr_ratio, decr_ratio) + + # apply_gradient append all ops in global block, thus we shouldn't + # apply gradient in the switch branch. + with fluid.layers.Switch() as switch: + with switch.case(is_overall_finite): + pass + with switch.default(): + for _, g in master_params_grads: + fluid.layers.assign(fluid.layers.zeros_like(g), g) + + +def create_master_params_grads(params_grads, main_prog, startup_prog, + loss_scaling): + master_params_grads = [] + for p, g in params_grads: + with main_prog._optimized_guard([p, g]): + # create master parameters + master_param = copy_to_master_param(p, main_prog.global_block()) + startup_master_param = startup_prog.global_block()._clone_variable( + master_param) + startup_p = startup_prog.global_block().var(p.name) + append_cast_op(startup_p, startup_master_param, startup_prog) + # cast fp16 gradients to fp32 before apply gradients + if g.name.find("layer_norm") > -1: + scaled_g = g / loss_scaling + master_params_grads.append([p, scaled_g]) + continue + master_grad = fluid.layers.cast(g, "float32") + master_grad = master_grad / loss_scaling + master_params_grads.append([master_param, master_grad]) + + return master_params_grads + + +def master_param_to_train_param(master_params_grads, params_grads, main_prog): + for idx, m_p_g in enumerate(master_params_grads): + train_p, _ = params_grads[idx] + if train_p.name.find("layer_norm") > -1: + continue + with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]): + append_cast_op(m_p_g[0], train_p, main_prog) + + +def update_loss_scaling(is_overall_finite, prev_loss_scaling, num_good_steps, + num_bad_steps, incr_every_n_steps, + decr_every_n_nan_or_inf, incr_ratio, decr_ratio): + """ + Update loss scaling according to overall gradients. If all gradients is + finite after incr_every_n_steps, loss scaling will increase by incr_ratio. + Otherwisw, loss scaling will decrease by decr_ratio after + decr_every_n_nan_or_inf steps and each step some gradients are infinite. + Args: + is_overall_finite (Variable): A boolean variable indicates whether + all gradients are finite. + prev_loss_scaling (Variable): Previous loss scaling. + num_good_steps (Variable): A variable accumulates good steps in which + all gradients are finite. + num_bad_steps (Variable): A variable accumulates bad steps in which + some gradients are infinite. + incr_every_n_steps (Variable): A variable represents increasing loss + scaling every n consecutive steps with + finite gradients. + decr_every_n_nan_or_inf (Variable): A variable represents decreasing + loss scaling every n accumulated + steps with nan or inf gradients. + incr_ratio(float): The multiplier to use when increasing the loss + scaling. + decr_ratio(float): The less-than-one-multiplier to use when decreasing + loss scaling. + """ + zero_steps = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0) + with fluid.layers.Switch() as switch: + with switch.case(is_overall_finite): + should_incr_loss_scaling = fluid.layers.less_than( + incr_every_n_steps, num_good_steps + 1) + with fluid.layers.Switch() as switch1: + with switch1.case(should_incr_loss_scaling): + new_loss_scaling = prev_loss_scaling * incr_ratio + loss_scaling_is_finite = fluid.layers.isfinite( + new_loss_scaling) + with fluid.layers.Switch() as switch2: + with switch2.case(loss_scaling_is_finite): + fluid.layers.assign(new_loss_scaling, + prev_loss_scaling) + with switch2.default(): + pass + fluid.layers.assign(zero_steps, num_good_steps) + fluid.layers.assign(zero_steps, num_bad_steps) + + with switch1.default(): + fluid.layers.increment(num_good_steps) + fluid.layers.assign(zero_steps, num_bad_steps) + + with switch.default(): + should_decr_loss_scaling = fluid.layers.less_than( + decr_every_n_nan_or_inf, num_bad_steps + 1) + with fluid.layers.Switch() as switch3: + with switch3.case(should_decr_loss_scaling): + new_loss_scaling = prev_loss_scaling * decr_ratio + static_loss_scaling = \ + fluid.layers.fill_constant(shape=[1], + dtype='float32', + value=1.0) + less_than_one = fluid.layers.less_than(new_loss_scaling, + static_loss_scaling) + with fluid.layers.Switch() as switch4: + with switch4.case(less_than_one): + fluid.layers.assign(static_loss_scaling, + prev_loss_scaling) + with switch4.default(): + fluid.layers.assign(new_loss_scaling, + prev_loss_scaling) + fluid.layers.assign(zero_steps, num_good_steps) + fluid.layers.assign(zero_steps, num_bad_steps) + with switch3.default(): + fluid.layers.assign(zero_steps, num_good_steps) + fluid.layers.increment(num_bad_steps) diff --git a/ogb_examples/linkproppred/ogbl-ppa/utils/init.py b/ogb_examples/linkproppred/ogbl-ppa/utils/init.py new file mode 100644 index 0000000000000000000000000000000000000000..baa3ba5987cf1cbae20a60ea88e3f3bf0e389f43 --- /dev/null +++ b/ogb_examples/linkproppred/ogbl-ppa/utils/init.py @@ -0,0 +1,97 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""paddle init""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals +from __future__ import absolute_import + +import os +import six +import ast +import copy +import logging + +import numpy as np +import paddle.fluid as fluid + +log = logging.getLogger(__name__) + + +def cast_fp32_to_fp16(exe, main_program): + """doc""" + log.info("Cast parameters to float16 data format.") + for param in main_program.global_block().all_parameters(): + if not param.name.endswith(".master"): + param_t = fluid.global_scope().find_var(param.name).get_tensor() + data = np.array(param_t) + if param.name.startswith("encoder_layer") \ + and "layer_norm" not in param.name: + param_t.set(np.float16(data).view(np.uint16), exe.place) + + #load fp32 + master_param_var = fluid.global_scope().find_var(param.name + + ".master") + if master_param_var is not None: + master_param_var.get_tensor().set(data, exe.place) + + +def init_checkpoint(exe, init_checkpoint_path, main_program, use_fp16=False): + """init""" + assert os.path.exists( + init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path + + def existed_persitables(var): + """existed""" + if not fluid.io.is_persistable(var): + return False + return os.path.exists(os.path.join(init_checkpoint_path, var.name)) + + fluid.io.load_vars( + exe, + init_checkpoint_path, + main_program=main_program, + predicate=existed_persitables) + log.info("Load model from {}".format(init_checkpoint_path)) + + if use_fp16: + cast_fp32_to_fp16(exe, main_program) + + +def init_pretraining_params(exe, + pretraining_params_path, + main_program, + use_fp16=False): + """init""" + assert os.path.exists(pretraining_params_path + ), "[%s] cann't be found." % pretraining_params_path + + def existed_params(var): + """doc""" + if not isinstance(var, fluid.framework.Parameter): + return False + return os.path.exists(os.path.join(pretraining_params_path, var.name)) + + fluid.io.load_vars( + exe, + pretraining_params_path, + main_program=main_program, + predicate=existed_params) + log.info("Load pretraining parameters from {}.".format( + pretraining_params_path)) + + if use_fp16: + cast_fp32_to_fp16(exe, main_program) diff --git a/pgl/graph_wrapper.py b/pgl/graph_wrapper.py index 3f30da477a5e97287315efc4e693eaa022399d84..ddc19c204a87be0993fa3656b561ac1a6ad5ccc2 100644 --- a/pgl/graph_wrapper.py +++ b/pgl/graph_wrapper.py @@ -40,7 +40,6 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, num_edges): """Recv message from given msg to dst nodes. """ - empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype="float32") if reduce_function == "sum": if isinstance(msg, dict): raise TypeError("The message for build-in function" @@ -49,8 +48,9 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, try: out_dim = msg.shape[-1] init_output = fluid.layers.fill_constant( - shape=[num_nodes, out_dim], value=0, dtype="float32") + shape=[num_nodes, out_dim], value=0, dtype=msg.dtype) init_output.stop_gradient = False + empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=msg.dtype) msg = msg * empty_msg_flag output = paddle_helper.scatter_add(init_output, dst, msg) return output @@ -66,10 +66,12 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, bucketed_msg = op.nested_lod_reset(msg, bucketing_index) output = reduce_function(bucketed_msg) output_dim = output.shape[-1] + + empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=output.dtype) output = output * empty_msg_flag init_output = fluid.layers.fill_constant( - shape=[num_nodes, output_dim], value=0, dtype="float32") + shape=[num_nodes, output_dim], value=0, dtype=output.dtype) init_output.stop_gradient = True final_output = fluid.layers.scatter(init_output, uniq_dst, output) return final_output