Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9b70c556
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9b70c556
编写于
9月 07, 2022
作者:
Y
Yuang Liu
提交者:
GitHub
9月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rename the template type name for tranpose (#45834)
上级
420d186a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
72 addition
and
67 deletion
+72
-67
paddle/fluid/framework/gpu_utils.h
paddle/fluid/framework/gpu_utils.h
+7
-7
paddle/fluid/operators/transpose_op.cu.h
paddle/fluid/operators/transpose_op.cu.h
+65
-60
未找到文件。
paddle/fluid/framework/gpu_utils.h
浏览文件 @
9b70c556
...
...
@@ -82,10 +82,10 @@ struct Index3 : DeviceArray<int, 3, 0> {
};
// Flat index with real dimension
template
<
typename
I
DX_T
=
int
>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE
I
DX_T
FlatTensorIndex
(
const
Index3
&
index
,
const
Dim3
&
dims
)
{
I
DX_T
flat_index
=
index
[
0
];
template
<
typename
I
ndexType
=
int
>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE
I
ndexType
FlatTensorIndex
(
const
Index3
&
index
,
const
Dim3
&
dims
)
{
I
ndexType
flat_index
=
index
[
0
];
for
(
int
i
=
1
;
i
<
3
;
i
++
)
{
flat_index
=
flat_index
*
dims
[
i
]
+
index
[
i
];
}
...
...
@@ -93,12 +93,12 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IDX_T FlatTensorIndex(const Index3& index,
}
// Convert index to tensor index with dimension.
template
<
typename
I
DX_T
=
int
>
template
<
typename
I
ndexType
=
int
>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE
Index3
ConvertTensorIndex
(
I
DX_T
index
,
const
Dim3
&
dims
)
{
ConvertTensorIndex
(
I
ndexType
index
,
const
Dim3
&
dims
)
{
Index3
tensor_index
;
for
(
int
i
=
2
;
i
>=
0
;
i
--
)
{
I
DX_T
new_index
=
index
/
dims
[
i
];
I
ndexType
new_index
=
index
/
dims
[
i
];
tensor_index
[
i
]
=
static_cast
<
int
>
(
index
-
dims
[
i
]
*
new_index
);
index
=
new_index
;
}
...
...
paddle/fluid/operators/transpose_op.cu.h
浏览文件 @
9b70c556
...
...
@@ -83,7 +83,7 @@ template <typename T,
int
NumThreads
,
int
TileX
,
int
TileY
,
typename
I
DX_T
=
int
>
typename
I
ndexType
=
int
>
__global__
void
TilingSwapDim1And2
(
const
T
*
__restrict__
input
,
Dim3
input_dims
,
T
*
__restrict__
output
)
{
...
...
@@ -119,8 +119,8 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
};
// Converts block idx to tile index, each block process a tile
Index3
input_block_tile_index
=
framework
::
ConvertTensorIndex
<
IDX_T
>
(
blockIdx
.
x
,
tile_aligned_input_dim
);
Index3
input_block_tile_index
=
framework
::
ConvertTensorIndex
<
IndexType
>
(
blockIdx
.
x
,
tile_aligned_input_dim
);
// Compute real index align to tile:0, 32, 64...
Index3
block_tile_index_in_input
=
{
...
...
@@ -130,11 +130,12 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
};
// Compute block flat index against input dims.
IDX_T
input_origin_block_flat_index
=
framework
::
FlatTensorIndex
<
IDX_T
>
(
block_tile_index_in_input
,
input_dims
);
IndexType
input_origin_block_flat_index
=
framework
::
FlatTensorIndex
<
IndexType
>
(
block_tile_index_in_input
,
input_dims
);
bool
full_tile
=
true
;
I
DX_T
tile_width
=
TileY
;
I
ndexType
tile_width
=
TileY
;
// Last row is not full.
if
(
input_block_tile_index
[
2
]
==
tile_aligned_input_dim
[
2
]
-
1
)
{
...
...
@@ -142,21 +143,22 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
full_tile
&=
false
;
}
I
DX_T
tile_height
=
TileX
;
I
ndexType
tile_height
=
TileX
;
if
(
input_block_tile_index
[
1
]
==
tile_aligned_input_dim
[
1
]
-
1
)
{
tile_height
=
input_dims
[
1
]
-
(
tile_aligned_input_dim
[
1
]
-
1
)
*
TileX
;
full_tile
&=
false
;
}
constexpr
I
DX_T
in_effective_thread_num
=
NumThreads
/
TileY
*
TileY
;
constexpr
I
ndexType
in_effective_thread_num
=
NumThreads
/
TileY
*
TileY
;
if
(
x
<
in_effective_thread_num
)
{
// Read a tile from input using block.
int
x_i
=
x
/
TileY
;
int
x_j
=
x
%
TileY
;
IDX_T
input_ind
=
input_origin_block_flat_index
+
x_i
*
input_dims
[
2
]
+
x_j
;
IDX_T
input_inc
=
BlockReadRows
*
input_dims
[
2
];
IndexType
input_ind
=
input_origin_block_flat_index
+
x_i
*
input_dims
[
2
]
+
x_j
;
IndexType
input_inc
=
BlockReadRows
*
input_dims
[
2
];
if
(
full_tile
)
{
#pragma unroll
...
...
@@ -167,7 +169,8 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
}
else
{
if
(
x_j
<
tile_width
)
{
#pragma unroll
for
(
IDX_T
ind_i
=
x_i
;
ind_i
<
(
tile_height
);
ind_i
+=
BlockReadRows
)
{
for
(
IndexType
ind_i
=
x_i
;
ind_i
<
(
tile_height
);
ind_i
+=
BlockReadRows
)
{
tile_sm
[
ind_i
][
x_j
]
=
input
[
input_ind
];
input_ind
+=
input_inc
;
}
...
...
@@ -190,17 +193,18 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
output_block_tile_index
[
2
]
*
TileX
,
};
IDX_T
output_origin_block_flat_index
=
framework
::
FlatTensorIndex
<
IDX_T
>
(
block_tile_index_in_output
,
output_dims
);
IndexType
output_origin_block_flat_index
=
framework
::
FlatTensorIndex
<
IndexType
>
(
block_tile_index_in_output
,
output_dims
);
constexpr
I
DX_T
out_effective_thread_num
=
NumThreads
/
TileX
*
TileX
;
constexpr
I
ndexType
out_effective_thread_num
=
NumThreads
/
TileX
*
TileX
;
if
(
x
<
out_effective_thread_num
)
{
int
x_i
=
x
/
TileX
;
int
x_j
=
x
%
TileX
;
I
DX_T
output_ind
=
I
ndexType
output_ind
=
output_origin_block_flat_index
+
x_i
*
output_dims
[
2
]
+
x_j
;
I
DX_T
output_inc
=
BlockWriteRows
*
output_dims
[
2
];
I
ndexType
output_inc
=
BlockWriteRows
*
output_dims
[
2
];
if
(
full_tile
)
{
#pragma unroll
...
...
@@ -211,7 +215,8 @@ __global__ void TilingSwapDim1And2(const T* __restrict__ input,
}
else
{
if
(
x_j
<
tile_height
)
{
#pragma unroll
for
(
IDX_T
ind_i
=
x_i
;
ind_i
<
(
tile_width
);
ind_i
+=
BlockWriteRows
)
{
for
(
IndexType
ind_i
=
x_i
;
ind_i
<
(
tile_width
);
ind_i
+=
BlockWriteRows
)
{
output
[
output_ind
]
=
tile_sm
[
x_j
][
ind_i
];
output_ind
+=
output_inc
;
}
...
...
@@ -276,21 +281,21 @@ struct SystemElemType<16> {
using
type
=
float4
;
};
template
<
typename
T
,
int
tile_long
,
int
tile_short
,
typename
I
DX_T
=
int
>
template
<
typename
T
,
int
tile_long
,
int
tile_short
,
typename
I
ndexType
=
int
>
void
LaunchNarrowDims2TransposeKernel
(
const
phi
::
GPUContext
&
d
,
int
tile_size_i
,
int
tile_size_j
,
I
DX_T
total_tiles_count
,
I
ndexType
total_tiles_count
,
const
T
*
input
,
const
Dim3
&
input_dims
,
T
*
output
)
{
constexpr
int
NumThreads
=
tile_long
;
if
(
tile_size_i
<=
tile_long
&&
tile_size_j
<=
tile_short
)
{
TilingSwapDim1And2
<
T
,
NumThreads
,
tile_long
,
tile_short
,
I
DX_T
>
TilingSwapDim1And2
<
T
,
NumThreads
,
tile_long
,
tile_short
,
I
ndexType
>
<<<
total_tiles_count
,
NumThreads
,
0
,
d
.
stream
()
>>>
(
input
,
input_dims
,
output
);
}
else
{
TilingSwapDim1And2
<
T
,
NumThreads
,
tile_short
,
tile_long
,
I
DX_T
>
TilingSwapDim1And2
<
T
,
NumThreads
,
tile_short
,
tile_long
,
I
ndexType
>
<<<
total_tiles_count
,
NumThreads
,
0
,
d
.
stream
()
>>>
(
input
,
input_dims
,
output
);
}
...
...
@@ -299,13 +304,13 @@ void LaunchNarrowDims2TransposeKernel(const phi::GPUContext& d,
template
<
typename
T
,
int
tile_long
,
int
tile_short
,
typename
I
DX_T
=
int
,
typename
I
ndexType
=
int
,
typename
dummy
=
void
>
struct
NarrowDims2TransposeDispatch
{
static
void
DoTranspose
(
const
phi
::
GPUContext
&
d
,
int
tile_size_i
,
int
tile_size_j
,
I
DX_T
total_tiles_count
,
I
ndexType
total_tiles_count
,
const
T
*
input
,
const
Dim3
&
input_dims
,
T
*
output
)
{
...
...
@@ -321,7 +326,7 @@ struct NarrowDims2TransposeDispatch {
std
::
min
(
tile_size_i
,
tile_size_j
)
<=
tile_short
;
if
(
request_satisfied
)
{
LaunchNarrowDims2TransposeKernel
<
T
,
tile_long
,
tile_short
,
I
DX_T
>
(
LaunchNarrowDims2TransposeKernel
<
T
,
tile_long
,
tile_short
,
I
ndexType
>
(
d
,
tile_size_i
,
tile_size_j
,
...
...
@@ -336,7 +341,7 @@ struct NarrowDims2TransposeDispatch {
std
::
max
(
tile_size_i
,
tile_size_j
)
>
tile_long
;
if
(
long_side_request_not_satisfied
)
{
NarrowDims2TransposeDispatch
<
T
,
tile_long
*
2
,
tile_short
,
I
DX_T
>::
NarrowDims2TransposeDispatch
<
T
,
tile_long
*
2
,
tile_short
,
I
ndexType
>::
DoTranspose
(
d
,
tile_size_i
,
tile_size_j
,
...
...
@@ -345,7 +350,7 @@ struct NarrowDims2TransposeDispatch {
input_dims
,
output
);
}
else
{
NarrowDims2TransposeDispatch
<
T
,
tile_long
,
tile_short
+
1
,
I
DX_T
>::
NarrowDims2TransposeDispatch
<
T
,
tile_long
,
tile_short
+
1
,
I
ndexType
>::
DoTranspose
(
d
,
tile_size_i
,
tile_size_j
,
...
...
@@ -358,19 +363,19 @@ struct NarrowDims2TransposeDispatch {
};
// If Not long tile size, goto this function when compile.
template
<
typename
T
,
int
tile_long
,
int
tile_short
,
typename
I
DX_T
>
template
<
typename
T
,
int
tile_long
,
int
tile_short
,
typename
I
ndexType
>
struct
NarrowDims2TransposeDispatch
<
T
,
tile_long
,
tile_short
,
I
DX_T
,
I
ndexType
,
typename
std
::
enable_if
<
CheckNonLongTileSize
(
tile_long
,
tile_short
,
sizeof
(
T
)),
void
>::
type
>
{
static
void
DoTranspose
(
const
phi
::
GPUContext
&
d
,
int
tile_size_i
,
int
tile_size_j
,
I
DX_T
total_tiles_count
,
I
ndexType
total_tiles_count
,
const
T
*
input
,
const
Dim3
&
input_dims
,
T
*
output
)
{
...
...
@@ -386,7 +391,7 @@ struct NarrowDims2TransposeDispatch<
std
::
min
(
tile_size_i
,
tile_size_j
)
<=
tile_short
;
if
(
request_satisfied
)
{
LaunchNarrowDims2TransposeKernel
<
T
,
tile_long
,
tile_short
,
I
DX_T
>
(
LaunchNarrowDims2TransposeKernel
<
T
,
tile_long
,
tile_short
,
I
ndexType
>
(
d
,
tile_size_i
,
tile_size_j
,
...
...
@@ -397,7 +402,7 @@ struct NarrowDims2TransposeDispatch<
return
;
}
NarrowDims2TransposeDispatch
<
T
,
tile_long
,
tile_short
+
1
,
I
DX_T
>::
NarrowDims2TransposeDispatch
<
T
,
tile_long
,
tile_short
+
1
,
I
ndexType
>::
DoTranspose
(
d
,
tile_size_i
,
tile_size_j
,
...
...
@@ -409,18 +414,18 @@ struct NarrowDims2TransposeDispatch<
};
// If long tile size, goto this function when compile.
template
<
typename
T
,
int
tile_long
,
int
tile_short
,
typename
I
DX_T
>
template
<
typename
T
,
int
tile_long
,
int
tile_short
,
typename
I
ndexType
>
struct
NarrowDims2TransposeDispatch
<
T
,
tile_long
,
tile_short
,
I
DX_T
,
I
ndexType
,
typename
std
::
enable_if
<
CheckLongTileSize
(
tile_long
,
tile_short
,
sizeof
(
T
)),
void
>::
type
>
{
static
void
DoTranspose
(
const
phi
::
GPUContext
&
d
,
int
tile_size_i
,
int
tile_size_j
,
I
DX_T
total_tiles_count
,
I
ndexType
total_tiles_count
,
const
T
*
input
,
const
Dim3
&
input_dims
,
T
*
output
)
{
...
...
@@ -432,7 +437,7 @@ struct NarrowDims2TransposeDispatch<
" but received is:%d."
,
tile_long
));
LaunchNarrowDims2TransposeKernel
<
T
,
tile_long
,
tile_short
,
I
DX_T
>
(
LaunchNarrowDims2TransposeKernel
<
T
,
tile_long
,
tile_short
,
I
ndexType
>
(
d
,
tile_size_i
,
tile_size_j
,
...
...
@@ -443,7 +448,7 @@ struct NarrowDims2TransposeDispatch<
}
};
template
<
typename
T
,
bool
conjugate
=
false
,
typename
I
DX_T
=
int
>
template
<
typename
T
,
bool
conjugate
=
false
,
typename
I
ndexType
=
int
>
void
SwapDim1And2InNarrow
(
const
phi
::
GPUContext
&
d
,
const
T
*
input
,
const
Dim3
&
input_dims
,
...
...
@@ -514,14 +519,14 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d,
framework
::
CeilOrFloor
<
int
,
true
>
(
input_dims
[
2
],
select_tile_size_j
),
};
I
DX_T
total_tiles_count
=
input_dims_aligned
[
0
];
I
ndexType
total_tiles_count
=
input_dims_aligned
[
0
];
total_tiles_count
*=
input_dims_aligned
[
1
];
total_tiles_count
*=
input_dims_aligned
[
2
];
// Suppose T can be replaced by system builtin types
using
ElemType
=
typename
SystemElemType
<
sizeof
(
T
)
>::
type
;
NarrowDims2TransposeDispatch
<
ElemType
,
32
,
2
,
I
DX_T
>::
DoTranspose
(
NarrowDims2TransposeDispatch
<
ElemType
,
32
,
2
,
I
ndexType
>::
DoTranspose
(
d
,
select_tile_size_i
,
select_tile_size_j
,
...
...
@@ -533,8 +538,8 @@ void SwapDim1And2InNarrow(const phi::GPUContext& d,
// This is for case that cannot do coalescing read and write.
// Or input is too small to split into tiles.
template
<
typename
T
,
int
pos0
,
int
pos1
,
int
pos2
,
typename
I
DX_T
=
int
>
__global__
void
TransposeSimpleKernel
(
I
DX_T
nthreads
,
template
<
typename
T
,
int
pos0
,
int
pos1
,
int
pos2
,
typename
I
ndexType
=
int
>
__global__
void
TransposeSimpleKernel
(
I
ndexType
nthreads
,
const
T
*
__restrict__
input
,
Dim3
input_dims
,
T
*
__restrict__
output
)
{
...
...
@@ -543,24 +548,24 @@ __global__ void TransposeSimpleKernel(IDX_T nthreads,
output_dims
[
pos1
]
=
input_dims
[
1
];
output_dims
[
pos2
]
=
input_dims
[
2
];
CUDA_KERNEL_LOOP_TYPE
(
output_index
,
nthreads
,
I
DX_T
)
{
CUDA_KERNEL_LOOP_TYPE
(
output_index
,
nthreads
,
I
ndexType
)
{
Index3
output_tensor_index
=
framework
::
ConvertTensorIndex
<
I
DX_T
>
(
output_index
,
output_dims
);
framework
::
ConvertTensorIndex
<
I
ndexType
>
(
output_index
,
output_dims
);
Index3
input_tensor_index
;
input_tensor_index
[
0
]
=
output_tensor_index
[
pos0
];
input_tensor_index
[
1
]
=
output_tensor_index
[
pos1
];
input_tensor_index
[
2
]
=
output_tensor_index
[
pos2
];
I
DX_T
input_index
=
framework
::
FlatTensorIndex
<
I
DX_T
>
(
input_tensor_index
,
input_dims
);
I
ndexType
input_index
=
framework
::
FlatTensorIndex
<
I
ndexType
>
(
input_tensor_index
,
input_dims
);
output
[
output_index
]
=
input
[
input_index
];
}
}
// Here suppose convert all tensor to dim3, so just change dim1 and 2.
template
<
typename
T
,
typename
I
DX_T
=
int
>
template
<
typename
T
,
typename
I
ndexType
=
int
>
void
SendSwapDim1And2InTranspose
(
const
phi
::
GPUContext
&
d
,
const
T
*
input
,
const
Dim3
&
input_dims
,
...
...
@@ -585,11 +590,11 @@ void SendSwapDim1And2InTranspose(const phi::GPUContext& d,
framework
::
CeilOrFloor
<
int
,
true
>
(
input_dims
[
2
],
kTileSize
),
};
I
DX_T
total_tiles_count
=
input_dims_aligned
[
0
];
I
ndexType
total_tiles_count
=
input_dims_aligned
[
0
];
total_tiles_count
*=
input_dims_aligned
[
1
];
total_tiles_count
*=
input_dims_aligned
[
2
];
TilingSwapDim1And2
<
T
,
kNumThreads
,
kTileSize
,
kTileSize
,
I
DX_T
>
TilingSwapDim1And2
<
T
,
kNumThreads
,
kTileSize
,
kTileSize
,
I
ndexType
>
<<<
total_tiles_count
,
kNumThreads
,
0
,
d
.
stream
()
>>>
(
input
,
input_dims
,
output
);
...
...
@@ -597,21 +602,21 @@ void SendSwapDim1And2InTranspose(const phi::GPUContext& d,
// If input shape is like Rect, such as 2X100, use Narrow tile size.
// It makes things complicated, because need to find a tile can coverr
// input and also reach best coalescing.
SwapDim1And2InNarrow
<
T
,
false
,
I
DX_T
>
(
SwapDim1And2InNarrow
<
T
,
false
,
I
ndexType
>
(
d
,
input
,
input_dims
,
output
,
kMinTileSize
);
}
else
{
// If input shape is small, such as 8X8, just do simple copy
I
DX_T
total_elements
=
input_dims
[
0
];
I
ndexType
total_elements
=
input_dims
[
0
];
total_elements
*=
input_dims
[
1
];
total_elements
*=
input_dims
[
2
];
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
d
,
total_elements
);
TransposeSimpleKernel
<
T
,
0
,
2
,
1
,
I
DX_T
>
TransposeSimpleKernel
<
T
,
0
,
2
,
1
,
I
ndexType
>
<<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
d
.
stream
()
>>>
(
total_elements
,
input
,
input_dims
,
output
);
}
}
template
<
typename
T
,
typename
I
DX_T
=
int
>
template
<
typename
T
,
typename
I
ndexType
=
int
>
struct
SwapDim1And2InTranspose
{
typedef
phi
::
GPUContext
Device
;
void
operator
()(
const
Device
&
d
,
...
...
@@ -621,11 +626,11 @@ struct SwapDim1And2InTranspose {
Dim3
input_dims
=
{
static_cast
<
int
>
(
combined_dims
[
0
]),
static_cast
<
int
>
(
combined_dims
[
1
]),
static_cast
<
int
>
(
combined_dims
[
2
])};
SendSwapDim1And2InTranspose
<
T
,
I
DX_T
>
(
d
,
in
,
input_dims
,
out
);
SendSwapDim1And2InTranspose
<
T
,
I
ndexType
>
(
d
,
in
,
input_dims
,
out
);
}
};
template
<
typename
T
,
typename
I
DX_T
=
int
>
template
<
typename
T
,
typename
I
ndexType
=
int
>
struct
SwapDim0And2InTranspose
{
typedef
phi
::
GPUContext
Device
;
void
operator
()(
const
Device
&
d
,
...
...
@@ -636,12 +641,12 @@ struct SwapDim0And2InTranspose {
static_cast
<
int
>
(
combined_dims
[
1
]),
static_cast
<
int
>
(
combined_dims
[
2
])};
I
DX_T
total_size
=
combined_dims
[
0
];
I
ndexType
total_size
=
combined_dims
[
0
];
total_size
*=
combined_dims
[
1
];
total_size
*=
combined_dims
[
2
];
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
d
,
total_size
);
TransposeSimpleKernel
<
T
,
2
,
1
,
0
,
I
DX_T
>
TransposeSimpleKernel
<
T
,
2
,
1
,
0
,
I
ndexType
>
<<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
d
.
stream
()
>>>
(
total_size
,
in
,
input_dims
,
out
);
}
...
...
@@ -705,7 +710,7 @@ inline void CombineTransposeDim3(const framework::DDim& shape,
*
new_dims
=
phi
::
make_ddim
(
dim_vec
);
}
template
<
typename
T
,
typename
I
DX_T
=
int
>
template
<
typename
T
,
typename
I
ndexType
=
int
>
struct
TransposeSimple
{
static
bool
run
(
const
phi
::
GPUContext
&
ctx
,
const
Tensor
&
in
,
...
...
@@ -728,7 +733,7 @@ struct TransposeSimple {
if
(
new_perm
[
0
]
==
1
&&
new_perm
[
1
]
==
0
)
{
// Add the first dimension size as 1.
new_dim_vec
.
insert
(
new_dim_vec
.
begin
(),
1
);
SwapDim1And2InTranspose
<
T
,
I
DX_T
>
()(
SwapDim1And2InTranspose
<
T
,
I
ndexType
>
()(
ctx
,
in_data
,
new_dim_vec
,
out_data
);
return
true
;
}
...
...
@@ -736,7 +741,7 @@ struct TransposeSimple {
case
3
:
// In this case, suppose we can do coalescing read and write in tile.
if
(
new_perm
==
std
::
vector
<
int
>
({
0
,
2
,
1
}))
{
SwapDim1And2InTranspose
<
T
,
I
DX_T
>
()(
SwapDim1And2InTranspose
<
T
,
I
ndexType
>
()(
ctx
,
in_data
,
new_dim_vec
,
out_data
);
return
true
;
}
else
if
(
new_perm
==
std
::
vector
<
int
>
({
2
,
1
,
0
}))
{
...
...
@@ -744,7 +749,7 @@ struct TransposeSimple {
// But I think it depends on the data size. If span is not large,
// maybe
// can do coalescing.
SwapDim0And2InTranspose
<
T
,
I
DX_T
>
()(
SwapDim0And2InTranspose
<
T
,
I
ndexType
>
()(
ctx
,
in_data
,
new_dim_vec
,
out_data
);
return
true
;
}
else
{
...
...
@@ -1183,7 +1188,7 @@ void TransposeGPUKernelDriver(const phi::GPUContext& ctx,
const
int
rank
=
perm
.
size
();
int64_t
numel
=
in
.
numel
();
bool
ret
{
false
};
if
(
numel
>=
INT32_MAX
)
{
if
(
numel
>=
std
::
numeric_limits
<
int32_t
>::
max
()
)
{
ret
=
TransposeSimple
<
T
,
int64_t
>::
run
(
ctx
,
in
,
perm
,
out
);
}
else
{
ret
=
TransposeSimple
<
T
>::
run
(
ctx
,
in
,
perm
,
out
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录