提交 660efceb 编写于 作者: H Hui Zhang

format code

上级 2112ee1e
...@@ -24,7 +24,6 @@ from typing import Optional ...@@ -24,7 +24,6 @@ from typing import Optional
from typing import Tuple from typing import Tuple
import paddle import paddle
from paddle import jit
from paddle import nn from paddle import nn
from yacs.config import CfgNode from yacs.config import CfgNode
...@@ -48,6 +47,7 @@ from deepspeech.utils.tensor_utils import add_sos_eos ...@@ -48,6 +47,7 @@ from deepspeech.utils.tensor_utils import add_sos_eos
from deepspeech.utils.tensor_utils import pad_sequence from deepspeech.utils.tensor_utils import pad_sequence
from deepspeech.utils.tensor_utils import th_accuracy from deepspeech.utils.tensor_utils import th_accuracy
from deepspeech.utils.utility import log_add from deepspeech.utils.utility import log_add
# from paddle import jit
__all__ = ["U2Model", "U2InferModel"] __all__ = ["U2Model", "U2InferModel"]
......
...@@ -175,7 +175,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention): ...@@ -175,7 +175,8 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
(x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype) (x.shape[0], x.shape[1], x.shape[2], 1), dtype=x.dtype)
x_padded = paddle.cat([zero_pad, x], dim=-1) x_padded = paddle.cat([zero_pad, x], dim=-1)
x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1, x.shape[2]) x_padded = x_padded.view(x.shape[0], x.shape[1], x.shape[3] + 1,
x.shape[2])
x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1] x = x_padded[:, :, 1:].view_as(x) # [B, H, T1, T1]
if zero_triu: if zero_triu:
......
...@@ -16,8 +16,8 @@ import json ...@@ -16,8 +16,8 @@ import json
import os import os
import re import re
from pathlib import Path from pathlib import Path
from typing import Union
from typing import Text from typing import Text
from typing import Union
import paddle import paddle
from paddle import distributed as dist from paddle import distributed as dist
...@@ -72,13 +72,13 @@ class Checkpoint(): ...@@ -72,13 +72,13 @@ class Checkpoint():
if isinstance(tag_or_iteration, int): if isinstance(tag_or_iteration, int):
self._save_checkpoint_record(checkpoint_dir, tag_or_iteration) self._save_checkpoint_record(checkpoint_dir, tag_or_iteration)
def load_parameters(self, def load_parameters(self,
model, model,
optimizer=None, optimizer=None,
checkpoint_dir=None, checkpoint_dir=None,
checkpoint_path=None, checkpoint_path=None,
record_file="checkpoint_latest"): record_file="checkpoint_latest"):
"""Load a last model checkpoint from disk. """Load a last model checkpoint from disk.
Args: Args:
model (Layer): model to load parameters. model (Layer): model to load parameters.
......
...@@ -151,8 +151,9 @@ def th_accuracy(pad_outputs: paddle.Tensor, ...@@ -151,8 +151,9 @@ def th_accuracy(pad_outputs: paddle.Tensor,
Returns: Returns:
float: Accuracy value (0.0 - 1.0). float: Accuracy value (0.0 - 1.0).
""" """
pad_pred = pad_outputs.view( pad_pred = pad_outputs.view(pad_targets.shape[0],
pad_targets.shape[0], pad_targets.size(1), pad_outputs.size(1)).argmax(2) pad_targets.size(1),
pad_outputs.size(1)).argmax(2)
mask = pad_targets != ignore_label mask = pad_targets != ignore_label
#TODO(Hui Zhang): sum not support bool type #TODO(Hui Zhang): sum not support bool type
# numerator = paddle.sum( # numerator = paddle.sum(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册