未验证 提交 4b3589fb 编写于 作者: Z zhaoyingli 提交者: GitHub

2.4/fix engine build (#47462)

* update codestyle

* [AutoParallel] fix fp16 for subblock (#47189)

* [AutoParallel] fix fp16 for subblock

* fix engine

* fix comment

* [AutoParallel] fix engine _build and cost method (#47263)

* fix engine build method

* fix import

* update engine cost

* update raise error

* update cmakelist

* revert optimizer

* revert optimizer

* fix unittest

* fix unittest
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
上级 f93e9a58
...@@ -27,12 +27,9 @@ from ..dist_tensor import DistributedTensor ...@@ -27,12 +27,9 @@ from ..dist_tensor import DistributedTensor
class CostEstimator: class CostEstimator:
_sepical_op_type = ["fused_attention", "fused_feedforward"] _sepical_op_type = ["fused_attention", "fused_feedforward"]
def __init__(self, def __init__(
program, self, program, cluster, mode="modeling", rank=None, loop_count=10
cluster, ):
mode="modeling",
rank=None,
loop_count=10):
self._program = program self._program = program
self._cluster = cluster self._cluster = cluster
self._check_mode(mode) self._check_mode(mode)
...@@ -41,7 +38,8 @@ class CostEstimator: ...@@ -41,7 +38,8 @@ class CostEstimator:
self._loop_count = loop_count self._loop_count = loop_count
self._global_cost = Cost() self._global_cost = Cost()
self._local_cost_mapping = {} self._local_cost_mapping = {}
self._detailed_cost = OrderedDict( self._detailed_cost = (
OrderedDict()
) # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}} ) # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}}
self._bubble_time_mapping = {} self._bubble_time_mapping = {}
self._ordered_ops = [] self._ordered_ops = []
...@@ -106,7 +104,8 @@ class CostEstimator: ...@@ -106,7 +104,8 @@ class CostEstimator:
def _check_mode(self, mode): def _check_mode(self, mode):
if mode not in ["modeling", "profiling"]: if mode not in ["modeling", "profiling"]:
raise ValueError( raise ValueError(
"Just support modeling and profiling, but got {}".format(mode)) "Just support modeling and profiling, but got {}".format(mode)
)
def _is_special_var_name(self, var_name): def _is_special_var_name(self, var_name):
special_var_name = ["lod_tensor_blocking_queue_0"] special_var_name = ["lod_tensor_blocking_queue_0"]
...@@ -116,6 +115,7 @@ class CostEstimator: ...@@ -116,6 +115,7 @@ class CostEstimator:
def _estimate_core(self, dist_context, resharder, block): def _estimate_core(self, dist_context, resharder, block):
from ..reshard import get_var_with_recursion from ..reshard import get_var_with_recursion
ops = block.ops ops = block.ops
loop_count = None loop_count = None
if block.desc.id != self.program.global_block().desc.id: if block.desc.id != self.program.global_block().desc.id:
...@@ -132,8 +132,9 @@ class CostEstimator: ...@@ -132,8 +132,9 @@ class CostEstimator:
if int(op.attr('op_role')) == int(OpRole.Optimize): if int(op.attr('op_role')) == int(OpRole.Optimize):
continue continue
if op.type in [ if op.type in [
"create_py_reader", "create_double_buffer_reader", "create_py_reader",
"read" "create_double_buffer_reader",
"read",
]: ]:
continue continue
...@@ -172,14 +173,16 @@ class CostEstimator: ...@@ -172,14 +173,16 @@ class CostEstimator:
max_time = rank_cost.time max_time = rank_cost.time
for rank in group_ranks: for rank in group_ranks:
self.local_cost( self.local_cost(rank).time = (
rank).time = max_time + cost.time max_time + cost.time
)
if rank not in self._bubble_time_mapping: if rank not in self._bubble_time_mapping:
self._bubble_time_mapping[rank] = 0 self._bubble_time_mapping[rank] = 0
self._bubble_time_mapping[rank] += ( self._bubble_time_mapping[rank] += (
max_time - cost_time[rank]) max_time - cost_time[rank]
)
for rank in local_comp_cost: for rank in local_comp_cost:
for comp_cost in local_comp_cost[rank]: for comp_cost in local_comp_cost[rank]:
...@@ -191,15 +194,19 @@ class CostEstimator: ...@@ -191,15 +194,19 @@ class CostEstimator:
processes = op_dist_attr.process_mesh.processes processes = op_dist_attr.process_mesh.processes
container = get_distributed_operator_impl_container( container = get_distributed_operator_impl_container(
op_dist_attr.impl_type) op_dist_attr.impl_type
)
dist_impl = container.impls[op_dist_attr.impl_idx] dist_impl = container.impls[op_dist_attr.impl_idx]
dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op, dist_op_cost = dist_impl.calc_cost(
dist_context, self.cluster) op.attr('op_role'), dist_op, dist_context, self.cluster
)
detail["dist_op_cost"] = dist_op_cost detail["dist_op_cost"] = dist_op_cost
if dist_op_cost is None: if dist_op_cost is None:
assert dist_op.serial_op.type in CostEstimator._sepical_op_type assert (
dist_op.serial_op.type in CostEstimator._sepical_op_type
)
continue continue
for item in dist_op_cost: for item in dist_op_cost:
if isinstance(item, list): if isinstance(item, list):
...@@ -217,12 +224,14 @@ class CostEstimator: ...@@ -217,12 +224,14 @@ class CostEstimator:
if max_time < rank_cost.time: if max_time < rank_cost.time:
max_time = rank_cost.time max_time = rank_cost.time
for rank in group_ranks: for rank in group_ranks:
self.local_cost( self.local_cost(rank).time = (
rank).time = max_time + comm_op_cost.time max_time + comm_op_cost.time
)
if rank not in self._bubble_time_mapping: if rank not in self._bubble_time_mapping:
self._bubble_time_mapping[rank] = 0 self._bubble_time_mapping[rank] = 0
self._bubble_time_mapping[rank] += ( self._bubble_time_mapping[rank] += (
max_time - cost_time[rank]) max_time - cost_time[rank]
)
elif isinstance(item, dict): elif isinstance(item, dict):
# Op just one # Op just one
for rank in processes: for rank in processes:
...@@ -247,8 +256,11 @@ class CostEstimator: ...@@ -247,8 +256,11 @@ class CostEstimator:
dtype_factor = 8 dtype_factor = 8
elif dtype == paddle.float32 or dtype == paddle.int32: elif dtype == paddle.float32 or dtype == paddle.int32:
dtype_factor = 4 dtype_factor = 4
elif dtype == paddle.float16 or dtype == paddle.bfloat16 \ elif (
or dtype == paddle.int16: dtype == paddle.float16
or dtype == paddle.bfloat16
or dtype == paddle.int16
):
dtype_factor = 2 dtype_factor = 2
elif dtype == paddle.int8 or dtype == paddle.uint8: elif dtype == paddle.int8 or dtype == paddle.uint8:
dtype_factor = 1 dtype_factor = 1
...@@ -270,8 +282,9 @@ class CostEstimator: ...@@ -270,8 +282,9 @@ class CostEstimator:
memories = {} memories = {}
self.max_memories = {} self.max_memories = {}
var_info = { var_info = (
} # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]} {}
) # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]}
for block in self.program.blocks: for block in self.program.blocks:
for op in block.ops: for op in block.ops:
...@@ -280,18 +293,22 @@ class CostEstimator: ...@@ -280,18 +293,22 @@ class CostEstimator:
for op_id, op in self._ordered_ops: for op_id, op in self._ordered_ops:
if op.type in [ if op.type in [
"create_py_reader", "create_double_buffer_reader", "read" "create_py_reader",
"create_double_buffer_reader",
"read",
]: ]:
continue continue
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
process_mesh = dist_op.dist_attr.process_mesh process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
var_name) var_name
)
if var_name not in var_info: if var_name not in var_info:
var_info[var_name] = {} var_info[var_name] = {}
key = _convert_pm_and_dm_to_str(process_mesh, key = _convert_pm_and_dm_to_str(
input_dims_mapping) process_mesh, input_dims_mapping
)
if key not in var_info[var_name]: if key not in var_info[var_name]:
var_info[var_name][key] = {} var_info[var_name][key] = {}
# It is even partition now # It is even partition now
...@@ -300,21 +317,27 @@ class CostEstimator: ...@@ -300,21 +317,27 @@ class CostEstimator:
global_sizes = var.shape global_sizes = var.shape
dtype = var.dtype dtype = var.dtype
sizes = DistributedTensor.get_local_sizes( sizes = DistributedTensor.get_local_sizes(
global_sizes, input_dims_mapping, process_mesh.topology, global_sizes,
process_mesh.processes) input_dims_mapping,
process_mesh.topology,
process_mesh.processes,
)
var_info[var_name][key]["memory"] = self._calculate_bytes( var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype) sizes, dtype
)
if "position" not in var_info[var_name][key]: if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = [] var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id) var_info[var_name][key]["position"].append(op_id)
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping( output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
var_name) var_name
)
if var_name not in var_info: if var_name not in var_info:
var_info[var_name] = {} var_info[var_name] = {}
key = _convert_pm_and_dm_to_str(process_mesh, key = _convert_pm_and_dm_to_str(
output_dims_mapping) process_mesh, output_dims_mapping
)
if key not in var_info[var_name]: if key not in var_info[var_name]:
var_info[var_name][key] = {} var_info[var_name][key] = {}
if "memory" not in var_info[var_name][key]: if "memory" not in var_info[var_name][key]:
...@@ -322,10 +345,14 @@ class CostEstimator: ...@@ -322,10 +345,14 @@ class CostEstimator:
global_sizes = var.shape global_sizes = var.shape
dtype = var.dtype dtype = var.dtype
sizes = DistributedTensor.get_local_sizes( sizes = DistributedTensor.get_local_sizes(
global_sizes, output_dims_mapping, global_sizes,
process_mesh.topology, process_mesh.processes) output_dims_mapping,
process_mesh.topology,
process_mesh.processes,
)
var_info[var_name][key]["memory"] = self._calculate_bytes( var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype) sizes, dtype
)
if "position" not in var_info[var_name][key]: if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = [] var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id) var_info[var_name][key]["position"].append(op_id)
...@@ -333,7 +360,9 @@ class CostEstimator: ...@@ -333,7 +360,9 @@ class CostEstimator:
has_used_vars = set() has_used_vars = set()
for op_id, op in self._ordered_ops: for op_id, op in self._ordered_ops:
if op.type in [ if op.type in [
"create_py_reader", "create_double_buffer_reader", "read" "create_py_reader",
"create_double_buffer_reader",
"read",
]: ]:
continue continue
can_free_memories = {} can_free_memories = {}
...@@ -342,9 +371,11 @@ class CostEstimator: ...@@ -342,9 +371,11 @@ class CostEstimator:
process_mesh = dist_op.dist_attr.process_mesh process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping( input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
var_name) var_name
key = _convert_pm_and_dm_to_str(process_mesh, )
input_dims_mapping) key = _convert_pm_and_dm_to_str(
process_mesh, input_dims_mapping
)
has_used_var = var_name + key has_used_var = var_name + key
var = dist_op.get_serial_input(var_name) var = dist_op.get_serial_input(var_name)
# Not used # Not used
...@@ -364,13 +395,16 @@ class CostEstimator: ...@@ -364,13 +395,16 @@ class CostEstimator:
if process not in can_free_memories: if process not in can_free_memories:
can_free_memories[process] = 0 can_free_memories[process] = 0
can_free_memories[process] += var_info[ can_free_memories[process] += var_info[
var_name][key]["memory"] var_name
][key]["memory"]
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping( output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
var_name) var_name
key = _convert_pm_and_dm_to_str(process_mesh, )
output_dims_mapping) key = _convert_pm_and_dm_to_str(
process_mesh, output_dims_mapping
)
has_used_var = var_name + key has_used_var = var_name + key
var = dist_op.get_serial_output(var_name) var = dist_op.get_serial_output(var_name)
# Not used # Not used
...@@ -390,7 +424,8 @@ class CostEstimator: ...@@ -390,7 +424,8 @@ class CostEstimator:
if process not in can_free_memories: if process not in can_free_memories:
can_free_memories[process] = 0 can_free_memories[process] = 0
can_free_memories[process] += var_info[ can_free_memories[process] += var_info[
var_name][key]["memory"] var_name
][key]["memory"]
# Calc peak memory # Calc peak memory
for process in memories: for process in memories:
...@@ -414,8 +449,12 @@ class CostEstimator: ...@@ -414,8 +449,12 @@ class CostEstimator:
def estimate(self, dist_context, resharder=None): def estimate(self, dist_context, resharder=None):
self.prepare() self.prepare()
from ..reshard import Resharder from ..reshard import Resharder
resharder = Resharder(self.program, None, self.rank, dist_context,
[]) if resharder is None else resharder resharder = (
Resharder(self.program, None, self.rank, dist_context, [])
if resharder is None
else resharder
)
block = self.program.global_block() block = self.program.global_block()
self._estimate_core(dist_context, resharder, block) self._estimate_core(dist_context, resharder, block)
...@@ -447,7 +486,7 @@ class CostEstimator: ...@@ -447,7 +486,7 @@ class CostEstimator:
memories = [ memories = [
int(item // 1e6) for item in list(self.max_memories.values()) int(item // 1e6) for item in list(self.max_memories.values())
] ]
for memory in (memories + header): for memory in memories + header:
if len(str(memory)) > max_len: if len(str(memory)) > max_len:
max_len = len(str(memory)) max_len = len(str(memory))
max_len += 4 # for pretty print of center max_len += 4 # for pretty print of center
...@@ -477,7 +516,7 @@ class CostEstimator: ...@@ -477,7 +516,7 @@ class CostEstimator:
max_len = 0 max_len = 0
header = ["Execution Time(ms)", "Max Memory(MiB)"] header = ["Execution Time(ms)", "Max Memory(MiB)"]
vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)] vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)]
for memory in (vals + header): for memory in vals + header:
if len(str(memory)) > max_len: if len(str(memory)) > max_len:
max_len = len(str(memory)) max_len = len(str(memory))
max_len += 4 # for pretty print of center max_len += 4 # for pretty print of center
...@@ -507,50 +546,73 @@ class CostEstimator: ...@@ -507,50 +546,73 @@ class CostEstimator:
def get_cost_from_engine(engine, mode): def get_cost_from_engine(engine, mode):
from ..utils import to_list from ..utils import to_list
# Construct cost estimator by original main program import copy
serial_main_prog = engine._serial_main_progs[mode].clone(
) if mode in engine._serial_main_progs else engine._orig_main_prog.clone()
serial_startup_prog = engine._serial_startup_progs[mode].clone( # Construct cost estimator by original main program
) if mode in engine._serial_startup_progs else engine._orig_startup_prog.clone( serial_main_prog = (
engine._fwd_main_progs[mode].clone()
if mode in engine._fwd_main_progs
else engine._orig_main_prog.clone()
) )
losses = to_list(
engine._loss) if (not isinstance(engine._loss, paddle.nn.Layer)
and not callable(engine._loss)) else engine._losses
if mode in engine._dist_contexts: serial_startup_prog = (
dist_context = engine._dist_contexts[mode] engine._serial_startup_progs[mode].clone()
completer = engine._planners[mode].completer if mode in engine._serial_startup_progs
else engine._orig_startup_prog.clone()
)
losses = (
to_list(engine._loss)
if (
not isinstance(engine._loss, paddle.nn.Layer)
and not callable(engine._loss)
)
else engine._losses
)
serial_optimizer = copy.deepcopy(engine._orig_optimizer)
if mode in engine._fwd_dist_contexts:
dist_context = copy.deepcopy(engine._fwd_dist_contexts[mode])
else: else:
from ..completion import Completer
from ..dist_context import DistributedContext from ..dist_context import DistributedContext
dist_context = DistributedContext(serial_main_prog, serial_startup_prog,
engine._optimizer, losses, {}, dist_context = DistributedContext(
{"loss": losses}, engine._cluster, serial_main_prog,
engine._strategy) serial_startup_prog,
completer = Completer(dist_context) serial_optimizer,
completer.complete_forward_annotation() losses,
dist_context.block_state.parse_forward_blocks( {},
dist_context.serial_main_program) {"loss": losses},
engine._cluster,
engine._strategy,
)
from ..completion import Completer
completer = Completer(dist_context)
completer.complete_forward_annotation()
dist_context.block_state.parse_forward_blocks(
dist_context.serial_main_program
)
if mode == "eval" or mode == "predict": if mode == "eval" or mode == "predict":
cost_estimator = CostEstimator(serial_main_prog, engine._cluster) cost_estimator = CostEstimator(serial_main_prog, engine._cluster)
elif mode == "train": elif mode == "train":
from ..parallelizer_v2 import Parallelizer from ..parallelizer_v2 import Parallelizer
# Get serial main program with backward # Get serial main program with backward
serial_optimizer = engine._optimizer
parallelizer = Parallelizer(mode, completer, dist_context) parallelizer = Parallelizer(mode, completer, dist_context)
# Generate backward # Generate backward
loss_name = dist_context.serial_loss.name loss_name = dist_context.serial_loss.name
serial_loss = serial_main_prog.global_block()._var_recursive(loss_name) serial_loss = serial_main_prog.global_block()._var_recursive(loss_name)
params_grads = parallelizer._generate_backward(serial_main_prog, params_grads = parallelizer._generate_backward(
serial_startup_prog, serial_main_prog, serial_startup_prog, serial_loss
serial_loss) )
# Generate optimizer # Generate optimizer
optimizer_ops = parallelizer._generate_optimizer( optimizer_ops = parallelizer._generate_optimizer(
serial_main_prog, serial_startup_prog, serial_optimizer, serial_main_prog,
params_grads) serial_startup_prog,
serial_optimizer,
params_grads,
)
cost_estimator = CostEstimator(serial_main_prog, engine._cluster) cost_estimator = CostEstimator(serial_main_prog, engine._cluster)
# Estimate global_cost and max memory # Estimate global_cost and max memory
......
...@@ -63,6 +63,15 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -63,6 +63,15 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_engine_callbacks MODULES test_engine_callbacks) py_test_modules(test_engine_callbacks MODULES test_engine_callbacks)
set_tests_properties(test_engine_callbacks set_tests_properties(test_engine_callbacks
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS})
set_tests_properties(test_parallel_tuner PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_full MODULES test_parallel_tuner_full
ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_full PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_predict MODULES
test_parallel_tuner_predict ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
py_test_modules(test_while_op_completion MODULES test_while_op_completion py_test_modules(test_while_op_completion MODULES test_while_op_completion
ENVS ${dist_ENVS}) ENVS ${dist_ENVS})
...@@ -90,6 +99,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -90,6 +99,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_prim_dist_op MODULES test_prim_dist_op ENVS ${dist_ENVS}) py_test_modules(test_prim_dist_op MODULES test_prim_dist_op ENVS ${dist_ENVS})
py_test_modules(test_to_static MODULES test_to_static ENVS ${dist_ENVS}) py_test_modules(test_to_static MODULES test_to_static ENVS ${dist_ENVS})
py_test_modules(test_dist_op_cost MODULES test_dist_op_cost ENVS ${dist_ENVS}) py_test_modules(test_dist_op_cost MODULES test_dist_op_cost ENVS ${dist_ENVS})
py_test_modules(test_cluster_v2 MODULES test_cluster_v2) py_test_modules(test_cluster_v2 MODULES test_cluster_v2)
py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2) py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2)
py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2) py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2)
...@@ -99,20 +109,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -99,20 +109,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_interface MODULES test_interface) py_test_modules(test_interface MODULES test_interface)
py_test_modules(test_strategy MODULES test_strategy) py_test_modules(test_strategy MODULES test_strategy)
py_test_modules(test_pass_quantization MODULES test_pass_quantization) py_test_modules(test_pass_quantization MODULES test_pass_quantization)
py_test_modules(test_dist_shape MODULES test_dist_shape) py_test_modules(test_dist_shape MODULES test_dist_shape)
py_test_modules(test_dist_assign MODULES test_dist_assign) py_test_modules(test_dist_assign MODULES test_dist_assign)
py_test_modules(test_conditional_block_reshard MODULES py_test_modules(test_conditional_block_reshard MODULES
test_conditional_block_reshard) test_conditional_block_reshard)
py_test_modules(test_engine_api_error MODULES test_engine_api_error)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS})
set_tests_properties(test_parallel_tuner PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_full MODULES test_parallel_tuner_full
ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_full PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_predict MODULES
test_parallel_tuner_predict ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
endif() endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册