dump_graph.py 6.5 KB
Newer Older
W
weiyue.su 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
#!/usr/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved
#
# File: dump_graph.py
# Author: suweiyue(suweiyue@baidu.com)
# Date: 2020/03/01 22:17:13
#
########################################################################
"""
    Comment.
"""
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
#from __future__ import unicode_literals

import io
import os
import sys
import argparse
import logging
import multiprocessing
from functools import partial
from io import open

import numpy as np
import tqdm
import pgl
from pgl.graph_kernel import alias_sample_build_table
from pgl.utils.logger import log

from tokenization import FullTokenizer


def term2id(string, tokenizer, max_seqlen):
39
    #string = string.split("\t")[1]
W
weiyue.su 已提交
40 41 42 43 44 45 46
    tokens = tokenizer.tokenize(string)
    ids = tokenizer.convert_tokens_to_ids(tokens)
    ids = ids[:max_seqlen-1]
    ids = ids + [2] # ids + [sep]
    ids = ids + [0] * (max_seqlen - len(ids))
    return ids

47
def load_graph(args, str2id, term_file, terms, item_distribution):
S
suweiyue 已提交
48 49
    edges = []
    with io.open(args.graphpath, encoding=args.encoding) as f:
W
weiyue.su 已提交
50 51
        for idx, line in enumerate(f):
            if idx % 100000 == 0:
S
suweiyue 已提交
52
                log.info("%s readed %s lines" % (args.graphpath, idx))
W
weiyue.su 已提交
53 54 55 56
            slots = []
            for col_idx, col in enumerate(line.strip("\n").split("\t")):
                s = col[:args.max_seqlen]
                if s not in str2id:
57
                    str2id[s] = len(str2id)
W
weiyue.su 已提交
58
                    term_file.write(str(col_idx) + "\t" + col + "\n")
S
suweiyue 已提交
59
                    item_distribution.append(0)
W
weiyue.su 已提交
60 61 62 63 64 65
                slots.append(str2id[s])

            src = slots[0]
            dst = slots[1]
            edges.append((src, dst))
            edges.append((dst, src))
S
suweiyue 已提交
66
            item_distribution[dst] += 1
S
suweiyue 已提交
67
    edges = np.array(edges, dtype="int64")
68
    return edges
S
suweiyue 已提交
69

70 71 72
def load_link_predict_train_data(args, str2id, term_file, terms, item_distribution):
    train_data = []
    neg_samples = []
S
suweiyue 已提交
73 74 75 76 77 78 79 80
    with io.open(args.inpath, encoding=args.encoding) as f:
        for idx, line in enumerate(f):
            if idx % 100000 == 0:
                log.info("%s readed %s lines" % (args.inpath, idx))
            slots = []
            for col_idx, col in enumerate(line.strip("\n").split("\t")):
                s = col[:args.max_seqlen]
                if s not in str2id:
81
                    str2id[s] = len(str2id)
S
suweiyue 已提交
82 83 84 85 86 87 88 89 90
                    term_file.write(str(col_idx) + "\t" + col + "\n")
                    item_distribution.append(0)
                slots.append(str2id[s])

            src = slots[0]
            dst = slots[1]
            neg_samples.append(slots[2:])
            train_data.append((src, dst))
    train_data = np.array(train_data, dtype="int64")
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
    np.save(os.path.join(args.outpath, "train_data.npy"), train_data)
    if len(neg_samples) != 0:
        np.save(os.path.join(args.outpath, "neg_samples.npy"), np.array(neg_samples))

def load_node_classification_train_data(args, str2id, term_file, terms, item_distribution):
    train_data = []
    neg_samples = []
    with io.open(args.inpath, encoding=args.encoding) as f:
        for idx, line in enumerate(f):
            if idx % 100000 == 0:
                log.info("%s readed %s lines" % (args.inpath, idx))
            slots = []
            col_idx = 0
            slots = line.strip("\n").split("\t")
            col = slots[0]
            label = int(slots[1])
            text = col[:args.max_seqlen]
            if text not in str2id:
                str2id[text] = len(str2id)
                term_file.write(str(col_idx) + "\t" + col + "\n")
                item_distribution.append(0)
            src = str2id[text]
            train_data.append([src, label])
    train_data = np.array(train_data, dtype="int64")
    np.save(os.path.join(args.outpath, "train_data.npy"), train_data)

def dump_graph(args):
    if not os.path.exists(args.outpath):
        os.makedirs(args.outpath)
    str2id = dict()
    term_file = io.open(os.path.join(args.outpath, "terms.txt"), "w", encoding=args.encoding)
    terms = []
    item_distribution = []

    edges = load_graph(args, str2id, term_file, terms, item_distribution)
    #load_train_data(args, str2id, term_file, terms, item_distribution)
    if args.task == "link_predict":
        load_link_predict_train_data(args, str2id, term_file, terms, item_distribution)
    elif args.task == "node_classification":
        load_node_classification_train_data(args, str2id, term_file, terms, item_distribution)
    else:
        raise ValueError
S
suweiyue 已提交
133 134 135 136

    term_file.close()
    num_nodes = len(str2id)
    str2id.clear()
W
weiyue.su 已提交
137 138 139 140

    log.info("building graph...")
    graph = pgl.graph.Graph(num_nodes=num_nodes, edges=edges)
    indegree = graph.indegree()
S
suweiyue 已提交
141
    graph.indegree()
W
weiyue.su 已提交
142 143 144 145
    graph.outdegree()
    graph.dump(args.outpath)
    
    # dump alias sample table
S
suweiyue 已提交
146 147 148
    item_distribution = np.array(item_distribution)
    item_distribution = np.sqrt(item_distribution)
    distribution = 1. * item_distribution / item_distribution.sum()
W
weiyue.su 已提交
149 150 151 152 153 154 155
    alias, events = alias_sample_build_table(distribution)
    np.save(os.path.join(args.outpath, "alias.npy"), alias)
    np.save(os.path.join(args.outpath, "events.npy"), events)
    log.info("End Build Graph")

def dump_node_feat(args):
    log.info("Dump node feat starting...")
156
    id2str = [line.strip("\n").split("\t")[-1] for line in io.open(os.path.join(args.outpath, "terms.txt"), encoding=args.encoding)]
W
weiyue.su 已提交
157 158 159
    pool = multiprocessing.Pool()
    tokenizer = FullTokenizer(args.vocab_file)
    term_ids = pool.map(partial(term2id, tokenizer=tokenizer, max_seqlen=args.max_seqlen), id2str)
160
    np.save(os.path.join(args.outpath, "term_ids.npy"), np.array(term_ids, np.uint16))
W
weiyue.su 已提交
161 162 163 164 165 166
    log.info("Dump node feat done.")
    pool.terminate()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='main')
    parser.add_argument("-i", "--inpath", type=str, default=None)
S
suweiyue 已提交
167
    parser.add_argument("-g", "--graphpath", type=str, default=None)
W
weiyue.su 已提交
168 169 170
    parser.add_argument("-l", "--max_seqlen", type=int, default=30)
    parser.add_argument("--vocab_file", type=str, default="./vocab.txt")
    parser.add_argument("--encoding", type=str, default="utf8")
171
    parser.add_argument("--task", type=str, default="link_predict", choices=["link_predict", "node_classification"])
W
weiyue.su 已提交
172 173 174 175
    parser.add_argument("-o", "--outpath", type=str, default=None)
    args = parser.parse_args()
    dump_graph(args)
    dump_node_feat(args)