未验证 提交 90b31790 编写于 作者: Z zhaoyingli 提交者: GitHub

[Cherry-Pick][AutoParallel] auto_parallel cherry-pick to release2.4 (#47145)

* [Auto Parallel] Make Engine class callable (#46416)

* [Auto Parallel] Imporve the user-defined fetches and logging

* [Auto Parallel] Make Engine class callable

* [Auto Parallel] Update the data loading of tuner

* Print IPS in auto parallel Engine (#46554)

* [AutoParallel] fix dist_split (#46505)

* [AutoParallel] fix dist_split

* add unittest

* update cmakelist

* [AutoParallel] fix sharding (#46572)

* [AutoParallel] fix process_mesh (#46583)

* [AutoParallel] fix reshard when train with eval (#46605)

* [AutoParallel] fix reshard when train with eval

* fix mppp

* [AutoParallel] fix amp when predict (#46637)

* [Auto Parallel]Update comp cost and completion for gpt auto search (#46387)

* update comp cost and completion for gpt auto search

* add unittest

* [Auto Parallel] Fix bugs caused by the inconsistent outputs of Engine API (#46633)

* [Auto Parallel] Unify the logger and outputs of Engine API

* [Auto Parallel] Fix the bugs of to_static

* [Auto Parallel] Adjust the test_to_static.py

* [Auto Parallel] Improve the fine-grained APIs (#46552)

* [Auto Parallel] Suppport different dataloaders

* [Auto Parallel] Add num_shards config for dataset

* [Auto Parallel] Unify the logger and outputs of Engine API

* [Auto Parallel] Fix the bugs of to_static

* [Auto Parallel] Adjust the test_to_static.py

* [Auto Parallel] Add the prepare API and replace __call__ with run

* [Auto Parallel] Improve the private implementations of Engine

* [Auto Parallel] Set capacity of dataloader for opt tuning

* [Auto Parallel] [WIP] Change the fine-grained API

* [Auto Parallel] Improve APIs to support different user cases

* [Auto Parallel] Add removed config

* [Auto Parallel] Add imports

* [Auto Parallel] Fix bugs for to_static

* [Auto Parallel] Remove unnecessary imports

* bugfix (#46921)

* [Auto Parallel] Fix the bug for None labels (#46987)

* [AutoParallel] adapt for gpt-gen (#46771)

* for gpt-gen

* fix reshard

* adapt assign and shape op

* add dist_assign & unittest

* add conditional block unittest

* rename unittest

* [Auto Parallel] Fix the bug of completion (#47056)

* [Auto Parallel] Fix the bug for None labels

* [Auto Parallel] Fix the completion bug

* [AutoParallel] add callbacks (#47014)

* [AutoParallel] add callbacks

* fix unittest

* fix dist_context

* fix engine

* fix cmakelist

* fix unittest's returns

* fix cmakelist

* [Auto Parallel] Add cost interface (#47043)

* add cost interface

* update inferface and add unittest

* update unittest

* update inferface

* [Auto Parallel]Add parallel tuner (#46189)

* add parallel tuner

* add unittest

* fix unittest

* set timeout of unittest

* set unittest timeout

* fix auto_mode setting

* update unittest

* sync from develop and update unittest

* remove unused import

* update unittest

* update cmakelist

* add unittests
Co-authored-by: NYulong Ao <aoyulong@baidu.com>
Co-authored-by: NRuibiao Chen <chenruibiao@baidu.com>
Co-authored-by: Ncaozhou <48191911+Caozhou1995@users.noreply.github.com>
Co-authored-by: NJZ-LIANG <jianzhongliang10@gmail.com>
上级 23f2a4ea
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import paddle
from paddle.hapi.callbacks import ProgBarLogger, ModelCheckpoint, LRScheduler, CallbackList, Callback
from .interface import CollectionNames, get_collection
def config_callbacks(callbacks=None,
engine=None,
batch_size=None,
epochs=None,
steps=None,
log_freq=2,
verbose=2,
save_freq=1,
save_dir=None,
metrics=None,
acc_step=1,
mode='train'):
cbks = callbacks or []
cbks = cbks if isinstance(cbks, (list, tuple)) else [cbks]
if not any(isinstance(k, ProgBarLogger) for k in cbks) and verbose:
cbks = [ProgBarLoggerAuto(log_freq, verbose=verbose)] + cbks
if not any(isinstance(k, LRScheduler) for k in cbks):
cbks = [LRSchedulerAuto()] + cbks
if not any(isinstance(k, ModelCheckpoint) for k in cbks):
cbks = cbks + [ModelCheckpointAuto(save_freq, save_dir)]
if not any(isinstance(k, Profiler) for k in cbks) and verbose == 3:
cbks = cbks + [Profiler(timer_only=True)]
if not any(isinstance(k, History) for k in cbks):
cbks = cbks + [History()]
for i, k in enumerate(cbks):
if isinstance(k, ProgBarLogger):
cbks[i] = ProgBarLoggerAuto(k.log_freq, k.verbose)
if isinstance(k, LRScheduler):
cbks[i] = LRSchedulerAuto(k.by_step, k.by_epoch)
if isinstance(k, ModelCheckpoint):
cbks[i] = ModelCheckpointAuto(k.save_freq, k.save_dir)
cbk_list = CallbackList(cbks)
cbk_list.set_model(engine)
metrics = metrics or [] if mode != 'test' else []
params = {
'batch_size': batch_size,
'epochs': epochs,
'steps': steps,
'verbose': verbose,
'metrics': metrics,
'acc_step': acc_step,
}
cbk_list.set_params(params)
return cbk_list
class ProgBarLoggerAuto(ProgBarLogger):
def __init__(self, log_freq=1, verbose=2):
super(ProgBarLoggerAuto, self).__init__(log_freq, verbose)
def _is_print(self):
return True
def _updates(self, logs, mode):
values = []
metrics = getattr(self, '%s_metrics' % (mode))
progbar = getattr(self, '%s_progbar' % (mode))
steps = getattr(self, '%s_step' % (mode))
for k in metrics:
if k in logs:
values.append((k, logs[k]))
if 'lr' in logs:
values.append(('lr', logs['lr']))
fetches_logs = logs.get('fetches', {})
collect_logging = get_collection(CollectionNames.LOGGING)
for name, var in collect_logging:
k = name or var.name
if k in fetches_logs:
values.append((k, fetches_logs[k]))
out_logs = logs.get('outputs', {})
for k in out_logs:
values.append((k, out_logs[k]))
if self.verbose == 3 and hasattr(self, '_%s_timer' % (mode)):
timer = getattr(self, '_%s_timer' % (mode))
cnt = timer['count'] if timer['count'] > 0 else 1.0
samples = timer['samples'] if timer['samples'] > 0 else 1.0
values.append(
('avg_reader_cost', "%.5f sec" % (timer['data_time'] / cnt)))
values.append(
('avg_batch_cost', "%.5f sec" % (timer['batch_time'] / cnt)))
values.append(
('ips', "%.5f samples/sec" %
(samples / (timer['data_time'] + timer['batch_time']))))
timer['count'] = 0
timer['samples'] = 0
timer['data_time'] = 0.
timer['batch_time'] = 0.
progbar.update(steps, values)
def on_eval_batch_end(self, step, logs=None):
logs = logs or {}
self.eval_step += 1
samples = self.params['batch_size']
self.evaled_samples += samples
self._eval_timer['batch_time'] += (
time.time() - self._eval_timer['batch_data_end_time'])
self._eval_timer['count'] += 1
samples = self.params['batch_size']
self._eval_timer['samples'] += samples
if self._is_print() and self.eval_step % self.log_freq == 0:
if self.eval_steps is None or self.eval_step < self.eval_steps:
self._updates(logs, 'eval')
self._eval_timer['batch_start_time'] = time.time()
class LRSchedulerAuto(LRScheduler):
def __init__(self, by_step=True, by_epoch=False):
super(LRSchedulerAuto, self).__init__(by_step, by_epoch)
def on_epoch_begin(self, epoch=None, logs=None):
self.acc_step = self.params["acc_step"]
self.epoch = epoch
self.train_step = 0
def on_train_batch_end(self, step, logs=None):
self.train_step += 1
if self.by_step and self.train_step % self.acc_step == 0:
if self.model._optimizer and \
hasattr(self.model._optimizer, '_learning_rate') and \
isinstance(self.model._optimizer._learning_rate,
paddle.optimizer.lr.LRScheduler):
self.model._optimizer._learning_rate.step()
class History(Callback):
def __init__(self):
self.history = {}
def on_train_begin(self, logs=None):
self.epoch = []
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epoch.append(epoch)
for k, v in logs.items():
self.history.setdefault(k, []).append(v)
self.model.history = self
class Profiler(Callback):
def __init__(self, *args, **kwargs):
self.prof = paddle.profiler.Profiler(*args, **kwargs)
def on_epoch_begin(self, epoch=None, logs=None):
self.epoch = epoch
self.train_step = 0
self.batch_size = self.params["batch_size"]
self.steps = self.params['steps']
def on_train_begin(self, logs=None):
self.prof.start()
def on_train_batch_end(self, step, logs=None):
self.train_step += 1
self.prof.step(num_samples=self.batch_size)
print("step {}:{}".format(self.train_step,
self.prof.step_info(unit='samples')))
def on_train_end(self, logs=None):
self.prof.stop()
self.prof.summary()
class ModelCheckpointAuto(ModelCheckpoint):
def __init__(self, *args, **kwargs):
super(ModelCheckpointAuto, self).__init__(*args, **kwargs)
def _is_save(self):
return self.model and self.save_dir
def on_epoch_end(self, epoch, logs=None):
if self._is_save() and (self.epoch + 1) % self.save_freq == 0:
path = '{}/epoch{}'.format(self.save_dir, epoch)
print('save checkpoint at {}'.format(os.path.abspath(path)))
self.model.save(path)
def on_train_end(self, logs=None):
if self._is_save():
path = '{}/final'.format(self.save_dir)
print('save checkpoint at {}'.format(os.path.abspath(path)))
self.model.save(path)
......@@ -19,7 +19,7 @@ import time
from paddle.fluid import core
from paddle.fluid import framework
from .utils import print_program_with_dist_attr, is_gradient_clip_op
from .utils import is_gradient_clip_op, __not_shape_var_type__
from .operators import find_compatible_distributed_operator_impls
from .dist_context import get_default_distributed_context, _node_id
from .dist_tensor import DistributedTensor
......@@ -142,6 +142,7 @@ class Completer:
def __init__(self, dist_context):
assert dist_context is not None
self._dist_context = dist_context
self._has_prepared = False
def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
changed = False
......@@ -366,7 +367,14 @@ class Completer:
def _update_dims_mapping_for_special(self):
# Set the dims_mapping of a tensor to the dims_mapping inside the op which produces it
op_nodes = self._dist_context._serial_ordered_op_nodes
# NOTE: this list may be changed if Paddle changes the existing rules.
related_reader_ops = [
"create_py_reader", "create_double_buffer_reader", "read"
]
for op_node in op_nodes:
if op_node.op() is not None \
and op_node.op().type() in related_reader_ops:
continue
op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node)
for tensor_node in op_node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
......@@ -406,6 +414,7 @@ class Completer:
reach_fix_point = False
else:
reach_fix_point = True
# NOTE: this will be removed after changing the reshard rule
self._update_dims_mapping_for_special()
def _update_process_mesh_by_nearest(self, op_node, nearest_op_node):
......@@ -494,14 +503,14 @@ class Completer:
for tensor_node in node.inputs:
if tensor_node.is_var() and tensor_node.var(
) is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER \
if tensor_node.var().type() in __not_shape_var_type__ \
or len(tensor_node.var().shape()) != 1:
flag = False
break
for tensor_node in node.outputs:
if tensor_node.is_var() and tensor_node.var(
) is not None:
if tensor_node.var().type() == core.VarDesc.VarType.READER \
if tensor_node.var().type() in __not_shape_var_type__ \
or len(tensor_node.var().shape()) != 1:
flag = False
break
......@@ -719,6 +728,8 @@ class Completer:
self._update_process_mesh_between_graphs()
def _prepare(self):
if self._has_prepared:
return
self._while_op_nodes = {}
self._array_nodes = {}
self._node_pairs_between_graphs = []
......@@ -732,6 +743,8 @@ class Completer:
if self._array_nodes.get(array_var_name, None) is None:
self._array_nodes[array_var_name] = []
self._array_nodes[array_var_name].append(node)
# Add the array input node
self._array_nodes[array_var_name].append(node.inputs[0])
if node.op().type() == "write_to_array":
array_var_name = node.op().output("Out")[0]
if self._array_nodes.get(array_var_name, None) is None:
......@@ -752,6 +765,7 @@ class Completer:
and after_node.var().name() == node.var().name():
self._node_pairs_between_graphs.append(
(after_node, node))
self._has_prepared = True
def complete_forward_annotation(self, serial_main_program=None):
""" Complete annotation for the partial annotated serial_main_program.
......@@ -899,6 +913,72 @@ class Completer:
else:
dist_op.dist_attr = original_op_dist_attr
def _complete_tensor_dist_attr_by_op(self, serial_main_program=None):
if serial_main_program is None:
serial_main_program = self._dist_context.serial_main_program
else:
self._dist_context._serial_main_program = serial_main_program
self._dist_context.initialize()
self._prepare()
has_set_dist_attr = set()
all_nodes = self._dist_context.serial_ordered_nodes
for node in all_nodes:
if node.is_op():
if node.op().type() in ["while"]:
continue
dist_op = self._dist_context.get_dist_op_for_graph(node)
op_dist_attr = dist_op.dist_attr
for tensor_node in node.inputs:
if tensor_node.is_var() and tensor_node.var() is not None:
# Skip the non-leaf var node
if len(tensor_node.inputs) != 0:
continue
tensor_desc = tensor_node.var()
tensor_name = tensor_desc.name()
tensor = dist_op.get_serial_input(tensor_name)
# Use the first op to set the tensor dist attr
if tensor_name in has_set_dist_attr:
continue
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dist_attr.process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_input_dims_mapping(
tensor_name) if tensor.is_parameter else [
-1 for i in tensor_desc.shape()
]
has_set_dist_attr.add(tensor_name)
for tensor_node in node.outputs:
if tensor_node.is_var() and tensor_node.var() is not None:
tensor_name = tensor_node.var().name()
if tensor_name in has_set_dist_attr:
continue
tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
tensor_node)
tensor_dist_attr.process_mesh = op_dist_attr.process_mesh
tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping(
tensor_name)
has_set_dist_attr.add(tensor_name)
self._update_process_mesh_for_specials()
self._update_process_mesh_between_graphs()
self._update_dims_mapping_for_special()
self._update_dims_mapping_between_graphs()
# Copy the corresponding distributed attribute from graph to serial_main_program
self._dist_context.copy_dist_attr_from_graph_to_program()
# Do the validation check and amend some completion
self._dist_context.amend_dist_attr_for_program()
self._dist_context.validate_dist_attr_for_program()
def _complete_high_order_grad_annotation(self, serial_main_program=None):
"""
NOTE:
......
......@@ -116,3 +116,10 @@ set_field_default_config(TUNING, "profile_start_step", 1)
set_field_default_config(TUNING, "profile_end_step", 1)
set_field_default_config(TUNING, "run_after_tuning", True)
set_field_default_config(TUNING, "verbose", True)
#########################################
# dataset configuration
#########################################
DATASET = "dataset"
set_field_default_config(DATASET, "enable", False)
set_field_default_config(DATASET, "num_shards", 1)
......@@ -167,6 +167,25 @@ class DropoutOpCost(CompOpCost):
return 0
@register_op_cost
class DropoutGradOpCost(CompOpCost):
OP_TYPE = "dropout_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(DropoutGradOpCost, self).__init__(op=op,
op_desc=op_desc,
cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class ElementwiseAddOpCost(CompOpCost):
OP_TYPE = "elementwise_add"
......@@ -395,6 +414,42 @@ class FillConstantBatchSizeLikeOpCost(CompOpCost):
return 0
@register_op_cost
class FusedSoftmaxMaskUpperTriangleOpCost(CompOpCost):
OP_TYPE = "fused_softmax_mask_upper_triangle"
def __init__(self, op=None, op_desc=None, cluster=None):
super(FusedSoftmaxMaskUpperTriangleOpCost,
self).__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class FusedSoftmaxMaskUpperTriangleGradOpCost(CompOpCost):
OP_TYPE = "fused_softmax_mask_upper_triangle_grad"
def __init__(self, op=None, op_desc=None, cluster=None):
super(FusedSoftmaxMaskUpperTriangleGradOpCost,
self).__init__(op=op, op_desc=op_desc, cluster=cluster)
# For a concrete COMP OP, the calc_time and calc_flops function need to be overrided
def calc_flops(self):
# NOTE: The actual formula will be filled in the future
return 0
def calc_time(self):
# NOTE: The actual formula will be filled in the future
return 0
@register_op_cost
class GatherOpCost(CompOpCost):
OP_TYPE = "gather"
......
......@@ -45,6 +45,8 @@ class CostEstimator:
) # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}}
self._bubble_time_mapping = {}
self._ordered_ops = []
self.max_memories = {}
self.max_memory = None
@property
def loop_count(self):
......@@ -123,7 +125,7 @@ class CostEstimator:
for i in range(loop_count):
for op in ops:
self._detailed_cost[op.desc.id()] = OrderedDict()
# if in the while sub block, the detail of cost is the last cost
# If in the while sub block, the detail of cost is the last cost
detail = self._detailed_cost[op.desc.id()]
detail["reshard_cost"] = OrderedDict() #
detail["dist_op_cost"] = []
......@@ -147,15 +149,15 @@ class CostEstimator:
var = get_var_with_recursion(var_name, block, self.program)
reshard_cost = resharder.get_cost(op, var, self.cluster)
# calc reshard cost
# Calc reshard cost
if reshard_cost is not None:
detail["reshard_cost"][var_name] = reshard_cost
comm_costs = reshard_cost[0]
local_comp_cost = reshard_cost[1]
for comm_cost in comm_costs:
# time is cumulative in global cost and local cost, but memory and flops just are cumulative in global cost.
# comm sync
# Time is cumulative in global cost and local cost, but memory and flops just are cumulative in global cost.
# Comm sync
for item in comm_cost:
group_ranks, cost = item
max_time = None
......@@ -183,7 +185,7 @@ class CostEstimator:
for comp_cost in local_comp_cost[rank]:
self.local_cost(rank).time += comp_cost.time
# calc dist op cost
# Calc dist op cost
dist_op = dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
processes = op_dist_attr.process_mesh.processes
......@@ -201,7 +203,7 @@ class CostEstimator:
continue
for item in dist_op_cost:
if isinstance(item, list):
# comm sync
# Comm sync
for comm_op_cost in item:
max_time = None
cost_time = {}
......@@ -222,9 +224,9 @@ class CostEstimator:
self._bubble_time_mapping[rank] += (
max_time - cost_time[rank])
elif isinstance(item, dict):
# op just one
# Op just one
for rank in processes:
# dp+pp+mp
# DP+PP+MP
if rank not in item:
continue
self.local_cost(rank).time += item[rank].time
......@@ -267,7 +269,7 @@ class CostEstimator:
return result
memories = {}
max_memories = {}
self.max_memories = {}
var_info = {
} # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]}
......@@ -277,6 +279,10 @@ class CostEstimator:
self._ordered_ops.sort(key=lambda x: x[0])
for op_id, op in self._ordered_ops:
if op.type in [
"create_py_reader", "create_double_buffer_reader", "read"
]:
continue
dist_op = dist_context.get_dist_op_for_program(op)
process_mesh = dist_op.dist_attr.process_mesh
for var_name in op.input_arg_names:
......@@ -288,7 +294,7 @@ class CostEstimator:
input_dims_mapping)
if key not in var_info[var_name]:
var_info[var_name][key] = {}
# it is even partition now
# It is even partition now
if "memory" not in var_info[var_name][key]:
var = dist_op.get_serial_input(var_name)
global_sizes = var.shape
......@@ -326,6 +332,10 @@ class CostEstimator:
has_used_vars = set()
for op_id, op in self._ordered_ops:
if op.type in [
"create_py_reader", "create_double_buffer_reader", "read"
]:
continue
can_free_memories = {}
can_free_vars = set()
dist_op = dist_context.get_dist_op_for_program(op)
......@@ -337,14 +347,14 @@ class CostEstimator:
input_dims_mapping)
has_used_var = var_name + key
var = dist_op.get_serial_input(var_name)
# not used
# Not used
if var_name + key not in has_used_vars:
has_used_vars.add(has_used_var)
for process in process_mesh.processes:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key]["memory"]
# used
# Used
else:
if op_id == var_info[var_name][key]["position"][-1]:
if has_used_var not in can_free_vars:
......@@ -363,14 +373,14 @@ class CostEstimator:
output_dims_mapping)
has_used_var = var_name + key
var = dist_op.get_serial_output(var_name)
# not used
# Not used
if var_name + key not in has_used_vars:
has_used_vars.add(has_used_var)
for process in process_mesh.processes:
if process not in memories:
memories[process] = 0
memories[process] += var_info[var_name][key]["memory"]
# used
# Used
else:
if op_id == var_info[var_name][key]["position"][-1]:
if has_used_var not in can_free_vars:
......@@ -382,21 +392,22 @@ class CostEstimator:
can_free_memories[process] += var_info[
var_name][key]["memory"]
# calc peak memory
# Calc peak memory
for process in memories:
if process not in max_memories:
max_memories[process] = memories[process]
if process not in self.max_memories:
self.max_memories[process] = memories[process]
else:
if memories[process] > max_memories[process]:
max_memories[process] = memories[process]
if memories[process] > self.max_memories[process]:
self.max_memories[process] = memories[process]
# free memory
# Free memory
for process in can_free_memories:
if process in memories:
memories[process] -= can_free_memories[process]
# Calculate the max memory in all ranks
max_memory = max(max_memories.values())
max_memory = max(self.max_memories.values())
self.max_memory = max_memory
return max_memory
......@@ -410,3 +421,143 @@ class CostEstimator:
self._estimate_core(dist_context, resharder, block)
return self.global_cost
def _print_tag(self, max_len, length):
tag = "+" + "-" * max_len
for i in range(length):
print(tag, end="")
if i == length - 1:
print("+")
def _print_vals(self, vals, max_len):
for idx, val in enumerate(vals):
s = "|" + str(val).center(max_len)
print(s, end="")
if idx == len(vals) - 1:
print("|")
def _pretty_print_memory_cost(self):
"""Print memory of every rank prettily."""
if not self.max_memories or not self.max_memory:
raise ValueError("Please calculate memory cost before print.")
# Padding automatically
max_len = 0
header = ["Rank", "Memory(MiB)"]
memories = [
int(item // 1e6) for item in list(self.max_memories.values())
]
for memory in (memories + header):
if len(str(memory)) > max_len:
max_len = len(str(memory))
max_len += 4 # for pretty print of center
# Print tag
self._print_tag(max_len, len(header))
# Print header
self._print_vals(header, max_len)
# Print tag
self._print_tag(max_len, len(header))
# Print rank and its memory
for i in range(len(self.max_memories)):
memory = memories[i]
vals = [i, memory]
self._print_vals(vals, max_len)
self._print_tag(max_len, len(header))
def _pretty_print_global(self):
"""Print global execution time and max memory prettily."""
if not self.max_memories or not self.max_memory:
raise ValueError("Please calculate cost before print.")
# Padding automatically
max_len = 0
header = ["Execution Time(ms)", "Max Memory(MiB)"]
vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)]
for memory in (vals + header):
if len(str(memory)) > max_len:
max_len = len(str(memory))
max_len += 4 # for pretty print of center
# Print tag
self._print_tag(max_len, len(header))
# Print header
self._print_vals(header, max_len)
# Print tag
self._print_tag(max_len, len(header))
# Print exec time and max memory
self._print_vals(vals, max_len)
# Print tag
self._print_tag(max_len, len(header))
def pretty_print_cost(self):
"""Print cost prettily."""
print("The global execution time and max memory are as follows:")
self._pretty_print_global()
print("The memory of every rank is as follows:")
self._pretty_print_memory_cost()
def get_cost_from_engine(engine, mode):
from ..utils import to_list
# Construct cost estimator by original main program
serial_main_prog = engine._serial_main_progs[mode].clone(
) if mode in engine._serial_main_progs else engine._orig_main_prog.clone()
serial_startup_prog = engine._serial_startup_progs[mode].clone(
) if mode in engine._serial_startup_progs else engine._orig_startup_prog.clone(
)
losses = to_list(
engine._loss) if (not isinstance(engine._loss, paddle.nn.Layer)
and not callable(engine._loss)) else engine._losses
if mode in engine._dist_contexts:
dist_context = engine._dist_contexts[mode]
completer = engine._planners[mode].completer
else:
from ..completion import Completer
from ..dist_context import DistributedContext
dist_context = DistributedContext(serial_main_prog, serial_startup_prog,
engine._optimizer, losses, {},
{"loss": losses}, engine._cluster,
engine._strategy)
completer = Completer(dist_context)
completer.complete_forward_annotation()
dist_context.block_state.parse_forward_blocks(
dist_context.serial_main_program)
if mode == "eval" or mode == "predict":
cost_estimator = CostEstimator(serial_main_prog, engine._cluster)
elif mode == "train":
from ..parallelizer_v2 import Parallelizer
# Get serial main program with backward
serial_optimizer = engine._optimizer
parallelizer = Parallelizer(mode, completer, dist_context)
# Generate backward
loss_name = dist_context.serial_loss.name
serial_loss = serial_main_prog.global_block()._var_recursive(loss_name)
params_grads = parallelizer._generate_backward(serial_main_prog,
serial_startup_prog,
serial_loss)
# Generate optimizer
optimizer_ops = parallelizer._generate_optimizer(
serial_main_prog, serial_startup_prog, serial_optimizer,
params_grads)
cost_estimator = CostEstimator(serial_main_prog, engine._cluster)
# Estimate global_cost and max memory
global_cost = cost_estimator.estimate(dist_context)
max_memory = cost_estimator._estimate_max_memory_by_dist_op(dist_context)
# Print the cost
cost_estimator.pretty_print_cost()
return global_cost, max_memory
......@@ -77,7 +77,6 @@ class DistributedContext:
self._serial_optimizer = None
self._serial_feed_vars = {}
self._serial_fetch_vars = {}
self._lr_optimizer = None # record the optimzier holding lr_scheduler
# Data members related to the program
self._dist_tensors_for_program = {}
......@@ -268,12 +267,24 @@ class DistributedContext:
def _restore_serial_fetch_vars(self):
for key, var_list in self._original_serial_fetch_vars.items():
new_var_list = []
for var in var_list:
block_idx = var.block.idx
var_name = var.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_var_list.append(var)
# metrics is a list of list
if key == "metrics":
for inner_var_list in var_list:
new_inner_var_list = []
for var in inner_var_list:
block_idx = var.block.idx
var_name = var.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_inner_var_list.append(var)
new_var_list.append(new_inner_var_list)
else:
for var in var_list:
block_idx = var.block.idx
var_name = var.name
var = self._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_var_list.append(var)
self._serial_fetch_vars[key] = new_var_list
def _restore_serial_info(self, mode="to_backup"):
......@@ -861,7 +872,7 @@ class DistributedContext:
"_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \
"_serial_ordered_op_nodes", "_original_serial_loss", \
"_original_serial_feed_vars", "_original_serial_fetch_vars", \
"_serial_loss", "_serial_feed_vars", "_serial_fetch_vars", "_lr_optimizer", \
"_serial_loss", "_serial_feed_vars", "_serial_fetch_vars", "_serial_optimizer", \
"_backup_serial_main_program_stack", "_backup_serial_startup_program_stack", \
"_pass_context"]:
setattr(result, k, v)
......
......@@ -14,44 +14,14 @@
import abc
import numpy as np
from functools import wraps
import paddle
from .utils import to_list
from paddle.fluid.layers.utils import flatten
from paddle.io import DataLoader, BatchSampler, IterableDataset
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler
from paddle.io import BatchSampler, IterableDataset
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler, DistributedBatchSampler
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn
class DistributedDataLoader(metaclass=abc.ABCMeta):
def __init__(self, dataset, batch_size=1, epochs=1, drop_last=False):
if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER
else:
self.dataset_kind = _DatasetKind.MAP
self.dataset = dataset
self.epochs = epochs
self.drop_lost = drop_last
if batch_size is None:
self.batch_size = None
self.batch_sampler = None
else:
self.batch_size = batch_size
if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler(
dataset, batch_size)
else:
self.batch_sampler = BatchSampler(dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)
self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler)
class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
@abc.abstractmethod
def __iter__(self):
......@@ -72,40 +42,72 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
return _InfiniteIterableSampler(self.dataset, 1)
class NonIterableGeneratorLoader(DistributedDataLoader):
class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
def __init__(self,
dataset,
feed_list,
places,
feed_list=None,
capacity=None,
use_double_buffer=True,
iterable=True,
return_list=False,
use_multiprocess=False,
drop_last=True,
places=None,
batch_size=1,
epochs=1,
steps_per_epoch=None,
collate_fn=None,
split_data=True,
data_parallel_world_size=[],
data_parallel_rank=[],
drop_last=False,
split_data=True):
data_parallel_rank=[]):
self.dataset = dataset
self.feed_list = feed_list
self.capacity = capacity
self.use_double_buffer = use_double_buffer
self.iterable = iterable
self.return_list = return_list
self.use_multiprocess = use_multiprocess
self.drop_last = drop_last
self.places = places
self.batch_size = batch_size
self.epochs = epochs
self.steps_per_epoch = steps_per_epoch
self.collate_fn = collate_fn
self.split_data = split_data
assert len(data_parallel_world_size) == len(feed_list)
assert len(data_parallel_rank) == len(feed_list)
self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank
self.split_data = split_data
super(NonIterableGeneratorLoader,
self).__init__(dataset, batch_size, epochs, drop_last)
if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER
else:
self.dataset_kind = _DatasetKind.MAP
if self.batch_size is None:
self.batch_sampler = None
else:
if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler(
dataset, batch_size)
else:
self.batch_sampler = BatchSampler(dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)
self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler)
if self.auto_collate_batch:
self.collate_fn = collate_fn or default_collate_fn
else:
self.collate_fn = collate_fn or default_convert_fn
self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.auto_collate_batch,
self.collate_fn, self.drop_lost)
self.collate_fn, self.drop_last)
self._steps = self._infer_steps()
self._inner_dataloader = self._create_inner_dataloader()
......@@ -118,8 +120,10 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
def __next__(self):
if not self._steps:
self._cur_step += 1
return None
elif self._cur_step < self._steps:
self._cur_step += 1
return None
else:
self._inner_dataloader.reset()
self.sampler_iter = iter(self.index_sampler)
......@@ -141,6 +145,16 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
)
return steps_per_epoch
@property
def index_sampler(self):
if self.auto_collate_batch:
return self.batch_sampler
else:
if self.dataset_kind == _DatasetKind.MAP:
return list(range(len(self.dataset)))
else:
return _InfiniteIterableSampler(self.dataset, 1)
def _create_inner_dataloader(self):
def data_generator():
......@@ -153,7 +167,7 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset,
self.auto_collate_batch, self.collate_fn,
self.drop_lost)
self.drop_last)
break
partial_data = []
......@@ -173,7 +187,83 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
yield partial_data
dataloader = paddle.fluid.io.DataLoader.from_generator(
feed_list=self.feed_list, capacity=70, iterable=False)
feed_list=self.feed_list,
capacity=self.capacity,
use_double_buffer=self.use_double_buffer,
# iterable=self.iterable,
iterable=False,
return_list=self.return_list,
use_multiprocess=self.use_multiprocess,
drop_last=self.drop_last)
dataloader.set_batch_generator(data_generator, self.places)
return dataloader
class DistributedDataLoader(DistributedDataLoaderBase):
def __init__(self,
dataset,
feed_list=None,
places=None,
return_list=True,
batch_size=1,
shuffle=False,
drop_last=False,
collate_fn=None,
num_workers=0,
use_buffer_reader=True,
use_shared_memory=True,
timeout=0,
worker_init_fn=None,
epochs=1,
steps_per_epoch=None,
split_data=True,
data_parallel_world_size=[],
data_parallel_rank=[]):
self.dataset = dataset
self.feed_list = feed_list
self.return_list = return_list
self.places = places
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.collate_fn = collate_fn
self.num_workers = num_workers
self.use_buffer_reader = use_buffer_reader
self.use_shared_memory = use_shared_memory
self.timeout = timeout
self.worker_init_fn = worker_init_fn
self.epochs = epochs
self.steps_per_epoch = steps_per_epoch
self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank
self.split_data = split_data
# TODO: rank info
self.batch_sampler = DistributedBatchSampler(
self.dataset, self.batch_size, self.dp_world_sizes[0],
self.dp_ranks[0], self.shuffle, self.drop_last)
self._inner_dataloader = self._create_inner_dataloader()
def __iter__(self):
return self
def __next__(self):
return next(self.data)
def _create_inner_dataloader(self):
dataloader = paddle.fluid.io.DataLoader(
self.dataset,
feed_list=self.feed_list,
places=self.places,
return_list=self.return_list,
batch_sampler=self.batch_sampler,
collate_fn=self.collate_fn,
num_workers=self.num_workers,
use_buffer_reader=self.use_buffer_reader,
use_shared_memory=self.use_shared_memory,
timeout=self.timeout,
worker_init_fn=self.worker_init_fn)
self.data = (x for x in dataloader)
return dataloader
......@@ -139,7 +139,7 @@ class ProxyLayer(Layer):
"""
outs = []
for metric in self.metrics:
outs.extend(metric.compute(*inputs))
outs.append(to_list(metric.compute(*inputs)))
return outs
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import paddle
from paddle.fluid import core
from .process_mesh import ProcessMesh
......@@ -196,15 +198,36 @@ def recompute(op):
return RecomputeOperator(op)
_g_fetched_tensors = {}
_g_collections = {}
class CollectionNames(object):
FETCHES = "fetches"
LOGGING = "logging"
def get_collection(name):
collection = _g_collections.get(name, None)
if collection is None:
collection = []
_g_collections[name] = collection
return _g_collections[name]
def fetch(tensor, name=None):
if name is None:
_g_fetched_tensors[tensor.name] = tensor
def add_to_collection(collection_name, value, name=None):
if collection_name not in _g_collections:
_g_collections[collection_name] = []
if name is not None:
for _, v in _g_collections[collection_name]:
if v == value: return
_g_collections[collection_name].append((name, value))
else:
_g_fetched_tensors[name] = tensor
for _, v in _g_collections[collection_name]:
if v == value: return
_g_collections[collection_name].append((None, value))
def _get_fetches():
return _g_fetched_tensors
def fetch(tensor, name=None, logging=False):
add_to_collection(CollectionNames.FETCHES, tensor, name)
if logging:
add_to_collection(CollectionNames.LOGGING, tensor, name)
......@@ -33,3 +33,5 @@ from . import dist_slice
from . import dist_fused_feedforward
from . import dist_fused_attention
from . import dist_reduce_sum_p
from . import dist_shape
from . import dist_assign
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .dist_default import DistributedDefaultImpl0
from ..utils import compute_compatible_and_update_dim_mapping
class DistributedAssign(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedAssign, self).__init__(op_type)
register_distributed_operator_impl_container(DistributedAssign("assign"))
class DistributedAssignImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedAssignImpl, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
return True
def is_output_compatible(self, dist_op):
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
if x_dims_mapping != out_dims_mapping:
return False
return True
def update_dims_mapping(self, dist_op):
changed = False
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
x_name = op_desc.input('X')[0]
out_name = op_desc.output('Out')[0]
x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name)
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
for i in range(len(x_dims_mapping)):
dim_changed = compute_compatible_and_update_dim_mapping(
[x_dims_mapping, out_dims_mapping], [i, i])
if dim_changed:
changed = True
return changed
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl("assign", DistributedAssignImpl("assign"))
......@@ -1308,6 +1308,8 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
process_mesh = dist_attr.process_mesh
processes = process_mesh.processes
# col parallel: matmul + allreduce
if backward_op.attr("trans_y"):
Y_var_dim_mapping.reverse()
assert Y_var_dim_mapping[0] < 0
parallel_axis = Y_var_dim_mapping[1]
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .common import DistributedOperatorImplContainer
from .common import DistributedOperatorImpl
from .common import register_distributed_operator_impl_container
from .common import register_distributed_operator_impl
from .dist_default import DistributedDefaultImpl0
from ..utils import is_dim_shard
class DistributedShape(DistributedOperatorImplContainer):
def __init__(self, op_type):
super(DistributedShape, self).__init__(op_type)
register_distributed_operator_impl_container(DistributedShape("shape"))
class DistributedShapeImpl(DistributedOperatorImpl):
def __init__(self, name):
super(DistributedShapeImpl, self).__init__(name)
self._forward_implemented = True
self._backward_implemented = True
def is_input_compatible(self, dist_op):
return True
def is_output_compatible(self, dist_op):
op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr
out_name = op_desc.output('Out')[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name)
assert len(out_dims_mapping) == 1
if is_dim_shard(out_dims_mapping[0]):
return False
return True
def is_auto_compatible(self, dist_op):
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)):
return False
return True
def update_dims_mapping(self, dist_op):
return False
@staticmethod
def forward(ctx, *args, **kwargs):
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl("shape", DistributedShapeImpl("shape"))
......@@ -101,8 +101,12 @@ class DistributedSplitImpl(DistributedOperatorImpl):
return changed
def is_auto_compatible(self, dist_op):
raise NotImplementedError(
"Auto Search is not supported by dist split yet.")
if (not self.is_input_compatible(dist_op)) or \
(not self.is_output_compatible(dist_op)) or \
(not self.is_compatible(dist_op)):
return False
return True
@staticmethod
def forward(ctx, *args, **kwargs):
......
......@@ -23,14 +23,12 @@ import logging
import pickle
import time
import paddle
from paddle.fluid.backward import append_backward
from paddle.distributed.utils.log_utils import get_logger
from paddle.distributed.fleet import cloud_utils
import paddle.fluid.core as core
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.distributed.utils.log_utils import get_logger
from paddle.distributed.passes import new_pass, PassContext
from .dist_context import DistributedContext
from .dist_context import get_default_distributed_context
from .dist_context import set_default_distributed_context
from .completion import Completer
from .partitioner import Partitioner
......@@ -40,9 +38,7 @@ from .process_group import get_world_process_group
from .process_group import _g_process_group_map, ProcessGroup
from .utils import make_data_unshard
from .utils import set_grad_var_shape
from .utils import print_program_with_dist_attr
from .utils import SerialProgramInfo
from .utils import get_logger
from .reshard import Resharder
from .cluster import Cluster
from .mapper import mapping
......@@ -148,7 +144,7 @@ class AutoParallelizer:
with program_guard(main_program, startup_program):
optimize_ops = optimizer.apply_gradients(params_grads)
self._dist_context._lr_optimizer = optimizer
self._dist_context._serial_optimizer = optimizer
# update completion
self._completer = Completer(self._dist_context)
self._completer.complete_update_annotation(main_program)
......
......@@ -15,24 +15,17 @@
import copy
import time
import logging
from collections import defaultdict
import paddle
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import _non_static_mode, unique_name
from paddle.fluid.framework import unique_name
from paddle.distributed.passes import new_pass
from .reshard import Resharder
from .partitioner import Partitioner
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .dist_loader import NonIterableGeneratorLoader
from .utils import make_data_unshard, set_grad_var_shape
from .utils import print_program_with_dist_attr, to_list
from .utils import get_logger
from .process_group import get_all_process_groups, get_world_process_group
from .dist_context import DistributedContext, get_default_distributed_context
from .utils import set_grad_var_shape
from .process_group import get_world_process_group
from ..utils.log_utils import get_logger
class Parallelizer:
......@@ -69,7 +62,7 @@ class Parallelizer:
serial_main_program, serial_startup_program, params_grads = self._apply_pre_optimization(
serial_main_program, serial_startup_program, serial_loss,
serial_optimizer, params_grads)
self._logger.info(
self._logger.debug(
"within parallel apply_pre_optimization time: {}, mode {}".
format(time.time() - time0, self._mode))
# Do logical partition
......@@ -77,14 +70,14 @@ class Parallelizer:
partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
serial_main_program, serial_startup_program, params_grads)
self._logger.info(
self._logger.debug(
"within parallel partitioner time: {}, mode {}".format(
time.time() - time0, self._mode))
# Generate optimizer
time0 = time.time()
self._generate_optimizer(dist_main_prog, dist_startup_prog,
serial_optimizer, dist_params_grads)
self._logger.info(
self._logger.debug(
"within parallel optimizer time: {}, mode {}".format(
time.time() - time0, self._mode))
# Do reshard process
......@@ -93,14 +86,14 @@ class Parallelizer:
resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, dist_params_grads)
resharder.reshard()
self._logger.info(
self._logger.debug(
"within parallel reshard time: {}, mode {}".format(
time.time() - time0, self._mode))
# Apply post optimization passes
time0 = time.time()
self._apply_post_optimization(dist_main_prog, dist_startup_prog,
rank, dist_params_grads)
self._logger.info(
self._logger.debug(
"within parallel apply_post_optimization time: {}, mode {}".
format(time.time() - time0, self._mode))
else:
......@@ -109,7 +102,7 @@ class Parallelizer:
self._apply_pre_optimization(serial_main_program,
serial_startup_program, None, None,
None)
self._logger.info(
self._logger.debug(
"within parallel apply_pre_optimization time: {}, mode {}".
format(time.time() - time0, self._mode))
# Do logical partition
......@@ -118,14 +111,14 @@ class Parallelizer:
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
serial_main_program, serial_startup_program, [])
# Do reshard process
self._logger.info(
self._logger.debug(
"within parallel partitioner time: {}, mode {}".format(
time.time() - time0, self._mode))
time0 = time.time()
resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, [], 1)
resharder.reshard()
self._logger.info(
self._logger.debug(
"within parallel reshard time: {}, mode {}".format(
time.time() - time0, self._mode))
# Clone program for test
......@@ -150,7 +143,7 @@ class Parallelizer:
# NOTE: `apply_gradients` will add an Accumulator for a parameter only once,
# but optimizer will be called repeatedly in re-launch, so optimizer need to be copied.
optimizer = copy.deepcopy(optimizer)
self._dist_context._lr_optimizer = optimizer
self._dist_context._serial_optimizer = optimizer
with program_guard(main_program, startup_program):
with unique_name.guard("opt_"):
optimizer_ops = optimizer.apply_gradients(params_grads)
......@@ -177,9 +170,7 @@ class Parallelizer:
startup_program = self._pass_context.get_attr("startup_program")
params_grads = self._pass_context.get_attr("params_grads")
# apply amp pass
# FIXME we disenable amp for eval since it has a little bug with
# eval program and which will be fixed in future
# apply amp pass on train/eval/predict
if self._strategy.amp.enable:
config = copy.deepcopy(self._strategy.amp.to_dict())
config["dist_context"] = self._dist_context
......
......@@ -28,7 +28,7 @@ from .utils import set_dist_op_desc_original_id
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_loss_op, is_optimize_op
from .operators.common import BACKWARD_ONLY_DIST_OPS
__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]
__varname_not_in_block__ = ["lod_tensor_blocking_queue"]
__not_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES
]
......@@ -243,7 +243,9 @@ class Partitioner(object):
target_block, serial_input_varname,
new_varname)
else:
assert serial_input_varname in __varname_not_in_block__
for varname_not_in_block in __varname_not_in_block__:
assert varname_not_in_block in serial_input_varname, \
"{} is not found".format(serial_input_varname)
self._serial2dist_varname_mapping[
serial_input_varname] = new_varname
......
......@@ -14,9 +14,7 @@
from .completion import Completer
from .dist_context import get_default_distributed_context
from .utils import print_program_with_dist_attr
# from .tuner.parallel_tuner import ParallelTuner
from .tuner.parallel_tuner import ParallelTuner
class Planner:
......@@ -39,20 +37,20 @@ class Planner:
self._completer = Completer(self._dist_context)
self._strategy = dist_context.strategy
# if self._strategy.auto_search:
# self._parallel_tuner = ParallelTuner(
# self._dist_context, mode=self._mode)
# set parallel tuner for auto search
if self._strategy.auto_mode == "full":
self._parallel_tuner = ParallelTuner(self._dist_context,
mode=self._mode)
@property
def completer(self):
return self._completer
def plan(self):
self._completer.complete_forward_annotation()
# if self._strategy.auto_search:
# self._parallel_tuner.tune()
# else:
# self._completer.complete_forward_annotation()
if self._strategy.auto_mode == "full":
self._parallel_tuner.tune()
else:
self._completer.complete_forward_annotation()
# parse forward sub block
self._dist_context.block_state.parse_forward_blocks(
self._dist_context.serial_main_program)
......@@ -168,7 +168,10 @@ class ProcessMesh(object):
else:
new_mesh = self._mesh[index]
new_dim_names = self._dim_names[1:]
return ProcessMesh(new_mesh, new_dim_names)
if new_mesh.shape:
return ProcessMesh(new_mesh, new_dim_names)
else:
return ProcessMesh([new_mesh])
def __enter__(self):
set_current_process_mesh(self)
......
......@@ -37,6 +37,7 @@ _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling']
_g_gradient_clip_ops = [
"sum", "sqrt", "fill_constant", "elementwise_max", "elementwise_div"
]
_g_subblock_ops = ["while", "conditional_block"]
def get_var_with_recursion(var_name, block, program):
......@@ -45,10 +46,11 @@ def get_var_with_recursion(var_name, block, program):
if var_name in block.vars:
var = block.vars[var_name]
else:
parent_block = program.blocks[block.parent_idx]
if var_name in parent_block.vars:
var = parent_block.vars[var_name]
assert var is not None
var = block._var_recursive(var_name)
# parent_block = program.blocks[block.parent_idx]
# if var_name in parent_block.vars:
# var = parent_block.vars[var_name]
assert var is not None, "{} is not found".format(var.name)
return var
......@@ -1077,7 +1079,9 @@ class Resharder:
new_Out = []
for var_name in while_op.output("Out"):
for output_name in sub_block_op_outputs[::-1]:
if output_name.find(var_name) != -1:
if output_name.find(var_name) != -1 and (
len(var_name) == len(output_name)
or "@RESHARD" in output_name):
if output_name not in new_Out:
new_Out.append(output_name)
assert new_Out
......@@ -1106,13 +1110,15 @@ class Resharder:
return False
def is_condition_replicative(self, op):
assert op.type == "while"
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
dist_op = self.dist_context.get_dist_op_for_program(op)
op_dist_attr = dist_op.dist_attr
if op.type == "while":
input_cond = op.input("Condition")
elif op.type == "conditional_block":
input_cond = op.input("Cond")
# the dims mapping of condition tensor should be replicative
for var_name in op.input("Condition"):
for var_name in input_cond:
var = get_var_with_recursion(var_name, sub_block,
self.auto_parallel_main_prog)
dist_tensor = self.dist_context.get_dist_tensor_for_program(var)
......@@ -1662,9 +1668,9 @@ class Resharder:
op.desc.set_input(proto.inputs[0].name,
op.input("X") + while_op_X_append)
def _get_while_op_input_attrs(self, op, var_name):
def _get_subblock_input_attrs(self, op, var_name):
# NOTE: Multi while loop is not supported
assert op.type == "while"
assert op.type in _g_subblock_ops
sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id]
ops = sub_block.ops
input_attrs = []
......@@ -1715,8 +1721,8 @@ class Resharder:
def get_op_input_attrs(self, op, var_name):
op_input_attrs = []
if op.type == "while":
op_input_attrs = self._get_while_op_input_attrs(op, var_name)
if op.type in _g_subblock_ops:
op_input_attrs = self._get_subblock_input_attrs(op, var_name)
else:
op_input_attrs = self._get_common_op_input_attrs(op, var_name)
......@@ -1738,8 +1744,18 @@ class Resharder:
if len(set(process_mesh.processes)) == len(processes):
global_process_mesh_idx = idx
break
if global_process_mesh_idx is not None:
self.dist_context.process_meshes.pop(idx)
is_removed = False
global_mesh = self.dist_context.process_meshes[idx]
for i, mesh in enumerate(self.dist_context.process_meshes):
if i == idx:
continue
if set(mesh.processes) < set(global_mesh.processes):
is_removed = True
if is_removed:
self.dist_context.process_meshes.pop(idx)
def _change_subblock_op_input_and_output(self, block_idx, block):
if "var_reshard_mapping" in Resharder.while_block_info[block_idx]:
......@@ -1810,7 +1826,7 @@ class Resharder:
if dist_op is not None:
op_input_dist_attrs = [
] # [(op_process_mesh, op_input_dims_mapping), (op_process_mesh, op_input_dims_mapping)]
if op.type == "while":
if op.type in _g_subblock_ops:
if not self.is_condition_replicative(op):
raise ValueError(
"Please check the condition due to the dims mapping is not replicative."
......@@ -1824,6 +1840,8 @@ class Resharder:
if op.type == "while":
# condition var process mesh is the same with op and dims_mapping is replicative, so it do not need reshard
input_var_names = op.input("X")
elif op.type == "conditional_block":
input_var_names = op.input("Input")
else:
input_var_names = op.input_arg_names
# to avoid while op X order different
......@@ -1831,8 +1849,8 @@ class Resharder:
idx_offset = 0
for var_name in input_var_names:
# skip lod_tensor_blocking_queue_0
if var_name == "lod_tensor_blocking_queue_0":
# skip lod_tensor_blocking_queue_? name
if "lod_tensor_blocking_queue" in var_name:
continue
var = get_var_with_recursion(var_name, block,
self.auto_parallel_main_prog)
......@@ -1976,11 +1994,12 @@ class Resharder:
idx = 0
# skip reader and ops whose process mesh is union
skip_ops = [
"create_py_reader", "create_double_buffer_reader", "read", "while",
"create_py_reader", "create_double_buffer_reader", "read",
"write_to_array", "read_from_array"
]
global _g_special_ops
skip_ops += _g_special_ops
skip_ops += _g_subblock_ops
while idx < len(block.ops):
pre_op_count = len(block.ops)
op = block.ops[idx]
......
......@@ -116,6 +116,13 @@ class TuningConfig(BaseConfig):
super(TuningConfig, self).__init__(category, config_dict)
class DatasetConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.DATASET
super(DatasetConfig, self).__init__(category, config_dict)
class Strategy(BaseConfig):
"""
The `Strategy` object is used to configure the paralleization and optimization beheviors.
......@@ -180,3 +187,6 @@ class Strategy(BaseConfig):
config_dict = self._config_dict.get(constants.TUNING, None)
self.tuning = TuningConfig(config_dict)
config_dict = self._config_dict.get(constants.DATASET, None)
self.dataset = DatasetConfig(config_dict)
......@@ -136,12 +136,24 @@ def _copy_context(ref_dist_context):
for key, var_list in ref_dist_context._serial_fetch_vars.items():
new_var_list = []
for var in var_list:
block_idx = var.block.idx
var_name = var.name
var = new_dist_context._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_var_list.append(var)
# metrics is a list of list
if key == "metrics":
for inner_var_list in var_list:
new_inner_var_list = []
for var in inner_var_list:
block_idx = var.block.idx
var_name = var.name
var = new_dist_context._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_inner_var_list.append(var)
new_var_list.append(new_inner_var_list)
else:
for var in var_list:
block_idx = var.block.idx
var_name = var.name
var = new_dist_context._serial_main_program.blocks[
block_idx]._var_recursive(var_name)
new_var_list.append(var)
new_dist_context._serial_fetch_vars[key] = new_var_list
# copy information in forward and backward
......
......@@ -13,20 +13,17 @@
# limitations under the License.
import os
import sys
import argparse
import traceback
import pickle
import json
import time
import numpy as np
from functools import partial
import paddle
from paddle.fluid.framework import Program, _current_expected_place
from paddle.fluid.framework import Operator, Parameter
from paddle.distributed.auto_parallel.process_group import clear_all_process_groups, get_all_process_groups, new_process_group
from paddle.distributed.auto_parallel.dist_loader import NonIterableGeneratorLoader
from paddle.fluid.framework import Operator
from paddle.distributed.auto_parallel.process_group import get_all_process_groups, new_process_group
from paddle.distributed.auto_parallel.dist_loader import DistributedDataLoaderFromGenerator
from paddle.distributed.collective import _get_global_env
paddle.enable_static()
......@@ -135,13 +132,14 @@ def create_dataloader(main_program,
# insert read op at the end of program
places = paddle.static.cuda_places()
with paddle.static.program_guard(main_program, startup_program):
dataloader = NonIterableGeneratorLoader(
dataset,
feed_list,
places,
dataset.batch_size,
epochs,
steps_per_epoch,
dataloader = DistributedDataLoaderFromGenerator(
dataset=dataset,
feed_list=feed_list,
capacity=70,
places=places,
batch_size=dataset.batch_size,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
data_parallel_world_size=dataset.dp_world_size,
data_parallel_rank=dataset.dp_rank)
......
......@@ -44,10 +44,18 @@ class TunableSpace(object):
def variables(self):
return self._variables
@variables.setter
def variables(self, variables):
self._variables = variables
@property
def values(self):
return self._values
@values.setter
def values(self, values):
self._values = values
def get_value(self, name):
if name in self.values:
return self.values[name]
......
......@@ -90,6 +90,7 @@ class Choice(TunableVariable):
raise TypeError(
"Choice can contain only one type of value, but found values: {} with types: {}."
.format(str(values), str(types)))
self._is_unknown_type = False
if isinstance(values[0], str):
values = [str(v) for v in values]
......@@ -108,9 +109,8 @@ class Choice(TunableVariable):
if default is not None:
default = bool(default)
else:
raise TypeError(
"Choice can only contain str, int, float, or boll, but found: {} "
.format(str(values)))
self._is_unknown_type = True
self._indices = [i for i in range(len(values))]
self.values = values
if default is not None and default not in values:
......@@ -129,7 +129,11 @@ class Choice(TunableVariable):
def random(self, seed=None):
rng = np.random.default_rng(seed)
return rng.choice(self.values)
if self._is_unknown_type:
indice = rng.choice(self._indices)
return self.values[indice]
else:
return rng.choice(self.values)
def get_state(self):
state = super(Choice, self).get_state()
......
......@@ -27,6 +27,10 @@ from paddle.distributed.auto_parallel.process_group import get_all_process_group
from paddle.fluid.io import is_parameter, is_belong_to_optimizer
from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute
__not_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES
]
def get_logger(log_level, name="auto_parallel"):
logger = logging.getLogger(name)
......@@ -1583,3 +1587,80 @@ def find_higher_order_backward_op(program):
return True
return False
def get_lr(optimizer):
if isinstance(optimizer, paddle.optimizer.Optimizer):
return optimizer.get_lr()
elif isinstance(optimizer, paddle.fluid.optimizer.Optimizer):
if isinstance(optimizer._learning_rate, float):
return optimizer._learning_rate
else:
return optimizer._learning_rate()
else:
raise TypeError(
"'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
" or `paddle.fluid.optimizer.Optimizer`, but got {}.".format(type(optimizer))
)
def initialize_pg_in_full_mode(all_process_groups, cur_rank):
import socket
from ..collective import _get_global_env
has_recv_by_socket = []
# This is a magic number
magic_num = 500
genv = _get_global_env()
cur_rank_ip, cur_rank_port = genv.current_endpoint.split(":")
cur_rank_recv_port = int(cur_rank_port) + magic_num
server_socket = None
# Large enough for recv rank
buff_size = 1024
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.bind((cur_rank_ip, cur_rank_recv_port))
# The 10 is an empirical value
server_socket.listen(10)
client_sockets = {}
for process_group in all_process_groups:
if cur_rank not in process_group.ranks:
continue
if len(process_group.ranks) == 2:
index = process_group.ranks.index(cur_rank)
is_send = True if index == 0 else False
if is_send:
recv_rank = process_group.ranks[1]
recv_rank_ip, recv_rank_port = genv.trainer_endpoints[
recv_rank].split(":")
connect_port = int(recv_rank_port) + magic_num
client_socket = socket.socket(socket.AF_INET,
socket.SOCK_STREAM)
client_socket.connect((recv_rank_ip, connect_port))
client_socket.send(str(cur_rank).encode('utf-8'))
rank = client_socket.recv(buff_size).decode('utf-8')
rank = int(rank)
if rank != recv_rank:
raise ValueError(
"Please check comm pair, the recv rank should be {} but got {}."
.format(recv_rank, rank))
else:
print("It is able to instantiate {} as sender now.".format(
process_group.ranks))
client_socket.close()
else:
send_rank = process_group.ranks[0]
while True:
if send_rank not in has_recv_by_socket:
client_socket, recv_addr = server_socket.accept()
rank = int(client_socket.recv(buff_size).decode())
client_sockets[rank] = client_socket
has_recv_by_socket.append(rank)
else:
client_sockets[send_rank].send(
str(cur_rank).encode("utf-8"))
client_sockets[send_rank].close()
print("It is able to instantiate {} as recver now.".
format(process_group.ranks))
break
process_group.instantiate()
server_socket.close()
......@@ -517,9 +517,11 @@ class AMPPass(PassBase):
self.set_attr("use_dynamic_loss_scaling", False)
self.set_attr("input_data", [])
self.set_attr("params_grads", [])
self._loss = None
self._loss_scaling = None
self._num_good_steps = None
self._num_bad_steps = None
self._loss = None
def _check_self(self):
if self.get_attr("init_loss_scaling") < 0:
......
......@@ -82,9 +82,11 @@ class DataParallelOptimizationPass(PassBase):
with paddle.static.program_guard(main_program, startup_program):
self._analyze_program()
self._prune_grad_scaling()
self._calc_comm_overlap()
grad_group = self._fuse_allreduce()
if self.is_data_parallel_applied():
self._prune_grad_scaling()
self._calc_comm_overlap()
grad_group = self._fuse_allreduce()
# self.summary(grad_group)
......@@ -167,6 +169,9 @@ class DataParallelOptimizationPass(PassBase):
) == 0, "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format(
not_synchronized_grads)
def is_data_parallel_applied(self):
return len(self._group_to_grad_name_map) > 0
def _could_be_prune(self):
return self.dist_context.gradient_scale and (
......
......@@ -213,7 +213,7 @@ class ClipGradByGloblNormPass(PassBase):
if self.get_attr("dist_context") is None:
return False
dist_context = self.get_attr("dist_context")
if dist_context._lr_optimizer._grad_clip is None:
if dist_context._serial_optimizer._grad_clip is None:
return False
if self.get_attr("params_grads") is None:
return False
......
......@@ -396,7 +396,7 @@ class ShardingPass(PassBase):
dp_ring_ids = [group.id for group in self.dp_groups]
for idx, op in reversed(list(enumerate(main_block.ops))):
if is_data_parallel_reduce_op(op):
if _is_param_grad_allreduce_op(op, main_block):
input_name = op.input_arg_names[0]
base_name = _get_base_name_from_grad_name(input_name)
sharding_info = self.varname_to_sharding_info[base_name]
......@@ -653,6 +653,20 @@ def _get_base_name_from_grad_name(grad_name):
return base_name
def _is_param_grad_allreduce_op(op, block):
if not is_data_parallel_reduce_op(op):
return False
output_name = op.output_arg_names[0]
base_name = _get_base_name_from_grad_name(output_name)
if not block.has_var(base_name):
return False
return block.var(base_name).is_parameter
def _is_param_grad_sum_op(op, block):
if not is_backward_op(op):
......
......@@ -60,6 +60,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_pass_amp MODULES test_pass_amp ENVS ${dist_ENVS})
set_tests_properties(test_pass_amp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_engine_callbacks MODULES test_engine_callbacks)
set_tests_properties(test_engine_callbacks
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_while_op_completion MODULES test_while_op_completion
ENVS ${dist_ENVS})
......@@ -78,6 +81,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_dist_embedding MODULES test_dist_embedding ENVS
${dist_ENVS})
py_test_modules(test_dist_slice MODULES test_dist_slice ENVS ${dist_ENVS})
py_test_modules(test_dist_split MODULES test_dist_split ENVS ${dist_ENVS})
py_test_modules(test_cluster MODULES test_cluster ENVS ${dist_ENVS})
py_test_modules(test_comm_cost MODULES test_comm_cost ENVS ${dist_ENVS})
py_test_modules(test_comp_cost MODULES test_comp_cost ENVS ${dist_ENVS})
......@@ -96,4 +100,19 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_strategy MODULES test_strategy)
py_test_modules(test_pass_quantization MODULES test_pass_quantization)
py_test_modules(test_dist_shape MODULES test_dist_shape)
py_test_modules(test_dist_assign MODULES test_dist_assign)
py_test_modules(test_conditional_block_reshard MODULES
test_conditional_block_reshard)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS})
set_tests_properties(test_parallel_tuner PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_full MODULES test_parallel_tuner_full
ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_full PROPERTIES TIMEOUT 120)
py_test_modules(test_parallel_tuner_predict MODULES
test_parallel_tuner_predict ENVS ${dist_ENVS})
set_tests_properties(test_parallel_tuner_predict PROPERTIES TIMEOUT 120)
endif()
......@@ -88,33 +88,27 @@ class TestAMPPass(unittest.TestCase):
def test_amp_pass(self):
# mp2 training
mp_engine = self.get_engine()
mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(mp_losses["loss"])
history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(history.history["loss"])
# mp2 amp-o1 training
amp_o1_engine = self.get_engine(True, "o1")
amp_o1_losses = amp_o1_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
amp_o1_losses = np.array(amp_o1_losses["loss"])
history = amp_o1_engine.fit(self.dataset, 3, batch_size=self.batch_size)
amp_o1_losses = np.array(history.history["loss"])
amp_o1_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o1_losses)
# mp2 amp-o2 training
amp_o2_engine = self.get_engine(True, "o2")
amp_o2_losses = amp_o2_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
amp_o2_losses = np.array(amp_o2_losses["loss"])
history = amp_o2_engine.fit(self.dataset, 3, batch_size=self.batch_size)
amp_o2_losses = np.array(history.history["loss"])
amp_o2_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o2_losses)
# mp2 amp-o3 training
amp_o3_engine = self.get_engine(True, "o3")
amp_o3_losses = amp_o3_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
amp_o3_losses = np.array(amp_o3_losses["loss"])
history = amp_o3_engine.fit(self.dataset, 3, batch_size=self.batch_size)
amp_o3_losses = np.array(history.history["loss"])
amp_o3_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o3_losses)
......
......@@ -20,6 +20,8 @@ import os
import numpy as np
import subprocess
import paddle
import paddle.static as static
import paddle.utils as utils
import paddle.nn as nn
import paddle.fluid as fluid
import paddle.static as static
......@@ -29,14 +31,17 @@ from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.distributed.fleet import auto
from paddle.distributed.auto_parallel.interface import get_collection, CollectionNames
from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn
paddle.enable_static()
global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
PP_MESH_0 = auto.ProcessMesh([0])
PP_MESH_1 = auto.ProcessMesh([1])
batch_size = 1
epoch_num = 1
batch_size = 2
batch_num = 10
hidden_size = 1024
sequence_len = 512
......@@ -46,6 +51,8 @@ class_num = 10
paddle.seed(44)
is_fetch = True
is_feed = True
my_feed_vars = []
class MyDataset(Dataset):
......@@ -63,6 +70,23 @@ class MyDataset(Dataset):
return self.num_samples
def get_random_inputs_and_labels(image_shape, label_shape):
input = np.random.random(size=image_shape).astype('float32')
label = np.random.random(size=label_shape).astype('int64')
return input, label
def batch_generator_creator():
def __reader__():
for _ in range(batch_num):
batch_input, batch_label = get_random_inputs_and_labels(
[batch_size, image_size], [batch_size, 1])
yield batch_input, batch_label
return __reader__
class MLPLayer(nn.Layer):
def __init__(self,
......@@ -92,16 +116,20 @@ class MLPLayer(nn.Layer):
def forward(self, input):
out = auto.shard_op(self.norm, PP_MESH_0)(input)
out = self.linear0(out)
if is_feed:
my_feed_vars.append((out, out.shape))
out = F.gelu(out, approximate=True)
out = auto.shard_op(self.linear1, PP_MESH_1)(out)
out = self.dropout(out)
out = self.linear2(out)
if is_feed:
my_feed_vars.append((out, out.shape))
if is_fetch:
auto.fetch(out, "out")
auto.fetch(out, "my_fetch", logging=True)
return out
def train(fetch):
def train_high_level(fetch):
global is_fetch
is_fetch = fetch
mlp = MLPLayer(hidden_size=hidden_size,
......@@ -124,10 +152,12 @@ def train(fetch):
# train
train_dataset = MyDataset(batch_num * batch_size)
eval_dataset1 = MyDataset(5 * batch_size)
engine.fit(train_data=train_dataset,
epochs=2,
batch_size=batch_size,
valid_data=eval_dataset1)
history = engine.fit(train_data=train_dataset,
epochs=2,
batch_size=batch_size,
valid_data=eval_dataset1,
log_freq=1)
# eval
eval_dataset2 = MyDataset(batch_size)
......@@ -135,7 +165,7 @@ def train(fetch):
# predict
test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size=batch_size)
outputs = engine.predict(test_dataset, batch_size=batch_size)
# save
temp_dir = tempfile.TemporaryDirectory()
......@@ -145,6 +175,265 @@ def train(fetch):
temp_dir.cleanup()
def train_low_level():
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(mlp, loss, optimizer, metrics=None, strategy=strategy)
feed_dict = {}
for feed_var, shape in my_feed_vars:
feed_dict[feed_var.name] = np.zeros(shape, dtype="float32")
# Build normal normal dataloader
# train
train_dataset = MyDataset(batch_num * batch_size)
train_dataloader = engine.dataloader(train_dataset,
batch_size=batch_size,
mode="train")
engine.prepare(mode="train")
for data in train_dataloader:
outs = engine.run(data, feed=feed_dict, mode="train")
# eval
eval_dataset2 = MyDataset(batch_size)
eval_dataloader = engine.dataloader(eval_dataset2,
batch_size=batch_size,
mode="eval")
engine.prepare(mode="eval")
for data in eval_dataloader:
outs = engine.run(data, feed=feed_dict, mode="eval")
# predict
engine.to_mode("predict")
test_dataset = MyDataset(batch_size)
predict_dataloader = engine.dataloader(test_dataset, batch_size=batch_size)
engine.prepare()
for data in predict_dataloader:
outs = engine.run(data, feed=feed_dict)
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp')
engine.save(model_filename, training=True)
engine.load(model_filename)
temp_dir.cleanup()
# Build dataloader from generator
# train
train_dataset = MyDataset(batch_num * batch_size)
train_dataloader = engine.dataloader_from_generator(train_dataset,
batch_size=batch_size,
mode="train")
engine.prepare(mode="train")
for data in train_dataloader:
outs = engine.run(data, feed=feed_dict, mode="train")
# eval
engine.to_mode("eval")
eval_dataset2 = MyDataset(batch_size)
eval_dataloader = engine.dataloader_from_generator(eval_dataset2,
batch_size=batch_size)
engine.prepare()
for data in eval_dataloader:
outs = engine.run(data, feed=feed_dict)
# predict
test_dataset = MyDataset(batch_size)
predict_dataloader = engine.dataloader_from_generator(test_dataset,
batch_size=batch_size,
mode="predict")
engine.prepare(mode="predict")
for data in predict_dataloader:
outs = engine.run(data, feed=feed_dict, mode="predict")
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp')
engine.save(model_filename, training=True)
engine.load(model_filename)
temp_dir.cleanup()
def train_builtin_data_vars():
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)
# train
engine.to_mode("train")
input_spec = static.InputSpec([batch_size, image_size], 'float32', 'input')
label_spec = static.InputSpec([batch_size, 1], 'int64', 'label')
engine.prepare(inputs_spec=[input_spec], labels_spec=[label_spec])
with static.program_guard(engine.main_program, engine.startup_program):
feed_list = engine.inputs + engine.labels
print(feed_list)
loader = paddle.io.DataLoader.from_generator(feed_list=feed_list,
capacity=4 * batch_size,
iterable=False)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
for _ in range(epoch_num):
loader.start() # call DataLoader.start() before each epoch starts
try:
while True:
engine.run()
except paddle.fluid.core.EOFException:
loader.reset(
) # call DataLoader.reset() after catching EOFException
def train_non_builtin_data_vars():
main_program = static.Program()
startup_program = static.Program()
with static.program_guard(main_program,
startup_program), utils.unique_name.guard():
input = static.data(name="input",
shape=[batch_size, image_size],
dtype='float32')
label = static.data(name="label", shape=[batch_size, 1], dtype='int64')
loader = paddle.io.DataLoader.from_generator(feed_list=[input, label],
capacity=4 * batch_size,
iterable=False)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
predict = mlp(input)
loss_var = loss(predict, label)
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(loss=loss_var,
optimizer=optimizer,
metrics=metric,
strategy=strategy)
# train
engine.to_mode("train")
engine.prepare(inputs=[input],
labels=[label],
main_program=main_program,
startup_program=startup_program)
for _ in range(epoch_num):
loader.start() # call DataLoader.start() before each epoch starts
try:
while True:
engine.run()
except paddle.fluid.core.EOFException:
loader.reset(
) # call DataLoader.reset() after catching EOFException
def get_cost():
main_program = static.default_main_program()
startup_program = static.default_startup_program()
with static.program_guard(main_program,
startup_program), utils.unique_name.guard():
input = static.data(name="input",
shape=[batch_size, image_size],
dtype='float32')
label = static.data(name="label", shape=[batch_size, 1], dtype='int64')
loader = paddle.io.DataLoader.from_generator(feed_list=[input, label],
capacity=4 * batch_size,
iterable=False)
places = static.cuda_places()
loader.set_batch_generator(batch_generator_creator(), places=places)
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
predict = mlp(input)
loss_var = loss(predict, label)
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(loss=loss_var,
optimizer=optimizer,
metrics=metric,
strategy=strategy)
engine.cost()
def get_cost_by_spec():
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)
input_spec = static.InputSpec([batch_size, image_size], 'float32', 'input')
label_spec = static.InputSpec([batch_size, 1], 'int64', 'label')
engine.cost(mode="eval", inputs_spec=[input_spec], labels_spec=[label_spec])
if __name__ == "__main__":
train(fetch=True)
train(fetch=False)
train_high_level(fetch=True)
train_high_level(fetch=False)
train_low_level()
train_builtin_data_vars()
train_non_builtin_data_vars()
get_cost()
get_cost_by_spec()
......@@ -84,25 +84,32 @@ class TestGradientMergePass(unittest.TestCase):
def test_gradient_merge_pass(self):
# dp2 training
dp_engine = self.get_engine()
dp_losses = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_losses = np.array(dp_losses["loss"])
history = dp_engine.fit(self.dataset,
3,
batch_size=self.batch_size,
log_freq=1)
dp_losses = np.array(history.history["loss"])
# dp2 gradient merge training
gm_engine = self.get_engine(True)
gm_losses = gm_engine.fit(self.dataset, 3, batch_size=self.batch_size)
gm_losses = np.array(gm_losses["loss"])
avg_loss = 0
pass_avg_ret_list = []
for i, pass_ret in enumerate(gm_losses):
if (i + 1) % 4 == 0:
avg_loss += pass_ret
pass_avg_ret_list.append(avg_loss / 4)
avg_loss = 0
else:
avg_loss += pass_ret
self.check_results(dp_losses, np.array(pass_avg_ret_list))
history = gm_engine.fit(self.dataset,
3,
batch_size=self.batch_size,
log_freq=1)
gm_losses = np.array(history.history["loss"])
# avg_loss = 0
# pass_avg_ret_list = []
# for i, pass_ret in enumerate(gm_losses):
# if (i + 1) % 4 == 0:
# avg_loss += pass_ret
# pass_avg_ret_list.append(avg_loss / 4)
# avg_loss = 0
# else:
# avg_loss += pass_ret
# NOTE: every sample data from dataset is all the same
self.check_results(dp_losses, gm_losses)
if __name__ == "__main__":
......
......@@ -79,13 +79,13 @@ class TestRecomputePass(unittest.TestCase):
def test_recompute_pass(self):
# mp2 training
mp_engine = self.get_engine()
mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(mp_losses["loss"])
history = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(history.history["loss"])
# mp2 recompute training
rc_engine = self.get_engine(True)
rc_losses = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc_losses = np.array(rc_losses["loss"])
history = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc_losses = np.array(history.history["loss"])
self.check_results(mp_losses, rc_losses)
......
......@@ -84,31 +84,31 @@ class TestShardingPass(unittest.TestCase):
def test_sharding_pass(self):
# dp2 training
dp_engine = self.get_engine()
dp_losses = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_losses = np.array(dp_losses["loss"])
history = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_losses = np.array(history.history["loss"])
# sharding2 stage1 training
sharding1_engine = self.get_engine(True, 1)
sharding1_losses = sharding1_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding1_losses = np.array(sharding1_losses["loss"])
history = sharding1_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding1_losses = np.array(history.history["loss"])
self.check_results(dp_losses, sharding1_losses)
# sharding2 stage2 training
sharding2_engine = self.get_engine(True, 2)
sharding2_losses = sharding2_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding2_losses = np.array(sharding2_losses["loss"])
history = sharding2_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding2_losses = np.array(history.history["loss"])
self.check_results(dp_losses, sharding2_losses)
# sharding2 stage3 training
sharding3_engine = self.get_engine(True, 3)
sharding3_losses = sharding3_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding3_losses = np.array(sharding3_losses["loss"])
history = sharding3_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding3_losses = np.array(history.history["loss"])
self.check_results(dp_losses, sharding3_losses)
......
......@@ -82,6 +82,9 @@ from paddle.distributed.auto_parallel.cost.comp_op_cost import Transpose2OpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import Transpose2GradOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import Unsqueeze2OpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import WriteToArrayOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import DropoutGradOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import FusedSoftmaxMaskUpperTriangleOpCost
from paddle.distributed.auto_parallel.cost.comp_op_cost import FusedSoftmaxMaskUpperTriangleGradOpCost
from test_cluster import cluster_json
......@@ -417,6 +420,22 @@ class TestCompOpCost(unittest.TestCase):
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)
op_cost = DropoutGradOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)
op_cost = FusedSoftmaxMaskUpperTriangleOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)
op_cost = FusedSoftmaxMaskUpperTriangleGradOpCost(cluster=cluster)
self.assertTrue(op_cost.flops >= 0)
self.assertTrue(op_cost.time >= 0)
self.assertTrue(op_cost.memory >= 0)
# Remove unnecessary files
if os.path.exists(cluster_json_path):
os.remove(cluster_json_path)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.static import InputSpec
from paddle.distributed.fleet import auto
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=64,
intermediate_size=4 * 64,
initializer_range=0.02):
super(MLPLayer, self).__init__()
self.norm = nn.LayerNorm(hidden_size, epsilon=1e-5)
self.linear0 = nn.Linear(
hidden_size,
intermediate_size,
paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)),
bias_attr=None)
self.linear1 = nn.Linear(
intermediate_size,
hidden_size,
paddle.ParamAttr(initializer=nn.initializer.Normal(
mean=0.0, std=initializer_range)),
bias_attr=None)
def forward(self, input):
out = self.norm(input)
auto.shard_tensor(self.linear0.weight, auto.ProcessMesh([0, 1], "x"),
[None, "x"])
out = self.linear0(out)
out = F.gelu(out, approximate=True)
auto.shard_tensor(self.linear1.weight, auto.ProcessMesh([0, 1], "x"),
["x", None])
out = self.linear1(out)
if paddle.mean(out) < 2:
out = self.norm(out)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
else:
out = self.norm(out)
out = self.linear0(out)
out = self.linear1(out)
return out
def loss_fn(predict, label):
error_cost = paddle.nn.functional.square_error_cost(predict, label)
loss = paddle.mean(error_cost)
return loss
class TestSubblock(unittest.TestCase):
def test_subblock(self):
mlp = MLPLayer()
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(model=mlp, loss=loss_fn, strategy=strategy)
input_sepc = InputSpec([4, 64], 'float32', 'input')
label_spec = InputSpec([4, 1], 'float32', 'label')
engine.prepare(inputs_spec=[input_sepc],
labels_spec=[label_spec],
mode="predict")
if __name__ == "__main__":
unittest.main()
......@@ -199,7 +199,7 @@ class TestDistributedContext(unittest.TestCase):
"_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \
"_serial_ordered_op_nodes", "_original_serial_loss", \
"_original_serial_feed_vars", "_original_serial_fetch_vars", \
"_serial_loss", "_serial_feed_vars", "_serial_fetch_vars", "_lr_optimizer", \
"_serial_loss", "_serial_feed_vars", "_serial_fetch_vars", "_serial_optimizer", \
"_backup_serial_main_program_stack", "_backup_serial_startup_program_stack", \
"_pass_context"]
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册