diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index 135d54d1b77e977f91349ca0f22ce2821c438320..342d046db73ea065c2605c98c06aa33d41b892e1 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -1015,34 +1015,28 @@ PartialGradEngine::PartialGradEngine( const std::vector> &no_grad_vars, const platform::Place &place, const detail::BackwardStrategy &strategy, bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs) - : input_targets_(input_targets), - output_targets_(output_targets), - output_grads_(output_grads), - no_grad_vars_(no_grad_vars), - place_(place), - strategy_(strategy), - create_graph_(create_graph), - retain_graph_(retain_graph), - allow_unused_(allow_unused), - only_inputs_(only_inputs) {} + : task_(new PartialGradTask(input_targets, output_targets, output_grads, + no_grad_vars, place, strategy, create_graph, + retain_graph, allow_unused, only_inputs)) {} + +PartialGradEngine::~PartialGradEngine() { Clear(); } std::vector> PartialGradEngine::GetResult() const { return results_; } void PartialGradEngine::Clear() { - input_targets_.clear(); - output_targets_.clear(); - output_grads_.clear(); - no_grad_vars_.clear(); + if (task_) { + delete task_; + task_ = nullptr; + } } void PartialGradEngine::Execute() { - PartialGradTask task(input_targets_, output_targets_, output_grads_, - no_grad_vars_, place_, strategy_, create_graph_, - retain_graph_, allow_unused_, only_inputs_); + PADDLE_ENFORCE_NOT_NULL(task_, platform::errors::PermissionDenied( + "PartialGradEngine has been destructed")); VLOG(10) << "Starts to execute PartialGradEngine"; - results_ = task.Run(); + results_ = task_->Run(); Clear(); } diff --git a/paddle/fluid/imperative/partial_grad_engine.h b/paddle/fluid/imperative/partial_grad_engine.h index 1a0fdcca4f839161950e83e69f63280c89473856..a7f28c49ec3950674cd43127f51934089a497412 100644 --- a/paddle/fluid/imperative/partial_grad_engine.h +++ b/paddle/fluid/imperative/partial_grad_engine.h @@ -25,6 +25,8 @@ namespace imperative { class VarBase; +class PartialGradTask; + class PartialGradEngine : public Engine { public: PartialGradEngine(const std::vector> &input_targets, @@ -35,6 +37,8 @@ class PartialGradEngine : public Engine { const detail::BackwardStrategy &strategy, bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs); + ~PartialGradEngine(); + void Execute() override; std::vector> GetResult() const; @@ -43,17 +47,8 @@ class PartialGradEngine : public Engine { void Clear(); private: - std::vector> input_targets_; - std::vector> output_targets_; - std::vector> output_grads_; - std::vector> no_grad_vars_; - platform::Place place_; - detail::BackwardStrategy strategy_; - bool create_graph_; - bool retain_graph_; - bool allow_unused_; - bool only_inputs_; - + // Pimpl for fast compilation and stable ABI + PartialGradTask *task_{nullptr}; std::vector> results_; };