未验证 提交 90b31790 编写于 作者: Z zhaoyingli 提交者: GitHub

[Cherry-Pick][AutoParallel] auto_parallel cherry-pick to release2.4 (#47145)

* [Auto Parallel] Make Engine class callable (#46416)

* [Auto Parallel] Imporve the user-defined fetches and logging

* [Auto Parallel] Make Engine class callable

* [Auto Parallel] Update the data loading of tuner

* Print IPS in auto parallel Engine (#46554)

* [AutoParallel] fix dist_split (#46505)

* [AutoParallel] fix dist_split

* add unittest

* update cmakelist

* [AutoParallel] fix sharding (#46572)

* [AutoParallel] fix process_mesh (#46583)

* [AutoParallel] fix reshard when train with eval (#46605)

* [AutoParallel] fix reshard when train with eval

* fix mppp

* [AutoParallel] fix amp when predict (#46637)

* [Auto Parallel]Update comp cost and completion for gpt auto search (#46387)

* update comp cost and completion for gpt auto search

* add unittest

* [Auto Parallel] Fix bugs caused by the inconsistent outputs of Engine API (#46633)

* [Auto Parallel] Unify the logger and outputs of Engine API

* [Auto Parallel] Fix the bugs of to_static

* [Auto Parallel] Adjust the test_to_static.py

* [Auto Parallel] Improve the fine-grained APIs (#46552)

* [Auto Parallel] Suppport different dataloaders

* [Auto Parallel] Add num_shards config for dataset

* [Auto Parallel] Unify the logger and outputs of Engine API

* [Auto Parallel] Fix the bugs of to_static

* [Auto Parallel] Adjust the test_to_static.py

* [Auto Parallel] Add the prepare API and replace __call__ with run

* [Auto Parallel] Improve the private implementations of Engine

* [Auto Parallel] Set capacity of dataloader for opt tuning

* [Auto Parallel] [WIP] Change the fine-grained API

* [Auto Parallel] Improve APIs to support different user cases

* [Auto Parallel] Add removed config

* [Auto Parallel] Add imports

* [Auto Parallel] Fix bugs for to_static

* [Auto Parallel] Remove unnecessary imports

* bugfix (#46921)

* [Auto Parallel] Fix the bug for None labels (#46987)

* [AutoParallel] adapt for gpt-gen (#46771)

* for gpt-gen

* fix reshard

* adapt assign and shape op

* add dist_assign & unittest

* add conditional block unittest

* rename unittest

* [Auto Parallel] Fix the bug of completion (#47056)

* [Auto Parallel] Fix the bug for None labels

* [Auto Parallel] Fix the completion bug

* [AutoParallel] add callbacks (#47014)

* [AutoParallel] add callbacks

* fix unittest

* fix dist_context

* fix engine

* fix cmakelist

* fix unittest's returns

* fix cmakelist

* [Auto Parallel] Add cost interface (#47043)

* add cost interface

* update inferface and add unittest

* update unittest

* update inferface

* [Auto Parallel]Add parallel tuner (#46189)

* add parallel tuner

* add unittest

* fix unittest

* set timeout of unittest

* set unittest timeout

* fix auto_mode setting

* update unittest

* sync from develop and update unittest

* remove unused import

* update unittest

* update cmakelist

* add unittests
Co-authored-by: NYulong Ao <aoyulong@baidu.com>
Co-authored-by: NRuibiao Chen <chenruibiao@baidu.com>
Co-authored-by: Ncaozhou <48191911+Caozhou1995@users.noreply.github.com>
Co-authored-by: NJZ-LIANG <jianzhongliang10@gmail.com>
上级 23f2a4ea
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import paddle
from paddle.hapi.callbacks import ProgBarLogger, ModelCheckpoint, LRScheduler, CallbackList, Callback
from .interface import CollectionNames, get_collection
def config_callbacks(callbacks=None,
engine=None,
batch_size=None,
epochs=None,
steps=None,
log_freq=2,
verbose=2,
save_freq=1,
save_dir=None,
metrics=None,
acc_step=1,
mode='train'):
cbks = callbacks or []
cbks = cbks if isinstance(cbks, (list, tuple)) else [cbks]
if not any(isinstance(k, ProgBarLogger) for k in cbks) and verbose:
cbks = [ProgBarLoggerAuto(log_freq, verbose=verbose)] + cbks
if not any(isinstance(k, LRScheduler) for k in cbks):
cbks = [LRSchedulerAuto()] + cbks
if not any(isinstance(k, ModelCheckpoint) for k in cbks):
cbks = cbks + [ModelCheckpointAuto(save_freq, save_dir)]
if not any(isinstance(k, Profiler) for k in cbks) and verbose == 3:
cbks = cbks + [Profiler(timer_only=True)]
if not any(isinstance(k, History) for k in cbks):
cbks = cbks + [History()]
for i, k in enumerate(cbks):
if isinstance(k, ProgBarLogger):
cbks[i] = ProgBarLoggerAuto(k.log_freq, k.verbose)
if isinstance(k, LRScheduler):
cbks[i] = LRSchedulerAuto(k.by_step, k.by_epoch)
if isinstance(k, ModelCheckpoint):
cbks[i] = ModelCheckpointAuto(k.save_freq, k.save_dir)
cbk_list = CallbackList(cbks)
cbk_list.set_model(engine)
metrics = metrics or [] if mode != 'test' else []
params = {
'batch_size': batch_size,
'epochs': epochs,
'steps': steps,
'verbose': verbose,
'metrics': metrics,
'acc_step': acc_step,
}
cbk_list.set_params(params)
return cbk_list
class ProgBarLoggerAuto(ProgBarLogger):
def __init__(self, log_freq=1, verbose=2):
super(ProgBarLoggerAuto, self).__init__(log_freq, verbose)
def _is_print(self):
return True
def _updates(self, logs, mode):
values = []
metrics = getattr(self, '%s_metrics' % (mode))
progbar = getattr(self, '%s_progbar' % (mode))
steps = getattr(self, '%s_step' % (mode))
for k in metrics:
if k in logs:
values.append((k, logs[k]))
if 'lr' in logs:
values.append(('lr', logs['lr']))
fetches_logs = logs.get('fetches', {})
collect_logging = get_collection(CollectionNames.LOGGING)
for name, var in collect_logging:
k = name or var.name
if k in fetches_logs:
values.append((k, fetches_logs[k]))
out_logs = logs.get('outputs', {})
for k in out_logs:
values.append((k, out_logs[k]))
if self.verbose == 3 and hasattr(self, '_%s_timer' % (mode)):
timer = getattr(self, '_%s_timer' % (mode))
cnt = timer['count'] if timer['count'] > 0 else 1.0
samples = timer['samples'] if timer['samples'] > 0 else 1.0
values.append(
('avg_reader_cost', "%.5f sec" % (timer['data_time'] / cnt)))
values.append(
('avg_batch_cost', "%.5f sec" % (timer['batch_time'] / cnt)))
values.append(
('ips', "%.5f samples/sec" %
(samples / (timer['data_time'] + timer['batch_time']))))
timer['count'] = 0
timer['samples'] = 0
timer['data_time'] = 0.
timer['batch_time'] = 0.
progbar.update(steps, values)
def on_eval_batch_end(self, step, logs=None):
logs = logs or {}
self.eval_step += 1
samples = self.params['batch_size']
self.evaled_samples += samples
self._eval_timer['batch_time'] += (
time.time() - self._eval_timer['batch_data_end_time'])
self._eval_timer['count'] += 1
samples = self.params['batch_size']
self._eval_timer['samples'] += samples
if self._is_print() and self.eval_step % self.log_freq == 0:
if self.eval_steps is None or self.eval_step < self.eval_steps:
self._updates(logs, 'eval')
self._eval_timer['batch_start_time'] = time.time()
class LRSchedulerAuto(LRScheduler):
def __init__(self, by_step=True, by_epoch=False):
super(LRSchedulerAuto, self).__init__(by_step, by_epoch)
def on_epoch_begin(self, epoch=None, logs=None):
self.acc_step = self.params["acc_step"]
self.epoch = epoch
self.train_step = 0
def on_train_batch_end(self, step, logs=None):
self.train_step += 1
if self.by_step and self.train_step % self.acc_step == 0:
if self.model._optimizer and \
hasattr(self.model._optimizer, '_learning_rate') and \
isinstance(self.model._optimizer._learning_rate,
paddle.optimizer.lr.LRScheduler):
self.model._optimizer._learning_rate.step()
class History(Callback):
def __init__(self):
self.history = {}
def on_train_begin(self, logs=None):
self.epoch = []
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epoch.append(epoch)
for k, v in logs.items():
self.history.setdefault(k, []).append(v)
self.model.history = self
class Profiler(Callback):
def __init__(self, *args, **kwargs):
self.prof = paddle.profiler.Profiler(*args, **kwargs)
def on_epoch_begin(self, epoch=None, logs=None):
self.epoch = epoch
self.train_step = 0
self.batch_size = self.params["batch_size"]
self.steps = self.params['steps']
def on_train_begin(self, logs=None):
self.prof.start()
def on_train_batch_end(self, step, logs=None):
self.train_step += 1
self.prof.step(num_samples=self.batch_size)
print("step {}:{}".format(self.train_step,
self.prof.step_info(unit='samples')))
def on_train_end(self, logs=None):
self.prof.stop()
self.prof.summary()
class ModelCheckpointAuto(ModelCheckpoint):
def __init__(self, *args, **kwargs):
super(ModelCheckpointAuto, self).__init__(*args, **kwargs)
def _is_save(self):
return self.model and self.save_dir
def on_epoch_end(self, epoch, logs=None):
if self._is_save() and (self.epoch + 1) % self.save_freq == 0:
path = '{}/epoch{}'.format(self.save_dir, epoch)
print('save checkpoint at {}'.format(os.path.abspath(path)))
self.model.save(path)
def on_train_end(self, logs=None):
if self._is_save():
path = '{}/final'.format(self.save_dir)
print('save checkpoint at {}'.format(os.path.abspath(path)))
self.model.save(path)
...@@ -19,7 +19,7 @@ import time ...@@ -19,7 +19,7 @@ import time
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from .utils import print_program_with_dist_attr, is_gradient_clip_op from .utils import is_gradient_clip_op, __not_shape_var_type__
from .operators import find_compatible_distributed_operator_impls from .operators import find_compatible_distributed_operator_impls
from .dist_context import get_default_distributed_context, _node_id from .dist_context import get_default_distributed_context, _node_id
from .dist_tensor import DistributedTensor from .dist_tensor import DistributedTensor
...@@ -142,6 +142,7 @@ class Completer: ...@@ -142,6 +142,7 @@ class Completer:
def __init__(self, dist_context): def __init__(self, dist_context):
assert dist_context is not None assert dist_context is not None
self._dist_context = dist_context self._dist_context = dist_context
self._has_prepared = False
def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True): def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
changed = False changed = False
...@@ -366,7 +367,14 @@ class Completer: ...@@ -366,7 +367,14 @@ class Completer:
def _update_dims_mapping_for_special(self): def _update_dims_mapping_for_special(self):
# Set the dims_mapping of a tensor to the dims_mapping inside the op which produces it # Set the dims_mapping of a tensor to the dims_mapping inside the op which produces it
op_nodes = self._dist_context._serial_ordered_op_nodes op_nodes = self._dist_context._serial_ordered_op_nodes
# NOTE: this list may be changed if Paddle changes the existing rules.
related_reader_ops = [
"create_py_reader", "create_double_buffer_reader", "read"
]
for op_node in op_nodes: for op_node in op_nodes:
if op_node.op() is not None \
and op_node.op().type() in related_reader_ops:
continue
op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node) op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node)
for tensor_node in op_node.outputs: for tensor_node in op_node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None: if tensor_node.is_var() and tensor_node.var() is not None:
...@@ -406,6 +414,7 @@ class Completer: ...@@ -406,6 +414,7 @@ class Completer:
reach_fix_point = False reach_fix_point = False
else: else:
reach_fix_point = True reach_fix_point = True
# NOTE: this will be removed after changing the reshard rule
self._update_dims_mapping_for_special() self._update_dims_mapping_for_special()
def _update_process_mesh_by_nearest(self, op_node, nearest_op_node): def _update_process_mesh_by_nearest(self, op_node, nearest_op_node):
...@@ -494,14 +503,14 @@ class Completer: ...@@ -494,14 +503,14 @@ class Completer:
for tensor_node in node.inputs: for tensor_node in node.inputs:
if tensor_node.is_var() and tensor_node.var( if tensor_node.is_var() and tensor_node.var(
) is not None: ) is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER \ if tensor_node.var().type() in __not_shape_var_type__ \
or len(tensor_node.var().shape()) != 1: or len(tensor_node.var().shape()) != 1:
flag = False flag = False
break break
for tensor_node in node.outputs: for tensor_node in node.outputs:
if tensor_node.is_var() and tensor_node.var( if tensor_node.is_var() and tensor_node.var(
) is not None: ) is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER \ if tensor_node.var().type() in __not_shape_var_type__ \
or len(tensor_node.var().shape()) != 1: or len(tensor_node.var().shape()) != 1:
flag = False flag = False
break break
...@@ -719,6 +728,8 @@ class Completer: ...@@ -719,6 +728,8 @@ class Completer:
self._update_process_mesh_between_graphs() self._update_process_mesh_between_graphs()
def _prepare(self): def _prepare(self):
if self._has_prepared:
return
self._while_op_nodes = {} self._while_op_nodes = {}
self._array_nodes = {} self._array_nodes = {}
self._node_pairs_between_graphs = [] self._node_pairs_between_graphs = []
...@@ -732,6 +743,8 @@ class Completer: ...@@ -732,6 +743,8 @@ class Completer:
if self._array_nodes.get(array_var_name, None) is None: if self._array_nodes.get(array_var_name, None) is None:
self._array_nodes[array_var_name] = [] self._array_nodes[array_var_name] = []
self._array_nodes[array_var_name].append(node) self._array_nodes[array_var_name].append(node)
# Add the array input node
self._array_nodes[array_var_name].append(node.inputs[0])
if node.op().type() == "write_to_array": if node.op().type() == "write_to_array":
array_var_name = node.op().output("Out")[0] array_var_name = node.op().output("Out")[0]
if self._array_nodes.get(array_var_name, None) is None: if self._array_nodes.get(array_var_name, None) is None:
...@@ -752,6 +765,7 @@ class Completer: ...@@ -752,6 +765,7 @@ class Completer:
and after_node.var().name() == node.var().name(): and after_node.var().name() == node.var().name():
self._node_pairs_between_graphs.append( self._node_pairs_between_graphs.append(
(after_node, node)) (after_node, node))
self._has_prepared = True
def complete_forward_annotation(self, serial_main_program=None): def complete_forward_annotation(self, serial_main_program=None):
""" Complete annotation for the partial annotated serial_main_program. """ Complete annotation for the partial annotated serial_main_program.
...@@ -899,6 +913,72 @@ class Completer: ...@@ -899,6 +913,72 @@ class Completer:
else: else:
dist_op.dist_attr = original_op_dist_attr dist_op.dist_attr = original_op_dist_attr
def _complete_tensor_dist_attr_by_op(self, serial_main_program=None):
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context._serial_main_program = serial_main_program
self._dist_context.initialize()
self._prepare()
has_set_dist_attr = set()
all_nodes = self._dist_context.serial_ordered_nodes
for node in all_nodes:
if node.is_op():
if node.op().type() in ["while"]:
continue
dist_op = self._dist_context.get_dist_op_for_graph(node)
op_dist_attr = dist_op.dist_attr
for tensor_node in node.inputs:
if tensor_node.is_var() and tensor_node.var() is not None:
# Skip the non-leaf var node
if len(tensor_node.inputs) != 0:
continue
tensor_desc = tensor_node.var()
tensor_name = tensor_desc.name()
tensor = dist_op.get_serial_input(tensor_name)
# Use the first op to set the tensor dist attr
if tensor_name in has_set_dist_attr:
continue
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dist_attr.process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_name) if tensor.is_parameter else [
-1 for i in tensor_desc.shape()
]
has_set_dist_attr.add(tensor_name)
for tensor_node in node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
tensor_name = tensor_node.var().name()
if tensor_name in has_set_dist_attr:
continue
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dist_attr.process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_name)
has_set_dist_attr.add(tensor_name)
self._update_process_mesh_for_specials()
self._update_process_mesh_between_graphs()
self._update_dims_mapping_for_special()
self._update_dims_mapping_between_graphs()
# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()
# Do the validation check and amend some completion
self._dist_context.amend_dist_attr_for_program()
self._dist_context.validate_dist_attr_for_program()
def _complete_high_order_grad_annotation(self, serial_main_program=None): def _complete_high_order_grad_annotation(self, serial_main_program=None):
""" """
NOTE: NOTE:
......
...@@ -116,3 +116,10 @@ set_field_default_config(TUNING, "profile_start_step", 1) ...@@ -116,3 +116,10 @@ set_field_default_config(TUNING, "profile_start_step", 1)
set_field_default_config(TUNING, "profile_end_step", 1) set_field_default_config(TUNING, "profile_end_step", 1)
set_field_default_config(TUNING, "run_after_tuning", True) set_field_default_config(TUNING, "run_after_tuning", True)
set_field_default_config(TUNING, "verbose", True) set_field_default_config(TUNING, "verbose", True)
#########################################
# dataset configuration
#########################################
DATASET = "dataset"
set_field_default_config(DATASET, "enable", False)
set_field_default_config(DATASET, "num_shards", 1)
...@@ -167,6 +167,25 @@ class DropoutOpCost(CompOpCost): ...@@ -167,6 +167,25 @@ class DropoutOpCost(CompOpCost):
return 0 return 0
@register_op_cost
class DropoutGradOpCost(CompOpCost):
OP_TYPE = "dropout_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(DropoutGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class ElementwiseAddOpCost(CompOpCost): class ElementwiseAddOpCost(CompOpCost):
OP_TYPE = "elementwise_add" OP_TYPE = "elementwise_add"
...@@ -395,6 +414,42 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost): ...@@ -395,6 +414,42 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost):
return 0 return 0
@register_op_cost
class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost):
OP_TYPE = "fused_softmax_mask_upper_triangle"
def __init__(self, op=None, op_desc=None, cluster=None):
super(FusedSoftmaxMaskUpperTriangleOpCost,
self).__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost):
OP_TYPE = "fused_softmax_mask_upper_triangle_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(FusedSoftmaxMaskUpperTriangleGradOpCost,
self).__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost @register_op_cost
class GatherOpCost(CompOpCost): class GatherOpCost(CompOpCost):
OP_TYPE = "gather" OP_TYPE = "gather"
......
...@@ -45,6 +45,8 @@ class CostEstimator: ...@@ -45,6 +45,8 @@ class CostEstimator:
) # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}} ) # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}}
self._bubble_time_mapping = {} self._bubble_time_mapping = {}
self._ordered_ops = [] self._ordered_ops = []
self.max_memories = {}
self.max_memory = None
@property @property
def loop_count(self): def loop_count(self):
...@@ -123,7 +125,7 @@ class CostEstimator: ...@@ -123,7 +125,7 @@ class CostEstimator:
for i in range(loop_count): for i in range(loop_count):
for op in ops: for op in ops:
self._detailed_cost[op.desc.id()] = OrderedDict() self._detailed_cost[op.desc.id()] = OrderedDict()
# if in the while sub block, the detail of cost is the last cost # If in the while sub block, the detail of cost is the last cost
detail = self._detailed_cost[op.desc.id()] detail = self._detailed_cost[op.desc.id()]
detail["reshard_cost"] = OrderedDict() # detail["reshard_cost"] = OrderedDict() #
detail["dist_op_cost"] = [] detail["dist_op_cost"] = []
...@@ -147,15 +149,15 @@ class CostEstimator: ...@@ -147,15 +149,15 @@ class CostEstimator:
var = get_var_with_recursion(var_name, block, self.program) var = get_var_with_recursion(var_name, block, self.program)
reshard_cost = resharder.get_cost(op, var, self.cluster) reshard_cost = resharder.get_cost(op, var, self.cluster)
# calc reshard cost # Calc reshard cost
if reshard_cost is not None: if reshard_cost is not None:
detail["reshard_cost"][var_name] = reshard_cost detail["reshard_cost"][var_name] = reshard_cost
comm_costs = reshard_cost[0] comm_costs = reshard_cost[0]
local_comp_cost = reshard_cost[1] local_comp_cost = reshard_cost[1]
for comm_cost in comm_costs: for comm_cost in comm_costs:
# time is cumulative in global cost and local cost, but memory and flops just are cumulative in global cost. # Time is cumulative in global cost and local cost, but memory and flops just are cumulative in global cost.
# comm sync # Comm sync
for item in comm_cost: for item in comm_cost:
group_ranks, cost = item group_ranks, cost = item
max_time = None max_time = None
...@@ -183,7 +185,7 @@ class CostEstimator: ...@@ -183,7 +185,7 @@ class CostEstimator:
for comp_cost in local_comp_cost[rank]: for comp_cost in local_comp_cost[rank]:
self.local_cost(rank).time += comp_cost.time self.local_cost(rank).time += comp_cost.time
# calc dist op cost # Calc dist op cost
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes processes = op_dist_attr.process_mesh.processes
...@@ -201,7 +203,7 @@ class CostEstimator: ...@@ -201,7 +203,7 @@ class CostEstimator:
continue continue
for item in dist_op_cost: for item in dist_op_cost:
if isinstance(item, list): if isinstance(item, list):
# comm sync # Comm sync
for comm_op_cost in item: for comm_op_cost in item:
max_time = None max_time = None
cost_time = {} cost_time = {}
...@@ -222,9 +224,9 @@ class CostEstimator: ...@@ -222,9 +224,9 @@ class CostEstimator:
self._bubble_time_mapping[rank] += ( self._bubble_time_mapping[rank] += (
max_time - cost_time[rank]) max_time - cost_time[rank])
elif isinstance(item, dict): elif isinstance(item, dict):
# op just one # Op just one
for rank in processes: for rank in processes:
# dp+pp+mp # DP+PP+MP
if rank not in item: if rank not in item:
continue continue
self.local_cost(rank).time += item[rank].time self.local_cost(rank).time += item[rank].time
...@@ -267,7 +269,7 @@ class CostEstimator: ...@@ -267,7 +269,7 @@ class CostEstimator:
return result return result
memories = {} memories = {}
max_memories = {} self.max_memories = {}
var_info = { var_info = {
} # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]} } # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]}
...@@ -277,6 +279,10 @@ class CostEstimator: ...@@ -277,6 +279,10 @@ class CostEstimator:
self._ordered_ops.sort(key=lambda x: x[0]) self._ordered_ops.sort(key=lambda x: x[0])
for op_id, op in self._ordered_ops: for op_id, op in self._ordered_ops:
if op.type in [
"create_py_reader", "create_double_buffer_reader", "read"
]:
continue
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
process_mesh = dist_op.dist_attr.process_mesh process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names: for var_name in op.input_arg_names:
...@@ -288,7 +294,7 @@ class CostEstimator: ...@@ -288,7 +294,7 @@ class CostEstimator:
input_dims_mapping) input_dims_mapping)
if key not in var_info[var_name]: if key not in var_info[var_name]:
var_info[var_name][key] = {} var_info[var_name][key] = {}
# it is even partition now # It is even partition now
if "memory" not in var_info[var_name][key]: if "memory" not in var_info[var_name][key]:
var = dist_op.get_serial_input(var_name) var = dist_op.get_serial_input(var_name)
global_sizes = var.shape global_sizes = var.shape
...@@ -326,6 +332,10 @@ class CostEstimator: ...@@ -326,6 +332,10 @@ class CostEstimator:
has_used_vars = set() has_used_vars = set()
for op_id, op in self._ordered_ops: for op_id, op in self._ordered_ops:
if op.type in [
"create_py_reader", "create_double_buffer_reader", "read"
]:
continue
can_free_memories = {} can_free_memories = {}
can_free_vars = set() can_free_vars = set()
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
...@@ -337,14 +347,14 @@ class CostEstimator: ...@@ -337,14 +347,14 @@ class CostEstimator:
input_dims_mapping) input_dims_mapping)
has_used_var = var_name + key has_used_var = var_name + key
var = dist_op.get_serial_input(var_name) var = dist_op.get_serial_input(var_name)
# not used # Not used
if var_name + key not in has_used_vars: if var_name + key not in has_used_vars:
has_used_vars.add(has_used_var) has_used_vars.add(has_used_var)
for process in process_mesh.processes: for process in process_mesh.processes:
if process not in memories: if process not in memories:
memories[process] = 0 memories[process] = 0
memories[process] += var_info[var_name][key]["memory"] memories[process] += var_info[var_name][key]["memory"]
# used # Used
else: else:
if op_id == var_info[var_name][key]["position"][-1]: if op_id == var_info[var_name][key]["position"][-1]:
if has_used_var not in can_free_vars: if has_used_var not in can_free_vars:
...@@ -363,14 +373,14 @@ class CostEstimator: ...@@ -363,14 +373,14 @@ class CostEstimator:
output_dims_mapping) output_dims_mapping)
has_used_var = var_name + key has_used_var = var_name + key
var = dist_op.get_serial_output(var_name) var = dist_op.get_serial_output(var_name)
# not used # Not used
if var_name + key not in has_used_vars: if var_name + key not in has_used_vars:
has_used_vars.add(has_used_var) has_used_vars.add(has_used_var)
for process in process_mesh.processes: for process in process_mesh.processes:
if process not in memories: if process not in memories:
memories[process] = 0 memories[process] = 0
memories[process] += var_info[var_name][key]["memory"] memories[process] += var_info[var_name][key]["memory"]
# used # Used
else: else:
if op_id == var_info[var_name][key]["position"][-1]: if op_id == var_info[var_name][key]["position"][-1]:
if has_used_var not in can_free_vars: if has_used_var not in can_free_vars:
...@@ -382,21 +392,22 @@ class CostEstimator: ...@@ -382,21 +392,22 @@ class CostEstimator:
can_free_memories[process] += var_info[ can_free_memories[process] += var_info[
var_name][key]["memory"] var_name][key]["memory"]
# calc peak memory # Calc peak memory
for process in memories: for process in memories:
if process not in max_memories: if process not in self.max_memories:
max_memories[process] = memories[process] self.max_memories[process] = memories[process]
else: else:
if memories[process] > max_memories[process]: if memories[process] > self.max_memories[process]:
max_memories[process] = memories[process] self.max_memories[process] = memories[process]
# free memory # Free memory
for process in can_free_memories: for process in can_free_memories:
if process in memories: if process in memories:
memories[process] -= can_free_memories[process] memories[process] -= can_free_memories[process]
# Calculate the max memory in all ranks # Calculate the max memory in all ranks
max_memory = max(max_memories.values()) max_memory = max(self.max_memories.values())
self.max_memory = max_memory
return max_memory return max_memory
...@@ -410,3 +421,143 @@ class CostEstimator: ...@@ -410,3 +421,143 @@ class CostEstimator:
self._estimate_core(dist_context, resharder, block) self._estimate_core(dist_context, resharder, block)
return self.global_cost return self.global_cost
def _print_tag(self, max_len, length):
tag = "+" + "-" * max_len
for i in range(length):
print(tag, end="")
if i == length - 1:
print("+")
def _print_vals(self, vals, max_len):
for idx, val in enumerate(vals):
s = "|" + str(val).center(max_len)
print(s, end="")
if idx == len(vals) - 1:
print("|")
def _pretty_print_memory_cost(self):
"""Print memory of every rank prettily."""
if not self.max_memories or not self.max_memory:
raise ValueError("Please calculate memory cost before print.")
# Padding automatically
max_len = 0
header = ["Rank", "Memory(MiB)"]
memories = [
int(item // 1e6) for item in list(self.max_memories.values())
]
for memory in (memories + header):
if len(str(memory)) > max_len:
max_len = len(str(memory))
max_len += 4 # for pretty print of center
# Print tag
self._print_tag(max_len, len(header))
# Print header
self._print_vals(header, max_len)
# Print tag
self._print_tag(max_len, len(header))
# Print rank and its memory
for i in range(len(self.max_memories)):
memory = memories[i]
vals = [i, memory]
self._print_vals(vals, max_len)
self._print_tag(max_len, len(header))
def _pretty_print_global(self):
"""Print global execution time and max memory prettily."""
if not self.max_memories or not self.max_memory:
raise ValueError("Please calculate cost before print.")
# Padding automatically
max_len = 0
header = ["Execution Time(ms)", "Max Memory(MiB)"]
vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)]
for memory in (vals + header):
if len(str(memory)) > max_len:
max_len = len(str(memory))
max_len += 4 # for pretty print of center
# Print tag
self._print_tag(max_len, len(header))
# Print header
self._print_vals(header, max_len)
# Print tag
self._print_tag(max_len, len(header))
# Print exec time and max memory
self._print_vals(vals, max_len)
# Print tag
self._print_tag(max_len, len(header))
def pretty_print_cost(self):
"""Print cost prettily."""
print("The global execution time and max memory are as follows:")
self._pretty_print_global()
print("The memory of every rank is as follows:")
self._pretty_print_memory_cost()
def get_cost_from_engine(engine, mode):
from ..utils import to_list
# Construct cost estimator by original main program
serial_main_prog = engine._serial_main_progs[mode].clone(
) if mode in engine._serial_main_progs else engine._orig_main_prog.clone()
serial_startup_prog = engine._serial_startup_progs[mode].clone(
) if mode in engine._serial_startup_progs else engine._orig_startup_prog.clone(
)
losses = to_list(
engine._loss) if (not isinstance(engine._loss, paddle.nn.Layer)
and not callable(engine._loss)) else engine._losses
if mode in engine._dist_contexts:
dist_context = engine._dist_contexts[mode]
completer = engine._planners[mode].completer
else:
from ..completion import Completer
from ..dist_context import DistributedContext
dist_context = DistributedContext(serial_main_prog, serial_startup_prog,
engine._optimizer, losses, {},
{"loss": losses}, engine._cluster,
engine._strategy)
completer = Completer(dist_context)
completer.complete_forward_annotation()
dist_context.block_state.parse_forward_blocks(
dist_context.serial_main_program)
if mode == "eval" or mode == "predict":
cost_estimator = CostEstimator(serial_main_prog, engine._cluster)
elif mode == "train":
from ..parallelizer_v2 import Parallelizer
# Get serial main program with backward
serial_optimizer = engine._optimizer
parallelizer = Parallelizer(mode, completer, dist_context)
# Generate backward
loss_name = dist_context.serial_loss.name
serial_loss = serial_main_prog.global_block()._var_recursive(loss_name)
params_grads = parallelizer._generate_backward(serial_main_prog,
serial_startup_prog,
serial_loss)
# Generate optimizer
optimizer_ops = parallelizer._generate_optimizer(
serial_main_prog, serial_startup_prog, serial_optimizer,
params_grads)
cost_estimator = CostEstimator(serial_main_prog, engine._cluster)
# Estimate global_cost and max memory
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(dist_context)
# Print the cost
cost_estimator.pretty_print_cost()
return global_cost, max_memory
...@@ -77,7 +77,6 @@ class DistributedContext: ...@@ -77,7 +77,6 @@ class DistributedContext:
self._serial_optimizer = None self._serial_optimizer = None
self._serial_feed_vars = {} self._serial_feed_vars = {}
self._serial_fetch_vars = {} self._serial_fetch_vars = {}
self._lr_optimizer = None # record the optimzier holding lr_scheduler
# Data members related to the program # Data members related to the program
self._dist_tensors_for_program = {} self._dist_tensors_for_program = {}
...@@ -268,12 +267,24 @@ class DistributedContext: ...@@ -268,12 +267,24 @@ class DistributedContext:
def _restore_serial_fetch_vars(self): def _restore_serial_fetch_vars(self):
for key, var_list in self._original_serial_fetch_vars.items(): for key, var_list in self._original_serial_fetch_vars.items():
new_var_list = [] new_var_list = []
for var in var_list: # metrics is a list of list
block_idx = var.block.idx if key == "metrics":
var_name = var.name for inner_var_list in var_list:
var = self._serial_main_program.blocks[ new_inner_var_list = []
block_idx]._var_recursive(var_name) for var in inner_var_list:
new_var_list.append(var) block_idx = var.block.idx
var_name = var.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_inner_var_list.append(var)
new_var_list.append(new_inner_var_list)
else:
for var in var_list:
block_idx = var.block.idx
var_name = var.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_var_list.append(var)
self._serial_fetch_vars[key] = new_var_list self._serial_fetch_vars[key] = new_var_list
def _restore_serial_info(self, mode="to_backup"): def _restore_serial_info(self, mode="to_backup"):
...@@ -861,7 +872,7 @@ class DistributedContext: ...@@ -861,7 +872,7 @@ class DistributedContext:
"_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \ "_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \
"_serial_ordered_op_nodes", "_original_serial_loss", \ "_serial_ordered_op_nodes", "_original_serial_loss", \
"_original_serial_feed_vars", "_original_serial_fetch_vars", \ "_original_serial_feed_vars", "_original_serial_fetch_vars", \
"_serial_loss", "_serial_feed_vars", "_serial_fetch_vars", "_lr_optimizer", \ "_serial_loss", "_serial_feed_vars", "_serial_fetch_vars", "_serial_optimizer", \
"_backup_serial_main_program_stack", "_backup_serial_startup_program_stack", \ "_backup_serial_main_program_stack", "_backup_serial_startup_program_stack", \
"_pass_context"]: "_pass_context"]:
setattr(result, k, v) setattr(result, k, v)
......
...@@ -14,44 +14,14 @@ ...@@ -14,44 +14,14 @@
import abc import abc
import numpy as np import numpy as np
from functools import wraps
import paddle import paddle
from .utils import to_list from paddle.io import BatchSampler, IterableDataset
from paddle.fluid.layers.utils import flatten from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler, DistributedBatchSampler
from paddle.io import DataLoader, BatchSampler, IterableDataset
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn
class DistributedDataLoader(metaclass=abc.ABCMeta): class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
def __init__(self, dataset, batch_size=1, epochs=1, drop_last=False):
if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER
else:
self.dataset_kind = _DatasetKind.MAP
self.dataset = dataset
self.epochs = epochs
self.drop_lost = drop_last
if batch_size is None:
self.batch_size = None
self.batch_sampler = None
else:
self.batch_size = batch_size
if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler(
dataset, batch_size)
else:
self.batch_sampler = BatchSampler(dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)
self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler)
@abc.abstractmethod @abc.abstractmethod
def __iter__(self): def __iter__(self):
...@@ -72,40 +42,72 @@ class DistributedDataLoader(metaclass=abc.ABCMeta): ...@@ -72,40 +42,72 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
return _InfiniteIterableSampler(self.dataset, 1) return _InfiniteIterableSampler(self.dataset, 1)
class NonIterableGeneratorLoader(DistributedDataLoader): class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
def __init__(self, def __init__(self,
dataset, dataset,
feed_list, feed_list=None,
places, capacity=None,
use_double_buffer=True,
iterable=True,
return_list=False,
use_multiprocess=False,
drop_last=True,
places=None,
batch_size=1, batch_size=1,
epochs=1, epochs=1,
steps_per_epoch=None, steps_per_epoch=None,
collate_fn=None, collate_fn=None,
split_data=True,
data_parallel_world_size=[], data_parallel_world_size=[],
data_parallel_rank=[], data_parallel_rank=[]):
drop_last=False, self.dataset = dataset
split_data=True):
self.feed_list = feed_list self.feed_list = feed_list
self.capacity = capacity
self.use_double_buffer = use_double_buffer
self.iterable = iterable
self.return_list = return_list
self.use_multiprocess = use_multiprocess
self.drop_last = drop_last
self.places = places self.places = places
self.batch_size = batch_size
self.epochs = epochs
self.steps_per_epoch = steps_per_epoch self.steps_per_epoch = steps_per_epoch
self.collate_fn = collate_fn
self.split_data = split_data
assert len(data_parallel_world_size) == len(feed_list) assert len(data_parallel_world_size) == len(feed_list)
assert len(data_parallel_rank) == len(feed_list) assert len(data_parallel_rank) == len(feed_list)
self.dp_world_sizes = data_parallel_world_size self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank self.dp_ranks = data_parallel_rank
self.split_data = split_data
super(NonIterableGeneratorLoader, if isinstance(dataset, IterableDataset):
self).__init__(dataset, batch_size, epochs, drop_last) self.dataset_kind = _DatasetKind.ITER
else:
self.dataset_kind = _DatasetKind.MAP
if self.batch_size is None:
self.batch_sampler = None
else:
if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler(
dataset, batch_size)
else:
self.batch_sampler = BatchSampler(dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)
self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler)
if self.auto_collate_batch: if self.auto_collate_batch:
self.collate_fn = collate_fn or default_collate_fn self.collate_fn = collate_fn or default_collate_fn
else: else:
self.collate_fn = collate_fn or default_convert_fn self.collate_fn = collate_fn or default_convert_fn
self.dataset_fetcher = _DatasetKind.create_fetcher( self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.auto_collate_batch, self.dataset_kind, self.dataset, self.auto_collate_batch,
self.collate_fn, self.drop_lost) self.collate_fn, self.drop_last)
self._steps = self._infer_steps() self._steps = self._infer_steps()
self._inner_dataloader = self._create_inner_dataloader() self._inner_dataloader = self._create_inner_dataloader()
...@@ -118,8 +120,10 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -118,8 +120,10 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
def __next__(self): def __next__(self):
if not self._steps: if not self._steps:
self._cur_step += 1 self._cur_step += 1
return None
elif self._cur_step < self._steps: elif self._cur_step < self._steps:
self._cur_step += 1 self._cur_step += 1
return None
else: else:
self._inner_dataloader.reset() self._inner_dataloader.reset()
self.sampler_iter = iter(self.index_sampler) self.sampler_iter = iter(self.index_sampler)
...@@ -141,6 +145,16 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -141,6 +145,16 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
) )
return steps_per_epoch return steps_per_epoch
@property
def index_sampler(self):
if self.auto_collate_batch:
return self.batch_sampler
else:
if self.dataset_kind == _DatasetKind.MAP:
return list(range(len(self.dataset)))
else:
return _InfiniteIterableSampler(self.dataset, 1)
def _create_inner_dataloader(self): def _create_inner_dataloader(self):
def data_generator(): def data_generator():
...@@ -153,7 +167,7 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -153,7 +167,7 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
self.dataset_fetcher = _DatasetKind.create_fetcher( self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.dataset_kind, self.dataset,
self.auto_collate_batch, self.collate_fn, self.auto_collate_batch, self.collate_fn,
self.drop_lost) self.drop_last)
break break
partial_data = [] partial_data = []
...@@ -173,7 +187,83 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -173,7 +187,83 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
yield partial_data yield partial_data
dataloader = paddle.fluid.io.DataLoader.from_generator( dataloader = paddle.fluid.io.DataLoader.from_generator(
feed_list=self.feed_list, capacity=70, iterable=False) feed_list=self.feed_list,
capacity=self.capacity,
use_double_buffer=self.use_double_buffer,
# iterable=self.iterable,
iterable=False,
return_list=self.return_list,
use_multiprocess=self.use_multiprocess,
drop_last=self.drop_last)
dataloader.set_batch_generator(data_generator, self.places) dataloader.set_batch_generator(data_generator, self.places)
return dataloader return dataloader
class DistributedDataLoader(DistributedDataLoaderBase):
def __init__(self,
dataset,
feed_list=None,
places=None,
return_list=True,
batch_size=1,
shuffle=False,
drop_last=False,
collate_fn=None,
num_workers=0,
use_buffer_reader=True,
use_shared_memory=True,
timeout=0,
worker_init_fn=None,
epochs=1,
steps_per_epoch=None,
split_data=True,
data_parallel_world_size=[],
data_parallel_rank=[]):
self.dataset = dataset
self.feed_list = feed_list
self.return_list = return_list
self.places = places
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.collate_fn = collate_fn
self.num_workers = num_workers
self.use_buffer_reader = use_buffer_reader
self.use_shared_memory = use_shared_memory
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.epochs = epochs
self.steps_per_epoch = steps_per_epoch
self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank
self.split_data = split_data
# TODO: rank info
self.batch_sampler = DistributedBatchSampler(
self.dataset, self.batch_size, self.dp_world_sizes[0],
self.dp_ranks[0], self.shuffle, self.drop_last)
self._inner_dataloader = self._create_inner_dataloader()
def __iter__(self):
return self
def __next__(self):
return next(self.data)
def _create_inner_dataloader(self):
dataloader = paddle.fluid.io.DataLoader(
self.dataset,
feed_list=self.feed_list,
places=self.places,
return_list=self.return_list,
batch_sampler=self.batch_sampler,
collate_fn=self.collate_fn,
num_workers=self.num_workers,
use_buffer_reader=self.use_buffer_reader,
use_shared_memory=self.use_shared_memory,
timeout=self.timeout,
worker_init_fn=self.worker_init_fn)
self.data = (x for x in dataloader)
return dataloader
...@@ -139,7 +139,7 @@ class ProxyLayer(Layer): ...@@ -139,7 +139,7 @@ class ProxyLayer(Layer):
""" """
outs = [] outs = []
for metric in self.metrics: for metric in self.metrics:
outs.extend(metric.compute(*inputs)) outs.append(to_list(metric.compute(*inputs)))
return outs return outs
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
...@@ -196,15 +198,36 @@ def recompute(op): ...@@ -196,15 +198,36 @@ def recompute(op):
return RecomputeOperator(op) return RecomputeOperator(op)
_g_fetched_tensors = {} _g_collections = {}
class CollectionNames(object):
FETCHES = "fetches"
LOGGING = "logging"
def get_collection(name):
collection = _g_collections.get(name, None)
if collection is None:
collection = []
_g_collections[name] = collection
return _g_collections[name]
def fetch(tensor, name=None): def add_to_collection(collection_name, value, name=None):
if name is None: if collection_name not in _g_collections:
_g_fetched_tensors[tensor.name] = tensor _g_collections[collection_name] = []
if name is not None:
for _, v in _g_collections[collection_name]:
if v == value: return
_g_collections[collection_name].append((name, value))
else: else:
_g_fetched_tensors[name] = tensor for _, v in _g_collections[collection_name]:
if v == value: return
_g_collections[collection_name].append((None, value))
def _get_fetches(): def fetch(tensor, name=None, logging=False):
return _g_fetched_tensors add_to_collection(CollectionNames.FETCHES, tensor, name)
if logging:
add_to_collection(CollectionNames.LOGGING, tensor, name)
...@@ -33,3 +33,5 @@ from . import dist_slice ...@@ -33,3 +33,5 @@ from . import dist_slice
from . import dist_fused_feedforward from . import dist_fused_feedforward
from . import dist_fused_attention from . import dist_fused_attention
from . import dist_reduce_sum_p from . import dist_reduce_sum_p
from . import dist_shape
from . import dist_assign
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .dist_default import DistributedDefaultImpl0
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedAssign(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedAssign, self).__init__(op_type)
register_distributed_operator_impl_container(DistributedAssign("assign"))
class DistributedAssignImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedAssignImpl, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
return True
def is_output_compatible(self, dist_op):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if x_dims_mapping != out_dims_mapping:
return False
return True
def update_dims_mapping(self, dist_op):
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
for i in range(len(x_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl("assign", DistributedAssignImpl("assign"))
...@@ -1308,6 +1308,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1308,6 +1308,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
processes = process_mesh.processes processes = process_mesh.processes
# col parallel: matmul + allreduce # col parallel: matmul + allreduce
if backward_op.attr("trans_y"):
Y_var_dim_mapping.reverse()
assert Y_var_dim_mapping[0] < 0 assert Y_var_dim_mapping[0] < 0
parallel_axis = Y_var_dim_mapping[1] parallel_axis = Y_var_dim_mapping[1]
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .dist_default import DistributedDefaultImpl0
from ..utils import is_dim_shard
class DistributedShape(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedShape, self).__init__(op_type)
register_distributed_operator_impl_container(DistributedShape("shape"))
class DistributedShapeImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedShapeImpl, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
return True
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
assert len(out_dims_mapping) == 1
if is_dim_shard(out_dims_mapping[0]):
return False
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
return True
def update_dims_mapping(self, dist_op):
return False
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl("shape", DistributedShapeImpl("shape"))
...@@ -101,8 +101,12 @@ class DistributedSplitImpl(DistributedOperatorImpl): ...@@ -101,8 +101,12 @@ class DistributedSplitImpl(DistributedOperatorImpl):
return changed return changed
def is_auto_compatible(self, dist_op): def is_auto_compatible(self, dist_op):
raise NotImplementedError( if (not self.is_input_compatible(dist_op)) or \
"Auto Search is not supported by dist split yet.") (not self.is_output_compatible(dist_op)) or \
(not self.is_compatible(dist_op)):
return False
return True
@staticmethod @staticmethod
def forward(ctx, *args, **kwargs): def forward(ctx, *args, **kwargs):
......
...@@ -23,14 +23,12 @@ import logging ...@@ -23,14 +23,12 @@ import logging
import pickle import pickle
import time import time
import paddle import paddle
from paddle.fluid.backward import append_backward
from paddle.distributed.utils.log_utils import get_logger
from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid import program_guard from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.distributed.utils.log_utils import get_logger
from paddle.distributed.passes import new_pass, PassContext from paddle.distributed.passes import new_pass, PassContext
from .dist_context import DistributedContext from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context from .dist_context import set_default_distributed_context
from .completion import Completer from .completion import Completer
from .partitioner import Partitioner from .partitioner import Partitioner
...@@ -40,9 +38,7 @@ from .process_group import get_world_process_group ...@@ -40,9 +38,7 @@ from .process_group import get_world_process_group
from .process_group import _g_process_group_map, ProcessGroup from .process_group import _g_process_group_map, ProcessGroup
from .utils import make_data_unshard from .utils import make_data_unshard
from .utils import set_grad_var_shape from .utils import set_grad_var_shape
from .utils import print_program_with_dist_attr
from .utils import SerialProgramInfo from .utils import SerialProgramInfo
from .utils import get_logger
from .reshard import Resharder from .reshard import Resharder
from .cluster import Cluster from .cluster import Cluster
from .mapper import mapping from .mapper import mapping
...@@ -148,7 +144,7 @@ class AutoParallelizer: ...@@ -148,7 +144,7 @@ class AutoParallelizer:
with program_guard(main_program, startup_program): with program_guard(main_program, startup_program):
optimize_ops = optimizer.apply_gradients(params_grads) optimize_ops = optimizer.apply_gradients(params_grads)
self._dist_context._lr_optimizer = optimizer self._dist_context._serial_optimizer = optimizer
# update completion # update completion
self._completer = Completer(self._dist_context) self._completer = Completer(self._dist_context)
self._completer.complete_update_annotation(main_program) self._completer.complete_update_annotation(main_program)
......
...@@ -15,24 +15,17 @@ ...@@ -15,24 +15,17 @@
import copy import copy
import time import time
import logging import logging
from collections import defaultdict
import paddle
from paddle.fluid import program_guard from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.framework import _non_static_mode, unique_name from paddle.fluid.framework import unique_name
from paddle.distributed.passes import new_pass from paddle.distributed.passes import new_pass
from .reshard import Resharder from .reshard import Resharder
from .partitioner import Partitioner from .partitioner import Partitioner
from .dist_op import DistributedOperator from .utils import set_grad_var_shape
from .dist_saver import DistributedSaver from .process_group import get_world_process_group
from .dist_loader import NonIterableGeneratorLoader from ..utils.log_utils import get_logger
from .utils import make_data_unshard, set_grad_var_shape
from .utils import print_program_with_dist_attr, to_list
from .utils import get_logger
from .process_group import get_all_process_groups, get_world_process_group
from .dist_context import DistributedContext, get_default_distributed_context
class Parallelizer: class Parallelizer:
...@@ -69,7 +62,7 @@ class Parallelizer: ...@@ -69,7 +62,7 @@ class Parallelizer:
serial_main_program, serial_startup_program, params_grads = self._apply_pre_optimization( serial_main_program, serial_startup_program, params_grads = self._apply_pre_optimization(
serial_main_program, serial_startup_program, serial_loss, serial_main_program, serial_startup_program, serial_loss,
serial_optimizer, params_grads) serial_optimizer, params_grads)
self._logger.info( self._logger.debug(
"within parallel apply_pre_optimization time: {}, mode {}". "within parallel apply_pre_optimization time: {}, mode {}".
format(time.time() - time0, self._mode)) format(time.time() - time0, self._mode))
# Do logical partition # Do logical partition
...@@ -77,14 +70,14 @@ class Parallelizer: ...@@ -77,14 +70,14 @@ class Parallelizer:
partitioner = Partitioner(self._dist_context, rank) partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
serial_main_program, serial_startup_program, params_grads) serial_main_program, serial_startup_program, params_grads)
self._logger.info( self._logger.debug(
"within parallel partitioner time: {}, mode {}".format( "within parallel partitioner time: {}, mode {}".format(
time.time() - time0, self._mode)) time.time() - time0, self._mode))
# Generate optimizer # Generate optimizer
time0 = time.time() time0 = time.time()
self._generate_optimizer(dist_main_prog, dist_startup_prog, self._generate_optimizer(dist_main_prog, dist_startup_prog,
serial_optimizer, dist_params_grads) serial_optimizer, dist_params_grads)
self._logger.info( self._logger.debug(
"within parallel optimizer time: {}, mode {}".format( "within parallel optimizer time: {}, mode {}".format(
time.time() - time0, self._mode)) time.time() - time0, self._mode))
# Do reshard process # Do reshard process
...@@ -93,14 +86,14 @@ class Parallelizer: ...@@ -93,14 +86,14 @@ class Parallelizer:
resharder = Resharder(dist_main_prog, dist_startup_prog, rank, resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, dist_params_grads) self._dist_context, dist_params_grads)
resharder.reshard() resharder.reshard()
self._logger.info( self._logger.debug(
"within parallel reshard time: {}, mode {}".format( "within parallel reshard time: {}, mode {}".format(
time.time() - time0, self._mode)) time.time() - time0, self._mode))
# Apply post optimization passes # Apply post optimization passes
time0 = time.time() time0 = time.time()
self._apply_post_optimization(dist_main_prog, dist_startup_prog, self._apply_post_optimization(dist_main_prog, dist_startup_prog,
rank, dist_params_grads) rank, dist_params_grads)
self._logger.info( self._logger.debug(
"within parallel apply_post_optimization time: {}, mode {}". "within parallel apply_post_optimization time: {}, mode {}".
format(time.time() - time0, self._mode)) format(time.time() - time0, self._mode))
else: else:
...@@ -109,7 +102,7 @@ class Parallelizer: ...@@ -109,7 +102,7 @@ class Parallelizer:
self._apply_pre_optimization(serial_main_program, self._apply_pre_optimization(serial_main_program,
serial_startup_program, None, None, serial_startup_program, None, None,
None) None)
self._logger.info( self._logger.debug(
"within parallel apply_pre_optimization time: {}, mode {}". "within parallel apply_pre_optimization time: {}, mode {}".
format(time.time() - time0, self._mode)) format(time.time() - time0, self._mode))
# Do logical partition # Do logical partition
...@@ -118,14 +111,14 @@ class Parallelizer: ...@@ -118,14 +111,14 @@ class Parallelizer:
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
serial_main_program, serial_startup_program, []) serial_main_program, serial_startup_program, [])
# Do reshard process # Do reshard process
self._logger.info( self._logger.debug(
"within parallel partitioner time: {}, mode {}".format( "within parallel partitioner time: {}, mode {}".format(
time.time() - time0, self._mode)) time.time() - time0, self._mode))
time0 = time.time() time0 = time.time()
resharder = Resharder(dist_main_prog, dist_startup_prog, rank, resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, [], 1) self._dist_context, [], 1)
resharder.reshard() resharder.reshard()
self._logger.info( self._logger.debug(
"within parallel reshard time: {}, mode {}".format( "within parallel reshard time: {}, mode {}".format(
time.time() - time0, self._mode)) time.time() - time0, self._mode))
# Clone program for test # Clone program for test
...@@ -150,7 +143,7 @@ class Parallelizer: ...@@ -150,7 +143,7 @@ class Parallelizer:
# NOTE: `apply_gradients` will add an Accumulator for a parameter only once, # NOTE: `apply_gradients` will add an Accumulator for a parameter only once,
# but optimizer will be called repeatedly in re-launch, so optimizer need to be copied. # but optimizer will be called repeatedly in re-launch, so optimizer need to be copied.
optimizer = copy.deepcopy(optimizer) optimizer = copy.deepcopy(optimizer)
self._dist_context._lr_optimizer = optimizer self._dist_context._serial_optimizer = optimizer
with program_guard(main_program, startup_program): with program_guard(main_program, startup_program):
with unique_name.guard("opt_"): with unique_name.guard("opt_"):
optimizer_ops = optimizer.apply_gradients(params_grads) optimizer_ops = optimizer.apply_gradients(params_grads)
...@@ -177,9 +170,7 @@ class Parallelizer: ...@@ -177,9 +170,7 @@ class Parallelizer:
startup_program = self._pass_context.get_attr("startup_program") startup_program = self._pass_context.get_attr("startup_program")
params_grads = self._pass_context.get_attr("params_grads") params_grads = self._pass_context.get_attr("params_grads")
# apply amp pass # apply amp pass on train/eval/predict
# FIXME we disenable amp for eval since it has a little bug with
# eval program and which will be fixed in future
if self._strategy.amp.enable: if self._strategy.amp.enable:
config = copy.deepcopy(self._strategy.amp.to_dict()) config = copy.deepcopy(self._strategy.amp.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
......
...@@ -28,7 +28,7 @@ from .utils import set_dist_op_desc_original_id ...@@ -28,7 +28,7 @@ from .utils import set_dist_op_desc_original_id
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_loss_op, is_optimize_op from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_loss_op, is_optimize_op
from .operators.common import BACKWARD_ONLY_DIST_OPS from .operators.common import BACKWARD_ONLY_DIST_OPS
__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] __varname_not_in_block__ = ["lod_tensor_blocking_queue"]
__not_shape_var_type__ = [ __not_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES
] ]
...@@ -243,7 +243,9 @@ class Partitioner(object): ...@@ -243,7 +243,9 @@ class Partitioner(object):
target_block, serial_input_varname, target_block, serial_input_varname,
new_varname) new_varname)
else: else:
assert serial_input_varname in __varname_not_in_block__ for varname_not_in_block in __varname_not_in_block__:
assert varname_not_in_block in serial_input_varname, \
"{} is not found".format(serial_input_varname)
self._serial2dist_varname_mapping[ self._serial2dist_varname_mapping[
serial_input_varname] = new_varname serial_input_varname] = new_varname
......
...@@ -14,9 +14,7 @@ ...@@ -14,9 +14,7 @@
from .completion import Completer from .completion import Completer
from .dist_context import get_default_distributed_context from .dist_context import get_default_distributed_context
from .utils import print_program_with_dist_attr from .tuner.parallel_tuner import ParallelTuner
# from .tuner.parallel_tuner import ParallelTuner
class Planner: class Planner:
...@@ -39,20 +37,20 @@ class Planner: ...@@ -39,20 +37,20 @@ class Planner:
self._completer = Completer(self._dist_context) self._completer = Completer(self._dist_context)
self._strategy = dist_context.strategy self._strategy = dist_context.strategy
# if self._strategy.auto_search: # set parallel tuner for auto search
# self._parallel_tuner = ParallelTuner( if self._strategy.auto_mode == "full":
# self._dist_context, mode=self._mode) self._parallel_tuner = ParallelTuner(self._dist_context,
mode=self._mode)
@property @property
def completer(self): def completer(self):
return self._completer return self._completer
def plan(self): def plan(self):
self._completer.complete_forward_annotation() if self._strategy.auto_mode == "full":
# if self._strategy.auto_search: self._parallel_tuner.tune()
# self._parallel_tuner.tune() else:
# else: self._completer.complete_forward_annotation()
# self._completer.complete_forward_annotation()
# parse forward sub block # parse forward sub block
self._dist_context.block_state.parse_forward_blocks( self._dist_context.block_state.parse_forward_blocks(
self._dist_context.serial_main_program) self._dist_context.serial_main_program)
...@@ -168,7 +168,10 @@ class ProcessMesh(object): ...@@ -168,7 +168,10 @@ class ProcessMesh(object):
else: else:
new_mesh = self._mesh[index] new_mesh = self._mesh[index]
new_dim_names = self._dim_names[1:] new_dim_names = self._dim_names[1:]
return ProcessMesh(new_mesh, new_dim_names) if new_mesh.shape:
return ProcessMesh(new_mesh, new_dim_names)
else:
return ProcessMesh([new_mesh])
def __enter__(self): def __enter__(self):
set_current_process_mesh(self) set_current_process_mesh(self)
......
...@@ -37,6 +37,7 @@ _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling'] ...@@ -37,6 +37,7 @@ _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling']
_g_gradient_clip_ops = [ _g_gradient_clip_ops = [
"sum", "sqrt", "fill_constant", "elementwise_max", "elementwise_div" "sum", "sqrt", "fill_constant", "elementwise_max", "elementwise_div"
] ]
_g_subblock_ops = ["while", "conditional_block"]
def get_var_with_recursion(var_name, block, program): def get_var_with_recursion(var_name, block, program):
...@@ -45,10 +46,11 @@ def get_var_with_recursion(var_name, block, program): ...@@ -45,10 +46,11 @@ def get_var_with_recursion(var_name, block, program):
if var_name in block.vars: if var_name in block.vars:
var = block.vars[var_name] var = block.vars[var_name]
else: else:
parent_block = program.blocks[block.parent_idx] var = block._var_recursive(var_name)
if var_name in parent_block.vars: # parent_block = program.blocks[block.parent_idx]
var = parent_block.vars[var_name] # if var_name in parent_block.vars:
assert var is not None # var = parent_block.vars[var_name]
assert var is not None, "{} is not found".format(var.name)
return var return var
...@@ -1077,7 +1079,9 @@ class Resharder: ...@@ -1077,7 +1079,9 @@ class Resharder:
new_Out = [] new_Out = []
for var_name in while_op.output("Out"): for var_name in while_op.output("Out"):
for output_name in sub_block_op_outputs[::-1]: for output_name in sub_block_op_outputs[::-1]:
if output_name.find(var_name) != -1: if output_name.find(var_name) != -1 and (
len(var_name) == len(output_name)
or "@RESHARD" in output_name):
if output_name not in new_Out: if output_name not in new_Out:
new_Out.append(output_name) new_Out.append(output_name)
assert new_Out assert new_Out
...@@ -1106,13 +1110,15 @@ class Resharder: ...@@ -1106,13 +1110,15 @@ class Resharder:
return False return False
def is_condition_replicative(self, op): def is_condition_replicative(self, op):
assert op.type == "while"
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id] sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
dist_op = self.dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr if op.type == "while":
input_cond = op.input("Condition")
elif op.type == "conditional_block":
input_cond = op.input("Cond")
# the dims mapping of condition tensor should be replicative # the dims mapping of condition tensor should be replicative
for var_name in op.input("Condition"): for var_name in input_cond:
var = get_var_with_recursion(var_name, sub_block, var = get_var_with_recursion(var_name, sub_block,
self.auto_parallel_main_prog) self.auto_parallel_main_prog)
dist_tensor = self.dist_context.get_dist_tensor_for_program(var) dist_tensor = self.dist_context.get_dist_tensor_for_program(var)
...@@ -1662,9 +1668,9 @@ class Resharder: ...@@ -1662,9 +1668,9 @@ class Resharder:
op.desc.set_input(proto.inputs[0].name, op.desc.set_input(proto.inputs[0].name,
op.input("X") + while_op_X_append) op.input("X") + while_op_X_append)
def _get_while_op_input_attrs(self, op, var_name): def _get_subblock_input_attrs(self, op, var_name):
# NOTE: Multi while loop is not supported # NOTE: Multi while loop is not supported
assert op.type == "while" assert op.type in _g_subblock_ops
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id] sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops ops = sub_block.ops
input_attrs = [] input_attrs = []
...@@ -1715,8 +1721,8 @@ class Resharder: ...@@ -1715,8 +1721,8 @@ class Resharder:
def get_op_input_attrs(self, op, var_name): def get_op_input_attrs(self, op, var_name):
op_input_attrs = [] op_input_attrs = []
if op.type == "while": if op.type in _g_subblock_ops:
op_input_attrs = self._get_while_op_input_attrs(op, var_name) op_input_attrs = self._get_subblock_input_attrs(op, var_name)
else: else:
op_input_attrs = self._get_common_op_input_attrs(op, var_name) op_input_attrs = self._get_common_op_input_attrs(op, var_name)
...@@ -1738,8 +1744,18 @@ class Resharder: ...@@ -1738,8 +1744,18 @@ class Resharder:
if len(set(process_mesh.processes)) == len(processes): if len(set(process_mesh.processes)) == len(processes):
global_process_mesh_idx = idx global_process_mesh_idx = idx
break break
if global_process_mesh_idx is not None: if global_process_mesh_idx is not None:
self.dist_context.process_meshes.pop(idx) is_removed = False
global_mesh = self.dist_context.process_meshes[idx]
for i, mesh in enumerate(self.dist_context.process_meshes):
if i == idx:
continue
if set(mesh.processes) < set(global_mesh.processes):
is_removed = True
if is_removed:
self.dist_context.process_meshes.pop(idx)
def _change_subblock_op_input_and_output(self, block_idx, block): def _change_subblock_op_input_and_output(self, block_idx, block):
if "var_reshard_mapping" in Resharder.while_block_info[block_idx]: if "var_reshard_mapping" in Resharder.while_block_info[block_idx]:
...@@ -1810,7 +1826,7 @@ class Resharder: ...@@ -1810,7 +1826,7 @@ class Resharder:
if dist_op is not None: if dist_op is not None:
op_input_dist_attrs = [ op_input_dist_attrs = [
] # [(op_process_mesh, op_input_dims_mapping), (op_process_mesh, op_input_dims_mapping)] ] # [(op_process_mesh, op_input_dims_mapping), (op_process_mesh, op_input_dims_mapping)]
if op.type == "while": if op.type in _g_subblock_ops:
if not self.is_condition_replicative(op): if not self.is_condition_replicative(op):
raise ValueError( raise ValueError(
"Please check the condition due to the dims mapping is not replicative." "Please check the condition due to the dims mapping is not replicative."
...@@ -1824,6 +1840,8 @@ class Resharder: ...@@ -1824,6 +1840,8 @@ class Resharder:
if op.type == "while": if op.type == "while":
# condition var process mesh is the same with op and dims_mapping is replicative, so it do not need reshard # condition var process mesh is the same with op and dims_mapping is replicative, so it do not need reshard
input_var_names = op.input("X") input_var_names = op.input("X")
elif op.type == "conditional_block":
input_var_names = op.input("Input")
else: else:
input_var_names = op.input_arg_names input_var_names = op.input_arg_names
# to avoid while op X order different # to avoid while op X order different
...@@ -1831,8 +1849,8 @@ class Resharder: ...@@ -1831,8 +1849,8 @@ class Resharder:
idx_offset = 0 idx_offset = 0
for var_name in input_var_names: for var_name in input_var_names:
# skip lod_tensor_blocking_queue_0 # skip lod_tensor_blocking_queue_? name
if var_name == "lod_tensor_blocking_queue_0": if "lod_tensor_blocking_queue" in var_name:
continue continue
var = get_var_with_recursion(var_name, block, var = get_var_with_recursion(var_name, block,
self.auto_parallel_main_prog) self.auto_parallel_main_prog)
...@@ -1976,11 +1994,12 @@ class Resharder: ...@@ -1976,11 +1994,12 @@ class Resharder:
idx = 0 idx = 0
# skip reader and ops whose process mesh is union # skip reader and ops whose process mesh is union
skip_ops = [ skip_ops = [
"create_py_reader", "create_double_buffer_reader", "read", "while", "create_py_reader", "create_double_buffer_reader", "read",
"write_to_array", "read_from_array" "write_to_array", "read_from_array"
] ]
global _g_special_ops global _g_special_ops
skip_ops += _g_special_ops skip_ops += _g_special_ops
skip_ops += _g_subblock_ops
while idx < len(block.ops): while idx < len(block.ops):
pre_op_count = len(block.ops) pre_op_count = len(block.ops)
op = block.ops[idx] op = block.ops[idx]
......
...@@ -116,6 +116,13 @@ class TuningConfig(BaseConfig): ...@@ -116,6 +116,13 @@ class TuningConfig(BaseConfig):
super(TuningConfig, self).__init__(category, config_dict) super(TuningConfig, self).__init__(category, config_dict)
class DatasetConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.DATASET
super(DatasetConfig, self).__init__(category, config_dict)
class Strategy(BaseConfig): class Strategy(BaseConfig):
""" """
The `Strategy` object is used to configure the paralleization and optimization beheviors. The `Strategy` object is used to configure the paralleization and optimization beheviors.
...@@ -180,3 +187,6 @@ class Strategy(BaseConfig): ...@@ -180,3 +187,6 @@ class Strategy(BaseConfig):
config_dict = self._config_dict.get(constants.TUNING, None) config_dict = self._config_dict.get(constants.TUNING, None)
self.tuning = TuningConfig(config_dict) self.tuning = TuningConfig(config_dict)
config_dict = self._config_dict.get(constants.DATASET, None)
self.dataset = DatasetConfig(config_dict)
...@@ -136,12 +136,24 @@ def _copy_context(ref_dist_context): ...@@ -136,12 +136,24 @@ def _copy_context(ref_dist_context):
for key, var_list in ref_dist_context._serial_fetch_vars.items(): for key, var_list in ref_dist_context._serial_fetch_vars.items():
new_var_list = [] new_var_list = []
for var in var_list: # metrics is a list of list
block_idx = var.block.idx if key == "metrics":
var_name = var.name for inner_var_list in var_list:
var = new_dist_context._serial_main_program.blocks[ new_inner_var_list = []
block_idx]._var_recursive(var_name) for var in inner_var_list:
new_var_list.append(var) block_idx = var.block.idx
var_name = var.name
var = new_dist_context._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_inner_var_list.append(var)
new_var_list.append(new_inner_var_list)
else:
for var in var_list:
block_idx = var.block.idx
var_name = var.name
var = new_dist_context._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_var_list.append(var)
new_dist_context._serial_fetch_vars[key] = new_var_list new_dist_context._serial_fetch_vars[key] = new_var_list
# copy information in forward and backward # copy information in forward and backward
......
...@@ -13,20 +13,17 @@ ...@@ -13,20 +13,17 @@
# limitations under the License. # limitations under the License.
import os import os
import sys
import argparse import argparse
import traceback import traceback
import pickle import pickle
import json import json
import time import time
import numpy as np
from functools import partial
import paddle import paddle
from paddle.fluid.framework import Program, _current_expected_place from paddle.fluid.framework import Program, _current_expected_place
from paddle.fluid.framework import Operator, Parameter from paddle.fluid.framework import Operator
from paddle.distributed.auto_parallel.process_group import clear_all_process_groups, get_all_process_groups, new_process_group from paddle.distributed.auto_parallel.process_group import get_all_process_groups, new_process_group
from paddle.distributed.auto_parallel.dist_loader import NonIterableGeneratorLoader from paddle.distributed.auto_parallel.dist_loader import DistributedDataLoaderFromGenerator
from paddle.distributed.collective import _get_global_env from paddle.distributed.collective import _get_global_env
paddle.enable_static() paddle.enable_static()
...@@ -135,13 +132,14 @@ def create_dataloader(main_program, ...@@ -135,13 +132,14 @@ def create_dataloader(main_program,
# insert read op at the end of program # insert read op at the end of program
places = paddle.static.cuda_places() places = paddle.static.cuda_places()
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
dataloader = NonIterableGeneratorLoader( dataloader = DistributedDataLoaderFromGenerator(
dataset, dataset=dataset,
feed_list, feed_list=feed_list,
places, capacity=70,
dataset.batch_size, places=places,
epochs, batch_size=dataset.batch_size,
steps_per_epoch, epochs=epochs,
steps_per_epoch=steps_per_epoch,
data_parallel_world_size=dataset.dp_world_size, data_parallel_world_size=dataset.dp_world_size,
data_parallel_rank=dataset.dp_rank) data_parallel_rank=dataset.dp_rank)
......
...@@ -44,10 +44,18 @@ class TunableSpace(object): ...@@ -44,10 +44,18 @@ class TunableSpace(object):
def variables(self): def variables(self):
return self._variables return self._variables
@variables.setter
def variables(self, variables):
self._variables = variables
@property @property
def values(self): def values(self):
return self._values return self._values
@values.setter
def values(self, values):
self._values = values
def get_value(self, name): def get_value(self, name):
if name in self.values: if name in self.values:
return self.values[name] return self.values[name]
......
...@@ -90,6 +90,7 @@ class Choice(TunableVariable): ...@@ -90,6 +90,7 @@ class Choice(TunableVariable):
raise TypeError( raise TypeError(
"Choice can contain only one type of value, but found values: {} with types: {}." "Choice can contain only one type of value, but found values: {} with types: {}."
.format(str(values), str(types))) .format(str(values), str(types)))
self._is_unknown_type = False
if isinstance(values[0], str): if isinstance(values[0], str):
values = [str(v) for v in values] values = [str(v) for v in values]
...@@ -108,9 +109,8 @@ class Choice(TunableVariable): ...@@ -108,9 +109,8 @@ class Choice(TunableVariable):
if default is not None: if default is not None:
default = bool(default) default = bool(default)
else: else:
raise TypeError( self._is_unknown_type = True
"Choice can only contain str, int, float, or boll, but found: {} " self._indices = [i for i in range(len(values))]
.format(str(values)))
self.values = values self.values = values
if default is not None and default not in values: if default is not None and default not in values:
...@@ -129,7 +129,11 @@ class Choice(TunableVariable): ...@@ -129,7 +129,11 @@ class Choice(TunableVariable):
def random(self, seed=None): def random(self, seed=None):
rng = np.random.default_rng(seed) rng = np.random.default_rng(seed)
return rng.choice(self.values) if self._is_unknown_type:
indice = rng.choice(self._indices)
return self.values[indice]
else:
return rng.choice(self.values)
def get_state(self): def get_state(self):
state = super(Choice, self).get_state() state = super(Choice, self).get_state()
......
...@@ -27,6 +27,10 @@ from paddle.distributed.auto_parallel.process_group import get_all_process_group ...@@ -27,6 +27,10 @@ from paddle.distributed.auto_parallel.process_group import get_all_process_group
from paddle.fluid.io import is_parameter, is_belong_to_optimizer from paddle.fluid.io import is_parameter, is_belong_to_optimizer
from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute
__not_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES
]
def get_logger(log_level, name="auto_parallel"): def get_logger(log_level, name="auto_parallel"):
logger = logging.getLogger(name) logger = logging.getLogger(name)
...@@ -1583,3 +1587,80 @@ def find_higher_order_backward_op(program): ...@@ -1583,3 +1587,80 @@ def find_higher_order_backward_op(program):
return True return True
return False return False
def get_lr(optimizer):
if isinstance(optimizer, paddle.optimizer.Optimizer):
return optimizer.get_lr()
elif isinstance(optimizer, paddle.fluid.optimizer.Optimizer):
if isinstance(optimizer._learning_rate, float):
return optimizer._learning_rate
else:
return optimizer._learning_rate()
else:
raise TypeError(
"'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
" or `paddle.fluid.optimizer.Optimizer`, but got {}.".format(type(optimizer))
)
def initialize_pg_in_full_mode(all_process_groups, cur_rank):
import socket
from ..collective import _get_global_env
has_recv_by_socket = []
# This is a magic number
magic_num = 500
genv = _get_global_env()
cur_rank_ip, cur_rank_port = genv.current_endpoint.split(":")
cur_rank_recv_port = int(cur_rank_port) + magic_num
server_socket = None
# Large enough for recv rank
buff_size = 1024
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.bind((cur_rank_ip, cur_rank_recv_port))
# The 10 is an empirical value
server_socket.listen(10)
client_sockets = {}
for process_group in all_process_groups:
if cur_rank not in process_group.ranks:
continue
if len(process_group.ranks) == 2:
index = process_group.ranks.index(cur_rank)
is_send = True if index == 0 else False
if is_send:
recv_rank = process_group.ranks[1]
recv_rank_ip, recv_rank_port = genv.trainer_endpoints[
recv_rank].split(":")
connect_port = int(recv_rank_port) + magic_num
client_socket = socket.socket(socket.AF_INET,
socket.SOCK_STREAM)
client_socket.connect((recv_rank_ip, connect_port))
client_socket.send(str(cur_rank).encode('utf-8'))
rank = client_socket.recv(buff_size).decode('utf-8')
rank = int(rank)
if rank != recv_rank:
raise ValueError(
"Please check comm pair, the recv rank should be {} but got {}."
.format(recv_rank, rank))
else:
print("It is able to instantiate {} as sender now.".format(
process_group.ranks))
client_socket.close()
else:
send_rank = process_group.ranks[0]
while True:
if send_rank not in has_recv_by_socket:
client_socket, recv_addr = server_socket.accept()
rank = int(client_socket.recv(buff_size).decode())
client_sockets[rank] = client_socket
has_recv_by_socket.append(rank)
else:
client_sockets[send_rank].send(
str(cur_rank).encode("utf-8"))
client_sockets[send_rank].close()
print("It is able to instantiate {} as recver now.".
format(process_group.ranks))
break
process_group.instantiate()
server_socket.close()
...@@ -517,9 +517,11 @@ class AMPPass(PassBase): ...@@ -517,9 +517,11 @@ class AMPPass(PassBase):
self.set_attr("use_dynamic_loss_scaling", False) self.set_attr("use_dynamic_loss_scaling", False)
self.set_attr("input_data", []) self.set_attr("input_data", [])
self.set_attr("params_grads", []) self.set_attr("params_grads", [])
self._loss = None
self._loss_scaling = None self._loss_scaling = None
self._num_good_steps = None self._num_good_steps = None
self._num_bad_steps = None self._num_bad_steps = None
self._loss = None
def _check_self(self): def _check_self(self):
if self.get_attr("init_loss_scaling") < 0: if self.get_attr("init_loss_scaling") < 0:
......
...@@ -82,9 +82,11 @@ class DataParallelOptimizationPass(PassBase): ...@@ -82,9 +82,11 @@ class DataParallelOptimizationPass(PassBase):
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
self._analyze_program() self._analyze_program()
self._prune_grad_scaling()
self._calc_comm_overlap() if self.is_data_parallel_applied():
grad_group = self._fuse_allreduce() self._prune_grad_scaling()
self._calc_comm_overlap()
grad_group = self._fuse_allreduce()
# self.summary(grad_group) # self.summary(grad_group)
...@@ -167,6 +169,9 @@ class DataParallelOptimizationPass(PassBase): ...@@ -167,6 +169,9 @@ class DataParallelOptimizationPass(PassBase):
) == 0, "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format( ) == 0, "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format(
not_synchronized_grads) not_synchronized_grads)
def is_data_parallel_applied(self):
return len(self._group_to_grad_name_map) > 0
def _could_be_prune(self): def _could_be_prune(self):
return self.dist_context.gradient_scale and ( return self.dist_context.gradient_scale and (
......
...@@ -213,7 +213,7 @@ class ClipGradByGloblNormPass(PassBase): ...@@ -213,7 +213,7 @@ class ClipGradByGloblNormPass(PassBase):
if self.get_attr("dist_context") is None: if self.get_attr("dist_context") is None:
return False return False
dist_context = self.get_attr("dist_context") dist_context = self.get_attr("dist_context")
if dist_context._lr_optimizer._grad_clip is None: if dist_context._serial_optimizer._grad_clip is None:
return False return False
if self.get_attr("params_grads") is None: if self.get_attr("params_grads") is None:
return False return False
......
...@@ -396,7 +396,7 @@ class ShardingPass(PassBase): ...@@ -396,7 +396,7 @@ class ShardingPass(PassBase):
dp_ring_ids = [group.id for group in self.dp_groups] dp_ring_ids = [group.id for group in self.dp_groups]
for idx, op in reversed(list(enumerate(main_block.ops))): for idx, op in reversed(list(enumerate(main_block.ops))):
if is_data_parallel_reduce_op(op): if _is_param_grad_allreduce_op(op, main_block):
input_name = op.input_arg_names[0] input_name = op.input_arg_names[0]
base_name = _get_base_name_from_grad_name(input_name) base_name = _get_base_name_from_grad_name(input_name)
sharding_info = self.varname_to_sharding_info[base_name] sharding_info = self.varname_to_sharding_info[base_name]
...@@ -653,6 +653,20 @@ def _get_base_name_from_grad_name(grad_name): ...@@ -653,6 +653,20 @@ def _get_base_name_from_grad_name(grad_name):
return base_name return base_name
def _is_param_grad_allreduce_op(op, block):
if not is_data_parallel_reduce_op(op):
return False
output_name = op.output_arg_names[0]
base_name = _get_base_name_from_grad_name(output_name)
if not block.has_var(base_name):
return False
return block.var(base_name).is_parameter
def _is_param_grad_sum_op(op, block): def _is_param_grad_sum_op(op, block):
if not is_backward_op(op): if not is_backward_op(op):
......
...@@ -60,6 +60,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -60,6 +60,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_pass_amp MODULES test_pass_amp ENVS ${dist_ENVS}) py_test_modules(test_pass_amp MODULES test_pass_amp ENVS ${dist_ENVS})
set_tests_properties(test_pass_amp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" set_tests_properties(test_pass_amp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50) TIMEOUT 50)
py_test_modules(test_engine_callbacks MODULES test_engine_callbacks)
set_tests_properties(test_engine_callbacks
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_while_op_completion MODULES test_while_op_completion py_test_modules(test_while_op_completion MODULES test_while_op_completion
ENVS ${dist_ENVS}) ENVS ${dist_ENVS})
...@@ -78,6 +81,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -78,6 +81,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_dist_embedding MODULES test_dist_embedding ENVS py_test_modules(test_dist_embedding MODULES test_dist_embedding ENVS
${dist_ENVS}) ${dist_ENVS})
py_test_modules(test_dist_slice MODULES test_dist_slice ENVS ${dist_ENVS}) py_test_modules(test_dist_slice MODULES test_dist_slice ENVS ${dist_ENVS})
py_test_modules(test_dist_split MODULES test_dist_split ENVS ${dist_ENVS})
py_test_modules(test_cluster MODULES test_cluster ENVS ${dist_ENVS}) py_test_modules(test_cluster MODULES test_cluster ENVS ${dist_ENVS})
py_test_modules(test_comm_cost MODULES test_comm_cost ENVS ${dist_ENVS}) py_test_modules(test_comm_cost MODULES test_comm_cost ENVS ${dist_ENVS})
py_test_modules(test_comp_cost MODULES test_comp_cost ENVS ${dist_ENVS}) py_test_modules(test_comp_cost MODULES test_comp_cost ENVS ${dist_ENVS})
...@@ -96,4 +100,19 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -96,4 +100,19 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_strategy MODULES test_strategy) py_test_modules(test_strategy MODULES test_strategy)
py_test_modules(test_pass_quantization MODULES test_pass_quantization) py_test_modules(test_pass_quantization MODULES test_pass_quantization)
py_test_modules(test_dist_shape MODULES test_dist_shape)
py_test_modules(test_dist_assign MODULES test_dist_assign)
py_test_modules(test_conditional_block_reshard MODULES
test_conditional_block_reshard)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS})
set_tests_properties(test_parallel_tuner PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_full MODULES test_parallel_tuner_full
ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_full PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_predict MODULES
test_parallel_tuner_predict ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
endif() endif()
...@@ -88,33 +88,27 @@ class TestAMPPass(unittest.TestCase): ...@@ -88,33 +88,27 @@ class TestAMPPass(unittest.TestCase):
def test_amp_pass(self): def test_amp_pass(self):
# mp2 training # mp2 training
mp_engine = self.get_engine() mp_engine = self.get_engine()
mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size) history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(mp_losses["loss"]) mp_losses = np.array(history.history["loss"])
# mp2 amp-o1 training # mp2 amp-o1 training
amp_o1_engine = self.get_engine(True, "o1") amp_o1_engine = self.get_engine(True, "o1")
amp_o1_losses = amp_o1_engine.fit(self.dataset, history = amp_o1_engine.fit(self.dataset, 3, batch_size=self.batch_size)
3, amp_o1_losses = np.array(history.history["loss"])
batch_size=self.batch_size)
amp_o1_losses = np.array(amp_o1_losses["loss"])
amp_o1_engine.evaluate(self.dataset, 3, batch_size=self.batch_size) amp_o1_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o1_losses) # self.check_results(mp_losses, amp_o1_losses)
# mp2 amp-o2 training # mp2 amp-o2 training
amp_o2_engine = self.get_engine(True, "o2") amp_o2_engine = self.get_engine(True, "o2")
amp_o2_losses = amp_o2_engine.fit(self.dataset, history = amp_o2_engine.fit(self.dataset, 3, batch_size=self.batch_size)
3, amp_o2_losses = np.array(history.history["loss"])
batch_size=self.batch_size)
amp_o2_losses = np.array(amp_o2_losses["loss"])
amp_o2_engine.evaluate(self.dataset, 3, batch_size=self.batch_size) amp_o2_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o2_losses) # self.check_results(mp_losses, amp_o2_losses)
# mp2 amp-o3 training # mp2 amp-o3 training
amp_o3_engine = self.get_engine(True, "o3") amp_o3_engine = self.get_engine(True, "o3")
amp_o3_losses = amp_o3_engine.fit(self.dataset, history = amp_o3_engine.fit(self.dataset, 3, batch_size=self.batch_size)
3, amp_o3_losses = np.array(history.history["loss"])
batch_size=self.batch_size)
amp_o3_losses = np.array(amp_o3_losses["loss"])
amp_o3_engine.evaluate(self.dataset, 3, batch_size=self.batch_size) amp_o3_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o3_losses) # self.check_results(mp_losses, amp_o3_losses)
......
...@@ -20,6 +20,8 @@ import os ...@@ -20,6 +20,8 @@ import os
import numpy as np import numpy as np
import subprocess import subprocess
import paddle import paddle
import paddle.static as static
import paddle.utils as utils
import paddle.nn as nn import paddle.nn as nn
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.static as static import paddle.static as static
...@@ -29,14 +31,17 @@ from paddle.fluid import layers ...@@ -29,14 +31,17 @@ from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.distributed.auto_parallel.interface import get_collection, CollectionNames
from paddle.optimizer.lr import CosineAnnealingDecay from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn from paddle.fluid.dataloader.collate import default_collate_fn
paddle.enable_static() paddle.enable_static()
global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
PP_MESH_0 = auto.ProcessMesh([0]) PP_MESH_0 = auto.ProcessMesh([0])
PP_MESH_1 = auto.ProcessMesh([1]) PP_MESH_1 = auto.ProcessMesh([1])
batch_size = 1 epoch_num = 1
batch_size = 2
batch_num = 10 batch_num = 10
hidden_size = 1024 hidden_size = 1024
sequence_len = 512 sequence_len = 512
...@@ -46,6 +51,8 @@ class_num = 10 ...@@ -46,6 +51,8 @@ class_num = 10
paddle.seed(44) paddle.seed(44)
is_fetch = True is_fetch = True
is_feed = True
my_feed_vars = []
class MyDataset(Dataset): class MyDataset(Dataset):
...@@ -63,6 +70,23 @@ class MyDataset(Dataset): ...@@ -63,6 +70,23 @@ class MyDataset(Dataset):
return self.num_samples return self.num_samples
def get_random_inputs_and_labels(image_shape, label_shape):
input = np.random.random(size=image_shape).astype('float32')
label = np.random.random(size=label_shape).astype('int64')
return input, label
def batch_generator_creator():
def __reader__():
for _ in range(batch_num):
batch_input, batch_label = get_random_inputs_and_labels(
[batch_size, image_size], [batch_size, 1])
yield batch_input, batch_label
return __reader__
class MLPLayer(nn.Layer): class MLPLayer(nn.Layer):
def __init__(self, def __init__(self,
...@@ -92,16 +116,20 @@ class MLPLayer(nn.Layer): ...@@ -92,16 +116,20 @@ class MLPLayer(nn.Layer):
def forward(self, input): def forward(self, input):
out = auto.shard_op(self.norm, PP_MESH_0)(input) out = auto.shard_op(self.norm, PP_MESH_0)(input)
out = self.linear0(out) out = self.linear0(out)
if is_feed:
my_feed_vars.append((out, out.shape))
out = F.gelu(out, approximate=True) out = F.gelu(out, approximate=True)
out = auto.shard_op(self.linear1, PP_MESH_1)(out) out = auto.shard_op(self.linear1, PP_MESH_1)(out)
out = self.dropout(out) out = self.dropout(out)
out = self.linear2(out) out = self.linear2(out)
if is_feed:
my_feed_vars.append((out, out.shape))
if is_fetch: if is_fetch:
auto.fetch(out, "out") auto.fetch(out, "my_fetch", logging=True)
return out return out
def train(fetch): def train_high_level(fetch):
global is_fetch global is_fetch
is_fetch = fetch is_fetch = fetch
mlp = MLPLayer(hidden_size=hidden_size, mlp = MLPLayer(hidden_size=hidden_size,
...@@ -124,10 +152,12 @@ def train(fetch): ...@@ -124,10 +152,12 @@ def train(fetch):
# train # train
train_dataset = MyDataset(batch_num * batch_size) train_dataset = MyDataset(batch_num * batch_size)
eval_dataset1 = MyDataset(5 * batch_size) eval_dataset1 = MyDataset(5 * batch_size)
engine.fit(train_data=train_dataset,
epochs=2, history = engine.fit(train_data=train_dataset,
batch_size=batch_size, epochs=2,
valid_data=eval_dataset1) batch_size=batch_size,
valid_data=eval_dataset1,
log_freq=1)
# eval # eval
eval_dataset2 = MyDataset(batch_size) eval_dataset2 = MyDataset(batch_size)
...@@ -135,7 +165,7 @@ def train(fetch): ...@@ -135,7 +165,7 @@ def train(fetch):
# predict # predict
test_dataset = MyDataset(batch_size) test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size=batch_size) outputs = engine.predict(test_dataset, batch_size=batch_size)
# save # save
temp_dir = tempfile.TemporaryDirectory() temp_dir = tempfile.TemporaryDirectory()
...@@ -145,6 +175,265 @@ def train(fetch): ...@@ -145,6 +175,265 @@ def train(fetch):
temp_dir.cleanup() temp_dir.cleanup()
def train_low_level():
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(mlp, loss, optimizer, metrics=None, strategy=strategy)
feed_dict = {}
for feed_var, shape in my_feed_vars:
feed_dict[feed_var.name] = np.zeros(shape, dtype="float32")
# Build normal normal dataloader
# train
train_dataset = MyDataset(batch_num * batch_size)
train_dataloader = engine.dataloader(train_dataset,
batch_size=batch_size,
mode="train")
engine.prepare(mode="train")
for data in train_dataloader:
outs = engine.run(data, feed=feed_dict, mode="train")
# eval
eval_dataset2 = MyDataset(batch_size)
eval_dataloader = engine.dataloader(eval_dataset2,
batch_size=batch_size,
mode="eval")
engine.prepare(mode="eval")
for data in eval_dataloader:
outs = engine.run(data, feed=feed_dict, mode="eval")
# predict
engine.to_mode("predict")
test_dataset = MyDataset(batch_size)
predict_dataloader = engine.dataloader(test_dataset, batch_size=batch_size)
engine.prepare()
for data in predict_dataloader:
outs = engine.run(data, feed=feed_dict)
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp')
engine.save(model_filename, training=True)
engine.load(model_filename)
temp_dir.cleanup()
# Build dataloader from generator
# train
train_dataset = MyDataset(batch_num * batch_size)
train_dataloader = engine.dataloader_from_generator(train_dataset,
batch_size=batch_size,
mode="train")
engine.prepare(mode="train")
for data in train_dataloader:
outs = engine.run(data, feed=feed_dict, mode="train")
# eval
engine.to_mode("eval")
eval_dataset2 = MyDataset(batch_size)
eval_dataloader = engine.dataloader_from_generator(eval_dataset2,
batch_size=batch_size)
engine.prepare()
for data in eval_dataloader:
outs = engine.run(data, feed=feed_dict)
# predict
test_dataset = MyDataset(batch_size)
predict_dataloader = engine.dataloader_from_generator(test_dataset,
batch_size=batch_size,
mode="predict")
engine.prepare(mode="predict")
for data in predict_dataloader:
outs = engine.run(data, feed=feed_dict, mode="predict")
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp')
engine.save(model_filename, training=True)
engine.load(model_filename)
temp_dir.cleanup()
def train_builtin_data_vars():
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)
# train
engine.to_mode("train")
input_spec = static.InputSpec([batch_size, image_size], 'float32', 'input')
label_spec = static.InputSpec([batch_size, 1], 'int64', 'label')
engine.prepare(inputs_spec=[input_spec], labels_spec=[label_spec])
with static.program_guard(engine.main_program, engine.startup_program):
feed_list = engine.inputs + engine.labels
print(feed_list)
loader = paddle.io.DataLoader.from_generator(feed_list=feed_list,
capacity=4 * batch_size,
iterable=False)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
for _ in range(epoch_num):
loader.start() # call DataLoader.start() before each epoch starts
try:
while True:
engine.run()
except paddle.fluid.core.EOFException:
loader.reset(
) # call DataLoader.reset() after catching EOFException
def train_non_builtin_data_vars():
main_program = static.Program()
startup_program = static.Program()
with static.program_guard(main_program,
startup_program), utils.unique_name.guard():
input = static.data(name="input",
shape=[batch_size, image_size],
dtype='float32')
label = static.data(name="label", shape=[batch_size, 1], dtype='int64')
loader = paddle.io.DataLoader.from_generator(feed_list=[input, label],
capacity=4 * batch_size,
iterable=False)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
predict = mlp(input)
loss_var = loss(predict, label)
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(loss=loss_var,
optimizer=optimizer,
metrics=metric,
strategy=strategy)
# train
engine.to_mode("train")
engine.prepare(inputs=[input],
labels=[label],
main_program=main_program,
startup_program=startup_program)
for _ in range(epoch_num):
loader.start() # call DataLoader.start() before each epoch starts
try:
while True:
engine.run()
except paddle.fluid.core.EOFException:
loader.reset(
) # call DataLoader.reset() after catching EOFException
def get_cost():
main_program = static.default_main_program()
startup_program = static.default_startup_program()
with static.program_guard(main_program,
startup_program), utils.unique_name.guard():
input = static.data(name="input",
shape=[batch_size, image_size],
dtype='float32')
label = static.data(name="label", shape=[batch_size, 1], dtype='int64')
loader = paddle.io.DataLoader.from_generator(feed_list=[input, label],
capacity=4 * batch_size,
iterable=False)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
predict = mlp(input)
loss_var = loss(predict, label)
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(loss=loss_var,
optimizer=optimizer,
metrics=metric,
strategy=strategy)
engine.cost()
def get_cost_by_spec():
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)
input_spec = static.InputSpec([batch_size, image_size], 'float32', 'input')
label_spec = static.InputSpec([batch_size, 1], 'int64', 'label')
engine.cost(mode="eval", inputs_spec=[input_spec], labels_spec=[label_spec])
if __name__ == "__main__": if __name__ == "__main__":
train(fetch=True) train_high_level(fetch=True)
train(fetch=False) train_high_level(fetch=False)
train_low_level()
train_builtin_data_vars()
train_non_builtin_data_vars()
get_cost()
get_cost_by_spec()
...@@ -84,25 +84,32 @@ class TestGradientMergePass(unittest.TestCase): ...@@ -84,25 +84,32 @@ class TestGradientMergePass(unittest.TestCase):
def test_gradient_merge_pass(self): def test_gradient_merge_pass(self):
# dp2 training # dp2 training
dp_engine = self.get_engine() dp_engine = self.get_engine()
dp_losses = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size) history = dp_engine.fit(self.dataset,
dp_losses = np.array(dp_losses["loss"]) 3,
batch_size=self.batch_size,
log_freq=1)
dp_losses = np.array(history.history["loss"])
# dp2 gradient merge training # dp2 gradient merge training
gm_engine = self.get_engine(True) gm_engine = self.get_engine(True)
gm_losses = gm_engine.fit(self.dataset, 3, batch_size=self.batch_size) history = gm_engine.fit(self.dataset,
gm_losses = np.array(gm_losses["loss"]) 3,
batch_size=self.batch_size,
avg_loss = 0 log_freq=1)
pass_avg_ret_list = [] gm_losses = np.array(history.history["loss"])
for i, pass_ret in enumerate(gm_losses):
if (i + 1) % 4 == 0: # avg_loss = 0
avg_loss += pass_ret # pass_avg_ret_list = []
pass_avg_ret_list.append(avg_loss / 4) # for i, pass_ret in enumerate(gm_losses):
avg_loss = 0 # if (i + 1) % 4 == 0:
else: # avg_loss += pass_ret
avg_loss += pass_ret # pass_avg_ret_list.append(avg_loss / 4)
# avg_loss = 0
self.check_results(dp_losses, np.array(pass_avg_ret_list)) # else:
# avg_loss += pass_ret
# NOTE: every sample data from dataset is all the same
self.check_results(dp_losses, gm_losses)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -79,13 +79,13 @@ class TestRecomputePass(unittest.TestCase): ...@@ -79,13 +79,13 @@ class TestRecomputePass(unittest.TestCase):
def test_recompute_pass(self): def test_recompute_pass(self):
# mp2 training # mp2 training
mp_engine = self.get_engine() mp_engine = self.get_engine()
mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size) history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(mp_losses["loss"]) mp_losses = np.array(history.history["loss"])
# mp2 recompute training # mp2 recompute training
rc_engine = self.get_engine(True) rc_engine = self.get_engine(True)
rc_losses = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size) history = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc_losses = np.array(rc_losses["loss"]) rc_losses = np.array(history.history["loss"])
self.check_results(mp_losses, rc_losses) self.check_results(mp_losses, rc_losses)
......
...@@ -84,31 +84,31 @@ class TestShardingPass(unittest.TestCase): ...@@ -84,31 +84,31 @@ class TestShardingPass(unittest.TestCase):
def test_sharding_pass(self): def test_sharding_pass(self):
# dp2 training # dp2 training
dp_engine = self.get_engine() dp_engine = self.get_engine()
dp_losses = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size) history = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_losses = np.array(dp_losses["loss"]) dp_losses = np.array(history.history["loss"])
# sharding2 stage1 training # sharding2 stage1 training
sharding1_engine = self.get_engine(True, 1) sharding1_engine = self.get_engine(True, 1)
sharding1_losses = sharding1_engine.fit(self.dataset, history = sharding1_engine.fit(self.dataset,
3, 3,
batch_size=self.batch_size) batch_size=self.batch_size)
sharding1_losses = np.array(sharding1_losses["loss"]) sharding1_losses = np.array(history.history["loss"])
self.check_results(dp_losses, sharding1_losses) self.check_results(dp_losses, sharding1_losses)
# sharding2 stage2 training # sharding2 stage2 training
sharding2_engine = self.get_engine(True, 2) sharding2_engine = self.get_engine(True, 2)
sharding2_losses = sharding2_engine.fit(self.dataset, history = sharding2_engine.fit(self.dataset,
3, 3,
batch_size=self.batch_size) batch_size=self.batch_size)
sharding2_losses = np.array(sharding2_losses["loss"]) sharding2_losses = np.array(history.history["loss"])
self.check_results(dp_losses, sharding2_losses) self.check_results(dp_losses, sharding2_losses)
# sharding2 stage3 training # sharding2 stage3 training
sharding3_engine = self.get_engine(True, 3) sharding3_engine = self.get_engine(True, 3)
sharding3_losses = sharding3_engine.fit(self.dataset, history = sharding3_engine.fit(self.dataset,
3, 3,
batch_size=self.batch_size) batch_size=self.batch_size)
sharding3_losses = np.array(sharding3_losses["loss"]) sharding3_losses = np.array(history.history["loss"])
self.check_results(dp_losses, sharding3_losses) self.check_results(dp_losses, sharding3_losses)
......
...@@ -82,6 +82,9 @@ from paddle.distributed.auto_parallel.cost.comp_op_cost import Transpose2OpCost ...@@ -82,6 +82,9 @@ from paddle.distributed.auto_parallel.cost.comp_op_cost import Transpose2OpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import Transpose2GradOpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import Transpose2GradOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import Unsqueeze2OpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import Unsqueeze2OpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import WriteToArrayOpCost from paddle.distributed.auto_parallel.cost.comp_op_cost import WriteToArrayOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import DropoutGradOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import FusedSoftmaxMaskUpperTriangleOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import FusedSoftmaxMaskUpperTriangleGradOpCost
from test_cluster import cluster_json from test_cluster import cluster_json
...@@ -417,6 +420,22 @@ class TestCompOpCost(unittest.TestCase): ...@@ -417,6 +420,22 @@ class TestCompOpCost(unittest.TestCase):
self.assertTrue(op_cost.flops >= 0) self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0) self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0) self.assertTrue(op_cost.memory >= 0)
op_cost = DropoutGradOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)
op_cost = FusedSoftmaxMaskUpperTriangleOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)
op_cost = FusedSoftmaxMaskUpperTriangleGradOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)
# Remove unnecessary files # Remove unnecessary files
if os.path.exists(cluster_json_path): if os.path.exists(cluster_json_path):
os.remove(cluster_json_path) os.remove(cluster_json_path)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.static import InputSpec
from paddle.distributed.fleet import auto
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=64,
intermediate_size=4 * 64,
initializer_range=0.02):
super(MLPLayer, self).__init__()
self.norm = nn.LayerNorm(hidden_size, epsilon=1e-5)
self.linear0 = nn.Linear(
hidden_size,
intermediate_size,
paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)),
bias_attr=None)
self.linear1 = nn.Linear(
intermediate_size,
hidden_size,
paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)),
bias_attr=None)
def forward(self, input):
out = self.norm(input)
auto.shard_tensor(self.linear0.weight, auto.ProcessMesh([0, 1], "x"),
[None, "x"])
out = self.linear0(out)
out = F.gelu(out, approximate=True)
auto.shard_tensor(self.linear1.weight, auto.ProcessMesh([0, 1], "x"),
["x", None])
out = self.linear1(out)
if paddle.mean(out) < 2:
out = self.norm(out)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
else:
out = self.norm(out)
out = self.linear0(out)
out = self.linear1(out)
return out
def loss_fn(predict, label):
error_cost = paddle.nn.functional.square_error_cost(predict, label)
loss = paddle.mean(error_cost)
return loss
class TestSubblock(unittest.TestCase):
def test_subblock(self):
mlp = MLPLayer()
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(model=mlp, loss=loss_fn, strategy=strategy)
input_sepc = InputSpec([4, 64], 'float32', 'input')
label_spec = InputSpec([4, 1], 'float32', 'label')
engine.prepare(inputs_spec=[input_sepc],
labels_spec=[label_spec],
mode="predict")
if __name__ == "__main__":
unittest.main()
...@@ -199,7 +199,7 @@ class TestDistributedContext(unittest.TestCase): ...@@ -199,7 +199,7 @@ class TestDistributedContext(unittest.TestCase):
"_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \ "_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \
"_serial_ordered_op_nodes", "_original_serial_loss", \ "_serial_ordered_op_nodes", "_original_serial_loss", \
"_original_serial_feed_vars", "_original_serial_fetch_vars", \ "_original_serial_feed_vars", "_original_serial_fetch_vars", \
"_serial_loss", "_serial_feed_vars", "_serial_fetch_vars", "_lr_optimizer", \ "_serial_loss", "_serial_feed_vars", "_serial_fetch_vars", "_serial_optimizer", \
"_backup_serial_main_program_stack", "_backup_serial_startup_program_stack", \ "_backup_serial_main_program_stack", "_backup_serial_startup_program_stack", \
"_pass_context"] "_pass_context"]
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册