Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
7541579a
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7541579a
编写于
10月 11, 2022
作者:
C
Chenxiao Niu
提交者:
GitHub
10月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[MLU] add masterparam support for mlu adamw. (#46804)
上级
28ef0fff
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
48 addition
and
5 deletion
+48
-5
paddle/fluid/operators/optimizers/adam_op_mlu.cc
paddle/fluid/operators/optimizers/adam_op_mlu.cc
+48
-5
未找到文件。
paddle/fluid/operators/optimizers/adam_op_mlu.cc
浏览文件 @
7541579a
...
...
@@ -291,11 +291,38 @@ class AdamWMLUKernel : public AdamMLUKernel<T> {
skip_update
=
skip_update_vec
[
0
];
}
bool
with_decay
=
ctx
.
Attr
<
bool
>
(
"with_decay"
);
const
bool
multi_precision
=
ctx
.
Attr
<
bool
>
(
"multi_precision"
);
auto
*
param_out
=
ctx
.
Output
<
LoDTensor
>
(
"ParamOut"
);
auto
*
master_param_out
=
ctx
.
Output
<
LoDTensor
>
(
"MasterParamOut"
);
const
auto
*
master_param
=
ctx
.
Input
<
LoDTensor
>
(
"MasterParam"
);
VLOG
(
3
)
<<
"Skip update: "
<<
skip_update
<<
", With decay: "
<<
with_decay
;
if
(
!
skip_update
&&
with_decay
)
{
if
(
ctx
.
HasInput
(
"MasterParam"
))
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Master Param is not supported on MLU"
));
auto
*
param
=
ctx
.
Input
<
LoDTensor
>
(
"Param"
);
MLUCnnlTensorDesc
param_desc
(
*
param
);
if
(
multi_precision
)
{
VLOG
(
3
)
<<
"[adamw] multi_precision, cast masterparam to param."
;
bool
has_master
=
ctx
.
HasInput
(
"MasterParam"
)
&&
ctx
.
HasOutput
(
"MasterParamOut"
);
PADDLE_ENFORCE_EQ
(
has_master
,
true
,
platform
::
errors
::
InvalidArgument
(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"
));
// cast masterparam (fp32) to param (fp16), then paramout (fp16) to
// masterparamout (fp32)
MLUCnnlTensorDesc
master_param_desc
(
*
master_param
);
cnnlCastDataType_t
cast_type
=
GetCastDataType
(
framework
::
TransToProtoVarType
(
master_param
->
dtype
()),
framework
::
TransToProtoVarType
(
param
->
dtype
()));
MLUCnnl
::
Cast
(
ctx
,
cast_type
,
master_param_desc
.
get
(),
GetBasePtr
(
master_param
),
param_desc
.
get
(),
const_cast
<
void
*>
(
GetBasePtr
(
param
)));
}
else
{
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
PADDLE_ENFORCE_EQ
(
param_var
->
IsType
<
phi
::
DenseTensor
>
(),
...
...
@@ -305,13 +332,12 @@ class AdamWMLUKernel : public AdamMLUKernel<T> {
"but the received is %s"
,
ctx
.
InputNames
(
"Param"
).
front
(),
framework
::
ToTypeName
(
param_var
->
Type
())));
auto
*
param
=
ctx
.
Input
<
LoDTensor
>
(
"Param"
);
auto
*
lr
=
ctx
.
Input
<
LoDTensor
>
(
"LearningRate"
);
float
coeff
=
ctx
.
Attr
<
float
>
(
"coeff"
);
// update param with decay coeff: mul(-1 * lr, coeff * param) + param
MLUCnnlTensorDesc
lr_desc
(
*
lr
);
MLUCnnlTensorDesc
param_desc
(
*
param
);
MLUCnnlOpTensorDesc
mul_op_desc
(
CNNL_OP_TENSOR_MUL
,
ToCnnlDataType
<
T
>
(),
CNNL_NOT_PROPAGATE_NAN
);
...
...
@@ -330,6 +356,23 @@ class AdamWMLUKernel : public AdamMLUKernel<T> {
}
}
AdamMLUKernel
<
T
>::
Compute
(
ctx
);
if
(
multi_precision
)
{
VLOG
(
3
)
<<
"[adamw] multi_precision, cast paramout to masterparamout."
;
// cast paramout to masterparamout
master_param_out
->
mutable_data
<
float
>
(
ctx
.
GetPlace
());
cnnlCastDataType_t
cast_type
=
GetCastDataType
(
framework
::
TransToProtoVarType
(
param_out
->
dtype
()),
framework
::
TransToProtoVarType
(
master_param_out
->
dtype
()));
MLUCnnlTensorDesc
param_out_desc
(
*
param_out
);
MLUCnnlTensorDesc
master_param_out_desc
(
*
master_param_out
);
MLUCnnl
::
Cast
(
ctx
,
cast_type
,
param_out_desc
.
get
(),
GetBasePtr
(
param_out
),
master_param_out_desc
.
get
(),
GetBasePtr
(
master_param_out
));
}
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录