未验证 提交 0c23e3ff 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix Tracer::NoGrad, test=develop (#23443)

上级 ebae6fb6
......@@ -70,7 +70,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs,
framework::AttributeMap attrs) {
TraceOp(type, ins, outs, std::move(attrs), expected_place_, no_grad_);
TraceOp(type, ins, outs, std::move(attrs), expected_place_, has_grad_);
}
bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins,
......
......@@ -86,9 +86,9 @@ class Tracer {
void SetExpectedPlace(platform::Place place) { expected_place_ = place; }
bool NoGrad() const { return no_grad_; }
bool HasGrad() const { return has_grad_; }
void SetNoGrad(bool no_grad) { no_grad_ = no_grad; }
void SetHasGrad(bool has_grad) { has_grad_ = has_grad; }
private:
std::unique_ptr<BasicEngine> basic_engine_;
......@@ -96,7 +96,7 @@ class Tracer {
bool enable_program_desc_tracing_{false};
std::unique_ptr<UniqueNameGenerator> generator_;
platform::Place expected_place_;
bool no_grad_{false};
bool has_grad_{true};
};
// To access static variable current_tracer
......
......@@ -695,8 +695,8 @@ void BindImperative(py::module *m_ptr) {
.def_property("_enable_program_desc_tracing",
&imperative::Tracer::IsProgramDescTracingEnabled,
&imperative::Tracer::SetEnableProgramDescTracing)
.def_property("_train_mode", &imperative::Tracer::NoGrad,
&imperative::Tracer::SetNoGrad)
.def_property("_train_mode", &imperative::Tracer::HasGrad,
&imperative::Tracer::SetHasGrad)
.def_property(
"_expected_place",
[](const imperative::Tracer &self) -> py::object {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册