Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
fad4744a
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fad4744a
编写于
11月 06, 2020
作者:
T
taixiurong
提交者:
GitHub
11月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix crash in adam in xpu, *test=kunlun (#28433)
上级
6bba8e57
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
41 addition
and
18 deletion
+41
-18
paddle/fluid/operators/optimizers/adam_op_xpu.cc
paddle/fluid/operators/optimizers/adam_op_xpu.cc
+41
-18
未找到文件。
paddle/fluid/operators/optimizers/adam_op_xpu.cc
浏览文件 @
fad4744a
...
...
@@ -74,7 +74,7 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
"output size is 1, but received "
"value is:%d."
,
beta2_pow_out
->
numel
()));
T
beta1
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"beta1"
));
if
(
ctx
.
HasInput
(
"Beta1Tensor"
))
{
auto
*
beta1_tensor
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Beta1Tensor"
);
...
...
@@ -88,30 +88,53 @@ class AdamOpXPUKernel : public framework::OpKernel<T> {
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
&
grad
=
GET_DATA_SAFELY
(
ctx
.
Input
<
LoDTensor
>
(
"Grad"
),
"Input"
,
"Grad"
,
"Adam"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
const
T
*
beta1_pow_ptr
=
beta1_pow
.
template
data
<
T
>();
const
T
*
beta2_pow_ptr
=
beta2_pow
.
template
data
<
T
>();
Tensor
xpu_beta1_pow
;
Tensor
xpu_beta2_pow
;
if
(
beta1_pow
.
place
()
==
platform
::
CPUPlace
()
&&
beta2_pow
.
place
()
==
platform
::
CPUPlace
())
{
TensorCopy
(
beta1_pow
,
ctx
.
GetPlace
(),
dev_ctx
,
&
xpu_beta1_pow
);
TensorCopy
(
beta2_pow
,
ctx
.
GetPlace
(),
dev_ctx
,
&
xpu_beta2_pow
);
dev_ctx
.
Wait
();
beta1_pow_ptr
=
xpu_beta1_pow
.
template
data
<
T
>();
beta2_pow_ptr
=
xpu_beta2_pow
.
template
data
<
T
>();
}
int
r
=
xpu
::
adam
(
dev_ctx
.
x_context
(),
grad
.
template
data
<
T
>(),
mom1
.
template
data
<
T
>(),
mom2
.
template
data
<
T
>(),
param
.
template
data
<
T
>(),
beta1_pow
.
template
data
<
T
>(),
beta2_pow
.
template
data
<
T
>(),
beta1
,
beta2
,
epsilon
,
lr
.
template
data
<
T
>(),
mom2
.
template
data
<
T
>(),
param
.
template
data
<
T
>(),
beta1_pow_ptr
,
beta2_pow_ptr
,
beta1
,
beta2
,
epsilon
,
lr
.
template
data
<
T
>(),
mom1_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
mom2_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
param_out
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
param
.
numel
());
const
float
*
ptr0
=
beta1_pow
.
template
data
<
T
>();
float
*
ptr1
=
beta1_pow_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
float
cpudata
;
xpu_memcpy
(
&
cpudata
,
ptr0
,
sizeof
(
float
),
XPU_DEVICE_TO_HOST
);
cpudata
=
cpudata
*
beta1
;
xpu_memcpy
(
ptr1
,
&
cpudata
,
sizeof
(
float
),
XPU_HOST_TO_DEVICE
);
const
float
*
ptr2
=
beta2_pow
.
template
data
<
T
>();
float
*
ptr3
=
beta2_pow_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
float
cpudata1
;
xpu_memcpy
(
&
cpudata1
,
ptr2
,
sizeof
(
float
),
XPU_DEVICE_TO_HOST
);
cpudata1
=
cpudata1
*
beta2
;
xpu_memcpy
(
ptr3
,
&
cpudata1
,
sizeof
(
float
),
XPU_HOST_TO_DEVICE
);
//update in cpu and then copy to xpu
if
(
beta1_pow
.
place
()
==
platform
::
CPUPlace
()
&&
beta2_pow
.
place
()
==
platform
::
CPUPlace
())
{
const
T
*
beta1_pow_p
=
beta1_pow
.
template
data
<
T
>();
beta1_pow_out
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
]
=
beta1
*
beta1_pow_p
[
0
];
const
T
*
beta2_pow_p
=
beta2_pow
.
template
data
<
T
>();
beta2_pow_out
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
]
=
beta2
*
beta2_pow_p
[
0
];
}
else
{
T
cpu_beta1_pow_out_data
;
T
cpu_beta2_pow_out_data
;
xpu_memcpy
(
&
cpu_beta1_pow_out_data
,
beta1_pow_ptr
,
sizeof
(
T
),
XPU_DEVICE_TO_HOST
);
cpu_beta1_pow_out_data
=
cpu_beta1_pow_out_data
*
beta1
;
xpu_memcpy
(
&
cpu_beta2_pow_out_data
,
beta2_pow_ptr
,
sizeof
(
T
),
XPU_DEVICE_TO_HOST
);
cpu_beta2_pow_out_data
=
cpu_beta2_pow_out_data
*
beta2
;
T
*
beta1_pow_out_p
=
beta1_pow_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
beta2_pow_out_p
=
beta2_pow_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
xpu_memcpy
(
beta1_pow_out_p
,
&
cpu_beta1_pow_out_data
,
sizeof
(
T
),
XPU_HOST_TO_DEVICE
);
xpu_memcpy
(
beta2_pow_out_p
,
&
cpu_beta2_pow_out_data
,
sizeof
(
T
),
XPU_HOST_TO_DEVICE
);
}
PADDLE_ENFORCE_EQ
(
r
==
xpu
::
Error_t
::
SUCCESS
,
true
,
platform
::
errors
::
External
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录