diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 3d97d68b5c7dfd66e80620b3cbc2d6dc6f00d5b0..9dc9c4d90acaba81faf2d8438cf26c498661d8df 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -30,6 +30,8 @@ DECLARE_string(tracer_mkldnn_ops_off); namespace paddle { namespace imperative { +thread_local bool Tracer::has_grad_ = true; + static std::shared_ptr g_current_tracer(nullptr); const std::shared_ptr& GetCurrentTracer() { return g_current_tracer; } diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 8f50550878262f1d37c34923e4c8bc55460b08d6..b734ae5c4993690d4ccfb9a02dcbae888cb0fe38 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -118,9 +118,9 @@ class Tracer { bool enable_program_desc_tracing_{false}; std::unique_ptr generator_; platform::Place expected_place_; - bool has_grad_{true}; bool enable_autocast_{false}; GarbageCollectorMap gcs_; + static thread_local bool has_grad_; }; // To access static variable current_tracer diff --git a/python/paddle/fluid/tests/unittests/test_imperative_thread_local_has_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_thread_local_has_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..d81849725d75aadd96466af4d2d0f935b7e60ec0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_imperative_thread_local_has_grad.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import time +import paddle.nn as nn +import numpy as np +import threading + + +class SimpleNet(nn.Layer): + def __init__(self, in_dim, out_dim): + super(SimpleNet, self).__init__() + self.fc = nn.Linear(in_dim, out_dim) + + def forward(self, x): + return self.fc(x) + + +class TestCases(unittest.TestCase): + @paddle.no_grad() + def thread_1_main(self): + time.sleep(8) + + def thread_2_main(self): + in_dim = 10 + out_dim = 3 + net = SimpleNet(in_dim, out_dim) + for _ in range(1000): + x = paddle.to_tensor(np.random.rand(32, in_dim).astype('float32')) + self.assertTrue(x.stop_gradient) + x = net(x) + self.assertFalse(x.stop_gradient) + + def test_main(self): + threads = [] + for _ in range(10): + threads.append(threading.Thread(target=self.thread_1_main)) + threads.append(threading.Thread(target=self.thread_2_main)) + for t in threads: + t.start() + for t in threads: + t.join() + + +if __name__ == "__main__": + unittest.main()