diff --git a/deepwalk/__main__.py b/deepwalk/__main__.py index 00f7bc33aea28a69156ddc48fb085b384c8df29f..c331aef8cb5255e3cd884c0606a0dceb95505d03 100644 --- a/deepwalk/__main__.py +++ b/deepwalk/__main__.py @@ -1,6 +1,7 @@ #! /usr/bin/env python # -*- coding: utf-8 -*- +import os import sys import random from io import open @@ -16,6 +17,13 @@ from skipgram import Skipgram from six import text_type as unicode from six import iteritems +from six.moves import range + +import psutil +from multiprocessing import cpu_count + +p = psutil.Process(os.getpid()) +p.set_cpu_affinity(list(range(cpu_count()))) logger = logging.getLogger(__name__) LOGFORMAT = "%(asctime).19s %(levelname)s %(filename)s: %(lineno)s %(message)s" @@ -68,11 +76,15 @@ def process(args): path_length=args.walk_length, alpha=0, rand=random.Random(args.seed), num_workers=args.workers) - # use degree distribution for frequency in tree - vertex_frequency = G.degree(nodes=G.iterkeys()) + print("Counting vertex frequency...") + if not args.vertex_freq_degree: + vertex_counts = serialized_walks.count_textfiles(walk_files, args.workers) + else: + # use degree distribution for frequency in tree + vertex_counts = G.degree(nodes=G.iterkeys()) print("Training...") - model = Skipgram(sentences=serialized_walks.combine_files_iter(walk_files), vocabulary_counts=vertex_frequency, + model = Skipgram(sentences=serialized_walks.combine_files_iter(walk_files), vocabulary_counts=vertex_counts, size=args.representation_size, window=args.window_size, min_count=0, workers=args.workers) @@ -117,6 +129,11 @@ def main(): parser.add_argument('--undirected', default=True, type=bool, help='Treat graph as undirected.') + parser.add_argument('--vertex-freq-degree', default=False, action='store_true', + help='Use vertex degree to estimate the frequency of nodes ' + 'in the random walks. This option is faster than ' + 'calculating the vocabulary.') + parser.add_argument('--walk-length', default=40, type=int, help='Length of the random walk started at each node') diff --git a/deepwalk/walks.py b/deepwalk/walks.py index eaa29540b4a296944637f5a187d1ca21da891623..58d497f4a2528f0fb336ead276e2d150053845a6 100644 --- a/deepwalk/walks.py +++ b/deepwalk/walks.py @@ -2,10 +2,12 @@ import logging from io import open from os import path from time import time -from itertools import izip from multiprocessing import cpu_count import random from concurrent.futures import ProcessPoolExecutor +from collections import Counter + +from six.moves import zip from deepwalk import graph @@ -16,6 +18,29 @@ __current_graph = None # speed up the string encoding __vertex2str = None +def count_words(file): + """ Counts the word frequences in a list of sentences. + + Note: + This is a helper function for parallel execution of `Vocabulary.from_text` + method. + """ + c = Counter() + with open(file, 'r') as f: + for l in f: + words = l.strip().split() + c.update(words) + return c + + +def count_textfiles(files, workers=1): + c = Counter() + with ProcessPoolExecutor(max_workers=workers) as executor: + for c_ in executor.map(count_words, files): + c.update(c_) + return c + + def count_lines(f): if path.isfile(f): num_lines = sum(1 for line in open(f)) @@ -52,7 +77,7 @@ def write_walks_to_disk(G, filebase, num_paths, path_length, alpha=0, rand=rando for x in graph.grouper(int(num_paths / num_workers)+1, range(1, num_paths+1))] with ProcessPoolExecutor(max_workers=num_workers) as executor: - for size, file_, ppw in izip(executor.map(count_lines, files_list), files_list, paths_per_worker): + for size, file_, ppw in zip(executor.map(count_lines, files_list), files_list, paths_per_worker): if always_rebuild or size != (ppw*expected_size): args_list.append((ppw, path_length, alpha, random.Random(rand.randint(0, 2**31)), file_)) else: @@ -68,4 +93,4 @@ def combine_files_iter(file_list): for file in file_list: with open(file, 'r') as f: for line in f: - yield line.split() \ No newline at end of file + yield line.split() diff --git a/requirements.txt b/requirements.txt index b607e1d2f4736ab85c9e05f4c52e6acbeb761af2..3138a4d8a8adf2fb9aee42d64db10f1ea0ae3f6e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ futures>=2.1.6 six>=1.7.3 gensim>=0.10.0 scipy>=0.7.0 +psutil>=2.1.1