Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
18745e6f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
18745e6f
编写于
1月 17, 2023
作者:
X
xiongkun
提交者:
GitHub
1月 17, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2Static] fix switch static graph affects dataloader (#49821)
* rebase merge * code fix * fix bugs
上级
611da7fc
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
117 addition
and
81 deletion
+117
-81
python/paddle/__init__.py
python/paddle/__init__.py
+1
-1
python/paddle/autograd/backward_mode.py
python/paddle/autograd/backward_mode.py
+2
-2
python/paddle/fluid/dygraph/base.py
python/paddle/fluid/dygraph/base.py
+14
-17
python/paddle/fluid/dygraph/math_op_patch.py
python/paddle/fluid/dygraph/math_op_patch.py
+4
-4
python/paddle/fluid/dygraph/varbase_patch_methods.py
python/paddle/fluid/dygraph/varbase_patch_methods.py
+21
-15
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+56
-23
python/paddle/fluid/lazy_init.py
python/paddle/fluid/lazy_init.py
+3
-3
python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py
.../tests/unittests/dygraph_to_static/test_break_continue.py
+1
-1
python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py
...le/fluid/tests/unittests/dygraph_to_static/test_ifelse.py
+2
-2
python/paddle/fluid/tests/unittests/npu/test_run_program_op_npu.py
...ddle/fluid/tests/unittests/npu/test_run_program_op_npu.py
+3
-3
python/paddle/fluid/tests/unittests/test_run_program_op.py
python/paddle/fluid/tests/unittests/test_run_program_op.py
+3
-3
python/paddle/jit/dy2static/partial_program.py
python/paddle/jit/dy2static/partial_program.py
+5
-5
python/paddle/tensor/logic.py
python/paddle/tensor/logic.py
+2
-2
未找到文件。
python/paddle/__init__.py
浏览文件 @
18745e6f
...
@@ -55,7 +55,7 @@ from .framework.dtype import bool # noqa: F401
...
@@ -55,7 +55,7 @@ from .framework.dtype import bool # noqa: F401
from
.framework.dtype
import
complex64
# noqa: F401
from
.framework.dtype
import
complex64
# noqa: F401
from
.framework.dtype
import
complex128
# noqa: F401
from
.framework.dtype
import
complex128
# noqa: F401
if
fluid
.
framework
.
_in_eager_mode_
:
if
fluid
.
framework
.
global_var
.
_in_eager_mode_
:
Tensor
=
framework
.
core
.
eager
.
Tensor
Tensor
=
framework
.
core
.
eager
.
Tensor
else
:
else
:
from
.framework
import
VarBase
as
Tensor
# noqa: F401
from
.framework
import
VarBase
as
Tensor
# noqa: F401
...
...
python/paddle/autograd/backward_mode.py
浏览文件 @
18745e6f
...
@@ -107,7 +107,7 @@ def backward(tensors, grad_tensors=None, retain_graph=False):
...
@@ -107,7 +107,7 @@ def backward(tensors, grad_tensors=None, retain_graph=False):
each_tensor
,
(
paddle
.
Tensor
,
core
.
eager
.
Tensor
)
each_tensor
,
(
paddle
.
Tensor
,
core
.
eager
.
Tensor
)
),
"The argument 'grad_tensors' of paddle.autograd.backward is invalid, it can be 'None', 'paddle.Tensor' or 'list[None/paddle.Tensor]'."
),
"The argument 'grad_tensors' of paddle.autograd.backward is invalid, it can be 'None', 'paddle.Tensor' or 'list[None/paddle.Tensor]'."
else
:
else
:
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
grad_tensors
=
[]
grad_tensors
=
[]
else
:
else
:
grad_tensors
=
[
None
]
*
len
(
tensors
)
grad_tensors
=
[
None
]
*
len
(
tensors
)
...
@@ -119,7 +119,7 @@ def backward(tensors, grad_tensors=None, retain_graph=False):
...
@@ -119,7 +119,7 @@ def backward(tensors, grad_tensors=None, retain_graph=False):
assert
isinstance
(
retain_graph
,
bool
),
"retain_graph must be True or False"
assert
isinstance
(
retain_graph
,
bool
),
"retain_graph must be True or False"
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
core
.
eager
.
run_backward
(
tensors
,
grad_tensors
,
retain_graph
)
core
.
eager
.
run_backward
(
tensors
,
grad_tensors
,
retain_graph
)
else
:
else
:
core
.
dygraph_run_backward
(
core
.
dygraph_run_backward
(
...
...
python/paddle/fluid/dygraph/base.py
浏览文件 @
18745e6f
...
@@ -20,6 +20,7 @@ import sys
...
@@ -20,6 +20,7 @@ import sys
import
numpy
as
np
import
numpy
as
np
from
paddle.fluid
import
core
from
paddle.fluid
import
core
from
paddle.fluid
import
framework
from
paddle.fluid
import
framework
from
paddle.fluid.framework
import
global_var
from
paddle.fluid.multiprocess_utils
import
CleanupFuncRegistrar
from
paddle.fluid.multiprocess_utils
import
CleanupFuncRegistrar
from
.tracer
import
Tracer
from
.tracer
import
Tracer
import
logging
import
logging
...
@@ -44,7 +45,6 @@ __all__ = [
...
@@ -44,7 +45,6 @@ __all__ = [
]
]
# Flag that indicates whether running code under `@to_static`
# Flag that indicates whether running code under `@to_static`
_in_declarative_mode_
=
False
def
in_declarative_mode
():
def
in_declarative_mode
():
...
@@ -52,7 +52,7 @@ def in_declarative_mode():
...
@@ -52,7 +52,7 @@ def in_declarative_mode():
Return a bool value that indicates whether running code under `@to_static`
Return a bool value that indicates whether running code under `@to_static`
"""
"""
return
_in_declarative_mode_
return
global_var
.
_in_declarative_mode_
def
declarative_unsupport_argument_warning
(
def
declarative_unsupport_argument_warning
(
...
@@ -86,11 +86,11 @@ switch_to_static_graph = wrap_decorator(_switch_to_static_graph_)
...
@@ -86,11 +86,11 @@ switch_to_static_graph = wrap_decorator(_switch_to_static_graph_)
@
signature_safe_contextmanager
@
signature_safe_contextmanager
def
_switch_declarative_mode_guard_
(
is_declarative
=
True
):
def
_switch_declarative_mode_guard_
(
is_declarative
=
True
):
global
_in_declarative_mode_
global
global_var
original_val
=
_in_declarative_mode_
original_val
=
global_var
.
_in_declarative_mode_
_in_declarative_mode_
=
is_declarative
global_var
.
_in_declarative_mode_
=
is_declarative
yield
yield
_in_declarative_mode_
=
original_val
global_var
.
_in_declarative_mode_
=
original_val
@
signature_safe_contextmanager
@
signature_safe_contextmanager
...
@@ -106,9 +106,6 @@ def program_desc_tracing_guard(enable):
...
@@ -106,9 +106,6 @@ def program_desc_tracing_guard(enable):
tracer
.
_enable_program_desc_tracing
=
original_val
tracer
.
_enable_program_desc_tracing
=
original_val
_functional_dygraph_context_manager
=
None
@
signature_safe_contextmanager
@
signature_safe_contextmanager
def
param_guard
(
parameters
):
def
param_guard
(
parameters
):
# Note: parameters is a reference of self._parameters or self._buffers
# Note: parameters is a reference of self._parameters or self._buffers
...
@@ -228,12 +225,12 @@ def enable_dygraph(place=None):
...
@@ -228,12 +225,12 @@ def enable_dygraph(place=None):
print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode
print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode
"""
"""
global
_functional_dygraph_context_manage
r
global
global_va
r
if
_functional_dygraph_context_manager
is
None
:
if
global_var
.
_functional_dygraph_context_manager
is
None
:
_functional_dygraph_context_manager
=
guard
(
global_var
.
_functional_dygraph_context_manager
=
guard
(
place
=
_get_paddle_place
(
place
)
place
=
_get_paddle_place
(
place
)
)
)
_functional_dygraph_context_manager
.
__enter__
()
global_var
.
_functional_dygraph_context_manager
.
__enter__
()
# call disable_dygraph when Python exit
# call disable_dygraph when Python exit
CleanupFuncRegistrar
.
register
(
disable_dygraph
)
CleanupFuncRegistrar
.
register
(
disable_dygraph
)
...
@@ -263,10 +260,10 @@ def disable_dygraph():
...
@@ -263,10 +260,10 @@ def disable_dygraph():
print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode
print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode
"""
"""
global
_functional_dygraph_context_manage
r
global
global_va
r
if
_functional_dygraph_context_manager
is
not
None
:
if
global_var
.
_functional_dygraph_context_manager
is
not
None
:
_functional_dygraph_context_manager
.
__exit__
(
*
sys
.
exc_info
())
global_var
.
_functional_dygraph_context_manager
.
__exit__
(
*
sys
.
exc_info
())
_functional_dygraph_context_manager
=
None
global_var
.
_functional_dygraph_context_manager
=
None
@
signature_safe_contextmanager
@
signature_safe_contextmanager
...
...
python/paddle/fluid/dygraph/math_op_patch.py
浏览文件 @
18745e6f
...
@@ -74,7 +74,7 @@ def monkey_patch_math_varbase():
...
@@ -74,7 +74,7 @@ def monkey_patch_math_varbase():
@
no_grad
@
no_grad
def
create_tensor
(
value
,
dtype
,
shape
):
def
create_tensor
(
value
,
dtype
,
shape
):
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
out
=
_C_ops
.
full
(
out
=
_C_ops
.
full
(
shape
,
value
,
dtype
,
framework
.
_current_expected_place
()
shape
,
value
,
dtype
,
framework
.
_current_expected_place
()
)
)
...
@@ -251,7 +251,7 @@ def monkey_patch_math_varbase():
...
@@ -251,7 +251,7 @@ def monkey_patch_math_varbase():
# 2. create varbase for scalar
# 2. create varbase for scalar
lhs_dtype
=
self
.
dtype
lhs_dtype
=
self
.
dtype
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
other_var_should_be
=
core
.
eager
.
Tensor
other_var_should_be
=
core
.
eager
.
Tensor
else
:
else
:
other_var_should_be
=
core
.
VarBase
other_var_should_be
=
core
.
VarBase
...
@@ -486,7 +486,7 @@ def monkey_patch_math_varbase():
...
@@ -486,7 +486,7 @@ def monkey_patch_math_varbase():
global
_already_patch_varbase
global
_already_patch_varbase
global
_already_patch_eager_tensor
global
_already_patch_eager_tensor
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
local_already_patch
=
_already_patch_eager_tensor
local_already_patch
=
_already_patch_eager_tensor
_already_patch_eager_tensor
=
True
_already_patch_eager_tensor
=
True
local_tensor
=
core
.
eager
.
Tensor
local_tensor
=
core
.
eager
.
Tensor
...
@@ -496,7 +496,7 @@ def monkey_patch_math_varbase():
...
@@ -496,7 +496,7 @@ def monkey_patch_math_varbase():
local_tensor
=
core
.
VarBase
local_tensor
=
core
.
VarBase
if
not
local_already_patch
:
if
not
local_already_patch
:
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
for
method_name
in
eager_cpp_level_patch
:
for
method_name
in
eager_cpp_level_patch
:
method_impl
=
getattr
(
local_tensor
,
method_name
,
None
)
method_impl
=
getattr
(
local_tensor
,
method_name
,
None
)
if
method_impl
:
if
method_impl
:
...
...
python/paddle/fluid/dygraph/varbase_patch_methods.py
浏览文件 @
18745e6f
...
@@ -54,7 +54,9 @@ class TensorHookRemoveHelper:
...
@@ -54,7 +54,9 @@ class TensorHookRemoveHelper:
def
__init__
(
self
,
tensor
,
hook_id
):
def
__init__
(
self
,
tensor
,
hook_id
):
self
.
_tensor
=
(
self
.
_tensor
=
(
tensor
if
framework
.
_in_eager_mode_
else
weakref
.
ref
(
tensor
)
tensor
if
framework
.
global_var
.
_in_eager_mode_
else
weakref
.
ref
(
tensor
)
)
)
self
.
_hook_id
=
hook_id
self
.
_hook_id
=
hook_id
...
@@ -65,7 +67,11 @@ class TensorHookRemoveHelper:
...
@@ -65,7 +67,11 @@ class TensorHookRemoveHelper:
Returns:
Returns:
bool: Return True if removed successfully
bool: Return True if removed successfully
"""
"""
tensor
=
self
.
_tensor
if
framework
.
_in_eager_mode_
else
self
.
_tensor
()
tensor
=
(
self
.
_tensor
if
framework
.
global_var
.
_in_eager_mode_
else
self
.
_tensor
()
)
if
tensor
is
not
None
:
if
tensor
is
not
None
:
res
=
tensor
.
_remove_grad_hook
(
self
.
_hook_id
)
res
=
tensor
.
_remove_grad_hook
(
self
.
_hook_id
)
if
res
is
True
:
if
res
is
True
:
...
@@ -178,7 +184,7 @@ def monkey_patch_varbase():
...
@@ -178,7 +184,7 @@ def monkey_patch_varbase():
out = linear(t) # call with different weight
out = linear(t) # call with different weight
"""
"""
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
base_tensor
=
core
.
eager
.
Tensor
base_tensor
=
core
.
eager
.
Tensor
else
:
else
:
base_tensor
=
core
.
VarBase
base_tensor
=
core
.
VarBase
...
@@ -282,7 +288,7 @@ def monkey_patch_varbase():
...
@@ -282,7 +288,7 @@ def monkey_patch_varbase():
)
)
record_event
.
begin
()
record_event
.
begin
()
if
grad_tensor
is
not
None
:
if
grad_tensor
is
not
None
:
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
assert
isinstance
(
assert
isinstance
(
grad_tensor
,
core
.
eager
.
Tensor
grad_tensor
,
core
.
eager
.
Tensor
),
"The type of grad_tensor must be paddle.Tensor"
),
"The type of grad_tensor must be paddle.Tensor"
...
@@ -296,7 +302,7 @@ def monkey_patch_varbase():
...
@@ -296,7 +302,7 @@ def monkey_patch_varbase():
grad_tensor
.
name
,
grad_tensor
.
shape
,
self
.
name
,
self
.
shape
grad_tensor
.
name
,
grad_tensor
.
shape
,
self
.
name
,
self
.
shape
)
)
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
if
grad_tensor
is
None
:
if
grad_tensor
is
None
:
grad_tensor
=
[]
grad_tensor
=
[]
else
:
else
:
...
@@ -311,7 +317,7 @@ def monkey_patch_varbase():
...
@@ -311,7 +317,7 @@ def monkey_patch_varbase():
):
):
# TODO(liuyuhui): Currently only for xpu. Will be removed in the future.
# TODO(liuyuhui): Currently only for xpu. Will be removed in the future.
scaled_loss
=
scale_loss
(
self
)
scaled_loss
=
scale_loss
(
self
)
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
core
.
eager
.
run_backward
(
core
.
eager
.
run_backward
(
[
scaled_loss
],
grad_tensor
,
retain_graph
[
scaled_loss
],
grad_tensor
,
retain_graph
)
)
...
@@ -323,7 +329,7 @@ def monkey_patch_varbase():
...
@@ -323,7 +329,7 @@ def monkey_patch_varbase():
framework
.
_dygraph_tracer
(),
framework
.
_dygraph_tracer
(),
)
)
else
:
else
:
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
core
.
eager
.
run_backward
([
self
],
grad_tensor
,
retain_graph
)
core
.
eager
.
run_backward
([
self
],
grad_tensor
,
retain_graph
)
else
:
else
:
core
.
dygraph_run_backward
(
core
.
dygraph_run_backward
(
...
@@ -368,7 +374,7 @@ def monkey_patch_varbase():
...
@@ -368,7 +374,7 @@ def monkey_patch_varbase():
# [500.]
# [500.]
"""
"""
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
if
self
.
grad
is
None
:
if
self
.
grad
is
None
:
return
None
return
None
if
self
.
grad
.
is_selected_rows
():
if
self
.
grad
.
is_selected_rows
():
...
@@ -673,7 +679,7 @@ def monkey_patch_varbase():
...
@@ -673,7 +679,7 @@ def monkey_patch_varbase():
# [[0.30574632, 0.55739117, 0.30902600, 0.39413780, 0.44830436],
# [[0.30574632, 0.55739117, 0.30902600, 0.39413780, 0.44830436],
# [0.79010487, 0.53972793, 0.09495186, 0.44267157, 0.72112119]])
# [0.79010487, 0.53972793, 0.09495186, 0.44267157, 0.72112119]])
"""
"""
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
from
paddle.tensor.to_string
import
tensor_to_string
from
paddle.tensor.to_string
import
tensor_to_string
return
tensor_to_string
(
self
)
return
tensor_to_string
(
self
)
...
@@ -707,7 +713,7 @@ def monkey_patch_varbase():
...
@@ -707,7 +713,7 @@ def monkey_patch_varbase():
raise
RuntimeError
(
raise
RuntimeError
(
"Only Leaf Tensor support the deepcopy at the moment, non-Leaf Tensors contains graph information that does't support deepcopy"
"Only Leaf Tensor support the deepcopy at the moment, non-Leaf Tensors contains graph information that does't support deepcopy"
)
)
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
new_varbase
=
core
.
eager
.
Tensor
()
new_varbase
=
core
.
eager
.
Tensor
()
else
:
else
:
new_varbase
=
core
.
VarBase
()
new_varbase
=
core
.
VarBase
()
...
@@ -725,7 +731,7 @@ def monkey_patch_varbase():
...
@@ -725,7 +731,7 @@ def monkey_patch_varbase():
assert
(
assert
(
numel
==
1
numel
==
1
),
"When Variable is used as the condition of if/while , Variable can only contain one element."
),
"When Variable is used as the condition of if/while , Variable can only contain one element."
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
assert
self
.
_is_initialized
(),
"tensor not initialized"
assert
self
.
_is_initialized
(),
"tensor not initialized"
return
bool
(
np
.
all
(
self
.
numpy
()
>
0
))
return
bool
(
np
.
all
(
self
.
numpy
()
>
0
))
else
:
else
:
...
@@ -850,7 +856,7 @@ def monkey_patch_varbase():
...
@@ -850,7 +856,7 @@ def monkey_patch_varbase():
return
_setitem_impl_
(
self
,
item
,
value
)
return
_setitem_impl_
(
self
,
item
,
value
)
else
:
else
:
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
return
self
.
__setitem_eager_tensor__
(
item
,
value
)
return
self
.
__setitem_eager_tensor__
(
item
,
value
)
else
:
else
:
# Call c++ func __setitem_varbase__ to speedup.
# Call c++ func __setitem_varbase__ to speedup.
...
@@ -1020,7 +1026,7 @@ def monkey_patch_varbase():
...
@@ -1020,7 +1026,7 @@ def monkey_patch_varbase():
def
__hash__
(
self
):
def
__hash__
(
self
):
return
hash
(
id
(
self
))
return
hash
(
id
(
self
))
if
framework
.
_in_eager_mode_
and
not
hasattr
(
core
,
"eager"
):
if
framework
.
global_var
.
_in_eager_mode_
and
not
hasattr
(
core
,
"eager"
):
return
return
for
method_name
,
method
in
(
for
method_name
,
method
in
(
...
@@ -1047,12 +1053,12 @@ def monkey_patch_varbase():
...
@@ -1047,12 +1053,12 @@ def monkey_patch_varbase():
(
"to_dense"
,
to_dense
),
(
"to_dense"
,
to_dense
),
(
"to_sparse_coo"
,
to_sparse_coo
),
(
"to_sparse_coo"
,
to_sparse_coo
),
):
):
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
setattr
(
core
.
eager
.
Tensor
,
method_name
,
method
)
setattr
(
core
.
eager
.
Tensor
,
method_name
,
method
)
else
:
else
:
setattr
(
core
.
VarBase
,
method_name
,
method
)
setattr
(
core
.
VarBase
,
method_name
,
method
)
if
framework
.
_in_eager_mode_
:
if
framework
.
global_var
.
_in_eager_mode_
:
setattr
(
core
.
eager
.
Tensor
,
"_set_grad_ivar"
,
_set_grad_ivar
)
setattr
(
core
.
eager
.
Tensor
,
"_set_grad_ivar"
,
_set_grad_ivar
)
setattr
(
core
.
eager
.
Tensor
,
"value"
,
value
)
setattr
(
core
.
eager
.
Tensor
,
"value"
,
value
)
setattr
(
core
.
eager
.
Tensor
,
"cpu"
,
cpu
)
setattr
(
core
.
eager
.
Tensor
,
"cpu"
,
cpu
)
...
...
python/paddle/fluid/framework.py
浏览文件 @
18745e6f
...
@@ -36,6 +36,7 @@ import paddle.version as fluid_version
...
@@ -36,6 +36,7 @@ import paddle.version as fluid_version
import
warnings
import
warnings
import
functools
import
functools
from
.variable_index
import
_getitem_impl_
,
_setitem_impl_
from
.variable_index
import
_getitem_impl_
,
_setitem_impl_
import
threading
__all__
=
[
__all__
=
[
'Program'
,
'Program'
,
...
@@ -70,8 +71,42 @@ GRAD_VAR_SUFFIX = core.kGradVarSuffix()
...
@@ -70,8 +71,42 @@ GRAD_VAR_SUFFIX = core.kGradVarSuffix()
ZERO_VAR_SUFFIX
=
core
.
kZeroVarSuffix
()
ZERO_VAR_SUFFIX
=
core
.
kZeroVarSuffix
()
CONTROL_DEP_VAR_PREFIX
=
core
.
kControlDepVarName
()
CONTROL_DEP_VAR_PREFIX
=
core
.
kControlDepVarName
()
# use thread local to create thread save global variables.
class
GlobalThreadLocal
(
threading
.
local
):
def
__init__
(
self
):
"""
init the thread local data.
TODO(xiongkun): how to access another thread local data ?
"""
global
_dygraph_tracer_
self
.
_in_declarative_mode_
=
False
self
.
_functional_dygraph_context_manager
=
None
self
.
_dygraph_tracer_
=
_dygraph_tracer_
self
.
_in_eager_mode_
=
True
def
__str__
(
self
):
strings
=
[]
strings
.
append
(
"_in_declarative_mode_:"
+
str
(
self
.
_in_declarative_mode_
)
)
strings
.
append
(
"_functional_dygraph_context_manager:"
+
str
(
self
.
_functional_dygraph_context_manager
)
)
strings
.
append
(
"_dygraph_tracer_:"
+
str
(
self
.
_dygraph_tracer_
))
strings
.
append
(
"_in_eager_mode_:"
+
str
(
self
.
_in_eager_mode_
))
return
"
\n
"
.
join
(
strings
)
def
__setattr__
(
self
,
name
,
val
):
if
name
==
'_dygraph_tracer_'
:
global
_dygraph_tracer_
_dygraph_tracer_
=
val
self
.
__dict__
[
name
]
=
val
_dygraph_tracer_
=
None
_dygraph_tracer_
=
None
_in_eager_mode_
=
True
global_var
=
GlobalThreadLocal
()
_global_expected_place_
=
None
_global_expected_place_
=
None
_current_device
=
None
_current_device
=
None
global_prog_seed
=
0
global_prog_seed
=
0
...
@@ -155,20 +190,17 @@ def _switch_tensor_bind_type(is_eager):
...
@@ -155,20 +190,17 @@ def _switch_tensor_bind_type(is_eager):
def
_enable_legacy_dygraph
():
def
_enable_legacy_dygraph
():
global
_in_eager_mode_
global_var
.
_in_eager_mode_
=
False
_in_eager_mode_
=
False
_update_monkey_methods
(
is_eager
=
False
)
_update_monkey_methods
(
is_eager
=
False
)
def
_disable_legacy_dygraph
():
def
_disable_legacy_dygraph
():
global
_in_eager_mode_
global_var
.
_in_eager_mode_
=
True
_in_eager_mode_
=
True
_update_monkey_methods
(
is_eager
=
True
)
_update_monkey_methods
(
is_eager
=
True
)
def
_in_eager_without_dygraph_check
():
def
_in_eager_without_dygraph_check
():
global
_in_eager_mode_
return
global_var
.
_in_eager_mode_
return
_in_eager_mode_
# FIXME(dev): We haven't fully verified eager mode on XPU/NPU et.al but
# FIXME(dev): We haven't fully verified eager mode on XPU/NPU et.al but
...
@@ -177,7 +209,6 @@ _is_first_import_ = True
...
@@ -177,7 +209,6 @@ _is_first_import_ = True
def
_fallback_legacy_dygraph
():
def
_fallback_legacy_dygraph
():
global
_in_eager_mode_
global
_is_first_import_
global
_is_first_import_
need_fallback
=
False
need_fallback
=
False
# Only enable eager on CPU/GPU/XPU
# Only enable eager on CPU/GPU/XPU
...
@@ -187,12 +218,12 @@ def _fallback_legacy_dygraph():
...
@@ -187,12 +218,12 @@ def _fallback_legacy_dygraph():
or
core
.
is_compiled_with_mlu
()
or
core
.
is_compiled_with_mlu
()
)
)
if
_in_eager_mode_
and
is_not_support
:
if
global_var
.
_in_eager_mode_
and
is_not_support
:
# switch into legacy dygraph mode
# switch into legacy dygraph mode
warnings
.
warn
(
warnings
.
warn
(
"We will fallback into legacy dygraph on NPU/XPU/MLU/IPU/ROCM devices. Because we only support new eager dygraph mode on CPU/GPU currently. "
"We will fallback into legacy dygraph on NPU/XPU/MLU/IPU/ROCM devices. Because we only support new eager dygraph mode on CPU/GPU currently. "
)
)
_in_eager_mode_
=
False
global_var
.
_in_eager_mode_
=
False
if
not
_is_first_import_
:
if
not
_is_first_import_
:
_enable_legacy_dygraph
()
_enable_legacy_dygraph
()
need_fallback
=
True
need_fallback
=
True
...
@@ -234,11 +265,13 @@ def in_dygraph_mode():
...
@@ -234,11 +265,13 @@ def in_dygraph_mode():
print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode
print(paddle.in_dynamic_mode()) # True, Now we are in dynamic mode
"""
"""
return
(
_dygraph_tracer_
is
not
None
)
and
_in_eager_mode_
return
(
global_var
.
_dygraph_tracer_
is
not
None
)
and
global_var
.
_in_eager_mode_
def
_non_static_mode
():
def
_non_static_mode
():
return
_dygraph_tracer_
is
not
None
return
global_var
.
_dygraph_tracer_
is
not
None
@
signature_safe_contextmanager
@
signature_safe_contextmanager
...
@@ -603,7 +636,7 @@ non_static_only = wrap_decorator(_non_static_only_)
...
@@ -603,7 +636,7 @@ non_static_only = wrap_decorator(_non_static_only_)
def
_dygraph_tracer
():
def
_dygraph_tracer
():
return
_dygraph_tracer_
return
global_var
.
_dygraph_tracer_
def
_global_flags
():
def
_global_flags
():
...
@@ -671,9 +704,8 @@ def _current_expected_place():
...
@@ -671,9 +704,8 @@ def _current_expected_place():
def
_set_dygraph_tracer_expected_place
(
place
):
def
_set_dygraph_tracer_expected_place
(
place
):
global
_dygraph_tracer_
if
global_var
.
_dygraph_tracer_
is
not
None
:
if
_dygraph_tracer_
is
not
None
:
global_var
.
_dygraph_tracer_
.
_expected_place
=
place
_dygraph_tracer_
.
_expected_place
=
place
def
_set_expected_place
(
place
):
def
_set_expected_place
(
place
):
...
@@ -1315,7 +1347,7 @@ def _varbase_creator(
...
@@ -1315,7 +1347,7 @@ def _varbase_creator(
if
not
isinstance
(
dtype
,
core
.
VarDesc
.
VarType
):
if
not
isinstance
(
dtype
,
core
.
VarDesc
.
VarType
):
dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
if
_in_eager_mode_
:
if
global_var
.
_in_eager_mode_
:
eager_tensor
=
core
.
eager
.
Tensor
(
eager_tensor
=
core
.
eager
.
Tensor
(
dtype
if
dtype
else
core
.
VarDesc
.
VarType
.
FP32
,
dtype
if
dtype
else
core
.
VarDesc
.
VarType
.
FP32
,
list
(
shape
)
if
shape
else
[],
list
(
shape
)
if
shape
else
[],
...
@@ -7460,16 +7492,17 @@ def _get_var(name, program=None):
...
@@ -7460,16 +7492,17 @@ def _get_var(name, program=None):
@
signature_safe_contextmanager
@
signature_safe_contextmanager
def
_dygraph_guard
(
tracer
):
def
_dygraph_guard
(
tracer
):
global
_dygraph_tracer_
tmp_tracer
=
global_var
.
_dygraph_tracer_
tmp_tracer
=
_dygraph_tracer_
global_var
.
_dygraph_tracer_
=
tracer
_dygraph_tracer_
=
tracer
if
tracer
is
not
None
:
core
.
_switch_tracer
(
tracer
)
core
.
_switch_tracer
(
tracer
)
try
:
try
:
yield
yield
finally
:
finally
:
core
.
_switch_tracer
(
tmp_tracer
)
if
tmp_tracer
is
not
None
:
_dygraph_tracer_
=
tmp_tracer
core
.
_switch_tracer
(
tmp_tracer
)
global_var
.
_dygraph_tracer_
=
tmp_tracer
@
signature_safe_contextmanager
@
signature_safe_contextmanager
...
...
python/paddle/fluid/lazy_init.py
浏览文件 @
18745e6f
...
@@ -59,8 +59,8 @@ class LazyInitHelper:
...
@@ -59,8 +59,8 @@ class LazyInitHelper:
self
.
enable
()
self
.
enable
()
if
self
.
_in_guard
:
if
self
.
_in_guard
:
return
return
self
.
_tracer
=
framework
.
_dygraph_tracer_
self
.
_tracer
=
framework
.
global_var
.
_dygraph_tracer_
framework
.
_dygraph_tracer_
=
None
framework
.
global_var
.
_dygraph_tracer_
=
None
self
.
_in_guard
=
True
self
.
_in_guard
=
True
def
__exit__
(
self
,
*
args
,
**
kwargs
):
def
__exit__
(
self
,
*
args
,
**
kwargs
):
...
@@ -71,7 +71,7 @@ class LazyInitHelper:
...
@@ -71,7 +71,7 @@ class LazyInitHelper:
if
not
self
.
_in_guard
:
if
not
self
.
_in_guard
:
return
return
assert
self
.
_tracer
is
not
None
assert
self
.
_tracer
is
not
None
framework
.
_dygraph_tracer_
=
self
.
_tracer
framework
.
global_var
.
_dygraph_tracer_
=
self
.
_tracer
self
.
_tracer
=
None
self
.
_tracer
=
None
self
.
_in_guard
=
False
self
.
_in_guard
=
False
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_break_continue.py
浏览文件 @
18745e6f
...
@@ -36,7 +36,7 @@ class TestDy2staticException(unittest.TestCase):
...
@@ -36,7 +36,7 @@ class TestDy2staticException(unittest.TestCase):
with
self
.
assertRaisesRegex
(
Dygraph2StaticException
,
self
.
error
):
with
self
.
assertRaisesRegex
(
Dygraph2StaticException
,
self
.
error
):
paddle
.
jit
.
enable_to_static
(
True
)
paddle
.
jit
.
enable_to_static
(
True
)
self
.
assertTrue
(
to_static
(
self
.
dyfunc
)(
self
.
x
))
self
.
assertTrue
(
to_static
(
self
.
dyfunc
)(
self
.
x
))
paddle
.
fluid
.
dygraph
.
base
.
_in_declarative_mode_
=
False
paddle
.
fluid
.
dygraph
.
base
.
global_var
.
_in_declarative_mode_
=
False
paddle
.
jit
.
enable_to_static
(
False
)
paddle
.
jit
.
enable_to_static
(
False
)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_ifelse.py
浏览文件 @
18745e6f
...
@@ -65,7 +65,7 @@ class TestDy2staticException(unittest.TestCase):
...
@@ -65,7 +65,7 @@ class TestDy2staticException(unittest.TestCase):
with
self
.
assertRaisesRegex
(
Dygraph2StaticException
,
self
.
error
):
with
self
.
assertRaisesRegex
(
Dygraph2StaticException
,
self
.
error
):
paddle
.
jit
.
enable_to_static
(
True
)
paddle
.
jit
.
enable_to_static
(
True
)
self
.
assertTrue
(
paddle
.
jit
.
to_static
(
self
.
dyfunc
)(
self
.
x
))
self
.
assertTrue
(
paddle
.
jit
.
to_static
(
self
.
dyfunc
)(
self
.
x
))
paddle
.
fluid
.
dygraph
.
base
.
_in_declarative_mode_
=
False
paddle
.
fluid
.
dygraph
.
base
.
global_var
.
_in_declarative_mode_
=
False
paddle
.
jit
.
enable_to_static
(
False
)
paddle
.
jit
.
enable_to_static
(
False
)
...
@@ -463,7 +463,7 @@ class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1):
...
@@ -463,7 +463,7 @@ class TestDy2StIfElseRetInt4(TestDy2StIfElseRetInt1):
# that the code block is under @to_static, but in this UT
# that the code block is under @to_static, but in this UT
# an exception is thrown during Dy2St, making the `_in_declarative_mode_`
# an exception is thrown during Dy2St, making the `_in_declarative_mode_`
# a wrong value. So We need set `_in_declarative_mode_` to False manually.
# a wrong value. So We need set `_in_declarative_mode_` to False manually.
paddle
.
fluid
.
dygraph
.
base
.
_in_declarative_mode_
=
False
paddle
.
fluid
.
dygraph
.
base
.
global_var
.
_in_declarative_mode_
=
False
paddle
.
jit
.
enable_to_static
(
False
)
paddle
.
jit
.
enable_to_static
(
False
)
...
...
python/paddle/fluid/tests/unittests/npu/test_run_program_op_npu.py
浏览文件 @
18745e6f
...
@@ -25,7 +25,7 @@ from paddle import _C_ops, _legacy_C_ops
...
@@ -25,7 +25,7 @@ from paddle import _C_ops, _legacy_C_ops
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid
import
core
,
framework
,
executor
from
paddle.fluid
import
core
,
framework
,
executor
from
paddle.fluid.layers.utils
import
_hash_with_id
from
paddle.fluid.layers.utils
import
_hash_with_id
from
paddle.fluid.framework
import
_in_eager_mode_
from
paddle.fluid.framework
import
global_var
paddle
.
enable_static
()
paddle
.
enable_static
()
np
.
random
.
seed
(
1243
)
np
.
random
.
seed
(
1243
)
...
@@ -135,7 +135,7 @@ class RunProgramNPUOpTest(unittest.TestCase):
...
@@ -135,7 +135,7 @@ class RunProgramNPUOpTest(unittest.TestCase):
def
prepare_dygraph_input
(
self
,
place
,
return_param_list
=
False
):
def
prepare_dygraph_input
(
self
,
place
,
return_param_list
=
False
):
def
create_var_base
(
is_input
,
name
,
np_value
,
stop_gradient
):
def
create_var_base
(
is_input
,
name
,
np_value
,
stop_gradient
):
if
_in_eager_mode_
:
if
global_var
.
_in_eager_mode_
:
var
=
core
.
eager
.
Tensor
(
var
=
core
.
eager
.
Tensor
(
value
=
np_value
,
name
=
name
,
place
=
place
,
zero_copy
=
True
value
=
np_value
,
name
=
name
,
place
=
place
,
zero_copy
=
True
)
)
...
@@ -176,7 +176,7 @@ class RunProgramNPUOpTest(unittest.TestCase):
...
@@ -176,7 +176,7 @@ class RunProgramNPUOpTest(unittest.TestCase):
for
name
in
self
.
output_names
[
'Out'
]:
for
name
in
self
.
output_names
[
'Out'
]:
outputs
[
'Out'
].
append
(
create_var_base
(
False
,
name
))
outputs
[
'Out'
].
append
(
create_var_base
(
False
,
name
))
if
_in_eager_mode_
:
if
global_var
.
_in_eager_mode_
:
outputs
[
'OutScope'
]
=
[
core
.
Scope
()]
outputs
[
'OutScope'
]
=
[
core
.
Scope
()]
else
:
else
:
outputs
[
'OutScope'
]
=
framework
.
_varbase_creator
(
outputs
[
'OutScope'
]
=
framework
.
_varbase_creator
(
...
...
python/paddle/fluid/tests/unittests/test_run_program_op.py
浏览文件 @
18745e6f
...
@@ -26,7 +26,7 @@ from paddle.fluid.executor import (
...
@@ -26,7 +26,7 @@ from paddle.fluid.executor import (
_is_dy2st_enable_standalone_executor
,
_is_dy2st_enable_standalone_executor
,
_is_enable_standalone_executor
,
_is_enable_standalone_executor
,
)
)
from
paddle.fluid.framework
import
_in_eager_mode_
from
paddle.fluid.framework
import
global_var
from
paddle.fluid.layers.utils
import
_hash_with_id
from
paddle.fluid.layers.utils
import
_hash_with_id
paddle
.
enable_static
()
paddle
.
enable_static
()
...
@@ -177,7 +177,7 @@ class RunProgramOpTest(unittest.TestCase):
...
@@ -177,7 +177,7 @@ class RunProgramOpTest(unittest.TestCase):
def
prepare_dygraph_input
(
self
,
place
,
return_param_list
=
False
):
def
prepare_dygraph_input
(
self
,
place
,
return_param_list
=
False
):
def
create_var_base
(
is_input
,
name
,
np_value
,
stop_gradient
):
def
create_var_base
(
is_input
,
name
,
np_value
,
stop_gradient
):
if
_in_eager_mode_
:
if
global_var
.
_in_eager_mode_
:
var
=
core
.
eager
.
Tensor
(
var
=
core
.
eager
.
Tensor
(
value
=
np_value
,
name
=
name
,
place
=
place
,
zero_copy
=
True
value
=
np_value
,
name
=
name
,
place
=
place
,
zero_copy
=
True
)
)
...
@@ -218,7 +218,7 @@ class RunProgramOpTest(unittest.TestCase):
...
@@ -218,7 +218,7 @@ class RunProgramOpTest(unittest.TestCase):
for
name
in
self
.
output_names
[
'Out'
]:
for
name
in
self
.
output_names
[
'Out'
]:
outputs
[
'Out'
].
append
(
create_var_base
(
False
,
name
))
outputs
[
'Out'
].
append
(
create_var_base
(
False
,
name
))
if
_in_eager_mode_
:
if
global_var
.
_in_eager_mode_
:
outputs
[
'OutScope'
]
=
[
core
.
Scope
()]
outputs
[
'OutScope'
]
=
[
core
.
Scope
()]
else
:
else
:
outputs
[
'OutScope'
]
=
framework
.
_varbase_creator
(
outputs
[
'OutScope'
]
=
framework
.
_varbase_creator
(
...
...
python/paddle/jit/dy2static/partial_program.py
浏览文件 @
18745e6f
...
@@ -619,7 +619,7 @@ class PartialProgramLayer:
...
@@ -619,7 +619,7 @@ class PartialProgramLayer:
if
"@GRAD"
in
name
:
if
"@GRAD"
in
name
:
var_desc
=
block
.
vars
[
name
].
desc
var_desc
=
block
.
vars
[
name
].
desc
var_base
=
None
var_base
=
None
if
not
framework
.
_in_eager_mode_
:
if
not
framework
.
global_var
.
_in_eager_mode_
:
var_base
=
core
.
VarBase
(
var_base
=
core
.
VarBase
(
var_desc
.
dtype
(),
var_desc
.
dtype
(),
var_desc
.
shape
(),
var_desc
.
shape
(),
...
@@ -874,7 +874,7 @@ class PartialProgramLayer:
...
@@ -874,7 +874,7 @@ class PartialProgramLayer:
for
i
,
value
in
enumerate
(
flatten_inputs
):
for
i
,
value
in
enumerate
(
flatten_inputs
):
if
isinstance
(
value
,
np
.
ndarray
):
if
isinstance
(
value
,
np
.
ndarray
):
var
=
None
var
=
None
if
not
framework
.
_in_eager_mode_
:
if
not
framework
.
global_var
.
_in_eager_mode_
:
var
=
core
.
VarBase
(
var
=
core
.
VarBase
(
value
=
value
,
value
=
value
,
name
=
self
.
_inputs
[
i
].
desc
.
name
(),
name
=
self
.
_inputs
[
i
].
desc
.
name
(),
...
@@ -918,7 +918,7 @@ class PartialProgramLayer:
...
@@ -918,7 +918,7 @@ class PartialProgramLayer:
if
var_desc
.
name
()
in
out_varbase_map
:
if
var_desc
.
name
()
in
out_varbase_map
:
return
out_varbase_map
[
var_desc
.
name
()]
return
out_varbase_map
[
var_desc
.
name
()]
if
not
framework
.
_in_eager_mode_
:
if
not
framework
.
global_var
.
_in_eager_mode_
:
var_base
=
core
.
VarBase
(
var_base
=
core
.
VarBase
(
var_desc
.
dtype
(),
var_desc
.
dtype
(),
var_desc
.
shape
(),
var_desc
.
shape
(),
...
@@ -949,7 +949,7 @@ class PartialProgramLayer:
...
@@ -949,7 +949,7 @@ class PartialProgramLayer:
inner_scope
=
self
.
_get_scope
(
inner_scope
=
self
.
_get_scope
(
program_id
=
program_id
,
use_scope_cache
=
use_scope_cache
program_id
=
program_id
,
use_scope_cache
=
use_scope_cache
)
)
if
not
framework
.
_in_eager_mode_
:
if
not
framework
.
global_var
.
_in_eager_mode_
:
tmp_scope_vec
=
core
.
VarBase
(
tmp_scope_vec
=
core
.
VarBase
(
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP32
,
[],
[],
...
@@ -1102,7 +1102,7 @@ def _create_fake_var():
...
@@ -1102,7 +1102,7 @@ def _create_fake_var():
"""
"""
Create a fake_var (force on CPU) to handle empty input or output
Create a fake_var (force on CPU) to handle empty input or output
"""
"""
if
not
framework
.
_in_eager_mode_
:
if
not
framework
.
global_var
.
_in_eager_mode_
:
return
[
return
[
core
.
VarBase
(
core
.
VarBase
(
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP32
,
...
...
python/paddle/tensor/logic.py
浏览文件 @
18745e6f
...
@@ -17,11 +17,11 @@
...
@@ -17,11 +17,11 @@
import
paddle
import
paddle
from
..fluid.data_feeder
import
check_type
,
check_variable_and_dtype
from
..fluid.data_feeder
import
check_type
,
check_variable_and_dtype
from
..fluid.framework
import
_in_eager_mode_
from
..fluid.framework
import
global_var
from
..static
import
Variable
from
..static
import
Variable
from
.layer_function_generator
import
templatedoc
from
.layer_function_generator
import
templatedoc
if
_in_eager_mode_
:
if
global_var
.
_in_eager_mode_
:
Tensor
=
paddle
.
fluid
.
framework
.
core
.
eager
.
Tensor
Tensor
=
paddle
.
fluid
.
framework
.
core
.
eager
.
Tensor
else
:
else
:
from
..framework
import
VarBase
as
Tensor
from
..framework
import
VarBase
as
Tensor
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录