Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f8ca5a9d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f8ca5a9d
编写于
4月 22, 2021
作者:
Y
Yang Zhang
提交者:
GitHub
4月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add `paddle.set_grad_enabled` (#31794)
上级
c3328288
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
64 addition
and
1 deletion
+64
-1
python/paddle/__init__.py
python/paddle/__init__.py
+1
-0
python/paddle/fluid/tests/unittests/test_imperative_basic.py
python/paddle/fluid/tests/unittests/test_imperative_basic.py
+22
-0
python/paddle/framework/__init__.py
python/paddle/framework/__init__.py
+5
-1
python/paddle/framework/framework.py
python/paddle/framework/framework.py
+36
-0
未找到文件。
python/paddle/__init__.py
浏览文件 @
f8ca5a9d
...
...
@@ -257,6 +257,7 @@ from .framework import CUDAPinnedPlace #DEFINE_ALIAS
from
.framework
import
grad
#DEFINE_ALIAS
from
.framework
import
no_grad
#DEFINE_ALIAS
from
.framework
import
set_grad_enabled
#DEFINE_ALIAS
from
.framework
import
save
#DEFINE_ALIAS
from
.framework
import
load
#DEFINE_ALIAS
from
.framework
import
DataParallel
#DEFINE_ALIAS
...
...
python/paddle/fluid/tests/unittests/test_imperative_basic.py
浏览文件 @
f8ca5a9d
...
...
@@ -296,6 +296,28 @@ class TestImperative(unittest.TestCase):
self
.
assertTrue
(
tmp
.
_grad_ivar
()
is
None
)
self
.
assertTrue
(
l0
.
weight
.
_grad_ivar
()
is
not
None
)
def
test_paddle_imperative_set_grad_enabled
(
self
):
data
=
np
.
array
([[
2
,
3
],
[
4
,
5
]]).
astype
(
'float32'
)
with
fluid
.
dygraph
.
guard
():
l0
=
fluid
.
Linear
(
2
,
2
)
self
.
assertTrue
(
l0
.
weight
.
_grad_ivar
()
is
None
)
l1
=
fluid
.
Linear
(
2
,
2
)
with
paddle
.
set_grad_enabled
(
False
):
self
.
assertTrue
(
l1
.
weight
.
stop_gradient
is
False
)
tmp
=
l1
.
weight
*
2
with
paddle
.
set_grad_enabled
(
True
):
tmp2
=
l1
.
weight
*
2
self
.
assertTrue
(
tmp
.
stop_gradient
)
self
.
assertTrue
(
tmp2
.
stop_gradient
is
False
)
x
=
fluid
.
dygraph
.
to_variable
(
data
)
y
=
l0
(
x
)
+
tmp2
o
=
l1
(
y
)
o
.
backward
()
self
.
assertTrue
(
tmp
.
_grad_ivar
()
is
None
)
self
.
assertTrue
(
tmp2
.
_grad_ivar
()
is
not
None
)
self
.
assertTrue
(
l0
.
weight
.
_grad_ivar
()
is
not
None
)
def
test_sum_op
(
self
):
x
=
np
.
ones
([
2
,
2
],
np
.
float32
)
with
fluid
.
dygraph
.
guard
():
...
...
python/paddle/framework/__init__.py
浏览文件 @
f8ca5a9d
...
...
@@ -18,12 +18,16 @@ __all__ = [
'NPUPlace'
,
'get_default_dtype'
,
'set_default_dtype'
]
__all__
+=
[
'grad'
,
'LayerList'
,
'load'
,
'save'
,
'no_grad'
,
'DataParallel'
]
__all__
+=
[
'grad'
,
'set_grad_enabled'
,
'LayerList'
,
'load'
,
'save'
,
'no_grad'
,
'DataParallel'
]
from
.
import
random
from
.random
import
seed
from
.framework
import
get_default_dtype
from
.framework
import
set_default_dtype
from
.framework
import
set_grad_enabled
from
..fluid.param_attr
import
ParamAttr
#DEFINE_ALIAS
# from ..fluid.layers.tensor import create_global_var #DEFINE_ALIAS
...
...
python/paddle/framework/framework.py
浏览文件 @
f8ca5a9d
...
...
@@ -15,7 +15,9 @@
# TODO: define framework api
from
paddle.fluid.layer_helper_base
import
LayerHelperBase
from
paddle.fluid.data_feeder
import
convert_dtype
from
paddle.fluid.framework
import
_dygraph_tracer
import
numpy
as
np
from
contextlib
import
contextmanager
__all__
=
[
'set_default_dtype'
,
'get_default_dtype'
]
...
...
@@ -80,3 +82,37 @@ def get_default_dtype():
paddle.get_default_dtype()
"""
return
LayerHelperBase
.
get_default_dtype
()
@
contextmanager
def
set_grad_enabled
(
mode
):
"""
:api_attr: imperative
Create a context which enables or disables dygraph gradient calculation.
Args:
mode(bool): whether to enable (`True`), or disable (`False`) grad.
Examples:
.. code-block:: python
x = paddle.ones([3, 2])
x.stop_gradient = False
with torch.set_grad_enabled(False):
y = x * 2
with torch.set_grad_enabled(True):
z = x * 2
print(y.stop_gradient) # True
print(z.stop_gradient) # False
"""
tracer
=
_dygraph_tracer
()
if
tracer
:
prev_mode
=
tracer
.
_has_grad
tracer
.
_has_grad
=
mode
try
:
yield
finally
:
tracer
.
_has_grad
=
prev_mode
else
:
yield
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录