未验证 提交 ea22515a 编写于 作者: Z Zeng Jinle 提交者: GitHub

pimpl to polish code, test=develop (#23597)

上级 42d67dac
...@@ -1015,34 +1015,28 @@ PartialGradEngine::PartialGradEngine( ...@@ -1015,34 +1015,28 @@ PartialGradEngine::PartialGradEngine(
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars, const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
const platform::Place &place, const detail::BackwardStrategy &strategy, const platform::Place &place, const detail::BackwardStrategy &strategy,
bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs) bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs)
: input_targets_(input_targets), : task_(new PartialGradTask(input_targets, output_targets, output_grads,
output_targets_(output_targets), no_grad_vars, place, strategy, create_graph,
output_grads_(output_grads), retain_graph, allow_unused, only_inputs)) {}
no_grad_vars_(no_grad_vars),
place_(place), PartialGradEngine::~PartialGradEngine() { Clear(); }
strategy_(strategy),
create_graph_(create_graph),
retain_graph_(retain_graph),
allow_unused_(allow_unused),
only_inputs_(only_inputs) {}
std::vector<std::shared_ptr<VarBase>> PartialGradEngine::GetResult() const { std::vector<std::shared_ptr<VarBase>> PartialGradEngine::GetResult() const {
return results_; return results_;
} }
void PartialGradEngine::Clear() { void PartialGradEngine::Clear() {
input_targets_.clear(); if (task_) {
output_targets_.clear(); delete task_;
output_grads_.clear(); task_ = nullptr;
no_grad_vars_.clear(); }
} }
void PartialGradEngine::Execute() { void PartialGradEngine::Execute() {
PartialGradTask task(input_targets_, output_targets_, output_grads_, PADDLE_ENFORCE_NOT_NULL(task_, platform::errors::PermissionDenied(
no_grad_vars_, place_, strategy_, create_graph_, "PartialGradEngine has been destructed"));
retain_graph_, allow_unused_, only_inputs_);
VLOG(10) << "Starts to execute PartialGradEngine"; VLOG(10) << "Starts to execute PartialGradEngine";
results_ = task.Run(); results_ = task_->Run();
Clear(); Clear();
} }
......
...@@ -25,6 +25,8 @@ namespace imperative { ...@@ -25,6 +25,8 @@ namespace imperative {
class VarBase; class VarBase;
class PartialGradTask;
class PartialGradEngine : public Engine { class PartialGradEngine : public Engine {
public: public:
PartialGradEngine(const std::vector<std::shared_ptr<VarBase>> &input_targets, PartialGradEngine(const std::vector<std::shared_ptr<VarBase>> &input_targets,
...@@ -35,6 +37,8 @@ class PartialGradEngine : public Engine { ...@@ -35,6 +37,8 @@ class PartialGradEngine : public Engine {
const detail::BackwardStrategy &strategy, bool create_graph, const detail::BackwardStrategy &strategy, bool create_graph,
bool retain_graph, bool allow_unused, bool only_inputs); bool retain_graph, bool allow_unused, bool only_inputs);
~PartialGradEngine();
void Execute() override; void Execute() override;
std::vector<std::shared_ptr<VarBase>> GetResult() const; std::vector<std::shared_ptr<VarBase>> GetResult() const;
...@@ -43,17 +47,8 @@ class PartialGradEngine : public Engine { ...@@ -43,17 +47,8 @@ class PartialGradEngine : public Engine {
void Clear(); void Clear();
private: private:
std::vector<std::shared_ptr<VarBase>> input_targets_; // Pimpl for fast compilation and stable ABI
std::vector<std::shared_ptr<VarBase>> output_targets_; PartialGradTask *task_{nullptr};
std::vector<std::shared_ptr<VarBase>> output_grads_;
std::vector<std::shared_ptr<VarBase>> no_grad_vars_;
platform::Place place_;
detail::BackwardStrategy strategy_;
bool create_graph_;
bool retain_graph_;
bool allow_unused_;
bool only_inputs_;
std::vector<std::shared_ptr<VarBase>> results_; std::vector<std::shared_ptr<VarBase>> results_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册