dump_graph.py 4.1 KB
Newer Older
W
weiyue.su 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
#!/usr/bin/env python
# -*- coding: utf-8 -*-
########################################################################
#
# Copyright (c) 2020 Baidu.com, Inc. All Rights Reserved
#
# File: dump_graph.py
# Author: suweiyue(suweiyue@baidu.com)
# Date: 2020/03/01 22:17:13
#
########################################################################
"""
    Comment.
"""
from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
#from __future__ import unicode_literals

import io
import os
import sys
import argparse
import logging
import multiprocessing
from functools import partial
from io import open

import numpy as np
import tqdm
import pgl
from pgl.graph_kernel import alias_sample_build_table
from pgl.utils.logger import log

from tokenization import FullTokenizer


def term2id(string, tokenizer, max_seqlen):
    string = string.split("\t")[1]
    tokens = tokenizer.tokenize(string)
    ids = tokenizer.convert_tokens_to_ids(tokens)
    ids = ids[:max_seqlen-1]
    ids = ids + [2] # ids + [sep]
    ids = ids + [0] * (max_seqlen - len(ids))
    return ids


def dump_graph(args):
    if not os.path.exists(args.outpath):
        os.makedirs(args.outpath)
    neg_samples = []
    str2id = dict()
    term_file = io.open(os.path.join(args.outpath, "terms.txt"), "w", encoding=args.encoding)
    terms = []
    count = 0

    with io.open(args.inpath, encoding=args.encoding) as f:
        edges = []
        for idx, line in enumerate(f):
            if idx % 100000 == 0:
                log.info("%s readed %s lines" % (args.inpath, idx))
            slots = []
            for col_idx, col in enumerate(line.strip("\n").split("\t")):
                s = col[:args.max_seqlen]
                if s not in str2id:
                    str2id[s] = count
                    count += 1
                    term_file.write(str(col_idx) + "\t" + col + "\n")
                    
                slots.append(str2id[s])

            src = slots[0]
            dst = slots[1]
            neg_samples.append(slots[2:])
            edges.append((src, dst))
            edges.append((dst, src))

        term_file.close()
        edges = np.array(edges, dtype="int64")
        num_nodes = len(str2id)
        str2id.clear()
    log.info("building graph...")
    graph = pgl.graph.Graph(num_nodes=num_nodes, edges=edges)
    indegree = graph.indegree()
    graph.outdegree()
    graph.dump(args.outpath)
    
    # dump alias sample table
    sqrt_indegree = np.sqrt(indegree)
    distribution = 1. * sqrt_indegree / sqrt_indegree.sum()
    alias, events = alias_sample_build_table(distribution)
    np.save(os.path.join(args.outpath, "alias.npy"), alias)
    np.save(os.path.join(args.outpath, "events.npy"), events)
    np.save(os.path.join(args.outpath, "neg_samples.npy"), np.array(neg_samples))
    log.info("End Build Graph")

def dump_id2str_map(args):
    log.info("Dump id2str map starting...")
    id2str = np.array([line.strip("\n") for line in open(os.path.join(args.outpath, "terms.txt"), "r", encoding=args.encoding)])
    np.save(os.path.join(args.outpath, "id2str.npy"), id2str)
    log.info("Dump id2str map done.")

def dump_node_feat(args):
    log.info("Dump node feat starting...")
    id2str = np.load(os.path.join(args.outpath, "id2str.npy"), mmap_mode="r")
    pool = multiprocessing.Pool()
    tokenizer = FullTokenizer(args.vocab_file)
    term_ids = pool.map(partial(term2id, tokenizer=tokenizer, max_seqlen=args.max_seqlen), id2str)
    np.save(os.path.join(args.outpath, "term_ids.npy"), np.array(term_ids))
    log.info("Dump node feat done.")
    pool.terminate()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='main')
    parser.add_argument("-i", "--inpath", type=str, default=None)
    parser.add_argument("-l", "--max_seqlen", type=int, default=30)
    parser.add_argument("--vocab_file", type=str, default="./vocab.txt")
    parser.add_argument("--encoding", type=str, default="utf8")
    parser.add_argument("-o", "--outpath", type=str, default=None)
    args = parser.parse_args()
    dump_graph(args)
    dump_id2str_map(args)
    dump_node_feat(args)