From 53208f52ff1927db6eb27e2101d49ff952640d77 Mon Sep 17 00:00:00 2001 From: yingyibiao <52000032+yingyibiao@users.noreply.github.com> Date: Tue, 2 Feb 2021 21:07:13 +0800 Subject: [PATCH] Support padding direction selection for collate (#5263) add an optional boolean argument pad_rght to Pad class to support padding direction selection --- PaddleNLP/paddlenlp/data/collate.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/PaddleNLP/paddlenlp/data/collate.py b/PaddleNLP/paddlenlp/data/collate.py index ed5bc65d..15b4fc8f 100644 --- a/PaddleNLP/paddlenlp/data/collate.py +++ b/PaddleNLP/paddlenlp/data/collate.py @@ -81,6 +81,9 @@ class Pad(object): data type of returned length. Default: False. dtype (numpy.dtype, optional): The value type of the output. If it is set to None, the input data type is used. Default: None. + pad_right (bool, optional): Boolean argument indicating whether the + padding direction is right-side. If True, it indicates we pad to the right side, + while False indicates we pad to the left side. Default: True. Example: .. code-block:: python from paddle.incubate.hapi.text.data_utils import Pad @@ -96,11 +99,12 @@ class Pad(object): ''' """ - def __init__(self, pad_val=0, axis=0, ret_length=None, dtype=None): + def __init__(self, pad_val=0, axis=0, ret_length=None, dtype=None, pad_right=True): self._pad_val = pad_val self._axis = axis self._ret_length = ret_length self._dtype = dtype + self._pad_right = pad_right def __call__(self, data): """ @@ -132,7 +136,11 @@ class Pad(object): ret[i] = arr else: slices = [slice(None) for _ in range(arr.ndim)] - slices[self._axis] = slice(0, arr.shape[self._axis]) + if self._pad_right: + slices[self._axis] = slice(0, arr.shape[self._axis]) + else: + slices[self._axis] = slice(max_size - arr.shape[self._axis], max_size) + if slices[self._axis].start != slices[self._axis].stop: slices = [slice(i, i + 1)] + slices ret[tuple(slices)] = arr -- GitLab