From 04285ab4dd8a54183117fcd002c8218d693e9328 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Wed, 26 Jan 2022 15:43:22 +0800 Subject: [PATCH] [AMP] support setting amp_level in multi-thread (#39198) --- paddle/fluid/imperative/tracer.cc | 2 ++ paddle/fluid/imperative/tracer.h | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index f4e535de108..e845ce10453 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -32,6 +32,8 @@ namespace imperative { thread_local bool Tracer::has_grad_ = true; +thread_local AmpLevel Tracer::amp_level_ = AmpLevel::O0; + static std::shared_ptr g_current_tracer(nullptr); const std::shared_ptr& GetCurrentTracer() { return g_current_tracer; } diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 93f68f2054b..bd8521dabde 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -126,7 +126,7 @@ class Tracer { platform::Place expected_place_; GarbageCollectorMap gcs_; static thread_local bool has_grad_; - AmpLevel amp_level_{AmpLevel::O0}; + static thread_local AmpLevel amp_level_; }; // To access static variable current_tracer -- GitLab