提交 e3d73acd 编写于 作者: H Hui Zhang

fix io; add test

上级 4b5410ee
# Locales
export LC_ALL=en_US.UTF-8
export LANG=en_US.UTF-8
export LANGUAGE=en_US.UTF-8
# Aliases
alias nvs="nvidia-smi"
alias rsync="rsync --progress -raz"
alias his="history"
此差异已折叠。
......@@ -347,7 +347,7 @@ def make_batchset(
Note that if any utts doesn't have "category",
perform as same as batchfy_by_{count}
:param Dict[str, Dict[str, Any]] data: dictionary loaded from data.json
:param List[Dict[str, Any]] data: dictionary loaded from data.json
:param int batch_size: maximum number of sequences in a minibatch.
:param int batch_bins: maximum number of bins (frames x dim) in a minibatch.
:param int batch_frames_in: maximum number of input frames in a minibatch.
......@@ -374,7 +374,6 @@ def make_batchset(
reserved for future research, -1 means all axis.)
:return: List[List[Tuple[str, dict]]] list of batches
"""
# check args
if count not in BATCH_COUNT_CHOICES:
raise ValueError(
......@@ -386,7 +385,6 @@ def make_batchset(
ikey = "input"
okey = "output"
batch_sort_axis = 0 # index of list
if count == "auto":
if batch_size != 0:
count = "seq"
......@@ -405,7 +403,8 @@ def make_batchset(
"batch_sort_key=shuffle is only available if batch_count=seq")
category2data = {} # Dict[str, dict]
for k, v in data.items():
for v in data:
k = v['utt']
category2data.setdefault(v.get("category"), {})[k] = v
batches_list = [] # List[List[List[Tuple[str, dict]]]]
......@@ -422,6 +421,7 @@ def make_batchset(
key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]),
reverse=not shortest_first, )
logger.info("# utts: " + str(len(sorted_data)))
if count == "seq":
batches = batchfy_by_seq(
sorted_data,
......@@ -466,4 +466,4 @@ def make_batchset(
logger.info("# minibatches: " + str(len(batches)))
# batch: List[List[Tuple[str, dict]]]
return batches
return batches
\ No newline at end of file
......@@ -16,7 +16,7 @@ from typing import Optional
from paddle.io import Dataset
from yacs.config import CfgNode
from deepspeech.frontend.utility import read_manifest
from deepspeech.utils.log import Log
__all__ = ["ManifestDataset", "TripletManifestDataset", "TransformDataset"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册