Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d9a9e638
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d9a9e638
编写于
9月 07, 2022
作者:
Y
Yuang Liu
提交者:
GitHub
9月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[alphafold] Transpose support large tensors where there numel is bigger than INT32_MAX (#45753)
上级
0ddcf30c
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
116 addition
and
86 deletion
+116
-86
paddle/fluid/framework/gpu_utils.h
paddle/fluid/framework/gpu_utils.h
+8
-6
paddle/fluid/operators/transpose_op.cu.h
paddle/fluid/operators/transpose_op.cu.h
+108
-80
未找到文件。
paddle/fluid/framework/gpu_utils.h
浏览文件 @
d9a9e638
...
...
@@ -82,9 +82,10 @@ struct Index3 : DeviceArray<int, 3, 0> {
};
// Flat index with real dimension
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE
int
FlatTensorIndex
(
const
Index3
&
index
,
const
Dim3
&
dims
)
{
int
flat_index
=
index
[
0
];
template
<
typename
IDX_T
=
int
>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE
IDX_T
FlatTensorIndex
(
const
Index3
&
index
,
const
Dim3
&
dims
)
{
IDX_T
flat_index
=
index
[
0
];
for
(
int
i
=
1
;
i
<
3
;
i
++
)
{
flat_index
=
flat_index
*
dims
[
i
]
+
index
[
i
];
}
...
...
@@ -92,12 +93,13 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int FlatTensorIndex(const Index3& index,
}
// Convert index to tensor index with dimension.
template
<
typename
IDX_T
=
int
>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE
Index3
ConvertTensorIndex
(
int
index
,
const
Dim3
&
dims
)
{
ConvertTensorIndex
(
IDX_T
index
,
const
Dim3
&
dims
)
{
Index3
tensor_index
;
for
(
int
i
=
2
;
i
>=
0
;
i
--
)
{
int
new_index
=
index
/
dims
[
i
];
tensor_index
[
i
]
=
index
-
dims
[
i
]
*
new_index
;
IDX_T
new_index
=
index
/
dims
[
i
];
tensor_index
[
i
]
=
static_cast
<
int
>
(
index
-
dims
[
i
]
*
new_index
)
;
index
=
new_index
;
}
return
tensor_index
;
...
...
paddle/fluid/operators/transpose_op.cu.h
浏览文件 @
d9a9e638
...
...
@@ -79,7 +79,11 @@ constexpr bool CheckNonLongTileSize(int tile_long, int tile_short, int size_T) {
// Use SM to do data transfer, load a tile into SM then store out.
// All tile read and write are colascing, so can speedup memory copy
template
<
typename
T
,
int
NumThreads
,
int
TileX
,
int
TileY
>
template
<
typename
T
,
int
NumThreads
,
int
TileX
,
int
TileY
,
typename
IDX_T
=
int
>
__global__
void
TilingSwapDim1And2
(
const
T
*
__restrict__
input
,
Dim3
input_dims
,
T
*
__restrict__
output
)
{
...
...
@@ -116,7 +120,7 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
// Converts block idx to tile index, each block process a tile
Index3
input_block_tile_index
=
ConvertTensorIndex
(
blockIdx
.
x
,
tile_aligned_input_dim
);
framework
::
ConvertTensorIndex
<
IDX_T
>
(
blockIdx
.
x
,
tile_aligned_input_dim
);
// Compute real index align to tile:0, 32, 64...
Index3
block_tile_index_in_input
=
{
...
...
@@ -126,11 +130,11 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
};
// Compute block flat index against input dims.
int
input_origin_block_flat_index
=
FlatTensorIndex
(
block_tile_index_in_input
,
input_dims
);
IDX_T
input_origin_block_flat_index
=
framework
::
FlatTensorIndex
<
IDX_T
>
(
block_tile_index_in_input
,
input_dims
);
bool
full_tile
=
true
;
int
tile_width
=
TileY
;
IDX_T
tile_width
=
TileY
;
// Last row is not full.
if
(
input_block_tile_index
[
2
]
==
tile_aligned_input_dim
[
2
]
-
1
)
{
...
...
@@ -138,21 +142,21 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
full_tile
&=
false
;
}
int
tile_height
=
TileX
;
IDX_T
tile_height
=
TileX
;
if
(
input_block_tile_index
[
1
]
==
tile_aligned_input_dim
[
1
]
-
1
)
{
tile_height
=
input_dims
[
1
]
-
(
tile_aligned_input_dim
[
1
]
-
1
)
*
TileX
;
full_tile
&=
false
;
}
constexpr
int
in_effective_thread_num
=
NumThreads
/
TileY
*
TileY
;
constexpr
IDX_T
in_effective_thread_num
=
NumThreads
/
TileY
*
TileY
;
if
(
x
<
in_effective_thread_num
)
{
// Read a tile from input using block.
int
x_i
=
x
/
TileY
;
int
x_j
=
x
%
TileY
;
int
input_ind
=
input_origin_block_flat_index
+
x_i
*
input_dims
[
2
]
+
x_j
;
int
input_inc
=
BlockReadRows
*
input_dims
[
2
];
IDX_T
input_ind
=
input_origin_block_flat_index
+
x_i
*
input_dims
[
2
]
+
x_j
;
IDX_T
input_inc
=
BlockReadRows
*
input_dims
[
2
];
if
(
full_tile
)
{
#pragma unroll
...
...
@@ -163,7 +167,7 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
}
else
{
if
(
x_j
<
tile_width
)
{
#pragma unroll
for
(
int
ind_i
=
x_i
;
ind_i
<
(
tile_height
);
ind_i
+=
BlockReadRows
)
{
for
(
IDX_T
ind_i
=
x_i
;
ind_i
<
(
tile_height
);
ind_i
+=
BlockReadRows
)
{
tile_sm
[
ind_i
][
x_j
]
=
input
[
input_ind
];
input_ind
+=
input_inc
;
}
...
...
@@ -186,17 +190,17 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
output_block_tile_index
[
2
]
*
TileX
,
};
int
output_origin_block_flat_index
=
FlatTensorIndex
(
block_tile_index_in_output
,
output_dims
);
IDX_T
output_origin_block_flat_index
=
framework
::
FlatTensorIndex
<
IDX_T
>
(
block_tile_index_in_output
,
output_dims
);
constexpr
int
out_effective_thread_num
=
NumThreads
/
TileX
*
TileX
;
constexpr
IDX_T
out_effective_thread_num
=
NumThreads
/
TileX
*
TileX
;
if
(
x
<
out_effective_thread_num
)
{
int
x_i
=
x
/
TileX
;
int
x_j
=
x
%
TileX
;
int
output_ind
=
IDX_T
output_ind
=
output_origin_block_flat_index
+
x_i
*
output_dims
[
2
]
+
x_j
;
int
output_inc
=
BlockWriteRows
*
output_dims
[
2
];
IDX_T
output_inc
=
BlockWriteRows
*
output_dims
[
2
];
if
(
full_tile
)
{
#pragma unroll
...
...
@@ -207,7 +211,7 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
}
else
{
if
(
x_j
<
tile_height
)
{
#pragma unroll
for
(
int
ind_i
=
x_i
;
ind_i
<
(
tile_width
);
ind_i
+=
BlockWriteRows
)
{
for
(
IDX_T
ind_i
=
x_i
;
ind_i
<
(
tile_width
);
ind_i
+=
BlockWriteRows
)
{
output
[
output_ind
]
=
tile_sm
[
x_j
][
ind_i
];
output_ind
+=
output_inc
;
}
...
...
@@ -272,32 +276,36 @@ struct SystemElemType<16> {
using
type
=
float4
;
};
template
<
typename
T
,
int
tile_long
,
int
tile_short
>
template
<
typename
T
,
int
tile_long
,
int
tile_short
,
typename
IDX_T
=
int
>
void
LaunchNarrowDims2TransposeKernel
(
const
phi
::
GPUContext
&
d
,
int
tile_size_i
,
int
tile_size_j
,
int
total_tiles_count
,
IDX_T
total_tiles_count
,
const
T
*
input
,
const
Dim3
&
input_dims
,
T
*
output
)
{
constexpr
int
NumThreads
=
tile_long
;
if
(
tile_size_i
<=
tile_long
&&
tile_size_j
<=
tile_short
)
{
TilingSwapDim1And2
<
T
,
NumThreads
,
tile_long
,
tile_short
>
TilingSwapDim1And2
<
T
,
NumThreads
,
tile_long
,
tile_short
,
IDX_T
>
<<<
total_tiles_count
,
NumThreads
,
0
,
d
.
stream
()
>>>
(
input
,
input_dims
,
output
);
}
else
{
TilingSwapDim1And2
<
T
,
NumThreads
,
tile_short
,
tile_long
>
TilingSwapDim1And2
<
T
,
NumThreads
,
tile_short
,
tile_long
,
IDX_T
>
<<<
total_tiles_count
,
NumThreads
,
0
,
d
.
stream
()
>>>
(
input
,
input_dims
,
output
);
}
}
template
<
typename
T
,
int
tile_long
,
int
tile_short
,
typename
dummy
=
void
>
template
<
typename
T
,
int
tile_long
,
int
tile_short
,
typename
IDX_T
=
int
,
typename
dummy
=
void
>
struct
NarrowDims2TransposeDispatch
{
static
void
DoTranspose
(
const
phi
::
GPUContext
&
d
,
int
tile_size_i
,
int
tile_size_j
,
int
total_tiles_count
,
IDX_T
total_tiles_count
,
const
T
*
input
,
const
Dim3
&
input_dims
,
T
*
output
)
{
...
...
@@ -313,7 +321,7 @@ struct NarrowDims2TransposeDispatch {
std
::
min
(
tile_size_i
,
tile_size_j
)
<=
tile_short
;
if
(
request_satisfied
)
{
LaunchNarrowDims2TransposeKernel
<
T
,
tile_long
,
tile_short
>
(
LaunchNarrowDims2TransposeKernel
<
T
,
tile_long
,
tile_short
,
IDX_T
>
(
d
,
tile_size_i
,
tile_size_j
,
...
...
@@ -328,40 +336,41 @@ struct NarrowDims2TransposeDispatch {
std
::
max
(
tile_size_i
,
tile_size_j
)
>
tile_long
;
if
(
long_side_request_not_satisfied
)
{
NarrowDims2TransposeDispatch
<
T
,
tile_long
*
2
,
tile_short
>::
DoTranspose
(
d
,
tile_size_i
,
tile_size_j
,
total_tiles_count
,
input
,
input_dims
,
output
);
NarrowDims2TransposeDispatch
<
T
,
tile_long
*
2
,
tile_short
,
IDX_T
>::
DoTranspose
(
d
,
tile_size_i
,
tile_size_j
,
total_tiles_count
,
input
,
input_dims
,
output
);
}
else
{
NarrowDims2TransposeDispatch
<
T
,
tile_long
,
tile_short
+
1
>::
DoTranspose
(
d
,
tile_size_i
,
tile_size_j
,
total_tiles_count
,
input
,
input_dims
,
output
);
NarrowDims2TransposeDispatch
<
T
,
tile_long
,
tile_short
+
1
,
IDX_T
>::
DoTranspose
(
d
,
tile_size_i
,
tile_size_j
,
total_tiles_count
,
input
,
input_dims
,
output
);
}
}
};
// If Not long tile size, goto this function when compile.
template
<
typename
T
,
int
tile_long
,
int
tile_short
>
template
<
typename
T
,
int
tile_long
,
int
tile_short
,
typename
IDX_T
>
struct
NarrowDims2TransposeDispatch
<
T
,
tile_long
,
tile_short
,
IDX_T
,
typename
std
::
enable_if
<
CheckNonLongTileSize
(
tile_long
,
tile_short
,
sizeof
(
T
)),
void
>::
type
>
{
static
void
DoTranspose
(
const
phi
::
GPUContext
&
d
,
int
tile_size_i
,
int
tile_size_j
,
int
total_tiles_count
,
IDX_T
total_tiles_count
,
const
T
*
input
,
const
Dim3
&
input_dims
,
T
*
output
)
{
...
...
@@ -377,7 +386,7 @@ struct NarrowDims2TransposeDispatch<
std
::
min
(
tile_size_i
,
tile_size_j
)
<=
tile_short
;
if
(
request_satisfied
)
{
LaunchNarrowDims2TransposeKernel
<
T
,
tile_long
,
tile_short
>
(
LaunchNarrowDims2TransposeKernel
<
T
,
tile_long
,
tile_short
,
IDX_T
>
(
d
,
tile_size_i
,
tile_size_j
,
...
...
@@ -388,29 +397,30 @@ struct NarrowDims2TransposeDispatch<
return
;
}
NarrowDims2TransposeDispatch
<
T
,
tile_long
,
tile_short
+
1
>::
DoTranspose
(
d
,
tile_size_i
,
tile_size_j
,
total_tiles_count
,
input
,
input_dims
,
output
);
NarrowDims2TransposeDispatch
<
T
,
tile_long
,
tile_short
+
1
,
IDX_T
>::
DoTranspose
(
d
,
tile_size_i
,
tile_size_j
,
total_tiles_count
,
input
,
input_dims
,
output
);
}
};
// If long tile size, goto this function when compile.
template
<
typename
T
,
int
tile_long
,
int
tile_short
>
template
<
typename
T
,
int
tile_long
,
int
tile_short
,
typename
IDX_T
>
struct
NarrowDims2TransposeDispatch
<
T
,
tile_long
,
tile_short
,
IDX_T
,
typename
std
::
enable_if
<
CheckLongTileSize
(
tile_long
,
tile_short
,
sizeof
(
T
)),
void
>::
type
>
{
static
void
DoTranspose
(
const
phi
::
GPUContext
&
d
,
int
tile_size_i
,
int
tile_size_j
,
int
total_tiles_count
,
IDX_T
total_tiles_count
,
const
T
*
input
,
const
Dim3
&
input_dims
,
T
*
output
)
{
...
...
@@ -422,7 +432,7 @@ struct NarrowDims2TransposeDispatch<
" but received is:%d."
,
tile_long
));
LaunchNarrowDims2TransposeKernel
<
T
,
tile_long
,
tile_short
>
(
LaunchNarrowDims2TransposeKernel
<
T
,
tile_long
,
tile_short
,
IDX_T
>
(
d
,
tile_size_i
,
tile_size_j
,
...
...
@@ -433,7 +443,7 @@ struct NarrowDims2TransposeDispatch<
}
};
template
<
typename
T
,
bool
conjugate
=
false
>
template
<
typename
T
,
bool
conjugate
=
false
,
typename
IDX_T
=
int
>
void
SwapDim1And2InNarrow
(
const
phi
::
GPUContext
&
d
,
const
T
*
input
,
const
Dim3
&
input_dims
,
...
...
@@ -504,13 +514,14 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d,
framework
::
CeilOrFloor
<
int
,
true
>
(
input_dims
[
2
],
select_tile_size_j
),
};
int
total_tiles_count
=
input_dims_aligned
[
0
]
*
input_dims_aligned
[
1
]
*
input_dims_aligned
[
2
];
IDX_T
total_tiles_count
=
input_dims_aligned
[
0
];
total_tiles_count
*=
input_dims_aligned
[
1
];
total_tiles_count
*=
input_dims_aligned
[
2
];
// Suppose T can be replaced by system builtin types
using
ElemType
=
typename
SystemElemType
<
sizeof
(
T
)
>::
type
;
NarrowDims2TransposeDispatch
<
ElemType
,
32
,
2
>::
DoTranspose
(
NarrowDims2TransposeDispatch
<
ElemType
,
32
,
2
,
IDX_T
>::
DoTranspose
(
d
,
select_tile_size_i
,
select_tile_size_j
,
...
...
@@ -522,8 +533,8 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d,
// This is for case that cannot do coalescing read and write.
// Or input is too small to split into tiles.
template
<
typename
T
,
int
pos0
,
int
pos1
,
int
pos2
>
__global__
void
TransposeSimpleKernel
(
int
nthreads
,
template
<
typename
T
,
int
pos0
,
int
pos1
,
int
pos2
,
typename
IDX_T
=
int
>
__global__
void
TransposeSimpleKernel
(
IDX_T
nthreads
,
const
T
*
__restrict__
input
,
Dim3
input_dims
,
T
*
__restrict__
output
)
{
...
...
@@ -532,22 +543,24 @@ __global__ void TransposeSimpleKernel(int nthreads,
output_dims
[
pos1
]
=
input_dims
[
1
];
output_dims
[
pos2
]
=
input_dims
[
2
];
CUDA_KERNEL_LOOP
(
output_index
,
nthreads
)
{
Index3
output_tensor_index
=
ConvertTensorIndex
(
output_index
,
output_dims
);
CUDA_KERNEL_LOOP_TYPE
(
output_index
,
nthreads
,
IDX_T
)
{
Index3
output_tensor_index
=
framework
::
ConvertTensorIndex
<
IDX_T
>
(
output_index
,
output_dims
);
Index3
input_tensor_index
;
input_tensor_index
[
0
]
=
output_tensor_index
[
pos0
];
input_tensor_index
[
1
]
=
output_tensor_index
[
pos1
];
input_tensor_index
[
2
]
=
output_tensor_index
[
pos2
];
int
input_index
=
FlatTensorIndex
(
input_tensor_index
,
input_dims
);
IDX_T
input_index
=
framework
::
FlatTensorIndex
<
IDX_T
>
(
input_tensor_index
,
input_dims
);
output
[
output_index
]
=
input
[
input_index
];
}
}
// Here suppose convert all tensor to dim3, so just change dim1 and 2.
template
<
typename
T
>
template
<
typename
T
,
typename
IDX_T
=
int
>
void
SendSwapDim1And2InTranspose
(
const
phi
::
GPUContext
&
d
,
const
T
*
input
,
const
Dim3
&
input_dims
,
...
...
@@ -572,10 +585,11 @@ void SendSwapDim1And2InTranspose(const phi::GPUContext& d,
framework
::
CeilOrFloor
<
int
,
true
>
(
input_dims
[
2
],
kTileSize
),
};
int
total_tiles_count
=
input_dims_aligned
[
0
]
*
input_dims_aligned
[
1
]
*
input_dims_aligned
[
2
];
IDX_T
total_tiles_count
=
input_dims_aligned
[
0
];
total_tiles_count
*=
input_dims_aligned
[
1
];
total_tiles_count
*=
input_dims_aligned
[
2
];
TilingSwapDim1And2
<
T
,
kNumThreads
,
kTileSize
,
kTileSize
>
TilingSwapDim1And2
<
T
,
kNumThreads
,
kTileSize
,
kTileSize
,
IDX_T
>
<<<
total_tiles_count
,
kNumThreads
,
0
,
d
.
stream
()
>>>
(
input
,
input_dims
,
output
);
...
...
@@ -583,18 +597,21 @@ void SendSwapDim1And2InTranspose(const phi::GPUContext& d,
// If input shape is like Rect, such as 2X100, use Narrow tile size.
// It makes things complicated, because need to find a tile can coverr
// input and also reach best coalescing.
SwapDim1And2InNarrow
<
T
>
(
d
,
input
,
input_dims
,
output
,
kMinTileSize
);
SwapDim1And2InNarrow
<
T
,
false
,
IDX_T
>
(
d
,
input
,
input_dims
,
output
,
kMinTileSize
);
}
else
{
// If input shape is small, such as 8X8, just do simple copy
int
total_elements
=
input_dims
[
0
]
*
input_dims
[
1
]
*
input_dims
[
2
];
IDX_T
total_elements
=
input_dims
[
0
];
total_elements
*=
input_dims
[
1
];
total_elements
*=
input_dims
[
2
];
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
d
,
total_elements
);
TransposeSimpleKernel
<
T
,
0
,
2
,
1
>
TransposeSimpleKernel
<
T
,
0
,
2
,
1
,
IDX_T
>
<<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
d
.
stream
()
>>>
(
total_elements
,
input
,
input_dims
,
output
);
}
}
template
<
typename
T
>
template
<
typename
T
,
typename
IDX_T
=
int
>
struct
SwapDim1And2InTranspose
{
typedef
phi
::
GPUContext
Device
;
void
operator
()(
const
Device
&
d
,
...
...
@@ -604,11 +621,11 @@ struct SwapDim1And2InTranspose {
Dim3
input_dims
=
{
static_cast
<
int
>
(
combined_dims
[
0
]),
static_cast
<
int
>
(
combined_dims
[
1
]),
static_cast
<
int
>
(
combined_dims
[
2
])};
SendSwapDim1And2InTranspose
<
T
>
(
d
,
in
,
input_dims
,
out
);
SendSwapDim1And2InTranspose
<
T
,
IDX_T
>
(
d
,
in
,
input_dims
,
out
);
}
};
template
<
typename
T
>
template
<
typename
T
,
typename
IDX_T
=
int
>
struct
SwapDim0And2InTranspose
{
typedef
phi
::
GPUContext
Device
;
void
operator
()(
const
Device
&
d
,
...
...
@@ -619,10 +636,12 @@ struct SwapDim0And2InTranspose {
static_cast
<
int
>
(
combined_dims
[
1
]),
static_cast
<
int
>
(
combined_dims
[
2
])};
size_t
total_size
=
combined_dims
[
0
]
*
combined_dims
[
1
]
*
combined_dims
[
2
];
IDX_T
total_size
=
combined_dims
[
0
];
total_size
*=
combined_dims
[
1
];
total_size
*=
combined_dims
[
2
];
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
d
,
total_size
);
TransposeSimpleKernel
<
T
,
2
,
1
,
0
>
TransposeSimpleKernel
<
T
,
2
,
1
,
0
,
IDX_T
>
<<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
d
.
stream
()
>>>
(
total_size
,
in
,
input_dims
,
out
);
}
...
...
@@ -652,7 +671,7 @@ inline void CombineTransposeDim3(const framework::DDim& shape,
return
;
}
std
::
vector
<
int
>
new_dim_pos
(
shape
.
size
(),
-
1
);
std
::
vector
<
int
>
combined_dims
(
shape
.
size
(),
0
);
std
::
vector
<
int
64_t
>
combined_dims
(
shape
.
size
(),
0
);
int
cur_head
=
perm
[
0
];
new_dim_pos
[
cur_head
]
=
0
;
combined_dims
[
0
]
=
shape
[
cur_head
];
...
...
@@ -686,7 +705,7 @@ inline void CombineTransposeDim3(const framework::DDim& shape,
*
new_dims
=
phi
::
make_ddim
(
dim_vec
);
}
template
<
typename
T
>
template
<
typename
T
,
typename
IDX_T
=
int
>
struct
TransposeSimple
{
static
bool
run
(
const
phi
::
GPUContext
&
ctx
,
const
Tensor
&
in
,
...
...
@@ -709,21 +728,24 @@ struct TransposeSimple {
if
(
new_perm
[
0
]
==
1
&&
new_perm
[
1
]
==
0
)
{
// Add the first dimension size as 1.
new_dim_vec
.
insert
(
new_dim_vec
.
begin
(),
1
);
SwapDim1And2InTranspose
<
T
>
()(
ctx
,
in_data
,
new_dim_vec
,
out_data
);
SwapDim1And2InTranspose
<
T
,
IDX_T
>
()(
ctx
,
in_data
,
new_dim_vec
,
out_data
);
return
true
;
}
break
;
case
3
:
// In this case, suppose we can do coalescing read and write in tile.
if
(
new_perm
==
std
::
vector
<
int
>
({
0
,
2
,
1
}))
{
SwapDim1And2InTranspose
<
T
>
()(
ctx
,
in_data
,
new_dim_vec
,
out_data
);
SwapDim1And2InTranspose
<
T
,
IDX_T
>
()(
ctx
,
in_data
,
new_dim_vec
,
out_data
);
return
true
;
}
else
if
(
new_perm
==
std
::
vector
<
int
>
({
2
,
1
,
0
}))
{
// Maybe can optimize later, find a way to do coalescing memory copy.
// But I think it depends on the data size. If span is not large,
// maybe
// can do coalescing.
SwapDim0And2InTranspose
<
T
>
()(
ctx
,
in_data
,
new_dim_vec
,
out_data
);
SwapDim0And2InTranspose
<
T
,
IDX_T
>
()(
ctx
,
in_data
,
new_dim_vec
,
out_data
);
return
true
;
}
else
{
return
false
;
...
...
@@ -1159,7 +1181,13 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx,
const
std
::
vector
<
int32_t
>&
perm
,
Tensor
*
out
)
{
const
int
rank
=
perm
.
size
();
auto
ret
=
TransposeSimple
<
T
>::
run
(
ctx
,
in
,
perm
,
out
);
int64_t
numel
=
in
.
numel
();
bool
ret
{
false
};
if
(
numel
>=
INT32_MAX
)
{
ret
=
TransposeSimple
<
T
,
int64_t
>::
run
(
ctx
,
in
,
perm
,
out
);
}
else
{
ret
=
TransposeSimple
<
T
>::
run
(
ctx
,
in
,
perm
,
out
);
}
if
(
!
ret
)
{
auto
*
tuner
=
phi
::
autotune
::
MakeTransposeTuner
<
T
>
(
TransCompute
<
phi
::
GPUContext
,
T
>
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录