未验证 提交 0c23e3ff 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix Tracer::NoGrad, test=develop (#23443)

上级 ebae6fb6
...@@ -70,7 +70,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, ...@@ -70,7 +70,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const NameVarBaseMap& outs,
framework::AttributeMap attrs) { framework::AttributeMap attrs) {
TraceOp(type, ins, outs, std::move(attrs), expected_place_, no_grad_); TraceOp(type, ins, outs, std::move(attrs), expected_place_, has_grad_);
} }
bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins, bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins,
......
...@@ -86,9 +86,9 @@ class Tracer { ...@@ -86,9 +86,9 @@ class Tracer {
void SetExpectedPlace(platform::Place place) { expected_place_ = place; } void SetExpectedPlace(platform::Place place) { expected_place_ = place; }
bool NoGrad() const { return no_grad_; } bool HasGrad() const { return has_grad_; }
void SetNoGrad(bool no_grad) { no_grad_ = no_grad; } void SetHasGrad(bool has_grad) { has_grad_ = has_grad; }
private: private:
std::unique_ptr<BasicEngine> basic_engine_; std::unique_ptr<BasicEngine> basic_engine_;
...@@ -96,7 +96,7 @@ class Tracer { ...@@ -96,7 +96,7 @@ class Tracer {
bool enable_program_desc_tracing_{false}; bool enable_program_desc_tracing_{false};
std::unique_ptr<UniqueNameGenerator> generator_; std::unique_ptr<UniqueNameGenerator> generator_;
platform::Place expected_place_; platform::Place expected_place_;
bool no_grad_{false}; bool has_grad_{true};
}; };
// To access static variable current_tracer // To access static variable current_tracer
......
...@@ -695,8 +695,8 @@ void BindImperative(py::module *m_ptr) { ...@@ -695,8 +695,8 @@ void BindImperative(py::module *m_ptr) {
.def_property("_enable_program_desc_tracing", .def_property("_enable_program_desc_tracing",
&imperative::Tracer::IsProgramDescTracingEnabled, &imperative::Tracer::IsProgramDescTracingEnabled,
&imperative::Tracer::SetEnableProgramDescTracing) &imperative::Tracer::SetEnableProgramDescTracing)
.def_property("_train_mode", &imperative::Tracer::NoGrad, .def_property("_train_mode", &imperative::Tracer::HasGrad,
&imperative::Tracer::SetNoGrad) &imperative::Tracer::SetHasGrad)
.def_property( .def_property(
"_expected_place", "_expected_place",
[](const imperative::Tracer &self) -> py::object { [](const imperative::Tracer &self) -> py::object {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册