提交 e4ecfb22 编写于 作者: H Hui Zhang

format code

上级 3fa2e44e
......@@ -355,7 +355,6 @@ if not hasattr(paddle.Tensor, 'tolist'):
"register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist)
########### hack paddle.nn #############
from paddle.nn import Layer
from typing import Optional
......@@ -506,5 +505,3 @@ if not hasattr(paddle.nn, 'LayerDict'):
logger.debug(
"register user LayerDict to paddle.nn, remove this when fixed!")
setattr(paddle.nn, 'LayerDict', LayerDict)
......@@ -12,12 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""V2 backend for `asr_recog.py` using py:class:`decoders.beam_search.BeamSearch`."""
import json
from pathlib import Path
import jsonlines
import paddle
import yaml
from yacs.config import CfgNode
from .beam_search import BatchBeamSearch
......@@ -79,8 +75,7 @@ def recog_v2(args):
sort_in_input_length=False,
preprocess_conf=confs.collator.augmentation_config
if args.preprocess_conf is None else args.preprocess_conf,
preprocess_args={"train": False},
)
preprocess_args={"train": False}, )
if args.rnnlm:
lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
......@@ -113,8 +108,7 @@ def recog_v2(args):
ctc=args.ctc_weight,
lm=args.lm_weight,
ngram=args.ngram_weight,
length_bonus=args.penalty,
)
length_bonus=args.penalty, )
beam_search = BeamSearch(
beam_size=args.beam_size,
vocab_size=len(char_list),
......@@ -123,8 +117,7 @@ def recog_v2(args):
sos=model.sos,
eos=model.eos,
token_list=char_list,
pre_beam_score_key=None if args.ctc_weight == 1.0 else "full",
)
pre_beam_score_key=None if args.ctc_weight == 1.0 else "full", )
# TODO(karita): make all scorers batchfied
if args.batchsize == 1:
......@@ -171,9 +164,10 @@ def recog_v2(args):
logger.info(f'feat: {feat.shape}')
enc = model.encode(paddle.to_tensor(feat).to(dtype))
logger.info(f'eout: {enc.shape}')
nbest_hyps = beam_search(x=enc,
maxlenratio=args.maxlenratio,
minlenratio=args.minlenratio)
nbest_hyps = beam_search(
x=enc,
maxlenratio=args.maxlenratio,
minlenratio=args.minlenratio)
nbest_hyps = [
h.asdict()
for h in nbest_hyps[:min(len(nbest_hyps), args.nbest)]
......@@ -183,9 +177,8 @@ def recog_v2(args):
item = new_js[name]['output'][0] # 1-best
ref = item['text']
rec_text = item['rec_text'].replace('▁',
' ').replace('<eos>',
'').strip()
rec_text = item['rec_text'].replace('▁', ' ').replace(
'<eos>', '').strip()
rec_tokenid = list(map(int, item['rec_tokenid'].split()))
f.write({
"utt": name,
......
......@@ -21,8 +21,6 @@ from distutils.util import strtobool
import configargparse
import numpy as np
from .recog import recog_v2
def get_parser():
"""Get default arguments."""
......
......@@ -110,7 +110,7 @@ class ASRInterface:
@property
def ctc_plot_class(self):
"""Get CTC plot class."""
from deepspeech.training.extensions.plot import PlotCTCReport
from deepspeech.training.extensions.plot import PlotCTCReport
return PlotCTCReport
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -20,11 +20,11 @@ import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from deepspeech.modules.mask import subsequent_mask
from deepspeech.modules.encoder import TransformerEncoder
from deepspeech.decoders.scorers.scorer_interface import BatchScorerInterface
from deepspeech.models.lm_interface import
#LMInterface
from deepspeech.models.lm_interface import LMInterface
from deepspeech.modules.encoder import TransformerEncoder
from deepspeech.modules.mask import subsequent_mask
class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
def __init__(
......@@ -36,10 +36,10 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
head: int=2,
unit: int=1024,
layer: int=4,
dropout_rate: float=0.5,
emb_dropout_rate: float = 0.0,
att_dropout_rate: float = 0.0,
tie_weights: bool = False,):
dropout_rate: float=0.5,
emb_dropout_rate: float=0.0,
att_dropout_rate: float=0.0,
tie_weights: bool=False, ):
nn.Layer.__init__(self)
if pos_enc == "sinusoidal":
......@@ -89,9 +89,8 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
m = subsequent_mask(ys_mask.size(-1)).unsqueeze(0)
return ys_mask.unsqueeze(-2) & m
def forward(
self, x: paddle.Tensor, t: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
def forward(self, x: paddle.Tensor, t: paddle.Tensor
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Compute LM loss value from buffer sequences.
Args:
......@@ -117,7 +116,8 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
emb = self.embed(x)
h, _ = self.encoder(emb, xlen)
y = self.decoder(h)
loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
loss = F.cross_entropy(
y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
mask = xm.to(dtype=loss.dtype)
logp = loss * mask.view(-1)
logp = logp.sum()
......@@ -148,16 +148,16 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
emb = self.embed(y)
h, _, cache = self.encoder.forward_one_step(
emb, self._target_mask(y), cache=state
)
emb, self._target_mask(y), cache=state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(axis=-1).squeeze(0)
return logp, cache
# batch beam search API (see BatchScorerInterface)
def batch_score(
self, ys: paddle.Tensor, states: List[Any], xs: paddle.Tensor
) -> Tuple[paddle.Tensor, List[Any]]:
def batch_score(self,
ys: paddle.Tensor,
states: List[Any],
xs: paddle.Tensor) -> Tuple[paddle.Tensor, List[Any]]:
"""Score new token batch (required).
Args:
......@@ -191,13 +191,13 @@ class TransformerLM(nn.Layer, LMInterface, BatchScorerInterface):
# batch decoding
h, _, states = self.encoder.forward_one_step(
emb, self._target_mask(ys), cache=batch_state
)
emb, self._target_mask(ys), cache=batch_state)
h = self.decoder(h[:, -1])
logp = h.log_softmax(axi=-1)
# transpose state of [layer, batch] into [batch, layer]
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
state_list = [[states[i][b] for i in range(n_layers)]
for b in range(n_batch)]
return logp, state_list
......@@ -212,17 +212,17 @@ if __name__ == "__main__":
layer=16,
dropout_rate=0.5, )
# n_vocab: int,
# pos_enc: str=None,
# embed_unit: int=128,
# att_unit: int=256,
# head: int=2,
# unit: int=1024,
# layer: int=4,
# dropout_rate: float=0.5,
# emb_dropout_rate: float = 0.0,
# att_dropout_rate: float = 0.0,
# tie_weights: bool = False,):
# n_vocab: int,
# pos_enc: str=None,
# embed_unit: int=128,
# att_unit: int=256,
# head: int=2,
# unit: int=1024,
# layer: int=4,
# dropout_rate: float=0.5,
# emb_dropout_rate: float = 0.0,
# att_dropout_rate: float = 0.0,
# tie_weights: bool = False,):
paddle.set_device("cpu")
model_dict = paddle.load("transformerLM.pdparams")
tlm.set_state_dict(model_dict)
......@@ -256,4 +256,4 @@ if __name__ == "__main__":
print("output", output)
#print("cache", cache)
#np.save("output_pd.npy", output)
"""
\ No newline at end of file
"""
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Language model interface."""
import argparse
from deepspeech.decoders.scorers.scorer_interface import ScorerInterface
from deepspeech.utils.dynamic_import import dynamic_import
class LMInterface(ScorerInterface):
"""LM Interface model implementation."""
......@@ -52,6 +65,7 @@ predefined_lms = {
"transformer": "deepspeech.models.lm.transformer:TransformerLM",
}
def dynamic_import_lm(module):
"""Import LM class dynamically.
......@@ -63,7 +77,6 @@ def dynamic_import_lm(module):
"""
model_class = dynamic_import(module, predefined_lms)
assert issubclass(
model_class, LMInterface
), f"{module} does not implement LMInterface"
assert issubclass(model_class,
LMInterface), f"{module} does not implement LMInterface"
return model_class
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ST Interface module."""
import argparse
from deepspeech.utils.dynamic_import import dynamic_import
from .asr_interface import ASRInterface
from deepspeech.utils.dynamic_import import dynamic_import
class STInterface(ASRInterface):
"""ST Interface model implementation.
......@@ -13,7 +24,12 @@ class STInterface(ASRInterface):
"""
def translate(self, x, trans_args, char_list=None, rnnlm=None, ensemble_models=[]):
def translate(self,
x,
trans_args,
char_list=None,
rnnlm=None,
ensemble_models=[]):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
......@@ -42,6 +58,7 @@ predefined_st = {
"transformer": "deepspeech.models.u2_st:U2STModel",
}
def dynamic_import_st(module):
"""Import ST models dynamically.
......@@ -53,7 +70,6 @@ def dynamic_import_st(module):
"""
model_class = dynamic_import(module, predefined_st)
assert issubclass(
model_class, STInterface
), f"{module} does not implement STInterface"
assert issubclass(model_class,
STInterface), f"{module} does not implement STInterface"
return model_class
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .u2_st import U2STInferModel
from .u2_st import U2STModel
from .u2_st import U2STInferModel
\ No newline at end of file
......@@ -11,9 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
import paddle
from paddle import nn
from typing import Union
from paddle.nn import functional as F
from typeguard import check_argument_types
......
......@@ -22,7 +22,10 @@ from deepspeech.utils.log import Log
logger = Log(__name__).getlog()
__all__ = ["NoPositionalEncoding", "PositionalEncoding", "RelPositionalEncoding"]
__all__ = [
"NoPositionalEncoding", "PositionalEncoding", "RelPositionalEncoding"
]
class NoPositionalEncoding(nn.Layer):
def __init__(self,
......
......@@ -24,9 +24,9 @@ from deepspeech.modules.activation import get_activation
from deepspeech.modules.attention import MultiHeadedAttention
from deepspeech.modules.attention import RelPositionMultiHeadedAttention
from deepspeech.modules.conformer_convolution import ConvolutionModule
from deepspeech.modules.embedding import NoPositionalEncoding
from deepspeech.modules.embedding import PositionalEncoding
from deepspeech.modules.embedding import RelPositionalEncoding
from deepspeech.modules.embedding import NoPositionalEncoding
from deepspeech.modules.encoder_layer import ConformerEncoderLayer
from deepspeech.modules.encoder_layer import TransformerEncoderLayer
from deepspeech.modules.mask import add_optional_chunk_mask
......@@ -103,7 +103,7 @@ class BaseEncoder(nn.Layer):
pos_enc_class = PositionalEncoding
elif pos_enc_layer_type == "rel_pos":
pos_enc_class = RelPositionalEncoding
elif pos_enc_layer_type is "no_pos":
elif pos_enc_layer_type == "no_pos":
pos_enc_class = NoPositionalEncoding
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
......@@ -378,8 +378,7 @@ class TransformerEncoder(BaseEncoder):
self,
xs: paddle.Tensor,
masks: paddle.Tensor,
cache=None,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
cache=None, ) -> Tuple[paddle.Tensor, paddle.Tensor]:
"""Encode input frame.
Args:
......@@ -397,7 +396,8 @@ class TransformerEncoder(BaseEncoder):
if isinstance(self.embed, Conv2dSubsampling):
#TODO(Hui Zhang): self.embed(xs, masks, offset=0), stride_slice not support bool tensor
xs, pos_emb, masks = self.embed(xs, masks.astype(xs.dtype), offset=0)
xs, pos_emb, masks = self.embed(
xs, masks.astype(xs.dtype), offset=0)
else:
xs = self.embed(xs)
#TODO(Hui Zhang): remove mask.astype, stride_slice not support bool tensor
......
......@@ -60,8 +60,8 @@ class LinearNoSubsampling(BaseSubsampling):
self.out = nn.Sequential(
nn.Linear(idim, odim),
nn.LayerNorm(odim, epsilon=1e-12),
nn.Dropout(dropout_rate),
nn.ReLU(),)
nn.Dropout(dropout_rate),
nn.ReLU(), )
self.right_context = 0
self.subsampling_rate = 1
......@@ -83,10 +83,12 @@ class LinearNoSubsampling(BaseSubsampling):
x, pos_emb = self.pos_enc(x, offset)
return x, pos_emb, x_mask
class Conv2dSubsampling(BaseSubsampling):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class Conv2dSubsampling4(Conv2dSubsampling):
"""Convolutional 2D subsampling (to 1/4 length)."""
......
import argparse
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import os
import shutil
import tempfile
import numpy as np
import numpy as np
from . import extension
from ..updaters.trainer import Trainer
class PlotAttentionReport(extension.Extension):
......@@ -37,20 +44,19 @@ class PlotAttentionReport(extension.Extension):
"""
def __init__(
self,
att_vis_fn,
data,
outdir,
converter,
transform,
device,
reverse=False,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
subsampling_factor=1,
):
self,
att_vis_fn,
data,
outdir,
converter,
transform,
device,
reverse=False,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
subsampling_factor=1, ):
self.att_vis_fn = att_vis_fn
self.data = copy.deepcopy(data)
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
......@@ -77,44 +83,30 @@ class PlotAttentionReport(extension.Extension):
for i in range(num_encs):
for idx, att_w in enumerate(att_ws[i]):
filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % (
self.outdir,
uttid_list[idx],
i + 1,
)
self.outdir, uttid_list[idx], i + 1, )
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % (
self.outdir,
uttid_list[idx],
i + 1,
)
self.outdir, uttid_list[idx], i + 1, )
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(att_w, filename.format(trainer))
self._plot_and_save_attention(att_w,
filename.format(trainer))
# han
for idx, att_w in enumerate(att_ws[num_encs]):
filename = "%s/%s.ep.{.updater.epoch}.han.png" % (
self.outdir,
uttid_list[idx],
)
self.outdir, uttid_list[idx], )
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % (
self.outdir,
uttid_list[idx],
)
self.outdir, uttid_list[idx], )
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(
att_w, filename.format(trainer), han_mode=True
)
att_w, filename.format(trainer), han_mode=True)
else:
for idx, att_w in enumerate(att_ws):
filename = "%s/%s.ep.{.updater.epoch}.png" % (
self.outdir,
uttid_list[idx],
)
filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
uttid_list[idx], )
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
self.outdir,
uttid_list[idx],
)
self.outdir, uttid_list[idx], )
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(att_w, filename.format(trainer))
......@@ -131,8 +123,7 @@ class PlotAttentionReport(extension.Extension):
logger.add_figure(
"%s_att%d" % (uttid_list[idx], i + 1),
plot.gcf(),
step,
)
step, )
# han
for idx, att_w in enumerate(att_ws[num_encs]):
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
......@@ -140,8 +131,7 @@ class PlotAttentionReport(extension.Extension):
logger.add_figure(
"%s_han" % (uttid_list[idx]),
plot.gcf(),
step,
)
step, )
else:
for idx, att_w in enumerate(att_ws):
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
......@@ -286,20 +276,19 @@ class PlotCTCReport(extension.Extension):
"""
def __init__(
self,
ctc_vis_fn,
data,
outdir,
converter,
transform,
device,
reverse=False,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
subsampling_factor=1,
):
self,
ctc_vis_fn,
data,
outdir,
converter,
transform,
device,
reverse=False,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
subsampling_factor=1, ):
self.ctc_vis_fn = ctc_vis_fn
self.data = copy.deepcopy(data)
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
......@@ -325,29 +314,19 @@ class PlotCTCReport(extension.Extension):
for i in range(num_encs):
for idx, ctc_prob in enumerate(ctc_probs[i]):
filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % (
self.outdir,
uttid_list[idx],
i + 1,
)
self.outdir, uttid_list[idx], i + 1, )
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % (
self.outdir,
uttid_list[idx],
i + 1,
)
self.outdir, uttid_list[idx], i + 1, )
np.save(np_filename.format(trainer), ctc_prob)
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
else:
for idx, ctc_prob in enumerate(ctc_probs):
filename = "%s/%s.ep.{.updater.epoch}.png" % (
self.outdir,
uttid_list[idx],
)
filename = "%s/%s.ep.{.updater.epoch}.png" % (self.outdir,
uttid_list[idx], )
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
self.outdir,
uttid_list[idx],
)
self.outdir, uttid_list[idx], )
np.save(np_filename.format(trainer), ctc_prob)
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
......@@ -363,8 +342,7 @@ class PlotCTCReport(extension.Extension):
logger.add_figure(
"%s_ctc%d" % (uttid_list[idx], i + 1),
plot.gcf(),
step,
)
step, )
else:
for idx, ctc_prob in enumerate(ctc_probs):
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
......@@ -420,8 +398,11 @@ class PlotCTCReport(extension.Extension):
for idx in set(topk_ids.reshape(-1).tolist()):
if idx == 0:
plt.plot(
times_probs, ctc_prob[:, 0], ":", label="<blank>", color="grey"
)
times_probs,
ctc_prob[:, 0],
":",
label="<blank>",
color="grey")
else:
plt.plot(times_probs, ctc_prob[:, idx])
plt.xlabel(u"Input [frame]", fontsize=12)
......@@ -434,4 +415,4 @@ class PlotCTCReport(extension.Extension):
def _plot_and_save_ctc(self, ctc_prob, filename):
plt = self.draw_ctc_plot(ctc_prob)
plt.savefig(filename)
plt.close()
\ No newline at end of file
plt.close()
......@@ -36,4 +36,4 @@ class VisualDL(extension.Extension):
self.writer.add_scalar(k, v, step=trainer.updater.state.iteration)
def finalize(self, trainer):
self.writer.close()
\ No newline at end of file
self.writer.close()
......@@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .utils import get_trigger
from ..reporter import DictSummary
from .utils import get_trigger
class CompareValueTrigger():
"""Trigger invoked when key value getting bigger or lower than before.
......@@ -24,6 +24,7 @@ class CompareValueTrigger():
trigger (tuple(int, str)) : Trigger that decide the comparison interval.
"""
def __init__(self, key, compare_fn, trigger=(1, "epoch")):
self._key = key
self._best_value = None
......@@ -57,4 +58,4 @@ class CompareValueTrigger():
return False
def _init_summary(self):
self._summary = DictSummary()
\ No newline at end of file
self._summary = DictSummary()
......@@ -28,4 +28,4 @@ class LimitTrigger():
state = trainer.updater.state
index = getattr(state, self.unit)
fire = index >= self.limit
return fire
\ No newline at end of file
return fire
......@@ -38,4 +38,4 @@ class TimeTrigger():
return state_dict
def set_state_dict(self, state_dict):
self._next_time = state_dict['next_time']
\ No newline at end of file
self._next_time = state_dict['next_time']
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .interval_trigger import IntervalTrigger
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import numpy as np
__all__ = ["label_smoothing_dist"]
......@@ -33,8 +34,7 @@ def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
if lsm_type == "unigram":
assert transcript is not None, (
"transcript is required for %s label smoothing" % lsm_type
)
"transcript is required for %s label smoothing" % lsm_type)
labelcount = np.zeros(odim)
for k, v in trans_json.items():
ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])
......
......@@ -14,9 +14,9 @@
"""This module provides functions to calculate bleu score in different level.
e.g. wer for word-level, cer for char-level.
"""
import sacrebleu
import nltk
import numpy as np
import sacrebleu
__all__ = ['bleu', 'char_bleu', "ErrorCalculator"]
......@@ -106,11 +106,14 @@ class ErrorCalculator():
# NOTE: padding index (-1) in y_true is used to pad y_hat
# because y_hats is not padded with -1
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.pad, "")
seq_true_text = "".join(seq_true).replace(self.space, " ")
seqs_hat.append(seq_hat_text)
seqs_true.append(seq_true_text)
bleu = nltk.bleu_score.corpus_bleu([[ref] for ref in seqs_true], seqs_hat)
bleu = nltk.bleu_score.corpus_bleu([[ref] for ref in seqs_true],
seqs_hat)
return bleu * 100
......@@ -14,11 +14,10 @@
"""This module provides functions to calculate error rate in different level.
e.g. wer for word-level, cer for char-level.
"""
from itertools import groupby
import editdistance
import numpy as np
import logging
import sys
from itertools import groupby
__all__ = ['word_errors', 'char_errors', 'wer', 'cer', "ErrorCalculator"]
......@@ -225,9 +224,12 @@ class ErrorCalculator():
:return:
"""
def __init__(
self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False
):
def __init__(self,
char_list,
sym_space,
sym_blank,
report_cer=False,
report_wer=False):
"""Construct an ErrorCalculator object."""
super().__init__()
......@@ -317,7 +319,9 @@ class ErrorCalculator():
ymax = eos_true[0] if len(eos_true) > 0 else len(y_true)
# NOTE: padding index (-1) in y_true is used to pad y_hat
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]]
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
seq_true = [
self.char_list[int(idx)] for idx in y_true if int(idx) != -1
]
seq_hat_text = "".join(seq_hat).replace(self.space, " ")
seq_hat_text = seq_hat_text.replace(self.blank, "")
seq_true_text = "".join(seq_true).replace(self.space, " ")
......
......@@ -15,7 +15,6 @@ import contextlib
import inspect
import io
import os
import re
import subprocess as sp
import sys
from pathlib import Path
......@@ -84,7 +83,7 @@ def _post_install(install_lib_dir):
tools_extrs_dir = HERE / 'tools/extras'
with pushd(tools_extrs_dir):
print(os.getcwd())
check_call(f"./install_autolog.sh")
check_call("./install_autolog.sh")
print("autolog install.")
# ctcdecoder
......
......@@ -4,7 +4,6 @@
# 2018 Xuankai Chang (Shanghai Jiao Tong University)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import json
import logging
import sys
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册