Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
68377b44
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
68377b44
编写于
8月 05, 2021
作者:
Z
Zeng Jinle
提交者:
GitHub
8月 05, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix dygraph has_grad (#34649)
上级
4a52c0cc
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
62 addition
and
1 deletion
+62
-1
paddle/fluid/imperative/tracer.cc
paddle/fluid/imperative/tracer.cc
+2
-0
paddle/fluid/imperative/tracer.h
paddle/fluid/imperative/tracer.h
+1
-1
python/paddle/fluid/tests/unittests/test_imperative_thread_local_has_grad.py
.../tests/unittests/test_imperative_thread_local_has_grad.py
+59
-0
未找到文件。
paddle/fluid/imperative/tracer.cc
浏览文件 @
68377b44
...
...
@@ -30,6 +30,8 @@ DECLARE_string(tracer_mkldnn_ops_off);
namespace
paddle
{
namespace
imperative
{
thread_local
bool
Tracer
::
has_grad_
=
true
;
static
std
::
shared_ptr
<
Tracer
>
g_current_tracer
(
nullptr
);
const
std
::
shared_ptr
<
Tracer
>&
GetCurrentTracer
()
{
return
g_current_tracer
;
}
...
...
paddle/fluid/imperative/tracer.h
浏览文件 @
68377b44
...
...
@@ -118,9 +118,9 @@ class Tracer {
bool
enable_program_desc_tracing_
{
false
};
std
::
unique_ptr
<
UniqueNameGenerator
>
generator_
;
platform
::
Place
expected_place_
;
bool
has_grad_
{
true
};
bool
enable_autocast_
{
false
};
GarbageCollectorMap
gcs_
;
static
thread_local
bool
has_grad_
;
};
// To access static variable current_tracer
...
...
python/paddle/fluid/tests/unittests/test_imperative_thread_local_has_grad.py
0 → 100644
浏览文件 @
68377b44
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle
import
time
import
paddle.nn
as
nn
import
numpy
as
np
import
threading
class
SimpleNet
(
nn
.
Layer
):
def
__init__
(
self
,
in_dim
,
out_dim
):
super
(
SimpleNet
,
self
).
__init__
()
self
.
fc
=
nn
.
Linear
(
in_dim
,
out_dim
)
def
forward
(
self
,
x
):
return
self
.
fc
(
x
)
class
TestCases
(
unittest
.
TestCase
):
@
paddle
.
no_grad
()
def
thread_1_main
(
self
):
time
.
sleep
(
8
)
def
thread_2_main
(
self
):
in_dim
=
10
out_dim
=
3
net
=
SimpleNet
(
in_dim
,
out_dim
)
for
_
in
range
(
1000
):
x
=
paddle
.
to_tensor
(
np
.
random
.
rand
(
32
,
in_dim
).
astype
(
'float32'
))
self
.
assertTrue
(
x
.
stop_gradient
)
x
=
net
(
x
)
self
.
assertFalse
(
x
.
stop_gradient
)
def
test_main
(
self
):
threads
=
[]
for
_
in
range
(
10
):
threads
.
append
(
threading
.
Thread
(
target
=
self
.
thread_1_main
))
threads
.
append
(
threading
.
Thread
(
target
=
self
.
thread_2_main
))
for
t
in
threads
:
t
.
start
()
for
t
in
threads
:
t
.
join
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录