From d104eccf6784585aa54d931b95db9364cac7744e Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Tue, 20 Jun 2017 18:13:46 +0800 Subject: [PATCH] Update the default num_threads for DS2 data generator. --- data_utils/data.py | 3 ++- infer.py | 3 ++- train.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/data_utils/data.py b/data_utils/data.py index 8391dacc..44af7ffa 100644 --- a/data_utils/data.py +++ b/data_utils/data.py @@ -7,6 +7,7 @@ from __future__ import print_function import random import numpy as np +import multiprocessing import paddle.v2 as paddle from data_utils import utils from data_utils.augmentor.augmentation import AugmentationPipeline @@ -60,7 +61,7 @@ class DataGenerator(object): window_ms=20.0, max_freq=None, specgram_type='linear', - num_threads=12, + num_threads=multiprocessing.cpu_count(), random_seed=0): self._max_duration = max_duration self._min_duration = min_duration diff --git a/infer.py b/infer.py index 7fc84829..71518133 100644 --- a/infer.py +++ b/infer.py @@ -6,6 +6,7 @@ from __future__ import print_function import argparse import gzip import distutils.util +import multiprocessing import paddle.v2 as paddle from data_utils.data import DataGenerator from model import deep_speech2 @@ -40,7 +41,7 @@ parser.add_argument( help="Use gpu or not. (default: %(default)s)") parser.add_argument( "--num_threads_data", - default=12, + default=multiprocessing.cpu_count(), type=int, help="Number of cpu threads for preprocessing data. (default: %(default)s)") parser.add_argument( diff --git a/train.py b/train.py index 2c3b8ce7..fc23ec72 100644 --- a/train.py +++ b/train.py @@ -9,6 +9,7 @@ import argparse import gzip import time import distutils.util +import multiprocessing import paddle.v2 as paddle from model import deep_speech2 from data_utils.data import DataGenerator @@ -77,7 +78,7 @@ parser.add_argument( help="Trainer number. (default: %(default)s)") parser.add_argument( "--num_threads_data", - default=12, + default=multiprocessing.cpu_count(), type=int, help="Number of cpu threads for preprocessing data. (default: %(default)s)") parser.add_argument( -- GitLab