diff --git a/deep_speech_2/data/librispeech/librispeech.py b/deep_speech_2/data/librispeech/librispeech.py index d963a7d5372d64f3abb1dcbdd16dbdafc1888de0..a485904a73cba059e2c3df173efd691c3641036f 100644 --- a/deep_speech_2/data/librispeech/librispeech.py +++ b/deep_speech_2/data/librispeech/librispeech.py @@ -12,12 +12,12 @@ from __future__ import print_function import distutils.util import os import sys -import tarfile import argparse import soundfile import json import codecs from paddle.v2.dataset.common import md5file +from data_utils.utility import download, unpack DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') @@ -59,33 +59,6 @@ parser.add_argument( args = parser.parse_args() -def download(url, md5sum, target_dir): - """ - Download file from url to target_dir, and check md5sum. - """ - if not os.path.exists(target_dir): os.makedirs(target_dir) - filepath = os.path.join(target_dir, url.split("/")[-1]) - if not (os.path.exists(filepath) and md5file(filepath) == md5sum): - print("Downloading %s ..." % url) - os.system("wget -c " + url + " -P " + target_dir) - print("\nMD5 Chesksum %s ..." % filepath) - if not md5file(filepath) == md5sum: - raise RuntimeError("MD5 checksum failed.") - else: - print("File exists, skip downloading. (%s)" % filepath) - return filepath - - -def unpack(filepath, target_dir): - """ - Unpack the file to the target_dir. - """ - print("Unpacking %s ..." % filepath) - tar = tarfile.open(filepath) - tar.extractall(target_dir) - tar.close() - - def create_manifest(data_dir, manifest_path): """ Create a manifest json file summarizing the data set, with each line diff --git a/deep_speech_2/data_utils/utility.py b/deep_speech_2/data_utils/utility.py index f970ff55adeee0e1a4613143db1145e617b3699c..e1e3b55e75b5085d36d495dff7363aab62960134 100644 --- a/deep_speech_2/data_utils/utility.py +++ b/deep_speech_2/data_utils/utility.py @@ -5,6 +5,8 @@ from __future__ import print_function import json import codecs +import os +import tarfile def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0): @@ -33,3 +35,28 @@ def read_manifest(manifest_path, max_duration=float('inf'), min_duration=0.0): json_data["duration"] >= min_duration): manifest.append(json_data) return manifest + + +def download(url, md5sum, target_dir): + """Download file from url to target_dir, and check md5sum.""" + if not os.path.exists(target_dir): os.makedirs(target_dir) + filepath = os.path.join(target_dir, url.split("/")[-1]) + if not (os.path.exists(filepath) and md5file(filepath) == md5sum): + print("Downloading %s ..." % url) + os.system("wget -c " + url + " -P " + target_dir) + print("\nMD5 Chesksum %s ..." % filepath) + if not md5file(filepath) == md5sum: + raise RuntimeError("MD5 checksum failed.") + else: + print("File exists, skip downloading. (%s)" % filepath) + return filepath + + +def unpack(filepath, target_dir, rm_tar=False): + """Unpack the file to the target_dir.""" + print("Unpacking %s ..." % filepath) + tar = tarfile.open(filepath) + tar.extractall(target_dir) + tar.close() + if rm_tar == True: + os.remove(filepath)