Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9a4acfee
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9a4acfee
编写于
1月 31, 2023
作者:
Z
zhangkaihuo
提交者:
GitHub
1月 31, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize 2D sync_batch_norm (#49663)
上级
118aee6f
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
388 addition
and
253 deletion
+388
-253
paddle/phi/kernels/funcs/norm_utils.cu.h
paddle/phi/kernels/funcs/norm_utils.cu.h
+118
-2
paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
+53
-176
paddle/phi/kernels/gpu/batch_norm_kernel.cu
paddle/phi/kernels/gpu/batch_norm_kernel.cu
+16
-70
paddle/phi/kernels/gpu/sync_batch_norm_utils.h
paddle/phi/kernels/gpu/sync_batch_norm_utils.h
+201
-5
未找到文件。
paddle/phi/kernels/funcs/norm_utils.cu.h
浏览文件 @
9a4acfee
...
@@ -26,6 +26,7 @@ namespace cub = hipcub;
...
@@ -26,6 +26,7 @@ namespace cub = hipcub;
#endif
#endif
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#ifdef __HIPCC__
#ifdef __HIPCC__
#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim)
#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim)
...
@@ -36,8 +37,6 @@ namespace cub = hipcub;
...
@@ -36,8 +37,6 @@ namespace cub = hipcub;
namespace
phi
{
namespace
phi
{
namespace
funcs
{
namespace
funcs
{
using
DataLayout
=
phi
::
DataLayout
;
// math: dx = scale * ((x - mean) * inv_var / NxHxW * (np.mean(ddx,
// math: dx = scale * ((x - mean) * inv_var / NxHxW * (np.mean(ddx,
// axis=(n,h,w)) *
// axis=(n,h,w)) *
// np.sum(dy, axis=(n,h,w)) -
// np.sum(dy, axis=(n,h,w)) -
...
@@ -670,5 +669,122 @@ void NormDoubleGradFunctor(const DeviceContext &ctx,
...
@@ -670,5 +669,122 @@ void NormDoubleGradFunctor(const DeviceContext &ctx,
}
}
}
}
}
}
template
<
typename
T
,
typename
BnT
>
__device__
__forceinline__
void
BlockReduceByVetical
(
BnT
x_sum
,
BnT
x_square_sum
,
BnT
*
smem_sum
,
BnT
*
smem_square_sum
,
BnT
*
x_sum_out
,
BnT
*
x_square_sum_out
)
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
#pragma unroll
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
>>=
1
)
{
if
(
threadIdx
.
y
<
offset
*
2
)
{
smem_sum
[
tid
]
=
x_sum
;
smem_square_sum
[
tid
]
=
x_square_sum
;
}
__syncthreads
();
if
(
threadIdx
.
y
<
offset
)
{
int
pair_tid
=
tid
+
offset
*
blockDim
.
x
;
x_sum
+=
smem_sum
[
pair_tid
];
x_square_sum
+=
smem_square_sum
[
pair_tid
];
}
}
if
(
threadIdx
.
y
==
0
)
{
*
x_sum_out
=
x_sum
;
*
x_square_sum_out
=
x_square_sum
;
}
}
template
<
typename
T
,
typename
BnT
>
__device__
__forceinline__
void
ReduceSumPost
(
const
int
C
,
// channels
const
int
c
,
// channel index
BnT
*
sum1
,
BnT
*
sum2
,
bool
*
is_last_block_done
,
BnT
*
cache1
,
BnT
*
cache2
,
BnT
*
block_data_ptr
,
int
*
flag_ptr
)
{
volatile
BnT
*
staging_sum
=
block_data_ptr
;
volatile
BnT
*
staging_sum2
=
&
block_data_ptr
[
C
*
gridDim
.
y
];
// write block data to global memory
if
(
threadIdx
.
y
==
0
)
{
staging_sum
[
c
+
blockIdx
.
y
*
C
]
=
*
sum1
;
staging_sum2
[
c
+
blockIdx
.
y
*
C
]
=
*
sum2
;
}
// make sure write is visible to all blocks
__threadfence
();
__syncthreads
();
// mark block done
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
int
old
=
atomicAdd
(
&
flag_ptr
[
blockIdx
.
x
],
1
);
*
is_last_block_done
=
(
old
==
(
gridDim
.
y
-
1
));
}
__syncthreads
();
if
(
*
is_last_block_done
)
{
*
sum1
=
static_cast
<
BnT
>
(
0
);
*
sum2
=
static_cast
<
BnT
>
(
0
);
// thread sum
for
(
int
y
=
threadIdx
.
y
;
y
<
gridDim
.
y
;
y
+=
blockDim
.
y
)
{
*
sum1
+=
staging_sum
[
c
+
y
*
C
];
*
sum2
+=
staging_sum2
[
c
+
y
*
C
];
}
// vertical block sum
funcs
::
BlockReduceByVetical
<
T
,
BnT
>
(
*
sum1
,
*
sum2
,
&
cache1
[
0
],
&
cache2
[
0
],
sum1
,
sum2
);
}
}
template
<
typename
T
,
typename
BnT
,
typename
Context
>
void
SetLaunchConfigInfoForChannelLast
(
const
Context
&
ctx
,
DenseTensor
*
block_data_tensor
,
DenseTensor
*
flag_tensor
,
BnT
**
block_data_ptr
,
int
**
flag_ptr
,
const
int
N
,
const
int
H
,
const
int
W
,
const
int
D
,
const
int
C
,
const
int
block_size
,
dim3
*
block
,
dim3
*
grid
)
{
const
int
MAX_GRID_SIZE
=
128
;
const
int
WARP_SIZE
=
32
;
int
block_x
=
std
::
min
(
phi
::
funcs
::
details
::
GetLastPow2
(
C
),
WARP_SIZE
);
int
block_y
=
std
::
min
(
phi
::
funcs
::
details
::
GetLastPow2
(
N
*
H
*
W
*
D
/
16
),
block_size
/
block_x
);
if
(
block_x
*
block_y
!=
block_size
)
{
block_x
=
std
::
min
(
phi
::
funcs
::
details
::
GetLastPow2
(
C
),
block_size
/
block_y
);
}
int
grid_x
=
(
C
+
block_x
-
1
)
/
block_x
;
int
grid_y
=
std
::
min
((
N
*
H
*
W
*
D
+
block_y
*
16
-
1
)
/
(
block_y
*
16
),
MAX_GRID_SIZE
);
block
->
x
=
block_x
;
block
->
y
=
block_y
;
grid
->
x
=
grid_x
;
grid
->
y
=
grid_y
;
if
(
grid
->
y
>
1
)
{
*
block_data_tensor
=
phi
::
Empty
<
BnT
,
Context
>
(
ctx
,
{
2
*
C
*
grid
->
y
});
*
flag_tensor
=
phi
::
Empty
<
int
,
Context
>
(
ctx
,
{
grid
->
x
});
*
block_data_ptr
=
block_data_tensor
->
data
<
BnT
>
();
*
flag_ptr
=
flag_tensor
->
data
<
int
>
();
funcs
::
SetConstant
<
Context
,
int
>
set_zero
;
set_zero
(
ctx
,
flag_tensor
,
static_cast
<
int
>
(
0
));
}
}
}
// namespace funcs
}
// namespace funcs
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
浏览文件 @
9a4acfee
...
@@ -245,34 +245,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackward(
...
@@ -245,34 +245,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackward(
}
}
}
}
template
<
typename
T
>
__device__
__forceinline__
void
BlockReduceByVetical
(
BatchNormParamType
<
T
>
x_sum
,
BatchNormParamType
<
T
>
x_square_sum
,
BatchNormParamType
<
T
>
*
smem_sum
,
BatchNormParamType
<
T
>
*
smem_square_sum
,
BatchNormParamType
<
T
>
*
x_sum_out
,
BatchNormParamType
<
T
>
*
x_square_sum_out
)
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
#pragma unroll
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
>>=
1
)
{
if
(
threadIdx
.
y
<
offset
*
2
)
{
smem_sum
[
tid
]
=
x_sum
;
smem_square_sum
[
tid
]
=
x_square_sum
;
}
__syncthreads
();
if
(
threadIdx
.
y
<
offset
)
{
int
pair_tid
=
tid
+
offset
*
blockDim
.
x
;
x_sum
+=
smem_sum
[
pair_tid
];
x_square_sum
+=
smem_square_sum
[
pair_tid
];
}
}
if
(
threadIdx
.
y
==
0
)
{
*
x_sum_out
=
x_sum
;
*
x_square_sum_out
=
x_square_sum
;
}
}
template
<
typename
T
,
int
BlockDim
>
template
<
typename
T
,
int
BlockDim
>
static
__global__
void
BNBackward2DChannelLastStage1
(
static
__global__
void
BNBackward2DChannelLastStage1
(
const
T
*
x
,
const
T
*
x
,
...
@@ -309,7 +281,7 @@ static __global__ void BNBackward2DChannelLastStage1(
...
@@ -309,7 +281,7 @@ static __global__ void BNBackward2DChannelLastStage1(
}
}
// vertical block sum
// vertical block sum
BlockReduceByVetical
<
T
>
(
x_sum
,
funcs
::
BlockReduceByVetical
<
T
,
BatchNormParamType
<
T
>
>
(
x_sum
,
x_square_sum
,
x_square_sum
,
&
smem_sum
[
0
],
&
smem_sum
[
0
],
&
smem_square_sum
[
0
],
&
smem_square_sum
[
0
],
...
@@ -317,45 +289,17 @@ static __global__ void BNBackward2DChannelLastStage1(
...
@@ -317,45 +289,17 @@ static __global__ void BNBackward2DChannelLastStage1(
&
x_square_sum
);
&
x_square_sum
);
if
(
gridDim
.
y
>
1
)
{
if
(
gridDim
.
y
>
1
)
{
volatile
BatchNormParamType
<
T
>
*
staging_sum
=
block_data_ptr
;
volatile
BatchNormParamType
<
T
>
*
staging_square_sum
=
&
block_data_ptr
[
C
*
gridDim
.
y
];
// write block data to global memory
if
(
threadIdx
.
y
==
0
)
{
staging_sum
[
i
+
blockIdx
.
y
*
C
]
=
x_sum
;
staging_square_sum
[
i
+
blockIdx
.
y
*
C
]
=
x_square_sum
;
}
// make sure write is visible to all blocks
__threadfence
();
__syncthreads
();
__shared__
bool
is_last_block_done
;
__shared__
bool
is_last_block_done
;
// mark block done
funcs
::
ReduceSumPost
<
T
,
BatchNormParamType
<
T
>>
(
C
,
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
i
,
int
old
=
atomicAdd
(
&
flag_ptr
[
blockIdx
.
x
],
1
);
is_last_block_done
=
(
old
==
(
gridDim
.
y
-
1
));
}
__syncthreads
();
if
(
is_last_block_done
)
{
x_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
x_square_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
// thread sum
for
(
int
y
=
threadIdx
.
y
;
y
<
gridDim
.
y
;
y
+=
blockDim
.
y
)
{
x_sum
+=
staging_sum
[
i
+
y
*
C
];
x_square_sum
+=
staging_square_sum
[
i
+
y
*
C
];
}
// vertical block sum
BlockReduceByVetical
<
T
>
(
x_sum
,
x_square_sum
,
&
smem_sum
[
0
],
&
smem_square_sum
[
0
],
&
x_sum
,
&
x_sum
,
&
x_square_sum
);
&
x_square_sum
,
&
is_last_block_done
,
smem_sum
,
smem_square_sum
,
block_data_ptr
,
flag_ptr
);
if
(
is_last_block_done
)
{
// final compute
// final compute
if
(
threadIdx
.
y
==
0
)
{
if
(
threadIdx
.
y
==
0
)
{
BatchNormParamType
<
T
>
compute_mean_val
=
x_sum
/
inner_size
;
BatchNormParamType
<
T
>
compute_mean_val
=
x_sum
/
inner_size
;
...
@@ -417,45 +361,21 @@ static __global__ void BNBackward2DChannelLastStage2(
...
@@ -417,45 +361,21 @@ static __global__ void BNBackward2DChannelLastStage2(
}
}
// vertical block sum
// vertical block sum
BlockReduceByVetical
<
T
>
(
funcs
::
BlockReduceByVetical
<
T
,
BatchNormParamType
<
T
>
>
(
ds_sum
,
db_sum
,
&
smem_ds_sum
[
0
],
&
smem_db_sum
[
0
],
&
ds_sum
,
&
db_sum
);
ds_sum
,
db_sum
,
&
smem_ds_sum
[
0
],
&
smem_db_sum
[
0
],
&
ds_sum
,
&
db_sum
);
if
(
gridDim
.
y
>
1
)
{
if
(
gridDim
.
y
>
1
)
{
volatile
BatchNormParamType
<
T
>
*
staging_ds_sum
=
block_data_ptr
;
volatile
BatchNormParamType
<
T
>
*
staging_db_sum
=
&
block_data_ptr
[
C
*
gridDim
.
y
];
// write block data to global memory
if
(
threadIdx
.
y
==
0
)
{
staging_ds_sum
[
i
+
blockIdx
.
y
*
C
]
=
ds_sum
;
staging_db_sum
[
i
+
blockIdx
.
y
*
C
]
=
db_sum
;
}
// make sure write is visible to all blocks
__threadfence
();
__syncthreads
();
__shared__
bool
is_last_block_done
;
__shared__
bool
is_last_block_done
;
// mark block done
funcs
::
ReduceSumPost
<
T
,
BatchNormParamType
<
T
>>
(
C
,
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
i
,
int
old
=
atomicAdd
(
&
flag_ptr
[
blockIdx
.
x
],
1
);
&
ds_sum
,
is_last_block_done
=
(
old
==
(
gridDim
.
y
-
1
));
&
db_sum
,
}
&
is_last_block_done
,
smem_ds_sum
,
__syncthreads
();
smem_db_sum
,
block_data_ptr
,
flag_ptr
);
if
(
is_last_block_done
)
{
if
(
is_last_block_done
)
{
ds_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
db_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
// thread sum
for
(
int
y
=
threadIdx
.
y
;
y
<
gridDim
.
y
;
y
+=
blockDim
.
y
)
{
ds_sum
+=
staging_ds_sum
[
i
+
y
*
C
];
db_sum
+=
staging_db_sum
[
i
+
y
*
C
];
}
// vertical block sum
BlockReduceByVetical
<
T
>
(
ds_sum
,
db_sum
,
&
smem_ds_sum
[
0
],
&
smem_db_sum
[
0
],
&
ds_sum
,
&
db_sum
);
// final compute
// final compute
if
(
threadIdx
.
y
==
0
)
{
if
(
threadIdx
.
y
==
0
)
{
dscale
[
i
]
=
ds_sum
*
inv_var_val
;
dscale
[
i
]
=
ds_sum
*
inv_var_val
;
...
@@ -563,51 +483,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData(
...
@@ -563,51 +483,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData(
}
}
}
}
template
<
typename
T
,
typename
Context
>
void
SetLaunchConfigInfoForChannelLast
(
const
Context
&
ctx
,
DenseTensor
*
block_data_tensor
,
DenseTensor
*
flag_tensor
,
BatchNormParamType
<
T
>
**
block_data_ptr
,
int
**
flag_ptr
,
const
int
N
,
const
int
H
,
const
int
W
,
const
int
D
,
const
int
C
,
const
int
block_size
,
dim3
*
block
,
dim3
*
grid
)
{
const
int
MAX_GRID_SIZE
=
128
;
const
int
WARP_SIZE
=
32
;
int
block_x
=
std
::
min
(
phi
::
funcs
::
details
::
GetLastPow2
(
C
),
WARP_SIZE
);
int
block_y
=
std
::
min
(
phi
::
funcs
::
details
::
GetLastPow2
(
N
*
H
*
W
*
D
/
16
),
block_size
/
block_x
);
if
(
block_x
*
block_y
!=
block_size
)
{
block_x
=
std
::
min
(
phi
::
funcs
::
details
::
GetLastPow2
(
C
),
block_size
/
block_y
);
}
int
grid_x
=
(
C
+
block_x
-
1
)
/
block_x
;
int
grid_y
=
std
::
min
((
N
*
H
*
W
*
D
+
block_y
*
16
-
1
)
/
(
block_y
*
16
),
MAX_GRID_SIZE
);
block
->
x
=
block_x
;
block
->
y
=
block_y
;
grid
->
x
=
grid_x
;
grid
->
y
=
grid_y
;
if
(
grid
->
y
>
1
)
{
*
block_data_tensor
=
phi
::
Empty
<
BatchNormParamType
<
T
>
,
Context
>
(
ctx
,
{
2
*
C
*
grid
->
y
});
*
flag_tensor
=
phi
::
Empty
<
int
,
Context
>
(
ctx
,
{
grid
->
x
});
*
block_data_ptr
=
block_data_tensor
->
data
<
BatchNormParamType
<
T
>>
();
*
flag_ptr
=
flag_tensor
->
data
<
int
>
();
funcs
::
SetConstant
<
Context
,
int
>
set_zero
;
set_zero
(
ctx
,
flag_tensor
,
static_cast
<
int
>
(
0
));
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
BatchNormGradRawKernel
(
const
Context
&
ctx
,
void
BatchNormGradRawKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
...
@@ -931,7 +806,8 @@ void BatchNormGradRawKernel(const Context &ctx,
...
@@ -931,7 +806,8 @@ void BatchNormGradRawKernel(const Context &ctx,
BatchNormParamType
<
T
>
*
block_data_ptr
=
nullptr
;
BatchNormParamType
<
T
>
*
block_data_ptr
=
nullptr
;
int
*
flag_ptr
=
nullptr
;
int
*
flag_ptr
=
nullptr
;
SetLaunchConfigInfoForChannelLast
<
T
>
(
ctx
,
funcs
::
SetLaunchConfigInfoForChannelLast
<
T
,
BatchNormParamType
<
T
>>
(
ctx
,
&
block_data_tensor
,
&
block_data_tensor
,
&
flag_tensor
,
&
flag_tensor
,
&
block_data_ptr
,
&
block_data_ptr
,
...
@@ -1294,7 +1170,8 @@ void BatchNormGradRawKernel(const Context &ctx,
...
@@ -1294,7 +1170,8 @@ void BatchNormGradRawKernel(const Context &ctx,
BatchNormParamType
<
T
>
*
block_data_ptr
=
nullptr
;
BatchNormParamType
<
T
>
*
block_data_ptr
=
nullptr
;
int
*
flag_ptr
=
nullptr
;
int
*
flag_ptr
=
nullptr
;
SetLaunchConfigInfoForChannelLast
<
T
>
(
ctx
,
funcs
::
SetLaunchConfigInfoForChannelLast
<
T
,
BatchNormParamType
<
T
>>
(
ctx
,
&
block_data_tensor
,
&
block_data_tensor
,
&
flag_tensor
,
&
flag_tensor
,
&
block_data_ptr
,
&
block_data_ptr
,
...
...
paddle/phi/kernels/gpu/batch_norm_kernel.cu
浏览文件 @
9a4acfee
...
@@ -30,6 +30,7 @@ namespace cub = hipcub;
...
@@ -30,6 +30,7 @@ namespace cub = hipcub;
#include "paddle/phi/kernels/batch_norm_kernel.h"
#include "paddle/phi/kernels/batch_norm_kernel.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/norm_utils.cu.h"
#include "paddle/phi/kernels/funcs/norm_utils.h"
#include "paddle/phi/kernels/funcs/norm_utils.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
...
@@ -171,34 +172,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining(
...
@@ -171,34 +172,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining(
}
}
}
}
template
<
typename
T
>
__device__
__forceinline__
void
merge_block_vertical
(
BatchNormParamType
<
T
>
x_sum
,
BatchNormParamType
<
T
>
x_square_sum
,
BatchNormParamType
<
T
>
*
smem_sum
,
BatchNormParamType
<
T
>
*
smem_square_sum
,
BatchNormParamType
<
T
>
*
x_sum_out
,
BatchNormParamType
<
T
>
*
x_square_sum_out
)
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
#pragma unroll
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
>>=
1
)
{
if
(
threadIdx
.
y
<
offset
*
2
)
{
smem_sum
[
tid
]
=
x_sum
;
smem_square_sum
[
tid
]
=
x_square_sum
;
}
__syncthreads
();
if
(
threadIdx
.
y
<
offset
)
{
int
pair_tid
=
tid
+
offset
*
blockDim
.
x
;
x_sum
+=
smem_sum
[
pair_tid
];
x_square_sum
+=
smem_square_sum
[
pair_tid
];
}
}
if
(
threadIdx
.
y
==
0
)
{
*
x_sum_out
=
x_sum
;
*
x_square_sum_out
=
x_square_sum
;
}
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
void
merge_block_horizonal
(
__device__
__forceinline__
void
merge_block_horizonal
(
BatchNormParamType
<
T
>
x_sum
,
BatchNormParamType
<
T
>
x_sum
,
...
@@ -269,7 +242,7 @@ static __global__ void BNForwardTraining2DChannelLastCompStat(
...
@@ -269,7 +242,7 @@ static __global__ void BNForwardTraining2DChannelLastCompStat(
}
}
// vertical block sum
// vertical block sum
merge_block_vertical
<
T
>
(
x_sum
,
funcs
::
BlockReduceByVetical
<
T
,
BatchNormParamType
<
T
>
>
(
x_sum
,
x_square_sum
,
x_square_sum
,
&
smem_sum
[
0
],
&
smem_sum
[
0
],
&
smem_square_sum
[
0
],
&
smem_square_sum
[
0
],
...
@@ -277,45 +250,18 @@ static __global__ void BNForwardTraining2DChannelLastCompStat(
...
@@ -277,45 +250,18 @@ static __global__ void BNForwardTraining2DChannelLastCompStat(
&
x_square_sum
);
&
x_square_sum
);
if
(
gridDim
.
y
>
1
)
{
if
(
gridDim
.
y
>
1
)
{
volatile
BatchNormParamType
<
T
>
*
staging_sum
=
block_data_ptr
;
volatile
BatchNormParamType
<
T
>
*
staging_square_sum
=
&
block_data_ptr
[
C
*
gridDim
.
y
];
// write block data to global memory
if
(
threadIdx
.
y
==
0
)
{
staging_sum
[
i
+
blockIdx
.
y
*
C
]
=
x_sum
;
staging_square_sum
[
i
+
blockIdx
.
y
*
C
]
=
x_square_sum
;
}
// make sure write is visible to all blocks
__threadfence
();
__syncthreads
();
__shared__
bool
is_last_block_done
;
__shared__
bool
is_last_block_done
;
// mark block done
funcs
::
ReduceSumPost
<
T
,
BatchNormParamType
<
T
>>
(
C
,
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
i
,
int
old
=
atomicAdd
(
&
flag_ptr
[
blockIdx
.
x
],
1
);
is_last_block_done
=
(
old
==
(
gridDim
.
y
-
1
));
}
__syncthreads
();
if
(
is_last_block_done
)
{
x_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
x_square_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
// thread sum
for
(
int
y
=
threadIdx
.
y
;
y
<
gridDim
.
y
;
y
+=
blockDim
.
y
)
{
x_sum
+=
staging_sum
[
i
+
y
*
C
];
x_square_sum
+=
staging_square_sum
[
i
+
y
*
C
];
}
// vertical block sum
merge_block_vertical
<
T
>
(
x_sum
,
x_square_sum
,
&
smem_sum
[
0
],
&
smem_square_sum
[
0
],
&
x_sum
,
&
x_sum
,
&
x_square_sum
);
&
x_square_sum
,
&
is_last_block_done
,
smem_sum
,
smem_square_sum
,
block_data_ptr
,
flag_ptr
);
if
(
is_last_block_done
)
{
// final compute
// final compute
if
(
threadIdx
.
y
==
0
)
{
if
(
threadIdx
.
y
==
0
)
{
BatchNormParamType
<
T
>
compute_mean_val
=
x_sum
/
inner_size
;
BatchNormParamType
<
T
>
compute_mean_val
=
x_sum
/
inner_size
;
...
...
paddle/phi/kernels/gpu/sync_batch_norm_utils.h
浏览文件 @
9a4acfee
...
@@ -34,6 +34,7 @@ namespace cub = hipcub;
...
@@ -34,6 +34,7 @@ namespace cub = hipcub;
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/kernels/funcs/norm_utils.cu.h"
#include "paddle/phi/kernels/funcs/norm_utils.h"
#include "paddle/phi/kernels/funcs/norm_utils.h"
namespace
phi
{
namespace
phi
{
...
@@ -168,6 +169,61 @@ __global__ void KeBackwardLocalStats(const T *dy,
...
@@ -168,6 +169,61 @@ __global__ void KeBackwardLocalStats(const T *dy,
}
}
}
}
template
<
typename
T
,
const
int
BlockDim
,
DataLayout
layout
>
__global__
void
KeBackwardLocalStats2D
(
const
T
*
dy
,
const
T
*
x
,
const
BatchNormParamType
<
T
>
*
means
,
int
N
,
int
M
,
int
C
,
BatchNormParamType
<
T
>
*
block_data_ptr
,
int
*
flag_ptr
,
BatchNormParamType
<
T
>
*
sum_dy_prod
)
{
__shared__
BatchNormParamType
<
T
>
smem_sum
[
BlockDim
];
__shared__
BatchNormParamType
<
T
>
smem_square_sum
[
BlockDim
];
for
(
int
k
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
k
<
C
;
k
+=
gridDim
.
x
*
blockDim
.
x
)
{
BatchNormParamType
<
T
>
sum1
=
0.
;
BatchNormParamType
<
T
>
sum2
=
0.
;
auto
mean
=
means
[
k
];
for
(
int
i
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
i
<
N
*
M
;
i
+=
gridDim
.
y
*
blockDim
.
y
)
{
int
id
=
layout
==
DataLayout
::
kNCHW
?
(
i
/
M
)
*
C
*
M
+
k
*
M
+
i
%
M
:
i
*
C
+
k
;
auto
g
=
static_cast
<
BatchNormParamType
<
T
>>
(
dy
[
id
]);
sum1
+=
g
;
auto
x_i
=
static_cast
<
BatchNormParamType
<
T
>>
(
x
[
id
]);
sum2
+=
g
*
(
x_i
-
mean
);
}
funcs
::
BlockReduceByVetical
<
T
,
BatchNormParamType
<
T
>>
(
sum1
,
sum2
,
&
smem_sum
[
0
],
&
smem_square_sum
[
0
],
&
sum1
,
&
sum2
);
if
(
gridDim
.
y
>
1
)
{
__shared__
bool
is_last_block_done
;
funcs
::
ReduceSumPost
<
T
,
BatchNormParamType
<
T
>>
(
C
,
k
,
&
sum1
,
&
sum2
,
&
is_last_block_done
,
smem_sum
,
smem_square_sum
,
block_data_ptr
,
flag_ptr
);
if
(
is_last_block_done
)
{
// final compute
if
(
threadIdx
.
y
==
0
)
{
sum_dy_prod
[
k
]
=
sum1
;
sum_dy_prod
[
k
+
C
]
=
sum2
;
}
}
}
}
if
(
blockIdx
.
y
==
0
&&
blockIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
threadIdx
.
x
==
0
)
{
sum_dy_prod
[
2
*
C
]
=
1.0
;
}
}
template
<
typename
T
,
int
BlockDim
,
DataLayout
layout
>
template
<
typename
T
,
int
BlockDim
,
DataLayout
layout
>
static
__global__
void
KeBNBackwardScaleBias
(
static
__global__
void
KeBNBackwardScaleBias
(
const
T
*
dy
,
const
T
*
dy
,
...
@@ -213,6 +269,68 @@ static __global__ void KeBNBackwardScaleBias(
...
@@ -213,6 +269,68 @@ static __global__ void KeBNBackwardScaleBias(
}
}
}
}
template
<
typename
T
,
int
BlockDim
,
DataLayout
layout
>
static
__global__
void
KeBNBackwardScaleBias2D
(
const
T
*
dy
,
const
T
*
x
,
const
BatchNormParamType
<
T
>
*
mean
,
const
BatchNormParamType
<
T
>
*
inv_variance
,
const
double
epsilon
,
const
int
N
,
const
int
C
,
const
int
HxW
,
BatchNormParamType
<
T
>
*
block_data_ptr
,
int
*
flag_ptr
,
BatchNormParamType
<
T
>
*
dscale
,
BatchNormParamType
<
T
>
*
dbias
)
{
const
int
outer_size
=
C
;
const
int
inner_size
=
N
*
HxW
;
__shared__
BatchNormParamType
<
T
>
smem_sum
[
BlockDim
];
__shared__
BatchNormParamType
<
T
>
smem_square_sum
[
BlockDim
];
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
outer_size
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
BatchNormParamType
<
T
>
ds_sum
=
0.
;
BatchNormParamType
<
T
>
db_sum
=
0.
;
auto
inv_var_i
=
inv_variance
[
i
];
auto
mean_i
=
mean
[
i
];
for
(
int
j
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
j
<
inner_size
;
j
+=
gridDim
.
y
*
blockDim
.
y
)
{
const
int
id
=
layout
==
DataLayout
::
kNCHW
?
((
j
/
HxW
)
*
C
+
i
)
*
HxW
+
(
j
%
HxW
)
:
j
*
outer_size
+
i
;
auto
x_i
=
static_cast
<
BatchNormParamType
<
T
>>
(
x
[
id
]);
auto
dy_i
=
static_cast
<
BatchNormParamType
<
T
>>
(
dy
[
id
]);
ds_sum
+=
dy_i
*
(
x_i
-
mean_i
);
db_sum
+=
dy_i
;
}
funcs
::
BlockReduceByVetical
<
T
,
BatchNormParamType
<
T
>>
(
ds_sum
,
db_sum
,
&
smem_sum
[
0
],
&
smem_square_sum
[
0
],
&
ds_sum
,
&
db_sum
);
if
(
gridDim
.
y
>
1
)
{
__shared__
bool
is_last_block_done
;
funcs
::
ReduceSumPost
<
T
,
BatchNormParamType
<
T
>>
(
C
,
i
,
&
ds_sum
,
&
db_sum
,
&
is_last_block_done
,
smem_sum
,
smem_square_sum
,
block_data_ptr
,
flag_ptr
);
if
(
is_last_block_done
)
{
// final compute
if
(
threadIdx
.
y
==
0
)
{
dscale
[
i
]
=
ds_sum
*
inv_var_i
;
dbias
[
i
]
=
db_sum
;
}
}
}
}
}
template
<
typename
T
,
DataLayout
layout
>
template
<
typename
T
,
DataLayout
layout
>
static
__global__
void
KeBNRestoreData
(
T
*
x
,
static
__global__
void
KeBNRestoreData
(
T
*
x
,
const
BatchNormParamType
<
T
>
*
scale
,
const
BatchNormParamType
<
T
>
*
scale
,
...
@@ -409,11 +527,48 @@ void SyncBatchNormGradFunctor(
...
@@ -409,11 +527,48 @@ void SyncBatchNormGradFunctor(
KeBackwardLocalStats
<
T
,
threads
,
DataLayout
::
kNCHW
>
KeBackwardLocalStats
<
T
,
threads
,
DataLayout
::
kNCHW
>
<<<
grid
,
threads
,
0
,
stream
>>>
(
<<<
grid
,
threads
,
0
,
stream
>>>
(
dy_d
,
x_d
,
saved_mean_ptr
,
N
,
fsize
,
C
,
stats
);
dy_d
,
x_d
,
saved_mean_ptr
,
N
,
fsize
,
C
,
stats
);
}
else
{
if
(
x_dims
.
size
()
==
2
&&
N
>=
65535
)
{
dim3
block
;
dim3
grid
;
const
int
block_size
=
512
;
// init intermediate storage
DenseTensor
block_data_tensor
;
DenseTensor
flag_tensor
;
BatchNormParamType
<
T
>
*
block_data_ptr
=
nullptr
;
int
*
flag_ptr
=
nullptr
;
funcs
::
SetLaunchConfigInfoForChannelLast
<
T
,
BatchNormParamType
<
T
>>
(
ctx
,
&
block_data_tensor
,
&
flag_tensor
,
&
block_data_ptr
,
&
flag_ptr
,
N
,
H
,
W
,
D
,
C
,
block_size
,
&
block
,
&
grid
);
KeBackwardLocalStats2D
<
T
,
block_size
,
DataLayout
::
kNHWC
>
<<<
grid
,
block
,
0
,
stream
>>>
(
dy_d
,
x_d
,
saved_mean_ptr
,
N
,
fsize
,
C
,
block_data_ptr
,
flag_ptr
,
stats
);
}
else
{
}
else
{
KeBackwardLocalStats
<
T
,
threads
,
DataLayout
::
kNHWC
>
KeBackwardLocalStats
<
T
,
threads
,
DataLayout
::
kNHWC
>
<<<
grid
,
threads
,
0
,
stream
>>>
(
<<<
grid
,
threads
,
0
,
stream
>>>
(
dy_d
,
x_d
,
saved_mean_ptr
,
N
,
fsize
,
C
,
stats
);
dy_d
,
x_d
,
saved_mean_ptr
,
N
,
fsize
,
C
,
stats
);
}
}
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
int
global_gid
=
0
;
int
global_gid
=
0
;
...
@@ -476,8 +631,48 @@ void SyncBatchNormGradFunctor(
...
@@ -476,8 +631,48 @@ void SyncBatchNormGradFunctor(
}
}
}
else
{
}
else
{
if
(
d_scale
&&
d_bias
)
{
if
(
d_scale
&&
d_bias
)
{
if
(
x_dims
.
size
()
==
2
&&
N
>=
65535
)
{
dim3
block
;
dim3
grid
;
const
int
block_size
=
512
;
// init intermediate storage
DenseTensor
block_data_tensor
;
DenseTensor
flag_tensor
;
BatchNormParamType
<
T
>
*
block_data_ptr
=
nullptr
;
int
*
flag_ptr
=
nullptr
;
funcs
::
SetLaunchConfigInfoForChannelLast
<
T
,
BatchNormParamType
<
T
>>
(
ctx
,
&
block_data_tensor
,
&
flag_tensor
,
&
block_data_ptr
,
&
flag_ptr
,
N
,
H
,
W
,
D
,
C
,
block_size
,
&
block
,
&
grid
);
KeBNBackwardScaleBias2D
<
T
,
block_size
,
DataLayout
::
kNHWC
>
<<<
grid
,
block
,
0
,
stream
>>>
(
dy_d
,
x_d
,
saved_mean_ptr
,
saved_inv_var
,
epsilon
,
N
,
C
,
fsize
,
block_data_ptr
,
flag_ptr
,
d_scale
->
data
<
BatchNormParamType
<
T
>>
(),
d_bias
->
data
<
BatchNormParamType
<
T
>>
());
}
else
{
KeBNBackwardScaleBias
<
T
,
threads
,
DataLayout
::
kNHWC
>
KeBNBackwardScaleBias
<
T
,
threads
,
DataLayout
::
kNHWC
>
<<<
grid
,
threads
,
0
,
stream
>>>
(
dy_d
,
<<<
grid
,
threads
,
0
,
stream
>>>
(
dy_d
,
x_d
,
x_d
,
saved_mean_ptr
,
saved_mean_ptr
,
saved_inv_var
,
saved_inv_var
,
...
@@ -488,6 +683,7 @@ void SyncBatchNormGradFunctor(
...
@@ -488,6 +683,7 @@ void SyncBatchNormGradFunctor(
d_scale
->
data
<
BatchNormParamType
<
T
>>
(),
d_scale
->
data
<
BatchNormParamType
<
T
>>
(),
d_bias
->
data
<
BatchNormParamType
<
T
>>
());
d_bias
->
data
<
BatchNormParamType
<
T
>>
());
}
}
}
if
(
d_x
)
{
if
(
d_x
)
{
KeBNBackwardData
<
T
,
DataLayout
::
kNHWC
><<<
grid2
,
block
,
0
,
stream
>>>
(
KeBNBackwardData
<
T
,
DataLayout
::
kNHWC
><<<
grid2
,
block
,
0
,
stream
>>>
(
dy_d
,
dy_d
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录