提交 bc6da7a1 编写于 作者: H Hui Zhang

not hack size since it exists

上级 df1d44f5
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typeing import Union
from typeing import Any
from typing import Union
from typing import Any
import paddle
from paddle import nn
......@@ -21,6 +21,7 @@ from paddle.nn import functional as F
from paddle.nn import initializer as I
logger = logging.getLogger(__name__)
logger.warn = logging.warning
# TODO(Hui Zhang): remove this hack
paddle.bool = 'bool'
......@@ -52,11 +53,10 @@ def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
return s
if not hasattr(paddle.Tensor, 'size'):
logger.warn(
"override size of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.size = size
# logger.warn(
# "override size of paddle.Tensor if exists or register, remove this when fixed!"
# )
# paddle.Tensor.size = size
def masked_fill(xs: paddle.Tensor,
......
......@@ -272,6 +272,6 @@ def mask_finished_preds(pred: paddle.Tensor, flag: paddle.Tensor,
Returns:
paddle.Tensor: (batch_size * beam_size).
"""
beam_size = pred.size(-1)
finished = flag.repeat([1, beam_size])
beam_size = pred.shape[-1]
finished = flag.repeat(1, beam_size)
return pred.masked_fill_(finished, eos)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册