Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
047971f0
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
047971f0
编写于
11月 08, 2022
作者:
Z
zhangyikun02
提交者:
GitHub
11月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add adadelta op for xpu, test=kunlun (#47661)
上级
6a6a3ff1
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
348 addition
and
41 deletion
+348
-41
paddle/fluid/platform/device/xpu/xpu2_op_list.h
paddle/fluid/platform/device/xpu/xpu2_op_list.h
+3
-0
paddle/phi/kernels/xpu/adadelta_kernel.cc
paddle/phi/kernels/xpu/adadelta_kernel.cc
+53
-0
python/paddle/fluid/tests/unittests/xpu/test_adadelta_op_xpu.py
.../paddle/fluid/tests/unittests/xpu/test_adadelta_op_xpu.py
+239
-0
python/paddle/fluid/tests/unittests/xpu/test_clip_by_norm_op_xpu.py
...dle/fluid/tests/unittests/xpu/test_clip_by_norm_op_xpu.py
+53
-41
未找到文件。
paddle/fluid/platform/device/xpu/xpu2_op_list.h
浏览文件 @
047971f0
...
...
@@ -33,6 +33,7 @@ XPUOpMap& get_kl2_ops() {
{
"abs_grad"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
()),
pOpKernelType
(
vartype
::
FP16
,
XPUPlace
())})},
{
"adadelta"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"adamw"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"adam"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
()),
...
...
@@ -109,6 +110,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
()),
pOpKernelType
(
vartype
::
FP16
,
XPUPlace
())})},
{
"clip"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"clip_by_norm"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"coalesce_tensor"
,
XPUKernelSet
({
pOpKernelType
(
vartype
::
FP32
,
XPUPlace
())})},
{
"concat_grad"
,
...
...
paddle/phi/kernels/xpu/adadelta_kernel.cc
0 → 100644
浏览文件 @
047971f0
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/adadelta_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
AdadeltaKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
param
,
const
DenseTensor
&
grad
,
const
DenseTensor
&
avg_squared_grad
,
const
DenseTensor
&
avg_squared_update
,
float
rho
,
float
epsilon
,
DenseTensor
*
param_out
,
DenseTensor
*
avg_squared_grad_out
,
DenseTensor
*
avg_squared_update_out
)
{
dev_ctx
.
template
Alloc
<
T
>(
param_out
);
dev_ctx
.
template
Alloc
<
T
>(
avg_squared_grad_out
);
dev_ctx
.
template
Alloc
<
T
>(
avg_squared_update_out
);
int
r
=
xpu
::
adadelta
<
T
,
T
>
(
dev_ctx
.
x_context
(),
param
.
data
<
T
>
(),
grad
.
data
<
T
>
(),
avg_squared_grad
.
data
<
T
>
(),
avg_squared_update
.
data
<
T
>
(),
param_out
->
data
<
T
>
(),
avg_squared_grad_out
->
data
<
T
>
(),
avg_squared_update_out
->
data
<
T
>
(),
param
.
numel
(),
rho
,
epsilon
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"adadelta"
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
adadelta
,
XPU
,
ALL_LAYOUT
,
phi
::
AdadeltaKernel
,
float
)
{}
python/paddle/fluid/tests/unittests/xpu/test_adadelta_op_xpu.py
0 → 100644
浏览文件 @
047971f0
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
sys
sys
.
path
.
append
(
".."
)
from
op_test
import
OpTest
import
paddle
import
paddle.fluid
as
fluid
from
op_test_xpu
import
XPUOpTest
from
xpu.get_test_cover_info
import
(
create_test_class
,
get_xpu_op_support_types
,
XPUOpTestWrapper
,
)
paddle
.
enable_static
()
class
XPUTestAdadelta
(
XPUOpTestWrapper
):
def
__init__
(
self
):
self
.
op_name
=
'adadelta'
class
TestAdadeltaOp1
(
XPUOpTest
):
def
setUp
(
self
):
self
.
op_type
=
"adadelta"
self
.
dtype
=
self
.
in_type
self
.
place
=
paddle
.
XPUPlace
(
0
)
param
=
np
.
random
.
uniform
(
-
1
,
1
,
(
102
,
105
)).
astype
(
self
.
dtype
)
grad
=
np
.
random
.
uniform
(
-
1
,
1
,
(
102
,
105
)).
astype
(
self
.
dtype
)
# The squared gradient is positive
avg_squared_grad
=
np
.
random
.
random
((
102
,
105
)).
astype
(
self
.
dtype
)
# The squared update is positive
avg_squared_update
=
np
.
random
.
random
((
102
,
105
)).
astype
(
self
.
dtype
)
rho
=
0.95
epsilon
=
1e-6
self
.
inputs
=
{
'Param'
:
param
,
'Grad'
:
grad
,
'AvgSquaredGrad'
:
avg_squared_grad
,
'AvgSquaredUpdate'
:
avg_squared_update
,
}
self
.
attrs
=
{
'rho'
:
rho
,
'epsilon'
:
epsilon
}
avg_squared_grad_out
=
rho
*
avg_squared_grad
+
(
1
-
rho
)
*
np
.
square
(
grad
)
update
=
-
np
.
multiply
(
np
.
sqrt
(
np
.
divide
(
avg_squared_update
+
epsilon
,
avg_squared_grad_out
+
epsilon
,
)
),
grad
,
)
avg_squared_update_out
=
rho
*
avg_squared_update
+
(
1
-
rho
)
*
np
.
square
(
update
)
param_out
=
param
+
update
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'AvgSquaredGradOut'
:
avg_squared_grad_out
,
'AvgSquaredUpdateOut'
:
avg_squared_update_out
,
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestAdadeltaOp2
(
OpTest
):
'''Test Adadelta op with default attribute values'''
def
setUp
(
self
):
self
.
op_type
=
"adadelta"
self
.
dtype
=
self
.
in_type
self
.
place
=
paddle
.
XPUPlace
(
0
)
param
=
np
.
random
.
uniform
(
-
1
,
1
,
(
102
,
105
)).
astype
(
self
.
dtype
)
grad
=
np
.
random
.
uniform
(
-
1
,
1
,
(
102
,
105
)).
astype
(
self
.
dtype
)
# The squared gradient is positive
avg_squared_grad
=
np
.
random
.
random
((
102
,
105
)).
astype
(
self
.
dtype
)
# The squared update is positive
avg_squared_update
=
np
.
random
.
random
((
102
,
105
)).
astype
(
self
.
dtype
)
rho
=
0.95
epsilon
=
1e-6
self
.
inputs
=
{
'Param'
:
param
,
'Grad'
:
grad
,
'AvgSquaredGrad'
:
avg_squared_grad
,
'AvgSquaredUpdate'
:
avg_squared_update
,
}
avg_squared_grad_out
=
rho
*
avg_squared_grad
+
(
1
-
rho
)
*
np
.
square
(
grad
)
update
=
-
np
.
multiply
(
np
.
sqrt
(
np
.
divide
(
avg_squared_update
+
epsilon
,
avg_squared_grad_out
+
epsilon
,
)
),
grad
,
)
avg_squared_update_out
=
rho
*
avg_squared_update
+
(
1
-
rho
)
*
np
.
square
(
update
)
param_out
=
param
+
update
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'AvgSquaredGradOut'
:
avg_squared_grad_out
,
'AvgSquaredUpdateOut'
:
avg_squared_update_out
,
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestAdadeltaV2
(
unittest
.
TestCase
):
def
test_adadelta_dygraph
(
self
):
self
.
dtype
=
self
.
in_type
self
.
place
=
paddle
.
XPUPlace
(
0
)
paddle
.
disable_static
(
self
.
place
)
value
=
np
.
arange
(
26
).
reshape
(
2
,
13
).
astype
(
self
.
dtype
)
a
=
paddle
.
to_tensor
(
value
)
linear
=
paddle
.
nn
.
Linear
(
13
,
5
)
# This can be any optimizer supported by dygraph.
adam
=
paddle
.
optimizer
.
Adadelta
(
learning_rate
=
0.01
,
parameters
=
linear
.
parameters
(),
weight_decay
=
0.01
,
)
out
=
linear
(
a
)
out
.
backward
()
adam
.
step
()
adam
.
clear_gradients
()
def
test_adadelta
(
self
):
self
.
dtype
=
self
.
in_type
paddle
.
enable_static
()
place
=
fluid
.
XPUPlace
(
0
)
main
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
self
.
dtype
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
self
.
dtype
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
paddle
.
mean
(
cost
)
rms_optimizer
=
paddle
.
optimizer
.
Adadelta
(
learning_rate
=
0.1
)
rms_optimizer
.
minimize
(
avg_cost
)
fetch_list
=
[
avg_cost
]
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
uci_housing
.
train
(),
batch_size
=
1
)
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
x
,
y
])
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
for
data
in
train_reader
():
exe
.
run
(
main
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
fetch_list
)
def
test_raise_error
(
self
):
self
.
assertRaises
(
ValueError
,
paddle
.
optimizer
.
Adadelta
,
None
)
self
.
assertRaises
(
ValueError
,
paddle
.
optimizer
.
Adadelta
,
learning_rate
=
0.1
,
rho
=
None
,
)
self
.
assertRaises
(
ValueError
,
paddle
.
optimizer
.
Adadelta
,
learning_rate
=
0.1
,
epsilon
=
None
,
)
class
TestAdadeltaV2Group
(
TestAdadeltaV2
):
def
test_adadelta_dygraph
(
self
):
self
.
dtype
=
self
.
in_type
self
.
place
=
paddle
.
XPUPlace
(
0
)
paddle
.
disable_static
(
self
.
place
)
value
=
np
.
arange
(
26
).
reshape
(
2
,
13
).
astype
(
self
.
dtype
)
a
=
paddle
.
to_tensor
(
value
)
linear_1
=
paddle
.
nn
.
Linear
(
13
,
5
)
linear_2
=
paddle
.
nn
.
Linear
(
5
,
5
)
# This can be any optimizer supported by dygraph.
adam
=
paddle
.
optimizer
.
Adadelta
(
learning_rate
=
0.01
,
parameters
=
[
{
'params'
:
linear_1
.
parameters
()},
{
'params'
:
linear_2
.
parameters
(),
'weight_decay'
:
0.001
,
},
],
weight_decay
=
0.1
,
)
out
=
linear_1
(
a
)
out
=
linear_2
(
out
)
out
.
backward
()
adam
.
step
()
adam
.
clear_gradients
()
support_types
=
get_xpu_op_support_types
(
'adadelta'
)
for
stype
in
support_types
:
create_test_class
(
globals
(),
XPUTestAdadelta
,
stype
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/xpu/test_clip_by_norm_op_xpu.py
浏览文件 @
047971f0
# Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -19,56 +19,68 @@ import unittest
import
numpy
as
np
from
op_test_xpu
import
XPUOpTest
import
paddle
from
xpu.get_test_cover_info
import
(
create_test_class
,
get_xpu_op_support_types
,
XPUOpTestWrapper
,
)
class
TestXPUClipByNormOp
(
XPUOpTest
):
def
setUp
(
self
):
self
.
op_type
=
"clip_by_norm"
self
.
dtype
=
np
.
float32
self
.
use_xpu
=
True
self
.
max_relative_error
=
0.006
self
.
initTestCase
()
input
=
np
.
random
.
random
(
self
.
shape
).
astype
(
"float32"
)
input
[
np
.
abs
(
input
)
<
self
.
max_relative_error
]
=
0.5
self
.
inputs
=
{
'X'
:
input
,
}
self
.
attrs
=
{}
self
.
attrs
[
'max_norm'
]
=
self
.
max_norm
norm
=
np
.
sqrt
(
np
.
sum
(
np
.
square
(
input
)))
if
norm
>
self
.
max_norm
:
output
=
self
.
max_norm
*
input
/
norm
else
:
output
=
input
self
.
outputs
=
{
'Out'
:
output
}
class
XPUTestClipByNormOp
(
XPUOpTestWrapper
):
def
__init__
(
self
):
self
.
op_name
=
'clip_by_norm'
self
.
use_dynamic_create_class
=
False
def
test_check_output
(
self
):
if
paddle
.
is_compiled_with_xpu
():
paddle
.
enable_static
()
place
=
paddle
.
XPUPlace
(
0
)
self
.
check_output_with_place
(
place
)
class
TestClipByNormOp
(
XPUOpTest
):
def
setUp
(
self
):
self
.
op_type
=
"clip_by_norm"
self
.
dtype
=
self
.
in_type
self
.
place
=
paddle
.
XPUPlace
(
0
)
self
.
use_xpu
=
True
self
.
max_relative_error
=
0.006
self
.
initTestCase
()
input
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
input
[
np
.
abs
(
input
)
<
self
.
max_relative_error
]
=
0.5
self
.
inputs
=
{
'X'
:
input
,
}
self
.
attrs
=
{}
self
.
attrs
[
'max_norm'
]
=
self
.
max_norm
norm
=
np
.
sqrt
(
np
.
sum
(
np
.
square
(
input
)))
if
norm
>
self
.
max_norm
:
output
=
self
.
max_norm
*
input
/
norm
else
:
output
=
input
self
.
outputs
=
{
'Out'
:
output
}
def
initTestCase
(
self
):
self
.
shape
=
(
100
,)
self
.
max_norm
=
1.0
def
test_check_output
(
self
):
if
paddle
.
is_compiled_with_xpu
():
paddle
.
enable_static
()
self
.
check_output_with_place
(
self
.
place
)
def
initTestCase
(
self
):
self
.
shape
=
(
100
,)
self
.
max_norm
=
1.0
class
TestCase1
(
TestXPU
ClipByNormOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
100
,)
self
.
max_norm
=
1e20
class
TestCase1
(
Test
ClipByNormOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
100
,)
self
.
max_norm
=
1e20
class
TestCase2
(
TestClipByNormOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
16
,
16
)
self
.
max_norm
=
0.1
class
TestCase2
(
TestXPU
ClipByNormOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
16
,
16
)
self
.
max_norm
=
0.1
class
TestCase3
(
Test
ClipByNormOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
4
,
8
,
16
)
self
.
max_norm
=
1.0
class
TestCase3
(
TestXPUClipByNormOp
):
def
initTestCase
(
self
):
self
.
shape
=
(
4
,
8
,
16
)
self
.
max_norm
=
1.0
support_types
=
get_xpu_op_support_types
(
'clip_by_norm'
)
for
stype
in
support_types
:
create_test_class
(
globals
(),
XPUTestClipByNormOp
,
stype
)
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录