Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
63fd7d66
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
63fd7d66
编写于
10月 14, 2021
作者:
Z
Zeng Jinle
提交者:
GitHub
10月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine merge lars (#36428)
上级
3e6d9dbb
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
24 addition
and
24 deletion
+24
-24
paddle/fluid/operators/optimizers/lars_momentum_op.cu
paddle/fluid/operators/optimizers/lars_momentum_op.cu
+24
-24
未找到文件。
paddle/fluid/operators/optimizers/lars_momentum_op.cu
浏览文件 @
63fd7d66
...
@@ -28,7 +28,7 @@ limitations under the License. */
...
@@ -28,7 +28,7 @@ limitations under the License. */
#define LARS_BLOCK_SIZE 512
#define LARS_BLOCK_SIZE 512
#endif
#endif
#define LARS_MAX_MERGED_OPS
15
0
#define LARS_MAX_MERGED_OPS
6
0
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -256,11 +256,8 @@ template <typename T, typename MT>
...
@@ -256,11 +256,8 @@ template <typename T, typename MT>
struct
LarsParamWarpper
{
struct
LarsParamWarpper
{
int64_t
numel_arr
[
LARS_MAX_MERGED_OPS
];
int64_t
numel_arr
[
LARS_MAX_MERGED_OPS
];
int
repeat_arr
[
LARS_MAX_MERGED_OPS
];
int
repeat_arr
[
LARS_MAX_MERGED_OPS
];
const
T
*
__restrict__
p_arr
[
LARS_MAX_MERGED_OPS
];
const
T
*
__restrict__
g_arr
[
LARS_MAX_MERGED_OPS
];
const
T
*
__restrict__
g_arr
[
LARS_MAX_MERGED_OPS
];
const
MT
*
__restrict__
v_arr
[
LARS_MAX_MERGED_OPS
];
const
MT
*
__restrict__
lr_arr
[
LARS_MAX_MERGED_OPS
];
const
MT
*
__restrict__
lr_arr
[
LARS_MAX_MERGED_OPS
];
const
MT
*
__restrict__
master_p_arr
[
LARS_MAX_MERGED_OPS
];
T
*
__restrict__
p_out_arr
[
LARS_MAX_MERGED_OPS
];
T
*
__restrict__
p_out_arr
[
LARS_MAX_MERGED_OPS
];
MT
*
__restrict__
v_out_arr
[
LARS_MAX_MERGED_OPS
];
MT
*
__restrict__
v_out_arr
[
LARS_MAX_MERGED_OPS
];
MT
*
__restrict__
master_p_out_arr
[
LARS_MAX_MERGED_OPS
];
MT
*
__restrict__
master_p_out_arr
[
LARS_MAX_MERGED_OPS
];
...
@@ -268,7 +265,7 @@ struct LarsParamWarpper {
...
@@ -268,7 +265,7 @@ struct LarsParamWarpper {
};
};
template
<
typename
T
,
typename
MT
>
template
<
typename
T
,
typename
MT
>
__global__
void
MergedMomentumLarsKernel
(
LarsParamWarpper
<
T
,
MT
>
*
lars_warpper
,
__global__
void
MergedMomentumLarsKernel
(
LarsParamWarpper
<
T
,
MT
>
lars_warpper
,
MT
*
__restrict__
p_buffer
,
MT
*
__restrict__
p_buffer
,
MT
*
__restrict__
g_buffer
,
MT
*
__restrict__
g_buffer
,
const
int
op_num
,
const
MT
mu
,
const
int
op_num
,
const
MT
mu
,
...
@@ -279,18 +276,18 @@ __global__ void MergedMomentumLarsKernel(LarsParamWarpper<T, MT>* lars_warpper,
...
@@ -279,18 +276,18 @@ __global__ void MergedMomentumLarsKernel(LarsParamWarpper<T, MT>* lars_warpper,
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
const
cooperative_groups
::
grid_group
cg
=
cooperative_groups
::
this_grid
();
const
cooperative_groups
::
grid_group
cg
=
cooperative_groups
::
this_grid
();
for
(
int
i
=
0
;
i
<
op_num
;
++
i
)
{
for
(
int
i
=
0
;
i
<
op_num
;
++
i
)
{
int
numel
=
lars_warpper
->
numel_arr
[
i
];
int
numel
=
lars_warpper
.
numel_arr
[
i
];
MT
param_norm
=
static_cast
<
MT
>
(
0
);
MT
param_norm
=
static_cast
<
MT
>
(
0
);
MT
grad_norm
=
static_cast
<
MT
>
(
0
);
MT
grad_norm
=
static_cast
<
MT
>
(
0
);
L2NormKernel
<
T
,
MT
>
(
&
cg
,
lars_warpper
->
p_arr
[
i
],
lars_warpper
->
g_arr
[
i
],
L2NormKernel
<
T
,
MT
>
(
&
cg
,
lars_warpper
.
p_out_arr
[
i
],
lars_warpper
.
g_arr
[
i
],
p_buffer
,
g_buffer
,
numel
,
lars_warpper
->
repeat_arr
[
i
],
p_buffer
,
g_buffer
,
numel
,
lars_warpper
.
repeat_arr
[
i
],
rescale_grad
,
0
,
&
param_norm
,
&
grad_norm
);
rescale_grad
,
0
,
&
param_norm
,
&
grad_norm
);
MomentumUpdate
<
T
,
MT
>
(
MomentumUpdate
<
T
,
MT
>
(
lars_warpper
->
p_arr
[
i
],
lars_warpper
->
g_arr
[
i
],
lars_warpper
.
p_out_arr
[
i
],
lars_warpper
.
g_arr
[
i
],
lars_warpper
->
v_out_arr
[
i
],
lars_warpper
->
p_out_arr
[
i
],
lars_warpper
.
v_out_arr
[
i
],
lars_warpper
.
p_out_arr
[
i
],
lars_warpper
->
v_out_arr
[
i
],
lars_warpper
->
master_p
_arr
[
i
],
lars_warpper
.
v_out_arr
[
i
],
lars_warpper
.
master_p_out
_arr
[
i
],
lars_warpper
->
master_p_out_arr
[
i
],
lars_warpper
->
lr_arr
[
i
],
mu
,
lars_warpper
.
master_p_out_arr
[
i
],
lars_warpper
.
lr_arr
[
i
],
mu
,
lars_warpper
->
weight_decay_arr
[
i
],
lars_coeff
,
epsilon
,
rescale_grad
,
lars_warpper
.
weight_decay_arr
[
i
],
lars_coeff
,
epsilon
,
rescale_grad
,
param_norm
,
grad_norm
,
tid
,
grid_stride
,
numel
,
is_amp
);
param_norm
,
grad_norm
,
tid
,
grid_stride
,
numel
,
is_amp
);
}
}
}
}
...
@@ -410,15 +407,21 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -410,15 +407,21 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
size_t
temp_numel
=
param
[
i
]
->
numel
();
size_t
temp_numel
=
param
[
i
]
->
numel
();
total_numel
+=
temp_numel
;
total_numel
+=
temp_numel
;
lars_warpper
.
numel_arr
[
i
]
=
temp_numel
;
lars_warpper
.
numel_arr
[
i
]
=
temp_numel
;
lars_warpper
.
p_arr
[
i
]
=
param
[
i
]
->
data
<
T
>
();
lars_warpper
.
g_arr
[
i
]
=
grad
[
i
]
->
data
<
T
>
();
lars_warpper
.
g_arr
[
i
]
=
grad
[
i
]
->
data
<
T
>
();
lars_warpper
.
v_arr
[
i
]
=
velocity
[
i
]
->
data
<
MT
>
();
lars_warpper
.
lr_arr
[
i
]
=
learning_rate
[
i
]
->
data
<
MT
>
();
lars_warpper
.
lr_arr
[
i
]
=
learning_rate
[
i
]
->
data
<
MT
>
();
lars_warpper
.
p_out_arr
[
i
]
=
lars_warpper
.
p_out_arr
[
i
]
=
param_out
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
param_out
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
lars_warpper
.
v_out_arr
[
i
]
=
lars_warpper
.
v_out_arr
[
i
]
=
velocity_out
[
i
]
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
());
velocity_out
[
i
]
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
());
lars_warpper
.
weight_decay_arr
[
i
]
=
static_cast
<
MT
>
(
weight_decay_arr
[
i
]);
lars_warpper
.
weight_decay_arr
[
i
]
=
static_cast
<
MT
>
(
weight_decay_arr
[
i
]);
PADDLE_ENFORCE_EQ
(
param
[
i
]
->
data
<
T
>
(),
lars_warpper
.
p_out_arr
[
i
],
platform
::
errors
::
InvalidArgument
(
"Input(Param) and Output(ParamOut) must be the same Tensors."
));
PADDLE_ENFORCE_EQ
(
velocity
[
i
]
->
data
<
MT
>
(),
lars_warpper
.
v_out_arr
[
i
],
platform
::
errors
::
InvalidArgument
(
"Input(Velocity) and Output(VelocityOut) must be "
"the same Tensors."
));
}
}
int64_t
avg_numel
=
total_numel
/
op_num
;
int64_t
avg_numel
=
total_numel
/
op_num
;
LarsThreadConfig
<
float
>
lars_thread_config
(
avg_numel
,
sm_num
,
LarsThreadConfig
<
float
>
lars_thread_config
(
avg_numel
,
sm_num
,
...
@@ -429,19 +432,16 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -429,19 +432,16 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
}
}
if
(
multi_precision
)
{
if
(
multi_precision
)
{
for
(
int
i
=
0
;
i
<
op_num
;
++
i
)
{
for
(
int
i
=
0
;
i
<
op_num
;
++
i
)
{
lars_warpper
.
master_p_arr
[
i
]
=
master_param
[
i
]
->
data
<
MT
>
();
lars_warpper
.
master_p_out_arr
[
i
]
=
lars_warpper
.
master_p_out_arr
[
i
]
=
master_param_out
[
i
]
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
());
master_param_out
[
i
]
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
());
PADDLE_ENFORCE_EQ
(
master_param
[
i
]
->
data
<
MT
>
(),
lars_warpper
.
master_p_out_arr
[
i
],
platform
::
errors
::
InvalidArgument
(
"Input(MasterParam) and Output(MasterParamOut) "
"must be the same Tensors."
));
}
}
}
}
auto
merged_buf
=
memory
::
Alloc
(
cuda_ctx
,
sizeof
(
lars_warpper
));
void
*
cuda_param
[]
=
{
reinterpret_cast
<
void
*>
(
&
lars_warpper
),
auto
*
merged_ptr
=
reinterpret_cast
<
LarsParamWarpper
<
T
,
MT
>*>
(
merged_buf
->
ptr
());
memory
::
Copy
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
cuda_ctx
.
GetPlace
()),
reinterpret_cast
<
void
*>
(
merged_ptr
),
platform
::
CPUPlace
(),
reinterpret_cast
<
void
*>
(
&
lars_warpper
),
sizeof
(
lars_warpper
),
cuda_ctx
.
stream
());
void
*
cuda_param
[]
=
{
reinterpret_cast
<
void
*>
(
&
merged_ptr
),
reinterpret_cast
<
void
*>
(
&
p_buffer
),
reinterpret_cast
<
void
*>
(
&
p_buffer
),
reinterpret_cast
<
void
*>
(
&
g_buffer
),
reinterpret_cast
<
void
*>
(
&
g_buffer
),
reinterpret_cast
<
void
*>
(
&
op_num
),
reinterpret_cast
<
void
*>
(
&
op_num
),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录