diff --git a/python/paddle/distributed/auto_parallel/cost/estimate_cost.py b/python/paddle/distributed/auto_parallel/cost/estimate_cost.py index 3fbb107db803ed0db153083b6dfd3c7748958de5..fba9d7919c406acd655f0840cb5eac75fb984617 100644 --- a/python/paddle/distributed/auto_parallel/cost/estimate_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/estimate_cost.py @@ -44,6 +44,8 @@ class CostEstimator: ) # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}} self._bubble_time_mapping = {} self._ordered_ops = [] + self.max_memories = {} + self.max_memory = None @property def loop_count(self): @@ -122,7 +124,7 @@ class CostEstimator: for i in range(loop_count): for op in ops: self._detailed_cost[op.desc.id()] = OrderedDict() - # if in the while sub block, the detail of cost is the last cost + # If in the while sub block, the detail of cost is the last cost detail = self._detailed_cost[op.desc.id()] detail["reshard_cost"] = OrderedDict() # detail["dist_op_cost"] = [] @@ -146,15 +148,15 @@ class CostEstimator: var = get_var_with_recursion(var_name, block, self.program) reshard_cost = resharder.get_cost(op, var, self.cluster) - # calc reshard cost + # Calc reshard cost if reshard_cost is not None: detail["reshard_cost"][var_name] = reshard_cost comm_costs = reshard_cost[0] local_comp_cost = reshard_cost[1] for comm_cost in comm_costs: - # time is cumulative in global cost and local cost, but memory and flops just are cumulative in global cost. - # comm sync + # Time is cumulative in global cost and local cost, but memory and flops just are cumulative in global cost. + # Comm sync for item in comm_cost: group_ranks, cost = item max_time = None @@ -182,7 +184,7 @@ class CostEstimator: for comp_cost in local_comp_cost[rank]: self.local_cost(rank).time += comp_cost.time - # calc dist op cost + # Calc dist op cost dist_op = dist_context.get_dist_op_for_program(op) op_dist_attr = dist_op.dist_attr processes = op_dist_attr.process_mesh.processes @@ -200,7 +202,7 @@ class CostEstimator: continue for item in dist_op_cost: if isinstance(item, list): - # comm sync + # Comm sync for comm_op_cost in item: max_time = None cost_time = {} @@ -221,9 +223,9 @@ class CostEstimator: self._bubble_time_mapping[rank] += ( max_time - cost_time[rank]) elif isinstance(item, dict): - # op just one + # Op just one for rank in processes: - # dp+pp+mp + # DP+PP+MP if rank not in item: continue self.local_cost(rank).time += item[rank].time @@ -266,7 +268,7 @@ class CostEstimator: return result memories = {} - max_memories = {} + self.max_memories = {} var_info = { } # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]} @@ -276,6 +278,10 @@ class CostEstimator: self._ordered_ops.sort(key=lambda x: x[0]) for op_id, op in self._ordered_ops: + if op.type in [ + "create_py_reader", "create_double_buffer_reader", "read" + ]: + continue dist_op = dist_context.get_dist_op_for_program(op) process_mesh = dist_op.dist_attr.process_mesh for var_name in op.input_arg_names: @@ -287,7 +293,7 @@ class CostEstimator: input_dims_mapping) if key not in var_info[var_name]: var_info[var_name][key] = {} - # it is even partition now + # It is even partition now if "memory" not in var_info[var_name][key]: var = dist_op.get_serial_input(var_name) global_sizes = var.shape @@ -325,6 +331,10 @@ class CostEstimator: has_used_vars = set() for op_id, op in self._ordered_ops: + if op.type in [ + "create_py_reader", "create_double_buffer_reader", "read" + ]: + continue can_free_memories = {} can_free_vars = set() dist_op = dist_context.get_dist_op_for_program(op) @@ -336,14 +346,14 @@ class CostEstimator: input_dims_mapping) has_used_var = var_name + key var = dist_op.get_serial_input(var_name) - # not used + # Not used if var_name + key not in has_used_vars: has_used_vars.add(has_used_var) for process in process_mesh.processes: if process not in memories: memories[process] = 0 memories[process] += var_info[var_name][key]["memory"] - # used + # Used else: if op_id == var_info[var_name][key]["position"][-1]: if has_used_var not in can_free_vars: @@ -362,14 +372,14 @@ class CostEstimator: output_dims_mapping) has_used_var = var_name + key var = dist_op.get_serial_output(var_name) - # not used + # Not used if var_name + key not in has_used_vars: has_used_vars.add(has_used_var) for process in process_mesh.processes: if process not in memories: memories[process] = 0 memories[process] += var_info[var_name][key]["memory"] - # used + # Used else: if op_id == var_info[var_name][key]["position"][-1]: if has_used_var not in can_free_vars: @@ -381,21 +391,22 @@ class CostEstimator: can_free_memories[process] += var_info[ var_name][key]["memory"] - # calc peak memory + # Calc peak memory for process in memories: - if process not in max_memories: - max_memories[process] = memories[process] + if process not in self.max_memories: + self.max_memories[process] = memories[process] else: - if memories[process] > max_memories[process]: - max_memories[process] = memories[process] + if memories[process] > self.max_memories[process]: + self.max_memories[process] = memories[process] - # free memory + # Free memory for process in can_free_memories: if process in memories: memories[process] -= can_free_memories[process] # Calculate the max memory in all ranks - max_memory = max(max_memories.values()) + max_memory = max(self.max_memories.values()) + self.max_memory = max_memory return max_memory @@ -409,3 +420,143 @@ class CostEstimator: self._estimate_core(dist_context, resharder, block) return self.global_cost + + def _print_tag(self, max_len, length): + tag = "+" + "-" * max_len + for i in range(length): + print(tag, end="") + if i == length - 1: + print("+") + + def _print_vals(self, vals, max_len): + for idx, val in enumerate(vals): + s = "|" + str(val).center(max_len) + print(s, end="") + if idx == len(vals) - 1: + print("|") + + def _pretty_print_memory_cost(self): + """Print memory of every rank prettily.""" + if not self.max_memories or not self.max_memory: + raise ValueError("Please calculate memory cost before print.") + + # Padding automatically + max_len = 0 + header = ["Rank", "Memory(MiB)"] + memories = [ + int(item // 1e6) for item in list(self.max_memories.values()) + ] + for memory in (memories + header): + if len(str(memory)) > max_len: + max_len = len(str(memory)) + max_len += 4 # for pretty print of center + + # Print tag + self._print_tag(max_len, len(header)) + + # Print header + self._print_vals(header, max_len) + + # Print tag + self._print_tag(max_len, len(header)) + + # Print rank and its memory + for i in range(len(self.max_memories)): + memory = memories[i] + vals = [i, memory] + self._print_vals(vals, max_len) + self._print_tag(max_len, len(header)) + + def _pretty_print_global(self): + """Print global execution time and max memory prettily.""" + if not self.max_memories or not self.max_memory: + raise ValueError("Please calculate cost before print.") + + # Padding automatically + max_len = 0 + header = ["Execution Time(ms)", "Max Memory(MiB)"] + vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)] + for memory in (vals + header): + if len(str(memory)) > max_len: + max_len = len(str(memory)) + max_len += 4 # for pretty print of center + + # Print tag + self._print_tag(max_len, len(header)) + + # Print header + self._print_vals(header, max_len) + + # Print tag + self._print_tag(max_len, len(header)) + + # Print exec time and max memory + self._print_vals(vals, max_len) + + # Print tag + self._print_tag(max_len, len(header)) + + def pretty_print_cost(self): + """Print cost prettily.""" + print("The global execution time and max memory are as follows:") + self._pretty_print_global() + print("The memory of every rank is as follows:") + self._pretty_print_memory_cost() + + +def get_cost_from_engine(engine, mode): + from ..utils import to_list + # Construct cost estimator by original main program + serial_main_prog = engine._serial_main_progs[mode].clone( + ) if mode in engine._serial_main_progs else engine._orig_main_prog.clone() + + serial_startup_prog = engine._serial_startup_progs[mode].clone( + ) if mode in engine._serial_startup_progs else engine._orig_startup_prog.clone( + ) + losses = to_list( + engine._loss) if (not isinstance(engine._loss, paddle.nn.Layer) + and not callable(engine._loss)) else engine._losses + + if mode in engine._dist_contexts: + dist_context = engine._dist_contexts[mode] + completer = engine._planners[mode].completer + else: + from ..completion import Completer + from ..dist_context import DistributedContext + dist_context = DistributedContext(serial_main_prog, serial_startup_prog, + engine._optimizer, losses, {}, + {"loss": losses}, engine._cluster, + engine._strategy) + completer = Completer(dist_context) + completer.complete_forward_annotation() + dist_context.block_state.parse_forward_blocks( + dist_context.serial_main_program) + + if mode == "eval" or mode == "predict": + cost_estimator = CostEstimator(serial_main_prog, engine._cluster) + elif mode == "train": + from ..parallelizer_v2 import Parallelizer + # Get serial main program with backward + serial_optimizer = engine._optimizer + parallelizer = Parallelizer(mode, completer, dist_context) + # Generate backward + loss_name = dist_context.serial_loss.name + serial_loss = serial_main_prog.global_block()._var_recursive(loss_name) + params_grads = parallelizer._generate_backward(serial_main_prog, + serial_startup_prog, + serial_loss) + + # Generate optimizer + optimizer_ops = parallelizer._generate_optimizer( + serial_main_prog, serial_startup_prog, serial_optimizer, + params_grads) + cost_estimator = CostEstimator(serial_main_prog, engine._cluster) + + # Estimate global_cost and max memory + global_cost = cost_estimator.estimate(dist_context) + max_memory = cost_estimator._estimate_max_memory_by_dist_op(dist_context) + + # Print the cost + cost_estimator.pretty_print_cost() + + return global_cost, max_memory diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 6ba4a7ad1a3ed073a0a9ae734edadb11264ef6ba..a2e1477f8873c9ec69e72cd3cfccede580695386 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -48,6 +48,8 @@ from .dist_context import DistributedContext, get_default_distributed_context from .strategy import Strategy from .interface import CollectionNames, get_collection from ..utils.log_utils import get_logger +from .utils import initialize_pg_in_full_mode +from .cost.estimate_cost import get_cost_from_engine class Engine: @@ -127,12 +129,6 @@ class Engine: "'model must be sub classes of `paddle.nn.Layer` or any callable function." ) self._model = model - - # if loss and not isinstance(loss, - # paddle.nn.Layer) and not callable(loss): - # raise TypeError( - # "'loss' must be sub classes of `paddle.nn.Layer` or any callable function." - # ) self._loss = loss if optimizer and not isinstance( @@ -201,6 +197,7 @@ class Engine: self._planned_mode = None self._dygraph_mode = False self._tuning = self._strategy.tuning + self._losses = None self.history = None @@ -487,6 +484,7 @@ class Engine: outputs = self.program_helper.output_vars labels = self.program_helper.label_vars losses = self.program_helper.loss_vars + self._losses = losses metrics = self.program_helper.metric_vars self._inputs = inputs @@ -512,6 +510,7 @@ class Engine: outputs = to_list(self._model(*inputs)) if mode != "predict" and self._loss: losses = to_list(self._loss(*(outputs + labels))) + self._losses = losses if mode != "predict" and (outputs or labels): for metric in self._metrics: @@ -519,6 +518,7 @@ class Engine: to_list(metric.compute(*(outputs + labels)))) else: losses = to_list(self._loss) + self.losses = losses default_ctx = get_default_distributed_context() if not default_ctx.has_annotation: @@ -648,11 +648,13 @@ class Engine: # instantiate communication by process_mapping. all_process_groups = get_all_process_groups() - # NOTE: add the comm init control in the future for auto search - for process_group in all_process_groups: - if self._cur_rank not in process_group.ranks: - continue - process_group.instantiate() + if self._strategy.auto_mode == "full": + initialize_pg_in_full_mode(all_process_groups, cur_rank) + else: + for process_group in all_process_groups: + if self._cur_rank not in process_group.ranks: + continue + process_group.instantiate() place = _get_device() if isinstance(place, fluid.CUDAPlace): @@ -1022,16 +1024,9 @@ class Engine: test_dataloader = self._prepare_dataloader_from_generator( dataset=test_data, - # feed_list=feed_list, capacity=70, - # use_double_buffer=use_double_buffer, iterable=False, - # return_list=return_list, - # use_multiprocess=use_multiprocess, - # drop_last=drop_last, - # places=places, batch_size=batch_size, - # epochs=epochs, steps_per_epoch=steps, collate_fn=collate_fn) @@ -1100,21 +1095,19 @@ class Engine: steps_per_epoch=steps_per_epoch) return dataloader - def dataloader_from_generator( - self, - dataset, - capacity=70, - use_double_buffer=True, - iterable=True, - # return_list=False, - use_multiprocess=False, - drop_last=True, - batch_size=1, - epochs=1, - steps_per_epoch=None, - collate_fn=None, - sample_split=1, - mode=None): + def dataloader_from_generator(self, + dataset, + capacity=70, + use_double_buffer=True, + iterable=True, + use_multiprocess=False, + drop_last=True, + batch_size=1, + epochs=1, + steps_per_epoch=None, + collate_fn=None, + sample_split=1, + mode=None): if mode is not None: self.to_mode(mode) self._inputs_spec, self._labels_spec = self._prepare_data_spec( @@ -1127,14 +1120,12 @@ class Engine: self._switch_mode(self._mode) dataloader = self._prepare_dataloader_from_generator( dataset=dataset, - # feed_list=feed_list, capacity=capacity, use_double_buffer=use_double_buffer, iterable=iterable, return_list=False, use_multiprocess=use_multiprocess, drop_last=drop_last, - # places=places, batch_size=batch_size, epochs=epochs, steps_per_epoch=steps_per_epoch, @@ -1187,20 +1178,7 @@ class Engine: assert self._inputs_spec and self._labels_spec, \ "Please call the dataloader(...) before calling prepare(...)" - def run( - self, - data=None, - # program=None, - feed=None, - fetch_list=None, - # feed_var_name='feed', - # fetch_var_name='fetch', - # scope=None, - # return_numpy=True, - # use_program_cache=False, - # return_merged=True, - # use_prune=False, - mode=None): + def run(self, data=None, feed=None, fetch_list=None, mode=None): if mode is not None: self.to_mode(mode) feed_dict = self._prepare_feed(data, feed, self._mode) @@ -1571,6 +1549,54 @@ class Engine: path, load_optimizer) return self._state_dict, self._dist_attr + def cost(self, inputs_spec=None, labels_spec=None, mode="train"): + """ + Get and Print cost, including memory of every rank, + max memory among all ranks, and the global cost of one step based on + communication cost(computation cost is 0 by default). + In the future, the flops information of every rank and global cost including + computation cost will be added. + + Args: + inputs_spec(InputSpec): The specification of inputs. Default: None. + labels_spec(InputSpec): The specification of labels. Default: None. + mode (str): The engine mode must be in ["train", "predict", "eval"]. Default: "train". + + Returns: + Return the global execution time (ms) and max memory (B). + + """ + # Check parallel mode + if self._strategy.auto_mode == "full": + print( + "The cost will be calcudated in the search process when the auto mode is full." + ) + return + + # Check mode + accepted_modes = ["train", "predict", "eval"] + if mode not in accepted_modes: + raise ValueError("The mode {} is not in accepted modes {}".format( + mode, accepted_modes)) + self.to_mode(mode) + + if inputs_spec is not None: + self._inputs_spec, self._labels_spec = inputs_spec, labels_spec + self._inputs, self._labels = self._prepare_data_tensor( + self._inputs_spec, self._labels_spec) + self._build(mode) + self._plan(mode) + else: + if _non_static_mode() or self._dygraph_mode: + raise ValueError( + "Please call `engine._prepare_program('mode')` firstly when in the static graph mode." + ) + + # Estimate the exec cost and max memory + global_cost, max_memory = get_cost_from_engine(self, mode) + + return global_cost.time, max_memory + @property def main_program(self): return self._dist_main_progs[self._mode][self._cur_rank] diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 00f86ca4d0c1e98fd346b8e07bb21c19e65aee34..cf6f506f8c5632c5c308d8b96ceee90765cca8bc 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1602,3 +1602,65 @@ def get_lr(optimizer): "'optimizer' must be object of class `paddle.optimizer.Optimizer`" \ " or `paddle.fluid.optimizer.Optimizer`, but got {}.".format(type(optimizer)) ) + + +def initialize_pg_in_full_mode(all_process_groups, cur_rank): + import socket + from ..collective import _get_global_env + + has_recv_by_socket = [] + # This is a magic number + magic_num = 500 + genv = _get_global_env() + cur_rank_ip, cur_rank_port = genv.current_endpoint.split(":") + cur_rank_recv_port = int(cur_rank_port) + magic_num + server_socket = None + # Large enough for recv rank + buff_size = 1024 + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.bind((cur_rank_ip, cur_rank_recv_port)) + # The 10 is an empirical value + server_socket.listen(10) + client_sockets = {} + for process_group in all_process_groups: + if cur_rank not in process_group.ranks: + continue + if len(process_group.ranks) == 2: + index = process_group.ranks.index(cur_rank) + is_send = True if index == 0 else False + if is_send: + recv_rank = process_group.ranks[1] + recv_rank_ip, recv_rank_port = genv.trainer_endpoints[ + recv_rank].split(":") + connect_port = int(recv_rank_port) + magic_num + client_socket = socket.socket(socket.AF_INET, + socket.SOCK_STREAM) + client_socket.connect((recv_rank_ip, connect_port)) + client_socket.send(str(cur_rank).encode('utf-8')) + rank = client_socket.recv(buff_size).decode('utf-8') + rank = int(rank) + if rank != recv_rank: + raise ValueError( + "Please check comm pair, the recv rank should be {} but got {}." + .format(recv_rank, rank)) + else: + print("It is able to instantiate {} as sender now.".format( + process_group.ranks)) + client_socket.close() + else: + send_rank = process_group.ranks[0] + while True: + if send_rank not in has_recv_by_socket: + client_socket, recv_addr = server_socket.accept() + rank = int(client_socket.recv(buff_size).decode()) + client_sockets[rank] = client_socket + has_recv_by_socket.append(rank) + else: + client_sockets[send_rank].send( + str(cur_rank).encode("utf-8")) + client_sockets[send_rank].close() + print("It is able to instantiate {} as recver now.". + format(process_group.ranks)) + break + process_group.instantiate() + server_socket.close() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py index 07eaac027e1dcfa7444a2b8ba16f859745721289..38287e98a219a672fa7ba4429c2bdae31a9e800b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py @@ -141,6 +141,7 @@ def train_high_level(fetch): # train train_dataset = MyDataset(batch_num * batch_size) eval_dataset1 = MyDataset(5 * batch_size) + history = engine.fit(train_data=train_dataset, epochs=2, batch_size=batch_size, @@ -354,9 +355,74 @@ def train_non_builtin_data_vars(): ) # call DataLoader.reset() after catching EOFException +def get_cost(): + main_program = static.default_main_program() + startup_program = static.default_startup_program() + with static.program_guard(main_program, + startup_program), utils.unique_name.guard(): + input = static.data(name="input", + shape=[batch_size, image_size], + dtype='float32') + label = static.data(name="label", shape=[batch_size, 1], dtype='int64') + + loader = paddle.io.DataLoader.from_generator(feed_list=[input, label], + capacity=4 * batch_size, + iterable=False) + places = static.cuda_places() + loader.set_batch_generator(batch_generator_creator(), places=places) + + mlp = MLPLayer(hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + loss = paddle.nn.CrossEntropyLoss() + optimizer = paddle.optimizer.Adam(learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + metric = paddle.metric.Accuracy() + predict = mlp(input) + loss_var = loss(predict, label) + + strategy = auto.Strategy() + strategy.auto_mode = "semi" + + engine = auto.Engine(loss=loss_var, + optimizer=optimizer, + metrics=metric, + strategy=strategy) + engine.cost() + + +def get_cost_by_spec(): + mlp = MLPLayer(hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + loss = paddle.nn.CrossEntropyLoss() + optimizer = paddle.optimizer.Adam(learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + metric = paddle.metric.Accuracy() + + strategy = auto.Strategy() + strategy.auto_mode = "semi" + + engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy) + + input_spec = static.InputSpec([batch_size, image_size], 'float32', 'input') + label_spec = static.InputSpec([batch_size, 1], 'int64', 'label') + engine.cost(mode="eval", inputs_spec=[input_spec], labels_spec=[label_spec]) + + if __name__ == "__main__": train_high_level(fetch=True) train_high_level(fetch=False) train_low_level() train_builtin_data_vars() train_non_builtin_data_vars() + get_cost() + get_cost_by_spec()