提交 dee672a7 编写于 作者: H Hui Zhang

flake8

上级 64f177cc
[flake8]
########## OPTIONS ##########
# Set the maximum length that any line (with some exceptions) may be.
max-line-length = 120
################### FILE PATTERNS ##########################
# Provide a comma-separated list of glob patterns to exclude from checks.
exclude =
# git folder
.git,
# python cache
__pycache__,
# Provide a comma-separate list of glob patterns to include for checks.
filename =
*.py
########## RULES ##########
# ERROR CODES
#
# E/W - PEP8 errors/warnings (pycodestyle)
# F - linting errors (pyflakes)
# C - McCabe complexity error (mccabe)
#
# W503 - line break before binary operator
# Specify a list of codes to ignore.
ignore =
W503
E252,E262,E127,E265,E126,E266,E241,E261,E128,E125
W291,W293,W605
E203,E305,E402,E501,E721,E741,F403,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
EXE001,
# these ignores are from flake8-bugbear; please fix!
B007,B008,
# these ignores are from flake8-comprehensions; please fix!
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415
# Specify the list of error codes you wish Flake8 to report.
select =
E,
W,
F,
C
\ No newline at end of file
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Union from typing import Union
from typing import Optional
from typing import List from typing import List
from typing import Tuple from typing import Tuple
from typing import Any from typing import Any
...@@ -21,7 +20,6 @@ from typing import Any ...@@ -21,7 +20,6 @@ from typing import Any
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
from paddle.nn import initializer as I
#TODO(Hui Zhang): remove fluid import #TODO(Hui Zhang): remove fluid import
from paddle.fluid import core from paddle.fluid import core
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -242,7 +240,7 @@ def is_broadcastable(shp1, shp2): ...@@ -242,7 +240,7 @@ def is_broadcastable(shp1, shp2):
def masked_fill(xs: paddle.Tensor, def masked_fill(xs: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
value: Union[float, int]): value: Union[float, int]):
assert is_broadcastable(xs.shape, mask.shape) == True assert is_broadcastable(xs.shape, mask.shape) is True
bshape = paddle.broadcast_shape(xs.shape, mask.shape) bshape = paddle.broadcast_shape(xs.shape, mask.shape)
mask = mask.broadcast_to(bshape) mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value trues = paddle.ones_like(xs) * value
...@@ -259,7 +257,7 @@ if not hasattr(paddle.Tensor, 'masked_fill'): ...@@ -259,7 +257,7 @@ if not hasattr(paddle.Tensor, 'masked_fill'):
def masked_fill_(xs: paddle.Tensor, def masked_fill_(xs: paddle.Tensor,
mask: paddle.Tensor, mask: paddle.Tensor,
value: Union[float, int]): value: Union[float, int]):
assert is_broadcastable(xs.shape, mask.shape) == True assert is_broadcastable(xs.shape, mask.shape) is True
bshape = paddle.broadcast_shape(xs.shape, mask.shape) bshape = paddle.broadcast_shape(xs.shape, mask.shape)
mask = mask.broadcast_to(bshape) mask = mask.broadcast_to(bshape)
trues = paddle.ones_like(xs) * value trues = paddle.ones_like(xs) * value
......
...@@ -104,14 +104,14 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -104,14 +104,14 @@ def ctc_beam_search_decoder(probs_seq,
global ext_nproc_scorer global ext_nproc_scorer
ext_scoring_func = ext_nproc_scorer ext_scoring_func = ext_nproc_scorer
## initialize # initialize
# prefix_set_prev: the set containing selected prefixes # prefix_set_prev: the set containing selected prefixes
# probs_b_prev: prefixes' probability ending with blank in previous step # probs_b_prev: prefixes' probability ending with blank in previous step
# probs_nb_prev: prefixes' probability ending with non-blank in previous step # probs_nb_prev: prefixes' probability ending with non-blank in previous step
prefix_set_prev = {'\t': 1.0} prefix_set_prev = {'\t': 1.0}
probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0} probs_b_prev, probs_nb_prev = {'\t': 1.0}, {'\t': 0.0}
## extend prefix in loop # extend prefix in loop
for time_step in range(len(probs_seq)): for time_step in range(len(probs_seq)):
# prefix_set_next: the set containing candidate prefixes # prefix_set_next: the set containing candidate prefixes
# probs_b_cur: prefixes' probability ending with blank in current step # probs_b_cur: prefixes' probability ending with blank in current step
...@@ -120,7 +120,7 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -120,7 +120,7 @@ def ctc_beam_search_decoder(probs_seq,
prob_idx = list(enumerate(probs_seq[time_step])) prob_idx = list(enumerate(probs_seq[time_step]))
cutoff_len = len(prob_idx) cutoff_len = len(prob_idx)
#If pruning is enabled # If pruning is enabled
if cutoff_prob < 1.0 or cutoff_top_n < cutoff_len: if cutoff_prob < 1.0 or cutoff_top_n < cutoff_len:
prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True) prob_idx = sorted(prob_idx, key=lambda asd: asd[1], reverse=True)
cutoff_len, cum_prob = 0, 0.0 cutoff_len, cum_prob = 0, 0.0
...@@ -172,7 +172,7 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -172,7 +172,7 @@ def ctc_beam_search_decoder(probs_seq,
# update probs # update probs
probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur probs_b_prev, probs_nb_prev = probs_b_cur, probs_nb_cur
## store top beam_size prefixes # store top beam_size prefixes
prefix_set_prev = sorted( prefix_set_prev = sorted(
prefix_set_next.items(), key=lambda asd: asd[1], reverse=True) prefix_set_next.items(), key=lambda asd: asd[1], reverse=True)
if beam_size < len(prefix_set_prev): if beam_size < len(prefix_set_prev):
...@@ -191,7 +191,7 @@ def ctc_beam_search_decoder(probs_seq, ...@@ -191,7 +191,7 @@ def ctc_beam_search_decoder(probs_seq,
else: else:
beam_result.append((float('-inf'), '')) beam_result.append((float('-inf'), ''))
## output top beam_size decoding results # output top beam_size decoding results
beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True) beam_result = sorted(beam_result, key=lambda asd: asd[0], reverse=True)
return beam_result return beam_result
......
...@@ -71,7 +71,7 @@ class Scorer(object): ...@@ -71,7 +71,7 @@ class Scorer(object):
""" """
lm = self._language_model_score(sentence) lm = self._language_model_score(sentence)
word_cnt = self._word_count(sentence) word_cnt = self._word_count(sentence)
if log == False: if log is False:
score = np.power(lm, self._alpha) * np.power(word_cnt, self._beta) score = np.power(lm, self._alpha) * np.power(word_cnt, self._beta)
else: else:
score = self._alpha * np.log(lm) + self._beta * np.log(word_cnt) score = self._alpha * np.log(lm) + self._beta * np.log(word_cnt)
......
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
from setuptools import setup, Extension, distutils from setuptools import setup, Extension, distutils
import glob import glob
import platform import platform
import os, sys import os
import sys
import multiprocessing.pool import multiprocessing.pool
import argparse import argparse
......
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
# limitations under the License. # limitations under the License.
"""Client-end for the ASR demo.""" """Client-end for the ASR demo."""
import keyboard import keyboard
import struct
import socket
import sys import sys
import argparse import argparse
import pyaudio import pyaudio
...@@ -49,7 +47,7 @@ def on_press_release(x): ...@@ -49,7 +47,7 @@ def on_press_release(x):
sys.stdout.flush() sys.stdout.flush()
is_recording = True is_recording = True
if x.event_type == 'up' and x.name == release.name: if x.event_type == 'up' and x.name == release.name:
if is_recording == True: if is_recording:
is_recording = False is_recording = False
......
...@@ -12,9 +12,6 @@ ...@@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Server-end for the ASR demo.""" """Server-end for the ASR demo."""
import os
import time
import argparse
import functools import functools
import paddle import paddle
import numpy as np import numpy as np
...@@ -26,7 +23,6 @@ from deepspeech.utils.socket_server import AsrRequestHandler ...@@ -26,7 +23,6 @@ from deepspeech.utils.socket_server import AsrRequestHandler
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.utility import add_arguments, print_arguments from deepspeech.utils.utility import add_arguments, print_arguments
from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.deepspeech2 import DeepSpeech2Model
...@@ -159,15 +155,13 @@ if __name__ == "__main__": ...@@ -159,15 +155,13 @@ if __name__ == "__main__":
"--params_file", "--params_file",
type=str, type=str,
default="", default="",
help= help="Parameter filename, Specify this when your model is a combined model."
"Parameter filename, Specify this when your model is a combined model."
) )
add_arg( add_arg(
"--model_dir", "--model_dir",
type=str, type=str,
default=None, default=None,
help= help="Model dir, If you load a non-combined model, specify the directory of the model."
"Model dir, If you load a non-combined model, specify the directory of the model."
) )
add_arg("--use_gpu", add_arg("--use_gpu",
type=bool, type=bool,
......
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Socket client to send wav to ASR server.""" """Socket client to send wav to ASR server."""
import struct
import socket
import argparse import argparse
import wave import wave
......
...@@ -12,9 +12,6 @@ ...@@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Server-end for the ASR demo.""" """Server-end for the ASR demo."""
import os
import time
import argparse
import functools import functools
import paddle import paddle
import numpy as np import numpy as np
...@@ -26,7 +23,6 @@ from deepspeech.utils.socket_server import AsrRequestHandler ...@@ -26,7 +23,6 @@ from deepspeech.utils.socket_server import AsrRequestHandler
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.utility import add_arguments, print_arguments from deepspeech.utils.utility import add_arguments, print_arguments
from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.deepspeech2 import DeepSpeech2Model
......
...@@ -12,17 +12,8 @@ ...@@ -12,17 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Export for DeepSpeech2 model.""" """Export for DeepSpeech2 model."""
import io
import logging
import argparse
import functools
from paddle import distributed as dist
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
from deepspeech.utils.error_rate import char_errors, word_errors
from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester as Tester from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester as Tester
......
...@@ -12,17 +12,8 @@ ...@@ -12,17 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inferer for DeepSpeech2 model.""" """Inferer for DeepSpeech2 model."""
import io
import logging
import argparse
import functools
from paddle import distributed as dist
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
from deepspeech.utils.error_rate import char_errors, word_errors
# TODO(hui zhang): dynamic load # TODO(hui zhang): dynamic load
from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.exps.deepspeech2.config import get_cfg_defaults
......
...@@ -12,17 +12,8 @@ ...@@ -12,17 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Evaluation for DeepSpeech2 model.""" """Evaluation for DeepSpeech2 model."""
import io
import logging
import argparse
import functools
from paddle import distributed as dist
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
from deepspeech.utils.error_rate import char_errors, word_errors
from deepspeech.exps.deepspeech2.config import get_cfg_defaults from deepspeech.exps.deepspeech2.config import get_cfg_defaults
from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester as Tester from deepspeech.exps.deepspeech2.model import DeepSpeech2Tester as Tester
......
...@@ -12,12 +12,6 @@ ...@@ -12,12 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Trainer for DeepSpeech2 model.""" """Trainer for DeepSpeech2 model."""
import io
import logging
import argparse
import functools
from paddle import distributed as dist from paddle import distributed as dist
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
......
...@@ -14,12 +14,8 @@ ...@@ -14,12 +14,8 @@
"""Beam search parameters tuning for DeepSpeech2 model.""" """Beam search parameters tuning for DeepSpeech2 model."""
import sys import sys
import os
import numpy as np import numpy as np
import argparse
import functools import functools
import gzip
import logging
from paddle.io import DataLoader from paddle.io import DataLoader
...@@ -122,7 +118,7 @@ def tune(config, args): ...@@ -122,7 +118,7 @@ def tune(config, args):
if index % 2 == 0: if index % 2 == 0:
sys.stdout.write('.') sys.stdout.write('.')
sys.stdout.flush() sys.stdout.flush()
print(f"tuneing: one grid done!") print("tuneing: one grid done!")
# output on-line tuning result at the end of current batch # output on-line tuning result at the end of current batch
err_ave_min = min(err_ave) err_ave_min = min(err_ave)
......
...@@ -14,13 +14,10 @@ ...@@ -14,13 +14,10 @@
"""Contains DeepSpeech2 model.""" """Contains DeepSpeech2 model."""
import io import io
import sys
import os
import time import time
import logging import logging
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
from functools import partial
from pathlib import Path from pathlib import Path
import paddle import paddle
...@@ -39,7 +36,6 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler ...@@ -39,7 +36,6 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.modules.loss import CTCLoss
from deepspeech.models.deepspeech2 import DeepSpeech2Model from deepspeech.models.deepspeech2 import DeepSpeech2Model
from deepspeech.models.deepspeech2 import DeepSpeech2InferModel from deepspeech.models.deepspeech2 import DeepSpeech2InferModel
......
...@@ -13,16 +13,8 @@ ...@@ -13,16 +13,8 @@
# limitations under the License. # limitations under the License.
"""Export for U2 model.""" """Export for U2 model."""
import io
import logging
import argparse
import functools
from paddle import distributed as dist
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
from deepspeech.utils.error_rate import char_errors, word_errors
from deepspeech.exps.u2.config import get_cfg_defaults from deepspeech.exps.u2.config import get_cfg_defaults
from deepspeech.exps.u2.model import U2Tester as Tester from deepspeech.exps.u2.model import U2Tester as Tester
......
...@@ -13,16 +13,8 @@ ...@@ -13,16 +13,8 @@
# limitations under the License. # limitations under the License.
"""Evaluation for U2 model.""" """Evaluation for U2 model."""
import io
import logging
import argparse
import functools
from paddle import distributed as dist
from deepspeech.training.cli import default_argument_parser from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
from deepspeech.utils.error_rate import char_errors, word_errors
# TODO(hui zhang): dynamic load # TODO(hui zhang): dynamic load
from deepspeech.exps.u2.config import get_cfg_defaults from deepspeech.exps.u2.config import get_cfg_defaults
......
...@@ -13,11 +13,6 @@ ...@@ -13,11 +13,6 @@
# limitations under the License. # limitations under the License.
"""Trainer for U2 model.""" """Trainer for U2 model."""
import io
import logging
import argparse
import functools
from paddle import distributed as dist from paddle import distributed as dist
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
......
...@@ -13,14 +13,10 @@ ...@@ -13,14 +13,10 @@
# limitations under the License. # limitations under the License.
"""Contains U2 model.""" """Contains U2 model."""
import io
import sys
import os
import time import time
import logging import logging
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
from functools import partial
from pathlib import Path from pathlib import Path
import paddle import paddle
...@@ -40,8 +36,6 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler ...@@ -40,8 +36,6 @@ from deepspeech.io.sampler import SortagradDistributedBatchSampler
from deepspeech.io.sampler import SortagradBatchSampler from deepspeech.io.sampler import SortagradBatchSampler
from deepspeech.io.dataset import ManifestDataset from deepspeech.io.dataset import ManifestDataset
from deepspeech.modules.loss import CTCLoss
from deepspeech.models.u2 import U2Model from deepspeech.models.u2 import U2Model
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -22,7 +22,6 @@ import resampy ...@@ -22,7 +22,6 @@ import resampy
from scipy import signal from scipy import signal
import random import random
import copy import copy
import io
class AudioSegment(object): class AudioSegment(object):
......
...@@ -14,8 +14,6 @@ ...@@ -14,8 +14,6 @@
"""Contains the audio featurizer class.""" """Contains the audio featurizer class."""
import numpy as np import numpy as np
from deepspeech.frontend.utility import read_manifest
from deepspeech.frontend.audio import AudioSegment
from python_speech_features import mfcc from python_speech_features import mfcc
from python_speech_features import logfbank from python_speech_features import logfbank
from python_speech_features import delta from python_speech_features import delta
...@@ -320,7 +318,7 @@ class AudioFeaturizer(object): ...@@ -320,7 +318,7 @@ class AudioFeaturizer(object):
if stride_ms > window_ms: if stride_ms > window_ms:
raise ValueError("Stride size must not be greater than " raise ValueError("Stride size must not be greater than "
"window size.") "window size.")
#(T, D) # (T, D)
fbank_feat = logfbank( fbank_feat = logfbank(
signal=samples, signal=samples,
samplerate=sample_rate, samplerate=sample_rate,
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
"""Contains the text featurizer class.""" """Contains the text featurizer class."""
import os
import sentencepiece as spm import sentencepiece as spm
from deepspeech.frontend.utility import UNK from deepspeech.frontend.utility import UNK
......
...@@ -16,15 +16,7 @@ import numpy as np ...@@ -16,15 +16,7 @@ import numpy as np
import math import math
import json import json
import codecs import codecs
import os
import tarfile
import time
import logging import logging
from typing import List
from threading import Thread
from multiprocessing import Process, Manager, Value
from paddle.dataset.common import md5file
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools
import numpy as np import numpy as np
from paddle.io import DataLoader from paddle.io import DataLoader
...@@ -131,7 +130,7 @@ def create_dataloader(manifest_path, ...@@ -131,7 +130,7 @@ def create_dataloader(manifest_path,
if keep_transcription_text: if keep_transcription_text:
padded_text[:len(text)] = [ord(t) for t in text] # string padded_text[:len(text)] = [ord(t) for t in text] # string
else: else:
padded_text[:len(text)] = text #ids padded_text[:len(text)] = text # ids
texts.append(padded_text) texts.append(padded_text)
text_lens.append(len(text)) text_lens.append(len(text))
...@@ -141,7 +140,7 @@ def create_dataloader(manifest_path, ...@@ -141,7 +140,7 @@ def create_dataloader(manifest_path,
text_lens = np.array(text_lens).astype('int64') text_lens = np.array(text_lens).astype('int64')
return padded_audios, audio_lens, texts, text_lens return padded_audios, audio_lens, texts, text_lens
#collate_fn=functools.partial(padding_batch, keep_transcription_text=keep_transcription_text), # collate_fn=functools.partial(padding_batch, keep_transcription_text=keep_transcription_text),
collate_fn = SpeechCollator(keep_transcription_text=keep_transcription_text) collate_fn = SpeechCollator(keep_transcription_text=keep_transcription_text)
loader = DataLoader( loader = DataLoader(
dataset, dataset,
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import logging import logging
import numpy as np import numpy as np
from collections import namedtuple
from deepspeech.io.utility import pad_sequence from deepspeech.io.utility import pad_sequence
from deepspeech.frontend.utility import IGNORE_ID from deepspeech.frontend.utility import IGNORE_ID
......
...@@ -13,13 +13,10 @@ ...@@ -13,13 +13,10 @@
# limitations under the License. # limitations under the License.
import io import io
import math
import random import random
import tarfile import tarfile
import logging import logging
import numpy as np
from collections import namedtuple from collections import namedtuple
from functools import partial
from yacs.config import CfgNode from yacs.config import CfgNode
from paddle.io import Dataset from paddle.io import Dataset
......
...@@ -13,14 +13,9 @@ ...@@ -13,14 +13,9 @@
# limitations under the License. # limitations under the License.
import math import math
import random
import tarfile
import logging import logging
import numpy as np import numpy as np
from collections import namedtuple
from functools import partial
import paddle
from paddle.io import BatchSampler from paddle.io import BatchSampler
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from paddle import distributed as dist from paddle import distributed as dist
...@@ -59,7 +54,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False): ...@@ -59,7 +54,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size)) batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices) rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch] batch_indices = [item for batch in batch_indices for item in batch]
assert (clipped == False) assert clipped is False
if not clipped: if not clipped:
res_len = len(indices) - shift_len - len(batch_indices) res_len = len(indices) - shift_len - len(batch_indices)
# when res_len is 0, will return whole list, len(List[-0:]) = len(List[:]) # when res_len is 0, will return whole list, len(List[-0:]) = len(List[:])
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import logging import logging
import numpy as np import numpy as np
from collections import namedtuple
from typing import List from typing import List
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -12,20 +12,12 @@ ...@@ -12,20 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Deepspeech2 ASR Model""" """Deepspeech2 ASR Model"""
import math
import collections
import numpy as np
import logging import logging
from typing import Optional from typing import Optional
from yacs.config import CfgNode from yacs.config import CfgNode
import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from deepspeech.modules.mask import sequence_mask
from deepspeech.modules.activation import brelu
from deepspeech.modules.conv import ConvStack from deepspeech.modules.conv import ConvStack
from deepspeech.modules.rnn import RNNStack from deepspeech.modules.rnn import RNNStack
from deepspeech.modules.ctc import CTCDecoder from deepspeech.modules.ctc import CTCDecoder
......
...@@ -15,10 +15,8 @@ ...@@ -15,10 +15,8 @@
Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition
(https://arxiv.org/pdf/2012.05481.pdf) (https://arxiv.org/pdf/2012.05481.pdf)
""" """
import math
import collections
from collections import defaultdict from collections import defaultdict
import numpy as np
import logging import logging
from yacs.config import CfgNode from yacs.config import CfgNode
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -26,8 +24,6 @@ from typing import List, Optional, Tuple ...@@ -26,8 +24,6 @@ from typing import List, Optional, Tuple
import paddle import paddle
from paddle import jit from paddle import jit
from paddle import nn from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from deepspeech.modules.mask import make_pad_mask from deepspeech.modules.mask import make_pad_mask
from deepspeech.modules.mask import mask_finished_preds from deepspeech.modules.mask import mask_finished_preds
...@@ -54,7 +50,7 @@ from deepspeech.utils.ctc_utils import remove_duplicates_and_blank ...@@ -54,7 +50,7 @@ from deepspeech.utils.ctc_utils import remove_duplicates_and_blank
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ['U2TransformerModel', "U2ConformerModel"] __all__ = ["U2Model", "U2InferModel"]
class U2BaseModel(nn.Module): class U2BaseModel(nn.Module):
......
...@@ -12,16 +12,11 @@ ...@@ -12,16 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Union
import logging import logging
import numpy as np
import math
from collections import OrderedDict from collections import OrderedDict
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -18,7 +18,6 @@ from typing import Optional, Tuple ...@@ -18,7 +18,6 @@ from typing import Optional, Tuple
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I from paddle.nn import initializer as I
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -16,8 +16,6 @@ import logging ...@@ -16,8 +16,6 @@ import logging
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -19,8 +19,6 @@ import logging ...@@ -19,8 +19,6 @@ import logging
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -14,10 +14,8 @@ ...@@ -14,10 +14,8 @@
import logging import logging
import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
from paddle.nn import initializer as I
from deepspeech.modules.mask import sequence_mask from deepspeech.modules.mask import sequence_mask
from deepspeech.modules.activation import brelu from deepspeech.modules.activation import brelu
......
...@@ -18,7 +18,6 @@ from typeguard import check_argument_types ...@@ -18,7 +18,6 @@ from typeguard import check_argument_types
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
from paddle.nn import initializer as I
from deepspeech.modules.loss import CTCLoss from deepspeech.modules.loss import CTCLoss
from deepspeech.utils import ctc_utils from deepspeech.utils import ctc_utils
...@@ -151,7 +150,7 @@ class CTCDecoder(nn.Layer): ...@@ -151,7 +150,7 @@ class CTCDecoder(nn.Layer):
:type vocab_list: list :type vocab_list: list
""" """
# init once # init once
if self._ext_scorer != None: if self._ext_scorer is not None:
return return
if language_model_path != '': if language_model_path != '':
...@@ -199,7 +198,7 @@ class CTCDecoder(nn.Layer): ...@@ -199,7 +198,7 @@ class CTCDecoder(nn.Layer):
:return: List of transcription texts. :return: List of transcription texts.
:rtype: List of str :rtype: List of str
""" """
if self._ext_scorer != None: if self._ext_scorer is not None:
self._ext_scorer.reset_params(beam_alpha, beam_beta) self._ext_scorer.reset_params(beam_alpha, beam_beta)
# beam search decode # beam search decode
......
...@@ -18,8 +18,6 @@ import logging ...@@ -18,8 +18,6 @@ import logging
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from deepspeech.modules.attention import MultiHeadedAttention from deepspeech.modules.attention import MultiHeadedAttention
from deepspeech.modules.decoder_layer import DecoderLayer from deepspeech.modules.decoder_layer import DecoderLayer
...@@ -125,7 +123,7 @@ class TransformerDecoder(nn.Module): ...@@ -125,7 +123,7 @@ class TransformerDecoder(nn.Module):
m = subsequent_mask(tgt_mask.size(-1)).unsqueeze(0) m = subsequent_mask(tgt_mask.size(-1)).unsqueeze(0)
# tgt_mask: (B, L, L) # tgt_mask: (B, L, L)
# TODO(Hui Zhang): not support & for tensor # TODO(Hui Zhang): not support & for tensor
#tgt_mask = tgt_mask & m # tgt_mask = tgt_mask & m
tgt_mask = tgt_mask.logical_and(m) tgt_mask = tgt_mask.logical_and(m)
x, _ = self.embed(tgt) x, _ = self.embed(tgt)
...@@ -137,8 +135,8 @@ class TransformerDecoder(nn.Module): ...@@ -137,8 +135,8 @@ class TransformerDecoder(nn.Module):
if self.use_output_layer: if self.use_output_layer:
x = self.output_layer(x) x = self.output_layer(x)
#TODO(Hui Zhang): reduce_sum not support bool type # TODO(Hui Zhang): reduce_sum not support bool type
#olens = tgt_mask.sum(1) # olens = tgt_mask.sum(1)
olens = tgt_mask.astype(paddle.int).sum(1) olens = tgt_mask.astype(paddle.int).sum(1)
return x, olens return x, olens
......
...@@ -17,8 +17,6 @@ import logging ...@@ -17,8 +17,6 @@ import logging
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -15,13 +15,10 @@ ...@@ -15,13 +15,10 @@
import math import math
import logging import logging
import numpy as np
from typing import Tuple from typing import Tuple
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -18,8 +18,6 @@ from typeguard import check_argument_types ...@@ -18,8 +18,6 @@ from typeguard import check_argument_types
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from deepspeech.modules.attention import MultiHeadedAttention from deepspeech.modules.attention import MultiHeadedAttention
from deepspeech.modules.attention import RelPositionMultiHeadedAttention from deepspeech.modules.attention import RelPositionMultiHeadedAttention
......
...@@ -17,8 +17,6 @@ import logging ...@@ -17,8 +17,6 @@ import logging
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -17,7 +17,6 @@ import logging ...@@ -17,7 +17,6 @@ import logging
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
from paddle.nn import initializer as I
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -15,9 +15,6 @@ ...@@ -15,9 +15,6 @@
import logging import logging
import paddle import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -16,8 +16,6 @@ import logging ...@@ -16,8 +16,6 @@ import logging
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -18,8 +18,6 @@ import logging ...@@ -18,8 +18,6 @@ import logging
import paddle import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I
from deepspeech.modules.embedding import PositionalEncoding from deepspeech.modules.embedding import PositionalEncoding
......
...@@ -11,5 +11,3 @@ ...@@ -11,5 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from deepspeech.training.trainer import *
...@@ -58,12 +58,15 @@ def default_argument_parser(): ...@@ -58,12 +58,15 @@ def default_argument_parser():
parser.add_argument("--export_path", type=str, help="path of the jit model to save") parser.add_argument("--export_path", type=str, help="path of the jit model to save")
# running # running
parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"], help="device type to use, cpu and gpu are supported.") parser.add_argument("--device", type=str, default='gpu', choices=["cpu", "gpu"],
help="device type to use, cpu and gpu are supported.")
parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.") parser.add_argument("--nprocs", type=int, default=1, help="number of parallel processes to use.")
# overwrite extra config and default config # overwrite extra config and default config
#parser.add_argument("--opts", nargs=argparse.REMAINDER, help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") # parser.add_argument("--opts", nargs=argparse.REMAINDER,
parser.add_argument("--opts", type=str, default=[], nargs='+', help="options to overwrite --config file and the default config, passing in KEY VALUE pairs") # help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
parser.add_argument("--opts", type=str, default=[], nargs='+',
help="options to overwrite --config file and the default config, passing in KEY VALUE pairs")
# yapd: enable # yapd: enable
return parser return parser
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
import logging import logging
import paddle
from paddle.optimizer.lr import LRScheduler from paddle.optimizer.lr import LRScheduler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -16,12 +16,9 @@ import time ...@@ -16,12 +16,9 @@ import time
import logging import logging
import logging.handlers import logging.handlers
from pathlib import Path from pathlib import Path
import numpy as np
from collections import defaultdict
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddle.distributed.utils import get_gpus
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from deepspeech.utils import checkpoint from deepspeech.utils import checkpoint
......
...@@ -13,15 +13,12 @@ ...@@ -13,15 +13,12 @@
# limitations under the License. # limitations under the License.
import os import os
import time
import logging import logging
import numpy as np
import re import re
import json import json
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
from paddle.nn import Layer
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer
from deepspeech.utils import mp_tools from deepspeech.utils import mp_tools
......
...@@ -81,7 +81,7 @@ def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '): ...@@ -81,7 +81,7 @@ def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '):
:return: Levenshtein distance and word number of reference sentence. :return: Levenshtein distance and word number of reference sentence.
:rtype: list :rtype: list
""" """
if ignore_case == True: if ignore_case:
reference = reference.lower() reference = reference.lower()
hypothesis = hypothesis.lower() hypothesis = hypothesis.lower()
...@@ -107,12 +107,12 @@ def char_errors(reference, hypothesis, ignore_case=False, remove_space=False): ...@@ -107,12 +107,12 @@ def char_errors(reference, hypothesis, ignore_case=False, remove_space=False):
:return: Levenshtein distance and length of reference sentence. :return: Levenshtein distance and length of reference sentence.
:rtype: list :rtype: list
""" """
if ignore_case == True: if ignore_case:
reference = reference.lower() reference = reference.lower()
hypothesis = hypothesis.lower() hypothesis = hypothesis.lower()
join_char = ' ' join_char = ' '
if remove_space == True: if remove_space:
join_char = '' join_char = ''
reference = join_char.join(list(filter(None, reference.split(' ')))) reference = join_char.join(list(filter(None, reference.split(' '))))
......
...@@ -51,7 +51,7 @@ def recursively_remove_weight_norm(layer: nn.Layer): ...@@ -51,7 +51,7 @@ def recursively_remove_weight_norm(layer: nn.Layer):
for layer in layer.sublayers(): for layer in layer.sublayers():
try: try:
nn.utils.remove_weight_norm(layer) nn.utils.remove_weight_norm(layer)
except: except ValueError as e:
# ther is not weight norm hoom in this layer # ther is not weight norm hoom in this layer
pass pass
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle
from paddle import distributed as dist from paddle import distributed as dist
from functools import wraps from functools import wraps
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Unility functions for Transformer.""" """Unility functions for Transformer."""
import math
import logging import logging
from typing import Tuple, List from typing import Tuple, List
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
"""Contains common utility functions.""" """Contains common utility functions."""
import math import math
import numpy as np
import distutils.util import distutils.util
from typing import List from typing import List
......
...@@ -55,7 +55,8 @@ def create_manifest(data_dir, manifest_path_prefix): ...@@ -55,7 +55,8 @@ def create_manifest(data_dir, manifest_path_prefix):
transcript_dict = {} transcript_dict = {}
for line in codecs.open(transcript_path, 'r', 'utf-8'): for line in codecs.open(transcript_path, 'r', 'utf-8'):
line = line.strip() line = line.strip()
if line == '': continue if line == '':
continue
audio_id, text = line.split(' ', 1) audio_id, text = line.split(' ', 1)
# remove withespace # remove withespace
text = ''.join(text.split()) text = ''.join(text.split())
...@@ -82,7 +83,7 @@ def create_manifest(data_dir, manifest_path_prefix): ...@@ -82,7 +83,7 @@ def create_manifest(data_dir, manifest_path_prefix):
os.path.splitext(os.path.basename(audio_path))[0], os.path.splitext(os.path.basename(audio_path))[0],
'feat': 'feat':
audio_path, audio_path,
'feat_shape': (duration, ), #second 'feat_shape': (duration, ), # second
'text': 'text':
text text
}, },
......
...@@ -19,7 +19,6 @@ meta data (i.e. audio filepath, transcript and audio duration) ...@@ -19,7 +19,6 @@ meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set. of each audio file in the data set.
""" """
import distutils.util
import os import os
import wget import wget
import zipfile import zipfile
...@@ -29,7 +28,7 @@ import json ...@@ -29,7 +28,7 @@ import json
import io import io
from paddle.v2.dataset.common import md5file from paddle.v2.dataset.common import md5file
#DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') # DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
DATA_HOME = os.path.expanduser('.') DATA_HOME = os.path.expanduser('.')
URL = "https://d4s.myairbridge.com/packagev2/AG0Y3DNBE5IWRRTV/?dlid=W19XG7T0NNHB027139H0EQ" URL = "https://d4s.myairbridge.com/packagev2/AG0Y3DNBE5IWRRTV/?dlid=W19XG7T0NNHB027139H0EQ"
...@@ -51,9 +50,10 @@ args = parser.parse_args() ...@@ -51,9 +50,10 @@ args = parser.parse_args()
def download(url, md5sum, target_dir, filename=None): def download(url, md5sum, target_dir, filename=None):
"""Download file from url to target_dir, and check md5sum.""" """Download file from url to target_dir, and check md5sum."""
if filename == None: if filename is None:
filename = url.split("/")[-1] filename = url.split("/")[-1]
if not os.path.exists(target_dir): os.makedirs(target_dir) if not os.path.exists(target_dir):
os.makedirs(target_dir)
filepath = os.path.join(target_dir, filename) filepath = os.path.join(target_dir, filename)
if not (os.path.exists(filepath) and md5file(filepath) == md5sum): if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
print("Downloading %s ..." % url) print("Downloading %s ..." % url)
...@@ -100,7 +100,7 @@ def create_manifest(data_dir, manifest_path): ...@@ -100,7 +100,7 @@ def create_manifest(data_dir, manifest_path):
'utt': os.path.splitext(os.path.basename(filepath))[ 'utt': os.path.splitext(os.path.basename(filepath))[
0], 0],
'feat': filepath, 'feat': filepath,
'feat_shape': (duration, ), #second 'feat_shape': (duration, ), # second
'type': 'background' 'type': 'background'
})) }))
with io.open(manifest_path, mode='w', encoding='utf8') as out_file: with io.open(manifest_path, mode='w', encoding='utf8') as out_file:
......
...@@ -21,7 +21,6 @@ of each audio file in the data set. ...@@ -21,7 +21,6 @@ of each audio file in the data set.
import distutils.util import distutils.util
import os import os
import sys
import argparse import argparse
import soundfile import soundfile
import json import json
......
...@@ -19,9 +19,7 @@ meta data (i.e. audio filepath, transcript and audio duration) ...@@ -19,9 +19,7 @@ meta data (i.e. audio filepath, transcript and audio duration)
of each audio file in the data set. of each audio file in the data set.
""" """
import distutils.util
import os import os
import sys
import argparse import argparse
import soundfile import soundfile
import json import json
......
...@@ -27,7 +27,7 @@ import codecs ...@@ -27,7 +27,7 @@ import codecs
import soundfile import soundfile
import json import json
import argparse import argparse
from utils.utility import download, unpack, unzip from utils.utility import download, unzip
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset/speech')
......
...@@ -26,11 +26,11 @@ class TestDeepSpeech2Model(unittest.TestCase): ...@@ -26,11 +26,11 @@ class TestDeepSpeech2Model(unittest.TestCase):
self.feat_dim = 161 self.feat_dim = 161
max_len = 64 max_len = 64
#(B, T, D) # (B, T, D)
audio = np.random.randn(self.batch_size, max_len, self.feat_dim) audio = np.random.randn(self.batch_size, max_len, self.feat_dim)
audio_len = np.random.randint(max_len, size=self.batch_size) audio_len = np.random.randint(max_len, size=self.batch_size)
audio_len[-1] = max_len audio_len[-1] = max_len
#(B, U) # (B, U)
text = np.array([[1, 2], [1, 2]]) text = np.array([[1, 2], [1, 2]])
text_len = np.array([2] * self.batch_size) text_len = np.array([2] * self.batch_size)
......
...@@ -17,10 +17,8 @@ Each item in vocabulary file is a character. ...@@ -17,10 +17,8 @@ Each item in vocabulary file is a character.
import argparse import argparse
import functools import functools
import json
from collections import Counter from collections import Counter
import os import os
import copy
import tempfile import tempfile
from deepspeech.frontend.utility import read_manifest from deepspeech.frontend.utility import read_manifest
...@@ -48,10 +46,8 @@ add_arg('manifest_paths', str, ...@@ -48,10 +46,8 @@ add_arg('manifest_paths', str,
required=True) required=True)
# bpe # bpe
add_arg('vocab_size', int, 0, "Vocab size for spm.") add_arg('vocab_size', int, 0, "Vocab size for spm.")
add_arg('spm_mode', str, 'unigram', add_arg('spm_mode', str, 'unigram', "spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm")
"spm model type, e.g. unigram, spm, char, word. only need when `unit_type` is spm") add_arg('spm_model_prefix', str, "spm_model_%(spm_mode)_%(count_threshold)", "spm model prefix, only need when `unit_type` is spm")
add_arg('spm_model_prefix', str, "spm_model_%(spm_mode)_%(count_threshold)",
"spm model prefix, only need when `unit_type` is spm")
# yapf: disable # yapf: disable
args = parser.parse_args() args = parser.parse_args()
...@@ -104,7 +100,8 @@ def main(): ...@@ -104,7 +100,8 @@ def main():
count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True) count_sorted = sorted(counter.items(), key=lambda x: x[1], reverse=True)
tokens = [] tokens = []
for token, count in count_sorted: for token, count in count_sorted:
if count < args.count_threshold: break if count < args.count_threshold:
break
tokens.append(token) tokens.append(token)
tokens = sorted(tokens) tokens = sorted(tokens)
......
...@@ -15,15 +15,8 @@ ...@@ -15,15 +15,8 @@
import argparse import argparse
import functools import functools
import json import json
from collections import Counter
import os
import copy
import tempfile
from deepspeech.frontend.utility import read_manifest from deepspeech.frontend.utility import read_manifest
from deepspeech.frontend.utility import UNK
from deepspeech.frontend.utility import BLANK
from deepspeech.frontend.utility import SOS
from deepspeech.frontend.utility import load_cmvn from deepspeech.frontend.utility import load_cmvn
from deepspeech.utils.utility import add_arguments from deepspeech.utils.utility import add_arguments
from deepspeech.utils.utility import print_arguments from deepspeech.utils.utility import print_arguments
...@@ -82,7 +75,7 @@ def main(): ...@@ -82,7 +75,7 @@ def main():
if args.feat_type == 'raw': if args.feat_type == 'raw':
feat_shape.append(feat_dim) feat_shape.append(feat_dim)
else: # kaldi else: # kaldi
raise NotImplemented('no support kaldi feat now!') raise NotImplementedError('no support kaldi feat now!')
fout.write(json.dumps(line_json) + '\n') fout.write(json.dumps(line_json) + '\n')
count += 1 count += 1
......
...@@ -30,7 +30,8 @@ def getfile_insensitive(path): ...@@ -30,7 +30,8 @@ def getfile_insensitive(path):
def download_multi(url, target_dir, extra_args): def download_multi(url, target_dir, extra_args):
"""Download multiple files from url to target_dir.""" """Download multiple files from url to target_dir."""
if not os.path.exists(target_dir): os.makedirs(target_dir) if not os.path.exists(target_dir):
os.makedirs(target_dir)
print("Downloading %s ..." % url) print("Downloading %s ..." % url)
ret_code = os.system("wget -c " + url + ' ' + extra_args + " -P " + ret_code = os.system("wget -c " + url + ' ' + extra_args + " -P " +
target_dir) target_dir)
...@@ -39,7 +40,8 @@ def download_multi(url, target_dir, extra_args): ...@@ -39,7 +40,8 @@ def download_multi(url, target_dir, extra_args):
def download(url, md5sum, target_dir): def download(url, md5sum, target_dir):
"""Download file from url to target_dir, and check md5sum.""" """Download file from url to target_dir, and check md5sum."""
if not os.path.exists(target_dir): os.makedirs(target_dir) if not os.path.exists(target_dir):
os.makedirs(target_dir)
filepath = os.path.join(target_dir, url.split("/")[-1]) filepath = os.path.join(target_dir, url.split("/")[-1])
if not (os.path.exists(filepath) and md5file(filepath) == md5sum): if not (os.path.exists(filepath) and md5file(filepath) == md5sum):
print("Downloading %s ..." % url) print("Downloading %s ..." % url)
...@@ -58,7 +60,7 @@ def unpack(filepath, target_dir, rm_tar=False): ...@@ -58,7 +60,7 @@ def unpack(filepath, target_dir, rm_tar=False):
tar = tarfile.open(filepath) tar = tarfile.open(filepath)
tar.extractall(target_dir) tar.extractall(target_dir)
tar.close() tar.close()
if rm_tar == True: if rm_tar:
os.remove(filepath) os.remove(filepath)
...@@ -68,5 +70,5 @@ def unzip(filepath, target_dir, rm_tar=False): ...@@ -68,5 +70,5 @@ def unzip(filepath, target_dir, rm_tar=False):
tar = zipfile.ZipFile(filepath, 'r') tar = zipfile.ZipFile(filepath, 'r')
tar.extractall(target_dir) tar.extractall(target_dir)
tar.close() tar.close()
if rm_tar == True: if rm_tar:
os.remove(filepath) os.remove(filepath)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册