diff --git a/PaddleNLP/paddlenlp/data/collate.py b/PaddleNLP/paddlenlp/data/collate.py index ed5bc65dc0e94e80c4d69d758a71237d494405ee..15b4fc8f82f9488aceca6222c8bb2c55580998be 100644 --- a/PaddleNLP/paddlenlp/data/collate.py +++ b/PaddleNLP/paddlenlp/data/collate.py @@ -81,6 +81,9 @@ class Pad(object): data type of returned length. Default: False. dtype (numpy.dtype, optional): The value type of the output. If it is set to None, the input data type is used. Default: None. + pad_right (bool, optional): Boolean argument indicating whether the + padding direction is right-side. If True, it indicates we pad to the right side, + while False indicates we pad to the left side. Default: True. Example: .. code-block:: python from paddle.incubate.hapi.text.data_utils import Pad @@ -96,11 +99,12 @@ class Pad(object): ''' """ - def __init__(self, pad_val=0, axis=0, ret_length=None, dtype=None): + def __init__(self, pad_val=0, axis=0, ret_length=None, dtype=None, pad_right=True): self._pad_val = pad_val self._axis = axis self._ret_length = ret_length self._dtype = dtype + self._pad_right = pad_right def __call__(self, data): """ @@ -132,7 +136,11 @@ class Pad(object): ret[i] = arr else: slices = [slice(None) for _ in range(arr.ndim)] - slices[self._axis] = slice(0, arr.shape[self._axis]) + if self._pad_right: + slices[self._axis] = slice(0, arr.shape[self._axis]) + else: + slices[self._axis] = slice(max_size - arr.shape[self._axis], max_size) + if slices[self._axis].start != slices[self._axis].stop: slices = [slice(i, i + 1)] + slices ret[tuple(slices)] = arr