From 0c23e3ff4d5022a0cb7fb3279b4f9511909a62c5 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Fri, 3 Apr 2020 06:09:10 -0500 Subject: [PATCH] fix Tracer::NoGrad, test=develop (#23443) --- paddle/fluid/imperative/tracer.cc | 2 +- paddle/fluid/imperative/tracer.h | 6 +++--- paddle/fluid/pybind/imperative.cc | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 9db241fb0e9..873963db1a1 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -70,7 +70,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, const NameVarBaseMap& outs, framework::AttributeMap attrs) { - TraceOp(type, ins, outs, std::move(attrs), expected_place_, no_grad_); + TraceOp(type, ins, outs, std::move(attrs), expected_place_, has_grad_); } bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins, diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index 90758c4acb9..49aa39d2b0f 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -86,9 +86,9 @@ class Tracer { void SetExpectedPlace(platform::Place place) { expected_place_ = place; } - bool NoGrad() const { return no_grad_; } + bool HasGrad() const { return has_grad_; } - void SetNoGrad(bool no_grad) { no_grad_ = no_grad; } + void SetHasGrad(bool has_grad) { has_grad_ = has_grad; } private: std::unique_ptr basic_engine_; @@ -96,7 +96,7 @@ class Tracer { bool enable_program_desc_tracing_{false}; std::unique_ptr generator_; platform::Place expected_place_; - bool no_grad_{false}; + bool has_grad_{true}; }; // To access static variable current_tracer diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 18b82292603..6b466c2639e 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -695,8 +695,8 @@ void BindImperative(py::module *m_ptr) { .def_property("_enable_program_desc_tracing", &imperative::Tracer::IsProgramDescTracingEnabled, &imperative::Tracer::SetEnableProgramDescTracing) - .def_property("_train_mode", &imperative::Tracer::NoGrad, - &imperative::Tracer::SetNoGrad) + .def_property("_train_mode", &imperative::Tracer::HasGrad, + &imperative::Tracer::SetHasGrad) .def_property( "_expected_place", [](const imperative::Tracer &self) -> py::object { -- GitLab