Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2e6548a9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2e6548a9
编写于
3月 02, 2022
作者:
S
sneaxiy
提交者:
GitHub
3月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
vec scale kernel (#40011)
上级
5898e9ab
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
39 addition
and
10 deletion
+39
-10
paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu
...e/fluid/operators/optimizers/distributed_fused_lamb_op.cu
+39
-10
未找到文件。
paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cu
浏览文件 @
2e6548a9
...
...
@@ -304,14 +304,30 @@ struct AndFunctor {
HOSTDEVICE
bool
operator
()(
bool
x
,
bool
y
)
const
{
return
x
&&
y
;
}
};
template
<
typename
T1
,
typename
T2
>
template
<
typename
T1
,
typename
T2
,
int
VecSize
>
static
__global__
void
ScaleCUDAKernel
(
const
T1
*
__restrict__
x
,
const
T2
*
__restrict__
scale
,
T1
*
__restrict__
y
,
int
num
)
{
static_assert
(
sizeof
(
T1
)
<=
sizeof
(
T2
),
"sizeof(T1) must be not greater than sizeof(T2)."
);
T2
s
=
scale
[
0
];
CUDA_KERNEL_LOOP
(
i
,
num
)
{
int
i
=
(
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
)
*
VecSize
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
*
VecSize
;
for
(;
i
+
VecSize
<=
num
;
i
+=
stride
)
{
platform
::
AlignedVector
<
T1
,
VecSize
>
x_vec
;
platform
::
AlignedVector
<
T1
,
VecSize
>
y_vec
;
platform
::
Load
(
x
+
i
,
&
x_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
VecSize
;
++
j
)
{
y_vec
[
j
]
=
static_cast
<
T1
>
(
static_cast
<
T2
>
(
x_vec
[
j
])
*
s
);
}
platform
::
Store
(
y_vec
,
y
+
i
);
}
for
(;
i
<
num
;
++
i
)
{
y
[
i
]
=
static_cast
<
T1
>
(
static_cast
<
T2
>
(
x
[
i
])
*
s
);
}
}
...
...
@@ -396,7 +412,6 @@ static __global__ void UpdateLambMomentAndTrustRatioDivCUDAKernel(
for
(;
i
+
VecSize
<=
num
;
i
+=
stride
)
{
platform
::
AlignedVector
<
T
,
VecSize
>
param_vec
;
platform
::
AlignedVector
<
GradT
,
VecSize
>
grad_vec
;
platform
::
AlignedVector
<
T
,
VecSize
>
weight_decay_vec
;
platform
::
AlignedVector
<
T
,
VecSize
>
mom1_vec
;
platform
::
AlignedVector
<
T
,
VecSize
>
mom2_vec
;
platform
::
AlignedVector
<
T
,
VecSize
>
trust_ratio_div_vec
;
...
...
@@ -760,6 +775,24 @@ static bool CreatePreMulScaleOpIfSupported(ncclDataType_t dtype,
return
false
;
}
template
<
typename
T1
,
typename
T2
>
static
void
LaunchScaleKernel
(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
T1
*
x
,
const
T2
*
scale
,
T1
*
y
,
int
n
,
gpuStream_t
stream
)
{
int
vec_size
=
std
::
min
(
GetChunkedVecSize
(
x
,
0
),
GetChunkedVecSize
(
y
,
0
));
auto
config
=
platform
::
GetGpuLaunchConfig1D
(
dev_ctx
,
n
,
vec_size
);
#define PD_LAMB_VEC_SCALE_KERNEL_CASE \
do { \
ScaleCUDAKernel<T1, T2, kVecSize><<<config.block_per_grid, \
config.thread_per_block, 0, stream>>>( \
x, scale, y, n); \
} while (0)
PD_VEC_LAUNCH_KERNEL
(
vec_size
,
PD_LAMB_VEC_SCALE_KERNEL_CASE
);
#undef PD_LAMB_VEC_SCALE_KERNEL_CASE
}
template
<
typename
T
>
static
void
NCCLReduceScatterWithScale
(
const
T
*
sendbuff
,
T
*
recvbuff
,
size_t
recvcount
,
size_t
nranks
,
...
...
@@ -775,10 +808,8 @@ static void NCCLReduceScatterWithScale(
PADDLE_ENFORCE_EQ
(
nranks
,
1
,
platform
::
errors
::
InvalidArgument
(
"nranks must be 1 when scale != nullptr."
));
auto
numel
=
recvcount
*
nranks
;
auto
config
=
platform
::
GetGpuLaunchConfig1D
(
dev_ctx
,
numel
);
ScaleCUDAKernel
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
stream
>>>
(
sendbuff
,
scale
,
recvbuff
,
numel
);
LaunchScaleKernel
(
dev_ctx
,
sendbuff
,
scale
,
recvbuff
,
recvcount
*
nranks
,
stream
);
}
return
;
}
...
...
@@ -792,9 +823,7 @@ static void NCCLReduceScatterWithScale(
if
(
scale
&&
!
should_destroy_op
)
{
size_t
numel
=
recvcount
*
nranks
;
T
*
new_sendbuff
=
buffer
.
Alloc
<
T
>
(
numel
);
auto
config
=
platform
::
GetGpuLaunchConfig1D
(
dev_ctx
,
numel
);
ScaleCUDAKernel
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
stream
>>>
(
sendbuff
,
scale
,
new_sendbuff
,
numel
);
LaunchScaleKernel
(
dev_ctx
,
sendbuff
,
scale
,
new_sendbuff
,
numel
,
stream
);
sendbuff
=
new_sendbuff
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录