未验证 提交 fcffd84d 编写于 作者: Z zhaoyingli 提交者: GitHub

add gc for multi jobs (#54897)

* add gc for multi jobs

* fix job.h

* update OpInfo to OpInOutInfo

* update get_skip_gc_vars algo order
上级 bd67209f
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <glog/logging.h> #include <glog/logging.h>
#include <set>
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h" #include "paddle/phi/core/errors.h"
...@@ -36,6 +37,8 @@ class Job final { ...@@ -36,6 +37,8 @@ class Job final {
int64_t MicroBatchId() const { return micro_batch_id_; } int64_t MicroBatchId() const { return micro_batch_id_; }
std::set<std::string> SkipGcVars() const { return skip_gc_vars_; }
std::vector<int> AllFetchOpIds() const { std::vector<int> AllFetchOpIds() const {
std::vector<int> fetch_op_ids; std::vector<int> fetch_op_ids;
fetch_op_ids.reserve(fetch_op_id_to_col_attr_.size()); fetch_op_ids.reserve(fetch_op_id_to_col_attr_.size());
...@@ -58,10 +61,21 @@ class Job final { ...@@ -58,10 +61,21 @@ class Job final {
micro_batch_id_ = micro_batch_id; micro_batch_id_ = micro_batch_id;
} }
void SetSkipGcVars(const std::set<std::string>& skip_gc_vars) {
PADDLE_ENFORCE_EQ(skip_gc_vars_.empty(),
true,
phi::errors::InvalidArgument(
"skip_gc_vars_ can only be initialized once, now "
"skip_gc_vars_ is not empty, "
"do not call SetSkipGcVars method repeatedly."));
skip_gc_vars_ = skip_gc_vars;
}
private: private:
const std::string type_; const std::string type_;
int64_t micro_batch_id_; int64_t micro_batch_id_;
std::unordered_map<int, int> fetch_op_id_to_col_attr_; std::unordered_map<int, int> fetch_op_id_to_col_attr_;
std::set<std::string> skip_gc_vars_;
}; };
} // namespace interpreter } // namespace interpreter
......
...@@ -59,12 +59,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, ...@@ -59,12 +59,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place,
interpreter::ExecutionConfig execution_config; interpreter::ExecutionConfig execution_config;
execution_config.create_local_scope = false; execution_config.create_local_scope = false;
// TODO(Ruibiao): hack skip gc all vars for multiple jobs, improve it later execution_config.skip_gc_vars = job->SkipGcVars();
if (jobs.size() > 1) {
for (VarDesc* var : program->Block(0).AllVars()) {
execution_config.skip_gc_vars.insert(var->Name());
}
}
if (FLAGS_enable_new_ir_in_executor) { if (FLAGS_enable_new_ir_in_executor) {
VLOG(6) << "begin to translate" << std::endl; VLOG(6) << "begin to translate" << std::endl;
......
...@@ -1874,7 +1874,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1874,7 +1874,8 @@ All parameter, weight, gradient are variables in Paddle.
.def("type", &framework::interpreter::Job::Type) .def("type", &framework::interpreter::Job::Type)
.def("set_col_attr_for_fetch_op", .def("set_col_attr_for_fetch_op",
&framework::interpreter::Job::SetColAttrForFetchOp) &framework::interpreter::Job::SetColAttrForFetchOp)
.def("set_micro_batch_id", &framework::interpreter::Job::SetMicroBatchId); .def("set_micro_batch_id", &framework::interpreter::Job::SetMicroBatchId)
.def("set_skip_gc_vars", &framework::interpreter::Job::SetSkipGcVars);
py::class_<framework::interpreter::Plan>(m, "Plan") py::class_<framework::interpreter::Plan>(m, "Plan")
.def( .def(
......
...@@ -13,6 +13,10 @@ ...@@ -13,6 +13,10 @@
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
from typing import List
from paddle.fluid import core
from paddle.fluid.framework import Program
def list_to_ordered_dict(list_obj, ordered_dict=None): def list_to_ordered_dict(list_obj, ordered_dict=None):
...@@ -133,3 +137,109 @@ def split_program(program, op_indices): ...@@ -133,3 +137,109 @@ def split_program(program, op_indices):
break break
valid_output_vars = [list(item.keys()) for item in valid_output_vars] valid_output_vars = [list(item.keys()) for item in valid_output_vars]
return splitted_programs, input_vars, valid_output_vars return splitted_programs, input_vars, valid_output_vars
class OpInOutInfo:
"""
Record unused buffer input_vars of op and other var_names except unused buffer input_vars
"""
def __init__(self):
self._is_build = False
self._no_need_buffer_slots = set()
self._other_arg_names_set = set()
@property
def is_build(self):
return self._is_build
def _get_op_attrs(self, op):
inputs = {}
for input_name in op.input_names:
inputs[input_name] = op.input(input_name)
outputs = {}
for output_name in op.output_names:
outputs[output_name] = op.output(output_name)
attrs = {}
for attr_name in op.attr_names:
attrs[attr_name] = op.attr(attr_name)
return inputs, outputs, attrs
def build_info(self, op):
inputs, outputs, attrs = self._get_op_attrs(op)
self._no_need_buffer_slots = core.infer_no_need_buffer_slots(
op.type, inputs, outputs, attrs
)
if len(self._no_need_buffer_slots) == 0:
return
for slot_name in op.input_names:
if slot_name in self._no_need_buffer_slots:
continue
for in_name in op.input(slot_name):
self._other_arg_names_set.add(in_name)
for slot_name in op.output_names:
for out_name in op.output(slot_name):
self._other_arg_names_set.add(out_name)
self._is_build = True
def is_needed(self, arg_name):
return (
len(self._no_need_buffer_slots) == 0
or arg_name in self._other_arg_names_set
)
def var_can_be_deleted(var_name, program):
var = program.global_block()._find_var_recursive(var_name)
if var is None or var.persistable:
return False
return var.type in [
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY,
]
def get_skip_gc_vars(program_list: List[Program]):
"""
Get `skip_gc_vars` for every sub_program of program_list.
A whole_program is split up into sub_programs according to the schedule mode,
thus a sub_program's vars might be used as the op's input of the later sub_program,
and these vars cannot be gc after executing current sub_program.
"""
# step1: Get all vars of every sub_program of program_list that are non-persistable and not in op's no_need_buffer.
vars_list = [set() for _ in range(len(program_list))]
for ip, program in enumerate(program_list):
for op in program.global_block().ops:
op_info = OpInOutInfo()
for in_name in op.input_arg_names:
if not var_can_be_deleted(in_name, program):
continue
if not op_info.is_build:
op_info.build_info(op)
if op_info.is_needed(in_name):
vars_list[ip].add(in_name)
for out_name in op.output_arg_names:
if var_can_be_deleted(out_name, program):
vars_list[ip].add(out_name)
# step2: get the `skip_gc_vars` that vars of current sub_program might be used in the later sub_program
union_set = set()
skip_gc_vars = [set()] * len(program_list)
for idx, vars_set in reversed(list(enumerate(vars_list))):
if idx < len(vars_list) - 1:
union_set = union_set.union(vars_list[idx + 1])
skip_gc_vars[idx] = vars_set & union_set
return skip_gc_vars
...@@ -23,6 +23,7 @@ from paddle.fluid import core ...@@ -23,6 +23,7 @@ from paddle.fluid import core
from paddle.fluid.framework import Parameter, Program from paddle.fluid.framework import Parameter, Program
from .pass_base import PassBase, PassContext, new_pass, register_pass from .pass_base import PassBase, PassContext, new_pass, register_pass
from .pass_utils import get_skip_gc_vars
__not_shape_var_type__ = [ __not_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.READER,
...@@ -249,11 +250,20 @@ def _program_for_fthenb_and_1f1b(program): ...@@ -249,11 +250,20 @@ def _program_for_fthenb_and_1f1b(program):
bwd_prog._rollback() bwd_prog._rollback()
opt_prog._rollback() opt_prog._rollback()
lr_vars, fwd_vars, bwd_vars, opt_vars = get_skip_gc_vars(
[lr_prog, fwd_prog, bwd_prog, opt_prog]
)
return { return {
"lr": lr_prog.desc, "lr": lr_prog.desc,
"forward": fwd_prog.desc, "forward": fwd_prog.desc,
"backward": bwd_prog.desc, "backward": bwd_prog.desc,
"optimizer": opt_prog.desc, "optimizer": opt_prog.desc,
}, {
"lr": lr_vars,
"forward": fwd_vars,
"backward": bwd_vars,
"optimizer": opt_vars,
} }
...@@ -268,22 +278,26 @@ class PipelineFThenBPass(PassBase): ...@@ -268,22 +278,26 @@ class PipelineFThenBPass(PassBase):
def _check_conflict(self, other_pass): def _check_conflict(self, other_pass):
return True return True
def _create_job_list(self): def _create_job_list(self, type_to_skip_vars):
job_list = [] job_list = []
lr_job = core.Job("lr") lr_job = core.Job("lr")
lr_job.set_skip_gc_vars(type_to_skip_vars["lr"])
job_list.append(lr_job) job_list.append(lr_job)
for i in range(self._num_micro_batches): for i in range(self._num_micro_batches):
forward_job = core.Job("forward") forward_job = core.Job("forward")
forward_job.set_micro_batch_id(i) forward_job.set_micro_batch_id(i)
forward_job.set_skip_gc_vars(type_to_skip_vars["forward"])
job_list.append(forward_job) job_list.append(forward_job)
for i in range(self._num_micro_batches): for i in range(self._num_micro_batches):
backward_job = core.Job("backward") backward_job = core.Job("backward")
backward_job.set_micro_batch_id(i) backward_job.set_micro_batch_id(i)
backward_job.set_skip_gc_vars(type_to_skip_vars["backward"])
job_list.append(backward_job) job_list.append(backward_job)
opt_job = core.Job("optimizer") opt_job = core.Job("optimizer")
opt_job.set_skip_gc_vars(type_to_skip_vars["optimizer"])
job_list.append(opt_job) job_list.append(opt_job)
return job_list return job_list
...@@ -292,8 +306,10 @@ class PipelineFThenBPass(PassBase): ...@@ -292,8 +306,10 @@ class PipelineFThenBPass(PassBase):
self._program = main_program self._program = main_program
_insert_sync_for_fthenb_1f1b(self._program) _insert_sync_for_fthenb_1f1b(self._program)
type_to_program = _program_for_fthenb_and_1f1b(self._program) type_to_program, type_to_skip_vars = _program_for_fthenb_and_1f1b(
job_list = self._create_job_list() self._program
)
job_list = self._create_job_list(type_to_skip_vars)
plan = core.Plan(job_list, type_to_program) plan = core.Plan(job_list, type_to_program)
context.set_attr("plan", plan) context.set_attr("plan", plan)
......
...@@ -19,7 +19,7 @@ import unittest ...@@ -19,7 +19,7 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
from paddle.distributed.passes.pass_utils import split_program from paddle.distributed.passes.pass_utils import get_skip_gc_vars, split_program
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.core import Job, Plan from paddle.fluid.core import Job, Plan
from paddle.fluid.executor import _add_feed_fetch_ops, _StandaloneExecutor from paddle.fluid.executor import _add_feed_fetch_ops, _StandaloneExecutor
...@@ -180,11 +180,13 @@ class TestEncorderMulitMicroBatchRun(unittest.TestCase): ...@@ -180,11 +180,13 @@ class TestEncorderMulitMicroBatchRun(unittest.TestCase):
job_list = [] job_list = []
program_num = len(programs) program_num = len(programs)
skip_gc_vars = get_skip_gc_vars(programs)
for micro_batch_id in range(micro_batch_num): for micro_batch_id in range(micro_batch_num):
for program_id in range(program_num): for program_id in range(program_num):
job = Job(f"P{program_id}") job = Job(f"P{program_id}")
job.set_micro_batch_id(micro_batch_id) job.set_micro_batch_id(micro_batch_id)
job.set_skip_gc_vars(skip_gc_vars[program_id])
# Set col_attr info for fetch_op to fetch the correct data after running multiple micro batch # Set col_attr info for fetch_op to fetch the correct data after running multiple micro batch
if program_id == program_num - 1: if program_id == program_num - 1:
fetch_op_id_to_col_attr = {} fetch_op_id_to_col_attr = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册