提交 008ed65f 编写于 作者: L Leo Chen 提交者: Zeng Jinle

Add c++ global current tracer for dygraph (#20882)

* Add c++ global current tracer for dygraph, test=develop

* add tracer property in c++, test=develop

* support different place, test=develop

* add unittest for tracer, test=develop
上级 5aae5959
...@@ -278,6 +278,24 @@ TEST(test_tracer, test_unique_name_generator) { ...@@ -278,6 +278,24 @@ TEST(test_tracer, test_unique_name_generator) {
ASSERT_STREQ("fc_2", fc_2.c_str()); ASSERT_STREQ("fc_2", fc_2.c_str());
} }
TEST(test_tracer, test_current_tracer) {
// use current_tracer
auto tracer = std::make_shared<imperative::Tracer>();
imperative::SetCurrentTracer(tracer);
auto current_tracer = imperative::GetCurrentTracer();
ASSERT_EQ(current_tracer, tracer);
}
TEST(test_tracer, test_expected_place) {
// default expected place is CPUPlace
imperative::Tracer tracer;
ASSERT_EQ(platform::is_cpu_place(tracer.ExpectedPlace()), true);
// set to CUDAPlace
platform::CUDAPlace gpu_place(0);
tracer.SetExpectedPlace(gpu_place);
ASSERT_EQ(platform::is_gpu_place(tracer.ExpectedPlace()), true);
}
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
......
...@@ -16,10 +16,18 @@ ...@@ -16,10 +16,18 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
static std::shared_ptr<Tracer> g_current_tracer(nullptr);
const std::shared_ptr<Tracer>& GetCurrentTracer() { return g_current_tracer; }
void SetCurrentTracer(const std::shared_ptr<Tracer>& tracer) {
g_current_tracer = tracer;
VLOG(6) << "Set current tracer: " << g_current_tracer;
}
static void ClearNoNeedBufferInputs(OpBase* op) { static void ClearNoNeedBufferInputs(OpBase* op) {
auto& inferer = op->Info().NoNeedBufferVarsInferer(); auto& inferer = op->Info().NoNeedBufferVarsInferer();
if (!inferer) return; if (!inferer) return;
......
...@@ -48,7 +48,9 @@ class Tracer { ...@@ -48,7 +48,9 @@ class Tracer {
Tracer() Tracer()
: engine_(new BasicEngine()), : engine_(new BasicEngine()),
program_desc_tracer_(new jit::ProgramDescTracer()), program_desc_tracer_(new jit::ProgramDescTracer()),
generator_(new UniqueNameGenerator()) {} generator_(new UniqueNameGenerator()) {
expected_place_ = platform::CPUPlace();
}
~Tracer() = default; ~Tracer() = default;
...@@ -80,6 +82,17 @@ class Tracer { ...@@ -80,6 +82,17 @@ class Tracer {
return generator_->Generate(key); return generator_->Generate(key);
} }
platform::Place ExpectedPlace() const { return expected_place_; }
template <typename PlaceType>
void SetExpectedPlace(PlaceType place) {
expected_place_ = place;
}
bool NoGrad() const { return no_grad_; }
void SetNoGrad(bool no_grad) { no_grad_ = no_grad; }
private: private:
static size_t GenerateUniqueId() { static size_t GenerateUniqueId() {
static std::atomic<size_t> id{0}; static std::atomic<size_t> id{0};
...@@ -91,7 +104,13 @@ class Tracer { ...@@ -91,7 +104,13 @@ class Tracer {
std::unique_ptr<jit::ProgramDescTracer> program_desc_tracer_; std::unique_ptr<jit::ProgramDescTracer> program_desc_tracer_;
bool enable_program_desc_tracing_{false}; bool enable_program_desc_tracing_{false};
std::unique_ptr<UniqueNameGenerator> generator_; std::unique_ptr<UniqueNameGenerator> generator_;
platform::Place expected_place_;
bool no_grad_{false};
}; };
// To access static variable current_tracer
const std::shared_ptr<Tracer>& GetCurrentTracer();
void SetCurrentTracer(const std::shared_ptr<Tracer>& tracer_);
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -229,6 +229,10 @@ void BindImperative(py::module *m_ptr) { ...@@ -229,6 +229,10 @@ void BindImperative(py::module *m_ptr) {
m.def("_is_dygraph_debug_enabled", m.def("_is_dygraph_debug_enabled",
[]() { return imperative::IsDebugEnabled(); }); []() { return imperative::IsDebugEnabled(); });
m.def("_dygraph_debug_level", []() { return imperative::GetDebugLevel(); }); m.def("_dygraph_debug_level", []() { return imperative::GetDebugLevel(); });
m.def("_switch_tracer",
[](const std::shared_ptr<imperative::Tracer> &tracer) {
imperative::SetCurrentTracer(tracer);
});
py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>( py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
m, "VarBase", m, "VarBase",
...@@ -332,12 +336,38 @@ void BindImperative(py::module *m_ptr) { ...@@ -332,12 +336,38 @@ void BindImperative(py::module *m_ptr) {
&imperative::jit::ProgramDescTracer::CreateProgramDesc) &imperative::jit::ProgramDescTracer::CreateProgramDesc)
.def("reset", &imperative::jit::ProgramDescTracer::Reset); .def("reset", &imperative::jit::ProgramDescTracer::Reset);
py::class_<imperative::Tracer>(m, "Tracer", "") py::class_<imperative::Tracer, std::shared_ptr<imperative::Tracer>>(
m, "Tracer",
R"DOC()DOC")
.def("__init__", .def("__init__",
[](imperative::Tracer &self) { new (&self) imperative::Tracer(); }) [](imperative::Tracer &self) { new (&self) imperative::Tracer(); })
.def_property("_enable_program_desc_tracing", .def_property("_enable_program_desc_tracing",
&imperative::Tracer::IsProgramDescTracingEnabled, &imperative::Tracer::IsProgramDescTracingEnabled,
&imperative::Tracer::SetEnableProgramDescTracing) &imperative::Tracer::SetEnableProgramDescTracing)
.def_property("_train_mode", &imperative::Tracer::NoGrad,
&imperative::Tracer::SetNoGrad)
.def_property(
"_expected_place",
[](const imperative::Tracer &self) -> py::object {
return py::cast(self.ExpectedPlace());
},
[](imperative::Tracer &self, const py::object &obj) {
if (py::isinstance<platform::CUDAPlace>(obj)) {
auto p = obj.cast<platform::CUDAPlace *>();
self.SetExpectedPlace<platform::CUDAPlace>(*p);
} else if (py::isinstance<platform::CPUPlace>(obj)) {
auto p = obj.cast<platform::CPUPlace *>();
self.SetExpectedPlace<platform::CPUPlace>(*p);
} else if (py::isinstance<platform::CUDAPinnedPlace>(obj)) {
auto p = obj.cast<platform::CUDAPinnedPlace *>();
self.SetExpectedPlace<platform::CUDAPinnedPlace>(*p);
} else {
PADDLE_THROW(
"Incompatible Place Type: supports CUDAPlace, CPUPlace, "
"CUDAPinnedPlace, "
"but got Unknown Type!");
}
})
.def("_get_program_desc_tracer", .def("_get_program_desc_tracer",
&imperative::Tracer::GetProgramDescTracer, &imperative::Tracer::GetProgramDescTracer,
py::return_value_policy::reference) py::return_value_policy::reference)
......
...@@ -176,6 +176,7 @@ if avx_supported(): ...@@ -176,6 +176,7 @@ if avx_supported():
from .core_avx import _set_fuse_parameter_memory_size from .core_avx import _set_fuse_parameter_memory_size
from .core_avx import _is_dygraph_debug_enabled from .core_avx import _is_dygraph_debug_enabled
from .core_avx import _dygraph_debug_level from .core_avx import _dygraph_debug_level
from .core_avx import _switch_tracer
from .core_avx import _set_paddle_lib_path from .core_avx import _set_paddle_lib_path
from .core_avx import _save_static_dict from .core_avx import _save_static_dict
from .core_avx import _load_static_dict from .core_avx import _load_static_dict
...@@ -210,6 +211,7 @@ if load_noavx: ...@@ -210,6 +211,7 @@ if load_noavx:
from .core_noavx import _set_fuse_parameter_memory_size from .core_noavx import _set_fuse_parameter_memory_size
from .core_noavx import _is_dygraph_debug_enabled from .core_noavx import _is_dygraph_debug_enabled
from .core_noavx import _dygraph_debug_level from .core_noavx import _dygraph_debug_level
from .core_noavx import _switch_tracer
from .core_noavx import _set_paddle_lib_path from .core_noavx import _set_paddle_lib_path
from .core_noavx import _save_static_dict from .core_noavx import _save_static_dict
from .core_noavx import _load_static_dict from .core_noavx import _load_static_dict
......
...@@ -127,12 +127,14 @@ def guard(place=None): ...@@ -127,12 +127,14 @@ def guard(place=None):
train = framework.Program() train = framework.Program()
startup = framework.Program() startup = framework.Program()
tracer = Tracer() tracer = Tracer()
core._switch_tracer(tracer)
if place is None: if place is None:
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
place = core.CUDAPlace(0) place = core.CUDAPlace(0)
else: else:
place = core.CPUPlace() place = core.CPUPlace()
tracer._expected_place = place
with framework.program_guard(train, startup): with framework.program_guard(train, startup):
with framework.unique_name.guard(): with framework.unique_name.guard():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册