Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9a4acfee
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
9a4acfee
编写于
1月 31, 2023
作者:
Z
zhangkaihuo
提交者:
GitHub
1月 31, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize 2D sync_batch_norm (#49663)
上级
118aee6f
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
388 addition
and
253 deletion
+388
-253
paddle/phi/kernels/funcs/norm_utils.cu.h
paddle/phi/kernels/funcs/norm_utils.cu.h
+118
-2
paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
+53
-176
paddle/phi/kernels/gpu/batch_norm_kernel.cu
paddle/phi/kernels/gpu/batch_norm_kernel.cu
+16
-70
paddle/phi/kernels/gpu/sync_batch_norm_utils.h
paddle/phi/kernels/gpu/sync_batch_norm_utils.h
+201
-5
未找到文件。
paddle/phi/kernels/funcs/norm_utils.cu.h
浏览文件 @
9a4acfee
...
@@ -26,6 +26,7 @@ namespace cub = hipcub;
...
@@ -26,6 +26,7 @@ namespace cub = hipcub;
#endif
#endif
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#ifdef __HIPCC__
#ifdef __HIPCC__
#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim)
#define LAUNCH_BOUNDS(BlockDim) __launch_bounds__(BlockDim)
...
@@ -36,8 +37,6 @@ namespace cub = hipcub;
...
@@ -36,8 +37,6 @@ namespace cub = hipcub;
namespace
phi
{
namespace
phi
{
namespace
funcs
{
namespace
funcs
{
using
DataLayout
=
phi
::
DataLayout
;
// math: dx = scale * ((x - mean) * inv_var / NxHxW * (np.mean(ddx,
// math: dx = scale * ((x - mean) * inv_var / NxHxW * (np.mean(ddx,
// axis=(n,h,w)) *
// axis=(n,h,w)) *
// np.sum(dy, axis=(n,h,w)) -
// np.sum(dy, axis=(n,h,w)) -
...
@@ -670,5 +669,122 @@ void NormDoubleGradFunctor(const DeviceContext &ctx,
...
@@ -670,5 +669,122 @@ void NormDoubleGradFunctor(const DeviceContext &ctx,
}
}
}
}
}
}
template
<
typename
T
,
typename
BnT
>
__device__
__forceinline__
void
BlockReduceByVetical
(
BnT
x_sum
,
BnT
x_square_sum
,
BnT
*
smem_sum
,
BnT
*
smem_square_sum
,
BnT
*
x_sum_out
,
BnT
*
x_square_sum_out
)
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
#pragma unroll
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
>>=
1
)
{
if
(
threadIdx
.
y
<
offset
*
2
)
{
smem_sum
[
tid
]
=
x_sum
;
smem_square_sum
[
tid
]
=
x_square_sum
;
}
__syncthreads
();
if
(
threadIdx
.
y
<
offset
)
{
int
pair_tid
=
tid
+
offset
*
blockDim
.
x
;
x_sum
+=
smem_sum
[
pair_tid
];
x_square_sum
+=
smem_square_sum
[
pair_tid
];
}
}
if
(
threadIdx
.
y
==
0
)
{
*
x_sum_out
=
x_sum
;
*
x_square_sum_out
=
x_square_sum
;
}
}
template
<
typename
T
,
typename
BnT
>
__device__
__forceinline__
void
ReduceSumPost
(
const
int
C
,
// channels
const
int
c
,
// channel index
BnT
*
sum1
,
BnT
*
sum2
,
bool
*
is_last_block_done
,
BnT
*
cache1
,
BnT
*
cache2
,
BnT
*
block_data_ptr
,
int
*
flag_ptr
)
{
volatile
BnT
*
staging_sum
=
block_data_ptr
;
volatile
BnT
*
staging_sum2
=
&
block_data_ptr
[
C
*
gridDim
.
y
];
// write block data to global memory
if
(
threadIdx
.
y
==
0
)
{
staging_sum
[
c
+
blockIdx
.
y
*
C
]
=
*
sum1
;
staging_sum2
[
c
+
blockIdx
.
y
*
C
]
=
*
sum2
;
}
// make sure write is visible to all blocks
__threadfence
();
__syncthreads
();
// mark block done
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
int
old
=
atomicAdd
(
&
flag_ptr
[
blockIdx
.
x
],
1
);
*
is_last_block_done
=
(
old
==
(
gridDim
.
y
-
1
));
}
__syncthreads
();
if
(
*
is_last_block_done
)
{
*
sum1
=
static_cast
<
BnT
>
(
0
);
*
sum2
=
static_cast
<
BnT
>
(
0
);
// thread sum
for
(
int
y
=
threadIdx
.
y
;
y
<
gridDim
.
y
;
y
+=
blockDim
.
y
)
{
*
sum1
+=
staging_sum
[
c
+
y
*
C
];
*
sum2
+=
staging_sum2
[
c
+
y
*
C
];
}
// vertical block sum
funcs
::
BlockReduceByVetical
<
T
,
BnT
>
(
*
sum1
,
*
sum2
,
&
cache1
[
0
],
&
cache2
[
0
],
sum1
,
sum2
);
}
}
template
<
typename
T
,
typename
BnT
,
typename
Context
>
void
SetLaunchConfigInfoForChannelLast
(
const
Context
&
ctx
,
DenseTensor
*
block_data_tensor
,
DenseTensor
*
flag_tensor
,
BnT
**
block_data_ptr
,
int
**
flag_ptr
,
const
int
N
,
const
int
H
,
const
int
W
,
const
int
D
,
const
int
C
,
const
int
block_size
,
dim3
*
block
,
dim3
*
grid
)
{
const
int
MAX_GRID_SIZE
=
128
;
const
int
WARP_SIZE
=
32
;
int
block_x
=
std
::
min
(
phi
::
funcs
::
details
::
GetLastPow2
(
C
),
WARP_SIZE
);
int
block_y
=
std
::
min
(
phi
::
funcs
::
details
::
GetLastPow2
(
N
*
H
*
W
*
D
/
16
),
block_size
/
block_x
);
if
(
block_x
*
block_y
!=
block_size
)
{
block_x
=
std
::
min
(
phi
::
funcs
::
details
::
GetLastPow2
(
C
),
block_size
/
block_y
);
}
int
grid_x
=
(
C
+
block_x
-
1
)
/
block_x
;
int
grid_y
=
std
::
min
((
N
*
H
*
W
*
D
+
block_y
*
16
-
1
)
/
(
block_y
*
16
),
MAX_GRID_SIZE
);
block
->
x
=
block_x
;
block
->
y
=
block_y
;
grid
->
x
=
grid_x
;
grid
->
y
=
grid_y
;
if
(
grid
->
y
>
1
)
{
*
block_data_tensor
=
phi
::
Empty
<
BnT
,
Context
>
(
ctx
,
{
2
*
C
*
grid
->
y
});
*
flag_tensor
=
phi
::
Empty
<
int
,
Context
>
(
ctx
,
{
grid
->
x
});
*
block_data_ptr
=
block_data_tensor
->
data
<
BnT
>
();
*
flag_ptr
=
flag_tensor
->
data
<
int
>
();
funcs
::
SetConstant
<
Context
,
int
>
set_zero
;
set_zero
(
ctx
,
flag_tensor
,
static_cast
<
int
>
(
0
));
}
}
}
// namespace funcs
}
// namespace funcs
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu
浏览文件 @
9a4acfee
...
@@ -245,34 +245,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackward(
...
@@ -245,34 +245,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackward(
}
}
}
}
template
<
typename
T
>
__device__
__forceinline__
void
BlockReduceByVetical
(
BatchNormParamType
<
T
>
x_sum
,
BatchNormParamType
<
T
>
x_square_sum
,
BatchNormParamType
<
T
>
*
smem_sum
,
BatchNormParamType
<
T
>
*
smem_square_sum
,
BatchNormParamType
<
T
>
*
x_sum_out
,
BatchNormParamType
<
T
>
*
x_square_sum_out
)
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
#pragma unroll
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
>>=
1
)
{
if
(
threadIdx
.
y
<
offset
*
2
)
{
smem_sum
[
tid
]
=
x_sum
;
smem_square_sum
[
tid
]
=
x_square_sum
;
}
__syncthreads
();
if
(
threadIdx
.
y
<
offset
)
{
int
pair_tid
=
tid
+
offset
*
blockDim
.
x
;
x_sum
+=
smem_sum
[
pair_tid
];
x_square_sum
+=
smem_square_sum
[
pair_tid
];
}
}
if
(
threadIdx
.
y
==
0
)
{
*
x_sum_out
=
x_sum
;
*
x_square_sum_out
=
x_square_sum
;
}
}
template
<
typename
T
,
int
BlockDim
>
template
<
typename
T
,
int
BlockDim
>
static
__global__
void
BNBackward2DChannelLastStage1
(
static
__global__
void
BNBackward2DChannelLastStage1
(
const
T
*
x
,
const
T
*
x
,
...
@@ -309,7 +281,7 @@ static __global__ void BNBackward2DChannelLastStage1(
...
@@ -309,7 +281,7 @@ static __global__ void BNBackward2DChannelLastStage1(
}
}
// vertical block sum
// vertical block sum
BlockReduceByVetical
<
T
>
(
x_sum
,
funcs
::
BlockReduceByVetical
<
T
,
BatchNormParamType
<
T
>
>
(
x_sum
,
x_square_sum
,
x_square_sum
,
&
smem_sum
[
0
],
&
smem_sum
[
0
],
&
smem_square_sum
[
0
],
&
smem_square_sum
[
0
],
...
@@ -317,45 +289,17 @@ static __global__ void BNBackward2DChannelLastStage1(
...
@@ -317,45 +289,17 @@ static __global__ void BNBackward2DChannelLastStage1(
&
x_square_sum
);
&
x_square_sum
);
if
(
gridDim
.
y
>
1
)
{
if
(
gridDim
.
y
>
1
)
{
volatile
BatchNormParamType
<
T
>
*
staging_sum
=
block_data_ptr
;
volatile
BatchNormParamType
<
T
>
*
staging_square_sum
=
&
block_data_ptr
[
C
*
gridDim
.
y
];
// write block data to global memory
if
(
threadIdx
.
y
==
0
)
{
staging_sum
[
i
+
blockIdx
.
y
*
C
]
=
x_sum
;
staging_square_sum
[
i
+
blockIdx
.
y
*
C
]
=
x_square_sum
;
}
// make sure write is visible to all blocks
__threadfence
();
__syncthreads
();
__shared__
bool
is_last_block_done
;
__shared__
bool
is_last_block_done
;
// mark block done
funcs
::
ReduceSumPost
<
T
,
BatchNormParamType
<
T
>>
(
C
,
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
i
,
int
old
=
atomicAdd
(
&
flag_ptr
[
blockIdx
.
x
],
1
);
is_last_block_done
=
(
old
==
(
gridDim
.
y
-
1
));
}
__syncthreads
();
if
(
is_last_block_done
)
{
x_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
x_square_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
// thread sum
for
(
int
y
=
threadIdx
.
y
;
y
<
gridDim
.
y
;
y
+=
blockDim
.
y
)
{
x_sum
+=
staging_sum
[
i
+
y
*
C
];
x_square_sum
+=
staging_square_sum
[
i
+
y
*
C
];
}
// vertical block sum
BlockReduceByVetical
<
T
>
(
x_sum
,
x_square_sum
,
&
smem_sum
[
0
],
&
smem_square_sum
[
0
],
&
x_sum
,
&
x_sum
,
&
x_square_sum
);
&
x_square_sum
,
&
is_last_block_done
,
smem_sum
,
smem_square_sum
,
block_data_ptr
,
flag_ptr
);
if
(
is_last_block_done
)
{
// final compute
// final compute
if
(
threadIdx
.
y
==
0
)
{
if
(
threadIdx
.
y
==
0
)
{
BatchNormParamType
<
T
>
compute_mean_val
=
x_sum
/
inner_size
;
BatchNormParamType
<
T
>
compute_mean_val
=
x_sum
/
inner_size
;
...
@@ -417,45 +361,21 @@ static __global__ void BNBackward2DChannelLastStage2(
...
@@ -417,45 +361,21 @@ static __global__ void BNBackward2DChannelLastStage2(
}
}
// vertical block sum
// vertical block sum
BlockReduceByVetical
<
T
>
(
funcs
::
BlockReduceByVetical
<
T
,
BatchNormParamType
<
T
>
>
(
ds_sum
,
db_sum
,
&
smem_ds_sum
[
0
],
&
smem_db_sum
[
0
],
&
ds_sum
,
&
db_sum
);
ds_sum
,
db_sum
,
&
smem_ds_sum
[
0
],
&
smem_db_sum
[
0
],
&
ds_sum
,
&
db_sum
);
if
(
gridDim
.
y
>
1
)
{
if
(
gridDim
.
y
>
1
)
{
volatile
BatchNormParamType
<
T
>
*
staging_ds_sum
=
block_data_ptr
;
volatile
BatchNormParamType
<
T
>
*
staging_db_sum
=
&
block_data_ptr
[
C
*
gridDim
.
y
];
// write block data to global memory
if
(
threadIdx
.
y
==
0
)
{
staging_ds_sum
[
i
+
blockIdx
.
y
*
C
]
=
ds_sum
;
staging_db_sum
[
i
+
blockIdx
.
y
*
C
]
=
db_sum
;
}
// make sure write is visible to all blocks
__threadfence
();
__syncthreads
();
__shared__
bool
is_last_block_done
;
__shared__
bool
is_last_block_done
;
// mark block done
funcs
::
ReduceSumPost
<
T
,
BatchNormParamType
<
T
>>
(
C
,
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
i
,
int
old
=
atomicAdd
(
&
flag_ptr
[
blockIdx
.
x
],
1
);
&
ds_sum
,
is_last_block_done
=
(
old
==
(
gridDim
.
y
-
1
));
&
db_sum
,
}
&
is_last_block_done
,
smem_ds_sum
,
__syncthreads
();
smem_db_sum
,
block_data_ptr
,
flag_ptr
);
if
(
is_last_block_done
)
{
if
(
is_last_block_done
)
{
ds_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
db_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
// thread sum
for
(
int
y
=
threadIdx
.
y
;
y
<
gridDim
.
y
;
y
+=
blockDim
.
y
)
{
ds_sum
+=
staging_ds_sum
[
i
+
y
*
C
];
db_sum
+=
staging_db_sum
[
i
+
y
*
C
];
}
// vertical block sum
BlockReduceByVetical
<
T
>
(
ds_sum
,
db_sum
,
&
smem_ds_sum
[
0
],
&
smem_db_sum
[
0
],
&
ds_sum
,
&
db_sum
);
// final compute
// final compute
if
(
threadIdx
.
y
==
0
)
{
if
(
threadIdx
.
y
==
0
)
{
dscale
[
i
]
=
ds_sum
*
inv_var_val
;
dscale
[
i
]
=
ds_sum
*
inv_var_val
;
...
@@ -563,51 +483,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData(
...
@@ -563,51 +483,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData(
}
}
}
}
template
<
typename
T
,
typename
Context
>
void
SetLaunchConfigInfoForChannelLast
(
const
Context
&
ctx
,
DenseTensor
*
block_data_tensor
,
DenseTensor
*
flag_tensor
,
BatchNormParamType
<
T
>
**
block_data_ptr
,
int
**
flag_ptr
,
const
int
N
,
const
int
H
,
const
int
W
,
const
int
D
,
const
int
C
,
const
int
block_size
,
dim3
*
block
,
dim3
*
grid
)
{
const
int
MAX_GRID_SIZE
=
128
;
const
int
WARP_SIZE
=
32
;
int
block_x
=
std
::
min
(
phi
::
funcs
::
details
::
GetLastPow2
(
C
),
WARP_SIZE
);
int
block_y
=
std
::
min
(
phi
::
funcs
::
details
::
GetLastPow2
(
N
*
H
*
W
*
D
/
16
),
block_size
/
block_x
);
if
(
block_x
*
block_y
!=
block_size
)
{
block_x
=
std
::
min
(
phi
::
funcs
::
details
::
GetLastPow2
(
C
),
block_size
/
block_y
);
}
int
grid_x
=
(
C
+
block_x
-
1
)
/
block_x
;
int
grid_y
=
std
::
min
((
N
*
H
*
W
*
D
+
block_y
*
16
-
1
)
/
(
block_y
*
16
),
MAX_GRID_SIZE
);
block
->
x
=
block_x
;
block
->
y
=
block_y
;
grid
->
x
=
grid_x
;
grid
->
y
=
grid_y
;
if
(
grid
->
y
>
1
)
{
*
block_data_tensor
=
phi
::
Empty
<
BatchNormParamType
<
T
>
,
Context
>
(
ctx
,
{
2
*
C
*
grid
->
y
});
*
flag_tensor
=
phi
::
Empty
<
int
,
Context
>
(
ctx
,
{
grid
->
x
});
*
block_data_ptr
=
block_data_tensor
->
data
<
BatchNormParamType
<
T
>>
();
*
flag_ptr
=
flag_tensor
->
data
<
int
>
();
funcs
::
SetConstant
<
Context
,
int
>
set_zero
;
set_zero
(
ctx
,
flag_tensor
,
static_cast
<
int
>
(
0
));
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
BatchNormGradRawKernel
(
const
Context
&
ctx
,
void
BatchNormGradRawKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
x
,
...
@@ -931,7 +806,8 @@ void BatchNormGradRawKernel(const Context &ctx,
...
@@ -931,7 +806,8 @@ void BatchNormGradRawKernel(const Context &ctx,
BatchNormParamType
<
T
>
*
block_data_ptr
=
nullptr
;
BatchNormParamType
<
T
>
*
block_data_ptr
=
nullptr
;
int
*
flag_ptr
=
nullptr
;
int
*
flag_ptr
=
nullptr
;
SetLaunchConfigInfoForChannelLast
<
T
>
(
ctx
,
funcs
::
SetLaunchConfigInfoForChannelLast
<
T
,
BatchNormParamType
<
T
>>
(
ctx
,
&
block_data_tensor
,
&
block_data_tensor
,
&
flag_tensor
,
&
flag_tensor
,
&
block_data_ptr
,
&
block_data_ptr
,
...
@@ -1294,7 +1170,8 @@ void BatchNormGradRawKernel(const Context &ctx,
...
@@ -1294,7 +1170,8 @@ void BatchNormGradRawKernel(const Context &ctx,
BatchNormParamType
<
T
>
*
block_data_ptr
=
nullptr
;
BatchNormParamType
<
T
>
*
block_data_ptr
=
nullptr
;
int
*
flag_ptr
=
nullptr
;
int
*
flag_ptr
=
nullptr
;
SetLaunchConfigInfoForChannelLast
<
T
>
(
ctx
,
funcs
::
SetLaunchConfigInfoForChannelLast
<
T
,
BatchNormParamType
<
T
>>
(
ctx
,
&
block_data_tensor
,
&
block_data_tensor
,
&
flag_tensor
,
&
flag_tensor
,
&
block_data_ptr
,
&
block_data_ptr
,
...
...
paddle/phi/kernels/gpu/batch_norm_kernel.cu
浏览文件 @
9a4acfee
...
@@ -30,6 +30,7 @@ namespace cub = hipcub;
...
@@ -30,6 +30,7 @@ namespace cub = hipcub;
#include "paddle/phi/kernels/batch_norm_kernel.h"
#include "paddle/phi/kernels/batch_norm_kernel.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/norm_utils.cu.h"
#include "paddle/phi/kernels/funcs/norm_utils.h"
#include "paddle/phi/kernels/funcs/norm_utils.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
...
@@ -171,34 +172,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining(
...
@@ -171,34 +172,6 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining(
}
}
}
}
template
<
typename
T
>
__device__
__forceinline__
void
merge_block_vertical
(
BatchNormParamType
<
T
>
x_sum
,
BatchNormParamType
<
T
>
x_square_sum
,
BatchNormParamType
<
T
>
*
smem_sum
,
BatchNormParamType
<
T
>
*
smem_square_sum
,
BatchNormParamType
<
T
>
*
x_sum_out
,
BatchNormParamType
<
T
>
*
x_square_sum_out
)
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
#pragma unroll
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
>>=
1
)
{
if
(
threadIdx
.
y
<
offset
*
2
)
{
smem_sum
[
tid
]
=
x_sum
;
smem_square_sum
[
tid
]
=
x_square_sum
;
}
__syncthreads
();
if
(
threadIdx
.
y
<
offset
)
{
int
pair_tid
=
tid
+
offset
*
blockDim
.
x
;
x_sum
+=
smem_sum
[
pair_tid
];
x_square_sum
+=
smem_square_sum
[
pair_tid
];
}
}
if
(
threadIdx
.
y
==
0
)
{
*
x_sum_out
=
x_sum
;
*
x_square_sum_out
=
x_square_sum
;
}
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
void
merge_block_horizonal
(
__device__
__forceinline__
void
merge_block_horizonal
(
BatchNormParamType
<
T
>
x_sum
,
BatchNormParamType
<
T
>
x_sum
,
...
@@ -269,7 +242,7 @@ static __global__ void BNForwardTraining2DChannelLastCompStat(
...
@@ -269,7 +242,7 @@ static __global__ void BNForwardTraining2DChannelLastCompStat(
}
}
// vertical block sum
// vertical block sum
merge_block_vertical
<
T
>
(
x_sum
,
funcs
::
BlockReduceByVetical
<
T
,
BatchNormParamType
<
T
>
>
(
x_sum
,
x_square_sum
,
x_square_sum
,
&
smem_sum
[
0
],
&
smem_sum
[
0
],
&
smem_square_sum
[
0
],
&
smem_square_sum
[
0
],
...
@@ -277,45 +250,18 @@ static __global__ void BNForwardTraining2DChannelLastCompStat(
...
@@ -277,45 +250,18 @@ static __global__ void BNForwardTraining2DChannelLastCompStat(
&
x_square_sum
);
&
x_square_sum
);
if
(
gridDim
.
y
>
1
)
{
if
(
gridDim
.
y
>
1
)
{
volatile
BatchNormParamType
<
T
>
*
staging_sum
=
block_data_ptr
;
volatile
BatchNormParamType
<
T
>
*
staging_square_sum
=
&
block_data_ptr
[
C
*
gridDim
.
y
];
// write block data to global memory
if
(
threadIdx
.
y
==
0
)
{
staging_sum
[
i
+
blockIdx
.
y
*
C
]
=
x_sum
;
staging_square_sum
[
i
+
blockIdx
.
y
*
C
]
=
x_square_sum
;
}
// make sure write is visible to all blocks
__threadfence
();
__syncthreads
();
__shared__
bool
is_last_block_done
;
__shared__
bool
is_last_block_done
;
// mark block done
funcs
::
ReduceSumPost
<
T
,
BatchNormParamType
<
T
>>
(
C
,
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
i
,
int
old
=
atomicAdd
(
&
flag_ptr
[
blockIdx
.
x
],
1
);
is_last_block_done
=
(
old
==
(
gridDim
.
y
-
1
));
}
__syncthreads
();
if
(
is_last_block_done
)
{
x_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
x_square_sum
=
static_cast
<
BatchNormParamType
<
T
>>
(
0
);
// thread sum
for
(
int
y
=
threadIdx
.
y
;
y
<
gridDim
.
y
;
y
+=
blockDim
.
y
)
{
x_sum
+=
staging_sum
[
i
+
y
*
C
];
x_square_sum
+=
staging_square_sum
[
i
+
y
*
C
];
}
// vertical block sum
merge_block_vertical
<
T
>
(
x_sum
,
x_square_sum
,
&
smem_sum
[
0
],
&
smem_square_sum
[
0
],
&
x_sum
,
&
x_sum
,
&
x_square_sum
);
&
x_square_sum
,
&
is_last_block_done
,
smem_sum
,
smem_square_sum
,
block_data_ptr
,
flag_ptr
);
if
(
is_last_block_done
)
{
// final compute
// final compute
if
(
threadIdx
.
y
==
0
)
{
if
(
threadIdx
.
y
==
0
)
{
BatchNormParamType
<
T
>
compute_mean_val
=
x_sum
/
inner_size
;
BatchNormParamType
<
T
>
compute_mean_val
=
x_sum
/
inner_size
;
...
...
paddle/phi/kernels/gpu/sync_batch_norm_utils.h
浏览文件 @
9a4acfee
...
@@ -34,6 +34,7 @@ namespace cub = hipcub;
...
@@ -34,6 +34,7 @@ namespace cub = hipcub;
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/kernels/funcs/norm_utils.cu.h"
#include "paddle/phi/kernels/funcs/norm_utils.h"
#include "paddle/phi/kernels/funcs/norm_utils.h"
namespace
phi
{
namespace
phi
{
...
@@ -168,6 +169,61 @@ __global__ void KeBackwardLocalStats(const T *dy,
...
@@ -168,6 +169,61 @@ __global__ void KeBackwardLocalStats(const T *dy,
}
}
}
}
template
<
typename
T
,
const
int
BlockDim
,
DataLayout
layout
>
__global__
void
KeBackwardLocalStats2D
(
const
T
*
dy
,
const
T
*
x
,
const
BatchNormParamType
<
T
>
*
means
,
int
N
,
int
M
,
int
C
,
BatchNormParamType
<
T
>
*
block_data_ptr
,
int
*
flag_ptr
,
BatchNormParamType
<
T
>
*
sum_dy_prod
)
{
__shared__
BatchNormParamType
<
T
>
smem_sum
[
BlockDim
];
__shared__
BatchNormParamType
<
T
>
smem_square_sum
[
BlockDim
];
for
(
int
k
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
k
<
C
;
k
+=
gridDim
.
x
*
blockDim
.
x
)
{
BatchNormParamType
<
T
>
sum1
=
0.
;
BatchNormParamType
<
T
>
sum2
=
0.
;
auto
mean
=
means
[
k
];
for
(
int
i
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
i
<
N
*
M
;
i
+=
gridDim
.
y
*
blockDim
.
y
)
{
int
id
=
layout
==
DataLayout
::
kNCHW
?
(
i
/
M
)
*
C
*
M
+
k
*
M
+
i
%
M
:
i
*
C
+
k
;
auto
g
=
static_cast
<
BatchNormParamType
<
T
>>
(
dy
[
id
]);
sum1
+=
g
;
auto
x_i
=
static_cast
<
BatchNormParamType
<
T
>>
(
x
[
id
]);
sum2
+=
g
*
(
x_i
-
mean
);
}
funcs
::
BlockReduceByVetical
<
T
,
BatchNormParamType
<
T
>>
(
sum1
,
sum2
,
&
smem_sum
[
0
],
&
smem_square_sum
[
0
],
&
sum1
,
&
sum2
);
if
(
gridDim
.
y
>
1
)
{
__shared__
bool
is_last_block_done
;
funcs
::
ReduceSumPost
<
T
,
BatchNormParamType
<
T
>>
(
C
,
k
,
&
sum1
,
&
sum2
,
&
is_last_block_done
,
smem_sum
,
smem_square_sum
,
block_data_ptr
,
flag_ptr
);
if
(
is_last_block_done
)
{
// final compute
if
(
threadIdx
.
y
==
0
)
{
sum_dy_prod
[
k
]
=
sum1
;
sum_dy_prod
[
k
+
C
]
=
sum2
;
}
}
}
}
if
(
blockIdx
.
y
==
0
&&
blockIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
threadIdx
.
x
==
0
)
{
sum_dy_prod
[
2
*
C
]
=
1.0
;
}
}
template
<
typename
T
,
int
BlockDim
,
DataLayout
layout
>
template
<
typename
T
,
int
BlockDim
,
DataLayout
layout
>
static
__global__
void
KeBNBackwardScaleBias
(
static
__global__
void
KeBNBackwardScaleBias
(
const
T
*
dy
,
const
T
*
dy
,
...
@@ -213,6 +269,68 @@ static __global__ void KeBNBackwardScaleBias(
...
@@ -213,6 +269,68 @@ static __global__ void KeBNBackwardScaleBias(
}
}
}
}
template
<
typename
T
,
int
BlockDim
,
DataLayout
layout
>
static
__global__
void
KeBNBackwardScaleBias2D
(
const
T
*
dy
,
const
T
*
x
,
const
BatchNormParamType
<
T
>
*
mean
,
const
BatchNormParamType
<
T
>
*
inv_variance
,
const
double
epsilon
,
const
int
N
,
const
int
C
,
const
int
HxW
,
BatchNormParamType
<
T
>
*
block_data_ptr
,
int
*
flag_ptr
,
BatchNormParamType
<
T
>
*
dscale
,
BatchNormParamType
<
T
>
*
dbias
)
{
const
int
outer_size
=
C
;
const
int
inner_size
=
N
*
HxW
;
__shared__
BatchNormParamType
<
T
>
smem_sum
[
BlockDim
];
__shared__
BatchNormParamType
<
T
>
smem_square_sum
[
BlockDim
];
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
outer_size
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
BatchNormParamType
<
T
>
ds_sum
=
0.
;
BatchNormParamType
<
T
>
db_sum
=
0.
;
auto
inv_var_i
=
inv_variance
[
i
];
auto
mean_i
=
mean
[
i
];
for
(
int
j
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
j
<
inner_size
;
j
+=
gridDim
.
y
*
blockDim
.
y
)
{
const
int
id
=
layout
==
DataLayout
::
kNCHW
?
((
j
/
HxW
)
*
C
+
i
)
*
HxW
+
(
j
%
HxW
)
:
j
*
outer_size
+
i
;
auto
x_i
=
static_cast
<
BatchNormParamType
<
T
>>
(
x
[
id
]);
auto
dy_i
=
static_cast
<
BatchNormParamType
<
T
>>
(
dy
[
id
]);
ds_sum
+=
dy_i
*
(
x_i
-
mean_i
);
db_sum
+=
dy_i
;
}
funcs
::
BlockReduceByVetical
<
T
,
BatchNormParamType
<
T
>>
(
ds_sum
,
db_sum
,
&
smem_sum
[
0
],
&
smem_square_sum
[
0
],
&
ds_sum
,
&
db_sum
);
if
(
gridDim
.
y
>
1
)
{
__shared__
bool
is_last_block_done
;
funcs
::
ReduceSumPost
<
T
,
BatchNormParamType
<
T
>>
(
C
,
i
,
&
ds_sum
,
&
db_sum
,
&
is_last_block_done
,
smem_sum
,
smem_square_sum
,
block_data_ptr
,
flag_ptr
);
if
(
is_last_block_done
)
{
// final compute
if
(
threadIdx
.
y
==
0
)
{
dscale
[
i
]
=
ds_sum
*
inv_var_i
;
dbias
[
i
]
=
db_sum
;
}
}
}
}
}
template
<
typename
T
,
DataLayout
layout
>
template
<
typename
T
,
DataLayout
layout
>
static
__global__
void
KeBNRestoreData
(
T
*
x
,
static
__global__
void
KeBNRestoreData
(
T
*
x
,
const
BatchNormParamType
<
T
>
*
scale
,
const
BatchNormParamType
<
T
>
*
scale
,
...
@@ -409,11 +527,48 @@ void SyncBatchNormGradFunctor(
...
@@ -409,11 +527,48 @@ void SyncBatchNormGradFunctor(
KeBackwardLocalStats
<
T
,
threads
,
DataLayout
::
kNCHW
>
KeBackwardLocalStats
<
T
,
threads
,
DataLayout
::
kNCHW
>
<<<
grid
,
threads
,
0
,
stream
>>>
(
<<<
grid
,
threads
,
0
,
stream
>>>
(
dy_d
,
x_d
,
saved_mean_ptr
,
N
,
fsize
,
C
,
stats
);
dy_d
,
x_d
,
saved_mean_ptr
,
N
,
fsize
,
C
,
stats
);
}
else
{
if
(
x_dims
.
size
()
==
2
&&
N
>=
65535
)
{
dim3
block
;
dim3
grid
;
const
int
block_size
=
512
;
// init intermediate storage
DenseTensor
block_data_tensor
;
DenseTensor
flag_tensor
;
BatchNormParamType
<
T
>
*
block_data_ptr
=
nullptr
;
int
*
flag_ptr
=
nullptr
;
funcs
::
SetLaunchConfigInfoForChannelLast
<
T
,
BatchNormParamType
<
T
>>
(
ctx
,
&
block_data_tensor
,
&
flag_tensor
,
&
block_data_ptr
,
&
flag_ptr
,
N
,
H
,
W
,
D
,
C
,
block_size
,
&
block
,
&
grid
);
KeBackwardLocalStats2D
<
T
,
block_size
,
DataLayout
::
kNHWC
>
<<<
grid
,
block
,
0
,
stream
>>>
(
dy_d
,
x_d
,
saved_mean_ptr
,
N
,
fsize
,
C
,
block_data_ptr
,
flag_ptr
,
stats
);
}
else
{
}
else
{
KeBackwardLocalStats
<
T
,
threads
,
DataLayout
::
kNHWC
>
KeBackwardLocalStats
<
T
,
threads
,
DataLayout
::
kNHWC
>
<<<
grid
,
threads
,
0
,
stream
>>>
(
<<<
grid
,
threads
,
0
,
stream
>>>
(
dy_d
,
x_d
,
saved_mean_ptr
,
N
,
fsize
,
C
,
stats
);
dy_d
,
x_d
,
saved_mean_ptr
,
N
,
fsize
,
C
,
stats
);
}
}
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
int
global_gid
=
0
;
int
global_gid
=
0
;
...
@@ -476,8 +631,48 @@ void SyncBatchNormGradFunctor(
...
@@ -476,8 +631,48 @@ void SyncBatchNormGradFunctor(
}
}
}
else
{
}
else
{
if
(
d_scale
&&
d_bias
)
{
if
(
d_scale
&&
d_bias
)
{
if
(
x_dims
.
size
()
==
2
&&
N
>=
65535
)
{
dim3
block
;
dim3
grid
;
const
int
block_size
=
512
;
// init intermediate storage
DenseTensor
block_data_tensor
;
DenseTensor
flag_tensor
;
BatchNormParamType
<
T
>
*
block_data_ptr
=
nullptr
;
int
*
flag_ptr
=
nullptr
;
funcs
::
SetLaunchConfigInfoForChannelLast
<
T
,
BatchNormParamType
<
T
>>
(
ctx
,
&
block_data_tensor
,
&
flag_tensor
,
&
block_data_ptr
,
&
flag_ptr
,
N
,
H
,
W
,
D
,
C
,
block_size
,
&
block
,
&
grid
);
KeBNBackwardScaleBias2D
<
T
,
block_size
,
DataLayout
::
kNHWC
>
<<<
grid
,
block
,
0
,
stream
>>>
(
dy_d
,
x_d
,
saved_mean_ptr
,
saved_inv_var
,
epsilon
,
N
,
C
,
fsize
,
block_data_ptr
,
flag_ptr
,
d_scale
->
data
<
BatchNormParamType
<
T
>>
(),
d_bias
->
data
<
BatchNormParamType
<
T
>>
());
}
else
{
KeBNBackwardScaleBias
<
T
,
threads
,
DataLayout
::
kNHWC
>
KeBNBackwardScaleBias
<
T
,
threads
,
DataLayout
::
kNHWC
>
<<<
grid
,
threads
,
0
,
stream
>>>
(
dy_d
,
<<<
grid
,
threads
,
0
,
stream
>>>
(
dy_d
,
x_d
,
x_d
,
saved_mean_ptr
,
saved_mean_ptr
,
saved_inv_var
,
saved_inv_var
,
...
@@ -488,6 +683,7 @@ void SyncBatchNormGradFunctor(
...
@@ -488,6 +683,7 @@ void SyncBatchNormGradFunctor(
d_scale
->
data
<
BatchNormParamType
<
T
>>
(),
d_scale
->
data
<
BatchNormParamType
<
T
>>
(),
d_bias
->
data
<
BatchNormParamType
<
T
>>
());
d_bias
->
data
<
BatchNormParamType
<
T
>>
());
}
}
}
if
(
d_x
)
{
if
(
d_x
)
{
KeBNBackwardData
<
T
,
DataLayout
::
kNHWC
><<<
grid2
,
block
,
0
,
stream
>>>
(
KeBNBackwardData
<
T
,
DataLayout
::
kNHWC
><<<
grid2
,
block
,
0
,
stream
>>>
(
dy_d
,
dy_d
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录