Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
63fd7d66
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
63fd7d66
编写于
10月 14, 2021
作者:
Z
Zeng Jinle
提交者:
GitHub
10月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine merge lars (#36428)
上级
3e6d9dbb
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
24 addition
and
24 deletion
+24
-24
paddle/fluid/operators/optimizers/lars_momentum_op.cu
paddle/fluid/operators/optimizers/lars_momentum_op.cu
+24
-24
未找到文件。
paddle/fluid/operators/optimizers/lars_momentum_op.cu
浏览文件 @
63fd7d66
...
@@ -28,7 +28,7 @@ limitations under the License. */
...
@@ -28,7 +28,7 @@ limitations under the License. */
#define LARS_BLOCK_SIZE 512
#define LARS_BLOCK_SIZE 512
#endif
#endif
#define LARS_MAX_MERGED_OPS
15
0
#define LARS_MAX_MERGED_OPS
6
0
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -256,11 +256,8 @@ template <typename T, typename MT>
...
@@ -256,11 +256,8 @@ template <typename T, typename MT>
struct
LarsParamWarpper
{
struct
LarsParamWarpper
{
int64_t
numel_arr
[
LARS_MAX_MERGED_OPS
];
int64_t
numel_arr
[
LARS_MAX_MERGED_OPS
];
int
repeat_arr
[
LARS_MAX_MERGED_OPS
];
int
repeat_arr
[
LARS_MAX_MERGED_OPS
];
const
T
*
__restrict__
p_arr
[
LARS_MAX_MERGED_OPS
];
const
T
*
__restrict__
g_arr
[
LARS_MAX_MERGED_OPS
];
const
T
*
__restrict__
g_arr
[
LARS_MAX_MERGED_OPS
];
const
MT
*
__restrict__
v_arr
[
LARS_MAX_MERGED_OPS
];
const
MT
*
__restrict__
lr_arr
[
LARS_MAX_MERGED_OPS
];
const
MT
*
__restrict__
lr_arr
[
LARS_MAX_MERGED_OPS
];
const
MT
*
__restrict__
master_p_arr
[
LARS_MAX_MERGED_OPS
];
T
*
__restrict__
p_out_arr
[
LARS_MAX_MERGED_OPS
];
T
*
__restrict__
p_out_arr
[
LARS_MAX_MERGED_OPS
];
MT
*
__restrict__
v_out_arr
[
LARS_MAX_MERGED_OPS
];
MT
*
__restrict__
v_out_arr
[
LARS_MAX_MERGED_OPS
];
MT
*
__restrict__
master_p_out_arr
[
LARS_MAX_MERGED_OPS
];
MT
*
__restrict__
master_p_out_arr
[
LARS_MAX_MERGED_OPS
];
...
@@ -268,7 +265,7 @@ struct LarsParamWarpper {
...
@@ -268,7 +265,7 @@ struct LarsParamWarpper {
};
};
template
<
typename
T
,
typename
MT
>
template
<
typename
T
,
typename
MT
>
__global__
void
MergedMomentumLarsKernel
(
LarsParamWarpper
<
T
,
MT
>
*
lars_warpper
,
__global__
void
MergedMomentumLarsKernel
(
LarsParamWarpper
<
T
,
MT
>
lars_warpper
,
MT
*
__restrict__
p_buffer
,
MT
*
__restrict__
p_buffer
,
MT
*
__restrict__
g_buffer
,
MT
*
__restrict__
g_buffer
,
const
int
op_num
,
const
MT
mu
,
const
int
op_num
,
const
MT
mu
,
...
@@ -279,18 +276,18 @@ __global__ void MergedMomentumLarsKernel(LarsParamWarpper<T, MT>* lars_warpper,
...
@@ -279,18 +276,18 @@ __global__ void MergedMomentumLarsKernel(LarsParamWarpper<T, MT>* lars_warpper,
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
const
cooperative_groups
::
grid_group
cg
=
cooperative_groups
::
this_grid
();
const
cooperative_groups
::
grid_group
cg
=
cooperative_groups
::
this_grid
();
for
(
int
i
=
0
;
i
<
op_num
;
++
i
)
{
for
(
int
i
=
0
;
i
<
op_num
;
++
i
)
{
int
numel
=
lars_warpper
->
numel_arr
[
i
];
int
numel
=
lars_warpper
.
numel_arr
[
i
];
MT
param_norm
=
static_cast
<
MT
>
(
0
);
MT
param_norm
=
static_cast
<
MT
>
(
0
);
MT
grad_norm
=
static_cast
<
MT
>
(
0
);
MT
grad_norm
=
static_cast
<
MT
>
(
0
);
L2NormKernel
<
T
,
MT
>
(
&
cg
,
lars_warpper
->
p_arr
[
i
],
lars_warpper
->
g_arr
[
i
],
L2NormKernel
<
T
,
MT
>
(
&
cg
,
lars_warpper
.
p_out_arr
[
i
],
lars_warpper
.
g_arr
[
i
],
p_buffer
,
g_buffer
,
numel
,
lars_warpper
->
repeat_arr
[
i
],
p_buffer
,
g_buffer
,
numel
,
lars_warpper
.
repeat_arr
[
i
],
rescale_grad
,
0
,
&
param_norm
,
&
grad_norm
);
rescale_grad
,
0
,
&
param_norm
,
&
grad_norm
);
MomentumUpdate
<
T
,
MT
>
(
MomentumUpdate
<
T
,
MT
>
(
lars_warpper
->
p_arr
[
i
],
lars_warpper
->
g_arr
[
i
],
lars_warpper
.
p_out_arr
[
i
],
lars_warpper
.
g_arr
[
i
],
lars_warpper
->
v_out_arr
[
i
],
lars_warpper
->
p_out_arr
[
i
],
lars_warpper
.
v_out_arr
[
i
],
lars_warpper
.
p_out_arr
[
i
],
lars_warpper
->
v_out_arr
[
i
],
lars_warpper
->
master_p
_arr
[
i
],
lars_warpper
.
v_out_arr
[
i
],
lars_warpper
.
master_p_out
_arr
[
i
],
lars_warpper
->
master_p_out_arr
[
i
],
lars_warpper
->
lr_arr
[
i
],
mu
,
lars_warpper
.
master_p_out_arr
[
i
],
lars_warpper
.
lr_arr
[
i
],
mu
,
lars_warpper
->
weight_decay_arr
[
i
],
lars_coeff
,
epsilon
,
rescale_grad
,
lars_warpper
.
weight_decay_arr
[
i
],
lars_coeff
,
epsilon
,
rescale_grad
,
param_norm
,
grad_norm
,
tid
,
grid_stride
,
numel
,
is_amp
);
param_norm
,
grad_norm
,
tid
,
grid_stride
,
numel
,
is_amp
);
}
}
}
}
...
@@ -410,15 +407,21 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -410,15 +407,21 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
size_t
temp_numel
=
param
[
i
]
->
numel
();
size_t
temp_numel
=
param
[
i
]
->
numel
();
total_numel
+=
temp_numel
;
total_numel
+=
temp_numel
;
lars_warpper
.
numel_arr
[
i
]
=
temp_numel
;
lars_warpper
.
numel_arr
[
i
]
=
temp_numel
;
lars_warpper
.
p_arr
[
i
]
=
param
[
i
]
->
data
<
T
>
();
lars_warpper
.
g_arr
[
i
]
=
grad
[
i
]
->
data
<
T
>
();
lars_warpper
.
g_arr
[
i
]
=
grad
[
i
]
->
data
<
T
>
();
lars_warpper
.
v_arr
[
i
]
=
velocity
[
i
]
->
data
<
MT
>
();
lars_warpper
.
lr_arr
[
i
]
=
learning_rate
[
i
]
->
data
<
MT
>
();
lars_warpper
.
lr_arr
[
i
]
=
learning_rate
[
i
]
->
data
<
MT
>
();
lars_warpper
.
p_out_arr
[
i
]
=
lars_warpper
.
p_out_arr
[
i
]
=
param_out
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
param_out
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
lars_warpper
.
v_out_arr
[
i
]
=
lars_warpper
.
v_out_arr
[
i
]
=
velocity_out
[
i
]
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
());
velocity_out
[
i
]
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
());
lars_warpper
.
weight_decay_arr
[
i
]
=
static_cast
<
MT
>
(
weight_decay_arr
[
i
]);
lars_warpper
.
weight_decay_arr
[
i
]
=
static_cast
<
MT
>
(
weight_decay_arr
[
i
]);
PADDLE_ENFORCE_EQ
(
param
[
i
]
->
data
<
T
>
(),
lars_warpper
.
p_out_arr
[
i
],
platform
::
errors
::
InvalidArgument
(
"Input(Param) and Output(ParamOut) must be the same Tensors."
));
PADDLE_ENFORCE_EQ
(
velocity
[
i
]
->
data
<
MT
>
(),
lars_warpper
.
v_out_arr
[
i
],
platform
::
errors
::
InvalidArgument
(
"Input(Velocity) and Output(VelocityOut) must be "
"the same Tensors."
));
}
}
int64_t
avg_numel
=
total_numel
/
op_num
;
int64_t
avg_numel
=
total_numel
/
op_num
;
LarsThreadConfig
<
float
>
lars_thread_config
(
avg_numel
,
sm_num
,
LarsThreadConfig
<
float
>
lars_thread_config
(
avg_numel
,
sm_num
,
...
@@ -429,19 +432,16 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -429,19 +432,16 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
}
}
if
(
multi_precision
)
{
if
(
multi_precision
)
{
for
(
int
i
=
0
;
i
<
op_num
;
++
i
)
{
for
(
int
i
=
0
;
i
<
op_num
;
++
i
)
{
lars_warpper
.
master_p_arr
[
i
]
=
master_param
[
i
]
->
data
<
MT
>
();
lars_warpper
.
master_p_out_arr
[
i
]
=
lars_warpper
.
master_p_out_arr
[
i
]
=
master_param_out
[
i
]
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
());
master_param_out
[
i
]
->
mutable_data
<
MT
>
(
ctx
.
GetPlace
());
PADDLE_ENFORCE_EQ
(
master_param
[
i
]
->
data
<
MT
>
(),
lars_warpper
.
master_p_out_arr
[
i
],
platform
::
errors
::
InvalidArgument
(
"Input(MasterParam) and Output(MasterParamOut) "
"must be the same Tensors."
));
}
}
}
}
auto
merged_buf
=
memory
::
Alloc
(
cuda_ctx
,
sizeof
(
lars_warpper
));
void
*
cuda_param
[]
=
{
reinterpret_cast
<
void
*>
(
&
lars_warpper
),
auto
*
merged_ptr
=
reinterpret_cast
<
LarsParamWarpper
<
T
,
MT
>*>
(
merged_buf
->
ptr
());
memory
::
Copy
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
cuda_ctx
.
GetPlace
()),
reinterpret_cast
<
void
*>
(
merged_ptr
),
platform
::
CPUPlace
(),
reinterpret_cast
<
void
*>
(
&
lars_warpper
),
sizeof
(
lars_warpper
),
cuda_ctx
.
stream
());
void
*
cuda_param
[]
=
{
reinterpret_cast
<
void
*>
(
&
merged_ptr
),
reinterpret_cast
<
void
*>
(
&
p_buffer
),
reinterpret_cast
<
void
*>
(
&
p_buffer
),
reinterpret_cast
<
void
*>
(
&
g_buffer
),
reinterpret_cast
<
void
*>
(
&
g_buffer
),
reinterpret_cast
<
void
*>
(
&
op_num
),
reinterpret_cast
<
void
*>
(
&
op_num
),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录