提交 bc6da7a1 编写于 作者: H Hui Zhang

not hack size since it exists

上级 df1d44f5
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typeing import Union from typing import Union
from typeing import Any from typing import Any
import paddle import paddle
from paddle import nn from paddle import nn
...@@ -21,6 +21,7 @@ from paddle.nn import functional as F ...@@ -21,6 +21,7 @@ from paddle.nn import functional as F
from paddle.nn import initializer as I from paddle.nn import initializer as I
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.warn = logging.warning
# TODO(Hui Zhang): remove this hack # TODO(Hui Zhang): remove this hack
paddle.bool = 'bool' paddle.bool = 'bool'
...@@ -52,11 +53,10 @@ def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor: ...@@ -52,11 +53,10 @@ def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
return s return s
if not hasattr(paddle.Tensor, 'size'): # logger.warn(
logger.warn( # "override size of paddle.Tensor if exists or register, remove this when fixed!"
"override size of paddle.Tensor if exists or register, remove this when fixed!" # )
) # paddle.Tensor.size = size
paddle.Tensor.size = size
def masked_fill(xs: paddle.Tensor, def masked_fill(xs: paddle.Tensor,
......
...@@ -272,6 +272,6 @@ def mask_finished_preds(pred: paddle.Tensor, flag: paddle.Tensor, ...@@ -272,6 +272,6 @@ def mask_finished_preds(pred: paddle.Tensor, flag: paddle.Tensor,
Returns: Returns:
paddle.Tensor: (batch_size * beam_size). paddle.Tensor: (batch_size * beam_size).
""" """
beam_size = pred.size(-1) beam_size = pred.shape[-1]
finished = flag.repeat([1, beam_size]) finished = flag.repeat(1, beam_size)
return pred.masked_fill_(finished, eos) return pred.masked_fill_(finished, eos)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册