From 660efcebbe09d2812a0e736872083e9b4700c731 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 5 Jul 2021 04:09:55 +0000 Subject: [PATCH] format code --- deepspeech/models/u2.py | 2 +- deepspeech/modules/attention.py | 3 ++- deepspeech/utils/checkpoint.py | 14 +++++++------- deepspeech/utils/tensor_utils.py | 5 +++-- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/deepspeech/models/u2.py b/deepspeech/models/u2.py index a9f37833..c3d93d8a 100644 --- a/deepspeech/models/u2.py +++ b/deepspeech/models/u2.py @@ -24,7 +24,6 @@ from typing import Optional from typing import Tuple import paddle -from paddle import jit from paddle import nn from yacs.config import CfgNode @@ -48,6 +47,7 @@ from deepspeech.utils.tensor_utils import add_sos_eos from deepspeech.utils.tensor_utils import pad_sequence from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.utility import log_add +# from paddle import jit __all__ = ["U2Model", "U2InferModel"] diff --git a/deepspeech/modules/attention.py b/deepspeech/modules/attention.py index ac1e9b75..afc70214 100644 --- a/deepspeech/modules/attention.py +++ b/deepspeech/modules/attention.py @@ -175,7 +175,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): (x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype) x_padded = paddle.cat([zero_pad, x], dim=-1) - x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]) + x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1, + x.shape[2]) x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] if zero_triu: diff --git a/deepspeech/utils/checkpoint.py b/deepspeech/utils/checkpoint.py index a2f7e18a..3bb04e7d 100644 --- a/deepspeech/utils/checkpoint.py +++ b/deepspeech/utils/checkpoint.py @@ -16,8 +16,8 @@ import json import os import re from pathlib import Path -from typing import Union from typing import Text +from typing import Union import paddle from paddle import distributed as dist @@ -72,13 +72,13 @@ class Checkpoint(): if isinstance(tag_or_iteration, int): self._save_checkpoint_record(checkpoint_dir, tag_or_iteration) - + def load_parameters(self, - model, - optimizer=None, - checkpoint_dir=None, - checkpoint_path=None, - record_file="checkpoint_latest"): + model, + optimizer=None, + checkpoint_dir=None, + checkpoint_path=None, + record_file="checkpoint_latest"): """Load a last model checkpoint from disk. Args: model (Layer): model to load parameters. diff --git a/deepspeech/utils/tensor_utils.py b/deepspeech/utils/tensor_utils.py index 2e605999..17becf6d 100644 --- a/deepspeech/utils/tensor_utils.py +++ b/deepspeech/utils/tensor_utils.py @@ -151,8 +151,9 @@ def th_accuracy(pad_outputs: paddle.Tensor, Returns: float: Accuracy value (0.0 - 1.0). """ - pad_pred = pad_outputs.view( - pad_targets.shape[0], pad_targets.size(1), pad_outputs.size(1)).argmax(2) + pad_pred = pad_outputs.view(pad_targets.shape[0], + pad_targets.size(1), + pad_outputs.size(1)).argmax(2) mask = pad_targets != ignore_label #TODO(Hui Zhang): sum not support bool type # numerator = paddle.sum( -- GitLab