diff --git a/examples/kg/README.md b/examples/kg/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2dcde4b12a3c34088693b45f3f6831d555ced6da --- /dev/null +++ b/examples/kg/README.md @@ -0,0 +1,44 @@ +# PGL - Knowledge Graph Embedding + +## Introduction +This package is mainly for computing node and relation embedding of knowledge graphs efficiently. + + +This package reproduce the following knowledge embedding models: +- TransE +- TransR +- RotatE + +## Dataset + +The dataset WN18 and FB15k are originally published by TransE paper and and be download [here](https://everest.hds.utc.fr/doku.php?id=en:transe) + + +## Dependencies +If you want to use the PGL-KGE in paddle, please install following packages. +- paddlepaddle>=1.7 +- pgl + + +## Experiment results +FB15k dataset + +| Models |Mean Rank| Mrr | Hits@1 | Hits@3 | Hits@10 | MR@filter| Hits10@filter| +|----------|-------|-------|--------|--------|---------|---------|---------| +| TransE| 214 | -- | -- | -- | 0.491 | 118 | 0.668| +| TransR| 202 | -- | -- | -- | 0.502 | 115 | 0.683| +| RotatE| 156| -- | -- | -- | 0.498 | 52 | 0.710| + +WN18 dataset + +| Models |Mean Rank| Mrr | Hits@1 | Hits@3 | Hits@10 | MR@filter| Hits10@filter| +|----------|-------|-------|--------|--------|---------|---------|---------| +| TransE| 257 | -- | -- | -- | 0.800 | 245 | 0.915| +| TransR| 255 | -- | -- | -- | 0.8012| 243 | 0.9371| +| RotatE| 188 | -- | -- | -- | 0.8325| 176 | 0.9601| + +## References + +[1]. TransE https://ieeexplore.ieee.org/abstract/document/8047276 +[2]. TransR http://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/viewFile/9571/9523 +[3]. RotatE https://arxiv.org/abs/1902.10197 diff --git a/examples/kg/data_loader.py b/examples/kg/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..09f2435aa1f16a78819c9ed9a0e27a29e2a0b130 --- /dev/null +++ b/examples/kg/data_loader.py @@ -0,0 +1,192 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed und +# er the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +loader for the knowledge dataset. +""" +import os +import numpy as np +from collections import defaultdict +from pgl.utils.logger import log +from pybloom import BloomFilter + + +class KBloader: + """ + load the FB15K + """ + + def __init__(self, data_dir, batch_size, neg_mode, neg_times): + """init""" + self.name = os.path.split(data_dir)[-1] + self._feed_list = ["pos_triple", "neg_triple"] + self._data_dir = data_dir + self._batch_size = batch_size + self._neg_mode = neg_mode + self._neg_times = neg_times + + self._entity2id = {} + self._relation2id = {} + self.training_triple_pool = set() + + self._triple_train = None + self._triple_test = None + self._triple_valid = None + + self.entity_total = 0 + self.relation_total = 0 + self.train_num = 0 + self.test_num = 0 + self.valid_num = 0 + + self.load_data() + + def test_data_batch(self, batch_size=None): + """ + Test data reader. + :param batch_size: Todo: batch_size > 1. + :return: None + """ + for i in range(self.test_num): + data = np.array(self._triple_test[i]) + data = data.reshape((-1)) + yield [data] + + def training_data_no_filter(self, train_triple_positive): + """faster, no filter for exists triples""" + size = len(train_triple_positive) + train_triple_negative = train_triple_positive + 0 + replace_head_probability = 0.5 * np.ones(size) + replace_entity_id = np.random.randint(self.entity_total, size=size) + random_num = np.random.random(size=size) + index_t = (random_num < replace_head_probability) * 1 + train_triple_negative[:, 0] = train_triple_negative[:, 0] + ( + replace_entity_id - train_triple_negative[:, 0]) * index_t + train_triple_negative[:, 2] = replace_entity_id + ( + train_triple_negative[:, 2] - replace_entity_id) * index_t + train_triple_positive = np.expand_dims(train_triple_positive, axis=2) + train_triple_negative = np.expand_dims(train_triple_negative, axis=2) + return train_triple_positive, train_triple_negative + + def training_data_map(self, train_triple_positive): + """ + Map function for negative sampling. + :param train_triple_positive: the triple positive. + :return: the positive and negative triples. + """ + size = len(train_triple_positive) + train_triple_negative = [] + for i in range(size): + corrupt_head_prob = np.random.binomial(1, 0.5) + head_neg = train_triple_positive[i][0] + relation = train_triple_positive[i][1] + tail_neg = train_triple_positive[i][2] + for j in range(0, self._neg_times): + sample = train_triple_positive[i] + 0 + while True: + rand_id = np.random.randint(self.entity_total) + if corrupt_head_prob: + if (rand_id, relation, tail_neg + ) not in self.training_triple_pool: + sample[0] = rand_id + train_triple_negative.append(sample) + break + else: + if (head_neg, relation, rand_id + ) not in self.training_triple_pool: + sample[2] = rand_id + train_triple_negative.append(sample) + break + train_triple_positive = np.expand_dims(train_triple_positive, axis=2) + train_triple_negative = np.expand_dims(train_triple_negative, axis=2) + if self._neg_mode: + return train_triple_positive, train_triple_negative, np.array( + [corrupt_head_prob], dtype="float32") + return train_triple_positive, train_triple_negative + + def training_data_batch(self): + """ + train_triple_positive + :return: + """ + n = len(self._triple_train) + rand_idx = np.random.permutation(n) + rand_idx = rand_idx % n + n_triple = len(rand_idx) + start = 0 + while start < n_triple: + end = min(start + self._batch_size, n_triple) + train_triple_positive = self._triple_train[rand_idx[start:end]] + start = end + yield train_triple_positive + + def load_kg_triple(self, file): + """ + Read in kg files. + """ + triples = [] + with open(os.path.join(self._data_dir, file), "r") as f: + for line in f.readlines(): + line_list = line.strip().split('\t') + assert len(line_list) == 3 + head = self._entity2id[line_list[0]] + tail = self._entity2id[line_list[1]] + relation = self._relation2id[line_list[2]] + triples.append((head, relation, tail)) + return np.array(triples) + + def load_data(self): + """ + load kg dataset. + """ + log.info("Start loading the {} dataset".format(self.name)) + with open(os.path.join(self._data_dir, 'entity2id.txt'), "r") as f: + for line in f.readlines(): + line = line.strip().split('\t') + self._entity2id[line[0]] = int(line[1]) + with open(os.path.join(self._data_dir, 'relation2id.txt'), "r") as f: + for line in f.readlines(): + line = line.strip().split('\t') + self._relation2id[line[0]] = int(line[1]) + self._triple_train = self.load_kg_triple('train.txt') + self._triple_test = self.load_kg_triple('test.txt') + self._triple_valid = self.load_kg_triple('valid.txt') + + self.relation_total = len(self._relation2id) + self.entity_total = len(self._entity2id) + self.train_num = len(self._triple_train) + self.test_num = len(self._triple_test) + self.valid_num = len(self._triple_valid) + + #bloom_capacity = len(self._triple_train) + len(self._triple_test) + len(self._triple_valid) + #self.training_triple_pool = BloomFilter(capacity=bloom_capacity, error_rate=0.01) + for i in range(len(self._triple_train)): + self.training_triple_pool.add( + (self._triple_train[i, 0], self._triple_train[i, 1], + self._triple_train[i, 2])) + + for i in range(len(self._triple_test)): + self.training_triple_pool.add( + (self._triple_test[i, 0], self._triple_test[i, 1], + self._triple_test[i, 2])) + + for i in range(len(self._triple_valid)): + self.training_triple_pool.add( + (self._triple_valid[i, 0], self._triple_valid[i, 1], + self._triple_valid[i, 2])) + log.info('entity number: {}'.format(self.entity_total)) + log.info('relation number: {}'.format(self.relation_total)) + log.info('training triple number: {}'.format(self.train_num)) + log.info('testing triple number: {}'.format(self.test_num)) + log.info('valid triple number: {}'.format(self.valid_num)) diff --git a/examples/kg/evalutate.py b/examples/kg/evalutate.py new file mode 100644 index 0000000000000000000000000000000000000000..389b211f43517460488498b945b2e0863ce88796 --- /dev/null +++ b/examples/kg/evalutate.py @@ -0,0 +1,135 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Evaluate.py: Evaluator for the results of knowledge graph embeddings. +""" +import numpy as np +import timeit +from mp_mapper import mp_reader_mapper +from pgl.utils.logger import log + + +class Evaluate: + """ + Evaluate for trained models. + """ + + def __init__(self, reader): + self.reader = reader + self.training_triple_pool = self.reader.training_triple_pool + + @staticmethod + def rank_extract(results, training_triple_pool): + """ + :param results: the scores of test examples. + :param training_triple_pool: existing edges. + :return: the ranks. + """ + eval_triple, head_score, tail_score = results + head_order = np.argsort(head_score) + tail_order = np.argsort(tail_score) + head, relation, tail = eval_triple[0], eval_triple[1], eval_triple[2] + head_rank_raw = 1 + tail_rank_raw = 1 + head_rank_filter = 1 + tail_rank_filter = 1 + for candidate in head_order: + if candidate == head: + break + else: + head_rank_raw += 1 + if (candidate, relation, tail) in training_triple_pool: + continue + else: + head_rank_filter += 1 + for candidate in tail_order: + if candidate == tail: + break + else: + tail_rank_raw += 1 + if (head, relation, candidate) in training_triple_pool: + continue + else: + tail_rank_filter += 1 + return head_rank_raw, tail_rank_raw, head_rank_filter, tail_rank_filter + + def launch_evaluation(self, + exe, + program, + reader, + fetch_list, + num_workers=4): + """ + launch_evaluation + :param exe: executor. + :param program: paddle program. + :param reader: test reader. + :param fetch_list: fetch list. + :param num_workers: num of workers. + :return: None + """ + + def func(training_triple_pool): + """func""" + + def run_func(results): + """run_func""" + return self.rank_extract(results, training_triple_pool) + + return run_func + + def iterator(): + """iterator""" + n_used_eval_triple = 0 + start = timeit.default_timer() + for batch_feed_dict in reader(): + head_score, tail_score = exe.run(program=program, + fetch_list=fetch_list, + feed=batch_feed_dict) + yield batch_feed_dict["test_triple"], head_score, tail_score + n_used_eval_triple += 1 + print('[{:.3f}s] #evaluation triple: {}/{}'.format( + timeit.default_timer() - start, n_used_eval_triple, 5000)) + + res_reader = mp_reader_mapper( + reader=iterator, + func=func(self.training_triple_pool), + num_works=num_workers) + self.result(res_reader) + + @staticmethod + def result(rank_result_iter): + """ + Calculate the final results. + :param rank_result_iter: results iter. + :return: None + """ + all_rank = [[], []] + for data in rank_result_iter(): + for i in range(4): + all_rank[i // 2].append(data[i]) + + raw_rank = np.array(all_rank[0]) + filter_rank = np.array(all_rank[1]) + log.info("-----Raw-Average-Results") + log.info( + 'MeanRank: {:.2f}, MRR: {:.4f}, Hits@1: {:.4f}, Hits@3: {:.4f}, Hits@10: {:.4f}'. + format(raw_rank.mean(), (1 / raw_rank).mean(), (raw_rank <= 1). + mean(), (raw_rank <= 3).mean(), (raw_rank <= 10).mean())) + log.info("-----Filter-Average-Results") + log.info( + 'MeanRank: {:.2f}, MRR: {:.4f}, Hits@1: {:.4f}, Hits@3: {:.4f}, Hits@10: {:.4f}'. + format(filter_rank.mean(), (1 / filter_rank).mean(), ( + filter_rank <= 1).mean(), (filter_rank <= 3).mean(), ( + filter_rank <= 10).mean())) diff --git a/examples/kg/main.py b/examples/kg/main.py new file mode 100644 index 0000000000000000000000000000000000000000..84e0add710a40e53c6f0a36897b44c6e07c4fd69 --- /dev/null +++ b/examples/kg/main.py @@ -0,0 +1,282 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The script to run these models. +""" +import argparse +import timeit +import paddle.fluid as fluid +from data_loader import KBloader +from evalutate import Evaluate +from model import model_dict +from mp_mapper import mp_reader_mapper +from pgl.utils.logger import log + + +def run_round(batch_iter, + program, + exe, + fetch_list, + epoch, + prefix="train", + log_per_step=1000): + """ + Run the program for one epoch. + :param batch_iter: the batch_iter of prepared data. + :param program: the running program, train_program or test program. + :param exe: the executor of paddle. + :param fetch_list: the variables to fetch. + :param epoch: the epoch number of train process. + :param prefix: the prefix name, type `string`. + :param log_per_step: log per step. + :return: None + """ + batch = 0 + tmp_epoch = 0 + loss = 0 + tmp_loss = 0 + run_time = 0 + data_time = 0 + t2 = timeit.default_timer() + for batch_feed_dict in batch_iter(): + batch += 1 + t1 = timeit.default_timer() + data_time += (t1 - t2) + batch_fetch = exe.run(program, + fetch_list=fetch_list, + feed=batch_feed_dict) + if prefix == "train": + loss += batch_fetch[0] + tmp_loss += batch_fetch[0] + if batch % log_per_step == 0: + tmp_epoch += 1 + if prefix == "train": + log.info("Epoch %s Ava Loss %s" % + (epoch + tmp_epoch, tmp_loss / batch)) + else: + log.info("Batch %s" % batch) + batch = 0 + tmp_loss = 0 + + t2 = timeit.default_timer() + run_time += (t2 - t1) + + if prefix == "train": + log.info("GPU run time {}, Data prepare extra time {}".format( + run_time, data_time)) + log.info("Epoch %s \t All Loss %s" % (epoch + tmp_epoch, loss)) + + +def train(args): + """ + Train the knowledge graph embedding model. + :param args: all args. + :return: None + """ + kgreader = KBloader( + batch_size=args.batch_size, + data_dir=args.data_dir, + neg_mode=args.neg_mode, + neg_times=args.neg_times) + if args.model in model_dict: + Model = model_dict[args.model] + else: + raise ValueError("No model for name {}".format(args.model)) + model = Model( + data_reader=kgreader, + hidden_size=args.hidden_size, + margin=args.margin, + learning_rate=args.learning_rate, + args=args, + optimizer=args.optimizer) + + def iter_map_wrapper(data_batch, repeat=1): + """ + wrapper for multiprocess reader + :param data_batch: the source data iter. + :param repeat: repeat data for multi epoch + :return: iterator of feed data + """ + + def data_repeat(): + """repeat data for multi epoch""" + for i in range(repeat): + for d in data_batch(): + yield d + + reader = mp_reader_mapper( + data_repeat, + func=kgreader.training_data_map, + #func=kgreader.training_data_no_filter, + num_works=args.sample_workers) + + return reader + + def iter_wrapper(data_batch, feed_list): + """ + Decorator of make up the feed dict + :param data_batch: the source data iter. + :param feed_list: the feed list (names of variables). + :return: iterator of feed data. + """ + + def work(): + """work""" + for batch in data_batch(): + feed_dict = {} + for k, v in zip(feed_list, batch): + feed_dict[k] = v + yield feed_dict + + return work + + loader = fluid.io.DataLoader.from_generator( + feed_list=model.train_feed_vars, capacity=20, iterable=True) + + places = fluid.cuda_places() if args.use_cuda else fluid.cpu_places() + exe = fluid.Executor(places[0]) + exe.run(model.startup_program) + exe.run(fluid.default_startup_program()) + + prog = fluid.CompiledProgram(model.train_program).with_data_parallel( + loss_name=model.train_fetch_vars[0].name) + + if args.only_evaluate: + s = timeit.default_timer() + fluid.io.load_params( + exe, dirname=args.checkpoint, main_program=model.train_program) + Evaluate(kgreader).launch_evaluation( + exe=exe, + reader=iter_wrapper(kgreader.test_data_batch, + model.test_feed_list), + fetch_list=model.test_fetch_vars, + program=model.test_program, + num_workers=10) + log.info(timeit.default_timer() - s) + return None + + batch_iter = iter_map_wrapper( + kgreader.training_data_batch, + repeat=args.evaluate_per_iteration, ) + loader.set_batch_generator(batch_iter, places=places) + + for epoch in range(0, args.epoch // args.evaluate_per_iteration): + run_round( + batch_iter=loader, + exe=exe, + prefix="train", + # program=model.train_program, + program=prog, + fetch_list=model.train_fetch_vars, + log_per_step=kgreader.train_num // args.batch_size, + epoch=epoch * args.evaluate_per_iteration) + log.info("epoch\t%s" % ((1 + epoch) * args.evaluate_per_iteration)) + if True: + fluid.io.save_params( + exe, dirname=args.checkpoint, main_program=model.train_program) + eva = Evaluate(kgreader) + eva.launch_evaluation( + exe=exe, + reader=iter_wrapper(kgreader.test_data_batch, + model.test_feed_list), + fetch_list=model.test_fetch_vars, + program=model.test_program, + num_workers=10) + + +def main(): + """ + The main entry of all. + :return: None + """ + parser = argparse.ArgumentParser( + description="Knowledge Graph Embedding for PGL") + parser.add_argument('--use_cuda', action='store_true', help="use_cuda") + parser.add_argument( + '--data_dir', + dest='data_dir', + type=str, + help='the directory of dataset', + default='./data/WN18/') + parser.add_argument( + '--model', + dest='model', + type=str, + help="model to run", + default="TransE") + parser.add_argument( + '--learning_rate', + dest='learning_rate', + type=float, + help='learning rate', + default=0.001) + parser.add_argument( + '--epoch', dest='epoch', type=int, help='epoch to run', default=400) + parser.add_argument( + '--sample_workers', + dest='sample_workers', + type=int, + help='sample workers', + default=4) + parser.add_argument( + '--batch_size', + dest='batch_size', + type=int, + help="batch size", + default=1000) + parser.add_argument( + '--optimizer', + dest='optimizer', + type=str, + help='optimizer', + default='adam') + parser.add_argument( + '--hidden_size', + dest='hidden_size', + type=int, + help='embedding dimension', + default=50) + parser.add_argument( + '--margin', dest='margin', type=float, help='margin', default=4.0) + parser.add_argument( + '--checkpoint', + dest='checkpoint', + type=str, + help='directory to save checkpoint directory', + default='output/') + parser.add_argument( + '--evaluate_per_iteration', + dest='evaluate_per_iteration', + type=int, + help='evaluate the training result per x iteration', + default=50) + parser.add_argument( + '--only_evaluate', + dest='only_evaluate', + action='store_true', + help='only do the evaluate program', + default=False) + parser.add_argument( + '--adv_temp_value', type=float, help='adv_temp_value', default=2.0) + parser.add_argument('--neg_times', type=int, help='neg_times', default=1) + parser.add_argument( + '--neg_mode', type=bool, help='return neg mode flag', default=False) + + args = parser.parse_args() + log.info(args) + train(args) + + +if __name__ == '__main__': + main() diff --git a/examples/kg/model/Model.py b/examples/kg/model/Model.py new file mode 100644 index 0000000000000000000000000000000000000000..2769404c181f867e6677a0dff189b227c85c6c2c --- /dev/null +++ b/examples/kg/model/Model.py @@ -0,0 +1,127 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Base model of the knowledge graph embedding model. +""" +from paddle import fluid + + +class Model(object): + """ + Base model. + """ + + def __init__(self, **kwargs): + """ + Init model + """ + # Needed parameters + self.model_name = kwargs["model_name"] + self.data_reader = kwargs["data_reader"] + self._hidden_size = kwargs["hidden_size"] + self._learning_rate = kwargs["learning_rate"] + self._optimizer = kwargs["optimizer"] + self.args = kwargs["args"] + + # Optional parameters + if "margin" in kwargs: + self._margin = kwargs["margin"] + self._prefix = "%s_%s_dim=%d_" % ( + self.model_name, self.data_reader.name, self._hidden_size) + self.ent_name = self._prefix + "entity_embeddings" + self.rel_name = self._prefix + "relation_embeddings" + + self._entity_total = self.data_reader.entity_total + self._relation_total = self.data_reader.relation_total + self._ent_shape = [self._entity_total, self._hidden_size] + self._rel_shape = [self._relation_total, self._hidden_size] + + def construct(self): + """ + Construct the program + :return: None + """ + self.startup_program = fluid.Program() + self.train_program = fluid.Program() + self.test_program = fluid.Program() + + with fluid.program_guard(self.train_program, self.startup_program): + self.train_pos_input = fluid.layers.data( + "pos_triple", + dtype="int64", + shape=[None, 3, 1], + append_batch_size=False) + self.train_neg_input = fluid.layers.data( + "neg_triple", + dtype="int64", + shape=[None, 3, 1], + append_batch_size=False) + self.train_feed_list = ["pos_triple", "neg_triple"] + self.train_feed_vars = [self.train_pos_input, self.train_neg_input] + self.train_fetch_vars = self.construct_train_program() + loss = self.train_fetch_vars[0] + self.apply_optimizer(loss, opt=self._optimizer) + + with fluid.program_guard(self.test_program, self.startup_program): + self.test_input = fluid.layers.data( + "test_triple", + dtype="int64", + shape=[3], + append_batch_size=False) + self.test_feed_list = ["test_triple"] + self.test_fetch_vars = self.construct_test_program() + + def apply_optimizer(self, loss, opt="sgd"): + """ + Construct the backward of the train program. + :param loss: `type : variable` final loss of the model. + :param opt: `type : string` the optimizer name + :return: + """ + optimizer_available = { + "adam": fluid.optimizer.Adam, + "sgd": fluid.optimizer.SGD, + "momentum": fluid.optimizer.Momentum + } + if opt in optimizer_available: + opt_func = optimizer_available[opt] + else: + opt_func = None + if opt_func is None: + raise ValueError("You should chose the optimizer in %s" % + optimizer_available.keys()) + else: + optimizer = opt_func(learning_rate=self._learning_rate) + return optimizer.minimize(loss) + + def construct_train_program(self): + """ + This function should construct the train program with the `self.train_pos_input` + and `self.train_neg_input`. These inputs are batch of triples. + :return: List of variables you want to get. Please be sure the ':var loss' should + be in the first place, eg. [loss, variable1, variable2, ...]. + """ + raise NotImplementedError( + "You should define the construct_train_program" + " function before use it!") + + def construct_test_program(self): + """ + This function should construct test (or evaluate) program with the 'self.test_input'. + Util now, we only support a triple the evaluate the ranks. + :return: the distance of all entity with the test triple (for both head and tail entity). + """ + raise NotImplementedError( + "You should define the construct_test_program" + " function before use it") diff --git a/examples/kg/model/RotatE.py b/examples/kg/model/RotatE.py new file mode 100644 index 0000000000000000000000000000000000000000..610a2c0c95d41c5538bd3a9a6ecd380a0dcb09d5 --- /dev/null +++ b/examples/kg/model/RotatE.py @@ -0,0 +1,268 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +RotatE: +"Learning entity and relation embeddings for knowledge graph completion." +Lin, Yankai, et al. +https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/view/9571/9523 +""" +import paddle.fluid as fluid +from .Model import Model +from .utils import lookup_table + + +class RotatE(Model): + """ + RotatE model. + """ + + def __init__(self, + data_reader, + hidden_size, + margin, + learning_rate, + args, + optimizer="adam"): + super(RotatE, self).__init__( + model_name="RotatE", + data_reader=data_reader, + hidden_size=hidden_size, + margin=margin, + learning_rate=learning_rate, + args=args, + optimizer=optimizer) + + self._neg_times = self.args.neg_times + self._adv_temp_value = self.args.adv_temp_value + + self._relation_hidden_size = self._hidden_size + self._entity_hidden_size = self._hidden_size * 2 + self._entity_embedding_margin = ( + self._margin + 2) / self._entity_hidden_size + self._relation_embedding_margin = ( + self._margin + 2) / self._relation_hidden_size + self._rel_shape = [self._relation_total, self._relation_hidden_size] + self._ent_shape = [self._entity_total, self._entity_hidden_size] + self._pi = 3.141592654 + + self.construct_program() + + def construct_program(self): + """ + construct the main program for train and test + """ + self.startup_program = fluid.Program() + self.train_program = fluid.Program() + self.test_program = fluid.Program() + + with fluid.program_guard(self.train_program, self.startup_program): + self.train_pos_input = fluid.layers.data( + "pos_triple", + dtype="int64", + shape=[None, 3, 1], + append_batch_size=False) + self.train_neg_input = fluid.layers.data( + "neg_triple", + dtype="int64", + shape=[None, 3, 1], + append_batch_size=False) + self.train_neg_mode = fluid.layers.data( + "neg_mode", + dtype='float32', + shape=[1], + append_batch_size=False) + self.train_feed_vars = [ + self.train_pos_input, self.train_neg_input, self.train_neg_mode + ] + self.train_fetch_vars = self.construct_train_program() + loss = self.train_fetch_vars[0] + self.apply_optimizer(loss, opt=self._optimizer) + + with fluid.program_guard(self.test_program, self.startup_program): + self.test_input = fluid.layers.data( + "test_triple", + dtype="int64", + shape=[3], + append_batch_size=False) + self.test_feed_list = ["test_triple"] + self.test_fetch_vars = self.construct_test_program() + + def creat_share_variables(self): + """ + Share variables for train and test programs. + """ + entity_embedding = fluid.layers.create_parameter( + shape=self._ent_shape, + dtype="float32", + name=self.ent_name, + default_initializer=fluid.initializer.Uniform( + low=-1.0 * self._entity_embedding_margin, + high=1.0 * self._entity_embedding_margin)) + relation_embedding = fluid.layers.create_parameter( + shape=self._rel_shape, + dtype="float32", + name=self.rel_name, + default_initializer=fluid.initializer.Uniform( + low=-1.0 * self._relation_embedding_margin, + high=1.0 * self._relation_embedding_margin)) + + return entity_embedding, relation_embedding + + def score_with_l2_normalize(self, head, tail, rel, epsilon_var, + train_neg_mode): + """ + Score function of RotatE + """ + one_var = fluid.layers.fill_constant( + shape=[1], dtype='float32', value=1.0) + re_head, im_head = fluid.layers.split(head, num_or_sections=2, dim=-1) + re_tail, im_tail = fluid.layers.split(tail, num_or_sections=2, dim=-1) + + phase_relation = rel / (self._relation_embedding_margin / self._pi) + re_relation = fluid.layers.cos(phase_relation) + im_relation = fluid.layers.sin(phase_relation) + + re_score = re_relation * re_tail + im_relation * im_tail + im_score = re_relation * im_tail - im_relation * re_tail + re_score = re_score - re_head + im_score = im_score - im_head + #with fluid.layers.control_flow.Switch() as switch: + # with switch.case(train_neg_mode == one_var): + # re_score = re_relation * re_tail + im_relation * im_tail + # im_score = re_relation * im_tail - im_relation * re_tail + # re_score = re_score - re_head + # im_score = im_score - im_head + # with switch.default(): + # re_score = re_head * re_relation - im_head * im_relation + # im_score = re_head * im_relation + im_head * re_relation + # re_score = re_score - re_tail + # im_score = im_score - im_tail + + re_score = re_score * re_score + im_score = im_score * im_score + + score = re_score + im_score + score = score + epsilon_var + score = fluid.layers.sqrt(score) + score = fluid.layers.reduce_sum(score, dim=-1) + return self._margin - score + + def adverarial_weight(self, score): + """ + adverarial the weight for softmax + """ + adv_score = self._adv_temp_value * score + adv_softmax = fluid.layers.softmax(adv_score) + return adv_softmax + + def construct_train_program(self): + """ + Construct train program + """ + zero_var = fluid.layers.fill_constant( + shape=[1], dtype='float32', value=0.0) + epsilon_var = fluid.layers.fill_constant( + shape=[1], dtype='float32', value=1e-12) + entity_embedding, relation_embedding = self.creat_share_variables() + pos_head = lookup_table(self.train_pos_input[:, 0], entity_embedding) + pos_tail = lookup_table(self.train_pos_input[:, 2], entity_embedding) + pos_rel = lookup_table(self.train_pos_input[:, 1], relation_embedding) + neg_head = lookup_table(self.train_neg_input[:, 0], entity_embedding) + neg_tail = lookup_table(self.train_neg_input[:, 2], entity_embedding) + neg_rel = lookup_table(self.train_neg_input[:, 1], relation_embedding) + + pos_score = self.score_with_l2_normalize(pos_head, pos_tail, pos_rel, + epsilon_var, zero_var) + neg_score = self.score_with_l2_normalize( + neg_head, neg_tail, neg_rel, epsilon_var, self.train_neg_mode) + + neg_score = fluid.layers.reshape( + neg_score, shape=[-1, self._neg_times], inplace=True) + + if self._adv_temp_value > 0.0: + sigmoid_pos_score = fluid.layers.logsigmoid(1.0 * pos_score) + sigmoid_neg_score = fluid.layers.logsigmoid( + -1.0 * neg_score) * self.adverarial_weight(neg_score) + sigmoid_neg_score = fluid.layers.reduce_sum( + sigmoid_neg_score, dim=-1) + else: + sigmoid_pos_score = fluid.layers.logsigmoid(pos_score) + sigmoid_neg_score = fluid.layers.logsigmoid(-1.0 * neg_score) + + loss_1 = fluid.layers.mean(sigmoid_pos_score) + loss_2 = fluid.layers.mean(sigmoid_neg_score) + loss = -1.0 * (loss_1 + loss_2) / 2 + return [loss] + + def score_with_l2_normalize_with_validate(self, entity_embedding, head, + rel, tail, epsilon_var): + """ + the score function for validation + """ + re_entity_embedding, im_entity_embedding = fluid.layers.split( + entity_embedding, num_or_sections=2, dim=-1) + re_head, im_head = fluid.layers.split(head, num_or_sections=2, dim=-1) + re_tail, im_tail = fluid.layers.split(tail, num_or_sections=2, dim=-1) + phase_relation = rel / (self._relation_embedding_margin / self._pi) + re_relation = fluid.layers.cos(phase_relation) + im_relation = fluid.layers.sin(phase_relation) + + re_score = re_relation * re_tail + im_relation * im_tail + im_score = re_relation * im_tail - im_relation * re_tail + re_score = re_entity_embedding - re_score + im_score = im_entity_embedding - im_score + + re_score = re_score * re_score + im_score = im_score * im_score + head_score = re_score + im_score + head_score += epsilon_var + head_score = fluid.layers.sqrt(head_score) + head_score = fluid.layers.reduce_sum(head_score, dim=-1) + + re_score = re_head * re_relation - im_head * im_relation + im_score = re_head * im_relation + im_head * re_relation + re_score = re_entity_embedding - re_score + im_score = im_entity_embedding - im_score + + re_score = re_score * re_score + im_score = im_score * im_score + tail_score = re_score + im_score + tail_score += epsilon_var + tail_score = fluid.layers.sqrt(tail_score) + tail_score = fluid.layers.reduce_sum(tail_score, dim=-1) + + return head_score, tail_score + + def construct_test_program(self): + """ + Construct test program + """ + epsilon_var = fluid.layers.fill_constant( + shape=[1], dtype='float32', value=1e-12) + entity_embedding, relation_embedding = self.creat_share_variables() + + head_vec = lookup_table(self.test_input[0], entity_embedding) + rel_vec = lookup_table(self.test_input[1], relation_embedding) + tail_vec = lookup_table(self.test_input[2], entity_embedding) + head_vec = fluid.layers.unsqueeze(head_vec, axes=[0]) + rel_vec = fluid.layers.unsqueeze(rel_vec, axes=[0]) + tail_vec = fluid.layers.unsqueeze(tail_vec, axes=[0]) + + id_replace_head, id_replace_tail = self.score_with_l2_normalize_with_validate( + entity_embedding, head_vec, rel_vec, tail_vec, epsilon_var) + + id_replace_head = fluid.layers.logsigmoid(id_replace_head) + id_replace_tail = fluid.layers.logsigmoid(id_replace_tail) + + return [id_replace_head, id_replace_tail] diff --git a/examples/kg/model/TransE.py b/examples/kg/model/TransE.py new file mode 100644 index 0000000000000000000000000000000000000000..232b16253bcd5239439cf26ce24c8a97c97cc4da --- /dev/null +++ b/examples/kg/model/TransE.py @@ -0,0 +1,109 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +TransE: +"Translating embeddings for modeling multi-relational data." +Bordes, Antoine, et al. +https://www.utc.fr/~bordesan/dokuwiki/_media/en/transe_nips13.pdf +""" +import paddle.fluid as fluid +from .Model import Model +from .utils import lookup_table + + +class TransE(Model): + """ + The TransE Model. + """ + + def __init__(self, + data_reader, + hidden_size, + margin, + learning_rate, + args, + optimizer="adam"): + super(TransE, self).__init__( + model_name="TransE", + data_reader=data_reader, + hidden_size=hidden_size, + margin=margin, + learning_rate=learning_rate, + args=args, + optimizer=optimizer) + self.construct() + + def creat_share_variables(self): + """ + Share variables for train and test programs. + """ + entity_embedding = fluid.layers.create_parameter( + shape=self._ent_shape, dtype="float32", name=self.ent_name) + relation_embedding = fluid.layers.create_parameter( + shape=self._rel_shape, dtype="float32", name=self.rel_name) + return entity_embedding, relation_embedding + + @staticmethod + def score_with_l2_normalize(head, rel, tail): + """ + Score function of TransE + """ + head = fluid.layers.l2_normalize(head, axis=-1) + rel = fluid.layers.l2_normalize(rel, axis=-1) + tail = fluid.layers.l2_normalize(tail, axis=-1) + score = head + rel - tail + return score + + def construct_train_program(self): + """ + Construct train program. + """ + entity_embedding, relation_embedding = self.creat_share_variables() + pos_head = lookup_table(self.train_pos_input[:, 0], entity_embedding) + pos_tail = lookup_table(self.train_pos_input[:, 2], entity_embedding) + pos_rel = lookup_table(self.train_pos_input[:, 1], relation_embedding) + neg_head = lookup_table(self.train_neg_input[:, 0], entity_embedding) + neg_tail = lookup_table(self.train_neg_input[:, 2], entity_embedding) + neg_rel = lookup_table(self.train_neg_input[:, 1], relation_embedding) + + pos_score = self.score_with_l2_normalize(pos_head, pos_rel, pos_tail) + neg_score = self.score_with_l2_normalize(neg_head, neg_rel, neg_tail) + + pos = fluid.layers.reduce_sum( + fluid.layers.abs(pos_score), 1, keep_dim=False) + neg = fluid.layers.reduce_sum( + fluid.layers.abs(neg_score), 1, keep_dim=False) + loss = fluid.layers.reduce_mean( + fluid.layers.relu(pos - neg + self._margin)) + return [loss] + + def construct_test_program(self): + """ + Construct test program + """ + entity_embedding, relation_embedding = self.creat_share_variables() + entity_embedding = fluid.layers.l2_normalize(entity_embedding, axis=-1) + relation_embedding = fluid.layers.l2_normalize( + relation_embedding, axis=-1) + head_vec = lookup_table(self.test_input[0], entity_embedding) + rel_vec = lookup_table(self.test_input[1], relation_embedding) + tail_vec = lookup_table(self.test_input[2], entity_embedding) + # The paddle fluid.layers.topk GPU OP is very inefficient + # we do sort operation in the evaluation step using multiprocessing. + id_replace_head = fluid.layers.reduce_sum( + fluid.layers.abs(entity_embedding + rel_vec - tail_vec), dim=1) + id_replace_tail = fluid.layers.reduce_sum( + fluid.layers.abs(entity_embedding - rel_vec - head_vec), dim=1) + + return [id_replace_head, id_replace_tail] diff --git a/examples/kg/model/TransR.py b/examples/kg/model/TransR.py new file mode 100644 index 0000000000000000000000000000000000000000..ae9daf5d043a33b6f3e9f6e800c1640a9ab5ee14 --- /dev/null +++ b/examples/kg/model/TransR.py @@ -0,0 +1,167 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +TransR: +"Learning entity and relation embeddings for knowledge graph completion." +Lin, Yankai, et al. +https://www.aaai.org/ocs/index.php/AAAI/AAAI15/paper/view/9571/9523 +""" +import numpy as np +import paddle.fluid as fluid +from .Model import Model +from .utils import lookup_table + + +class TransR(Model): + """ + TransR model. + """ + + def __init__(self, + data_reader, + hidden_size, + margin, + learning_rate, + args, + optimizer="adam"): + """init""" + super(TransR, self).__init__( + model_name="TransR", + data_reader=data_reader, + hidden_size=hidden_size, + margin=margin, + learning_rate=learning_rate, + args=args, + optimizer=optimizer) + self.construct() + + def creat_share_variables(self): + """ + Share variables for train and test programs. + """ + entity_embedding = fluid.layers.create_parameter( + shape=self._ent_shape, + dtype="float32", + name=self.ent_name, + default_initializer=fluid.initializer.Xavier()) + relation_embedding = fluid.layers.create_parameter( + shape=self._rel_shape, + dtype="float32", + name=self.rel_name, + default_initializer=fluid.initializer.Xavier()) + transfer_matrix = fluid.layers.create_parameter( + shape=[ + self._relation_total, self._hidden_size * self._hidden_size + ], + dtype="float32", + name=self._prefix + "transfer_matrix", ) + # Here is a trick, must init with identity matrix to get good hit@10 performance. + fluid.layers.assign( + np.tile( + np.identity( + self._hidden_size, dtype="float32").reshape(-1), + (self._relation_total, 1)), + transfer_matrix) + return entity_embedding, relation_embedding, transfer_matrix + + def score_with_l2_normalize(self, head, rel, tail): + """ + Score function of TransR + """ + head = fluid.layers.l2_normalize(head, axis=-1) + rel = fluid.layers.l2_normalize(rel, axis=-1) + tail = fluid.layers.l2_normalize(tail, axis=-1) + score = head + rel - tail + return score + + @staticmethod + def matmul_with_expend_dims(x, y): + """matmul_with_expend_dims""" + x = fluid.layers.unsqueeze(x, axes=[1]) + res = fluid.layers.matmul(x, y) + return fluid.layers.squeeze(res, axes=[1]) + + def construct_train_program(self): + """ + Construct train program + """ + entity_embedding, relation_embedding, transfer_matrix = self.creat_share_variables( + ) + pos_head = lookup_table(self.train_pos_input[:, 0], entity_embedding) + pos_tail = lookup_table(self.train_pos_input[:, 2], entity_embedding) + pos_rel = lookup_table(self.train_pos_input[:, 1], relation_embedding) + neg_head = lookup_table(self.train_neg_input[:, 0], entity_embedding) + neg_tail = lookup_table(self.train_neg_input[:, 2], entity_embedding) + neg_rel = lookup_table(self.train_neg_input[:, 1], relation_embedding) + + rel_matrix = fluid.layers.reshape( + lookup_table(self.train_pos_input[:, 1], transfer_matrix), + [-1, self._hidden_size, self._hidden_size]) + pos_head_trans = self.matmul_with_expend_dims(pos_head, rel_matrix) + pos_tail_trans = self.matmul_with_expend_dims(pos_tail, rel_matrix) + + trans_neg = False + if trans_neg: + rel_matrix_neg = fluid.layers.reshape( + lookup_table(self.train_neg_input[:, 1], transfer_matrix), + [-1, self._hidden_size, self._hidden_size]) + neg_head_trans = self.matmul_with_expend_dims(neg_head, + rel_matrix_neg) + neg_tail_trans = self.matmul_with_expend_dims(neg_tail, + rel_matrix_neg) + else: + neg_head_trans = self.matmul_with_expend_dims(neg_head, rel_matrix) + neg_tail_trans = self.matmul_with_expend_dims(neg_tail, rel_matrix) + + pos_score = self.score_with_l2_normalize(pos_head_trans, pos_rel, + pos_tail_trans) + neg_score = self.score_with_l2_normalize(neg_head_trans, neg_rel, + neg_tail_trans) + + pos = fluid.layers.reduce_sum( + fluid.layers.abs(pos_score), -1, keep_dim=False) + neg = fluid.layers.reduce_sum( + fluid.layers.abs(neg_score), -1, keep_dim=False) + loss = fluid.layers.reduce_mean( + fluid.layers.relu(pos - neg + self._margin)) + return [loss] + + def construct_test_program(self): + """ + Construct test program + """ + entity_embedding, relation_embedding, transfer_matrix = self.creat_share_variables( + ) + rel_matrix = fluid.layers.reshape( + lookup_table(self.test_input[1], transfer_matrix), + [self._hidden_size, self._hidden_size]) + entity_embedding_trans = fluid.layers.matmul(entity_embedding, + rel_matrix, False, False) + rel_vec = lookup_table(self.test_input[1], relation_embedding) + entity_embedding_trans = fluid.layers.l2_normalize( + entity_embedding_trans, axis=-1) + rel_vec = fluid.layers.l2_normalize(rel_vec, axis=-1) + head_vec = lookup_table(self.test_input[0], entity_embedding_trans) + tail_vec = lookup_table(self.test_input[2], entity_embedding_trans) + + # The paddle fluid.layers.topk GPU OP is very inefficient + # we do sort operation in the evaluation step using multiprocessing + id_replace_head = fluid.layers.reduce_sum( + fluid.layers.abs(entity_embedding_trans + rel_vec - tail_vec), + dim=1) + id_replace_tail = fluid.layers.reduce_sum( + fluid.layers.abs(entity_embedding_trans - rel_vec - head_vec), + dim=1) + + return [id_replace_head, id_replace_tail] diff --git a/examples/kg/model/__init__.py b/examples/kg/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b827d6aa95ff7b27f9fa38696ca6028459b5b83e --- /dev/null +++ b/examples/kg/model/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""import all models""" + +from .TransE import TransE +from .TransR import TransR +from .RotatE import RotatE +model_dict = { + "TransE": TransE, + "transe": TransE, + "TransR": TransR, + "transr": TransR, + "RotatE": RotatE +} diff --git a/examples/kg/model/utils.py b/examples/kg/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b6952316db3353362e57dca4e9d08204134ad4be --- /dev/null +++ b/examples/kg/model/utils.py @@ -0,0 +1,58 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utils for the models. +""" +import paddle.fluid as fluid +from paddle.fluid.layer_helper import LayerHelper + + +def lookup_table(input, embedding_table, dtype='float32'): + """ + lookup table support for paddle. + :param input: + :param embedding_table: + :param dtype: + :return: + """ + is_sparse = False + is_distributed = False + helper = LayerHelper('embedding', **locals()) + remote_prefetch = is_sparse and (not is_distributed) + if remote_prefetch: + assert is_sparse is True and is_distributed is False + tmp = helper.create_variable_for_type_inference(dtype) + padding_idx = -1 + helper.append_op( + type='lookup_table', + inputs={'Ids': input, + 'W': embedding_table}, + outputs={'Out': tmp}, + attrs={ + 'is_sparse': is_sparse, + 'is_distributed': is_distributed, + 'remote_prefetch': remote_prefetch, + 'padding_idx': padding_idx + }) + return tmp + + +def lookup_table_gather(index, input): + """ + lookup table support for paddle by gather. + :param index: + :param input: + :return: + """ + return fluid.layers.gather(index=index, input=input, overwrite=False) diff --git a/examples/kg/mp_mapper.py b/examples/kg/mp_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..9c4fbeb436eff3d105723ba0447fca6574ae3ec7 --- /dev/null +++ b/examples/kg/mp_mapper.py @@ -0,0 +1,99 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This file aims to use multiprocessing to do following process. + ` + for data in reader(): + yield func(data) + ` +""" +#encoding=utf8 +import numpy as np +import multiprocessing as mp +import traceback +from pgl.utils.logger import log + + +def mp_reader_mapper(reader, func, num_works=4): + """ + This function aims to use multiprocessing to do following process. + ` + for data in reader(): + yield func(data) + ` + The data in_stream is the `reader`, the mapper is map the in_stream to + an out_stream. + Please ensure the `func` have specific return value, not `None`! + :param reader: the data iterator + :param func: the map func + :param num_works: number of works + :return: an new iterator + """ + + def _read_into_pipe(func, conn): + """ + read into pipe, and use the `func` to get final data. + """ + while True: + data = conn.recv() + if data is None: + conn.send(None) + conn.close() + break + conn.send(func(data)) + + def pipe_reader(): + """pipe_reader""" + conns = [] + all_process = [] + for w in range(num_works): + parent_conn, child_conn = mp.Pipe() + conns.append(parent_conn) + p = mp.Process(target=_read_into_pipe, args=(func, child_conn)) + p.start() + all_process.append(p) + + data_iter = reader() + + def next_data(): + """next_data""" + _next = None + try: + _next = data_iter.next() + except StopIteration: + # log.debug(traceback.format_exc()) + pass + except Exception: + log.debug(traceback.format_exc()) + return _next + + for i in range(num_works): + conns[i].send(next_data()) + + finish_num = 0 + finish_flag = np.zeros(len(conns), dtype="int32") + while finish_num < num_works: + for conn_id, conn in enumerate(conns): + if finish_flag[conn_id] > 0: + continue + sample = conn.recv() + if sample is None: + finish_num += 1 + conn.close() + finish_flag[conn_id] = 1 + else: + yield sample + conns[conn_id].send(next_data()) + + return pipe_reader diff --git a/examples/kg/run.sh b/examples/kg/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..fbc53e02aade4d9cb317ac374bb45247333b8520 --- /dev/null +++ b/examples/kg/run.sh @@ -0,0 +1,44 @@ +#CUDA_VISIBLE_DEVICES=2 \ +#FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +#python main.py \ +# --use_cuda \ +# --model TransE \ +# --optimizer adam \ +# --batch_size=512 \ +# --learning_rate=0.001 \ +# --epoch 100 \ +# --evaluate_per_iteration 20 \ +# --sample_workers 4 \ +# --margin 4 \ +## #--only_evaluate + +#CUDA_VISIBLE_DEVICES=2 \ +#FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +#python main.py \ +# --use_cuda \ +# --model RotatE \ +# --data_dir ./data/WN18 \ +# --optimizer adam \ +# --batch_size=512 \ +# --learning_rate=0.001 \ +# --epoch 100 \ +# --evaluate_per_iteration 100 \ +# --sample_workers 10 \ +# --margin 6 \ +# --neg_times 10 + +CUDA_VISIBLE_DEVICES=2 \ +FLAGS_fraction_of_gpu_memory_to_use=0.01 \ +python main.py \ + --use_cuda \ + --model RotatE \ + --data_dir ./data/FB15k \ + --optimizer adam \ + --batch_size=512 \ + --learning_rate=0.001 \ + --epoch 100 \ + --evaluate_per_iteration 100 \ + --sample_workers 10 \ + --margin 8 \ + --neg_times 10 \ + --neg_mode True