未验证 提交 50df0170 编写于 作者: L limingshu 提交者: GitHub

first commit (#51683)

上级 0b1a8a83
......@@ -86,7 +86,8 @@ template <typename IndexType = int>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexType
FlatTensorIndex(const Index3& index, const Dim3& dims) {
IndexType flat_index = index[0];
for (int i = 1; i < 3; i++) {
#pragma unroll
for (int i = 1; i < 3; ++i) {
flat_index = flat_index * dims[i] + index[i];
}
return flat_index;
......@@ -97,7 +98,8 @@ template <typename IndexType = int>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index3
ConvertTensorIndex(IndexType index, const Dim3& dims) {
Index3 tensor_index;
for (int i = 2; i >= 0; i--) {
#pragma unroll
for (int i = 2; i >= 0; --i) {
IndexType new_index = index / dims[i];
tensor_index[i] = static_cast<int>(index - dims[i] * new_index);
index = new_index;
......
......@@ -153,7 +153,7 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
if (x < in_effective_thread_num) {
// Read a tile from input using block.
int x_i = x / TileY;
int x_j = x % TileY;
int x_j = x - x_i * TileY;
IndexType input_ind =
input_origin_block_flat_index + x_i * input_dims[2] + x_j;
IndexType input_inc = BlockReadRows * input_dims[2];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册