提交 660efceb 编写于 作者: H Hui Zhang

format code

上级 2112ee1e
......@@ -24,7 +24,6 @@ from typing import Optional
from typing import Tuple
import paddle
from paddle import jit
from paddle import nn
from yacs.config import CfgNode
......@@ -48,6 +47,7 @@ from deepspeech.utils.tensor_utils import add_sos_eos
from deepspeech.utils.tensor_utils import pad_sequence
from deepspeech.utils.tensor_utils import th_accuracy
from deepspeech.utils.utility import log_add
# from paddle import jit
__all__ = ["U2Model", "U2InferModel"]
......
......@@ -175,7 +175,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
(x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype)
x_padded = paddle.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2])
x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1,
x.shape[2])
x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1]
if zero_triu:
......
......@@ -16,8 +16,8 @@ import json
import os
import re
from pathlib import Path
from typing import Union
from typing import Text
from typing import Union
import paddle
from paddle import distributed as dist
......@@ -72,13 +72,13 @@ class Checkpoint():
if isinstance(tag_or_iteration, int):
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
def load_parameters(self,
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None,
record_file="checkpoint_latest"):
model,
optimizer=None,
checkpoint_dir=None,
checkpoint_path=None,
record_file="checkpoint_latest"):
"""Load a last model checkpoint from disk.
Args:
model (Layer): model to load parameters.
......
......@@ -151,8 +151,9 @@ def th_accuracy(pad_outputs: paddle.Tensor,
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(
pad_targets.shape[0], pad_targets.size(1), pad_outputs.size(1)).argmax(2)
pad_pred = pad_outputs.view(pad_targets.shape[0],
pad_targets.size(1),
pad_outputs.size(1)).argmax(2)
mask = pad_targets != ignore_label
#TODO(Hui Zhang): sum not support bool type
# numerator = paddle.sum(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册