未验证 提交 26980b7b 编写于 作者: Z zhaoyingli 提交者: GitHub

add skip_gc_vars for 1f1b schedule mode (#54938)

* add skip_gc_vars for 1f1b schedule mode

* add pp_degree and pp_stage
上级 16ff63a5
...@@ -25,7 +25,7 @@ from ..random import init_auto_parallel_rng ...@@ -25,7 +25,7 @@ from ..random import init_auto_parallel_rng
from .partitioner import Partitioner from .partitioner import Partitioner
from .process_group import get_world_process_group from .process_group import get_world_process_group
from .reshard import Resharder from .reshard import Resharder
from .utils import set_grad_var_shape, use_new_executor from .utils import get_pp_stage, set_grad_var_shape, use_new_executor
class Parallelizer: class Parallelizer:
...@@ -400,4 +400,6 @@ class Parallelizer: ...@@ -400,4 +400,6 @@ class Parallelizer:
main_program._pipeline_opt["standalone_opt"] = { main_program._pipeline_opt["standalone_opt"] = {
"schedule_mode": self._strategy.pipeline.schedule_mode, "schedule_mode": self._strategy.pipeline.schedule_mode,
"num_micro_batches": self._strategy.pipeline.accumulate_steps, "num_micro_batches": self._strategy.pipeline.accumulate_steps,
"pp_degree": len(self._dist_context.process_meshes),
"pp_stage": get_pp_stage(self._dist_context, rank),
} }
...@@ -2382,6 +2382,15 @@ def use_new_executor(): ...@@ -2382,6 +2382,15 @@ def use_new_executor():
] ]
def get_pp_stage(dist_context, rank):
pp_idx = None
for idx, process_mesh in enumerate(dist_context.process_meshes):
if rank in process_mesh.process_ids:
pp_idx = idx
break
return pp_idx
def wrap_data_for_completion( def wrap_data_for_completion(
dist_op, input_names: list, output_names: list, attr_names: list dist_op, input_names: list, output_names: list, attr_names: list
): ):
......
...@@ -326,9 +326,10 @@ class Pipeline1F1BPass(PassBase): ...@@ -326,9 +326,10 @@ class Pipeline1F1BPass(PassBase):
def _check_conflict(self, other_pass): def _check_conflict(self, other_pass):
return True return True
def _create_job_list(self): def _create_job_list(self, type_to_skip_vars):
job_list = [] job_list = []
lr_job = core.Job("lr") lr_job = core.Job("lr")
lr_job.set_skip_gc_vars(type_to_skip_vars["lr"])
job_list.append(lr_job) job_list.append(lr_job)
assert ( assert (
...@@ -342,6 +343,7 @@ class Pipeline1F1BPass(PassBase): ...@@ -342,6 +343,7 @@ class Pipeline1F1BPass(PassBase):
for i in range(micro_batch_in_warmup): for i in range(micro_batch_in_warmup):
forward_job = core.Job("forward") forward_job = core.Job("forward")
forward_job.set_micro_batch_id(forward_micro_batch_id) forward_job.set_micro_batch_id(forward_micro_batch_id)
forward_job.set_skip_gc_vars(type_to_skip_vars["forward"])
job_list.append(forward_job) job_list.append(forward_job)
forward_micro_batch_id += 1 forward_micro_batch_id += 1
...@@ -349,20 +351,24 @@ class Pipeline1F1BPass(PassBase): ...@@ -349,20 +351,24 @@ class Pipeline1F1BPass(PassBase):
for i in range(micro_batch_in_1f1b): for i in range(micro_batch_in_1f1b):
backward_job = core.Job("backward") backward_job = core.Job("backward")
backward_job.set_micro_batch_id(backward_micro_batch_id) backward_job.set_micro_batch_id(backward_micro_batch_id)
backward_job.set_skip_gc_vars(type_to_skip_vars["backward"])
job_list.append(backward_job) job_list.append(backward_job)
backward_micro_batch_id += 1 backward_micro_batch_id += 1
forward_job = core.Job("forward") forward_job = core.Job("forward")
forward_job.set_micro_batch_id(forward_micro_batch_id) forward_job.set_micro_batch_id(forward_micro_batch_id)
forward_job.set_skip_gc_vars(type_to_skip_vars["forward"])
job_list.append(forward_job) job_list.append(forward_job)
forward_micro_batch_id += 1 forward_micro_batch_id += 1
for i in range(micro_batch_in_warmup): for i in range(micro_batch_in_warmup):
backward_job = core.Job("backward") backward_job = core.Job("backward")
backward_job.set_micro_batch_id(backward_micro_batch_id) backward_job.set_micro_batch_id(backward_micro_batch_id)
backward_job.set_skip_gc_vars(type_to_skip_vars["backward"])
job_list.append(backward_job) job_list.append(backward_job)
backward_micro_batch_id += 1 backward_micro_batch_id += 1
opt_job = core.Job("optimizer") opt_job = core.Job("optimizer")
opt_job.set_skip_gc_vars(type_to_skip_vars["optimizer"])
job_list.append(opt_job) job_list.append(opt_job)
return job_list return job_list
...@@ -373,8 +379,10 @@ class Pipeline1F1BPass(PassBase): ...@@ -373,8 +379,10 @@ class Pipeline1F1BPass(PassBase):
self._program = main_program self._program = main_program
_insert_sync_for_fthenb_1f1b(self._program) _insert_sync_for_fthenb_1f1b(self._program)
type_to_program = _program_for_fthenb_and_1f1b(self._program) type_to_program, type_to_skip_vars = _program_for_fthenb_and_1f1b(
job_list = self._create_job_list() self._program
)
job_list = self._create_job_list(type_to_skip_vars)
plan = core.Plan(job_list, type_to_program) plan = core.Plan(job_list, type_to_program)
context.set_attr("plan", plan) context.set_attr("plan", plan)
......
...@@ -878,10 +878,9 @@ class _ExecutorCache: ...@@ -878,10 +878,9 @@ class _ExecutorCache:
standalone_opt = new_program._pipeline_opt["standalone_opt"] standalone_opt = new_program._pipeline_opt["standalone_opt"]
pass_name = standalone_opt["schedule_mode"] pass_name = standalone_opt["schedule_mode"]
pass_attr = { plan = apply_pass(
"num_micro_batches": standalone_opt["num_micro_batches"] new_program, new_program, pass_name, standalone_opt
} )
plan = apply_pass(new_program, new_program, pass_name, pass_attr)
else: else:
default_job = core.Job("default") default_job = core.Job("default")
type_to_program = {"default": new_program.desc} type_to_program = {"default": new_program.desc}
......
...@@ -68,9 +68,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -68,9 +68,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_auto_tuner MODULES test_auto_tuner) py_test_modules(test_auto_tuner MODULES test_auto_tuner)
set_tests_properties(test_auto_tuner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" set_tests_properties(test_auto_tuner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 100) TIMEOUT 100)
py_test_modules(test_pipeline_scheduler_FThenB MODULES py_test_modules(test_pipeline_scheduler MODULES test_pipeline_scheduler)
test_pipeline_scheduler_FThenB) set_tests_properties(test_pipeline_scheduler
set_tests_properties(test_pipeline_scheduler_FThenB
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_auto_tuner_compare MODULES test_auto_tuner_compare) py_test_modules(test_auto_tuner_compare MODULES test_auto_tuner_compare)
set_tests_properties(test_auto_tuner_compare set_tests_properties(test_auto_tuner_compare
......
...@@ -26,14 +26,14 @@ from paddle.distributed.fleet import auto ...@@ -26,14 +26,14 @@ from paddle.distributed.fleet import auto
paddle.enable_static() paddle.enable_static()
def apply_pass(use_standalone_exe=False): def apply_pass(schedule_mode="FThenB"):
strategy = auto.Strategy() strategy = auto.Strategy()
strategy.auto_mode = "semi" strategy.auto_mode = "semi"
strategy.reinit = True strategy.reinit = True
pipeline = strategy.pipeline pipeline = strategy.pipeline
pipeline.enable = True pipeline.enable = True
pipeline.schedule_mode = "1F1B" if not use_standalone_exe else "FThenB" pipeline.schedule_mode = schedule_mode
pipeline.accumulate_steps = 2 pipeline.accumulate_steps = 2
return strategy return strategy
...@@ -61,10 +61,10 @@ class Test1F1BPass(unittest.TestCase): ...@@ -61,10 +61,10 @@ class Test1F1BPass(unittest.TestCase):
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place) engine._executor = paddle.static.Executor(place)
def get_engine(self, use_standalone_exe=False): def get_engine(self, schedule_mode="FThenB"):
reset_prog() reset_prog()
strategy = apply_pass(use_standalone_exe) strategy = apply_pass(schedule_mode)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("pp") model, loss = generate_model("pp")
...@@ -87,31 +87,44 @@ class Test1F1BPass(unittest.TestCase): ...@@ -87,31 +87,44 @@ class Test1F1BPass(unittest.TestCase):
def test_pp_pass(self): def test_pp_pass(self):
# pp2 1f1b training with fleet executor # pp2 1f1b training with fleet executor
os.environ['FLAGS_new_executor_micro_batching'] = 'False' os.environ['FLAGS_new_executor_micro_batching'] = 'False'
engine_1f1b = self.get_engine(use_standalone_exe=False) engine_fleet_1f1b = self.get_engine(schedule_mode="1F1B")
history_1f1b = engine_1f1b.fit( history_fleet_1f1b = engine_fleet_1f1b.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1 self.dataset, 3, batch_size=self.batch_size, log_freq=1
) )
assert engine_1f1b._strategy.pipeline.schedule_mode == "1F1B" assert engine_fleet_1f1b._strategy.pipeline.schedule_mode == "1F1B"
assert os.environ.get('FLAGS_new_executor_micro_batching') == "False" assert os.environ.get('FLAGS_new_executor_micro_batching') == "False"
# pp2 fthenb training with standalone executor # pp2 fthenb training with standalone executor
os.environ['FLAGS_new_executor_micro_batching'] = 'True' os.environ['FLAGS_new_executor_micro_batching'] = 'True'
engine_fthenb = self.get_engine(use_standalone_exe=True) engine_fthenb = self.get_engine(schedule_mode="FThenB")
history_fthenb = engine_fthenb.fit( history_fthenb = engine_fthenb.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1 self.dataset, 3, batch_size=self.batch_size, log_freq=1
) )
assert engine_fthenb._strategy.pipeline.schedule_mode == "FThenB" assert engine_fthenb._strategy.pipeline.schedule_mode == "FThenB"
assert os.environ.get('FLAGS_new_executor_micro_batching') == "True" assert os.environ.get('FLAGS_new_executor_micro_batching') == "True"
# pp2 1f1b training with standalone executor
os.environ['FLAGS_new_executor_micro_batching'] = 'True'
engine_1f1b = self.get_engine(schedule_mode="1F1B")
history_1f1b = engine_1f1b.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert engine_1f1b._strategy.pipeline.schedule_mode == "1F1B"
assert os.environ.get('FLAGS_new_executor_micro_batching') == "True"
# NOTE: every sample data from dataset is all the same # NOTE: every sample data from dataset is all the same
if paddle.distributed.get_rank() == 1: if paddle.distributed.get_rank() == 1:
losses_1f1b = np.array(history_1f1b.history["loss"]) losses_fleet_1f1b = np.array(history_fleet_1f1b.history["loss"])
losses_fthenb = np.array(history_fthenb.history["loss"]) losses_fthenb = np.array(history_fthenb.history["loss"])
losses_1f1b = np.array(history_1f1b.history["loss"])
# accumulate_steps is 2 # accumulate_steps is 2
assert losses_fthenb[0].shape[0] == 2 assert losses_fthenb[0].shape[0] == 2
# losses_1f1b is the last loss of accumulate_steps assert losses_1f1b[0].shape[0] == 2
# losses_fleet_1f1b is the last loss of accumulate_steps
# losses_fthenb is all the losses of accumulate_steps # losses_fthenb is all the losses of accumulate_steps
self.check_results(losses_1f1b[0], losses_fthenb[0][-1]) # losses_1f1b is alla the losses of accumulate_steps
self.check_results(losses_fleet_1f1b[0], losses_fthenb[0][-1])
self.check_results(losses_fleet_1f1b[0], losses_1f1b[0][-1])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -23,7 +23,7 @@ class TestFThenBPass(unittest.TestCase): ...@@ -23,7 +23,7 @@ class TestFThenBPass(unittest.TestCase):
def test_pp2(self): def test_pp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__)) file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join( launch_model_path = os.path.join(
file_dir, "pipeline_scheduler_FThenB.py" file_dir, "pipeline_scheduler_unittest.py"
) )
if os.environ.get("WITH_COVERAGE", "OFF") == "ON": if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册