From ea22515ae4ba1b21cc2d10869f8efda19659ce11 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 9 Apr 2020 05:41:02 -0500 Subject: [PATCH] pimpl to polish code, test=develop (#23597) --- .../fluid/imperative/partial_grad_engine.cc | 30 ++++++++----------- paddle/fluid/imperative/partial_grad_engine.h | 17 ++++------- 2 files changed, 18 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index 135d54d1b7..342d046db7 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -1015,34 +1015,28 @@ PartialGradEngine::PartialGradEngine( const std::vector> &no_grad_vars, const platform::Place &place, const detail::BackwardStrategy &strategy, bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs) - : input_targets_(input_targets), - output_targets_(output_targets), - output_grads_(output_grads), - no_grad_vars_(no_grad_vars), - place_(place), - strategy_(strategy), - create_graph_(create_graph), - retain_graph_(retain_graph), - allow_unused_(allow_unused), - only_inputs_(only_inputs) {} + : task_(new PartialGradTask(input_targets, output_targets, output_grads, + no_grad_vars, place, strategy, create_graph, + retain_graph, allow_unused, only_inputs)) {} + +PartialGradEngine::~PartialGradEngine() { Clear(); } std::vector> PartialGradEngine::GetResult() const { return results_; } void PartialGradEngine::Clear() { - input_targets_.clear(); - output_targets_.clear(); - output_grads_.clear(); - no_grad_vars_.clear(); + if (task_) { + delete task_; + task_ = nullptr; + } } void PartialGradEngine::Execute() { - PartialGradTask task(input_targets_, output_targets_, output_grads_, - no_grad_vars_, place_, strategy_, create_graph_, - retain_graph_, allow_unused_, only_inputs_); + PADDLE_ENFORCE_NOT_NULL(task_, platform::errors::PermissionDenied( + "PartialGradEngine has been destructed")); VLOG(10) << "Starts to execute PartialGradEngine"; - results_ = task.Run(); + results_ = task_->Run(); Clear(); } diff --git a/paddle/fluid/imperative/partial_grad_engine.h b/paddle/fluid/imperative/partial_grad_engine.h index 1a0fdcca4f..a7f28c49ec 100644 --- a/paddle/fluid/imperative/partial_grad_engine.h +++ b/paddle/fluid/imperative/partial_grad_engine.h @@ -25,6 +25,8 @@ namespace imperative { class VarBase; +class PartialGradTask; + class PartialGradEngine : public Engine { public: PartialGradEngine(const std::vector> &input_targets, @@ -35,6 +37,8 @@ class PartialGradEngine : public Engine { const detail::BackwardStrategy &strategy, bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs); + ~PartialGradEngine(); + void Execute() override; std::vector> GetResult() const; @@ -43,17 +47,8 @@ class PartialGradEngine : public Engine { void Clear(); private: - std::vector> input_targets_; - std::vector> output_targets_; - std::vector> output_grads_; - std::vector> no_grad_vars_; - platform::Place place_; - detail::BackwardStrategy strategy_; - bool create_graph_; - bool retain_graph_; - bool allow_unused_; - bool only_inputs_; - + // Pimpl for fast compilation and stable ABI + PartialGradTask *task_{nullptr}; std::vector> results_; }; -- GitLab