未验证 提交 86d08f99 编写于 作者: H Hui Zhang 提交者: GitHub

Merge pull request #768 from PaddlePaddle/espnet

support  kaldi data pipeline
[中文版](README_cn.md)
# PaddlePaddle ASR toolkit
# PaddlePaddle Speech to Any toolkit
![License](https://img.shields.io/badge/license-Apache%202-red.svg)
![python version](https://img.shields.io/badge/python-3.7+-orange.svg)
![support os](https://img.shields.io/badge/os-linux-yellow.svg)
*PaddleASR* is an open-source implementation of end-to-end Automatic Speech Recognition (ASR) engine, with [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform. Our vision is to empower both industrial application and academic research on speech recognition, via an easy-to-use, efficient, samller and scalable implementation, including training, inference & testing module, and deployment.
*DeepSpeech* is an open-source implementation of end-to-end Automatic Speech Recognition engine, with [PaddlePaddle](https://github.com/PaddlePaddle/Paddle) platform. Our vision is to empower both industrial application and academic research on speech recognition, via an easy-to-use, efficient, samller and scalable implementation, including training, inference & testing module, and deployment.
## Features
......@@ -15,6 +15,8 @@
## Setup
All tested under:
* Ubuntu 16.04
* python>=3.7
* paddlepaddle>=2.1.2
......
[English](README.md)
# PaddlePaddle ASR toolkit
# PaddlePaddle Speech to Any toolkit
![License](https://img.shields.io/badge/license-Apache%202-red.svg)
![python version](https://img.shields.io/badge/python-3.7+-orange.svg)
![support os](https://img.shields.io/badge/os-linux-yellow.svg)
*PaddleASR*是一个采用[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)平台的端到端自动语音识别(ASR)引擎的开源项目,
*DeepSpeech*是一个采用[PaddlePaddle](https://github.com/PaddlePaddle/Paddle)平台的端到端自动语音识别引擎的开源项目,
我们的愿景是为语音识别在工业应用和学术研究上,提供易于使用、高效、小型化和可扩展的工具,包括训练,推理,以及 部署。
## 特性
......@@ -16,6 +16,9 @@
## 安装
在以下环境测试验证过:
* Ubuntu 16.04
* python>=3.7
* paddlepaddle>=2.1.2
......
......@@ -407,42 +407,3 @@ class GLU(nn.Layer):
if not hasattr(paddle.nn, 'GLU'):
logger.warn("register user GLU to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'GLU', GLU)
# TODO(Hui Zhang): remove this Layer
class ConstantPad2d(nn.Layer):
"""Pads the input tensor boundaries with a constant value.
For N-dimensional padding, use paddle.nn.functional.pad().
"""
def __init__(self, padding: Union[tuple, list, int], value: float):
"""
Args:
paddle ([tuple]): the size of the padding.
If is int, uses the same padding in all boundaries.
If a 4-tuple, uses (padding_left, padding_right, padding_top, padding_bottom)
value ([flaot]): pad value
"""
self.padding = padding if isinstance(padding,
[tuple, list]) else [padding] * 4
self.value = value
def forward(self, xs: paddle.Tensor) -> paddle.Tensor:
return nn.functional.pad(
xs,
self.padding,
mode='constant',
value=self.value,
data_format='NCHW')
if not hasattr(paddle.nn, 'ConstantPad2d'):
logger.warn(
"register user ConstantPad2d to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'ConstantPad2d', ConstantPad2d)
########### hcak paddle.jit #############
if not hasattr(paddle.jit, 'export'):
logger.warn("register user export to paddle.jit, remove this when fixed!")
setattr(paddle.jit, 'export', paddle.jit.to_static)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation for U2 model."""
import cProfile
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.utility import print_arguments
model_test_alias = {
"u2": "deepspeech.exps.u2.model:U2Tester",
"u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Tester",
}
def main_sp(config, args):
class_obj = dynamic_import(args.model_name, model_test_alias)
exp = class_obj(config, args)
exp.setup()
if args.run_mode == 'test':
exp.run_test()
elif args.run_mode == 'export':
exp.run_export()
elif args.run_mode == 'align':
exp.run_align()
def main(config, args):
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument(
'--model-name',
type=str,
default='u2_kaldi',
help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st')
parser.add_argument(
'--run-mode',
type=str,
default='test',
help='run mode, e.g. test, align, export')
args = parser.parse_args()
print_arguments(args, globals())
config = CfgNode()
config.set_new_allowed(True)
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats('test.profile')
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer for U2 model."""
import cProfile
import os
from paddle import distributed as dist
from yacs.config import CfgNode
from deepspeech.training.cli import default_argument_parser
from deepspeech.utils.dynamic_import import dynamic_import
from deepspeech.utils.utility import print_arguments
model_train_alias = {
"u2": "deepspeech.exps.u2.model:U2Trainer",
"u2_kaldi": "deepspeech.exps.u2_kaldi.model:U2Trainer",
}
def main_sp(config, args):
class_obj = dynamic_import(args.model_name, model_train_alias)
exp = class_obj(config, args)
exp.setup()
exp.run()
def main(config, args):
if args.device == "gpu" and args.nprocs > 1:
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
else:
main_sp(config, args)
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument(
'--model-name',
type=str,
default='u2_kaldi',
help='model name, e.g: deepspeech2, u2, u2_kaldi, u2_st')
args = parser.parse_args()
print_arguments(args, globals())
config = CfgNode()
config.set_new_allowed(True)
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
if args.dump_config:
with open(args.dump_config, 'w') as f:
print(config, file=f)
# Setting for profiling
pr = cProfile.Profile()
pr.runcall(main, config, args)
pr.dump_stats(os.path.join(args.output, 'train.profile'))
此差异已折叠。
......@@ -97,14 +97,14 @@ class AugmentationPipeline():
ValueError: If the augmentation json config is in incorrect format".
"""
SPEC_TYPES = {'specaug'}
def __init__(self, augmentation_config: str, random_seed: int=0):
self._rng = np.random.RandomState(random_seed)
self._spec_types = ('specaug')
if augmentation_config is None:
self.conf = {}
else:
self.conf = json.loads(augmentation_config)
self.conf = {'mode': 'sequential', 'process': []}
if augmentation_config:
process = json.loads(augmentation_config)
self.conf['process'] += process
self._augmentors, self._rates = self._parse_pipeline_from('all')
self._audio_augmentors, self._audio_rates = self._parse_pipeline_from(
......@@ -186,9 +186,9 @@ class AugmentationPipeline():
audio_confs = []
feature_confs = []
all_confs = []
for config in self.conf:
for config in self.conf['process']:
all_confs.append(config)
if config["type"] in self._spec_types:
if config["type"] in self.SPEC_TYPES:
feature_confs.append(config)
else:
audio_confs.append(config)
......
......@@ -30,7 +30,7 @@ class AugmentorBase():
@abstractmethod
def __call__(self, xs):
raise NotImplementedError
raise NotImplementedError("AugmentorBase: Not impl __call__")
@abstractmethod
def transform_audio(self, audio_segment):
......@@ -44,7 +44,7 @@ class AugmentorBase():
:param audio_segment: Audio segment to add effects to.
:type audio_segment: AudioSegmenet|SpeechSegment
"""
raise NotImplementedError
raise NotImplementedError("AugmentorBase: Not impl transform_audio")
@abstractmethod
def transform_feature(self, spec_segment):
......@@ -56,4 +56,4 @@ class AugmentorBase():
Args:
spec_segment (Spectrogram): Spectrogram segment to add effects to.
"""
raise NotImplementedError
raise NotImplementedError("AugmentorBase: Not impl transform_feature")
......@@ -34,6 +34,7 @@ class ImpulseResponseAugmentor(AugmentorBase):
if not train:
return
self.transform_audio(x)
return x
def transform_audio(self, audio_segment):
"""Add impulse response effect.
......
......@@ -40,6 +40,7 @@ class NoisePerturbAugmentor(AugmentorBase):
if not train:
return
self.transform_audio(x)
return x
def transform_audio(self, audio_segment):
"""Add background noise audio.
......
......@@ -48,6 +48,7 @@ class OnlineBayesianNormalizationAugmentor(AugmentorBase):
if not train:
return
self.transform_audio(x)
return x
def transform_audio(self, audio_segment):
"""Normalizes the input audio using the online Bayesian approach.
......
......@@ -35,6 +35,7 @@ class ResampleAugmentor(AugmentorBase):
if not train:
return
self.transform_audio(x)
return x
def transform_audio(self, audio_segment):
"""Resamples the input audio to a target sample rate.
......
......@@ -35,6 +35,7 @@ class ShiftPerturbAugmentor(AugmentorBase):
if not train:
return
self.transform_audio(x)
return x
def transform_audio(self, audio_segment):
"""Shift audio.
......
......@@ -64,7 +64,6 @@ class SpecAugmentor(AugmentorBase):
self.n_freq_masks = n_freq_masks
self.n_time_masks = n_time_masks
self.p = p
#logger.info(f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}")
# adaptive SpecAugment
self.adaptive_number_ratio = adaptive_number_ratio
......@@ -121,6 +120,9 @@ class SpecAugmentor(AugmentorBase):
def time_mask(self):
return self._time_mask
def __repr__(self):
return f"specaug: F-{F}, T-{T}, F-n-{n_freq_masks}, T-n-{n_time_masks}"
def time_warp(xs, W=40):
raise NotImplementedError
......@@ -160,7 +162,7 @@ class SpecAugmentor(AugmentorBase):
def __call__(self, x, train=True):
if not train:
return
self.transform_audio(x)
return self.transform_feature(x)
def transform_feature(self, xs: np.ndarray):
"""
......
......@@ -83,6 +83,7 @@ class SpeedPerturbAugmentor(AugmentorBase):
if not train:
return
self.transform_audio(x)
return x
def transform_audio(self, audio_segment):
"""Sample a new speed rate from the given range and
......
......@@ -41,6 +41,7 @@ class VolumePerturbAugmentor(AugmentorBase):
if not train:
return
self.transform_audio(x)
return x
def transform_audio(self, audio_segment):
"""Change audio loadness.
......
......@@ -49,6 +49,7 @@ class CustomConverter():
# batch should be located in list
assert len(batch) == 1
(xs, ys), utts = batch[0]
assert xs[0] is not None, "please check Reader and Augmentation impl."
# perform subsampling
if self.subsampling_factor > 1:
......
......@@ -11,6 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from typing import Dict
from typing import List
from typing import Text
import numpy as np
from paddle.io import DataLoader
from deepspeech.frontend.utility import read_manifest
......@@ -25,6 +31,18 @@ __all__ = ["BatchDataLoader"]
logger = Log(__name__).getlog()
def feat_dim_and_vocab_size(data_json: List[Dict[Text, Any]],
mode: Text="asr",
iaxis=0,
oaxis=0):
if mode == 'asr':
feat_dim = data_json[0]['input'][oaxis]['shape'][1]
vocab_size = data_json[0]['output'][oaxis]['shape'][1]
else:
raise ValueError(f"{mode} mode not support!")
return feat_dim, vocab_size
class BatchDataLoader():
def __init__(self,
json_file: str,
......@@ -62,6 +80,8 @@ class BatchDataLoader():
# read json data
self.data_json = read_manifest(json_file)
self.feat_dim, self.vocab_size = feat_dim_and_vocab_size(
self.data_json, mode='asr')
# make minibatch list (variable length)
self.minibaches = make_batchset(
......@@ -106,7 +126,7 @@ class BatchDataLoader():
self.dataloader = DataLoader(
dataset=self.dataset,
batch_size=1,
shuffle=not use_sortagrad if train_mode else False,
shuffle=not self.use_sortagrad if train_mode else False,
collate_fn=lambda x: x[0],
num_workers=n_iter_processes, )
......
......@@ -66,8 +66,9 @@ class LoadInputsAndTargets():
raise ValueError("Only asr are allowed: mode={}".format(mode))
if preprocess_conf is not None:
self.preprocessing = AugmentationPipeline(preprocess_conf)
logging.warning(
with open(preprocess_conf, 'r') as fin:
self.preprocessing = AugmentationPipeline(fin.read())
logger.warning(
"[Experimental feature] Some preprocessing will be done "
"for the mini-batch creation using {}".format(
self.preprocessing))
......@@ -197,7 +198,7 @@ class LoadInputsAndTargets():
nonzero_sorted_idx = nonzero_idx
if len(nonzero_sorted_idx) != len(xs[0]):
logging.warning(
logger.warning(
"Target sequences include empty tokenid (batch {} -> {}).".
format(len(xs[0]), len(nonzero_sorted_idx)))
......
......@@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False
......
......@@ -612,32 +612,32 @@ class U2BaseModel(nn.Layer):
best_index = i
return hyps[best_index][0]
#@jit.export
#@jit.to_static
def subsampling_rate(self) -> int:
""" Export interface for c++ call, return subsampling_rate of the
model
"""
return self.encoder.embed.subsampling_rate
#@jit.export
#@jit.to_static
def right_context(self) -> int:
""" Export interface for c++ call, return right_context of the model
"""
return self.encoder.embed.right_context
#@jit.export
#@jit.to_static
def sos_symbol(self) -> int:
""" Export interface for c++ call, return sos symbol id of the model
"""
return self.sos
#@jit.export
#@jit.to_static
def eos_symbol(self) -> int:
""" Export interface for c++ call, return eos symbol id of the model
"""
return self.eos
@jit.export
@jit.to_static
def forward_encoder_chunk(
self,
xs: paddle.Tensor,
......@@ -667,7 +667,7 @@ class U2BaseModel(nn.Layer):
xs, offset, required_cache_size, subsampling_cache,
elayers_output_cache, conformer_cnn_cache)
# @jit.export([
# @jit.to_static([
# paddle.static.InputSpec(shape=[1, None, feat_dim],dtype='float32'), # audio feat, [B,T,D]
# ])
def ctc_activation(self, xs: paddle.Tensor) -> paddle.Tensor:
......@@ -680,7 +680,7 @@ class U2BaseModel(nn.Layer):
"""
return self.ctc.log_softmax(xs)
@jit.export
@jit.to_static
def forward_attention_decoder(
self,
hyps: paddle.Tensor,
......
......@@ -69,7 +69,7 @@ class ConvGLUBlock(nn.Layer):
dim=0)
self.dropout_residual = nn.Dropout(p=dropout)
self.pad_left = ConstantPad2d((0, 0, kernel_size - 1, 0), 0)
self.pad_left = nn.Pad2d((0, 0, kernel_size - 1, 0), 0)
layers = OrderedDict()
if bottlececk_dim == 0:
......
......@@ -15,6 +15,7 @@ from typing import Any
from typing import Dict
from typing import Text
import paddle
from paddle.optimizer import Optimizer
from paddle.regularizer import L2Decay
......@@ -43,6 +44,40 @@ def register_optimizer(cls):
return cls
@register_optimizer
class Noam(paddle.optimizer.Adam):
"""Seem to: espnet/nets/pytorch_backend/transformer/optimizer.py """
def __init__(self,
learning_rate=0,
beta1=0.9,
beta2=0.98,
epsilon=1e-9,
parameters=None,
weight_decay=None,
grad_clip=None,
lazy_mode=False,
multi_precision=False,
name=None):
super().__init__(
learning_rate=learning_rate,
beta1=beta1,
beta2=beta2,
epsilon=epsilon,
parameters=parameters,
weight_decay=weight_decay,
grad_clip=grad_clip,
lazy_mode=lazy_mode,
multi_precision=multi_precision,
name=name)
def __repr__(self):
echo = f"<{self.__class__.__module__}.{self.__class__.__name__} object at {hex(id(self))}> "
echo += f"learning_rate: {self._learning_rate}, "
echo += f"(beta1: {self._beta1} beta2: {self._beta2}), "
echo += f"epsilon: {self._epsilon}"
def dynamic_import_optimizer(module):
"""Import Optimizer class dynamically.
......@@ -69,15 +104,18 @@ class OptimizerFactory():
args['grad_clip']) if "grad_clip" in args else None
weight_decay = L2Decay(
args['weight_decay']) if "weight_decay" in args else None
module_class = dynamic_import_optimizer(name.lower())
if weight_decay:
logger.info(f'WeightDecay: {weight_decay}')
logger.info(f'<WeightDecay - {weight_decay}>')
if grad_clip:
logger.info(f'GradClip: {grad_clip}')
logger.info(
f"Optimizer: {module_class.__name__} {args['learning_rate']}")
logger.info(f'<GradClip - {grad_clip}>')
module_class = dynamic_import_optimizer(name.lower())
args.update({"grad_clip": grad_clip, "weight_decay": weight_decay})
return instance_class(module_class, args)
opt = instance_class(module_class, args)
if "__repr__" in vars(opt):
logger.info(f"{opt}")
else:
logger.info(
f"<Optimizer {module_class.__module__}.{module_class.__name__}> LR: {args['learning_rate']}"
)
return opt
......@@ -41,22 +41,6 @@ def register_scheduler(cls):
return cls
def dynamic_import_scheduler(module):
"""Import Scheduler class dynamically.
Args:
module (str): module_name:class_name or alias in `SCHEDULER_DICT`
Returns:
type: Scheduler class
"""
module_class = dynamic_import(module, SCHEDULER_DICT)
assert issubclass(module_class,
LRScheduler), f"{module} does not implement LRScheduler"
return module_class
@register_scheduler
class WarmupLR(LRScheduler):
"""The WarmupLR scheduler
......@@ -102,6 +86,41 @@ class WarmupLR(LRScheduler):
self.step(epoch=step)
@register_scheduler
class ConstantLR(LRScheduler):
"""
Args:
learning_rate (float): The initial learning rate. It is a python float number.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``ConstantLR`` instance to schedule learning rate.
"""
def __init__(self, learning_rate, last_epoch=-1, verbose=False):
super().__init__(learning_rate, last_epoch, verbose)
def get_lr(self):
return self.base_lr
def dynamic_import_scheduler(module):
"""Import Scheduler class dynamically.
Args:
module (str): module_name:class_name or alias in `SCHEDULER_DICT`
Returns:
type: Scheduler class
"""
module_class = dynamic_import(module, SCHEDULER_DICT)
assert issubclass(module_class,
LRScheduler), f"{module} does not implement LRScheduler"
return module_class
class LRSchedulerFactory():
@classmethod
def from_args(cls, name: str, args: Dict[Text, Any]):
......
[
{
"type": "shift",
"params": {
"min_shift_ms": -5,
"max_shift_ms": 5
},
"prob": 1.0
},
{
"type": "speed",
"params": {
"min_speed_rate": 0.9,
"max_speed_rate": 1.1,
"num_rates": 3
},
"prob": 0.0
},
{
"type": "specaug",
"params": {
......
......@@ -19,7 +19,7 @@ collator:
batch_size: 64
raw_wav: True # use raw_wav or kaldi feature
specgram_type: fbank #linear, mfcc, fbank
feat_dim: 80
feat_dim: 83
delta_delta: False
dither: 1.0
target_sample_rate: 16000
......@@ -38,7 +38,7 @@ collator:
# network architecture
model:
cmvn_file: "data/mean_std.json"
cmvn_file:
cmvn_file_type: "json"
# encoder related
encoder: transformer
......@@ -74,20 +74,20 @@ model:
training:
n_epoch: 120
accum_grad: 2
global_grad_clip: 5.0
optim: adam
optim_conf:
lr: 0.004
weight_decay: 1e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
lr_decay: 1.0
log_interval: 100
checkpoint:
kbest_n: 50
latest_n: 5
optim: adam
optim_conf:
global_grad_clip: 5.0
weight_decay: 1.0e-06
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
lr: 0.004
warmup_steps: 25000
lr_decay: 1.0
decoding:
batch_size: 64
......
......@@ -21,7 +21,8 @@ mkdir -p ${output_dir}
# align dump in `result_file`
# .tier, .TextGrid dump in `dir of result_file`
python3 -u ${BIN_DIR}/alignment.py \
python3 -u ${BIN_DIR}/test.py \
--run_mode 'align' \
--device ${device} \
--nproc 1 \
--config ${config_path} \
......
......@@ -17,7 +17,8 @@ if [ ${ngpu} == 0 ];then
device=cpu
fi
python3 -u ${BIN_DIR}/export.py \
python3 -u ${BIN_DIR}/test.py \
--run_mode 'export' \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
......
......@@ -38,6 +38,7 @@ for type in attention ctc_greedy_search; do
batch_size=64
fi
python3 -u ${BIN_DIR}/test.py \
--run_mode test \
--device ${device} \
--nproc 1 \
--config ${config_path} \
......@@ -55,6 +56,7 @@ for type in ctc_prefix_beam_search attention_rescoring; do
echo "decoding ${type}"
batch_size=1
python3 -u ${BIN_DIR}/test.py \
--run_mode test \
--device ${device} \
--nproc 1 \
--config ${config_path} \
......
......@@ -20,6 +20,7 @@ echo "using ${device}..."
mkdir -p exp
python3 -u ${BIN_DIR}/train.py \
--model-name u2_kaldi \
--device ${device} \
--nproc ${ngpu} \
--config ${config_path} \
......
......@@ -10,5 +10,5 @@ export PYTHONPATH=${MAIN_ROOT}:${PYTHONPATH}
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/lib/
MODEL=u2
MODEL=u2_kaldi
export BIN_DIR=${MAIN_ROOT}/deepspeech/exps/${MODEL}/bin
......@@ -7,4 +7,3 @@
* https://github.com/NVIDIA/FasterTransformer.git
* https://github.com/idiap/fast-transformers
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册