Created by: sneaxiy
Refactor tracer by adding VariableWrapper
to avoid circular reference of std::shared_ptr
. BTW, remove the ownership of VarBase
and OpBase
apart from Engine
, so that all operators can be destructed automatically when all variables are destructed. Users do not need to set layer.eval()
when doing inference to avoid memory leak.
187 209 public: 188 210 using GradOpBaseMakerBase::GradOpBaseMakerBase; 189 211 190 public: 191 std::vector<std::unique_ptr<imperative::OpBase>> operator()() const { 192 std::vector<std::unique_ptr<imperative::OpBase>> retv; 193 retv.emplace_back(this->Apply()); 194 212 std::vector<std::shared_ptr<imperative::OpBase>> operator()() const { 189 190 ~TracedGradOp() { 191 op_->SetGradPendingOps( 192 {grad_pending_ops_.begin(), grad_pending_ops_.end()}); 193 op_->CheckAttrs(); 194 } 195 196 template <TracedVarRole kRole> 197 void SetInput(const std::string& name, 198 const TracedVarList<VarBase, kRole>& vars) { 199 if (kRole == TracedVarRole::kBackward) { 200 for (auto& var : vars) { 201 var->AddGradOp(op_); 202 } 203 } 204 op_->SetInput(name, ToVarWrapperList(vars)); 189 190 ~TracedGradOp() { 191 op_->SetGradPendingOps( 192 {grad_pending_ops_.begin(), grad_pending_ops_.end()}); 193 op_->CheckAttrs(); 194 } 195 196 template <TracedVarRole kRole> 197 void SetInput(const std::string& name, 198 const TracedVarList<VarBase, kRole>& vars) { 199 if (kRole == TracedVarRole::kBackward) { 200 for (auto& var : vars) { 201 var->AddGradOp(op_); 202 } 203 } 204 op_->SetInput(name, ToVarWrapperList(vars)); 187 209 public: 188 210 using GradOpBaseMakerBase::GradOpBaseMakerBase; 189 211 190 public: 191 std::vector<std::unique_ptr<imperative::OpBase>> operator()() const { 192 std::vector<std::unique_ptr<imperative::OpBase>> retv; 193 retv.emplace_back(this->Apply()); 194 212 std::vector<std::shared_ptr<imperative::OpBase>> operator()() const {