提交 ee90c2d2 编写于 作者: F fengjiayi

add slice_dim draft

上级 8bcd1faf
...@@ -401,5 +401,20 @@ HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) { ...@@ -401,5 +401,20 @@ HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) {
return result; return result;
} }
template <int D, int S>
Dim<D> slice(const Dim<S>& dim, int begin, int end) {
PADDLE_ENFORCE(begin < end,
"Begin index must be less than end index in Dim slice.");
PADDLE_ENFORCE(begin >= 0 && end <= S && end - begin == D,
"Index error occurs in Dim slice.");
if (begin > 0) {
return slice<D>(dim.tail, begin - 1, end - 1);
}
if (D == 1) {
return Dim<1>(dim.head);
}
return Dim<D>(dim.head, slice<D - 1>(dim.tail, 0, end - 1));
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册