未验证 提交 d9aa85de 编写于 作者: L LiYuRio 提交者: GitHub

add fthenb pass (#54409)

上级 856c54a8
...@@ -19,7 +19,9 @@ ...@@ -19,7 +19,9 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
const std::vector<Job*>& Plan::GetJobList() const { return job_list_; } const std::vector<std::shared_ptr<Job>>& Plan::GetJobList() const {
return job_list_;
}
const std::unordered_map<std::string, ProgramDesc*>& Plan::GetTypeToProgram() const std::unordered_map<std::string, ProgramDesc*>& Plan::GetTypeToProgram()
const { const {
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -26,18 +27,18 @@ class Job; ...@@ -26,18 +27,18 @@ class Job;
class Plan final { class Plan final {
public: public:
Plan(const std::vector<Job*>& job_list, Plan(const std::vector<std::shared_ptr<Job>>& job_list,
const std::unordered_map<std::string, ProgramDesc*>& type_to_program) const std::unordered_map<std::string, ProgramDesc*>& type_to_program)
: job_list_(job_list), type_to_program_(type_to_program) {} : job_list_(job_list), type_to_program_(type_to_program) {}
~Plan() = default; ~Plan() = default;
const std::vector<Job*>& GetJobList() const; const std::vector<std::shared_ptr<Job>>& GetJobList() const;
const std::unordered_map<std::string, ProgramDesc*>& GetTypeToProgram() const; const std::unordered_map<std::string, ProgramDesc*>& GetTypeToProgram() const;
private: private:
DISABLE_COPY_AND_ASSIGN(Plan); DISABLE_COPY_AND_ASSIGN(Plan);
std::vector<Job*> job_list_; std::vector<std::shared_ptr<Job>> job_list_;
std::unordered_map<std::string, ProgramDesc*> type_to_program_; std::unordered_map<std::string, ProgramDesc*> type_to_program_;
}; };
......
...@@ -1857,14 +1857,14 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1857,14 +1857,14 @@ All parameter, weight, gradient are variables in Paddle.
return py::cast(std::move(ret)); return py::cast(std::move(ret));
}); });
py::class_<framework::Job>(m, "job") py::class_<framework::Job, std::shared_ptr<framework::Job>>(m, "job")
.def(py::init<const std::string &>(), py::arg("type")) .def(py::init<const std::string &>(), py::arg("type"))
.def("type", &framework::Job::GetJobType) .def("type", &framework::Job::GetJobType)
.def("micro_batch_id", &framework::Job::GetMicroBatchId) .def("micro_batch_id", &framework::Job::GetMicroBatchId)
.def("set_micro_batch_id", &framework::Job::SetMicroBatchId); .def("set_micro_batch_id", &framework::Job::SetMicroBatchId);
py::class_<framework::Plan>(m, "plan") py::class_<framework::Plan>(m, "plan")
.def(py::init<const std::vector<Job *> &, .def(py::init<const std::vector<std::shared_ptr<framework::Job>> &,
const std::unordered_map<std::string, const std::unordered_map<std::string,
framework::ProgramDesc *> &>(), framework::ProgramDesc *> &>(),
py::arg("job_list"), py::arg("job_list"),
......
...@@ -24,6 +24,7 @@ from .auto_parallel_data_parallel_optimization import * # noqa: F403 ...@@ -24,6 +24,7 @@ from .auto_parallel_data_parallel_optimization import * # noqa: F403
from .auto_parallel_grad_clip import * # noqa: F403 from .auto_parallel_grad_clip import * # noqa: F403
from .auto_parallel_supplement_explicit_dependencies import * # noqa: F403 from .auto_parallel_supplement_explicit_dependencies import * # noqa: F403
from .auto_parallel_pipeline import * # noqa: F403 from .auto_parallel_pipeline import * # noqa: F403
from .pipeline_scheduler_pass import * # noqa: F403
from .cpp_pass import * # noqa: F403 from .cpp_pass import * # noqa: F403
from .ps_trainer_pass import * # noqa: F403 from .ps_trainer_pass import * # noqa: F403
from .ps_server_pass import * # noqa: F403 from .ps_server_pass import * # noqa: F403
......
...@@ -24,11 +24,14 @@ from paddle.distributed.auto_parallel.static.utils import ( ...@@ -24,11 +24,14 @@ from paddle.distributed.auto_parallel.static.utils import (
is_optimize_op, is_optimize_op,
) )
from paddle.distributed.fleet.fleet_executor_utils import TaskNode from paddle.distributed.fleet.fleet_executor_utils import TaskNode
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import Parameter, Program from paddle.fluid.framework import Program
from .pass_base import PassBase, register_pass from .pass_base import PassBase, register_pass
from .pipeline_scheduler_pass import (
_create_program,
_insert_sync_for_fthenb_1f1b,
)
__not_shape_var_type__ = [ __not_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.READER,
...@@ -75,7 +78,7 @@ class PipelinePass(PassBase): ...@@ -75,7 +78,7 @@ class PipelinePass(PassBase):
self._cur_pp_stage = self._get_pp_stage(self._cur_rank) self._cur_pp_stage = self._get_pp_stage(self._cur_rank)
if self._mode == "1F1B": if self._mode == "1F1B":
self._insert_sync_ops_for_1f1b() _insert_sync_for_fthenb_1f1b(self._program)
self._task_1f1b() self._task_1f1b()
elif self._mode == "F-Then-B": elif self._mode == "F-Then-B":
raise NotImplementedError("F-Then-B has not been implemented") raise NotImplementedError("F-Then-B has not been implemented")
...@@ -123,177 +126,6 @@ class PipelinePass(PassBase): ...@@ -123,177 +126,6 @@ class PipelinePass(PassBase):
block._sync_with_cpp() block._sync_with_cpp()
def _insert_sync_ops_for_1f1b(self):
"""
This implementation refers to lots of Paddle/python/paddle/fluid/optimizer.py.
The difference between this function with 'PipelineOptimizer' is that
'send_v2' op and 'recv_v2' op have been inserted in program by 'reshard'.
"""
for block in self._program.blocks:
offset = 0
first_optimize_index = None
for index, op in enumerate(list(block.ops)):
if is_optimize_op(op):
first_optimize_index = index
break
# insert sync ops
for index, op in enumerate(list(block.ops)):
# NOTE: pipeline might hang when dynamic_shape is True
if op.type in ['send_v2', 'recv_v2']:
op._set_attr("dynamic_shape", False)
# set send op on comm stream
if op.type == 'send_v2':
# step1: set 'use_calc_stream' False
op._set_attr("use_calc_stream", False)
op_role = op.attr('op_role')
ring_id = op.attr('ring_id')
# step2: insert 'c_sync_calc_stream' op before 'send_v2' op
var_name = op.input_arg_names[0]
var = block.var(var_name)
block._insert_op_without_sync(
index=index + offset,
type="c_sync_calc_stream",
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={'op_role': op_role},
)
offset += 1
# step3: insert 'c_sync_comm_stream' op after 'send_v2' op or
# before the first optimize op
if int(op_role) == int(OpRole.Backward):
index = first_optimize_index + offset
new_op_role = OpRole.Optimize
else:
index = index + offset + 1
new_op_role = OpRole.Backward
sync_comm_op = block._insert_op_without_sync(
index=index,
type="c_sync_comm_stream",
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
'op_role': new_op_role,
'ring_id': ring_id,
},
)
# step4: If 'send_v2' op in forward parse, set 'pipeline_flag' to distinguish
# whether the 'c_sync_comm_stream' op is inserted for pipeline.
if int(op_role) == int(OpRole.Forward):
sync_comm_op._set_attr('pipeline_flag', '')
offset += 1
block._sync_with_cpp()
offset = 0
backward_recv_index = None
for index, op in enumerate(block.ops):
if op.type == "recv_v2" and is_backward_op(op):
backward_recv_index = index
break
if backward_recv_index is None:
continue
# replace 'c_sync_comm_stream' op with 'nop' op
# use nop op for gc
for index, op in enumerate(list(block.ops)):
if index >= backward_recv_index:
break
if op.type == 'c_sync_comm_stream' and op.has_attr(
'pipeline_flag'
):
var_name = op.output_arg_names[0]
var = block.var(var_name)
block._remove_op(index + offset, sync=False)
offset -= 1
block._insert_op_without_sync(
index=backward_recv_index,
type="nop",
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={'op_role': OpRole.Backward},
)
block._sync_with_cpp()
def _create_param(self, dst_block, src_var):
copied_kwargs = {}
copied_kwargs['trainable'] = src_var.trainable
copied_kwargs['optimize_attr'] = src_var.optimize_attr
copied_kwargs['regularizer'] = src_var.regularizer
copied_kwargs['do_model_average'] = src_var.do_model_average
copied_kwargs['need_clip'] = src_var.need_clip
Parameter(
block=dst_block,
type=src_var.type,
name=src_var.name,
shape=src_var.shape,
dtype=src_var.dtype,
lod_level=src_var.lod_level,
error_clip=src_var.error_clip,
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer,
**copied_kwargs
)
def _create_inter(self, dst_block, src_var):
dst_block.create_var(
type=src_var.type,
name=src_var.name,
shape=src_var.shape,
dtype=src_var.dtype,
lod_level=src_var.lod_level,
persistable=src_var.persistable,
error_clip=src_var.error_clip,
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer,
)
def _create_var(
self, src_block, dst_block, src_varname, force_create=False
):
if not force_create:
src_var = src_block.var(src_varname)
else:
src_var = src_block._var_recursive(src_varname)
if src_var.type in __not_shape_var_type__:
persist = getattr(src_var, 'persistable', False)
dst_block.create_var(
type=src_var.type,
name=src_var.name,
persistable=persist,
error_clip=src_var.error_clip,
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer,
)
else:
if isinstance(src_var, Parameter):
self._create_param(dst_block, src_var)
else:
self._create_inter(dst_block, src_var)
def _create_program(self, src_block, dst_block, src_op, force_create=False):
dst_op_desc = dst_block.desc.append_op()
dst_op_desc.copy_from(src_op.desc)
for input_varname in src_op.input_arg_names:
if src_block.has_var(input_varname) or (
force_create and src_block._find_var_recursive(input_varname)
):
self._create_var(
src_block, dst_block, input_varname, force_create
)
for output_varname in src_op.output_arg_names:
if src_block.has_var(output_varname) or (
force_create and src_block._find_var_recursive(output_varname)
):
self._create_var(
src_block, dst_block, output_varname, force_create
)
def _get_pp_stage(self, rank): def _get_pp_stage(self, rank):
pp_idx = None pp_idx = None
for idx, process_mesh in enumerate(self._dist_context.process_meshes): for idx, process_mesh in enumerate(self._dist_context.process_meshes):
...@@ -337,13 +169,13 @@ class PipelinePass(PassBase): ...@@ -337,13 +169,13 @@ class PipelinePass(PassBase):
# split the program based on the op_role # split the program based on the op_role
for op in src_block.ops: for op in src_block.ops:
if is_lr_sched_op(op): if is_lr_sched_op(op):
self._create_program(src_block, lr_block, op) _create_program(src_block, lr_block, op)
if is_forward_op(op): if is_forward_op(op):
self._create_program(src_block, fwd_block, op) _create_program(src_block, fwd_block, op)
elif is_backward_op(op): elif is_backward_op(op):
self._create_program(src_block, bwd_block, op) _create_program(src_block, bwd_block, op)
elif is_optimize_op(op): elif is_optimize_op(op):
self._create_program(src_block, opt_block, op) _create_program(src_block, opt_block, op)
else: else:
raise ValueError( raise ValueError(
"The op role: " "The op role: "
...@@ -505,11 +337,11 @@ class PipelinePass(PassBase): ...@@ -505,11 +337,11 @@ class PipelinePass(PassBase):
continue continue
if not is_after_while_op: if not is_after_while_op:
self._create_program( _create_program(
src_block, strat_block, op, force_create=True src_block, strat_block, op, force_create=True
) )
else: else:
self._create_program( _create_program(
src_block, end_block, op, force_create=True src_block, end_block, op, force_create=True
) )
elif ib == 1: elif ib == 1:
...@@ -572,7 +404,7 @@ class PipelinePass(PassBase): ...@@ -572,7 +404,7 @@ class PipelinePass(PassBase):
if op.type == "send_v2": if op.type == "send_v2":
remove_process_group(op.attr("ring_id")) remove_process_group(op.attr("ring_id"))
continue continue
self._create_program( _create_program(
src_block, send_block, op, force_create=True src_block, send_block, op, force_create=True
) )
continue continue
...@@ -615,7 +447,7 @@ class PipelinePass(PassBase): ...@@ -615,7 +447,7 @@ class PipelinePass(PassBase):
op.desc._rename_input( op.desc._rename_input(
in_name, recv_vars_name[in_name] in_name, recv_vars_name[in_name]
) )
self._create_program( _create_program(
src_block, recv_block, op, force_create=True src_block, recv_block, op, force_create=True
) )
continue continue
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.auto_parallel.static.utils import (
is_backward_op,
is_forward_op,
is_lr_sched_op,
is_optimize_op,
)
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.fluid import core
from paddle.fluid.framework import Parameter, Program
from .pass_base import PassBase, register_pass
__not_shape_var_type__ = [
core.VarDesc.VarType.READER,
core.VarDesc.VarType.STEP_SCOPES,
core.VarDesc.VarType.LOD_TENSOR_ARRAY,
core.VarDesc.VarType.FEED_MINIBATCH,
core.VarDesc.VarType.FETCH_LIST,
]
def _create_param(dst_block, src_var):
copied_kwargs = {}
copied_kwargs['trainable'] = src_var.trainable
copied_kwargs['optimize_attr'] = src_var.optimize_attr
copied_kwargs['regularizer'] = src_var.regularizer
copied_kwargs['do_model_average'] = src_var.do_model_average
copied_kwargs['need_clip'] = src_var.need_clip
Parameter(
block=dst_block,
type=src_var.type,
name=src_var.name,
shape=src_var.shape,
dtype=src_var.dtype,
lod_level=src_var.lod_level,
error_clip=src_var.error_clip,
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer,
**copied_kwargs
)
def _create_inter(dst_block, src_var):
dst_block.create_var(
type=src_var.type,
name=src_var.name,
shape=src_var.shape,
dtype=src_var.dtype,
lod_level=src_var.lod_level,
persistable=src_var.persistable,
error_clip=src_var.error_clip,
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer,
)
def _create_var(src_block, dst_block, src_varname, force_create=False):
if not force_create:
src_var = src_block.var(src_varname)
else:
src_var = src_block._var_recursive(src_varname)
if src_var.type in __not_shape_var_type__:
persist = getattr(src_var, 'persistable', False)
dst_block.create_var(
type=src_var.type,
name=src_var.name,
persistable=persist,
error_clip=src_var.error_clip,
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer,
)
else:
if isinstance(src_var, Parameter):
_create_param(dst_block, src_var)
else:
_create_inter(dst_block, src_var)
def _create_program(src_block, dst_block, src_op, force_create=False):
dst_op_desc = dst_block.desc.append_op()
dst_op_desc.copy_from(src_op.desc)
for input_varname in src_op.input_arg_names:
if src_block.has_var(input_varname) or (
force_create and src_block._find_var_recursive(input_varname)
):
_create_var(src_block, dst_block, input_varname, force_create)
for output_varname in src_op.output_arg_names:
if src_block.has_var(output_varname) or (
force_create and src_block._find_var_recursive(output_varname)
):
_create_var(src_block, dst_block, output_varname, force_create)
def _insert_sync_for_fthenb_1f1b(program):
"""
This implementation refers to lots of Paddle/python/paddle/fluid/optimizer.py.
The difference between this function with 'PipelineOptimizer' is that
'send_v2' op and 'recv_v2' op have been inserted in program by 'reshard'.
"""
for block in program.blocks:
offset = 0
first_optimize_index = None
for index, op in enumerate(list(block.ops)):
if is_optimize_op(op):
first_optimize_index = index
break
# insert sync ops
for index, op in enumerate(list(block.ops)):
# NOTE: pipeline might hang when dynamic_shape is True
if op.type in ['send_v2', 'recv_v2']:
op._set_attr("dynamic_shape", False)
# set send op on comm stream
if op.type == 'send_v2':
# step1: set 'use_calc_stream' False
op._set_attr("use_calc_stream", False)
op_role = op.attr('op_role')
ring_id = op.attr('ring_id')
# step2: insert 'c_sync_calc_stream' op before 'send_v2' op
var_name = op.input_arg_names[0]
var = block.var(var_name)
block._insert_op_without_sync(
index=index + offset,
type="c_sync_calc_stream",
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={'op_role': op_role},
)
offset += 1
# step3: insert 'c_sync_comm_stream' op after 'send_v2' op or
# before the first optimize op
if int(op_role) == int(OpRole.Backward):
index = first_optimize_index + offset
new_op_role = OpRole.Optimize
else:
index = index + offset + 1
new_op_role = OpRole.Backward
sync_comm_op = block._insert_op_without_sync(
index=index,
type="c_sync_comm_stream",
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
'op_role': new_op_role,
'ring_id': ring_id,
},
)
# step4: If 'send_v2' op in forward parse, set 'pipeline_flag' to distinguish
# whether the 'c_sync_comm_stream' op is inserted for pipeline.
if int(op_role) == int(OpRole.Forward):
sync_comm_op._set_attr('pipeline_flag', '')
offset += 1
block._sync_with_cpp()
offset = 0
backward_recv_index = None
for index, op in enumerate(block.ops):
if op.type == "recv_v2" and is_backward_op(op):
backward_recv_index = index
break
if backward_recv_index is None:
continue
# replace 'c_sync_comm_stream' op with 'nop' op
# use nop op for gc
for index, op in enumerate(list(block.ops)):
if index >= backward_recv_index:
break
if op.type == 'c_sync_comm_stream' and op.has_attr('pipeline_flag'):
var_name = op.output_arg_names[0]
var = block.var(var_name)
block._remove_op(index + offset, sync=False)
offset -= 1
block._insert_op_without_sync(
index=backward_recv_index,
type="nop",
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={'op_role': OpRole.Backward},
)
block._sync_with_cpp()
def _program_for_fthenb_and_1f1b(program):
lr_prog = Program()
fwd_prog = Program()
bwd_prog = Program()
opt_prog = Program()
for idx, src_block in enumerate(program.blocks):
if idx == 0:
lr_block = lr_prog.block(0)
fwd_block = fwd_prog.block(0)
bwd_block = bwd_prog.block(0)
opt_block = opt_prog.block(0)
else:
lr_block = lr_prog._create_block(parent_idx=src_block.parent_idx)
fwd_block = fwd_prog._create_block(parent_idx=src_block.parent_idx)
bwd_block = bwd_prog._create_block(parent_idx=src_block.parent_idx)
opt_block = opt_prog._create_block(parent_idx=src_block.parent_idx)
lr_block._set_forward_block_idx(src_block.forward_block_idx)
fwd_block._set_forward_block_idx(src_block.forward_block_idx)
bwd_block._set_forward_block_idx(src_block.forward_block_idx)
opt_block._set_forward_block_idx(src_block.forward_block_idx)
# split the program based on the op_role
for op in src_block.ops:
if is_lr_sched_op(op):
_create_program(src_block, lr_block, op)
if is_forward_op(op):
_create_program(src_block, fwd_block, op)
elif is_backward_op(op):
_create_program(src_block, bwd_block, op)
elif is_optimize_op(op):
_create_program(src_block, opt_block, op)
else:
raise ValueError(
"The op role: "
+ str(op.attr('op_role'))
+ " isn't one of LRSched, Forward, Backward or Optimizer."
)
lr_prog._sync_with_cpp()
fwd_prog._sync_with_cpp()
bwd_prog._sync_with_cpp()
opt_prog._sync_with_cpp()
lr_prog._rollback()
fwd_prog._rollback()
bwd_prog._rollback()
opt_prog._rollback()
return {
"lr": lr_prog.desc,
"forward": fwd_prog.desc,
"backward": bwd_prog.desc,
"optimizer": opt_prog.desc,
}
@register_pass("pipeline_fthenb_scheduler")
class PipelineFThenBPass(PassBase):
def __init__(self):
super().__init__()
def _check_self(self):
return True
def _check_conflict(self, other_pass):
return True
def _create_job_list(self):
job_list = []
lr_job = core.job("lr")
job_list.append(lr_job)
for i in range(self._micro_batch_size):
forward_job = core.job("forward")
forward_job.set_micro_batch_id(i)
job_list.append(forward_job)
for i in range(self._micro_batch_size):
backward_job = core.job("backward")
backward_job.set_micro_batch_id(i)
job_list.append(backward_job)
opt_job = core.job("optimizer")
job_list = []
lr_job = core.job("lr")
job_list.append(lr_job)
for i in range(self._micro_batch_size):
forward_job = core.job("forward")
forward_job.set_micro_batch_id(i)
job_list.append(forward_job)
for i in range(self._micro_batch_size):
backward_job = core.job("backward")
backward_job.set_micro_batch_id(i)
job_list.append(backward_job)
opt_job = core.job("optimizer")
job_list.append(opt_job)
return job_list
def _apply_single_impl(self, main_program, startup_program, context):
self._micro_batch_size = self.get_attr("micro_batch_size")
self._program = main_program
_insert_sync_for_fthenb_1f1b(self._program)
type_to_program = _program_for_fthenb_and_1f1b(self._program)
job_list = self._create_job_list()
plan = core.plan(job_list, type_to_program)
context.set_attr("plan", plan)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from paddle import static
from paddle.distributed.passes import PassContext, new_pass
class TestStandaloneExecutorFThenBPlan(unittest.TestCase):
def test_standalone_executor_fthenb_plan(self):
config = {}
config["micro_batch_size"] = 4
pass_context = PassContext()
startup_program = static.Program()
main_program = static.Program()
pipeline_fthenb_pass = new_pass("pipeline_fthenb_scheduler", config)
pipeline_fthenb_pass.apply(
[main_program], [startup_program], pass_context
)
plan = pass_context.get_attr("plan")
job_type_list = []
for job in plan.job_list():
job_type_list.append(job.type())
expect_job_type_list = [
"lr",
"forward",
"forward",
"forward",
"forward",
"backward",
"backward",
"backward",
"backward",
"optimizer",
]
self.assertEqual(job_type_list, expect_job_type_list)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册