提交 d647cde8 编写于 作者: H huangyuxin

change the lm dataset dir

上级 db96da3e
...@@ -14,6 +14,10 @@ export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/ ...@@ -14,6 +14,10 @@ export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=u2_kaldi MODEL=u2_kaldi
export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/${MODEL}/bin export BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/${MODEL}/bin
LM_MODEL=transformer
export LM_BIN_DIR=${MAIN_ROOT}/paddlespeech/s2t/exps/lm/${LM_MODEL}/bin
# srilm # srilm
export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10 export LIBLBFGS=${MAIN_ROOT}/tools/liblbfgs-1.10
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs export LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}:${LIBLBFGS}/lib/.libs
...@@ -25,4 +29,4 @@ export KALDI_ROOT=${MAIN_ROOT}/tools/kaldi ...@@ -25,4 +29,4 @@ export KALDI_ROOT=${MAIN_ROOT}/tools/kaldi
[ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh [ -f $KALDI_ROOT/tools/env.sh ] && . $KALDI_ROOT/tools/env.sh
export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH export PATH=$PWD/utils/:$KALDI_ROOT/tools/openfst/bin:$PWD:$PATH
[ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present, can not using Kaldi!" [ ! -f $KALDI_ROOT/tools/config/common_path.sh ] && echo >&2 "The standard file $KALDI_ROOT/tools/config/common_path.sh is not present, can not using Kaldi!"
[ -f $KALDI_ROOT/tools/config/common_path.sh ] && . $KALDI_ROOT/tools/config/common_path.sh [ -f $KALDI_ROOT/tools/config/common_path.sh ] && . $KALDI_ROOT/tools/config/common_path.sh
\ No newline at end of file
...@@ -19,8 +19,8 @@ import paddle ...@@ -19,8 +19,8 @@ import paddle
from paddle.io import DataLoader from paddle.io import DataLoader
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.s2t.io.collator import TextCollatorSpm from paddlespeech.s2t.models.lm.dataset import TextCollatorSpm
from paddlespeech.s2t.io.dataset import TextDataset from paddlespeech.s2t.models.lm.dataset import TextDataset
from paddlespeech.s2t.models.lm_interface import dynamic_import_lm from paddlespeech.s2t.models.lm_interface import dynamic_import_lm
from paddlespeech.s2t.utils.log import Log from paddlespeech.s2t.utils.log import Log
......
...@@ -19,7 +19,6 @@ from yacs.config import CfgNode ...@@ -19,7 +19,6 @@ from yacs.config import CfgNode
from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline from paddlespeech.s2t.frontend.augmentor.augmentation import AugmentationPipeline
from paddlespeech.s2t.frontend.featurizer.speech_featurizer import SpeechFeaturizer from paddlespeech.s2t.frontend.featurizer.speech_featurizer import SpeechFeaturizer
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.normalizer import FeatureNormalizer from paddlespeech.s2t.frontend.normalizer import FeatureNormalizer
from paddlespeech.s2t.frontend.speech import SpeechSegment from paddlespeech.s2t.frontend.speech import SpeechSegment
from paddlespeech.s2t.frontend.utility import IGNORE_ID from paddlespeech.s2t.frontend.utility import IGNORE_ID
...@@ -46,43 +45,6 @@ def _tokenids(text, keep_transcription_text): ...@@ -46,43 +45,6 @@ def _tokenids(text, keep_transcription_text):
return tokens return tokens
class TextCollatorSpm():
def __init__(self, unit_type, vocab_filepath, spm_model_prefix):
assert (vocab_filepath is not None)
self.text_featurizer = TextFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix)
self.eos_id = self.text_featurizer.eos_id
self.blank_id = self.text_featurizer.blank_id
def __call__(self, batch):
"""
return type [List, np.array [B, T], np.array [B, T], np.array[B]]
"""
keys = []
texts = []
texts_input = []
texts_output = []
text_lens = []
for idx, item in enumerate(batch):
key = item.split(" ")[0].strip()
text = " ".join(item.split(" ")[1:])
keys.append(key)
token_ids = self.text_featurizer.featurize(text)
texts_input.append(
np.array([self.eos_id] + token_ids).astype(np.int64))
texts_output.append(
np.array(token_ids + [self.eos_id]).astype(np.int64))
text_lens.append(len(token_ids) + 1)
ys_input_pad = pad_list(texts_input, self.blank_id).astype(np.int64)
ys_output_pad = pad_list(texts_output, self.blank_id).astype(np.int64)
y_lens = np.array(text_lens).astype(np.int64)
return keys, ys_input_pad, ys_output_pad, y_lens
class SpeechCollatorBase(): class SpeechCollatorBase():
def __init__( def __init__(
self, self,
......
...@@ -24,25 +24,6 @@ __all__ = ["ManifestDataset", "TransformDataset"] ...@@ -24,25 +24,6 @@ __all__ = ["ManifestDataset", "TransformDataset"]
logger = Log(__name__).getlog() logger = Log(__name__).getlog()
class TextDataset(Dataset):
@classmethod
def from_file(cls, file_path):
dataset = cls(file_path)
return dataset
def __init__(self, file_path):
self._manifest = []
with open(file_path) as f:
for line in f:
self._manifest.append(line.strip())
def __len__(self):
return len(self._manifest)
def __getitem__(self, idx):
return self._manifest[idx]
class ManifestDataset(Dataset): class ManifestDataset(Dataset):
@classmethod @classmethod
def params(cls, config: Optional[CfgNode]=None) -> CfgNode: def params(cls, config: Optional[CfgNode]=None) -> CfgNode:
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddle.io import Dataset
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.io.utility import pad_list
class TextDataset(Dataset):
@classmethod
def from_file(cls, file_path):
dataset = cls(file_path)
return dataset
def __init__(self, file_path):
self._manifest = []
with open(file_path) as f:
for line in f:
self._manifest.append(line.strip())
def __len__(self):
return len(self._manifest)
def __getitem__(self, idx):
return self._manifest[idx]
class TextCollatorSpm():
def __init__(self, unit_type, vocab_filepath, spm_model_prefix):
assert (vocab_filepath is not None)
self.text_featurizer = TextFeaturizer(
unit_type=unit_type,
vocab_filepath=vocab_filepath,
spm_model_prefix=spm_model_prefix)
self.eos_id = self.text_featurizer.eos_id
self.blank_id = self.text_featurizer.blank_id
def __call__(self, batch):
"""
return type [List, np.array [B, T], np.array [B, T], np.array[B]]
"""
keys = []
texts = []
texts_input = []
texts_output = []
text_lens = []
for idx, item in enumerate(batch):
key = item.split(" ")[0].strip()
text = " ".join(item.split(" ")[1:])
keys.append(key)
token_ids = self.text_featurizer.featurize(text)
texts_input.append(
np.array([self.eos_id] + token_ids).astype(np.int64))
texts_output.append(
np.array(token_ids + [self.eos_id]).astype(np.int64))
text_lens.append(len(token_ids) + 1)
ys_input_pad = pad_list(texts_input, self.blank_id).astype(np.int64)
ys_output_pad = pad_list(texts_output, self.blank_id).astype(np.int64)
y_lens = np.array(text_lens).astype(np.int64)
return keys, ys_input_pad, ys_output_pad, y_lens
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册