未验证 提交 50df0170 编写于 作者: L limingshu 提交者: GitHub

first commit (#51683)

上级 0b1a8a83
...@@ -86,7 +86,8 @@ template <typename IndexType = int> ...@@ -86,7 +86,8 @@ template <typename IndexType = int>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexType EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexType
FlatTensorIndex(const Index3& index, const Dim3& dims) { FlatTensorIndex(const Index3& index, const Dim3& dims) {
IndexType flat_index = index[0]; IndexType flat_index = index[0];
for (int i = 1; i < 3; i++) { #pragma unroll
for (int i = 1; i < 3; ++i) {
flat_index = flat_index * dims[i] + index[i]; flat_index = flat_index * dims[i] + index[i];
} }
return flat_index; return flat_index;
...@@ -97,7 +98,8 @@ template <typename IndexType = int> ...@@ -97,7 +98,8 @@ template <typename IndexType = int>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index3 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index3
ConvertTensorIndex(IndexType index, const Dim3& dims) { ConvertTensorIndex(IndexType index, const Dim3& dims) {
Index3 tensor_index; Index3 tensor_index;
for (int i = 2; i >= 0; i--) { #pragma unroll
for (int i = 2; i >= 0; --i) {
IndexType new_index = index / dims[i]; IndexType new_index = index / dims[i];
tensor_index[i] = static_cast<int>(index - dims[i] * new_index); tensor_index[i] = static_cast<int>(index - dims[i] * new_index);
index = new_index; index = new_index;
......
...@@ -153,7 +153,7 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input, ...@@ -153,7 +153,7 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
if (x < in_effective_thread_num) { if (x < in_effective_thread_num) {
// Read a tile from input using block. // Read a tile from input using block.
int x_i = x / TileY; int x_i = x / TileY;
int x_j = x % TileY; int x_j = x - x_i * TileY;
IndexType input_ind = IndexType input_ind =
input_origin_block_flat_index + x_i * input_dims[2] + x_j; input_origin_block_flat_index + x_i * input_dims[2] + x_j;
IndexType input_inc = BlockReadRows * input_dims[2]; IndexType input_inc = BlockReadRows * input_dims[2];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册