diff --git a/paddle/framework/dim.h b/paddle/framework/dim.h index 883fdc55eb929ebc51e8ae05938e9d07374406ce..8dc1bab06df21ff3d77de5cd34adbdbb0df62ab6 100644 --- a/paddle/framework/dim.h +++ b/paddle/framework/dim.h @@ -401,5 +401,20 @@ HOSTDEVICE Dim linear_to_dimension(int linear_index, Dim extents) { return result; } +template +Dim slice(const Dim& dim, int begin, int end) { + PADDLE_ENFORCE(begin < end, + "Begin index must be less than end index in Dim slice."); + PADDLE_ENFORCE(begin >= 0 && end <= S && end - begin == D, + "Index error occurs in Dim slice."); + if (begin > 0) { + return slice(dim.tail, begin - 1, end - 1); + } + if (D == 1) { + return Dim<1>(dim.head); + } + return Dim(dim.head, slice(dim.tail, 0, end - 1)); +} + } // namespace framework } // namespace paddle