提交 7ffd806e 编写于 作者: R Rami Al-Rfou

Adding parallel support to counting vertex frequency

上级 0a5ce370
#! /usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import random
from io import open
......@@ -16,6 +17,13 @@ from skipgram import Skipgram
from six import text_type as unicode
from six import iteritems
from six.moves import range
import psutil
from multiprocessing import cpu_count
p = psutil.Process(os.getpid())
p.set_cpu_affinity(list(range(cpu_count())))
logger = logging.getLogger(__name__)
LOGFORMAT = "%(asctime).19s %(levelname)s %(filename)s: %(lineno)s %(message)s"
......@@ -68,11 +76,15 @@ def process(args):
path_length=args.walk_length, alpha=0, rand=random.Random(args.seed),
num_workers=args.workers)
# use degree distribution for frequency in tree
vertex_frequency = G.degree(nodes=G.iterkeys())
print("Counting vertex frequency...")
if not args.vertex_freq_degree:
vertex_counts = serialized_walks.count_textfiles(walk_files, args.workers)
else:
# use degree distribution for frequency in tree
vertex_counts = G.degree(nodes=G.iterkeys())
print("Training...")
model = Skipgram(sentences=serialized_walks.combine_files_iter(walk_files), vocabulary_counts=vertex_frequency,
model = Skipgram(sentences=serialized_walks.combine_files_iter(walk_files), vocabulary_counts=vertex_counts,
size=args.representation_size,
window=args.window_size, min_count=0, workers=args.workers)
......@@ -117,6 +129,11 @@ def main():
parser.add_argument('--undirected', default=True, type=bool,
help='Treat graph as undirected.')
parser.add_argument('--vertex-freq-degree', default=False, action='store_true',
help='Use vertex degree to estimate the frequency of nodes '
'in the random walks. This option is faster than '
'calculating the vocabulary.')
parser.add_argument('--walk-length', default=40, type=int,
help='Length of the random walk started at each node')
......
......@@ -2,10 +2,12 @@ import logging
from io import open
from os import path
from time import time
from itertools import izip
from multiprocessing import cpu_count
import random
from concurrent.futures import ProcessPoolExecutor
from collections import Counter
from six.moves import zip
from deepwalk import graph
......@@ -16,6 +18,29 @@ __current_graph = None
# speed up the string encoding
__vertex2str = None
def count_words(file):
""" Counts the word frequences in a list of sentences.
Note:
This is a helper function for parallel execution of `Vocabulary.from_text`
method.
"""
c = Counter()
with open(file, 'r') as f:
for l in f:
words = l.strip().split()
c.update(words)
return c
def count_textfiles(files, workers=1):
c = Counter()
with ProcessPoolExecutor(max_workers=workers) as executor:
for c_ in executor.map(count_words, files):
c.update(c_)
return c
def count_lines(f):
if path.isfile(f):
num_lines = sum(1 for line in open(f))
......@@ -52,7 +77,7 @@ def write_walks_to_disk(G, filebase, num_paths, path_length, alpha=0, rand=rando
for x in graph.grouper(int(num_paths / num_workers)+1, range(1, num_paths+1))]
with ProcessPoolExecutor(max_workers=num_workers) as executor:
for size, file_, ppw in izip(executor.map(count_lines, files_list), files_list, paths_per_worker):
for size, file_, ppw in zip(executor.map(count_lines, files_list), files_list, paths_per_worker):
if always_rebuild or size != (ppw*expected_size):
args_list.append((ppw, path_length, alpha, random.Random(rand.randint(0, 2**31)), file_))
else:
......@@ -68,4 +93,4 @@ def combine_files_iter(file_list):
for file in file_list:
with open(file, 'r') as f:
for line in f:
yield line.split()
\ No newline at end of file
yield line.split()
......@@ -5,3 +5,4 @@ futures>=2.1.6
six>=1.7.3
gensim>=0.10.0
scipy>=0.7.0
psutil>=2.1.1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册