From 68377b4494ca4735760f0a93ff880fc771ead3f1 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 5 Aug 2021 22:50:02 +0800 Subject: [PATCH] fix dygraph has_grad (#34649) --- paddle/fluid/imperative/tracer.cc | 2 + paddle/fluid/imperative/tracer.h | 2 +- .../test_imperative_thread_local_has_grad.py | 59 +++++++++++++++++++ 3 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/test_imperative_thread_local_has_grad.py diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 3d97d68b5c7..9dc9c4d90ac 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -30,6 +30,8 @@ DECLARE_string(tracer_mkldnn_ops_off); namespace paddle { namespace imperative { +thread_local bool Tracer::has_grad_ = true; + static std::shared_ptr g_current_tracer(nullptr); const std::shared_ptr& GetCurrentTracer() { return g_current_tracer; } diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 8f505508782..b734ae5c499 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -118,9 +118,9 @@ class Tracer { bool enable_program_desc_tracing_{false}; std::unique_ptr generator_; platform::Place expected_place_; - bool has_grad_{true}; bool enable_autocast_{false}; GarbageCollectorMap gcs_; + static thread_local bool has_grad_; }; // To access static variable current_tracer diff --git a/python/paddle/fluid/tests/unittests/test_imperative_thread_local_has_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_thread_local_has_grad.py new file mode 100644 index 00000000000..d81849725d7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_thread_local_has_grad.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import time +import paddle.nn as nn +import numpy as np +import threading + + +class SimpleNet(nn.Layer): + def __init__(self, in_dim, out_dim): + super(SimpleNet, self).__init__() + self.fc = nn.Linear(in_dim, out_dim) + + def forward(self, x): + return self.fc(x) + + +class TestCases(unittest.TestCase): + @paddle.no_grad() + def thread_1_main(self): + time.sleep(8) + + def thread_2_main(self): + in_dim = 10 + out_dim = 3 + net = SimpleNet(in_dim, out_dim) + for _ in range(1000): + x = paddle.to_tensor(np.random.rand(32, in_dim).astype('float32')) + self.assertTrue(x.stop_gradient) + x = net(x) + self.assertFalse(x.stop_gradient) + + def test_main(self): + threads = [] + for _ in range(10): + threads.append(threading.Thread(target=self.thread_1_main)) + threads.append(threading.Thread(target=self.thread_2_main)) + for t in threads: + t.start() + for t in threads: + t.join() + + +if __name__ == "__main__": + unittest.main() -- GitLab