未验证 提交 c736fef9 编写于 作者: H hong 提交者: GitHub

dygraph backward engine accelerate (#22808)

* fix loaded program load bug; test=develop

* first version

* speed backward engin; test=develop

* remove useless code; test=develop

* reconvery io.py; test=develop

* remove useless code; test=develop

* remove useless code; test=develop
上级 d41d802b
...@@ -34,8 +34,6 @@ void Engine::RunOp(paddle::imperative::OpBase* op, ...@@ -34,8 +34,6 @@ void Engine::RunOp(paddle::imperative::OpBase* op,
const paddle::imperative::NameVarBaseMap& ins, const paddle::imperative::NameVarBaseMap& ins,
const paddle::imperative::NameVarBaseMap& outs, const paddle::imperative::NameVarBaseMap& outs,
const paddle::platform::Place& place) { const paddle::platform::Place& place) {
platform::RecordEvent event(op->Type());
op->Run(ins, outs); op->Run(ins, outs);
} }
...@@ -62,7 +60,6 @@ void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) { ...@@ -62,7 +60,6 @@ void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) {
} }
} }
init_ops_ = ops; init_ops_ = ops;
platform::RecordEvent record_event("Imperative Backward");
VLOG(3) << "start backward"; VLOG(3) << "start backward";
PADDLE_ENFORCE_EQ(var->HasGradVar(), true, PADDLE_ENFORCE_EQ(var->HasGradVar(), true,
...@@ -194,39 +191,36 @@ void BasicEngine::Execute() { ...@@ -194,39 +191,36 @@ void BasicEngine::Execute() {
auto& bwd_ins = cur_op->GetInsMap(); auto& bwd_ins = cur_op->GetInsMap();
auto& bwd_outs = cur_op->GetOutsMap(); auto& bwd_outs = cur_op->GetOutsMap();
NameVarBaseMap tmp_outs; NameVarBaseMap tmp_outs(bwd_outs);
// 1. construct the output map 2. replace the element in the map
// A var may be coresponding to several grad var in one op // A var may be coresponding to several grad var in one op
std::unordered_map<VarBase*, std::vector<std::shared_ptr<VarBase>>> var_map; for (auto it = tmp_outs.begin(); it != tmp_outs.end(); ++it) {
for (auto& bwd_out : bwd_outs) { for (size_t i = 0; i < it->second.size(); ++i) {
auto& tmp_var_list = tmp_outs[bwd_out.first];
tmp_var_list.reserve(bwd_out.second.size());
for (auto& var : bwd_out.second) {
auto tmp_var = auto tmp_var =
std::make_shared<VarBase>(false, "Gtmp@"); // Do not need grad std::make_shared<VarBase>(false, "Gtmp@"); // Do not need grad
tmp_var_list.emplace_back(tmp_var);
if (var) {
var_map[var.get()].emplace_back(std::move(tmp_var));
auto var = it->second[i];
it->second[i] = tmp_var;
if (var) {
need_accu_var_list_.emplace_back(
make_pair(var.get(), std::move(tmp_var)));
var->ClearGradOps(); var->ClearGradOps();
} }
} }
} }
VLOG(3) << "Start to execute grad op " << cur_op->Type(); VLOG(3) << "Start to execute grad op " << cur_op->Type();
RunOp(cur_op, bwd_ins, tmp_outs, cur_op->place()); RunOp(cur_op, bwd_ins, tmp_outs, cur_op->place());
// Step 2: Sum Gradient // Step 2: Sum Gradient
{
platform::RecordEvent record_event("merge_grads"); if (need_accu_var_list_.size() > 0) {
for (auto& var_pair : var_map) { for (auto& pair : need_accu_var_list_) {
auto* dst_var = var_pair.first; SumGradient(cur_op, std::move(pair.second), pair.first);
if (dst_var == nullptr) continue;
for (auto& src_var : var_pair.second) {
VLOG(3) << "Sum gradient of variable " << dst_var->Name()
<< " after op " << cur_op->Type();
SumGradient(cur_op, std::move(src_var), dst_var);
}
} }
} }
need_accu_var_list_.clear();
// Step 3: Collect ready ops // Step 3: Collect ready ops
for (auto* grad_pending_op : cur_op->GradPendingOps()) { for (auto* grad_pending_op : cur_op->GradPendingOps()) {
......
...@@ -107,6 +107,9 @@ class BasicEngine : public Engine { ...@@ -107,6 +107,9 @@ class BasicEngine : public Engine {
std::unordered_map<OpBase*, size_t> op_deps_; std::unordered_map<OpBase*, size_t> op_deps_;
std::unordered_map<VarBase*, std::unique_ptr<GradientAccumulator>> std::unordered_map<VarBase*, std::unique_ptr<GradientAccumulator>>
accumulators_; accumulators_;
std::vector<std::pair<VarBase*, std::shared_ptr<VarBase>>>
need_accu_var_list_;
}; };
} // namespace imperative } // namespace imperative
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册