未验证 提交 53208f52 编写于 作者: Y yingyibiao 提交者: GitHub

Support padding direction selection for collate (#5263)

add an optional boolean argument pad_rght to Pad class to support padding direction selection
上级 a1d2fd54
...@@ -81,6 +81,9 @@ class Pad(object): ...@@ -81,6 +81,9 @@ class Pad(object):
data type of returned length. Default: False. data type of returned length. Default: False.
dtype (numpy.dtype, optional): The value type of the output. If it is dtype (numpy.dtype, optional): The value type of the output. If it is
set to None, the input data type is used. Default: None. set to None, the input data type is used. Default: None.
pad_right (bool, optional): Boolean argument indicating whether the
padding direction is right-side. If True, it indicates we pad to the right side,
while False indicates we pad to the left side. Default: True.
Example: Example:
.. code-block:: python .. code-block:: python
from paddle.incubate.hapi.text.data_utils import Pad from paddle.incubate.hapi.text.data_utils import Pad
...@@ -96,11 +99,12 @@ class Pad(object): ...@@ -96,11 +99,12 @@ class Pad(object):
''' '''
""" """
def __init__(self, pad_val=0, axis=0, ret_length=None, dtype=None): def __init__(self, pad_val=0, axis=0, ret_length=None, dtype=None, pad_right=True):
self._pad_val = pad_val self._pad_val = pad_val
self._axis = axis self._axis = axis
self._ret_length = ret_length self._ret_length = ret_length
self._dtype = dtype self._dtype = dtype
self._pad_right = pad_right
def __call__(self, data): def __call__(self, data):
""" """
...@@ -132,7 +136,11 @@ class Pad(object): ...@@ -132,7 +136,11 @@ class Pad(object):
ret[i] = arr ret[i] = arr
else: else:
slices = [slice(None) for _ in range(arr.ndim)] slices = [slice(None) for _ in range(arr.ndim)]
slices[self._axis] = slice(0, arr.shape[self._axis]) if self._pad_right:
slices[self._axis] = slice(0, arr.shape[self._axis])
else:
slices[self._axis] = slice(max_size - arr.shape[self._axis], max_size)
if slices[self._axis].start != slices[self._axis].stop: if slices[self._axis].start != slices[self._axis].stop:
slices = [slice(i, i + 1)] + slices slices = [slice(i, i + 1)] + slices
ret[tuple(slices)] = arr ret[tuple(slices)] = arr
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册