未验证 提交 a702e170 编写于 作者: Z zhaoyingli 提交者: GitHub

auto parallel support pipeline scheduler with standalone executor (#54727)

* auto parallel support pipeline scheduler with standalone executor

* rm check_fetch

* update cmakelist and flags env

* rm set micro batch id

* rm import

* update utils func

* raise error when merge tensor for return_numpy is False

* fix _pipeline_opt

* fix unittest
上级 3e0f0a00
...@@ -140,6 +140,8 @@ class FetchV2Kernel { ...@@ -140,6 +140,8 @@ class FetchV2Kernel {
"operator 'Fetch') of current fetching variable to be " "operator 'Fetch') of current fetching variable to be "
"no less than 0. But received column index = %d.", "no less than 0. But received column index = %d.",
col)); col));
VLOG(3) << "Fetch variable " << fetch_var_name << "'s " << col
<< " column.";
auto *fetch_list = out_var->GetMutable<framework::FetchList>(); auto *fetch_list = out_var->GetMutable<framework::FetchList>();
......
...@@ -400,8 +400,11 @@ class Engine: ...@@ -400,8 +400,11 @@ class Engine:
dist_main_block._sync_with_cpp() dist_main_block._sync_with_cpp()
self._has_prepared_reader[self._mode] = True self._has_prepared_reader[self._mode] = True
# Insert read op to forward TaskNode if 1F1B pass is setted # Insert read op to forward TaskNode for fleet executor if 1F1B pass is setted
if self.main_program._pipeline_opt: if (
self.main_program._pipeline_opt
and not auto_utils.use_new_executor()
):
assert "tasks" in self.main_program._pipeline_opt["fleet_opt"] assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
fleet_opt = self.main_program._pipeline_opt["fleet_opt"] fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
fwd_task = None fwd_task = None
...@@ -471,8 +474,6 @@ class Engine: ...@@ -471,8 +474,6 @@ class Engine:
if var_name not in fetch_names: if var_name not in fetch_names:
fetch_names.append(var_name) fetch_names.append(var_name)
group_indices.append(fetch_names.index(var_name)) group_indices.append(fetch_names.index(var_name))
if not group_indices:
fetch_names.append([])
fetch_indices.append(group_indices) fetch_indices.append(group_indices)
dist_context = self._dist_contexts[mode] dist_context = self._dist_contexts[mode]
......
...@@ -25,7 +25,7 @@ from ..random import init_auto_parallel_rng ...@@ -25,7 +25,7 @@ from ..random import init_auto_parallel_rng
from .partitioner import Partitioner from .partitioner import Partitioner
from .process_group import get_world_process_group from .process_group import get_world_process_group
from .reshard import Resharder from .reshard import Resharder
from .utils import set_grad_var_shape from .utils import set_grad_var_shape, use_new_executor
class Parallelizer: class Parallelizer:
...@@ -38,6 +38,14 @@ class Parallelizer: ...@@ -38,6 +38,14 @@ class Parallelizer:
self._strategy = self._dist_context.strategy self._strategy = self._dist_context.strategy
self._logger = get_logger(logging.INFO) self._logger = get_logger(logging.INFO)
@property
def is_train(self):
return self._mode == "train"
@property
def is_test(self):
return self._mode in ["eval", "predict"]
def parallel_all(self): def parallel_all(self):
world_process_group = get_world_process_group() world_process_group = get_world_process_group()
all_ranks = world_process_group.ranks all_ranks = world_process_group.ranks
...@@ -50,7 +58,7 @@ class Parallelizer: ...@@ -50,7 +58,7 @@ class Parallelizer:
serial_main_program = self._dist_context.serial_main_program serial_main_program = self._dist_context.serial_main_program
serial_startup_program = self._dist_context.serial_startup_program serial_startup_program = self._dist_context.serial_startup_program
serial_optimizer = self._dist_context.serial_optimizer serial_optimizer = self._dist_context.serial_optimizer
if self._mode == "train" and serial_optimizer: if self.is_train and serial_optimizer:
# Generate backward # Generate backward
serial_loss = self._dist_context.serial_loss serial_loss = self._dist_context.serial_loss
params_grads = self._generate_backward( params_grads = self._generate_backward(
...@@ -191,8 +199,9 @@ class Parallelizer: ...@@ -191,8 +199,9 @@ class Parallelizer:
time.time() - time0, self._mode time.time() - time0, self._mode
) )
) )
# Clone program for test # Clone program for test
if self._mode != 'train': if self.is_test:
pipeline_opt = dist_main_prog._pipeline_opt pipeline_opt = dist_main_prog._pipeline_opt
dist_main_prog = dist_main_prog.clone(for_test=True) dist_main_prog = dist_main_prog.clone(for_test=True)
dist_startup_prog = dist_startup_prog.clone(for_test=True) dist_startup_prog = dist_startup_prog.clone(for_test=True)
...@@ -263,7 +272,7 @@ class Parallelizer: ...@@ -263,7 +272,7 @@ class Parallelizer:
# apply quantization pass # apply quantization pass
# The pass can be applied when mode must be 'train' # The pass can be applied when mode must be 'train'
if self._mode == 'train' and self._strategy.qat.enable: if self.is_train and self._strategy.qat.enable:
config = copy.deepcopy(self._strategy.qat.to_dict()) config = copy.deepcopy(self._strategy.qat.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["params_grads"] = params_grads config["params_grads"] = params_grads
...@@ -282,7 +291,7 @@ class Parallelizer: ...@@ -282,7 +291,7 @@ class Parallelizer:
# apply recompute pass # apply recompute pass
# recompute is then train-only optimization # recompute is then train-only optimization
if self._mode == "train" and self._strategy.recompute.enable: if self.is_train and self._strategy.recompute.enable:
config = copy.deepcopy(self._strategy.recompute.to_dict()) config = copy.deepcopy(self._strategy.recompute.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["no_grad_set"] = None config["no_grad_set"] = None
...@@ -326,7 +335,7 @@ class Parallelizer: ...@@ -326,7 +335,7 @@ class Parallelizer:
) )
params_grads = self._pass_context.get_attr("params_grads") params_grads = self._pass_context.get_attr("params_grads")
if self._mode == "train": if self.is_train:
# GradClip is train-only optimization # GradClip is train-only optimization
config = copy.deepcopy(self._strategy.sharding.to_dict()) config = copy.deepcopy(self._strategy.sharding.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
...@@ -349,7 +358,7 @@ class Parallelizer: ...@@ -349,7 +358,7 @@ class Parallelizer:
[main_program], [startup_program], self._pass_context [main_program], [startup_program], self._pass_context
) )
if self._strategy.pipeline.enable: if self.is_train and self._strategy.pipeline.enable:
self._strategy.gradient_merge.enable = True self._strategy.gradient_merge.enable = True
self._strategy.gradient_merge.k_steps = ( self._strategy.gradient_merge.k_steps = (
self._strategy.pipeline.accumulate_steps self._strategy.pipeline.accumulate_steps
...@@ -357,7 +366,7 @@ class Parallelizer: ...@@ -357,7 +366,7 @@ class Parallelizer:
self._strategy.gradient_merge.avg = True self._strategy.gradient_merge.avg = True
# gradient_merge is then train-only optimization # gradient_merge is then train-only optimization
if self._mode == "train" and self._strategy.gradient_merge.enable: if self.is_train and self._strategy.gradient_merge.enable:
config = copy.deepcopy(self._strategy.gradient_merge.to_dict()) config = copy.deepcopy(self._strategy.gradient_merge.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["params_grads"] = params_grads config["params_grads"] = params_grads
...@@ -368,7 +377,7 @@ class Parallelizer: ...@@ -368,7 +377,7 @@ class Parallelizer:
[main_program], [startup_program], self._pass_context [main_program], [startup_program], self._pass_context
) )
if self._strategy.pipeline.enable: if self._strategy.pipeline.enable and not use_new_executor():
config = copy.deepcopy(self._strategy.pipeline.to_dict()) config = copy.deepcopy(self._strategy.pipeline.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
auto_parallel_pipeline_pass = new_pass( auto_parallel_pipeline_pass = new_pass(
...@@ -378,10 +387,17 @@ class Parallelizer: ...@@ -378,10 +387,17 @@ class Parallelizer:
[main_program], [startup_program], self._pass_context [main_program], [startup_program], self._pass_context
) )
if self._mode == "train" and self._strategy.fused_passes.enable: if self.is_train and self._strategy.fused_passes.enable:
if len(self._strategy.fused_passes.fused_passes_list) > 0: if len(self._strategy.fused_passes.fused_passes_list) > 0:
new_pass_list = [] new_pass_list = []
for op in self._strategy.fused_passes.fused_passes_list: for op in self._strategy.fused_passes.fused_passes_list:
new_pass_list.append(new_pass(op)) new_pass_list.append(new_pass(op))
pass_manager = PassManager(new_pass_list) pass_manager = PassManager(new_pass_list)
pass_manager.apply([main_program], [startup_program]) pass_manager.apply([main_program], [startup_program])
if self._strategy.pipeline.enable and use_new_executor():
main_program._pipeline_opt = {}
main_program._pipeline_opt["standalone_opt"] = {
"schedule_mode": self._strategy.pipeline.schedule_mode,
"num_micro_batches": self._strategy.pipeline.accumulate_steps,
}
...@@ -2367,3 +2367,16 @@ def _dygraph_guard_(func): ...@@ -2367,3 +2367,16 @@ def _dygraph_guard_(func):
dygraph_guard = wrap_decorator(_dygraph_guard_) dygraph_guard = wrap_decorator(_dygraph_guard_)
def use_new_executor():
new_executor_micro_batching = os.environ.get(
'FLAGS_new_executor_micro_batching', None
)
return new_executor_micro_batching in [
1,
'1',
True,
'True',
'true',
]
...@@ -94,8 +94,13 @@ def _get_gm_cond_var(main_program, k_steps, dist_context): ...@@ -94,8 +94,13 @@ def _get_gm_cond_var(main_program, k_steps, dist_context):
) )
set_var_dist_attr(dist_context, step_var, [-1], world_process_group.ranks) set_var_dist_attr(dist_context, step_var, [-1], world_process_group.ranks)
cond_var = main_block.create_var( cond_var = paddle.static.create_global_var(
name="gradient_merge_cond", shape=[1], dtype='bool' name="gradient_merge_cond",
shape=[1],
value=bool(0),
dtype='bool',
persistable=True,
force_cpu=True,
) )
set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks) set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks)
......
...@@ -22,7 +22,7 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole ...@@ -22,7 +22,7 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import Parameter, Program from paddle.fluid.framework import Parameter, Program
from .pass_base import PassBase, register_pass from .pass_base import PassBase, PassContext, new_pass, register_pass
__not_shape_var_type__ = [ __not_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.READER,
...@@ -257,7 +257,7 @@ def _program_for_fthenb_and_1f1b(program): ...@@ -257,7 +257,7 @@ def _program_for_fthenb_and_1f1b(program):
} }
@register_pass("pipeline_fthenb_scheduler") @register_pass("pipeline_scheduler_FThenB")
class PipelineFThenBPass(PassBase): class PipelineFThenBPass(PassBase):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -272,12 +272,12 @@ class PipelineFThenBPass(PassBase): ...@@ -272,12 +272,12 @@ class PipelineFThenBPass(PassBase):
job_list = [] job_list = []
lr_job = core.Job("lr") lr_job = core.Job("lr")
job_list.append(lr_job) job_list.append(lr_job)
for i in range(self._micro_batch_size): for i in range(self._num_micro_batches):
forward_job = core.Job("forward") forward_job = core.Job("forward")
forward_job.set_micro_batch_id(i) forward_job.set_micro_batch_id(i)
job_list.append(forward_job) job_list.append(forward_job)
for i in range(self._micro_batch_size): for i in range(self._num_micro_batches):
backward_job = core.Job("backward") backward_job = core.Job("backward")
backward_job.set_micro_batch_id(i) backward_job.set_micro_batch_id(i)
job_list.append(backward_job) job_list.append(backward_job)
...@@ -287,7 +287,7 @@ class PipelineFThenBPass(PassBase): ...@@ -287,7 +287,7 @@ class PipelineFThenBPass(PassBase):
return job_list return job_list
def _apply_single_impl(self, main_program, startup_program, context): def _apply_single_impl(self, main_program, startup_program, context):
self._micro_batch_size = self.get_attr("micro_batch_size") self._num_micro_batches = self.get_attr("num_micro_batches")
self._program = main_program self._program = main_program
_insert_sync_for_fthenb_1f1b(self._program) _insert_sync_for_fthenb_1f1b(self._program)
...@@ -296,3 +296,16 @@ class PipelineFThenBPass(PassBase): ...@@ -296,3 +296,16 @@ class PipelineFThenBPass(PassBase):
plan = core.Plan(job_list, type_to_program) plan = core.Plan(job_list, type_to_program)
context.set_attr("plan", plan) context.set_attr("plan", plan)
def apply_pass(main_program, startup_program, pass_name, pass_attr={}):
assert pass_name in [
"FThenB"
], "pipeline scheduler only support FThenB, but recieve {}".format(
pass_name
)
pipeline_pass = new_pass("pipeline_scheduler_" + pass_name, pass_attr)
pass_context = PassContext()
pipeline_pass.apply([main_program], [startup_program], pass_context)
plan = pass_context.get_attr("plan")
return plan
...@@ -414,6 +414,36 @@ def _add_feed_fetch_ops( ...@@ -414,6 +414,36 @@ def _add_feed_fetch_ops(
return tmp_program return tmp_program
def _set_micro_batch_fetch(plan):
if plan.micro_batch_num() <= 1:
return
valid_fetch_types = ["fetch", "fetch_v2"]
for job in plan.job_list():
idx_to_col_attr = {}
prog = plan.program(job.type())
for i in range(prog.block(0).op_size()):
op = prog.block(0).op(i)
if op.type() in valid_fetch_types:
idx_to_col_attr[i] = op.attr('col')
for idx, col in idx_to_col_attr.items():
job.set_col_attr_for_fetch_op(
idx, col * plan.micro_batch_num() + job.micro_batch_id()
)
def _merge_tensors(tensor, micro_batch_num):
if micro_batch_num <= 1:
return tensor
assert len(tensor) % micro_batch_num == 0
chunk_tensor = [
tensor[i : i + micro_batch_num]
for i in range(0, len(tensor), micro_batch_num)
]
return [np.array(chunk) for chunk in chunk_tensor]
def _apply_inplace_addto_pass( def _apply_inplace_addto_pass(
program, enable_inplace, enable_addto, skip_var_names program, enable_inplace, enable_addto, skip_var_names
): ):
...@@ -653,8 +683,13 @@ class _StandaloneExecutor: ...@@ -653,8 +683,13 @@ class _StandaloneExecutor:
""" """
tensors = self._new_exe.run(feed_names)._move_to_list() tensors = self._new_exe.run(feed_names)._move_to_list()
if return_numpy: if return_numpy:
return as_numpy(tensors, copy=True) tensors = as_numpy(tensors, copy=True)
return _merge_tensors(tensors, self._plan.micro_batch_num())
else: else:
if self._plan.micro_batch_num() > 1:
raise RuntimeError(
"`merge_tensor` does not support when return_numpy is False."
)
return tensors return tensors
def _create_new_executor(self): def _create_new_executor(self):
...@@ -831,12 +866,30 @@ class _ExecutorCache: ...@@ -831,12 +866,30 @@ class _ExecutorCache:
_apply_inplace_addto_pass( _apply_inplace_addto_pass(
program, enable_inplace, enable_addto, skip_var_names program, enable_inplace, enable_addto, skip_var_names
) )
new_program = program.clone() new_program = program.clone()
new_exe = _StandaloneExecutor( if (
place, new_program._pipeline_opt
core.Plan([core.Job("default")], {"default": new_program.desc}), and "standalone_opt" in new_program._pipeline_opt
scope, ):
from paddle.distributed.passes.pipeline_scheduler_pass import (
apply_pass,
) )
standalone_opt = new_program._pipeline_opt["standalone_opt"]
pass_name = standalone_opt["schedule_mode"]
pass_attr = {
"num_micro_batches": standalone_opt["num_micro_batches"]
}
plan = apply_pass(new_program, new_program, pass_name, pass_attr)
else:
default_job = core.Job("default")
type_to_program = {"default": new_program.desc}
plan = core.Plan([default_job], type_to_program)
_set_micro_batch_fetch(plan)
new_exe = _StandaloneExecutor(place, plan, scope)
return new_program, new_exe return new_program, new_exe
...@@ -1408,7 +1461,15 @@ class Executor: ...@@ -1408,7 +1461,15 @@ class Executor:
fetch_list = self._check_fetch_list(fetch_list) fetch_list = self._check_fetch_list(fetch_list)
if isinstance(program, Program) and program._pipeline_opt: from paddle.distributed.auto_parallel.static.utils import (
use_new_executor,
)
if (
isinstance(program, Program)
and program._pipeline_opt
and not use_new_executor()
):
if "fleet_opt" in program._pipeline_opt: if "fleet_opt" in program._pipeline_opt:
# Move prepare here for port conflict with nccl in startup program # Move prepare here for port conflict with nccl in startup program
if self._fleet_executor is None: if self._fleet_executor is None:
......
...@@ -5921,6 +5921,8 @@ class Program: ...@@ -5921,6 +5921,8 @@ class Program:
p._appending_grad_times = self._appending_grad_times p._appending_grad_times = self._appending_grad_times
if hasattr(self, 'lr_scheduler'): if hasattr(self, 'lr_scheduler'):
p.lr_scheduler = self.lr_scheduler p.lr_scheduler = self.lr_scheduler
if hasattr(self, '_pipeline_opt'):
p._pipeline_opt = self._pipeline_opt
# NOTE(zhiqiu): we sync the cloned program, to update its program by # NOTE(zhiqiu): we sync the cloned program, to update its program by
# its desc. # its desc.
......
...@@ -68,6 +68,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -68,6 +68,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_auto_tuner MODULES test_auto_tuner) py_test_modules(test_auto_tuner MODULES test_auto_tuner)
set_tests_properties(test_auto_tuner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" set_tests_properties(test_auto_tuner PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 100) TIMEOUT 100)
py_test_modules(test_pipeline_scheduler_FThenB MODULES
test_pipeline_scheduler_FThenB)
set_tests_properties(test_pipeline_scheduler_FThenB
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_auto_tuner_compare MODULES test_auto_tuner_compare) py_test_modules(test_auto_tuner_compare MODULES test_auto_tuner_compare)
set_tests_properties(test_auto_tuner_compare set_tests_properties(test_auto_tuner_compare
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import unittest
import numpy as np
from get_gpt_model import FakeDataset, generate_model
import paddle
from paddle.distributed import ParallelEnv
from paddle.distributed.fleet import auto
paddle.enable_static()
def apply_pass(use_standalone_exe=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
pipeline = strategy.pipeline
pipeline.enable = True
pipeline.schedule_mode = "1F1B" if not use_standalone_exe else "FThenB"
pipeline.accumulate_steps = 2
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class Test1F1BPass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
paddle.distributed.fleet.init(is_collective=True)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_standalone_exe=False):
reset_prog()
strategy = apply_pass(use_standalone_exe)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("pp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses
),
)
def test_pp_pass(self):
# pp2 1f1b training with fleet executor
os.environ['FLAGS_new_executor_micro_batching'] = 'False'
engine_1f1b = self.get_engine(use_standalone_exe=False)
history_1f1b = engine_1f1b.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert engine_1f1b._strategy.pipeline.schedule_mode == "1F1B"
assert os.environ.get('FLAGS_new_executor_micro_batching') == "False"
# pp2 fthenb training with standalone executor
os.environ['FLAGS_new_executor_micro_batching'] = 'True'
engine_fthenb = self.get_engine(use_standalone_exe=True)
history_fthenb = engine_fthenb.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert engine_fthenb._strategy.pipeline.schedule_mode == "FThenB"
assert os.environ.get('FLAGS_new_executor_micro_batching') == "True"
# NOTE: every sample data from dataset is all the same
if paddle.distributed.get_rank() == 1:
losses_1f1b = np.array(history_1f1b.history["loss"])
losses_fthenb = np.array(history_fthenb.history["loss"])
# accumulate_steps is 2
assert losses_fthenb[0].shape[0] == 2
# losses_1f1b is the last loss of accumulate_steps
# losses_fthenb is all the losses of accumulate_steps
self.check_results(losses_1f1b[0], losses_fthenb[0][-1])
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
class TestFThenBPass(unittest.TestCase):
def test_pp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(
file_dir, "pipeline_scheduler_FThenB.py"
)
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
...@@ -21,13 +21,13 @@ from paddle.distributed.passes import PassContext, new_pass ...@@ -21,13 +21,13 @@ from paddle.distributed.passes import PassContext, new_pass
class TestStandaloneExecutorFThenBPlan(unittest.TestCase): class TestStandaloneExecutorFThenBPlan(unittest.TestCase):
def test_standalone_executor_fthenb_plan(self): def test_standalone_executor_fthenb_plan(self):
config = {} config = {}
config["micro_batch_size"] = 4 config["num_micro_batches"] = 4
pass_context = PassContext() pass_context = PassContext()
startup_program = static.Program() startup_program = static.Program()
main_program = static.Program() main_program = static.Program()
pipeline_fthenb_pass = new_pass("pipeline_fthenb_scheduler", config) pipeline_fthenb_pass = new_pass("pipeline_scheduler_FThenB", config)
pipeline_fthenb_pass.apply( pipeline_fthenb_pass.apply(
[main_program], [startup_program], pass_context [main_program], [startup_program], pass_context
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册