diff --git a/paddle/fluid/imperative/tests/test_tracer.cc b/paddle/fluid/imperative/tests/test_tracer.cc index 4e97f97853e07438c833f1debdc3c84e01d62577..e9d62e376a69e632e9824157358ba614c4df9340 100644 --- a/paddle/fluid/imperative/tests/test_tracer.cc +++ b/paddle/fluid/imperative/tests/test_tracer.cc @@ -278,6 +278,24 @@ TEST(test_tracer, test_unique_name_generator) { ASSERT_STREQ("fc_2", fc_2.c_str()); } +TEST(test_tracer, test_current_tracer) { + // use current_tracer + auto tracer = std::make_shared(); + imperative::SetCurrentTracer(tracer); + auto current_tracer = imperative::GetCurrentTracer(); + ASSERT_EQ(current_tracer, tracer); +} + +TEST(test_tracer, test_expected_place) { + // default expected place is CPUPlace + imperative::Tracer tracer; + ASSERT_EQ(platform::is_cpu_place(tracer.ExpectedPlace()), true); + // set to CUDAPlace + platform::CUDAPlace gpu_place(0); + tracer.SetExpectedPlace(gpu_place); + ASSERT_EQ(platform::is_gpu_place(tracer.ExpectedPlace()), true); +} + } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 04412c2cacf667e9b30312ae3acd804232b2c3c5..a81af74d559a3dc3403827e12adecb7b0c63bf2e 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -16,10 +16,18 @@ #include #include #include "paddle/fluid/platform/profiler.h" - namespace paddle { namespace imperative { +static std::shared_ptr g_current_tracer(nullptr); + +const std::shared_ptr& GetCurrentTracer() { return g_current_tracer; } + +void SetCurrentTracer(const std::shared_ptr& tracer) { + g_current_tracer = tracer; + VLOG(6) << "Set current tracer: " << g_current_tracer; +} + static void ClearNoNeedBufferInputs(OpBase* op) { auto& inferer = op->Info().NoNeedBufferVarsInferer(); if (!inferer) return; diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index b58f2a817476695598f35f4227e52e8f09476139..5d5988981bb37b93054863221c04d03823af919e 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -48,7 +48,9 @@ class Tracer { Tracer() : engine_(new BasicEngine()), program_desc_tracer_(new jit::ProgramDescTracer()), - generator_(new UniqueNameGenerator()) {} + generator_(new UniqueNameGenerator()) { + expected_place_ = platform::CPUPlace(); + } ~Tracer() = default; @@ -80,6 +82,17 @@ class Tracer { return generator_->Generate(key); } + platform::Place ExpectedPlace() const { return expected_place_; } + + template + void SetExpectedPlace(PlaceType place) { + expected_place_ = place; + } + + bool NoGrad() const { return no_grad_; } + + void SetNoGrad(bool no_grad) { no_grad_ = no_grad; } + private: static size_t GenerateUniqueId() { static std::atomic id{0}; @@ -91,7 +104,13 @@ class Tracer { std::unique_ptr program_desc_tracer_; bool enable_program_desc_tracing_{false}; std::unique_ptr generator_; + platform::Place expected_place_; + bool no_grad_{false}; }; +// To access static variable current_tracer +const std::shared_ptr& GetCurrentTracer(); +void SetCurrentTracer(const std::shared_ptr& tracer_); + } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 275c1f047b47e89874157f02e5d3937665115885..36739ae462f76286fa264ea6330866e353840919 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -229,6 +229,10 @@ void BindImperative(py::module *m_ptr) { m.def("_is_dygraph_debug_enabled", []() { return imperative::IsDebugEnabled(); }); m.def("_dygraph_debug_level", []() { return imperative::GetDebugLevel(); }); + m.def("_switch_tracer", + [](const std::shared_ptr &tracer) { + imperative::SetCurrentTracer(tracer); + }); py::class_>( m, "VarBase", @@ -332,12 +336,38 @@ void BindImperative(py::module *m_ptr) { &imperative::jit::ProgramDescTracer::CreateProgramDesc) .def("reset", &imperative::jit::ProgramDescTracer::Reset); - py::class_(m, "Tracer", "") + py::class_>( + m, "Tracer", + R"DOC()DOC") .def("__init__", [](imperative::Tracer &self) { new (&self) imperative::Tracer(); }) .def_property("_enable_program_desc_tracing", &imperative::Tracer::IsProgramDescTracingEnabled, &imperative::Tracer::SetEnableProgramDescTracing) + .def_property("_train_mode", &imperative::Tracer::NoGrad, + &imperative::Tracer::SetNoGrad) + .def_property( + "_expected_place", + [](const imperative::Tracer &self) -> py::object { + return py::cast(self.ExpectedPlace()); + }, + [](imperative::Tracer &self, const py::object &obj) { + if (py::isinstance(obj)) { + auto p = obj.cast(); + self.SetExpectedPlace(*p); + } else if (py::isinstance(obj)) { + auto p = obj.cast(); + self.SetExpectedPlace(*p); + } else if (py::isinstance(obj)) { + auto p = obj.cast(); + self.SetExpectedPlace(*p); + } else { + PADDLE_THROW( + "Incompatible Place Type: supports CUDAPlace, CPUPlace, " + "CUDAPinnedPlace, " + "but got Unknown Type!"); + } + }) .def("_get_program_desc_tracer", &imperative::Tracer::GetProgramDescTracer, py::return_value_policy::reference) diff --git a/python/paddle/fluid/core.py b/python/paddle/fluid/core.py index b0908534a7cd79b8b2f132041ce605a91aa5a6ca..930feeee2bba73f5a1a43bbbff075630e815870b 100644 --- a/python/paddle/fluid/core.py +++ b/python/paddle/fluid/core.py @@ -176,6 +176,7 @@ if avx_supported(): from .core_avx import _set_fuse_parameter_memory_size from .core_avx import _is_dygraph_debug_enabled from .core_avx import _dygraph_debug_level + from .core_avx import _switch_tracer from .core_avx import _set_paddle_lib_path from .core_avx import _save_static_dict from .core_avx import _load_static_dict @@ -210,6 +211,7 @@ if load_noavx: from .core_noavx import _set_fuse_parameter_memory_size from .core_noavx import _is_dygraph_debug_enabled from .core_noavx import _dygraph_debug_level + from .core_noavx import _switch_tracer from .core_noavx import _set_paddle_lib_path from .core_noavx import _save_static_dict from .core_noavx import _load_static_dict diff --git a/python/paddle/fluid/dygraph/base.py b/python/paddle/fluid/dygraph/base.py index 5a3774c8993dc703007701049c6d96d24144d4ba..03708c085f235666560bbb6f8892e1a78ecf7235 100644 --- a/python/paddle/fluid/dygraph/base.py +++ b/python/paddle/fluid/dygraph/base.py @@ -127,12 +127,14 @@ def guard(place=None): train = framework.Program() startup = framework.Program() tracer = Tracer() + core._switch_tracer(tracer) if place is None: if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) else: place = core.CPUPlace() + tracer._expected_place = place with framework.program_guard(train, startup): with framework.unique_name.guard():