未验证 提交 4b3589fb 编写于 作者: Z zhaoyingli 提交者: GitHub

2.4/fix engine build (#47462)

* update codestyle

* [AutoParallel] fix fp16 for subblock (#47189)

* [AutoParallel] fix fp16 for subblock

* fix engine

* fix comment

* [AutoParallel] fix engine _build and cost method (#47263)

* fix engine build method

* fix import

* update engine cost

* update raise error

* update cmakelist

* revert optimizer

* revert optimizer

* fix unittest

* fix unittest
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
上级 f93e9a58
......@@ -27,12 +27,9 @@ from ..dist_tensor import DistributedTensor
class CostEstimator:
_sepical_op_type = ["fused_attention", "fused_feedforward"]
def __init__(self,
program,
cluster,
mode="modeling",
rank=None,
loop_count=10):
def __init__(
self, program, cluster, mode="modeling", rank=None, loop_count=10
):
self._program = program
self._cluster = cluster
self._check_mode(mode)
......@@ -41,7 +38,8 @@ class CostEstimator:
self._loop_count = loop_count
self._global_cost = Cost()
self._local_cost_mapping = {}
self._detailed_cost = OrderedDict(
self._detailed_cost = (
OrderedDict()
) # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}}
self._bubble_time_mapping = {}
self._ordered_ops = []
......@@ -106,7 +104,8 @@ class CostEstimator:
def _check_mode(self, mode):
if mode not in ["modeling", "profiling"]:
raise ValueError(
"Just support modeling and profiling, but got {}".format(mode))
"Just support modeling and profiling, but got {}".format(mode)
)
def _is_special_var_name(self, var_name):
special_var_name = ["lod_tensor_blocking_queue_0"]
......@@ -116,6 +115,7 @@ class CostEstimator:
def _estimate_core(self, dist_context, resharder, block):
from ..reshard import get_var_with_recursion
ops = block.ops
loop_count = None
if block.desc.id != self.program.global_block().desc.id:
......@@ -132,8 +132,9 @@ class CostEstimator:
if int(op.attr('op_role')) == int(OpRole.Optimize):
continue
if op.type in [
"create_py_reader", "create_double_buffer_reader",
"read"
"create_py_reader",
"create_double_buffer_reader",
"read",
]:
continue
......@@ -172,14 +173,16 @@ class CostEstimator:
max_time = rank_cost.time
for rank in group_ranks:
self.local_cost(
rank).time = max_time + cost.time
self.local_cost(rank).time = (
max_time + cost.time
)
if rank not in self._bubble_time_mapping:
self._bubble_time_mapping[rank] = 0
self._bubble_time_mapping[rank] += (
max_time - cost_time[rank])
max_time - cost_time[rank]
)
for rank in local_comp_cost:
for comp_cost in local_comp_cost[rank]:
......@@ -191,15 +194,19 @@ class CostEstimator:
processes = op_dist_attr.process_mesh.processes
container = get_distributed_operator_impl_container(
op_dist_attr.impl_type)
op_dist_attr.impl_type
)
dist_impl = container.impls[op_dist_attr.impl_idx]
dist_op_cost = dist_impl.calc_cost(op.attr('op_role'), dist_op,
dist_context, self.cluster)
dist_op_cost = dist_impl.calc_cost(
op.attr('op_role'), dist_op, dist_context, self.cluster
)
detail["dist_op_cost"] = dist_op_cost
if dist_op_cost is None:
assert dist_op.serial_op.type in CostEstimator._sepical_op_type
assert (
dist_op.serial_op.type in CostEstimator._sepical_op_type
)
continue
for item in dist_op_cost:
if isinstance(item, list):
......@@ -217,12 +224,14 @@ class CostEstimator:
if max_time < rank_cost.time:
max_time = rank_cost.time
for rank in group_ranks:
self.local_cost(
rank).time = max_time + comm_op_cost.time
self.local_cost(rank).time = (
max_time + comm_op_cost.time
)
if rank not in self._bubble_time_mapping:
self._bubble_time_mapping[rank] = 0
self._bubble_time_mapping[rank] += (
max_time - cost_time[rank])
max_time - cost_time[rank]
)
elif isinstance(item, dict):
# Op just one
for rank in processes:
......@@ -247,8 +256,11 @@ class CostEstimator:
dtype_factor = 8
elif dtype == paddle.float32 or dtype == paddle.int32:
dtype_factor = 4
elif dtype == paddle.float16 or dtype == paddle.bfloat16 \
or dtype == paddle.int16:
elif (
dtype == paddle.float16
or dtype == paddle.bfloat16
or dtype == paddle.int16
):
dtype_factor = 2
elif dtype == paddle.int8 or dtype == paddle.uint8:
dtype_factor = 1
......@@ -270,8 +282,9 @@ class CostEstimator:
memories = {}
self.max_memories = {}
var_info = {
} # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]}
var_info = (
{}
) # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]}
for block in self.program.blocks:
for op in block.ops:
......@@ -280,18 +293,22 @@ class CostEstimator:
for op_id, op in self._ordered_ops:
if op.type in [
"create_py_reader", "create_double_buffer_reader", "read"
"create_py_reader",
"create_double_buffer_reader",
"read",
]:
continue
dist_op = dist_context.get_dist_op_for_program(op)
process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
var_name)
var_name
)
if var_name not in var_info:
var_info[var_name] = {}
key = _convert_pm_and_dm_to_str(process_mesh,
input_dims_mapping)
key = _convert_pm_and_dm_to_str(
process_mesh, input_dims_mapping
)
if key not in var_info[var_name]:
var_info[var_name][key] = {}
# It is even partition now
......@@ -300,21 +317,27 @@ class CostEstimator:
global_sizes = var.shape
dtype = var.dtype
sizes = DistributedTensor.get_local_sizes(
global_sizes, input_dims_mapping, process_mesh.topology,
process_mesh.processes)
global_sizes,
input_dims_mapping,
process_mesh.topology,
process_mesh.processes,
)
var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype)
sizes, dtype
)
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
var_name)
var_name
)
if var_name not in var_info:
var_info[var_name] = {}
key = _convert_pm_and_dm_to_str(process_mesh,
output_dims_mapping)
key = _convert_pm_and_dm_to_str(
process_mesh, output_dims_mapping
)
if key not in var_info[var_name]:
var_info[var_name][key] = {}
if "memory" not in var_info[var_name][key]:
......@@ -322,10 +345,14 @@ class CostEstimator:
global_sizes = var.shape
dtype = var.dtype
sizes = DistributedTensor.get_local_sizes(
global_sizes, output_dims_mapping,
process_mesh.topology, process_mesh.processes)
global_sizes,
output_dims_mapping,
process_mesh.topology,
process_mesh.processes,
)
var_info[var_name][key]["memory"] = self._calculate_bytes(
sizes, dtype)
sizes, dtype
)
if "position" not in var_info[var_name][key]:
var_info[var_name][key]["position"] = []
var_info[var_name][key]["position"].append(op_id)
......@@ -333,7 +360,9 @@ class CostEstimator:
has_used_vars = set()
for op_id, op in self._ordered_ops:
if op.type in [
"create_py_reader", "create_double_buffer_reader", "read"
"create_py_reader",
"create_double_buffer_reader",
"read",
]:
continue
can_free_memories = {}
......@@ -342,9 +371,11 @@ class CostEstimator:
process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names:
input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
var_name)
key = _convert_pm_and_dm_to_str(process_mesh,
input_dims_mapping)
var_name
)
key = _convert_pm_and_dm_to_str(
process_mesh, input_dims_mapping
)
has_used_var = var_name + key
var = dist_op.get_serial_input(var_name)
# Not used
......@@ -364,13 +395,16 @@ class CostEstimator:
if process not in can_free_memories:
can_free_memories[process] = 0
can_free_memories[process] += var_info[
var_name][key]["memory"]
var_name
][key]["memory"]
for var_name in op.output_arg_names:
output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
var_name)
key = _convert_pm_and_dm_to_str(process_mesh,
output_dims_mapping)
var_name
)
key = _convert_pm_and_dm_to_str(
process_mesh, output_dims_mapping
)
has_used_var = var_name + key
var = dist_op.get_serial_output(var_name)
# Not used
......@@ -390,7 +424,8 @@ class CostEstimator:
if process not in can_free_memories:
can_free_memories[process] = 0
can_free_memories[process] += var_info[
var_name][key]["memory"]
var_name
][key]["memory"]
# Calc peak memory
for process in memories:
......@@ -414,8 +449,12 @@ class CostEstimator:
def estimate(self, dist_context, resharder=None):
self.prepare()
from ..reshard import Resharder
resharder = Resharder(self.program, None, self.rank, dist_context,
[]) if resharder is None else resharder
resharder = (
Resharder(self.program, None, self.rank, dist_context, [])
if resharder is None
else resharder
)
block = self.program.global_block()
self._estimate_core(dist_context, resharder, block)
......@@ -447,7 +486,7 @@ class CostEstimator:
memories = [
int(item // 1e6) for item in list(self.max_memories.values())
]
for memory in (memories + header):
for memory in memories + header:
if len(str(memory)) > max_len:
max_len = len(str(memory))
max_len += 4 # for pretty print of center
......@@ -477,7 +516,7 @@ class CostEstimator:
max_len = 0
header = ["Execution Time(ms)", "Max Memory(MiB)"]
vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)]
for memory in (vals + header):
for memory in vals + header:
if len(str(memory)) > max_len:
max_len = len(str(memory))
max_len += 4 # for pretty print of center
......@@ -507,50 +546,73 @@ class CostEstimator:
def get_cost_from_engine(engine, mode):
from ..utils import to_list
# Construct cost estimator by original main program
serial_main_prog = engine._serial_main_progs[mode].clone(
) if mode in engine._serial_main_progs else engine._orig_main_prog.clone()
import copy
serial_startup_prog = engine._serial_startup_progs[mode].clone(
) if mode in engine._serial_startup_progs else engine._orig_startup_prog.clone(
# Construct cost estimator by original main program
serial_main_prog = (
engine._fwd_main_progs[mode].clone()
if mode in engine._fwd_main_progs
else engine._orig_main_prog.clone()
)
losses = to_list(
engine._loss) if (not isinstance(engine._loss, paddle.nn.Layer)
and not callable(engine._loss)) else engine._losses
if mode in engine._dist_contexts:
dist_context = engine._dist_contexts[mode]
completer = engine._planners[mode].completer
serial_startup_prog = (
engine._serial_startup_progs[mode].clone()
if mode in engine._serial_startup_progs
else engine._orig_startup_prog.clone()
)
losses = (
to_list(engine._loss)
if (
not isinstance(engine._loss, paddle.nn.Layer)
and not callable(engine._loss)
)
else engine._losses
)
serial_optimizer = copy.deepcopy(engine._orig_optimizer)
if mode in engine._fwd_dist_contexts:
dist_context = copy.deepcopy(engine._fwd_dist_contexts[mode])
else:
from ..completion import Completer
from ..dist_context import DistributedContext
dist_context = DistributedContext(serial_main_prog, serial_startup_prog,
engine._optimizer, losses, {},
{"loss": losses}, engine._cluster,
engine._strategy)
completer = Completer(dist_context)
completer.complete_forward_annotation()
dist_context.block_state.parse_forward_blocks(
dist_context.serial_main_program)
dist_context = DistributedContext(
serial_main_prog,
serial_startup_prog,
serial_optimizer,
losses,
{},
{"loss": losses},
engine._cluster,
engine._strategy,
)
from ..completion import Completer
completer = Completer(dist_context)
completer.complete_forward_annotation()
dist_context.block_state.parse_forward_blocks(
dist_context.serial_main_program
)
if mode == "eval" or mode == "predict":
cost_estimator = CostEstimator(serial_main_prog, engine._cluster)
elif mode == "train":
from ..parallelizer_v2 import Parallelizer
# Get serial main program with backward
serial_optimizer = engine._optimizer
parallelizer = Parallelizer(mode, completer, dist_context)
# Generate backward
loss_name = dist_context.serial_loss.name
serial_loss = serial_main_prog.global_block()._var_recursive(loss_name)
params_grads = parallelizer._generate_backward(serial_main_prog,
serial_startup_prog,
serial_loss)
params_grads = parallelizer._generate_backward(
serial_main_prog, serial_startup_prog, serial_loss
)
# Generate optimizer
optimizer_ops = parallelizer._generate_optimizer(
serial_main_prog, serial_startup_prog, serial_optimizer,
params_grads)
serial_main_prog,
serial_startup_prog,
serial_optimizer,
params_grads,
)
cost_estimator = CostEstimator(serial_main_prog, engine._cluster)
# Estimate global_cost and max memory
......
......@@ -63,6 +63,15 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_engine_callbacks MODULES test_engine_callbacks)
set_tests_properties(test_engine_callbacks
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS})
set_tests_properties(test_parallel_tuner PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_full MODULES test_parallel_tuner_full
ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_full PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_predict MODULES
test_parallel_tuner_predict ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
py_test_modules(test_while_op_completion MODULES test_while_op_completion
ENVS ${dist_ENVS})
......@@ -90,6 +99,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_prim_dist_op MODULES test_prim_dist_op ENVS ${dist_ENVS})
py_test_modules(test_to_static MODULES test_to_static ENVS ${dist_ENVS})
py_test_modules(test_dist_op_cost MODULES test_dist_op_cost ENVS ${dist_ENVS})
py_test_modules(test_cluster_v2 MODULES test_cluster_v2)
py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2)
py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2)
......@@ -99,20 +109,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_interface MODULES test_interface)
py_test_modules(test_strategy MODULES test_strategy)
py_test_modules(test_pass_quantization MODULES test_pass_quantization)
py_test_modules(test_dist_shape MODULES test_dist_shape)
py_test_modules(test_dist_assign MODULES test_dist_assign)
py_test_modules(test_conditional_block_reshard MODULES
test_conditional_block_reshard)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS})
set_tests_properties(test_parallel_tuner PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_full MODULES test_parallel_tuner_full
ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_full PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_predict MODULES
test_parallel_tuner_predict ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
py_test_modules(test_engine_api_error MODULES test_engine_api_error)
endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册