提交 e3d73acd 编写于 作者: H Hui Zhang

fix io; add test

上级 4b5410ee
# Locales
export LC_ALL=en_US.UTF-8
export LANG=en_US.UTF-8
export LANGUAGE=en_US.UTF-8
# Aliases
alias nvs="nvidia-smi"
alias rsync="rsync --progress -raz"
alias his="history"
此差异已折叠。
...@@ -347,7 +347,7 @@ def make_batchset( ...@@ -347,7 +347,7 @@ def make_batchset(
Note that if any utts doesn't have "category", Note that if any utts doesn't have "category",
perform as same as batchfy_by_{count} perform as same as batchfy_by_{count}
:param Dict[str, Dict[str, Any]] data: dictionary loaded from data.json :param List[Dict[str, Any]] data: dictionary loaded from data.json
:param int batch_size: maximum number of sequences in a minibatch. :param int batch_size: maximum number of sequences in a minibatch.
:param int batch_bins: maximum number of bins (frames x dim) in a minibatch. :param int batch_bins: maximum number of bins (frames x dim) in a minibatch.
:param int batch_frames_in: maximum number of input frames in a minibatch. :param int batch_frames_in: maximum number of input frames in a minibatch.
...@@ -374,7 +374,6 @@ def make_batchset( ...@@ -374,7 +374,6 @@ def make_batchset(
reserved for future research, -1 means all axis.) reserved for future research, -1 means all axis.)
:return: List[List[Tuple[str, dict]]] list of batches :return: List[List[Tuple[str, dict]]] list of batches
""" """
# check args # check args
if count not in BATCH_COUNT_CHOICES: if count not in BATCH_COUNT_CHOICES:
raise ValueError( raise ValueError(
...@@ -386,7 +385,6 @@ def make_batchset( ...@@ -386,7 +385,6 @@ def make_batchset(
ikey = "input" ikey = "input"
okey = "output" okey = "output"
batch_sort_axis = 0 # index of list batch_sort_axis = 0 # index of list
if count == "auto": if count == "auto":
if batch_size != 0: if batch_size != 0:
count = "seq" count = "seq"
...@@ -405,7 +403,8 @@ def make_batchset( ...@@ -405,7 +403,8 @@ def make_batchset(
"batch_sort_key=shuffle is only available if batch_count=seq") "batch_sort_key=shuffle is only available if batch_count=seq")
category2data = {} # Dict[str, dict] category2data = {} # Dict[str, dict]
for k, v in data.items(): for v in data:
k = v['utt']
category2data.setdefault(v.get("category"), {})[k] = v category2data.setdefault(v.get("category"), {})[k] = v
batches_list = [] # List[List[List[Tuple[str, dict]]]] batches_list = [] # List[List[List[Tuple[str, dict]]]]
...@@ -422,6 +421,7 @@ def make_batchset( ...@@ -422,6 +421,7 @@ def make_batchset(
key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]), key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]),
reverse=not shortest_first, ) reverse=not shortest_first, )
logger.info("# utts: " + str(len(sorted_data))) logger.info("# utts: " + str(len(sorted_data)))
if count == "seq": if count == "seq":
batches = batchfy_by_seq( batches = batchfy_by_seq(
sorted_data, sorted_data,
......
...@@ -16,7 +16,7 @@ from typing import Optional ...@@ -16,7 +16,7 @@ from typing import Optional
from paddle.io import Dataset from paddle.io import Dataset
from yacs.config import CfgNode from yacs.config import CfgNode
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.log import Log from deepspeech.utils.log import Log
__all__ = ["ManifestDataset", "TripletManifestDataset", "TransformDataset"] __all__ = ["ManifestDataset", "TripletManifestDataset", "TransformDataset"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册