Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ac47d003
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2301
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ac47d003
编写于
3月 20, 2023
作者:
W
Weilong Wu
提交者:
GitHub
3月 20, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sigmoid custom grad for prim (#51768)
上级
52e1742f
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
194 addition
and
0 deletion
+194
-0
paddle/fluid/prim/api/composite_backward/composite_backward_api.h
...luid/prim/api/composite_backward/composite_backward_api.h
+7
-0
paddle/phi/api/yaml/backward.yaml
paddle/phi/api/yaml/backward.yaml
+1
-0
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sigmoid_grad.py
...tests/prim/prim/vjp/eager/test_comp_eager_sigmoid_grad.py
+73
-0
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sigmoid_grad.py
.../unittests/prim/prim/vjp/static/test_comp_sigmoid_grad.py
+113
-0
未找到文件。
paddle/fluid/prim/api/composite_backward/composite_backward_api.h
浏览文件 @
ac47d003
...
@@ -504,6 +504,13 @@ void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
...
@@ -504,6 +504,13 @@ void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
}
}
}
}
template
<
typename
T
>
void
sigmoid_grad
(
const
Tensor
&
out
,
const
Tensor
&
out_grad
,
Tensor
*
x_grad
)
{
if
(
x_grad
)
{
set_output
<
T
>
(
out_grad
*
(
out
*
(
1
-
out
)),
x_grad
);
}
}
template
<
typename
T
>
template
<
typename
T
>
void
abs_grad
(
const
Tensor
&
x
,
const
Tensor
&
out_grad
,
Tensor
*
x_grad
)
{
void
abs_grad
(
const
Tensor
&
x
,
const
Tensor
&
out_grad
,
Tensor
*
x_grad
)
{
if
(
x_grad
)
{
if
(
x_grad
)
{
...
...
paddle/phi/api/yaml/backward.yaml
浏览文件 @
ac47d003
...
@@ -1298,6 +1298,7 @@
...
@@ -1298,6 +1298,7 @@
func
:
sigmoid_grad
func
:
sigmoid_grad
backward
:
sigmoid_double_grad
backward
:
sigmoid_double_grad
inplace
:
(out_grad -> x_grad)
inplace
:
(out_grad -> x_grad)
composite
:
sigmoid_grad(out, out_grad, x_grad)
-
backward_op
:
sigmoid_triple_grad
-
backward_op
:
sigmoid_triple_grad
forward
:
sigmoid_double_grad (Tensor out, Tensor fwd_grad_out, Tensor grad_grad_x) -> Tensor(grad_out), Tensor(grad_grad_out)
forward
:
sigmoid_double_grad (Tensor out, Tensor fwd_grad_out, Tensor grad_grad_x) -> Tensor(grad_out), Tensor(grad_grad_out)
...
...
python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sigmoid_grad.py
0 → 100644
浏览文件 @
ac47d003
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
parameterized
as
param
import
paddle
import
paddle.nn.functional
as
F
from
paddle.fluid
import
core
@
param
.
parameterized_class
(
(
'primal'
,
'cotangent'
,
'dtype'
),
[
(
np
.
random
.
rand
(
10
,
10
),
np
.
random
.
rand
(
10
,
10
),
np
.
float32
),
],
)
class
TestExpGradComp
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
core
.
set_prim_eager_enabled
(
True
)
cls
.
primal
=
cls
.
primal
.
astype
(
cls
.
dtype
)
if
cls
.
cotangent
is
not
None
:
cls
.
cotangent
=
cls
.
cotangent
.
astype
(
cls
.
dtype
)
def
setUp
(
self
):
paddle
.
enable_static
()
def
tearDown
(
self
):
paddle
.
disable_static
()
def
test_sigmoid_grad_comp
(
self
):
def
actual
(
primal
,
cotangent
):
core
.
set_prim_eager_enabled
(
True
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal
)
dout
=
paddle
.
to_tensor
(
cotangent
)
x
.
stop_gradient
=
False
return
paddle
.
grad
(
F
.
sigmoid
(
x
),
x
,
dout
)[
0
]
def
desired
(
primal
,
cotangent
):
core
.
set_prim_eager_enabled
(
False
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
primal
)
dout
=
paddle
.
to_tensor
(
cotangent
)
x
.
stop_gradient
=
False
return
paddle
.
grad
(
F
.
sigmoid
(
x
),
x
,
dout
)[
0
]
np
.
testing
.
assert_allclose
(
actual
=
actual
(
self
.
primal
,
self
.
cotangent
),
desired
=
desired
(
self
.
primal
,
self
.
cotangent
),
rtol
=
1e-6
,
atol
=
0
,
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sigmoid_grad.py
0 → 100644
浏览文件 @
ac47d003
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
parameterized
as
param
import
paddle
import
paddle.nn.functional
as
F
from
paddle.fluid
import
core
@
param
.
parameterized_class
(
(
'primal'
,
'cotangent'
,
'dtype'
),
[
(
np
.
random
.
rand
(
10
,
10
),
np
.
random
.
rand
(
10
,
10
),
np
.
float32
),
],
)
class
TestExpGradComp
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
core
.
set_prim_eager_enabled
(
True
)
cls
.
primal
=
cls
.
primal
.
astype
(
cls
.
dtype
)
if
cls
.
cotangent
is
not
None
:
cls
.
cotangent
=
cls
.
cotangent
.
astype
(
cls
.
dtype
)
def
setUp
(
self
):
paddle
.
enable_static
()
def
tearDown
(
self
):
paddle
.
disable_static
()
def
test_sigmoid_grad_comp
(
self
):
def
actual
(
primal
,
cotangent
):
core
.
_set_prim_backward_enabled
(
True
)
paddle
.
enable_static
()
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
'primal'
,
primal
.
shape
,
primal
.
dtype
)
dout
=
paddle
.
static
.
data
(
'cotangent'
,
cotangent
.
shape
,
cotangent
.
dtype
)
x
.
stop_gradient
=
False
res
=
F
.
sigmoid
(
x
)
x_grad
=
paddle
.
static
.
gradients
(
res
,
[
x
],
dout
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
sp
)
out
=
exe
.
run
(
program
=
mp
,
feed
=
{
'primal'
:
primal
,
'cotangent'
:
cotangent
,
},
fetch_list
=
[
x_grad
[
0
].
name
,
],
)
return
out
[
0
]
def
desired
(
primal
,
cotangent
):
core
.
_set_prim_backward_enabled
(
False
)
paddle
.
enable_static
()
mp
,
sp
=
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
mp
,
sp
):
x
=
paddle
.
static
.
data
(
'primal'
,
primal
.
shape
,
primal
.
dtype
)
dout
=
paddle
.
static
.
data
(
'cotangent'
,
cotangent
.
shape
,
cotangent
.
dtype
)
x
.
stop_gradient
=
False
res
=
F
.
sigmoid
(
x
)
x_grad
=
paddle
.
static
.
gradients
(
res
,
[
x
],
dout
)
exe
=
paddle
.
static
.
Executor
()
exe
.
run
(
sp
)
out
=
exe
.
run
(
program
=
mp
,
feed
=
{
'primal'
:
primal
,
'cotangent'
:
cotangent
,
},
fetch_list
=
[
x_grad
[
0
].
name
,
],
)
return
out
[
0
]
np
.
testing
.
assert_allclose
(
actual
=
actual
(
self
.
primal
,
self
.
cotangent
),
desired
=
desired
(
self
.
primal
,
self
.
cotangent
),
rtol
=
1e-6
,
atol
=
0
,
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录