提交 1dfca4ef 编写于 作者: T tianhao zhang

fix multigpu training

上级 ed80b0e2
...@@ -21,10 +21,10 @@ import paddle ...@@ -21,10 +21,10 @@ import paddle
from numpy import float32 from numpy import float32
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.cli.asr.infer import ASRExecutor from paddlespeech.cli.asr.infer import ASRExecutor
from paddlespeech.cli.log import logger from paddlespeech.cli.log import logger
from paddlespeech.resource import CommonTaskResource from paddlespeech.resource import CommonTaskResource
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.modules.ctc import CTCDecoder from paddlespeech.s2t.modules.ctc import CTCDecoder
from paddlespeech.s2t.utils.tensor_utils import add_sos_eos from paddlespeech.s2t.utils.tensor_utils import add_sos_eos
...@@ -130,8 +130,8 @@ class PaddleASRConnectionHanddler: ...@@ -130,8 +130,8 @@ class PaddleASRConnectionHanddler:
## conformer ## conformer
# cache for conformer online # cache for conformer online
self.att_cache = paddle.zeros([0,0,0,0]) self.att_cache = paddle.zeros([0, 0, 0, 0])
self.cnn_cache = paddle.zeros([0,0,0,0]) self.cnn_cache = paddle.zeros([0, 0, 0, 0])
self.encoder_out = None self.encoder_out = None
# conformer decoding state # conformer decoding state
...@@ -474,9 +474,10 @@ class PaddleASRConnectionHanddler: ...@@ -474,9 +474,10 @@ class PaddleASRConnectionHanddler:
# cur chunk # cur chunk
chunk_xs = self.cached_feat[:, cur:end, :] chunk_xs = self.cached_feat[:, cur:end, :]
# forward chunk # forward chunk
(y, self.att_cache, self.cnn_cache) = self.model.encoder.forward_chunk( (y, self.att_cache,
chunk_xs, self.offset, required_cache_size, self.cnn_cache) = self.model.encoder.forward_chunk(
self.att_cache, self.cnn_cache) chunk_xs, self.offset, required_cache_size, self.att_cache,
self.cnn_cache, paddle.ones([0, 0, 0], dtype=paddle.bool))
outputs.append(y) outputs.append(y)
# update the global offset, in decoding frame unit # update the global offset, in decoding frame unit
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册