未验证 提交 ea22515a 编写于 作者: Z Zeng Jinle 提交者: GitHub

pimpl to polish code, test=develop (#23597)

上级 42d67dac
......@@ -1015,34 +1015,28 @@ PartialGradEngine::PartialGradEngine(
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
const platform::Place &place, const detail::BackwardStrategy &strategy,
bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs)
: input_targets_(input_targets),
output_targets_(output_targets),
output_grads_(output_grads),
no_grad_vars_(no_grad_vars),
place_(place),
strategy_(strategy),
create_graph_(create_graph),
retain_graph_(retain_graph),
allow_unused_(allow_unused),
only_inputs_(only_inputs) {}
: task_(new PartialGradTask(input_targets, output_targets, output_grads,
no_grad_vars, place, strategy, create_graph,
retain_graph, allow_unused, only_inputs)) {}
PartialGradEngine::~PartialGradEngine() { Clear(); }
std::vector<std::shared_ptr<VarBase>> PartialGradEngine::GetResult() const {
return results_;
}
void PartialGradEngine::Clear() {
input_targets_.clear();
output_targets_.clear();
output_grads_.clear();
no_grad_vars_.clear();
if (task_) {
delete task_;
task_ = nullptr;
}
}
void PartialGradEngine::Execute() {
PartialGradTask task(input_targets_, output_targets_, output_grads_,
no_grad_vars_, place_, strategy_, create_graph_,
retain_graph_, allow_unused_, only_inputs_);
PADDLE_ENFORCE_NOT_NULL(task_, platform::errors::PermissionDenied(
"PartialGradEngine has been destructed"));
VLOG(10) << "Starts to execute PartialGradEngine";
results_ = task.Run();
results_ = task_->Run();
Clear();
}
......
......@@ -25,6 +25,8 @@ namespace imperative {
class VarBase;
class PartialGradTask;
class PartialGradEngine : public Engine {
public:
PartialGradEngine(const std::vector<std::shared_ptr<VarBase>> &input_targets,
......@@ -35,6 +37,8 @@ class PartialGradEngine : public Engine {
const detail::BackwardStrategy &strategy, bool create_graph,
bool retain_graph, bool allow_unused, bool only_inputs);
~PartialGradEngine();
void Execute() override;
std::vector<std::shared_ptr<VarBase>> GetResult() const;
......@@ -43,17 +47,8 @@ class PartialGradEngine : public Engine {
void Clear();
private:
std::vector<std::shared_ptr<VarBase>> input_targets_;
std::vector<std::shared_ptr<VarBase>> output_targets_;
std::vector<std::shared_ptr<VarBase>> output_grads_;
std::vector<std::shared_ptr<VarBase>> no_grad_vars_;
platform::Place place_;
detail::BackwardStrategy strategy_;
bool create_graph_;
bool retain_graph_;
bool allow_unused_;
bool only_inputs_;
// Pimpl for fast compilation and stable ABI
PartialGradTask *task_{nullptr};
std::vector<std::shared_ptr<VarBase>> results_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册