未验证 提交 c5cc4278 编写于 作者: Y Yulong Ao 提交者: GitHub

[Cherry-pick][Auto Parallel] Improve the APIs (#46164)

* [AutoParallel] adapt gradient merge pass (#45915)

* adapt gradient merge

* fix op_role

* fix strategy

* [Auto Parallel] Gradient Fuse Allreduce (#45643)

* bugfix (#45332)

* dist embedding support lookup table v1

* add unitest

* customize wait_comm

* group gradients

* bugfix

* update program

* [Auto Parallel] Improve the APIs (#45776)

* [Auto Parallel] Use c++ dist attr in the completion process

* [Auto Parallel] Add minor changes

* [Auto Parallel] Use c++ dist attr in the completion process

* [Auto Parallel] Add minor changes

* [Auto Parallel] Add the serialization process for dist attrs

* [Auto Parallel] Remove unnecessary comments

* [Auto Parallel] Fix some bugs

* [Auto Parallel] Fix the code style

* [Auto Parallel] Remove unnecessary impls

* [Auto Parallel] Fix the importing error

* [Auto Parallel] Fix the copy from bugs of op dist attr

* [Auto Parallel] Replace the use of constexpr if

* [Auto Parallel] Redesign the shard_tensor, shard_op and ProcessMesh

* [Auto Parallel] Change API of the completion unittest

* [Auto Parallel] Fix the bug when set_attr an int

* [Auto Parallel] Add the unittest for the serialization

* [Auto Parallel] Add some unit tests

* [Auto Paralle] Unify the strategy

* [Auto Parallel] Improve the engine api

* [Auto Parallel] Reset the changes made to the framework

* [Auto Parallel] Change the engine unittest

* [Auto Parallel] Update API of the completion and partitioner

* [Auto Parallel] Update unit tests using engine api

* update shard annotation

* [Auto Parallel] Remove the modifications of other modules

* [Auto Parallel] Add docs for APIs

* add new strategy

* [Auto Parallel] Replace the logger

* [Auto Parallel] Restore the test_program.py

* [Auto Parallel] Change the import rules

* [Auto Parallel] Add the examples for Engine

* [Auto Parallel] Do some minor changes

* [Auto Parallel] Remove yaml dependency

* [Auto Parallel] Fix the unittests

* add valid after train

* bug fix
Co-authored-by: Nzhaoyingli <zhaoyingli@baidu.com>
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
Co-authored-by: Ncaozhou <48191911+Caozhou1995@users.noreply.github.com>

* [Auto Parallel] Bugfix allreduce fuse for MP (#46086)

* bugfix

* bugfix

* typos fixed

* update strategy (#46138)
Co-authored-by: Nzhaoyingli <86812880+zhaoyinglia@users.noreply.github.com>
Co-authored-by: NJZ-LIANG <jianzhongliang10@gmail.com>
Co-authored-by: Nzhaoyingli <zhaoyingli@baidu.com>
Co-authored-by: Ncaozhou <caozhou@radi.ac.cn>
Co-authored-by: Ncaozhou <48191911+Caozhou1995@users.noreply.github.com>
上级 860f6077
......@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .interface import shard_tensor # noqa: F401
from .interface import shard_op # noqa: F401
from .strategy import Strategy
from .process_mesh import ProcessMesh
from .reshard import Resharder # noqa: F401
from .cost_model import estimate_cost
from .engine import Engine
from .interface import shard_tensor
from .interface import shard_op
from .interface import recompute
from .interface import fetch
__all__ = []
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from collections import defaultdict
# _g_default_config[category][field] = default_value
_g_default_config = defaultdict(dict)
def get_category_default_config(category):
return _g_default_config[category]
def set_category_default_config(category, default_value):
_g_default_config[category] = default_value
def get_field_default_config(category, field):
return _g_default_config[category][field]
def set_field_default_config(category, field, default_value):
_g_default_config[category][field] = default_value
NOT_FOUND = "not_found"
#########################################
# base configuration
#########################################
BASE = "base"
set_field_default_config(BASE, "auto_mode", "semi")
set_field_default_config(BASE, "gradient_scale", True)
set_field_default_config(BASE, "use_cache", True)
set_field_default_config(BASE, "return_numpy", True)
set_field_default_config(BASE, "all_ranks", False)
set_field_default_config(BASE, "split_data", False)
set_field_default_config(BASE, "seed", None)
set_field_default_config(BASE, "reinit", False) # Only for debug
#########################################
# recompute configuration
#########################################
RECOMPUTE = "recompute"
set_field_default_config(RECOMPUTE, "enable", False)
set_field_default_config(RECOMPUTE, "checkpoints", None)
set_field_default_config(RECOMPUTE, "enable_tuning", False)
#########################################
# AMP configuration
#########################################
AMP = "amp"
set_field_default_config(AMP, "enable", False)
set_field_default_config(AMP, "init_loss_scaling", 32768.0)
set_field_default_config(AMP, "incr_every_n_steps", 1000)
set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2)
set_field_default_config(AMP, "incr_ratio", 2.0)
set_field_default_config(AMP, "decr_ratio", 0.8)
set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
set_field_default_config(AMP, "custom_white_list", [])
set_field_default_config(AMP, "custom_black_list", [])
set_field_default_config(AMP, "custom_black_varnames", [])
set_field_default_config(AMP, "use_pure_fp16", False)
set_field_default_config(AMP, "use_fp16_guard", True)
set_field_default_config(AMP, "use_optimizer_fp16", False)
#########################################
# sharding configuration
#########################################
SHARDING = "sharding"
set_field_default_config(SHARDING, "enable", False)
set_field_default_config(SHARDING, "stage", 1)
set_field_default_config(SHARDING, "degree", 8)
set_field_default_config(SHARDING, "segment_broadcast_MB", 32.0)
set_field_default_config(SHARDING, "enable_tuning", False)
set_field_default_config(SHARDING, "tuning_range", [])
#########################################
# gradient merge configuration
#########################################
GRADIENT_MERGE = "gradient_merge"
set_field_default_config(GRADIENT_MERGE, "enable", False)
set_field_default_config(GRADIENT_MERGE, "k_steps", 1)
set_field_default_config(GRADIENT_MERGE, "avg", True)
#########################################
# quantization configuration
#########################################
QAT = "qat"
set_field_default_config(QAT, "enable", False)
set_field_default_config(QAT, "channel_wise_abs_max", True)
set_field_default_config(QAT, "weight_bits", 8)
set_field_default_config(QAT, "activation_bits", 8)
set_field_default_config(QAT, "not_quant_pattern", ['skip_quant'])
set_field_default_config(QAT, "algo", None)
# #########################################
# auto tuning configuration
# #########################################
TUNING = "tuning"
set_field_default_config(TUNING, "enable", False)
set_field_default_config(TUNING, "batch_size", 1)
set_field_default_config(TUNING, "dataset", None)
set_field_default_config(TUNING, "profile_start_step", 1)
set_field_default_config(TUNING, "profile_end_step", 1)
set_field_default_config(TUNING, "run_after_tuning", True)
set_field_default_config(TUNING, "verbose", True)
......@@ -16,7 +16,7 @@ import paddle
import warnings
import logging
import numpy as np
from ..utils import get_logger
from .utils import get_logger
class Converter(object):
......
......@@ -173,6 +173,17 @@ class TensorDistributedAttribute:
def clear_annotated(self):
self._is_annotated.clear()
def __eq__(self, other):
if not isinstance(other, TensorDistributedAttribute):
return False
if self.process_mesh != other.process_mesh:
return False
if self.dims_mapping != other.dims_mapping:
return False
if self._is_annotated != other._is_annotated:
return False
return True
def __str__(self):
str = "\n\ttensor_dist_attr = {"
if self.is_annotated("process_mesh"):
......@@ -486,6 +497,27 @@ class OperatorDistributedAttribute:
else:
return False
def __eq__(self, other):
if not isinstance(other, OperatorDistributedAttribute):
return False
if self.process_mesh != other.process_mesh:
return False
if self.op_type != other.op_type:
return False
if self.impl_type != other.impl_type:
return False
if self.impl_idx != other.impl_idx:
return False
if self._is_annotated != other._is_annotated:
return False
if self._is_recompute != other._is_recompute:
return False
if self.inputs_dist_attrs != other.inputs_dist_attrs:
return False
if self.outputs_dist_attrs != other.outputs_dist_attrs:
return False
return True
def __str__(self):
str = "\n\top_dist_attr = {"
if self.is_annotated("process_mesh"):
......
......@@ -126,9 +126,6 @@ class DistributedContext:
# A flag indicates whether the used parallelism is data parallel
self._data_parallel = False
# flag whether using `to_static`
self._dygraph_mode = False
@property
def serial_main_program(self):
return self._serial_main_program
......
......@@ -23,6 +23,7 @@ from .dist_attribute import append_op_input_suffix
from .dist_attribute import append_op_output_suffix
from .dist_attribute import get_tensor_dist_attr_field_keys
from .dist_attribute import get_op_dist_attr_field_keys
from .utils import convert_to_shard_spec, verify_shard_spec
class DistributedOperator:
......@@ -248,23 +249,106 @@ class DistributedOperator:
return result
class DistributedModule:
class DistributedOperatorHelper:
def __init__(self, serial_module, dist_attr=None):
self._serial_module = serial_module
self._dist_attr = dist_attr
def __init__(self, serial_op, process_mesh, in_dims_mappings,
out_dims_mappings):
self._serial_op = serial_op
self._process_mesh = process_mesh
self._in_dims_mappings = in_dims_mappings
self._out_dims_mappings = out_dims_mappings
def __call__(self, *args, **kwargs):
from .dist_context import get_default_distributed_context
tensor_to_dims_mapping = {}
index = 0
if self._in_dims_mappings:
assert len(args) + len(kwargs) == len(self._in_dims_mappings), \
"The length of dims_mapping {} does not matching the length output {}.".format(len(self._in_dims_mappings), len(args) + len(kwargs))
for arg in args:
if isinstance(arg, Variable) and self._in_dims_mappings:
tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index]
index += 1
for arg in kwargs.values() and self._in_dims_mappings:
if isinstance(arg, Variable):
tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index]
index += 1
default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block()
op_size = len(cur_block.ops)
output = self._serial_module(*args, **kwargs)
output = self._serial_op(*args, **kwargs)
new_op_size = len(cur_block.ops)
if isinstance(output, tuple) or isinstance(output, list):
new_output = list(output)
elif isinstance(output, Variable):
new_output = [output]
else:
raise ValueError("Unrecognized outpout.")
if self._out_dims_mappings:
assert len(new_output) == len(self._out_dims_mappings), \
"The length of dims_mapping {} does not matching the length output {}.".format(len(self._out_dims_mappings), len(new_output))
for i, item in enumerate(new_output):
if isinstance(item, Variable) and self._out_dims_mappings:
tensor_to_dims_mapping[item.name] = self._out_dims_mappings[i]
from .dist_context import get_default_distributed_context
default_dist_ctx = get_default_distributed_context()
for idx in range(op_size, new_op_size):
op = cur_block.ops[idx]
dist_op = DistributedOperator(op, self._dist_attr)
dist_op.dist_attr.mark_annotated_as(self._dist_attr)
dist_op = DistributedOperator(op)
for name in dist_op.serial_op.input_arg_names:
if name in tensor_to_dims_mapping.keys():
tensor = dist_op.get_serial_input(name)
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(
name)
dims_mapping = tensor_to_dims_mapping[name]
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = tensor.shape
if dims_mapping is not None:
dims_mapping = tensor_to_dims_mapping[name]
shard_spec = convert_to_shard_spec(
dims_mapping, self._process_mesh)
assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
name, shard_spec, tensor_shape, self._process_mesh)
tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.mark_annotated("dims_mapping")
for name in dist_op.serial_op.output_arg_names:
if name in tensor_to_dims_mapping.keys():
tensor = dist_op.get_serial_output(name)
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(
name)
dims_mapping = tensor_to_dims_mapping[name]
if tensor is None:
tensor_shape = []
else:
if tensor.type == core.VarDesc.VarType.READER \
or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = tensor.shape
if dims_mapping is not None:
dims_mapping = tensor_to_dims_mapping[name]
shard_spec = convert_to_shard_spec(
dims_mapping, self._process_mesh)
assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
name, shard_spec, tensor_shape, self._process_mesh)
tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.mark_annotated("dims_mapping")
dist_op.dist_attr.process_mesh = self._process_mesh
if self._process_mesh is not None:
dist_op.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_op_for_program(dist_op)
return output
......@@ -27,7 +27,7 @@ from paddle.fluid.framework import static_only
from .utils import get_dist_attr
from .converter import Converter
from .process_group import _g_process_group_map
from ..utils import get_logger
from .utils import get_logger
def check_filename(re_exp, filename):
......@@ -59,6 +59,14 @@ class DistributedSaver:
def save(self, path, serial_program, dist_main_program, dist_context):
def _save_state(program, path, mode="param"):
state = {
k: np.array(v)
for k, v in program.state_dict(mode).items()
}
with open(path, "wb") as f:
pickle.dump(state, f)
dirname, filename = _process_path(path)
rank_id = paddle.distributed.get_rank()
......@@ -76,16 +84,6 @@ class DistributedSaver:
with open(dist_model_path, "wb") as f:
f.write(dist_main_program.desc.serialize_to_string())
# save distributed params
dist_param_filename = filename + "_dist" + str(rank_id) + ".pdparams"
dist_param_path = os.path.join(dirname, dist_param_filename)
dist_param = {
k: np.array(v)
for k, v in dist_main_program.state_dict().items()
}
with open(dist_param_path, "wb") as f:
pickle.dump(dist_param, f)
# save distributed attribute
dist_attr_filename = filename + "_dist" + str(rank_id) + ".pdattr"
dist_attr_path = os.path.join(dirname, dist_attr_filename)
......@@ -93,65 +91,69 @@ class DistributedSaver:
with open(dist_attr_path, "wb") as f:
pickle.dump(dist_attrs, f)
# save distributed params
dist_param_filename = filename + "_dist" + str(rank_id) + ".pdparams"
dist_param_path = os.path.join(dirname, dist_param_filename)
_save_state(dist_main_program, dist_param_path)
# save distributed opt states
dist_opt_filename = filename + "_dist" + str(rank_id) + ".pdopt"
dist_opt_path = os.path.join(dirname, dist_opt_filename)
_save_state(dist_main_program, dist_opt_path, "opt")
# TODO:save cluster.json
def load(self,
path,
program,
dist_context,
strict=True,
load_optimizer=True):
def load(self, path, load_optimizer=True):
# TODO: if `program` is None, load `path.pdmodel`.
def _load_file(filename, dirname, suffix="pdparams"):
file_list = []
for file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).{}'.format(filename, suffix),
file):
file_list.append(os.path.join(dirname, file))
file_list.sort()
return file_list
def _load_state(filename, dirname, suffix="pdparams"):
file_list = _load_file(filename, dirname, suffix)
state_dict = {}
for file in file_list:
with open(file, 'rb') as f:
state_dict_info = pickle.load(f, encoding='latin1')
for name, value in state_dict_info.items():
if name in state_dict:
state_dict[name].append(np.array(value))
else:
state_dict[name] = [np.array(value)]
self._logger.info("Load param file: {}".format(file_list))
return state_dict
filename = os.path.basename(path)
if filename == "":
raise ValueError(
"path should be of 'dirname/filename' format, but received filename is empty string"
)
dirname = os.path.dirname(path)
# load path.pdparam
param_file_list = []
for param_file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).pdparams'.format(filename),
param_file):
param_file_list.append(os.path.join(dirname, param_file))
param_file_list.sort()
self._logger.info(
"Load distributed attribute file: {}".format(param_file_list))
param_dict = {}
for param_file in param_file_list:
with open(param_file, 'rb') as f:
state_dict_info = pickle.load(f, encoding='latin1')
for name, value in state_dict_info.items():
if name in param_dict:
param_dict[name].append(np.array(value))
else:
param_dict[name] = [np.array(value)]
# load path.pdparam and path.pdopt
param_state_dict = _load_state(filename, dirname)
opt_state_dict = _load_state(filename, dirname,
"pdopt") if load_optimizer else {}
state_dict = dict(param_state_dict, **opt_state_dict)
# load path.pdattr
dist_attr_file_list = []
for dist_attr_file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).pdattr'.format(filename),
dist_attr_file):
dist_attr_file_list.append(os.path.join(dirname,
dist_attr_file))
dist_attr_file_list.sort()
dist_attr_file_list = _load_file(filename, dirname, "pdattr")
self._logger.info(
"Load distributed attribute file: {}".format(dist_attr_file_list))
pre_dist_attr = {}
dist_attr = {}
for dist_attr_file in dist_attr_file_list:
with open(dist_attr_file, 'rb') as f:
dist_attr = pickle.load(f, encoding='latin1')
for name, attr in dist_attr.items():
if name not in pre_dist_attr:
pre_dist_attr[name] = attr
# get current dist_attr
cur_dist_attr = get_dist_attr(program, dist_context)
# param convert
converter = Converter(param_dict, pre_dist_attr, cur_dist_attr)
param_dict = converter.convert(strict=strict)
program.set_state_dict(param_dict)
dist_attr_info = pickle.load(f, encoding='latin1')
for name, attr in dist_attr_info.items():
if name not in dist_attr:
dist_attr[name] = attr
return state_dict, dist_attr
def save_inference_model(self, path, feed_vars, fetch_vars, exe, **kwargs):
......
......@@ -12,80 +12,169 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
import copy
import logging
import random
import numpy as np
from collections import defaultdict
import paddle
import paddle.utils as utils
from paddle import fluid, static
from paddle.io import Dataset
from paddle.jit import to_static
from paddle.metric import Metric
from paddle.static import InputSpec
from paddle.fluid import core
from paddle.fluid import program_guard
from paddle.fluid import Variable
from paddle.fluid.layers.utils import flatten
from paddle.fluid.executor import global_scope, _to_name_str
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Operator, Parameter, _non_static_mode
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import fleet
from paddle.distributed.passes import new_pass, PassContext
from .converter import Converter
from .helper import ProgramHelper
from ..collective import _get_global_env
from .cluster import Cluster, get_default_cluster
from .planner_v2 import Planner
from .parallelizer_v2 import Parallelizer
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .dist_loader import NonIterableGeneratorLoader
from .utils import make_data_unshard, set_grad_var_shape
from .utils import print_program_with_dist_attr, to_list
from .process_group import new_process_group, get_all_process_groups, get_world_process_group
from .utils import get_logger, get_dist_attr
from .process_group import new_process_group, get_all_process_groups
from .dist_context import DistributedContext, get_default_distributed_context
from .strategy import Strategy
from .interface import _get_fetches
class Engine:
"""
An Engine object can provide the full power of auto parallel to users.
With the help of it, users can easily obtain the abilities of the
distributed training and inference. It also support the dynamic graph and
static graph at the same time.
Args:
model (paddle.nn.Layer, optional): The model is an instance of
paddle.nn.Layer.
loss (Loss|Callable|None, optional): The loss can be a `paddle.nn.Layer`
instance or any callable function taken the predicted values and
ground truth values as input. It can be None when there is no loss.
Default: None.
optimizer (Optimizer|None, optional): The optimizer need to be set in training
and should be None in eval and predict mode. Default: None.
metrics (Metric|list[Metric]|None, optional): If metrics is set, all
metrics will be calculated and output in train/eval mode. Default: None.
cluster (Cluster|None, optional): The cluster represents the topology information
about the used physical devices. Default: None. (Unused for now)
strategy (Strategy|None, optional): The strategy is used to configure the
parallelization and optimization behaviors. Default: None.
Examples:
.. code-block:: python
import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.vision.datasets import MNIST
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = MNIST(mode='train', transform=transform)
valid_dataset = MNIST(mode='test', transform=transform)
model = paddle.vision.models.LeNet()
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
metrics = paddle.metric.Accuracy(topk=(1, 2))
engine = auto.Engine(model, loss, optimizer, metrics)
# fit
engine.fit(train_dataset,
epochs=2,
batch_size=64)
# evaluate
engine.evaluate(valid_dataset,
batch_size=64)
# predict
engine.predict(valid_dataset,
batch_size=64)
# save
engine.save("./my_model")
# load
engine.load("./my_model")
"""
def __init__(self,
model=None,
inputs_spec=None,
labels_spec=None,
loss=None,
optimizer=None,
metrics=None,
cluster=None,
strategy=None,
user_tuning_config=None):
self.model = model
self.inputs_spec = self._validate_spec(inputs_spec)
self.labels_spec = self._validate_spec(labels_spec)
self.cluster = cluster
if self.cluster is None:
self.cluster = get_default_cluster()
self.strategy = strategy
if self.strategy is None:
self.strategy = fleet.DistributedStrategy()
self._user_tuning_config = user_tuning_config
strategy=None):
if model and not isinstance(model,
paddle.nn.Layer) and not callable(model):
raise TypeError(
"'model must be sub classes of `paddle.nn.Layer` or any callable function."
)
self._model = model
if loss and not isinstance(loss,
paddle.nn.Layer) and not callable(loss):
raise TypeError(
"'loss' must be sub classes of `paddle.nn.Layer` or any callable function."
)
self._loss = loss
if optimizer and not isinstance(
optimizer,
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)):
raise TypeError(
"'optimizer' must be object of class `paddle.optimizer.Optimizer`"
" or `paddle.fluid.optimizer.Optimizer`.")
self._optimizer = self._validate_opt(optimizer)
metrics = metrics or []
for metric in to_list(metrics):
assert isinstance(metric, Metric), \
"{} is not sub class of Metric".format(
metric.__class__.__name__)
self._metrics = to_list(metrics)
if cluster and not isinstance(cluster, Cluster):
raise TypeError(
"'cluster' must be the object or class `paddle.distributed.auto_parallel.Cluster`"
)
self._cluster = cluster or get_default_cluster()
if strategy and not isinstance(strategy, Strategy):
raise TypeError(
"'strategy' must be object of class `paddle.distributed.auto_parallel.Strategy`"
)
self._strategy = strategy or Strategy()
if os.getenv("POD_NAME"):
print("Distribute training by paddle.distributed.launch",
flush=True)
fleet.init(is_collective=True)
self._executor = None
self._cur_rank = paddle.distributed.get_rank()
self._nranks = paddle.distributed.get_world_size()
self._saver = DistributedSaver()
# TODO: add logger module
self._logger = logging.getLogger()
self._logger.propagate = False
if not self._logger.handlers:
self._logger.setLevel(logging.INFO)
log_handler = logging.StreamHandler()
log_format = logging.Formatter(
'[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
)
log_handler.setFormatter(log_format)
self._logger.addHandler(log_handler)
self._logger = get_logger(logging.INFO)
self._orig_main_prog = static.default_main_program()
self._orig_startup_prog = static.default_startup_program()
......@@ -103,54 +192,18 @@ class Engine:
"eval": False,
"predict": False
}
self._dygraph_mode = False
def prepare(self,
optimizer=None,
loss=None,
gradient_scale=True,
metrics=None,
all_ranks=False):
if optimizer and not isinstance(
optimizer,
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)):
raise TypeError(
"'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
" or `paddle.fluid.optimizer.Optimizer`."
)
self._optimizer = self._validate_opt(optimizer)
if loss and not isinstance(loss,
paddle.nn.Layer) and not callable(loss):
raise TypeError(
"'loss' must be sub classes of `paddle.nn.Layer` or any callable function."
)
self._loss = loss
metrics = metrics or []
for metric in to_list(metrics):
assert isinstance(metric, Metric), \
"{} is not sub class of Metric".format(
metric.__class__.__name__)
self._metrics = to_list(metrics)
self._gradient_scale = gradient_scale
self._planned_mode = None
self._all_ranks = all_ranks
self._prepare_single_mode("train")
self._dygraph_mode = False
self._tuning = self._strategy.tuning
def _prepare_single_mode(self, mode):
# Do the build process
self._build(mode)
# Do the planning process
self._plan(mode)
# Do the Optimization tuning
if self._user_tuning_config and mode == "train":
self._optimization_tuning(mode)
# Do the parallel process
self._parallel(mode, self._all_ranks)
self._parallel(mode)
# Init comm and startup program
self._initialize(mode)
self._mode_init_states[mode] = True
......@@ -163,7 +216,7 @@ class Engine:
inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else []
self.program_helper = ProgramHelper(self.model, self._loss,
self.program_helper = ProgramHelper(self._model, self._loss,
self._metrics, inputs_spec,
labels_spec)
# build forward main program
......@@ -190,14 +243,13 @@ class Engine:
metrics = []
serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone()
# FIXME to support grad clip
with static.program_guard(serial_main_prog, serial_startup_prog), \
utils.unique_name.guard():
inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else []
inputs = [s._create_feed_layer() for s in inputs_spec]
labels = [s._create_feed_layer() for s in labels_spec]
outputs = to_list(self.model(*inputs))
outputs = to_list(self._model(*inputs))
if mode != "predict" and self._loss:
losses = to_list(self._loss(*(outputs + labels)))
......@@ -221,25 +273,30 @@ class Engine:
"metrics": metrics
}
if mode != "train":
serial_main_prog = serial_main_prog.clone(for_test=True)
self._set_recompute_ckpts()
self._dist_contexts[mode] = DistributedContext(
serial_main_prog, serial_startup_prog, self._optimizer, losses,
feed_vars, fetch_vars, self.cluster, self.strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale
self._dist_contexts[mode]._dygraph_mode = self._dygraph_mode
feed_vars, fetch_vars, self._cluster, self._strategy)
self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale
def _optimization_tuning(self, mode):
def _optimization_tuning(self, mode, dataset, batch_size):
if not self._tuning.enable:
raise ValueError("Please set `tuning.enable=True`.")
self.mode = mode
assert "batch_size" in self._user_tuning_config, "Optimization Tuning should provide with batch size."
assert "dataset" in self._user_tuning_config, "Optimization Tuning should provide with dataset."
batch_size = self._user_tuning_config["batch_size"]
dataset = self._user_tuning_config["dataset"]
dataset.dp_world_size = self.dp_world_sizes
dataset.dp_rank = self.dp_ranks
assert mode == "train"
# Do the build process
self._build(mode)
# Do the planning process
self._plan(mode)
dataset.dp_world_size = self._dp_world_sizes
dataset.dp_rank = self._dp_ranks
from .tuner.optimization_tuner import OptimizationTuner
self._optimization_tuner = OptimizationTuner(self._user_tuning_config,
self._optimization_tuner = OptimizationTuner(self._tuning.to_dict(),
self._dist_contexts[mode],
dataset,
self.inputs_spec,
......@@ -249,12 +306,10 @@ class Engine:
self._optimization_tuner.tune()
if self._user_tuning_config["run_after_tuning"]:
if self._tuning.run_after_tuning:
# update the strategy
self._dist_contexts[
mode]._strategy = self._optimization_tuner.get_best_config()
else:
return
def _plan(self, mode):
if self._planned_mode is None:
......@@ -274,15 +329,15 @@ class Engine:
if var.name in block.vars:
feed_list.append(block.vars[var.name])
self.dp_world_sizes = []
self.dp_ranks = []
self._dp_world_sizes = []
self._dp_ranks = []
for feed_var in feed_list:
dp_world_size, dp_rank = self._get_input_split_info(
feed_var, self._dist_contexts[mode])
self.dp_world_sizes.append(dp_world_size)
self.dp_ranks.append(dp_rank)
self._dp_world_sizes.append(dp_world_size)
self._dp_ranks.append(dp_rank)
def _parallel(self, mode, all_ranks):
def _parallel(self, mode, all_ranks=False):
# Parallelize program based on the planner's results
# For now, the completer has to be passed to the planner,
# because we may use it to complete the annotation of the backwarkward and update.
......@@ -340,6 +395,11 @@ class Engine:
if isinstance(place, fluid.CUDAPlace):
place = fluid.CUDAPlace(ParallelEnv().dev_id)
if self._strategy.seed:
paddle.seed(self._strategy.seed + self._dp_ranks[0])
np.random.seed(self._strategy.seed + self._dp_ranks[0])
random.seed(self._strategy.seed + self._dp_ranks[0])
if self._dygraph_mode:
dist_context = self._dist_contexts[mode]
dist_main_program = self._dist_main_progs[mode][self._cur_rank]
......@@ -358,136 +418,303 @@ class Engine:
prune_startup_prog = dist_startup_prog._prune(uninitialized)
self._executor.run(prune_startup_prog)
if self.strategy.amp and self.strategy.amp_configs['use_pure_fp16']:
# from paddle.fluid.contrib.mixed_precision.fp16_utils import cast_parameters_to_fp16
def cast_parameters_to_fp16(place,
program,
scope=None,
to_fp16_var_names=None):
"""
Traverse all parameters in the whole model and set them to the FP16 data type.
Whereas, this function will keep parameters of batchnorms in FP32.
Args:
place(fluid.CPUPlace|fluid.CUDAPlace): `place` is used to restore the FP16 weight tensors.
program (Program): The used program.
scope(fluid.Scope, optional): `scope` is used to get the FP32 weight tensor values.
Default is None.
to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names`
will be set to FP16. Usually, it is the returned
value of `cast_model_to_fp16` API.
"""
from paddle.framework import core
import numpy as np
all_parameters = []
for block in program.blocks:
all_parameters.extend(block.all_parameters())
var_scope = scope if scope else paddle.static.global_scope()
for param in all_parameters:
if param.dtype == core.VarDesc.VarType.FP16:
param_t = var_scope.find_var(
param.name).get_tensor()
data = np.array(param_t)
param_t.set(np.float16(data), place)
cast_parameters_to_fp16(place, prune_startup_prog)
if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"):
self._set_state_dict(mode, self._strict, self._state_dict,
self._dist_attr)
if self._strategy.reinit:
self._logger.info("NOTE: parameters wiil be re-initialized.")
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
self._executor.run(dist_startup_prog)
def _infer_sample_spec(self, data, batch_size, split):
if isinstance(data, paddle.io.IterableDataset):
if split is None:
input, label = next(iter(data))
else:
sample = next(iter(data))
input = sample[:split]
label = sample[split:]
elif isinstance(data, paddle.io.Dataset):
if split is None:
input, label = data[0]
else:
sample = data[0]
input = sample[:split]
label = sample[split:]
else:
raise ValueError(
"Data should be a Dataset or IterableDatset, but received {}.".
format(type(data).__name__))
self.inputs_spec = []
self.labels_spec = []
input_list = to_list(input)
label_list = to_list(label)
def _infer_item_spec(item, name, batch_size, specs):
if isinstance(item, np.ndarray):
spec = InputSpec.from_numpy(item, name)
if batch_size is None:
specs.append(spec)
else:
specs.append(spec.batch(batch_size))
elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)):
spec = InputSpec.from_tensor(item, name)
if batch_size is None:
specs.append(spec)
else:
specs.append(spec.batch(batch_size))
else:
specs.append(InputSpec([batch_size], type(item), name))
if input_list is not None:
for i, item in enumerate(input_list):
assert item is not None, "Receive None input."
name = "input" + str(i)
_infer_item_spec(item, name, batch_size, self.inputs_spec)
if label_list is not None:
for i, item in enumerate(label_list):
assert item is not None, "Receive None input."
name = "label" + str(i)
_infer_item_spec(item, name, batch_size, self.labels_spec)
self.inputs_spec = self._validate_spec(self.inputs_spec)
self.labels_spec = self._validate_spec(self.labels_spec)
def fit(self,
train_data,
train_sample_split=None,
batch_size=1,
epochs=1,
fetches=None,
steps_per_epoch=None,
valid_data=None,
valid_sample_split=None,
valid_freq=1,
valid_steps=None,
collate_fn=None,
use_cache=False,
return_numpy=True):
# TODO: callbacks
# TODO: evaluate after training
if not self._mode_init_states['train']:
raise Exception(
"train program is not initialized yet, please call engine.prepare() before calling fit() funtion."
)
callbacks=None):
"""
Trains the model for a fixed number of epochs. If `valid_data` is set,
evaluation will be done at the end of each epoch.
Args:
train_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
train_sample_split (int, optional): Each sample of the train dataset is assumed
to be a (input, label) pair by default and has two items. If each sample has
more than two items, train_sample_split specifies how to split these items into
input and label. The items before it are input and the left are label. Default: None.
batch_size (int, optional): The batch size of train_data and valid_data if provided.
The user's data will be used directly without batching if set to None. Default: 1.
epochs (int, optional): The number of epochs to train the model. Default: 1.
steps_per_epoch (int, optional): The total number of steps (batches of samples)
is executed in one epoch before stating the next one. If None, it is equal to
the number samples in your dataset divided by the batch size. Default: None.
valid_data (Dataset, optional): An instance of paddle paddle.io.Dataset used for
evaluation at the end of epoch. No evaluation will be done if set to None.
Default: None. (Unsupported for now)
valid_freq (int, optional): Only relevant if valid_data is provided. This specifies
how many training epochs before a new evaluation is performed. Default: 1.
valid_sample_split (int, optional): Only relevant if valid_data is provided.
Each sample of the valid dataset is assumed to be a (input, label) pair
by default and has two items. If each sample has more than two items,
valid_sample_split specifies how to split these items into input and label.
The items before it are input and the left are label. Default: None.
valid_steps (int, optional): Only relevant if valid_data is provided.
It is the total number of steps (batches of samples) to draw before
stopping validation at the end of every epoch. If None, validation will run until the
`valid_data` dataset is exhausted. The validation will start from the
beginning of the dataset at each epoch. Default: None.
collate_fn(callable, optional): function to generate mini-batch data by merging
the sample list, None for only stack each fields of sample in axis
0. Default None.
callbacks (Callback|None, optional): A list of `Callback` instances to apply
during training. Default: None. (Unused for now)
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.vision.datasets import MNIST
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = MNIST(mode='train', transform=transform)
model = paddle.vision.models.LeNet()
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
metrics = paddle.metric.Accuracy(topk=(1, 2))
engine = auto.Engine(model, loss, optimizer, metrics)
engine.fit(train_dataset,
epochs=2,
batch_size=64)
"""
self.mode = 'train'
self._infer_sample_spec(train_data, batch_size, train_sample_split)
if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode)
else:
self._switch_mode("train")
assert self.mode in self._dist_main_progs, \
"train model is not ready, please call `engine.prepare()` first."
"train model is not ready, please call `engine._prepare_single_mode('train')` first."
train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch,
collate_fn)
usr_fetch = self._validate_fetches(fetches)
fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch)
lr_scheduler = self.get_lr_scheduler(self.main_program)
fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"])
inner_fetch = dict(fetch_loss, **fetch_metrics)
usr_fetch = self._validate_fetches(_get_fetches())
fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch)
lr_scheduler = self._get_lr_scheduler(self.main_program)
outputs = defaultdict(list)
for epoch in range(epochs):
train_logs = {"epoch: {:d} ": epoch}
for step, _ in enumerate(train_dataloader):
try:
outs = self._executor.run(self.main_program,
outs = self._executor.run(
self.main_program,
fetch_list=fetch_list,
use_program_cache=use_cache,
return_numpy=return_numpy)
except fluid.core.EOFException:
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy)
except core.EOFException:
break
train_logs["step: {:d} "] = step
if lr_scheduler is not None:
# update lr
if lr_scheduler and step % self._k_steps == 0:
lr_scheduler.step()
try:
train_logs["lr: {:5e} "] = self._lr_optimizer.get_lr()
except:
train_logs[
"lr: {:5e} "] = self._lr_optimizer._learning_rate.get_lr(
)
train_logs["lr: {:5e} "] = self._get_lr(self._lr_optimizer)
# inner fetches
if fetch_loss:
train_logs["loss: {:9f} "] = outs[0][0]
train_logs["loss: {:8f} "] = outs[0][0]
outputs["loss"].append(outs[0][0])
# Metric
if fetch_metrics:
metric_out = outs[len(fetch_loss):len(inner_fetch)]
for metric in self._metrics:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
train_logs[metric.name()[i] + ": {:8f} "] = res
outputs[metric.name()[i]].append(outs[0][0])
# user fetches
user_outs = outs[len(fetch_loss):]
user_fetch_list = fetch_list[len(fetch_loss):]
user_outs = outs[len(inner_fetch):]
user_fetch_list = fetch_list[len(inner_fetch):]
for i, out in enumerate(user_outs):
train_logs[fetch_map[user_fetch_list[i]] + ": {}"] = out
# logger
string = '[train] ' + ''.join(list(train_logs.keys()))
self._logger.info(string.format(*list(train_logs.values())))
if valid_data and epoch % valid_freq == 0:
self.evaluate(valid_data, valid_sample_split, batch_size,
valid_steps, collate_fn, callbacks)
self._switch_mode("train")
else:
self._reset_metrics()
return outputs
def evaluate(self,
eval_data,
valid_data,
valid_sample_split=None,
batch_size=1,
fetches=None,
steps=None,
collate_fn=None,
use_cache=False,
return_numpy=True):
callbacks=None):
"""
Evaluate the loss and metrics of the model on evaluation data.
Args:
valid_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
valid_sample_split (int, optional): Each sample of the eval dataset is assumed
to be a (input, label) pair by default and has two items. If each sample has
more than two items, valid_sample_split specifies how to split these items into
input and label. The items before it are input and the left are label. Default: None.
batch_size (int, optional): The batch size of valid_data. The user's data will
be used directly without batching if set to None. Default: 1.
steps (int, optional): It is the total number of steps (batches of samples) to draw before
stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted.
The evaluation will start from the beginning of the dataset in each run. Default: None.
collate_fn(callable, optional): function to generate mini-batch data by merging
the sample list, None for only stack each fields of sample in axis
0. Default None.
callbacks (Callback|None, optional): A list of `Callback` instances to apply
during evaling. Default: None. (Unused for now)
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.vision.datasets import MNIST
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
valid_dataset = MNIST(mode='test', transform=transform)
model = paddle.vision.models.LeNet()
loss = paddle.nn.CrossEntropyLoss()
metrics = paddle.metric.Accuracy(topk=(1, 2))
engine = auto.Engine(model, loss, metrics=metrics)
engine.evaluate(valid_dataset, batch_size=64)
"""
self.mode = 'eval'
self._infer_sample_spec(valid_data, batch_size, valid_sample_split)
if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode)
else:
self._switch_mode("eval")
assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine.prepare()` first."
eval_dataloader = self._create_dataloader(eval_data,
"eval model is not ready, please call `engine._prepare_single_mode('eval')` first."
valid_dataloader = self._create_dataloader(valid_data,
batch_size,
steps_per_epoch=steps,
collate_fn=collate_fn)
usr_fetch = self._validate_fetches(fetches)
fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"])
inner_fetch = dict(fetch_loss, **fetch_metrics)
usr_fetch = self._validate_fetches(_get_fetches())
fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch)
for step, _ in enumerate(eval_dataloader):
eval_logs = {"step: {:d} ": step}
outputs = defaultdict(list)
for step, _ in enumerate(valid_dataloader):
try:
outs = self._executor.run(self.main_program,
outs = self._executor.run(
self.main_program,
fetch_list=fetch_list,
use_program_cache=use_cache,
return_numpy=return_numpy)
except fluid.core.EOFException:
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy)
except core.EOFException:
break
eval_logs = {"step: {:d} ": step}
# inner fetches
if fetch_loss:
eval_logs["loss: {:9f} "] = outs[0][0]
eval_logs["loss: {:8f} "] = outs[0][0]
outputs["eval_loss"].append(outs[0][0])
# Metric
if fetch_metrics:
metric_out = outs[len(fetch_loss):len(inner_fetch)]
......@@ -495,8 +722,9 @@ class Engine:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
eval_logs[metric.name()[i] + ": {:9f} "] = res
# usr fetches
eval_logs[metric.name()[i] + ": {:8f} "] = res
outputs["eval_" + metric.name()[i]].append(res)
# user fetches
usr_outs = outs[len(inner_fetch):]
usr_fetch_list = fetch_list[len(inner_fetch):]
for i, out in enumerate(usr_outs):
......@@ -504,38 +732,88 @@ class Engine:
# logger
string = '[eval] ' + ''.join(list(eval_logs.keys()))
self._logger.info(string.format(*list(eval_logs.values())))
self._reset_metrics()
return outputs
def predict(self,
test_data,
test_sample_split=None,
batch_size=1,
fetches=None,
steps=None,
collate_fn=None,
use_cache=False,
return_numpy=True):
callbacks=None):
"""
Compute the output predictions on testing data.
Args:
test_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None.
test_sample_split (int, optional): Each sample of the test dataset is assumed
to be a (input, label) pair by default and has two items. If each sample has
more than two items, test_sample_split specifies how to split these items into
input and label. The items before it are input and the left are label. Default: None.
batch_size (int, optional): The batch size of test_data. The user's data will
be used directly without batching if set to None. Default: 1.
steps (int, optional): It is the total number of steps (batches of samples) to draw before
stopping predict. If None, predict will run until the `test_data` dataset is exhausted.
The predict will start from the beginning of the dataset in each run. Default: None.
collate_fn(callable, optional): function to generate mini-batch data by merging
the sample list, None for only stack each fields of sample in axis
0. Default None.
callbacks (Callback|None, optional): A list of `Callback` instances to apply
during testing. Default: None. (Unused for now)
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.vision.datasets import MNIST
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
valid_dataset = MNIST(mode='test', transform=transform)
model = paddle.vision.models.LeNet()
engine = auto.Engine(model)
engine.predict(valid_dataset, batch_size=64)
"""
self.mode = 'predict'
self._infer_sample_spec(test_data, batch_size, test_sample_split)
if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode)
else:
self._switch_mode("predict")
assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine.prepare()` first."
"predict model is not ready, please call `engine._prepare_single_mode('predict')` first."
test_dataloader = self._create_dataloader(test_data,
batch_size,
steps_per_epoch=steps,
collate_fn=collate_fn)
usr_fetch = self._validate_fetches(fetches)
fetch_outputs = self._validate_fetches(self.fetch_vars["outputs"])
usr_fetch = self._validate_fetches(_get_fetches())
fetch_list, fetch_map = self._fetch_map(fetch_outputs, usr_fetch)
outputs = []
for step, _ in enumerate(test_dataloader):
predict_logs = {"step: {:d} ": step}
try:
outs = self._executor.run(self.main_program,
outs = self._executor.run(
self.main_program,
fetch_list=fetch_list,
use_program_cache=use_cache,
return_numpy=return_numpy)
except fluid.core.EOFException:
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy)
except core.EOFException:
break
predict_logs = {"step: {:d} ": step}
outputs.append(outs[:len(fetch_outputs)])
for i, out in enumerate(outs):
predict_logs[fetch_map[fetch_list[i]] + ": {}"] = out
......@@ -545,12 +823,23 @@ class Engine:
return outputs
def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
self.mode = 'train'
self._infer_sample_spec(tune_data, batch_size, tune_sample_split)
self._optimization_tuning(self.mode, tune_data, batch_size)
def _create_dataloader(self,
dataset,
batch_size,
epochs=1,
steps_per_epoch=None,
collate_fn=None):
if self._strategy.gradient_merge and batch_size is not None:
assert batch_size % self._k_steps == 0, \
"Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps)
batch_size //= self._k_steps
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank]
dist_context = self._dist_contexts[self.mode]
......@@ -589,9 +878,9 @@ class Engine:
epochs,
steps_per_epoch,
collate_fn,
data_parallel_world_size=self.dp_world_sizes,
data_parallel_rank=self.dp_ranks,
split_data=self.strategy.split_data)
data_parallel_world_size=self._dp_world_sizes,
data_parallel_rank=self._dp_ranks,
split_data=self._strategy.split_data)
# move read op from the end of program to the start of program
new_op_size = len(dist_main_block.ops)
......@@ -612,6 +901,7 @@ class Engine:
def _validate_spec(self, specs):
specs = to_list(specs)
self._k_steps = self._strategy.gradient_merge.k_steps
if specs is not None:
for i, spec in enumerate(specs):
assert isinstance(spec, InputSpec)
......@@ -619,6 +909,12 @@ class Engine:
raise ValueError(
"Requires Input[{}].name != None, but receive `None` with {}."
.format(i, spec))
if self._k_steps > 1:
shape = list(spec.shape)
assert shape[0] % self._k_steps == 0, \
"Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self._k_steps)
shape[0] //= self._k_steps
spec.shape = shape
return specs
def _is_local_var(self, var):
......@@ -678,41 +974,98 @@ class Engine:
# NOTE hack to enable recompute in engine api for GPT-3
# TODO support more PaddleNLP/CV models here
config = self.strategy.recompute_configs
recompute = self._strategy.recompute
# extract ckpts by specific model
if isinstance(self.model, paddle.nn.Layer):
if isinstance(self._model, paddle.nn.Layer):
if hasattr(
self.model, "gpt"
) and self.model.__class__.__name__ == 'GPTForPretraining':
exact_ckpts = self.model.gpt.checkpoints
self._model, "gpt"
) and self._model.__class__.__name__ == 'GPTForPretraining':
exact_ckpts = self._model.gpt.checkpoints
else:
exact_ckpts = config["checkpoints"]
exact_ckpts = recompute.checkpoints
else:
exact_ckpts = config["checkpoints"]
exact_ckpts = recompute.checkpoints
# modify strategy
if self.strategy.recompute:
config["checkpoints"] = exact_ckpts[:]
self.strategy.recompute_configs = config
if recompute.enable:
recompute.checkpoints = exact_ckpts[:]
logs = {
'Model Class': self.model.__class__.__name__,
'Model Class': self._model.__class__.__name__,
'Applied Recompute ckpts': exact_ckpts
}
self._logger.info(logs)
def _validate_opt(self, optimizer):
if optimizer is not None:
optimizer._parameter_list = None
optimizer._param_groups = None
return optimizer
def save(self, path, training=True, mode=None):
if not mode:
mode = self.mode
def _reset_metrics(self):
for metric in self._metrics:
metric.reset()
def _switch_mode(self, mode):
self.mode = mode
self._initialize(mode)
def _set_state_dict(self, mode, strict, state_dict, dist_attr):
program = self._dist_main_progs[mode][self._cur_rank]
dist_context = self._dist_contexts[mode]
cur_dist_attr = get_dist_attr(program, dist_context)
converter = Converter(state_dict, dist_attr, cur_dist_attr)
state_dict = converter.convert(strict=strict)
program.set_state_dict(state_dict)
def save(self, path, training=True):
"""
Saves the model, parameters, optimizer state to path.
If `training` is set to False, only inference model will be saved.
Args:
path (str): The file prefix to save model. The format
is 'dirname/file_prefix' or 'file_prefix'. if empty str.
A exception will be raised.
training (bool, optional): Whether to save for training. If not, save
for inference only. If `training` is set to True, the optimzer state
will be saved. Otherwise, only the model and parameters are saved.
This function will silently overwrite existing file at the target
location. Default: True.
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.vision.datasets import MNIST
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = MNIST(mode='train', transform=transform)
model = paddle.vision.models.LeNet()
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
metrics = paddle.metric.Accuracy(topk=(1, 2))
engine = auto.Engine(model, loss, optimizer, metrics)
engine.fit(train_dataset,
epochs=1,
batch_size=64)
engine.save("./my_model")
"""
if training:
assert 'train' in self._serial_main_progs, \
"training model is not ready, please call `engine.prepare()` first."
"training model is not ready, please call `engine._prepare_single_mode('train')` first."
serial_program = self._serial_main_progs["train"]
dist_main_prog = self._dist_main_progs["train"][self._cur_rank]
dist_context = self._dist_contexts["train"]
......@@ -721,7 +1074,7 @@ class Engine:
dist_main_program=dist_main_prog,
dist_context=dist_context)
else:
assert mode, "Please set the 'mode' you want to save."
mode = "predict"
feed_vars = self._feed_vars[mode]['inputs']
fetch_vars = self._fetch_vars[mode]['outputs']
dist_main_prog = self._dist_main_progs[mode][self._cur_rank]
......@@ -731,18 +1084,59 @@ class Engine:
self._executor,
program=dist_main_prog)
def load(self, path, strict=True, load_optimizer=True, mode=None):
if not mode:
mode = self.mode
assert mode, "Please set the 'mode' you want to load."
def load(self, path, strict=True, load_optimizer=True):
"""
Load the stored model, parameters and optimizer states.
dist_main_prog = self._dist_main_progs[mode][self._cur_rank]
dist_context = self._dist_contexts[mode]
self._saver.load(path, dist_main_prog, dist_context, strict,
load_optimizer)
Args:
path (str): The prefix of files storing the model states and
optimizer states.
strict (bool, optional): Whether to skip the loading of mismatch
parameter or raise an error when mismatch happens (not found
the parameter in file storing model states of or receives a
mismatch shape). Default: False.
load_optimizer (bool, optional): If True, the stored optimizer
states is restored. Otherwise, the optimizer states is intialized
from scratch. Default: False.
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.vision.transforms as T
import paddle.distributed.auto_parallel as auto
from paddle.vision.datasets import MNIST
transform = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = MNIST(mode='train', transform=transform)
model = paddle.vision.models.LeNet()
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(
learning_rate=0.001, parameters=model.parameters())
metrics = paddle.metric.Accuracy(topk=(1, 2))
engine = auto.Engine(model, loss, optimizer, metrics)
engine.fit(train_dataset,
epochs=1,
batch_size=64)
engine.save("./my_model")
engine.load("./my_model")
"""
self._strict = strict
self._state_dict, self._dist_attr = self._saver.load(
path, load_optimizer)
return self._state_dict, self._dist_attr
@staticmethod
def get_lr_scheduler(program):
def _get_lr_scheduler(program):
lr_sheduler = None
if hasattr(program, 'lr_sheduler'):
from paddle.optimizer.lr import LRScheduler
......@@ -750,6 +1144,20 @@ class Engine:
assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler"
return lr_sheduler
def _get_lr(self, optimizer):
if isinstance(optimizer, paddle.optimizer.Optimizer):
return optimizer.get_lr()
elif isinstance(optimizer, paddle.fluid.optimizer.Optimizer):
if isinstance(optimizer._learning_rate, float):
return optimizer._learning_rate
else:
return optimizer._learning_rate()
else:
raise TypeError(
"'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
" or `paddle.fluid.optimizer.Optimizer`, but got {}.".format(type(optimizer))
)
@property
def mode(self):
return self._mode
......@@ -781,3 +1189,11 @@ class Engine:
@property
def fetch_vars(self):
return self._fetch_vars[self.mode]
@property
def inputs(self):
return self.inputs_spec
@property
def labels(self):
return self.labels_spec
......@@ -19,13 +19,13 @@ import paddle
from paddle.nn import Layer
from paddle.jit import to_static, not_to_static
from paddle.distributed.utils import get_logger
from paddle.fluid.framework import Operator, Parameter, _non_static_mode
from paddle.fluid.framework import program_guard
from paddle.fluid.executor import global_scope
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
from .utils import to_list
from .utils import get_logger
from .converter import Converter
......
......@@ -12,101 +12,199 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
import copy
import paddle
import paddle.fluid.core as core
from paddle.fluid.framework import Variable
from paddle.fluid.framework import _non_static_mode
from paddle.fluid import core
from .process_mesh import ProcessMesh
from .process_mesh import get_current_process_mesh
from .process_mesh import set_current_process_mesh
from .process_mesh import reset_current_process_mesh
from .dist_context import get_default_distributed_context
from .dist_tensor import DistributedTensor
from .dist_op import DistributedModule
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from .dist_op import DistributedOperatorHelper
from .utils import verify_shard_spec, convert_to_dims_mapping
def _static_mode_check():
if _non_static_mode():
raise RuntimeError("Auto-parallel only supports static mode for now, "
"please use paddle.enable_static() first.")
def shard_tensor(x, dist_attr=None):
def shard_tensor(x, process_mesh=None, shard_spec=None):
"""
Add distributed attributes for a tensors.
Shard a tensor on a process mesh according to the shard specification.
Args:
x (Tensor): the tensor to be sharded.
dist_attr (dict): the tensor distributed attributes. The accepted attributes are as follow:
"process_mesh": a nested list an to describe the mesh topology of logical processes.
"dims_mapping": a list to describe the mapping between `x` and `process_mesh`, the dimension
`i` of `x` is split across the dimension `dims_mapping[i]` of `process_mesh`,
where -1 means that tensor dimension is not split.
Both process_mesh and dims_mapping are optional and users can specify as need.
process_mesh (ProcessMesh, optional): An instance of ProcessMesh describes a mesh
topology of the used logical processes where the tensor is sharded. If it is None,
the found current process mesh will be used. And an error will be raised if the
current process mesh cannot be found. Default: None.
shard_spec (list, optional): a list to describe the sharding mapping between `x` and `process_mesh`,
which means the dimension `i` of `x` is split across the dimension `shard_spec[i]` of `process_mesh`,
where `None` means that tensor dimension is not split. For example, given a tensor wih
the shape [6, 12] and a process mesh with the shape [2, 3] and the dimension names ["x", "y"]:
If `shard_spec=["x", "y"]`, each shard of the tensor will have a shape [3, 4];
If `shard_spec=["y", "x"]`, each shard of the tensor will have a shape [2, 6];
If `shard_spec=["x", None]`, each shard of the tensor will have a shape [3, 12];
If `shard_spec=[None, "x"]`, each shard of the tensor will have a shape [6, 4];
If `shard_spec=["y", None]`, each shard of the tensor will have a shape [2, 12];
If `shard_spec=[None, "y"]`, each shard of the tensor will have a shape [6, 4];
If `shard_spec=[None, None]`, each shard of the tensor will have a shape [6, 12];
If the `shard_spec` is None, the tensor will be replicated across all the processes of `process_mesh`.
In the above example, the `shard_spec=None` is same as 'shard_spec=[None, None]'. Defaults: None.
Returns:
Tensor: the tensor `x` annotated with distributed attributes.
Tensor: the tensor `x` annotated with sharding information.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
import paddle.distributed.auto_parallel as auto
mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
x = paddle.ones([4, 6])
dist.shard_tensor(x, dist_attr={"process_mesh": [[0, 1], [2, 3]],
"dims_mapping": [0, -1]})
shard_spec = ["x", "y"]
auto.shard_tensor(x, mesh, shard_spec)
"""
_static_mode_check()
assert dist_attr is None or isinstance(dist_attr, (dict, TensorDistributedAttribute)), \
"The type of dist_attr must be None, dict or TensorDistributedAttribute."
dist_tensor = DistributedTensor(x, dist_attr)
dist_tensor.dist_attr.mark_annotated_as(dist_attr)
if process_mesh is not None:
assert isinstance(process_mesh, ProcessMesh), \
"Argument process_mesh {} is not an instance of ProcessMesh".format(process_mesh)
else:
process_mesh = get_current_process_mesh()
assert process_mesh is not None, \
"Specify the process mesh argument or use ProcessMesh context manager first."
assert isinstance(shard_spec, list), \
"Argument shard_spec {} is not an instance of list".format(shard_spec)
dist_tensor = DistributedTensor(x)
serial_tensor = dist_tensor.serial_tensor
dist_tensor.dist_attr.process_mesh = process_mesh
if serial_tensor.type == core.VarDesc.VarType.READER \
or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES:
tensor_shape = []
else:
tensor_shape = serial_tensor.shape
if shard_spec is not None:
assert verify_shard_spec(shard_spec, tensor_shape, process_mesh), \
"For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format(
serial_tensor.name, shard_spec, tensor_shape, process_mesh)
dist_tensor.dist_attr.dims_mapping = convert_to_dims_mapping(
shard_spec, process_mesh)
if process_mesh is not None:
dist_tensor.dist_attr.mark_annotated("process_mesh")
if shard_spec is not None:
dist_tensor.dist_attr.mark_annotated("dims_mapping")
default_dist_ctx = get_default_distributed_context()
default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
dist_tensor = default_dist_ctx.get_dist_tensor_for_program(x)
return x
def shard_op(op_fn, dist_attr=None):
def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None):
"""
Call a functioin and add distributed attributes for ops added by the function.
Shard an operation on a process mesh according to its input and output shard specification.
Args:
op_fn (callable): a callable operator or module to be sharded.
dist_attr (dict): the operator distributed attributes. The accepted attributes are classified into
two categories. The first category decsribes the distributed attributes shared by all inputs and
outputs, and only `process_mesh` can be specified now. The second category describes distributed
attributes for inputs or outputs same as the `dist_attr` of `shard_tensor`. All of them are
optional and users can specify them as need. Note that `process_mesh` for operators must be the
same as these process_meshes for inputs and outputs.
op (Callable): a callable operator or module to be sharded.
process_mesh (ProcessMesh, optional): An instance of ProcessMesh describes a mesh
topology of the used logical processes where the op is sharded. All of its inputs and
outputs are sharded by this process mesh. If it is None, the found current process mesh
will be used. And an error will be raised if the current process mesh cannot be found.
Default: None.
in_shard_specs (list of list, optional): a list of list to describe the sharding specifications
for the inputs. Each item of `in_shard_specs` is a `shard_spec` between the correspoinding input
and `process_mesh`. If one item is None, the cooresponding input is replicated across all processes
If it is None, all inputs are replicated accross all processes. Note that the lenght of the
`in_shard_specs` should be equal to the actual number of inputs when calling this operation.
Default: None.
out_shard_specs (list of list, optional): a list of list to describe the sharding specifications
for the outputs. Each item of `out_shard_specs` is a `shard_spec` between the correspoinding output
and `process_mesh`. If one item is None, the cooresponding output is replicated across all processes
If it is None, all outputs are replicated accross all processes. Note that the lenght of the
`in_shard_specs` should be equal to the actual number of inputs when calling this operation.
Default: None. Default: None.
Returns:
list: the outputs of the function `op_fn`, which are annotated with distributed attributes.
Outputs of `op`, each of which is annotated with sharding information.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
import paddle.distributed.auto_parallel as auto
x = paddle.ones([4, 6])
y = paddle.zeros([4, 6])
dist_add = dist.shard_op(paddle.add,
dist_attr={
"process_mesh": [[2, 3, 1], [0, 4, 5]],
x: {"dims_mapping": [-1, 0]},
y: {"dims_mapping": [0, -1]}
})
mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
dist_add = auto.shard_op(paddle.add,
in_shard_specs=[["x", "y"], ["y", None]],
out_shard_specs=[[None, "x"]])
dist_add(x, y)
"""
_static_mode_check()
assert dist_attr is None or isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \
"The type of dist_attr must be dict or OperatorDistributedAttribute."
dist_module = DistributedModule(op_fn, dist_attr)
return dist_module
if process_mesh is not None:
assert isinstance(process_mesh, ProcessMesh), \
"Argument process_mesh {} is not an instance of ProcessMesh".format(process_mesh)
else:
process_mesh = get_current_process_mesh()
assert process_mesh is not None, \
"Specify the process mesh argument or use ProcessMesh context manager first."
in_dims_mappings = []
if in_shard_specs is not None:
assert all((isinstance(shard_spec, list) or shard_spec is None) for shard_spec in in_shard_specs), \
"in_shard_spec {} is not a list of list or None".format(in_shard_specs)
for shard_spec in in_shard_specs:
if shard_spec is not None:
in_dims_mappings.append(
convert_to_dims_mapping(shard_spec, process_mesh))
else:
in_dims_mappings.append(None)
out_dims_mappings = []
if out_shard_specs is not None:
assert all((isinstance(shard_spec, list) or shard_spec is None) for shard_spec in out_shard_specs), \
"out_shard_spec {} is not a list of list or None".format(out_shard_specs)
for shard_spec in out_shard_specs:
if shard_spec is not None:
out_dims_mappings.append(
convert_to_dims_mapping(shard_spec, process_mesh))
else:
out_dims_mappings.append(None)
op = DistributedOperatorHelper(op, process_mesh, in_dims_mappings,
out_dims_mappings)
return op
def recompute(op):
class RecomputeOperator:
def __init__(self, op):
self._op = op
def __call__(self, *args, **kwargs):
default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block()
op_size = len(cur_block.ops)
output = self._op(*args, **kwargs)
new_op_size = len(cur_block.ops)
for idx in range(op_size, new_op_size):
op = cur_block.ops[idx]
op._set_attr("is_recompute@auto_parallel", True)
return output
return RecomputeOperator(op)
_g_fetched_tensors = {}
def fetch(tensor, name=None):
if name is None:
_g_fetched_tensors[tensor.name] = tensor
else:
_g_fetched_tensors[name] = tensor
def _get_fetches():
return _g_fetched_tensors
......@@ -42,6 +42,7 @@ from .utils import make_data_unshard
from .utils import set_grad_var_shape
from .utils import print_program_with_dist_attr
from .utils import SerialProgramInfo
from .utils import get_logger
from .reshard import Resharder
from .cluster import Cluster
from .mapper import mapping
......@@ -84,7 +85,7 @@ class AutoParallelizer:
self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING")
self._need_rank_mapping = True if self._need_rank_mapping and \
self._need_rank_mapping.lower() == 'true' else False
self._pass_context = None
# self._pass_context = None
def _remove_distributed_attrs(self, main_program):
suffix = core.kAutoParallelSuffix()
......@@ -143,10 +144,11 @@ class AutoParallelizer:
def _apply_optimize(self, main_program, startup_program, params_grads):
optimizer = copy.deepcopy(self._optimizer)
with program_guard(main_program, startup_program):
optimize_ops = copy.deepcopy(
self._optimizer).apply_gradients(params_grads)
optimize_ops = optimizer.apply_gradients(params_grads)
self._dist_context._lr_optimizer = optimizer
# update completion
self._completer = Completer(self._dist_context)
self._completer.complete_update_annotation(main_program)
......@@ -165,6 +167,15 @@ class AutoParallelizer:
config)
auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._pass_context)
params_grads = self._pass_context.get_attr("params_grads")
config = copy.deepcopy(self._dist_strategy.sharding_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["rank_id"] = rank
auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", config)
auto_parallel_clip_pass.apply([main_program], [startup_program],
self._pass_context)
if self._dist_strategy.gradient_merge:
config = copy.deepcopy(self._dist_strategy.gradient_merge_configs)
......
......@@ -22,7 +22,6 @@ from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import _non_static_mode, unique_name
from paddle.distributed.passes import new_pass
from paddle.distributed.utils import get_logger
from .reshard import Resharder
from .partitioner import Partitioner
......@@ -31,6 +30,7 @@ from .dist_saver import DistributedSaver
from .dist_loader import NonIterableGeneratorLoader
from .utils import make_data_unshard, set_grad_var_shape
from .utils import print_program_with_dist_attr, to_list
from .utils import get_logger
from .process_group import get_all_process_groups, get_world_process_group
from .dist_context import DistributedContext, get_default_distributed_context
......@@ -160,8 +160,8 @@ class Parallelizer:
# apply quantization pass
# The pass can be applied when mode must be 'train'
if self._mode == 'train' and self._strategy.qat:
config = copy.deepcopy(self._strategy.qat_configs)
if self._mode == 'train' and self._strategy.qat.enable:
config = copy.deepcopy(self._strategy.qat.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
auto_parallel_quantization_pass = new_pass(
......@@ -176,8 +176,8 @@ class Parallelizer:
# apply amp pass
# FIXME we disenable amp for eval since it has a little bug with
# eval program and which will be fixed in future
if self._mode == 'train' and self._strategy.amp:
config = copy.deepcopy(self._strategy.amp_configs)
if self._mode == 'train' and self._strategy.amp.enable:
config = copy.deepcopy(self._strategy.amp.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["loss"] = loss
......@@ -195,8 +195,8 @@ class Parallelizer:
# apply recompute pass
# recompute is then train-only optimization
if self._mode == "train" and self._strategy.recompute:
config = copy.deepcopy(self._strategy.recompute_configs)
if self._mode == "train" and self._strategy.recompute.enable:
config = copy.deepcopy(self._strategy.recompute.to_dict())
config["dist_context"] = self._dist_context
config["no_grad_set"] = None
config["loss"] = loss
......@@ -217,12 +217,12 @@ class Parallelizer:
config = {}
config["dist_context"] = self._dist_context
config["global_rank"] = rank
config["use_sharding"] = self._strategy.sharding
config["use_sharding"] = self._strategy.sharding.enable
dp_pass = new_pass("auto_parallel_data_parallel_optimization", config)
dp_pass.apply([main_program], [startup_program], self._pass_context)
if self._strategy.sharding:
config = copy.deepcopy(self._strategy.sharding_configs)
if self._strategy.sharding.enable:
config = copy.deepcopy(self._strategy.sharding.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["global_rank"] = rank
......@@ -230,11 +230,11 @@ class Parallelizer:
config)
auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._pass_context)
params_grads = self._pass_context.get_attr("params_grads")
# GradClip is train-only optimization
if self._mode == "train":
config = copy.deepcopy(self._strategy.sharding_configs)
config = copy.deepcopy(self._strategy.sharding.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["rank_id"] = rank
......@@ -244,8 +244,8 @@ class Parallelizer:
self._pass_context)
# gradient_merge is then train-only optimization
if self._mode == "train" and self._strategy.gradient_merge:
config = copy.deepcopy(self._strategy.gradient_merge_configs)
if self._mode == "train" and self._strategy.gradient_merge.enable:
config = copy.deepcopy(self._strategy.gradient_merge.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
auto_parallel_gradient_merge_pass = new_pass(
......
......@@ -12,86 +12,90 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy
import numpy as np
import copy
import paddle
# Use to store the previous and current process mesh
_g_previous_process_mesh = None
_g_current_process_mesh = None
def _get_nested_list_shape(nested_list):
"""
Get the shape of a nested_list.
"""
result = []
while isinstance(nested_list, list):
result.append(len(nested_list))
nested_list = nested_list[0]
return result
def get_current_process_mesh():
global _g_current_process_mesh
return _g_current_process_mesh
def _flatten_nested_list(nested_list):
"""
Get a list of all items in a nested_list.
Ref: https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-a-list-of-lists
"""
result = numpy.array(nested_list).flatten().tolist()
return result
def set_current_process_mesh(process_mesh):
global _g_previous_process_mesh
global _g_current_process_mesh
_g_previous_process_mesh = _g_current_process_mesh
_g_current_process_mesh = process_mesh
class ProcessMesh(object):
r"""
The class `Processmesh` describes the topology of logical processes.
A mesh is an N-dimensional array. The shape of the N-dimensional
array represents the topology of logical processes and every
element of the N-dimensional array represent a logical process. For
example, the 2-dimensional array [[2, 4, 5], [0, 1, 3]]
illustrates six logical processes organized as the topology [2, 3],
i.e., the shape of the 2-dimensional array. With the above topology,
there are two parallel groups, where the first parallel group has a
parallel degree of 2 and the second one has a parallel degree of 3.
And the first logical process is the one with id=2.
Args:
mesh (list): an N-dimensional array (nested list) describes the toplogy
of logical processes. The shape of the N-dimensional array
represents the topology of logical processes and every
element of the N-dimensional array represents a logical process.
def reset_current_process_mesh():
global _g_previous_process_mesh
global _g_current_process_mesh
_g_current_process_mesh = _g_previous_process_mesh
Returns:
None
class ProcessMesh(object):
"""
The `Processmesh` object describes the topology of the used processes.
Raises:
ValueError: If `mesh` is not an instance of list.
Args:
mesh (list|numpy.array): an n-dimensional array describes the toplogy
of the processes.
dim_names (list, optional): the i-th element of this list gives the name of the
i-th dimension of the mesh.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
paddle.enable_static()
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]])
assert mesh.topology == [2, 3]
assert mesh.processes == [2, 4, 5, 0, 1, 3]
mesh = auto.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"])
assert mesh.shape == [2, 3]
assert mesh.processe_ids == [2, 4, 5, 0, 1, 3]
"""
def __init__(self, mesh):
if mesh is None or not isinstance(mesh, list):
raise ValueError('mesh must be an instance of list.')
processes = _flatten_nested_list(mesh)
assert all(isinstance(p, int) for p in processes), \
("All elements of mesh must be integer")
assert min(processes) >= 0, ('All elements of mesh must be >= 0.')
unique_processes = set(processes)
assert len(unique_processes) == len(processes), (
'All elements of mesh must be unique.')
self._topology = _get_nested_list_shape(mesh)
self._processes = processes
def __init__(self, mesh=None, dim_names=None, shape=None, process_ids=None):
# Use shape and process_ids just for compatibility
# Users should not use these directly
if mesh is None:
assert shape is not None
assert process_ids is not None
mesh = np.array(process_ids).reshape(shape)
if not isinstance(mesh, list) and \
not isinstance(mesh, np.ndarray):
raise ValueError(
'The mesh must be an instance of list or np.ndarray.')
if isinstance(mesh, list):
mesh = np.array(mesh)
self._mesh = mesh
self._shape = list(self._mesh.shape)
self._process_ids = self._mesh.flatten().tolist()
assert all(isinstance(p, int) for p in self._process_ids), \
("All elements of the mesh must be integer")
assert min(
self._process_ids) >= 0, ('All elements of the mesh must be >= 0.')
unique_process_ids = set(self._process_ids)
assert len(unique_process_ids) == len(
self._process_ids), ('All elements of the mesh must be unique.')
if dim_names is not None:
assert len(dim_names) == len(self._shape), \
("The length of dims_names must be same as the shape of the mesh.")
self._dim_names = copy.deepcopy(dim_names)
else:
self._dim_names = ["d" + str(i) for i in range(len(self._shape))]
unique_dim_names = set(self._dim_names)
assert len(unique_dim_names) == len(self._dim_names), (
'All dim_names {} must be unique.'.format(dim_names))
# Store all process meshes
from .dist_context import get_default_distributed_context
......@@ -103,31 +107,117 @@ class ProcessMesh(object):
pg0.add_ranks(self.processes)
@property
def topology(self):
r"""
Get the topology of logical processes belonging to this ProcessMesh.
This is the shape of `mesh` used to initialized this ProcessMesh.
def shape(self):
"""
Get the shape of this ProcessMesh.
"""
return self._topology
return self._shape
@property
def processes(self):
r"""
Get a list of all processes belonging to this ProcessMesh.
def process_ids(self):
"""
Get the process ids belonging to this ProcessMesh.
"""
return self._processes
return self._process_ids
@property
def dim_names(self):
"""
Get the dimension names of this ProcessMesh.
"""
return self._dim_names
@property
def ndim(self):
r"""
Get the number of dimension of ProcessMesh.
"""
return len(self._topology)
Get the number of dimension of this ProcessMesh.
"""
return len(self._shape)
@property
def mesh(self):
"""
Get the underlying mesh of ProcessMesh.
"""
return self._mesh
@property
def topology(self):
return self._shape
@property
def processes(self):
return self._process_ids
def __getitem__(self, index):
if isinstance(index, tuple):
new_dim_names = []
for i, item in enumerate(index):
if isinstance(item, slice):
new_dim_names.append(self._dim_names[i])
new_mesh = self._mesh[index]
if new_mesh.shape:
return ProcessMesh(new_mesh, new_dim_names)
else:
# Wrap a scalar into a list but without dim_names
return ProcessMesh([new_mesh])
elif isinstance(index, slice):
new_mesh = self._mesh[index]
new_dim_names = self._dim_names
return ProcessMesh(new_mesh, new_dim_names)
else:
new_mesh = self._mesh[index]
new_dim_names = self._dim_names[1:]
return ProcessMesh(new_mesh, new_dim_names)
def __enter__(self):
set_current_process_mesh(self)
default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block()
self._old_var_names = list(cur_block.vars.keys())
self._old_op_size = len(cur_block.ops)
def __exit__(self, exc_type, exc_value, exc_traceback):
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
default_prog = paddle.fluid.default_main_program()
cur_block = default_prog.current_block()
new_var_names = list(cur_block.vars.keys())
new_op_size = len(cur_block.ops)
from .dist_context import get_default_distributed_context
default_dist_ctx = get_default_distributed_context()
for name in new_var_names:
if name not in self._old_var_names:
tensor = cur_block.vars[name]
dist_tensor = default_dist_ctx.get_dist_tensor_for_program(
tensor)
if dist_tensor is None:
dist_tensor = DistributedTensor(cur_block.vars[name],
{"process_mesh": self})
dist_tensor.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
else:
if dist_tensor.dist_attr.process_mesh is None:
dist_tensor.dist_attr.process_mesh = self
dist_tensor.dist_attr.mark_annotated("process_mesh")
for idx in range(self._old_op_size, new_op_size):
op = cur_block.ops[idx]
dist_op = default_dist_ctx.get_dist_op_for_program(op)
if dist_op is None:
dist_op = DistributedOperator(op, {"process_mesh": self})
dist_op.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_op_for_program(dist_op)
else:
if dist_op.dist_attr.process_mesh is None:
dist_op.dist_attr.process_mesh = self
dist_op.dist_attr.mark_annotated("process_mesh")
reset_current_process_mesh()
def __eq__(self, other):
if not isinstance(other, ProcessMesh):
return False
if self.topology != other.topology or self.processes != other.processes:
if self.shape != other.shape or self.process_ids != other.process_ids:
return False
return True
......@@ -135,6 +225,6 @@ class ProcessMesh(object):
return not self.__eq__(other)
def __str__(self):
str = "shape {} and process group {}".format(self.topology,
self.processes)
str = "shape {}, process_ids {}, dim_nams {}".format(
self.shape, self.process_ids, self.dim_names)
return str
......@@ -81,54 +81,57 @@ class ProcessMesh(core.ProcessMesh):
return self._mesh
# def compute_compatible_process_meshes(process_meshes):
# """Compute the compatible process mesh given a list of process meshes."""
# if not process_meshes:
# return None
# def _compute_compatible_two_process_meshes(pm1, pm2):
# if pm1 is None:
# return True, pm2
# if pm2 is None:
# return True, pm1
# if pm1 == pm2:
# return True, pm1
# if pm1.device_mesh != pm2.device_mesh:
# return False, None
# if pm1.process_ids == pm2.process_ids:
# if len(pm1.shape) >= len(pm2.shape):
# return True, pm1
# else:
# return True, pm2
# process_set1 = set(pm1.process_ids)
# process_set2 = set(pm2.process_ids)
# if process_set1.issubset(process_set2):
# return True, pm2
# if process_set2.issubset(process_set1):
# return True, pm1
# return False, None
# compatible_result = None
# for process_mesh in process_meshes:
# compatible, compatible_result = _compute_compatible_two_process_meshes(
# compatible_result, process_mesh)
# if not compatible:
# return None
# return ProcessMesh(compatible_result.mesh, compatible_result.dim_names)
# def merge_process_meshes(process_meshes):
# """Merge a list of process meshes."""
# merged_process_mesh = None
# merged_process_ids = set()
# device_type = ""
# for process_mesh in process_meshes:
# if process_mesh is not None:
# process_ids = set(process_mesh.process_ids)
# if not device_type:
# device_type = process_mesh.device_type
# assert device_type != process_mesh.device_type, \
# "All process meshes must have the same device_type."
# merged_process_ids.union(process_ids)
# if len(merged_process_ids) != 0:
# merged_process_mesh = ProcessMesh(list(merged_process_ids))
# return merged_process_mesh
def compute_compatible_process_mesh(process_meshes):
"""Compute the compatible process mesh given a list of process meshes."""
if not process_meshes:
return None
def _compute_compatible_of_two_process_meshes(pm1, pm2):
if pm1 is None:
return True, pm2
if pm2 is None:
return True, pm1
if pm1 == pm2:
return True, pm1
if pm1.process_ids == pm2.process_ids:
if len(pm1.shape) >= len(pm2.shape):
return True, pm1
else:
return True, pm2
process_set1 = set(pm1.process_ids)
process_set2 = set(pm2.process_ids)
if process_set1.issubset(process_set2):
return True, pm2
if process_set2.issubset(process_set1):
return True, pm1
return False, None
compatible_result = None
for process_mesh in process_meshes:
compatible, compatible_result = _compute_compatible_of_two_process_meshes(
compatible_result, process_mesh)
if not compatible:
return None
if compatible_result.empty():
return None
if isinstance(compatible_result, core.ProcessMesh):
mesh = np.array(compatible_result.process_ids).reshape(
compatible_result.shape)
return ProcessMesh(mesh, compatible_result.dim_names)
elif isinstance(compatible_result, ProcessMesh):
return ProcessMesh(compatible_result.mesh, compatible_result.dim_names)
else:
raise ValueError("Unrecognized ProcessMesh.")
def merge_process_mesh(process_meshes):
"""Merge a list of process meshes."""
merged_process_mesh = None
merged_process_ids = set()
for process_mesh in process_meshes:
if process_mesh is not None:
process_ids = set(process_mesh.process_ids)
merged_process_ids = merged_process_ids.union(process_ids)
if len(merged_process_ids) != 0:
merged_process_mesh = ProcessMesh(list(merged_process_ids))
return merged_process_mesh
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import os
import copy
import argparse
from . import constants
class BaseConfig(object):
def __init__(self, category, config_dict=None):
self._category = category
self._config_dict = None
if config_dict is not None:
if isinstance(config_dict, dict):
self._config_dict = config_dict
else:
raise ValueError(
"Expected a dictionary. But received: {}".format(
config_dict))
# Initialize attributes by the default config
config = constants.get_category_default_config(self._category)
for field, default_value in config.items():
setattr(self, field, default_value)
# Overide attributes by the config_dict
if self._config_dict:
self.from_dict(self._config_dict)
def from_dict(self, config_dict):
config = constants.get_category_default_config(self._category)
for field in config.keys():
value = config_dict.get(field, constants.NOT_FOUND)
# Use the default value if we cannot found the value
if value != constants.NOT_FOUND:
setattr(self, field, value)
def to_dict(self):
result_dict = {}
config = constants.get_category_default_config(self._category)
for field in config.keys():
value = getattr(self, field)
result_dict[field] = value
for field, value in self.__dict__.items():
if isinstance(value, BaseConfig):
result_dict[field] = value.to_dict()
return result_dict
def __repr__(self):
return yaml.dump(self.to_dict(),
default_flow_style=False,
sort_keys=True,
indent=4)
def __deepcopy__(self, memo):
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
setattr(result, k, copy.deepcopy(v, memo))
return result
class RecomputeConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.RECOMPUTE
super(RecomputeConfig, self).__init__(category, config_dict)
class AMPConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.AMP
super(AMPConfig, self).__init__(category, config_dict)
class ShardingConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.SHARDING
super(ShardingConfig, self).__init__(category, config_dict)
class GradientMergeConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.GRADIENT_MERGE
super(GradientMergeConfig, self).__init__(category, config_dict)
class QATConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.QAT
super(QATConfig, self).__init__(category, config_dict)
class TuningConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.TUNING
super(TuningConfig, self).__init__(category, config_dict)
class Strategy(BaseConfig):
"""
The `Strategy` object is used to configure the paralleization and optimization beheviors.
Args:
config (dict|string, optional): If this is None, the default configurations will used.
If this is a dictionary, the recognized key-value of it will be used to override the default
configurations while other default configurations are left unchanged. If this is a string,
it is interpreted as the path to a YAML configuration and will be loaded to override the
corresponding default configurations.
Examples:
.. code-block:: python
import paddle
import paddle.distributed.auto_parallel as auto
strategy = auto.Strategy()
sharding = strategy.sharding
self.assertEqual(sharding.enabled, False)
self.assertEqual(sharding.stage, 1)
self.assertEqual(sharding.degree, 8)
sharding.enabled = True
sharding.stage = 2
sharding.degree = 2
self.assertEqual(sharding.enabled, True)
self.assertEqual(sharding.stage, 2)
self.assertEqual(sharding.degree, 2)
"""
def __init__(self, config=None):
if config is not None:
if isinstance(config, dict):
self._config_dict = copy.deepcopy(config)
# elif os.path.exists(config):
# with open(config, "rb") as yaml_file:
# self._config_dict = yaml.load(yaml_file, Loader=yaml.Loader)
else:
raise ValueError(
"Expected a dictionary. But received: {}".format(config))
else:
self._config_dict = {}
category = constants.BASE
super(Strategy, self).__init__(category, self._config_dict)
config_dict = self._config_dict.get(constants.RECOMPUTE, None)
self.recompute = RecomputeConfig(config_dict)
config_dict = self._config_dict.get(constants.AMP, None)
self.amp = AMPConfig(config_dict)
config_dict = self._config_dict.get(constants.SHARDING, None)
self.sharding = ShardingConfig(config_dict)
config_dict = self._config_dict.get(constants.GRADIENT_MERGE, None)
self.gradient_merge = GradientMergeConfig(config_dict)
config_dict = self._config_dict.get(constants.QAT, None)
self.qat = QATConfig(config_dict)
config_dict = self._config_dict.get(constants.TUNING, None)
self.tuning = TuningConfig(config_dict)
......@@ -16,7 +16,7 @@ import copy
from abc import ABC, abstractmethod
import logging
from paddle.distributed.utils import get_logger
from ..utils import get_logger
from .trial import TrialStatus
from .trial import OptimizationTunerTrial as Trial
......@@ -110,13 +110,13 @@ class ShardingStageAlgorithm(AlgorithmBase):
# TODO import trial class & copy strategy
def __init__(self, config):
super().__init__(config)
self._changed_configs = ["sharding_configs"]
self._changed_configs = ["sharding"]
def _init_spaces(self):
self._max_stage = 3
self._trial_idx = 0
stage_range = self._config.sharding_configs.get("stage_range", None)
stage_range = self._config.sharding.to_dict().get("tuning_range", None)
if stage_range:
assert set(stage_range).issubset(
set([0, 1, 2, 3])
......@@ -136,9 +136,8 @@ class ShardingStageAlgorithm(AlgorithmBase):
stage = self._stage_range[self._trial_idx]
new_strategy = copy.deepcopy(self._config.dist_strategy)
config_dict = new_strategy.sharding_configs
config_dict["stage"] = stage
new_strategy.sharding_configs = config_dict
sharding = new_strategy.sharding
sharding.stage = stage
name = "trial-sharding-stage{}".format(stage)
trial = Trial(new_strategy, name, self.changed_configs)
......
......@@ -17,15 +17,13 @@ import copy
import pathlib
import paddle
from paddle.distributed import fleet
from ..strategy import Strategy
_tuning_supported_passes = ["sharding", "recompute"]
_strategy_config_suffiex = "_configs"
def _get_pass_config(strategy, pass_name):
config_name = pass_name + _strategy_config_suffiex
config = getattr(strategy, config_name)
config = getattr(strategy, pass_name)
return config
......@@ -38,10 +36,8 @@ class TuningConfig(object):
def __init__(self, user_config, strategy):
if not isinstance(strategy, fleet.DistributedStrategy):
raise TypeError(
"'strategy' must be object of class `fleet.DistributedStrategy`."
)
if not isinstance(strategy, Strategy):
raise TypeError("'strategy' must be object of class `Strategy`.")
if not user_config:
user_config = {}
......@@ -116,11 +112,11 @@ class TuningConfig(object):
for p in _tuning_supported_passes:
if getattr(self._dist_strategy, p) and _get_pass_config(
self._dist_strategy, p)["enable_tuning"]:
self._dist_strategy, p).enable_tuning:
# TODO distinguish different args of each passes
self._tuning_passes_name.add(p)
config_name = p + _strategy_config_suffiex
config_name = p
p_dict = getattr(self._dist_strategy, config_name)
self.__dict__[config_name] = p_dict
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# import yaml
import os
import sys
import copy
......@@ -29,7 +30,6 @@ import paddle
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.distributed.passes import new_pass, PassContext
from paddle.distributed.utils import get_logger
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context
from paddle.distributed.auto_parallel.completion import Completer
......@@ -39,6 +39,7 @@ from paddle.distributed.auto_parallel.process_group import clear_all_process_gro
from paddle.distributed.auto_parallel.utils import debug_program
from paddle.distributed.auto_parallel.utils import make_data_unshard, set_grad_var_shape
from ..utils import get_logger
from .config import TuningConfig
from .algorithms import new_algorithm
from .trial import TrialStatus
......@@ -256,8 +257,8 @@ class OptimizationTuner:
startup_program = dist_context.serial_startup_program
# applying optimization pass
if new_strategy.amp:
config = copy.deepcopy(new_strategy.amp_configs)
if new_strategy.amp.enable:
config = copy.deepcopy(new_strategy.amp.to_dict())
config["dist_context"] = dist_context
config["params_grads"] = dist_context._params_grads
......@@ -275,8 +276,8 @@ class OptimizationTuner:
auto_parallel_amp_pass.apply([main_program], [startup_program],
pass_context)
if new_strategy.recompute:
config = copy.deepcopy(new_strategy.recompute_configs)
if new_strategy.recompute.enable:
config = copy.deepcopy(new_strategy.recompute.to_dict())
config["dist_context"] = dist_context
config["no_grad_set"] = None
config["loss"] = dist_context.serial_loss
......@@ -303,8 +304,8 @@ class OptimizationTuner:
dist_context, dist_params_grads)
resharder.reshard()
if new_strategy.sharding:
config = copy.deepcopy(new_strategy.sharding_configs)
if new_strategy.sharding.enable:
config = copy.deepcopy(new_strategy.sharding.to_dict())
config["dist_context"] = dist_context
config["params_grads"] = dist_params_grads
config["global_rank"] = self.rank
......@@ -313,8 +314,8 @@ class OptimizationTuner:
auto_parallel_sharding_pass.apply([dist_main_prog],
[dist_startup_prog], pass_context)
if new_strategy.gradient_merge:
config = copy.deepcopy(new_strategy.gradient_merge_configs)
if new_strategy.gradient_merge.enable:
config = copy.deepcopy(new_strategy.gradient_merge.to_dict())
config["dist_context"] = dist_context
config["params_grads"] = dist_params_grads
auto_parallel_gradient_merge_pass = new_pass(
......@@ -492,9 +493,10 @@ The best trial is: [{}], whose configuration is following:
for line in summary_.split("\n"):
fw.write(line + "\n")
full_strategy = self.get_best_config()
full_strategy.save_to_prototxt(
os.path.join(self.project_dir, "tuned_dist_strategy.prototxt"))
# full_strategy = self.get_best_config()
# path = os.path.join(self.project_dir, "tuned_dist_strategy.yaml")
# with open(path, 'w') as outfile:
# yaml.dump(full_strategy, outfile, default_flow_style=False)
def clear(self):
"""
......
......@@ -156,9 +156,10 @@ class OptimizationTunerTrial(Trial):
draws += h1_format.format("{} auto=True <-> {}".format(name, name))
draws += line + "\n"
my_configs = getattr(self.space, name)
keys = my_configs.keys()
keys = my_configs.to_dict().keys()
for key in keys:
draws += h2_format.format(key, str(my_configs.get(key, None)))
draws += h2_format.format(
key, str(my_configs.to_dict().get(key, None)))
result_res = draws + border
return result_res
......
......@@ -28,6 +28,19 @@ from paddle.fluid.io import is_parameter, is_belong_to_optimizer
from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute
def get_logger(log_level, name="auto_parallel"):
logger = logging.getLogger(name)
logger.propagate = False
if not logger.handlers:
logger.setLevel(log_level)
log_handler = logging.StreamHandler()
log_format = logging.Formatter(
'%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s')
log_handler.setFormatter(log_format)
logger.addHandler(log_handler)
return logger
def is_valid_list_index(list, index):
if index >= -len(list) and index < len(list):
return True
......@@ -49,6 +62,58 @@ def is_dim_replicate(mapping):
return False
def verify_dims_mapping(dims_mapping, process_mesh):
if dims_mapping is None:
return False
if not all(isinstance(d, int) for d in dims_mapping):
return False
for i in range(len(dims_mapping)):
if dims_mapping[i] < -1 or dims_mapping[i] >= len(process_mesh.shape):
return False
for i in range(len(process_mesh.shape)):
if dims_mapping.count(i) > 1:
return False
return True
def convert_to_dims_mapping(shard_spec, process_mesh):
dims_mapping = []
for shard in shard_spec:
if shard is None:
dims_mapping.append(-1)
else:
dims_mapping.append(process_mesh.dim_names.index(shard))
return dims_mapping
def convert_to_shard_spec(dims_mapping, process_mesh):
shard_spec = []
for dim_mapping in dims_mapping:
if dim_mapping == -1:
shard_spec.append(None)
else:
shard_spec.append(process_mesh.dim_names[dim_mapping])
return shard_spec
def verify_shard_spec(shard_spec, tensor_shape, process_mesh):
if len(shard_spec) != len(tensor_shape):
return False
for shard in shard_spec:
if shard is not None and not isinstance(shard, str):
return False
if shard is not None and shard not in process_mesh.dim_names:
return False
dims_mapping = convert_to_dims_mapping(shard_spec, process_mesh)
if not verify_dims_mapping(dims_mapping, process_mesh):
return False
for i in range(len(tensor_shape)):
if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
and tensor_shape[i] % process_mesh.shape[dims_mapping[i]] != 0:
return False
return True
def compute_compatible_dim_mapping(dim_mappings):
if not dim_mappings:
return None
......@@ -1040,7 +1105,7 @@ def set_grad_var_shape(program, dist_context):
if op.type in [
"c_allreduce_sum", "c_identity", "scale", "cast",
'fill_any_like'
"fill_any_like"
]:
forward_var_name = op.input_arg_names[0]
elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad":
......@@ -1504,3 +1569,15 @@ def ring_id_to_process_group(ring_id):
if g.id == ring_id:
return g
return None
def find_higher_order_backward_op(program):
higher_order_op_suffix = ['_grad_grad', 'triple_grad']
for block in program.blocks:
for op in block.ops:
for suffix in higher_order_op_suffix:
if suffix in op.type:
return True
return False
......@@ -314,7 +314,9 @@ class AMPState(object):
consume_op_attr.set_input_dist_attr(
cast_name, in_var_dist_attr)
else:
assert in_var.dtype == dst_dtype
assert in_var.dtype == dst_dtype, "op [{}] expect input [{}] to be dtype [{}] BUT got [{}]. {}".format(
grad_op.type, in_name, dst_dtype, in_var.dtype,
str(grad_op))
for out_name in grad_op.output_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_output(
......
......@@ -13,12 +13,14 @@
# limitations under the License.
from collections import OrderedDict
import numpy as np
import paddle
from paddle.fluid import core, unique_name
from paddle.fluid.framework import default_main_program
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op
from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, ring_id_to_process_group
from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, is_backward_op, ring_id_to_process_group, find_higher_order_backward_op
from .pass_base import PassBase, PassType, register_pass
# add new optimizers supporting rescale_grad here
......@@ -31,6 +33,10 @@ __rescale_grad_supported_opts__ = [
__max_stream_num_allow__ = 16
def numel(var):
return np.prod(list(var.shape))
@register_pass("auto_parallel_data_parallel_optimization")
class DataParallelOptimizationPass(PassBase):
"""
......@@ -78,7 +84,9 @@ class DataParallelOptimizationPass(PassBase):
self._analyze_program()
self._prune_grad_scaling()
self._calc_comm_overlap()
self._fuse_allreduce()
grad_group = self._fuse_allreduce()
# self.summary(grad_group)
def _prune_grad_scaling(self):
......@@ -99,7 +107,14 @@ class DataParallelOptimizationPass(PassBase):
self._calc_wait_comms()
def _fuse_allreduce(self):
pass
if not self._could_be_fuse():
return []
grad_group = self._group_grads()
self._update_program(grad_group)
return grad_group
def _analyze_program(self):
"""
......@@ -154,7 +169,7 @@ class DataParallelOptimizationPass(PassBase):
def _could_be_prune(self):
return self.dist_context._gradient_scale and (
return self.dist_context.gradient_scale and (
self._support_rescale_grad or self._all_dp_groups_same_degree())
def _all_dp_groups_same_degree(self):
......@@ -316,3 +331,252 @@ class DataParallelOptimizationPass(PassBase):
'op_role': OpRole.Backward,
'ring_id': ring_id
})
def _could_be_fuse(self):
# TODO support gradient fuse higher order gradient.
# should analyse the dependencies of gradient in backward.
if find_higher_order_backward_op(default_main_program()):
return False
if self.use_sharding:
return False
return True
def _group_grads(self):
"""
conditions for gradients to be grouped:
1. group size < max_fuse_numel
2. same dp group
3. same dtype
4. dependency: grad would NOT be used by other ops within group segment
gradients inside same group would be fuse into one coalesce tensor
"""
block = default_main_program().global_block()
ops = block.ops
# group individual grad vars
# TODO consider fuse gradient for sharding reduce
# TODO let user to set fuse_grad_size
# emb = 50000 * h, ffn = 8 * h * h, mha = 4 * h * h
h = 2048
ffn_numel = 2 * (4 * h) * h
mha_numel = 3 * h * h + h * h
max_fuse_numel = ffn_numel + mha_numel
grad_groups = []
cur_group = GradientsGroup(ops, max_fuse_numel)
grouped_grad_names = set()
def collect_group(cur_group, grad_var, ring_id, i):
if len(cur_group.gradients) == 0:
cur_group = None
elif len(cur_group.gradients) == 1:
grouped_grad_names.remove(cur_group.gradients[0].name)
else:
cur_group.finalize()
grad_groups.append(cur_group)
new_group = GradientsGroup(ops, max_fuse_numel)
if grad_var:
new_group.add(grad_var, ring_id, i)
grouped_grad_names.add(grad_var.name)
return new_group
def op_depend_on_group(op, group):
vars_ = set(op.input_arg_names + op.output_arg_names)
grad_names = set([grad.name for grad in group.gradients])
return len(vars_.intersection(grad_names)) > 0
for i, op in enumerate(ops):
if is_data_parallel_reduce_op(op):
ring_id = op.attr("ring_id")
grad_name = op.output_arg_names[0]
grad_var = block.var(grad_name)
grad_numel = numel(grad_var)
if cur_group.acceptable(grad_var, ring_id):
assert grad_name not in grouped_grad_names
grouped_grad_names.add(grad_name)
cur_group.add(grad_var, ring_id, i)
else:
cur_group = collect_group(cur_group, grad_var, ring_id, i)
else:
if op_depend_on_group(op, cur_group):
cur_group = collect_group(cur_group, None, None, None)
# collect last group
collect_group(cur_group, None, None, None)
return grad_groups
def _update_program(self, grad_groups):
block = default_main_program().global_block()
remove_op_types = ['scale', 'c_allreduce_sum', 'c_wait_compute']
for i, group in enumerate(grad_groups[::-1]):
# create coalecse tensor
group.coalesce_var = block.create_var(name=unique_name.generate(
'coalecse_grad_{}'.format(i)),
dtype=group.dtype,
persistable=False,
stop_gradient=True)
# update allreduce & scale op
if group.scale_op_idx != -1:
scale_op = block.ops[group.scale_op_idx]
assert scale_op.type == 'scale', "should found scale op but found {}".format(
str(scale_op))
scale_op._rename_input(scale_op.input_arg_names[0],
group.coalesce_var.name)
scale_op._rename_output(scale_op.output_arg_names[0],
group.coalesce_var.name)
allreduce_op = block.ops[group.allreduce_op_idx]
assert allreduce_op.type == 'c_allreduce_sum', "should found c_allreduce_sum op but found {}".format(
str(allreduce_op))
allreduce_op._rename_input(allreduce_op.input_arg_names[0],
group.coalesce_var.name)
allreduce_op._rename_output(allreduce_op.output_arg_names[0],
group.coalesce_var.name)
# remvoe un-used op
remove_op_indices = group.remove_wait_op_indices + group.remove_allreduce_op_indices + group.remove_scale_op_indices
for idx in sorted(remove_op_indices, reverse=True):
assert block.ops[
idx].type in remove_op_types, "Unexception: try to remove op {}".format(
str(op))
block._remove_op(idx)
# insert coalecse op
concated_shapes = []
concated_ranks = []
for grad_ in group.gradients:
shape = grad_.shape
concated_shapes.extend(shape)
concated_ranks.append(len(shape))
grad_names = [grad.name for grad in group.gradients]
block._insert_op_without_sync(group.coalesce_op_idx,
type="coalesce_tensor",
inputs={"Input": grad_names},
outputs={
"Output": grad_names,
"FusedOutput": group.coalesce_var
},
attrs={
"copy_data": False,
"use_align": True,
"dtype": group.dtype,
"concated_shapes":
concated_shapes,
"concated_ranks": concated_ranks,
OP_ROLE_KEY: OpRole.Backward
})
block._sync_with_cpp()
# TODO update dist attr
def summary(self, grad_groups=[]):
# TODO: add logger module
import logging
self._logger = logging.getLogger()
self._logger.propagate = False
if not self._logger.handlers:
self._logger.setLevel(logging.INFO)
log_handler = logging.StreamHandler()
log_format = logging.Formatter(
'[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s'
)
log_handler.setFormatter(log_format)
self._logger.addHandler(log_handler)
if len(grad_groups) > 0:
self._logger.info(
"origin {} allreduce ops are fused into {} coalecse allreduce ops."
.format(len(self._grad_name_to_group_map.keys()),
len(grad_groups)))
self._logger.info("gradient fusing group are following: ")
fused_grads = set()
for i, group in enumerate(grad_groups):
self._logger.info(
"coalecse gradient [{}] is composed by: {}".format(
i, [grad.name for grad in group.gradients]))
fused_grads.update([grad.name for grad in group.gradients])
individual_grads = set(
self._grad_name_to_group_map.keys()) - set(fused_grads)
self._logger.info(
"the following [{}] gradients are not fused: ".format(
len(individual_grads)))
self._logger.info("individual gradient {}".format(individual_grads))
class GradientsGroup(object):
def __init__(self, ops, max_group_size):
self.max_group_size = max_group_size
self.ops = ops
self.gradients = []
self.numel = 0
self.dtype = None
self.ring_id = None
self.coalesce_var = None
self.coalesce_op_idx = -1
self.allreduce_op_idx = -1
self.scale_op_idx = -1
self.remove_wait_op_indices = []
self.remove_allreduce_op_indices = []
self.remove_scale_op_indices = []
def acceptable(self, grad_var, ring_id):
if len(self.gradients) == 0:
return True
if ring_id != self.ring_id:
return False
if numel(grad_var) + self.numel > self.max_group_size:
return False
if grad_var.dtype != self.dtype:
return False
return True
def add(self, grad_var, ring_id, i):
self.gradients.append(grad_var)
self.ring_id = ring_id
self.dtype = grad_var.dtype
self.numel += numel(grad_var)
# remove auxiliary ops in non-fuse dp allreduce
self.remove_allreduce_op_indices.append(i)
# NOTE this pass rely on the original synchronization add in previous passes
# (same stream or calc_wait_comm & comm_wait_calc)
# to guarantee the correctness of comm_calc execution order.
# so the calc_wait_comm should be keep.
grad_op_idx = i - 1
if i > 0 and self.ops[i - 1].type == 'c_wait_compute':
self.remove_wait_op_indices.append(i - 1)
grad_op_idx -= 1
if i + 1 < len(self.ops) and is_data_parallel_scale_op(self.ops[i - 1]):
self.remove_scale_op_indices.append(i + 1)
if len(self.gradients) == 1:
# TODO Remove this is a temporary hack for Tensor Parallel. the logic
# for find grad_op should be more general.
if self.ops[grad_op_idx].type == "c_allreduce_sum":
grad_op_idx -= 1
grad_op = self.ops[grad_op_idx]
assert grad_var.name in grad_op.output_arg_names, "grad [{}] should be output of {}".format(
grad_var.name, str(grad_op))
self.coalesce_op_idx = grad_op_idx
def finalize(self):
self.allreduce_op_idx = self.remove_allreduce_op_indices.pop()
if len(self.remove_wait_op_indices) > 1:
self.remove_wait_op_indices.pop()
if len(self.remove_scale_op_indices) > 1:
self.scale_op_idx = self.remove_scale_op_indices.pop()
......@@ -16,6 +16,7 @@ from collections import defaultdict
import paddle
from paddle.framework import core
from paddle.fluid.framework import default_main_program, default_startup_program
from paddle.fluid import unique_name
from .pass_base import register_pass
from paddle.fluid.data_feeder import check_variable_and_dtype, check_type
......@@ -379,6 +380,10 @@ class FP16State(object):
# create cast grad
grad_slot_name = slot_name + "@GRAD"
assert grad_slot_name in op.output_names
if len(op.output(grad_slot_name)) == 0:
var = block.var(src_name)
assert var.stop_gradient is True
continue
assert len(op.output(grad_slot_name)) == 1
grad_name = op.output(grad_slot_name)[0]
grad = block.var(grad_name)
......@@ -442,7 +447,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):
inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Backward}
attrs = {'op_role': OpRole.Optimize}
new_op = main_block.append_op(type='check_finite_and_unscale',
inputs=inputs,
outputs=outputs,
......@@ -536,6 +541,39 @@ def _insert_memcopy(block, idx, src_var, dist_context, direction="D2H"):
return output_var
def cast_startup_program():
main_program = default_main_program()
startup_program = default_startup_program()
param_to_dtype = {}
for block in main_program.blocks:
for p in block.all_parameters():
param_to_dtype[p.name] = p.dtype
def is_initialization_op(op):
comm_op_prefix = "c_"
op_type = op.type
if op_type.startswith(comm_op_prefix):
return False
if len(op.output_arg_names) != 1 and len(op.input_arg_names) != 0:
return False
return True
for op in startup_program.global_block().ops:
if is_initialization_op(op):
output_name = op.output_arg_names[0]
if param_to_dtype.get(output_name,
None) == core.VarDesc.VarType.FP16:
assert op.has_attr(
'dtype'
), "initialization op is supported to has dtype attribute but got {}.".format(
str(op))
if op.attr('dtype') == core.VarDesc.VarType.FP32:
op._set_attr('dtype', core.VarDesc.VarType.FP16)
@register_pass("auto_parallel_fp16")
class FP16Pass(AMPPass):
......@@ -563,6 +601,8 @@ class FP16Pass(AMPPass):
input_data_var_names)
is_train = fp16_state._build_state()
cast_startup_program()
if is_train:
with paddle.static.program_guard(main_program, startup_program):
# TODO (JZ-LIANG)support cast forward program only when inference
......@@ -575,18 +615,18 @@ class FP16Pass(AMPPass):
) or self.get_attr("init_loss_scaling") != 1.0:
found_infs = []
if fp32_grads:
with main_program._backward_role_guard():
with main_program._optimized_guard([]):
_, found_inf_fp32 = _check_and_update_gradient(
fp32_grads, self._loss_scaling, "@fp32",
self.dist_context)
found_infs.append(found_inf_fp32)
if fp16_grads:
with main_program._backward_role_guard():
with main_program._optimized_guard([]):
_, found_inf_fp16 = _check_and_update_gradient(
fp16_grads, self._loss_scaling, "@fp16",
self.dist_context)
found_infs.append(found_inf_fp16)
with main_program._backward_role_guard():
with main_program._optimized_guard([]):
block = main_program.global_block()
all_infs = paddle.fluid.layers.concat(found_infs)
......@@ -608,7 +648,7 @@ class FP16Pass(AMPPass):
block, self.dist_context)
if self.get_attr("use_dynamic_loss_scaling"):
with main_program._backward_role_guard():
with main_program._optimized_guard([]):
if fp32_grads:
self._update_loss_scaling(fp32_grads, found_inf)
if fp16_grads:
......
......@@ -207,6 +207,7 @@ class ClipGradByGloblNormPass(PassBase):
super(ClipGradByGloblNormPass, self).__init__()
self.set_attr("rank_id", None)
self.set_attr("dist_context", None)
self.set_attr("params_grads", None)
def _check_self(self):
if self.get_attr("dist_context") is None:
......@@ -214,6 +215,8 @@ class ClipGradByGloblNormPass(PassBase):
dist_context = self.get_attr("dist_context")
if dist_context._lr_optimizer._grad_clip is None:
return False
if self.get_attr("params_grads") is None:
return False
return True
def _check_conflict(self, other_pass):
......@@ -223,7 +226,8 @@ class ClipGradByGloblNormPass(PassBase):
dist_context = self.get_attr("dist_context", None)
rank_id = self.get_attr("rank_id", None)
block = main_program.global_block()
dist_params_grads = _get_params_grads(block)
dist_params_grads = self.get_attr("params_grads", None)
# dist_params_grads = _get_params_grads(block)
self.clip_helper = ClipHelper(dist_params_grads, rank_id, block,
dist_context)
......
......@@ -55,13 +55,6 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
return optimize_ops_desc
def _remove_op_role_var(param, grad):
op_maker = core.op_proto_and_checker_maker
op = grad.op
if op and op.has_attr(op_maker.kOpRoleVarAttrName()):
op._remove_attr(op_maker.kOpRoleVarAttrName())
def _get_gm_cond_var(main_program, k_steps, dist_context):
main_block = main_program.global_block()
# Add const var
......@@ -147,8 +140,6 @@ def _append_gradient_merge_backward_op(
param.type != core.VarDesc.VarType.SELECTED_ROWS
), "SELECTED_ROWS is not supported in GradientMergeOptimizer for now"
_remove_op_role_var(param, grad)
# {grad.name: gradient_merge_var.name} to rename opt inputs
grad_to_gradient_merge = {}
# {param: gradient_merge_var} to insert scale op and fill_constant op
......
......@@ -51,7 +51,8 @@ class ShardingPass(PassBase):
super(ShardingPass, self).__init__()
self.set_attr("dist_context", None)
self.set_attr("stage", None)
self.set_attr("sharding_degree", None)
self.set_attr("sharding_degree", None) # for parallelizer
self.set_attr("degree", None) # for parallelizer_v2
self.set_attr("params_grads", [])
self.set_attr("global_rank", -1)
self.dp_groups = set()
......@@ -59,6 +60,7 @@ class ShardingPass(PassBase):
self.varname_to_sharding_info = {}
self.partial_sharding = False
self.outer_dp_group = None
self.shared_params_grads = []
def _check_self(self):
if self.get_attr("dist_context") is None:
......@@ -66,8 +68,15 @@ class ShardingPass(PassBase):
if self.get_attr("stage") not in [1, 2, 3]:
return False
if (not isinstance(self.get_attr("sharding_degree"),
int)) or self.get_attr("sharding_degree") <= 1:
if self.get_attr("sharding_degree") is not None:
if (not isinstance(self.get_attr("sharding_degree"), int)) \
or self.get_attr("sharding_degree") <= 1:
return False
elif self.get_attr("degree") is not None:
if (not isinstance(self.get_attr("degree"), int)) \
or self.get_attr("degree") <= 1:
return False
else:
return False
if len(self.get_attr("params_grads")) <= 0:
return False
......@@ -82,7 +91,8 @@ class ShardingPass(PassBase):
def _apply_single_impl(self, main_program, startup_program, context):
self._dist_context = self.get_attr("dist_context")
self.sharding_world_size = int(self.get_attr("sharding_degree"))
self.sharding_world_size = int(
self.get_attr("sharding_degree") or self.get_attr("degree"))
self.stage = int(self.get_attr("stage"))
self.global_rank = int(self.get_attr("global_rank"))
params_grads = self.get_attr("params_grads")
......@@ -94,6 +104,8 @@ class ShardingPass(PassBase):
self._shard_gradient_synchronization(main_block)
self._shard_parameter(main_block, startup_block)
context.set_attr("params_grads", self.shared_params_grads)
def _build_sharding_groups(self, main_block, params_grads):
self._collective_data_parallel_groups(main_block)
self._build_sharding_infos(params_grads)
......@@ -148,13 +160,10 @@ class ShardingPass(PassBase):
self._dist_context._sharding_group = sharding_group
# TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
params_in_group = [p for p, g in params_grads]
assert len(params_in_group) == len(
set(params_in_group)), "found duplicated param in params_grads"
sharding_info = ShardingInfo(sharding_group, self.global_rank,
params_in_group)
params_grads)
self.sharding_infos.append(sharding_info)
for param in params_in_group:
for param in sharding_info.params:
self.varname_to_sharding_info[param.name] = sharding_info
def _shard_optimizer(self, main_block, startup_block, params_grads,
......@@ -201,6 +210,7 @@ class ShardingPass(PassBase):
op.desc.set_output('Out', reversed_x)
else:
if op.type == "check_finite_and_unscale":
op_role = op.attr('op_role')
out_name = op.output_arg_names[0]
out_var = main_block.vars[out_name]
main_block._remove_op(idx, sync=False)
......@@ -212,6 +222,7 @@ class ShardingPass(PassBase):
"shape": out_var.shape,
"dtype": out_var.dtype,
"value": 0,
OP_ROLE_KEY: op_role,
})
else:
main_block._remove_op(idx, sync=False)
......@@ -313,6 +324,9 @@ class ShardingPass(PassBase):
if varname != param_name
])
main_block._remove_op(idx, sync=False)
else:
self.shared_params_grads.append(
self._get_param_grad(param_name))
for idx, op in reversed(list(enumerate(startup_block.ops))):
if len(op.output_arg_names) == 1 and op.output_arg_names[
......@@ -365,6 +379,13 @@ class ShardingPass(PassBase):
sharding_info = self.varname_to_sharding_info[param_name]
return sharding_info.is_in_local_shard(param_name)
def _get_param_grad(self, param_name):
assert param_name in self.varname_to_sharding_info
sharding_info = self.varname_to_sharding_info[param_name]
p_g = sharding_info.get_param_grad(param_name)
assert p_g is not None
return p_g
def _shard_gradient_synchronization(self, main_block):
if self.stage < 2:
......@@ -705,9 +726,13 @@ def shard_parameters(params, group_size):
class ShardingInfo(object):
def __init__(self, group, rank, params):
def __init__(self, group, rank, params_grads):
self.group = group
self.params = params
self.params_grads = dict([(p.name, (p, g)) for p, g in params_grads])
assert len(self.params_grads) == len(set(
self.params_grads)), "found duplicated param in params_grads"
self.params = [p for p, _ in params_grads]
self.param_names = [p.name for p in self.params]
self.group_size = group.nranks
self.global_rank = rank
......@@ -762,3 +787,11 @@ class ShardingInfo(object):
if usage > 0:
broadcast_vars.add(param)
return broadcast_vars, param_usage
def get_param_grad(self, param_name):
if not self.is_in_local_shard(param_name):
raise ValueError(
"param[{}] not in current rank.".format(param_name))
if param_name not in self.params_grads:
raise ValueError('param[{}] not in params_grads'.format(param_name))
return self.params_grads.get(param_name, None)
......@@ -37,8 +37,28 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
${dist_ENVS})
set_tests_properties(test_high_order_grad
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_grad_clip MODULES test_grad_clip ENVS ${dist_ENVS})
set_tests_properties(test_grad_clip PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS
${dist_ENVS})
set_tests_properties(test_iterable_dataset
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80)
py_test_modules(test_pass_grad_clip MODULES test_pass_grad_clip ENVS
${dist_ENVS})
set_tests_properties(test_pass_grad_clip
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_gradient_merge MODULES test_pass_gradient_merge
ENVS ${dist_ENVS})
set_tests_properties(test_pass_gradient_merge
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_recompute MODULES test_pass_recompute ENVS
${dist_ENVS})
set_tests_properties(test_pass_recompute
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_sharding MODULES test_pass_sharding ENVS
${dist_ENVS})
set_tests_properties(test_pass_sharding
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_amp MODULES test_pass_amp ENVS ${dist_ENVS})
set_tests_properties(test_pass_amp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_while_op_completion MODULES test_while_op_completion
......@@ -70,11 +90,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2)
py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2)
py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip)
py_test_modules(test_quantization MODULES test_quantization)
py_test_modules(test_dist_matmul MODULES test_dist_matmul)
py_test_modules(test_process_mesh MODULES test_process_mesh)
py_test_modules(test_interface MODULES test_interface)
py_test_modules(test_strategy MODULES test_strategy)
py_test_modules(test_pass_quantization MODULES test_pass_quantization)
py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS
${dist_ENVS})
set_tests_properties(test_iterable_dataset
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80)
endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
def apply_pass(use_amp=False, level=None):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_amp:
amp = strategy.amp
amp.enable = True
amp.custom_white_list = ['softmax', 'layer_norm', 'gelu']
amp.custom_black_list = [
'c_softmax_with_cross_entropy', 'elementwise_div', 'reduce_sum'
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = level in ["o2", "o3"]
amp.use_optimizer_fp16 = level == "o3"
print("amp level: ", level)
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestAMPPass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 1
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_amp=False, level=None):
reset_prog()
strategy = apply_pass(use_amp, level)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses, rtol=None, atol=None):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=rtol or self.rtol,
atol=atol or self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses))
def test_amp_pass(self):
# mp2 training
mp_engine = self.get_engine()
mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(mp_losses["loss"])
# mp2 amp-o1 training
amp_o1_engine = self.get_engine(True, "o1")
amp_o1_losses = amp_o1_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
amp_o1_losses = np.array(amp_o1_losses["loss"])
# self.check_results(mp_losses, amp_o1_losses)
# mp2 amp-o2 training
amp_o2_engine = self.get_engine(True, "o2")
amp_o2_losses = amp_o2_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
amp_o2_losses = np.array(amp_o2_losses["loss"])
# self.check_results(mp_losses, amp_o2_losses)
# mp2 amp-o3 training
amp_o3_engine = self.get_engine(True, "o3")
amp_o3_losses = amp_o3_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
amp_o3_losses = np.array(amp_o3_losses["loss"])
# self.check_results(mp_losses, amp_o3_losses)
if __name__ == "__main__":
unittest.main()
......@@ -32,7 +32,7 @@ import paddle.distributed.auto_parallel as auto
paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"])
batch_size = 4
hidden_size = 1024
sequence_len = 512
......@@ -103,11 +103,7 @@ def mlp_pretrain_forward(train_program, start_program):
shape=[batch_size, sequence_len, 1],
dtype='float32')
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mappig": [-1, -1, -1]
})
auto.shard_tensor(input, _global_process_mesh, [None, None, None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......@@ -126,9 +122,6 @@ def mlp_pretrain_forward(train_program, start_program):
def train():
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
dist_strategy.pipeline = False
......
......@@ -18,24 +18,21 @@ import random
import numpy as np
import paddle
import paddle.distributed.fleet as fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
paddle.enable_static()
def apply_pass(use_sharding=False):
strategy = fleet.DistributedStrategy()
strategy.semi_auto = True
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_sharding:
strategy.sharding = True
strategy.sharding_configs = {
"sharding_degree": 2,
"stage": 2,
}
sharding = strategy.sharding
sharding.degree = 2
sharding.stage = 2
return strategy
......@@ -76,34 +73,17 @@ class TestGradientClipByGlobalNorm(unittest.TestCase):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
engine.mode = "train"
engine._executor.run(engine.startup_program)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_dp2_engine(self):
def get_engine(self, use_sharding=False):
reset_prog()
strategy = apply_pass()
strategy = apply_pass(use_sharding)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
inputs_spec, labels_spec = create_data_holder(self.batch_size)
engine = Engine(model, inputs_spec, labels_spec, strategy=strategy)
engine.prepare(optimizer=opt, loss=loss)
self.init(engine)
return engine
def get_dp2sharding2_engine(self):
reset_prog()
strategy = apply_pass(True)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
inputs_spec, labels_spec = create_data_holder(self.batch_size)
engine = Engine(model, inputs_spec, labels_spec, strategy=strategy)
engine.prepare(optimizer=opt, loss=loss)
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
......@@ -121,15 +101,13 @@ class TestGradientClipByGlobalNorm(unittest.TestCase):
def test_grad_clip(self):
# dp2 training
dp_engine = self.get_dp2_engine()
dp_engine.fit(self.dataset, batch_size=self.batch_size, use_cache=True)
dp_engine = self.get_engine()
dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_param_values = get_parameter_value(dp_engine.main_program)
# dp2sharding2 training
sharding_engine = self.get_dp2sharding2_engine()
sharding_engine.fit(self.dataset,
batch_size=self.batch_size,
use_cache=True)
sharding_engine = self.get_engine(True)
sharding_engine.fit(self.dataset, 3, batch_size=self.batch_size)
sharding_param_values = get_parameter_value(
sharding_engine.main_program)
......
......@@ -27,10 +27,8 @@ import paddle.nn.functional as F
import paddle.utils as utils
from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn
......@@ -47,6 +45,8 @@ class_num = 10
paddle.seed(44)
is_fetch = True
class MyDataset(Dataset):
......@@ -90,19 +90,20 @@ class MLPLayer(nn.Layer):
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
out = auto.shard_op(self.norm, dist_attr={"process_mesh":
PP_MESH_0})(input)
out = auto.shard_op(self.norm, PP_MESH_0)(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = auto.shard_op(self.linear1, dist_attr={"process_mesh":
PP_MESH_1})(out)
out = auto.shard_op(self.linear1, PP_MESH_1)(out)
out = self.dropout(out)
out = self.linear2(out)
self.out = out
if is_fetch:
auto.fetch(out, "out")
return out
def train(fetch):
global is_fetch
is_fetch = fetch
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
......@@ -113,46 +114,34 @@ def train(fetch):
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels_spec = InputSpec([batch_size], 'int64', 'label')
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
# init engine
engine = Engine(mlp,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
strategy=dist_strategy)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
strategy = auto.Strategy()
strategy.auto_mode = "semi"
# fetch
if fetch:
fetches = {'out': mlp.out}
else:
fetches = None
engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)
# train
train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset,
eval_dataset1 = MyDataset(5 * batch_size)
engine.fit(train_data=train_dataset,
epochs=2,
batch_size=batch_size,
steps_per_epoch=batch_num * batch_size,
fetches=fetches)
valid_data=eval_dataset1)
# eval
eval_dataset = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size, fetches=fetches)
eval_dataset2 = MyDataset(batch_size)
engine.evaluate(eval_dataset2, batch_size=batch_size)
# predict
test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size, fetches=fetches)
engine.predict(test_dataset, batch_size=batch_size)
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp_inf')
engine.save(model_filename, training=False, mode='predict')
model_filename = os.path.join(temp_dir.name, 'mlp')
engine.save(model_filename, training=True)
engine.load(model_filename)
temp_dir.cleanup()
......
......@@ -26,11 +26,9 @@ import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
from paddle.io import Dataset, DataLoader
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
paddle.enable_static()
batch_size = 2
......@@ -91,6 +89,7 @@ class MLPLayer(nn.Layer):
out = self.linear1(out)
out = self.dropout(out)
out = self.linear2(out)
auto.fetch(out, "out")
self.out = out
return out
......@@ -107,46 +106,32 @@ def train(fetch):
epsilon=1e-08,
grad_clip=None)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels_spec = InputSpec([batch_size], 'int64', 'label')
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
dist_strategy.pipeline = False
dist_strategy.recompute = False
# init parallel optimizer
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
dist_strategy = auto.Strategy()
dist_strategy.auto_mode = "semi"
# init engine
engine = Engine(mlp,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
engine = auto.Engine(mlp,
loss,
optimizer,
paddle.metric.Accuracy(),
strategy=dist_strategy)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
# fetch
if fetch:
fetches = {'out': mlp.out}
else:
fetches = None
# train
train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset, batch_size=batch_size, fetches=fetches)
engine.fit(train_dataset, batch_size=batch_size)
# eval
eval_dataset = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size, fetches=fetches)
engine.evaluate(eval_dataset, batch_size=batch_size)
# predict
test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size, fetches=fetches)
engine.predict(test_dataset, batch_size=batch_size)
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp_inf')
engine.save(model_filename, training=False, mode='predict')
engine.save(model_filename, training=False)
temp_dir.cleanup()
......
......@@ -14,8 +14,10 @@
import sys
import numpy as np
import random
import paddle
import paddle.distributed.auto_parallel as auto
sys.path.append("..")
import auto_parallel_gpt_model as modeling
......@@ -25,7 +27,7 @@ sequence_len = 512
vocab_size = 1000
class FakeDataset:
class FakeDataset(paddle.io.Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
......@@ -33,6 +35,9 @@ class FakeDataset:
self.vocab_size = vocab_size
def __getitem__(self, idx):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
tokens = np.random.randint(self.vocab_size, size=self.sequence_len)
position_ids = np.arange(self.sequence_len)
attention_mask = np.tril(np.ones(self.sequence_len)).reshape(
......@@ -67,8 +72,9 @@ def create_data_holder(batch_size):
def generate_model(strategy):
modeling.init_global()
modeling._global_process_mesh = list(
range(paddle.distributed.get_world_size()))
ranks = list(range(paddle.distributed.get_world_size()))
modeling._global_process_mesh = auto.ProcessMesh(mesh=ranks,
dim_names=["x"])
if strategy == "serial":
modeling._global_parallel_strategy = "serial"
elif strategy == "mp":
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
paddle.enable_static()
def apply_pass(use_gradient_merge=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_gradient_merge:
gradient_merge = strategy.gradient_merge
gradient_merge.enable = True
gradient_merge.k_steps = 4
gradient_merge.avg = True
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestGradientMergePass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 8
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_gradient_merge=False):
reset_prog()
strategy = apply_pass(use_gradient_merge)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses))
def test_gradient_merge_pass(self):
# dp2 training
dp_engine = self.get_engine()
dp_losses = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_losses = np.array(dp_losses["loss"])
# dp2 gradient merge training
gm_engine = self.get_engine(True)
gm_losses = gm_engine.fit(self.dataset, 3, batch_size=self.batch_size)
gm_losses = np.array(gm_losses["loss"])
avg_loss = 0
pass_avg_ret_list = []
for i, pass_ret in enumerate(gm_losses):
if (i + 1) % 4 == 0:
avg_loss += pass_ret
pass_avg_ret_list.append(avg_loss / 4)
avg_loss = 0
else:
avg_loss += pass_ret
self.check_results(dp_losses, np.array(pass_avg_ret_list))
if __name__ == "__main__":
unittest.main()
......@@ -17,11 +17,7 @@ import paddle
import unittest
import numpy as np
import paddle.distributed.auto_parallel as auto
from paddle.static import InputSpec
from paddle.distributed import fleet
from paddle.incubate.autograd import Hessian
from paddle.distributed.auto_parallel.engine import Engine
np.random.seed(1234)
paddle.seed(1234)
......@@ -87,7 +83,7 @@ class LaplaceModel(paddle.nn.Layer):
return eq_loss, bc_u
class LaplaceDataset:
class LaplaceDataset(paddle.io.Dataset):
def __init__(self, num_sample):
self.num_sample = num_sample
......@@ -129,23 +125,14 @@ def main():
# model
laplace = LaplaceModel()
# spec
inputs_spec = [
InputSpec([100, 2], 'float32', 'x'),
InputSpec([36], 'int64', 'bc_idx')
]
labels_spec = InputSpec([36, 1], 'float32', 'bc_v')
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
dist_strategy = auto.Strategy()
dist_strategy.auto_mode = "semi"
engine = Engine(laplace,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
engine = auto.Engine(laplace,
loss=loss_func,
optimizer=optimizer,
strategy=dist_strategy)
engine.prepare(optimizer=optimizer, loss=loss_func)
engine.fit(train_dataset, batch_size=None)
engine.fit(train_dataset, train_sample_split=2, batch_size=None)
dist_context = engine.dist_context
block = engine.main_program.global_block()
......
......@@ -28,9 +28,8 @@ import paddle.utils as utils
from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn
......@@ -48,10 +47,9 @@ class_num = 10
paddle.seed(44)
class MyDataset(IterableDataset):
class MyDataset(paddle.io.IterableDataset):
def __init__(self, num_samples):
super(MyDataset, self).__init__()
self.num_samples = num_samples
def __iter__(self):
......@@ -61,10 +59,9 @@ class MyDataset(IterableDataset):
yield input, label
class MyDataset1(Dataset):
class MyDataset1(paddle.io.Dataset):
def __init__(self, num_samples):
super(MyDataset1, self).__init__()
self.num_samples = num_samples
self.data = []
for i in range(self.num_samples):
......@@ -112,12 +109,10 @@ class MLPLayer(nn.Layer):
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
out = auto.shard_op(self.norm, dist_attr={"process_mesh":
PP_MESH_0})(input)
out = auto.shard_op(self.norm, PP_MESH_0)(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = auto.shard_op(self.linear1, dist_attr={"process_mesh":
PP_MESH_1})(out)
out = auto.shard_op(self.linear1, PP_MESH_1)(out)
out = self.dropout(out)
out = self.linear2(out)
self.out = out
......@@ -136,54 +131,36 @@ def train(fetch):
epsilon=1e-08,
grad_clip=None)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels_spec = InputSpec([batch_size], 'int64', 'label')
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
dist_strategy = auto.Strategy()
dist_strategy.auto_mode = "semi"
dist_strategy.split_data = True
fleet.init(is_collective=True, strategy=dist_strategy)
# init engine
engine = Engine(mlp,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
engine = auto.Engine(mlp,
loss,
optimizer,
paddle.metric.Accuracy(),
strategy=dist_strategy)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
# fetch
if fetch:
fetches = {'out': mlp.out}
else:
fetches = None
# train
train_dataset = MyDataset(batch_num * batch_size)
train_dataset1 = MyDataset1(batch_num)
engine.fit(train_dataset,
epochs=2,
batch_size=batch_size,
steps_per_epoch=batch_num,
fetches=fetches)
engine.fit(train_dataset1,
epochs=2,
batch_size=None,
steps_per_epoch=batch_num,
fetches=fetches)
engine.fit(train_dataset, epochs=2, batch_size=batch_size)
train_dataset1 = MyDataset1(batch_size * batch_num)
engine.fit(train_dataset1, epochs=2, batch_size=None)
# eval
eval_dataset = MyDataset(batch_size)
engine.evaluate(eval_dataset, batch_size, fetches=fetches)
engine.evaluate(eval_dataset, batch_size=batch_size)
# predict
test_dataset = MyDataset(batch_size)
engine.predict(test_dataset, batch_size, fetches=fetches)
engine.predict(test_dataset, batch_size=batch_size)
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp_inf')
engine.save(model_filename, training=False, mode='predict')
engine.save(model_filename, training=False)
temp_dir.cleanup()
......
......@@ -27,10 +27,8 @@ import paddle.nn.functional as F
import paddle.utils as utils
from paddle.fluid import layers
from paddle.io import Dataset, IterableDataset, DataLoader
from paddle.static import InputSpec
from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from engine_api_dp import MyDataset
paddle.enable_static()
......@@ -43,20 +41,6 @@ class_num = 10
paddle.seed(44)
# class MyDataset(Dataset):
# def __init__(self, num_samples):
# super(MyDataset, self).__init__()
# self.num_samples = num_samples
# def __getitem__(self, index):
# input = np.random.uniform(size=image_size).astype("float32")
# label = np.random.randint(0, class_num - 1, dtype="int64")
# return input, label
# def __len__(self):
# return self.num_samples
class MLPLayer(nn.Layer):
......@@ -107,50 +91,33 @@ def train(fetch):
epsilon=1e-08,
grad_clip=None)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels_spec = InputSpec([batch_size], 'int64', 'label')
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
dist_strategy.pipeline = False
dist_strategy.recompute = False
# init parallel optimizer
dist_strategy.semi_auto = True
dist_strategy.sharding = True
dist_strategy.sharding_configs = {
"sharding_degree": 2,
"stage": 3,
"enable_tuning": True,
}
fleet.init(is_collective=True, strategy=dist_strategy)
# init engine
import tempfile
tmp_dir = tempfile.TemporaryDirectory()
dataset = MyDataset(batch_num * batch_size)
dist_strategy = auto.Strategy()
dist_strategy.auto_mode = "semi"
# sharding config
sharding = dist_strategy.sharding
sharding.enable = True
sharding.degree = 2
sharding.stage = 3
sharding.enable_tuning = True
sharding.tuning_range = [0, 1, 2, 3]
# Tuning configuration
tuning_config = {
"batch_size": batch_size,
"dataset": dataset,
"profile_start_step": 1,
"profile_end_step": 5,
"run_after_tuning": True,
"sharding": {
"stage_range": [0, 1, 2, 3]
},
"verbose": True,
}
engine = Engine(mlp,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
strategy=dist_strategy,
user_tuning_config=tuning_config)
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
tuning = dist_strategy.tuning
tuning.enable = True
tuning.profile_start_step = 1
tuning.profile_end_step = 5
tuning.run_after_tuning = True
tuning.verbose = True
dataset = MyDataset(batch_num * batch_size)
engine = auto.Engine(mlp,
loss,
optimizer,
paddle.metric.Accuracy(),
strategy=dist_strategy)
engine._tune(dataset, batch_size=batch_size)
# check tuned
assert (engine._dist_contexts['train'].strategy.sharding_configs['stage'] !=
3)
assert (engine._dist_contexts['train'].strategy.sharding.stage != 3)
if __name__ == "__main__":
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
def apply_pass(use_recompute=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_recompute:
recompute = strategy.recompute
recompute.enable = True
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestRecomputePass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-6
self.atol = 1e-8
self.batch_size = 1
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_recompute=False):
reset_prog()
strategy = apply_pass(use_recompute)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses))
def test_recompute_pass(self):
# mp2 training
mp_engine = self.get_engine()
mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
mp_losses = np.array(mp_losses["loss"])
# mp2 recompute training
rc_engine = self.get_engine(True)
rc_losses = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size)
rc_losses = np.array(rc_losses["loss"])
self.check_results(mp_losses, rc_losses)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.auto_parallel as auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
paddle.enable_static()
def apply_pass(use_sharding=False, stage=None):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_sharding:
sharding = strategy.sharding
sharding.enable = True
sharding.degree = 2
sharding.stage = 1
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestShardingPass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-6
self.atol = 1e-8
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_sharding=False, stage=None):
reset_prog()
strategy = apply_pass(use_sharding, stage)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("dp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses))
def test_sharding_pass(self):
# dp2 training
dp_engine = self.get_engine()
dp_losses = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size)
dp_losses = np.array(dp_losses["loss"])
# sharding2 stage1 training
sharding1_engine = self.get_engine(True, 1)
sharding1_losses = sharding1_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding1_losses = np.array(sharding1_losses["loss"])
self.check_results(dp_losses, sharding1_losses)
# sharding2 stage2 training
sharding2_engine = self.get_engine(True, 2)
sharding2_losses = sharding2_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding2_losses = np.array(sharding2_losses["loss"])
self.check_results(dp_losses, sharding2_losses)
# sharding2 stage3 training
sharding3_engine = self.get_engine(True, 3)
sharding3_losses = sharding3_engine.fit(self.dataset,
3,
batch_size=self.batch_size)
sharding3_losses = np.array(sharding3_losses["loss"])
self.check_results(dp_losses, sharding3_losses)
if __name__ == "__main__":
unittest.main()
......@@ -45,9 +45,10 @@ from test_cluster import cluster_json
paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp"
_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]])
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]])
_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]],
dim_names=["x", "y", "z"])
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], dim_names=["x", "y"])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], dim_names=["x", "y"])
class MLPLayer(nn.Layer):
......@@ -74,16 +75,8 @@ class MLPLayer(nn.Layer):
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input):
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [1, -1]
})
auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, "y"])
auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["y", None])
out = self.norm(input)
out = self.linear0(out)
......@@ -111,16 +104,8 @@ def mlp_forward(train_program, start_program):
embedding = paddle.nn.Embedding(10, hidden_size, sparse=True)
embedding_out = embedding(fill_constant_out)
auto.shard_tensor(input,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
auto.shard_tensor(input, PP_MESH_0, ["x", None])
auto.shard_tensor(label, PP_MESH_1, ["x", None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......
......@@ -34,7 +34,10 @@ paddle.enable_static()
batch_size = 4
hidden_size = 1024
sequence_len = 512
_g_process_mesh = [[0, 1], [2, 3]]
_g_process_mesh = [
auto.ProcessMesh([0, 1], dim_names=["x"]),
auto.ProcessMesh([2, 3], dim_names=["x"])
]
def get_random_inputs_and_labels(input_shape, label_shape):
......@@ -82,18 +85,10 @@ class MLPLayer(nn.Layer):
def forward(self, input):
out = self.norm(input)
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear0.weight, _g_process_mesh[0], [None, "x"])
out = self.linear0(out)
out = F.gelu(out, approximate=True)
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _g_process_mesh[1],
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.linear1.weight, _g_process_mesh[1], ["x", None])
out = self.linear1(out)
return out
......@@ -123,16 +118,8 @@ def get_program():
dataloader.set_batch_generator(batch_generator_creator(),
places=paddle.static.cuda_places())
# data dist_attr
auto.shard_tensor(input,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(input, _g_process_mesh[0], ["x", None, None])
auto.shard_tensor(label, _g_process_mesh[0], ["x", None, None])
mlp_start = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......
......@@ -42,19 +42,13 @@ def make_program_lookup_table_v1_mp_dp():
is_sparse=False)
loss = paddle.fluid.layers.reduce_mean(emb_out)
auto.shard_tensor(src_ids,
dist_attr={
"process_mesh": auto.ProcessMesh([[0, 1], [2,
3]]),
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(
src_ids, auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]),
["x", None, None])
emb_weight = block.vars["emb_weight"]
auto.shard_tensor(emb_weight,
dist_attr={
"process_mesh": auto.ProcessMesh([[0, 1], [2,
3]]),
"dims_mapping": [1, -1]
})
auto.shard_tensor(
emb_weight, auto.ProcessMesh([[0, 1], [2, 3]],
dim_names=["x", "y"]), ["y", None])
return main_program, start_program, loss
......
......@@ -22,82 +22,58 @@ from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
mesh = [[0, 1], [2, 3]]
mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
def init_x_row(trans_x):
if trans_x:
x = paddle.static.data(name='x', shape=[10, 6, 8], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, 1, -1]
})
auto.shard_tensor(x, mesh, ["x", "y", None])
return x
else:
x = paddle.static.data(name='x', shape=[10, 8, 6], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, -1, 1]
})
auto.shard_tensor(x, mesh, ["x", None, "y"])
return x
def init_x_col(trans_x):
if trans_x:
x = paddle.static.data(name='x', shape=[6, 8], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(x, mesh, [None, "x"])
return x
else:
x = paddle.static.data(name='x', shape=[8, 6], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(x, mesh, ["x", None])
return x
def init_y_row(trans_y):
if trans_y:
y = paddle.static.data(name='y', shape=[4, 6], dtype='float32')
auto.shard_tensor(y,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(y, mesh, [None, "y"])
return y
else:
y = paddle.static.data(name='y', shape=[6, 4], dtype='float32')
auto.shard_tensor(y,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [1, -1]
})
auto.shard_tensor(y, mesh, ["y", None])
return y
def init_y_col(trans_y):
if trans_y:
y = paddle.static.data(name='y', shape=[4, 6], dtype='float32')
auto.shard_tensor(y,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [1, -1]
})
auto.shard_tensor(y, mesh, ["y", None])
return y
else:
y = paddle.static.data(name='y', shape=[6, 4], dtype='float32')
auto.shard_tensor(y,
dist_attr={
"process_mesh": mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(y, mesh, [None, "y"])
return y
......
......@@ -71,11 +71,8 @@ class TestDistOpCost(unittest.TestCase):
shape=[4, 1],
dtype='float32')
label.stop_gradient = True
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[2, 8], value=1, dtype='float32')
weight_attr = paddle.ParamAttr()
......@@ -121,17 +118,12 @@ class TestDistOpCost(unittest.TestCase):
shape=[8, 1],
dtype='float32')
label.stop_gradient = True
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x"])
auto.shard_tensor(label,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
# embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32')
......@@ -141,12 +133,9 @@ class TestDistOpCost(unittest.TestCase):
for op in main_program.global_block().ops:
if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W,
dist_attr={
"process_mesh":
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.shard_tensor(
W, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0]
......@@ -154,26 +143,20 @@ class TestDistOpCost(unittest.TestCase):
param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, 0]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
[None, "x"])
out1 = paddle.fluid.layers.matmul(out,
param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
[None, None])
tmp_out = paddle.fluid.layers.matmul(out1, tmp_param)
out2 = paddle.fluid.layers.matmul(tmp_out,
param2) # [8, 4] [-1, 0]
......@@ -227,17 +210,12 @@ class TestDistOpCost(unittest.TestCase):
shape=[8, 1],
dtype='float32')
label.stop_gradient = True
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x"])
auto.shard_tensor(label,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
# embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32')
......@@ -247,12 +225,9 @@ class TestDistOpCost(unittest.TestCase):
for op in main_program.global_block().ops:
if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W,
dist_attr={
"process_mesh":
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.shard_tensor(
W, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0]
......@@ -260,25 +235,20 @@ class TestDistOpCost(unittest.TestCase):
param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, 0]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
[None, "x"])
out1 = paddle.matmul(out, param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
[None, None])
tmp_out = paddle.matmul(out1, tmp_param)
out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0]
......@@ -331,17 +301,11 @@ class TestDistOpCost(unittest.TestCase):
shape=[8, 1],
dtype='float32')
label.stop_gradient = True
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0]
})
auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x"])
auto.shard_tensor(label,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
# embedding
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=x, shape=[4], value=1, dtype='int32')
......@@ -351,12 +315,9 @@ class TestDistOpCost(unittest.TestCase):
for op in main_program.global_block().ops:
if op.type == "lookup_table_v2":
W = main_program.global_block().vars[op.input("W")[0]]
auto.shard_tensor(W,
dist_attr={
"process_mesh":
auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.shard_tensor(
W, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
out = paddle.fluid.layers.transpose(out,
[1, 0]) # [8, 2] [-1, 0]
......@@ -364,25 +325,21 @@ class TestDistOpCost(unittest.TestCase):
param1 = paddle.fluid.layers.create_parameter(
[4, 8], paddle.float32) # [2, 8] [0, -1]
auto.shard_tensor(param1,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None])
param2 = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 4] [-1, 0]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, 0]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
[None, "x"])
out1 = paddle.fluid.layers.mul(out, param1) # [8, 8] [-1, -1]
tmp_param = paddle.fluid.layers.create_parameter(
[8, 8], paddle.float32) # [8, 8] [-1, -1]
auto.shard_tensor(param2,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [-1, -1]
})
auto.ProcessMesh([0, 1], dim_names=["x"]),
[None, None])
tmp_out = paddle.fluid.layers.mul(out1, tmp_param)
out2 = paddle.fluid.layers.mul(tmp_out,
param2) # [8, 4] [-1, 0]
......
......@@ -29,11 +29,8 @@ def make_program_dp2():
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
x.stop_gradient = False
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None, None])
tmp_0 = paddle.norm(x, p=2)
return main_program, start_program, tmp_0
......@@ -44,11 +41,8 @@ def make_program_serial():
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
x.stop_gradient = False
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0]),
"dims_mapping": [-1, -1, -1]
})
auto.shard_tensor(x, auto.ProcessMesh([0], dim_names=["x"]),
[None, None, None])
tmp_0 = paddle.norm(x, p=2)
return main_program, start_program, tmp_0
......
......@@ -29,11 +29,9 @@ def make_program_dp2():
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 4, 8], dtype='float32')
x.stop_gradient = False
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None, None])
tmp_0 = paddle.reshape(x, shape=[0, 0, 4, 2])
tmp_1 = paddle.reshape(tmp_0, shape=[0, 0, 8])
tmp_2 = tmp_1.reshape((tmp_1.shape[0], tmp_1.shape[1], -1))
......
......@@ -25,11 +25,9 @@ def make_program_dp2():
start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0, 1]),
"dims_mapping": [0, -1, -1]
})
auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]),
["x", None, None])
tmp_0 = x[0]
tmp_1 = x[:, 0, :]
tmp_2 = x[:, :, 1]
......@@ -42,11 +40,9 @@ def make_program_serial():
start_program = paddle.fluid.Program()
with paddle.static.program_guard(main_program, start_program):
x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32')
auto.shard_tensor(x,
dist_attr={
"process_mesh": auto.ProcessMesh([0]),
"dims_mapping": [-1, -1, -1]
})
auto.shard_tensor(x, auto.ProcessMesh([0], dim_names=["x"]),
[None, None, None])
tmp_0 = x[0]
tmp_1 = x[:, 0, :]
tmp_2 = x[:, :, 1]
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import paddle.fluid as fluid
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.static as static
import paddle.distributed as dist
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
batch_size = 4
epoch_num = 10
hidden_size = 1024
sequence_len = 512
process_mesh1 = ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]],
dim_names=["x", "y"])
process_mesh2 = ProcessMesh(mesh=[0, 1, 2, 3], dim_names=["x"])
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
param_initializer = nn.initializer.Normal(mean=0.0,
std=initializer_range)
self.linear0 = nn.Linear(
d_model,
dim_feedforward,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
self.linear1 = nn.Linear(
dim_feedforward,
d_model,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
def forward(self, input):
auto.shard_tensor(self.linear0.weight, process_mesh1[0], [None, "y"])
linear0 = auto.shard_op(self.linear0, process_mesh1,
[["y", None, None]], [[None, "x", None]])
linear0_out = linear0(input)
gelu = auto.shard_op(F.gelu, process_mesh1, [["y", "x", None], None])
gelu_out = gelu(linear0_out, approximate=True)
auto.shard_tensor(self.linear1.weight, shard_spec=["y", None])
linear1 = auto.shard_op(self.linear1,
process_mesh1[1],
out_shard_specs=[["y", None, None]])
linear1_out = linear1(gelu_out)
return self.linear0, self.linear1, linear0_out, gelu_out, linear1_out
class TestAutoParallelAPI(unittest.TestCase):
def test_api(self):
# input
input = static.data(name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
label = static.data(name="label",
shape=[batch_size, sequence_len, 1],
dtype='float32')
auto.shard_tensor(input, process_mesh1, ["x", None, None])
auto.shard_tensor(label, process_mesh1, ["y", None, None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
with ProcessMesh(process_mesh1.mesh, process_mesh1.dim_names):
linear0, linear1, linear0_out, gelu_out, linear1_out = mlp(input)
default_program = paddle.fluid.default_main_program()
default_dist_context = get_default_distributed_context()
self.assertEqual(len(default_program.blocks[0].ops), 5)
matmul0 = default_program.blocks[0].ops[0]
self.assertEqual(matmul0.type, "matmul_v2")
ewise_add0 = default_program.blocks[0].ops[1]
self.assertEqual(ewise_add0.type, "elementwise_add")
gelu = default_program.blocks[0].ops[2]
self.assertEqual(gelu.type, "gelu")
matmul1 = default_program.blocks[0].ops[3]
self.assertEqual(matmul1.type, "matmul_v2")
ewise_add1 = default_program.blocks[0].ops[4]
self.assertEqual(ewise_add1.type, "elementwise_add")
dist_input = default_dist_context.get_dist_tensor_for_program(input)
self.assertEqual(dist_input.dist_attr.process_mesh, process_mesh1)
self.assertEqual(dist_input.dist_attr.dims_mapping, [0, -1, -1])
self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping"))
dist_input = default_dist_context.get_dist_tensor_for_program(label)
self.assertEqual(dist_input.dist_attr.process_mesh, process_mesh1)
self.assertEqual(dist_input.dist_attr.dims_mapping, [1, -1, -1])
self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping"))
dist_linear0_weight = default_dist_context.get_dist_tensor_for_program(
linear0.weight)
self.assertEqual(dist_linear0_weight.dist_attr.process_mesh,
process_mesh1[0])
self.assertEqual(dist_linear0_weight.dist_attr.dims_mapping, [-1, 0])
self.assertTrue(
dist_linear0_weight.dist_attr.is_annotated("process_mesh"))
self.assertTrue(
dist_linear0_weight.dist_attr.is_annotated("dims_mapping"))
dist_linear1_weight = default_dist_context.get_dist_tensor_for_program(
linear1.weight)
self.assertEqual(dist_linear1_weight.dist_attr.process_mesh,
process_mesh1)
self.assertEqual(dist_linear1_weight.dist_attr.dims_mapping, [1, -1])
self.assertTrue(
dist_linear1_weight.dist_attr.is_annotated("process_mesh"))
self.assertTrue(
dist_linear1_weight.dist_attr.is_annotated("dims_mapping"))
dist_linear1_out = default_dist_context.get_dist_tensor_for_program(
linear1_out)
self.assertEqual(dist_linear1_out.dist_attr.process_mesh, process_mesh1)
self.assertEqual(dist_linear1_out.dist_attr.dims_mapping, [-1, -1, -1])
self.assertTrue(dist_linear1_out.dist_attr.is_annotated("process_mesh"))
self.assertFalse(
dist_linear1_out.dist_attr.is_annotated("dims_mapping"))
dist_op = default_dist_context.get_dist_op_for_program(matmul0)
self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1)
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(input.name)
self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1)
self.assertEqual(tensor_dist_attr.dims_mapping, [1, -1, -1])
self.assertTrue(tensor_dist_attr.is_annotated("process_mesh"))
self.assertTrue(tensor_dist_attr.is_annotated("dims_mapping"))
dist_op = default_dist_context.get_dist_op_for_program(ewise_add0)
self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1)
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(
linear0_out.name)
self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1)
self.assertEqual(tensor_dist_attr.dims_mapping, [-1, 0, -1])
self.assertTrue(tensor_dist_attr.is_annotated("process_mesh"))
self.assertTrue(tensor_dist_attr.is_annotated("dims_mapping"))
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
dist_op = default_dist_context.get_dist_op_for_program(gelu)
self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1)
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(
linear0_out.name)
self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1)
self.assertEqual(tensor_dist_attr.dims_mapping, [1, 0, -1])
self.assertTrue(tensor_dist_attr.is_annotated("process_mesh"))
self.assertTrue(tensor_dist_attr.is_annotated("dims_mapping"))
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(gelu_out.name)
self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1)
self.assertEqual(tensor_dist_attr.dims_mapping, [-1, -1, -1])
self.assertTrue(tensor_dist_attr.is_annotated("process_mesh"))
self.assertFalse(tensor_dist_attr.is_annotated("dims_mapping"))
dist_op = default_dist_context.get_dist_op_for_program(matmul1)
self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1[1])
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(gelu_out.name)
self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1[1])
self.assertEqual(tensor_dist_attr.dims_mapping, [-1, -1, -1])
self.assertTrue(tensor_dist_attr.is_annotated("process_mesh"))
self.assertFalse(tensor_dist_attr.is_annotated("dims_mapping"))
dist_op = default_dist_context.get_dist_op_for_program(ewise_add1)
self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1[1])
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(
linear1_out.name)
self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1[1])
self.assertEqual(tensor_dist_attr.dims_mapping, [0, -1, -1])
self.assertTrue(tensor_dist_attr.is_annotated("process_mesh"))
self.assertTrue(tensor_dist_attr.is_annotated("dims_mapping"))
if __name__ == '__main__':
unittest.main()
......@@ -26,7 +26,6 @@ import paddle.distributed.fleet as fleet
from paddle.io import Dataset
from paddle.static import InputSpec
from paddle.fluid.framework import _non_static_mode
from paddle.distributed.auto_parallel.engine import Engine
from test_to_static import MLPLayer, MyDataset
......@@ -60,14 +59,12 @@ class TestEngineBase(unittest.TestCase):
self.dataset = MyDataset(self.batch_num * self.batch_size)
def init_engine(self):
inputs = InputSpec([self.batch_size, self.hidden_size], 'float32', 'x')
labels = InputSpec([self.batch_size], 'int64', 'label')
# inputs = InputSpec([self.batch_size, self.hidden_size], 'float32', 'x')
# labels = InputSpec([self.batch_size], 'int64', 'label')
self.engine = Engine(model=self.mlp,
inputs_spec=inputs,
labels_spec=labels)
self.engine.prepare(optimizer=self.optimizer,
self.engine = auto.Engine(model=self.mlp,
loss=self.loss,
optimizer=self.optimizer,
metrics=paddle.metric.Accuracy())
......@@ -80,9 +77,9 @@ class TestLRScheduler(TestEngineBase):
def test_lr_scheduler(self):
self.init_engine()
lr = self.engine._optimizer._learning_rate
assert isinstance(lr, paddle.optimizer.lr.LRScheduler)
self.engine.fit(self.dataset, batch_size=self.batch_size)
lr = self.engine._lr_optimizer._learning_rate
assert isinstance(lr, paddle.optimizer.lr.LRScheduler)
class TestGradClipByGlobalNorm(TestEngineBase):
......@@ -94,7 +91,6 @@ class TestGradClipByGlobalNorm(TestEngineBase):
def test_grad_clip(self):
clip = self.engine._optimizer._grad_clip
self.engine.fit(self.dataset, batch_size=self.batch_size)
self.check_program()
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestAMPPass(unittest.TestCase):
def test_mp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "amp_pass_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir",
tmp_dir.name, launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestGradientMergePass(unittest.TestCase):
def test_dp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir,
"gradient_merge_pass_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir",
tmp_dir.name, launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
......@@ -14,123 +14,41 @@
import unittest
import sys
import random
import numpy as np
import paddle
import paddle.distributed.fleet as fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion
from get_gpt_model import generate_model, create_data_holder, FakeDataset
paddle.enable_static()
class FakeDataset:
def __init__(self, num_samples, sequence_len, vocab_size):
self.num_samples = num_samples
self.sequence_len = sequence_len
self.vocab_size = vocab_size
def __getitem__(self, idx):
tokens = np.random.randint(self.vocab_size, size=self.sequence_len)
position_ids = np.arange(self.sequence_len)
attention_mask = np.tril(np.ones(self.sequence_len)).reshape(
(1, self.sequence_len, self.sequence_len)).astype(np.float32)
labels = np.random.randint(self.vocab_size, size=self.sequence_len)
loss_mask = np.ones(self.sequence_len).astype(np.float32)
return tokens, position_ids, attention_mask, labels, loss_mask
def __len__(self):
return self.num_samples
def apply_pass():
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
dist_strategy.qat = True
dist_strategy.qat_configs = {
'channel_wise_abs_max': True,
'weight_bits': 8,
'activation_bits': 8,
'not_quant_pattern': ['skip_quant'],
}
dist_strategy = auto.Strategy()
dist_strategy.auto_mode = "semi"
qat = dist_strategy.qat
qat.enable = True
qat.channel_wise_abs_max = True
qat.weight_bits = 8
qat.activation_bits = 8
qat.not_quant_pattern = ['skip_quant']
return dist_strategy
def create_data_holder(batch_size, sequence_len):
tokens = paddle.static.InputSpec(name="tokens",
shape=[batch_size, sequence_len],
dtype='int64')
position_ids = paddle.static.InputSpec(name="position_ids",
shape=[batch_size, sequence_len],
dtype='int64')
attention_mask = paddle.static.InputSpec(
name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float32')
labels = paddle.static.InputSpec(name="labels",
shape=[batch_size, sequence_len],
dtype='int64')
loss_mask = paddle.static.InputSpec(name="loss_mask",
shape=[batch_size, sequence_len],
dtype='float32')
return [tokens, position_ids, attention_mask], [labels, loss_mask]
def get_gpt_model():
modeling.init_global()
modeling._global_parallel_strategy = "serial"
modeling._global_process_mesh = auto.ProcessMesh(mesh=[0])
gpt = GPTModel(vocab_size=1000,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=8,
intermediate_size=256,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3)
model = GPTForPretraining(gpt,
vocab_size=1000,
hidden_size=64,
initializer_range=0.02)
criterion = GPTPretrainingCriterion()
return model, criterion
class TestQuantizationPass(unittest.TestCase):
def test_qat_pass(self):
batch_size = 8
batch_num = 10
sequence_len = 512
vocab_size = 1000
strategy = apply_pass()
model, loss = get_gpt_model()
model, loss = generate_model("serial")
opt = paddle.optimizer.AdamW(learning_rate=0.00001)
inputs_spec, labels_spec = create_data_holder(batch_size=batch_size,
sequence_len=sequence_len)
engine = Engine(model, inputs_spec, labels_spec, strategy=strategy)
engine.prepare(optimizer=opt, loss=loss)
dataset = FakeDataset(batch_size * batch_num, sequence_len, vocab_size)
engine.fit(train_data=dataset, batch_size=batch_size)
engine = auto.Engine(model, loss, opt, strategy=strategy)
dataset = FakeDataset(batch_size * batch_num)
engine.fit(dataset, 3, batch_size=batch_size)
self.check_program(engine.main_program)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestRecomputePass(unittest.TestCase):
def test_mp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "recompute_pass_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir",
tmp_dir.name, launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class TestShardingPass(unittest.TestCase):
def test_dp2sharding2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "sharding_pass_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = [sys.executable, "-u"] + coverage_args + [
"-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir",
tmp_dir.name, launch_model_path
]
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.static as static
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
batch_size = 4
epoch_num = 10
hidden_size = 1024
sequence_len = 512
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
param_initializer = nn.initializer.Normal(mean=0.0,
std=initializer_range)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.linear0 = nn.Linear(
d_model,
dim_feedforward,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
self.linear1 = nn.Linear(
dim_feedforward,
d_model,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
def forward(self, input):
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
return out
class TestProcessMesh(unittest.TestCase):
def test_construction(self):
mesh = [[0, 1, 2], [3, 4, 5]]
process_mesh = ProcessMesh(mesh, dim_names=["x", "y"])
self.assertEqual(process_mesh.shape, [2, 3])
self.assertEqual(process_mesh.process_ids, [0, 1, 2, 3, 4, 5])
self.assertEqual(process_mesh.dim_names, ["x", "y"])
self.assertEqual(process_mesh.ndim, 2)
self.assertEqual(process_mesh, process_mesh)
self.assertEqual(str(process_mesh), str(process_mesh))
sub_process_mesh1 = process_mesh[0]
self.assertEqual(sub_process_mesh1.shape, [3])
self.assertEqual(sub_process_mesh1.process_ids, [0, 1, 2])
self.assertEqual(sub_process_mesh1.dim_names, ["y"])
self.assertEqual(sub_process_mesh1.ndim, 1)
sub_process_mesh2 = process_mesh[:, 1]
self.assertEqual(sub_process_mesh2.shape, [2])
self.assertEqual(sub_process_mesh2.process_ids, [1, 4])
self.assertEqual(sub_process_mesh2.dim_names, ["x"])
self.assertEqual(sub_process_mesh2.ndim, 1)
sub_process_mesh3 = sub_process_mesh2[:]
self.assertEqual(sub_process_mesh3.shape, [2])
self.assertEqual(sub_process_mesh3.process_ids, [1, 4])
self.assertEqual(sub_process_mesh3.dim_names, ["x"])
self.assertEqual(sub_process_mesh3.ndim, 1)
sub_process_mesh4 = process_mesh[1, 1]
self.assertEqual(sub_process_mesh4.shape, [1])
self.assertEqual(sub_process_mesh4.process_ids, [4])
self.assertEqual(sub_process_mesh4.dim_names, ["d0"])
self.assertEqual(sub_process_mesh4.ndim, 1)
def test_context_manager(self):
mesh = np.array([1, 2, 3, 4])
input = static.data(name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
label = static.data(name="label",
shape=[batch_size, sequence_len, 1],
dtype='float32')
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
with ProcessMesh(mesh, "d"):
out = mlp(input)
default_program = paddle.fluid.default_main_program()
default_dist_context = get_default_distributed_context()
for block in default_program.blocks:
for tensor in block.vars.values():
dist_tensor = default_dist_context.get_dist_tensor_for_program(
tensor)
if dist_tensor is not None:
self.assertEqual(dist_tensor.dist_attr.process_mesh,
ProcessMesh(mesh))
for op in block.ops:
dist_op = default_dist_context.get_dist_op_for_program(op)
if dist_op is not None:
self.assertEqual(dist_op.dist_attr.process_mesh,
ProcessMesh(mesh))
if __name__ == "__main__":
unittest.main()
......@@ -13,7 +13,8 @@
# limitations under the License
import unittest
from paddle.distributed.auto_parallel.process_mesh_v2 import ProcessMesh
from paddle.distributed.auto_parallel.process_mesh_v2 import (
ProcessMesh, compute_compatible_process_mesh, merge_process_mesh)
class TestProcessMesh(unittest.TestCase):
......@@ -39,6 +40,54 @@ class TestProcessMesh(unittest.TestCase):
self.assertNotEqual(process_mesh, process_mesh2)
self.assertEqual(str(process_mesh), str(process_mesh))
def test_compute_compatible_process_mesh(self):
process_mesh1 = ProcessMesh([[0, 1, 2], [3, 4, 5]],
dim_names=["x", "y"])
compatible_process_mesh = compute_compatible_process_mesh(
[process_mesh1, None])
self.assertEqual(compatible_process_mesh, process_mesh1)
compatible_process_mesh = compute_compatible_process_mesh(
[None, process_mesh1])
self.assertEqual(compatible_process_mesh, process_mesh1)
process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]])
compatible_process_mesh = compute_compatible_process_mesh(
[process_mesh1, process_mesh2])
self.assertEqual(compatible_process_mesh, process_mesh1)
self.assertEqual(compatible_process_mesh, process_mesh2)
process_mesh2 = ProcessMesh([[0, 1, 2, 3, 4, 5]])
compatible_process_mesh = compute_compatible_process_mesh(
[process_mesh1, process_mesh2])
self.assertEqual(compatible_process_mesh, process_mesh1)
process_mesh2 = ProcessMesh([[0, 1, 2]])
compatible_process_mesh = compute_compatible_process_mesh(
[process_mesh1, process_mesh2])
self.assertEqual(compatible_process_mesh, process_mesh1)
def test_merge_process_mesh(self):
process_mesh1 = ProcessMesh([[0, 1, 2], [3, 4, 5]],
dim_names=["x", "y"])
merged_process_mesh = merge_process_mesh([process_mesh1, None])
print(merged_process_mesh)
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
merged_process_mesh = merge_process_mesh([None, process_mesh1])
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]])
merged_process_mesh = merge_process_mesh([process_mesh1, process_mesh2])
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
process_mesh2 = ProcessMesh([[0, 1, 2]])
merged_process_mesh = merge_process_mesh([process_mesh1, process_mesh2])
self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5]))
process_mesh2 = ProcessMesh([[6, 7]])
merged_process_mesh = merge_process_mesh([process_mesh1, process_mesh2])
self.assertEqual(merged_process_mesh,
ProcessMesh([0, 1, 2, 3, 4, 5, 6, 7]))
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
# import yaml
import unittest
import paddle.distributed.auto_parallel as auto
class TestStrategy(unittest.TestCase):
def test_default_config(self):
strategy = auto.Strategy()
recompute = strategy.recompute
self.assertEqual(recompute.enable, False)
self.assertEqual(recompute.checkpoints, None)
amp = strategy.amp
self.assertEqual(amp.enable, False)
self.assertAlmostEqual(amp.init_loss_scaling, 32768.0)
self.assertEqual(amp.incr_every_n_steps, 1000)
self.assertEqual(amp.decr_every_n_nan_or_inf, 2)
self.assertAlmostEqual(amp.incr_ratio, 2.0)
self.assertAlmostEqual(amp.decr_ratio, 0.8)
self.assertEqual(amp.use_dynamic_loss_scaling, True)
self.assertEqual(amp.custom_black_list, [])
self.assertEqual(amp.custom_white_list, [])
self.assertEqual(amp.custom_black_varnames, [])
self.assertEqual(amp.use_pure_fp16, False)
self.assertEqual(amp.use_fp16_guard, True)
self.assertEqual(amp.use_optimizer_fp16, False)
sharding = strategy.sharding
self.assertEqual(sharding.enable, False)
self.assertEqual(sharding.stage, 1)
self.assertEqual(sharding.degree, 8)
self.assertAlmostEqual(sharding.segment_broadcast_MB, 32.0)
self.assertEqual(sharding.enable_tuning, False)
self.assertEqual(sharding.tuning_range, [])
gradient_merge = strategy.gradient_merge
self.assertEqual(gradient_merge.enable, False)
self.assertEqual(gradient_merge.k_steps, 1)
self.assertEqual(gradient_merge.avg, True)
qat = strategy.qat
self.assertEqual(qat.enable, False)
self.assertEqual(qat.channel_wise_abs_max, True)
self.assertEqual(qat.weight_bits, 8)
self.assertEqual(qat.activation_bits, 8)
self.assertEqual(qat.not_quant_pattern, ['skip_quant'])
self.assertEqual(qat.algo, None)
tuning = strategy.tuning
self.assertEqual(tuning.enable, False)
self.assertEqual(tuning.batch_size, 1)
self.assertEqual(tuning.dataset, None)
self.assertEqual(tuning.profile_start_step, 1)
self.assertEqual(tuning.profile_end_step, 1)
self.assertEqual(tuning.run_after_tuning, True)
self.assertEqual(tuning.verbose, True)
def test_modify_config(self):
strategy = auto.Strategy()
recompute = strategy.recompute
recompute.enable = True
recompute.checkpoints = ["x"]
self.assertEqual(recompute.enable, True)
self.assertEqual(recompute.checkpoints, ["x"])
amp = strategy.amp
amp.enable = True
amp.init_loss_scaling = 16384.0
amp.incr_every_n_steps = 2000
amp.decr_every_n_nan_or_inf = 4
amp.incr_ratio = 4.0
amp.decr_ratio = 0.4
amp.use_dynamic_loss_scaling = False
amp.custom_white_list = ["x"]
amp.custom_black_list = ["y"]
amp.custom_black_varnames = ["z"]
amp.use_pure_fp16 = True
amp.use_fp16_guard = False
amp.use_optimizer_fp16 = True
self.assertEqual(amp.enable, True)
self.assertAlmostEqual(amp.init_loss_scaling, 16384.0)
self.assertEqual(amp.incr_every_n_steps, 2000)
self.assertEqual(amp.decr_every_n_nan_or_inf, 4)
self.assertAlmostEqual(amp.incr_ratio, 4.0)
self.assertAlmostEqual(amp.decr_ratio, 0.4)
self.assertEqual(amp.use_dynamic_loss_scaling, False)
self.assertEqual(amp.custom_white_list, ["x"])
self.assertEqual(amp.custom_black_list, ["y"])
self.assertEqual(amp.custom_black_varnames, ["z"])
self.assertEqual(amp.use_pure_fp16, True)
self.assertEqual(amp.use_fp16_guard, False)
self.assertEqual(amp.use_optimizer_fp16, True)
sharding = strategy.sharding
sharding.enable = True
sharding.stage = 2
sharding.degree = 2
sharding.segment_broadcast_MB = 64.0
sharding.enable_tuning = True
sharding.tuning_range = [1, 2, 3]
self.assertEqual(sharding.enable, True)
self.assertEqual(sharding.stage, 2)
self.assertEqual(sharding.degree, 2)
self.assertAlmostEqual(sharding.segment_broadcast_MB, 64.0)
self.assertEqual(sharding.enable_tuning, True)
self.assertEqual(sharding.tuning_range, [1, 2, 3])
gradient_merge = strategy.gradient_merge
gradient_merge.enable = True
gradient_merge.k_steps = 4
gradient_merge.avg = False
self.assertEqual(gradient_merge.enable, True)
self.assertEqual(gradient_merge.k_steps, 4)
self.assertEqual(gradient_merge.avg, False)
# def test_file_config(self):
# yaml_data = """
# all_ranks: false
# amp:
# custom_black_list:
# - y
# custom_black_varnames:
# - z
# custom_white_list:
# - x
# decr_every_n_nan_or_inf: 4
# decr_ratio: 0.4
# enable: false
# incr_every_n_steps: 2000
# incr_ratio: 4.0
# init_loss_scaling: 16384.0
# use_dynamic_loss_scaling: false
# use_fp16_guard: false
# use_optimizer_fp16: true
# use_pure_fp16: true
# auto_mode: semi
# gradient_merge:
# avg: false
# enable: false
# k_steps: 4
# gradient_scale: true
# qat:
# activation_bits: 8
# algo: null
# channel_wise_abs_max: true
# enable: false
# not_quant_pattern:
# - skip_quant
# weight_bits: 8
# recompute:
# checkpoints: null
# enable: false
# enable_tuning: false
# return_numpy: true
# seed: null
# sharding:
# enable: false
# enable_tuning: true
# segment_broadcast_MB: 64.0
# degree: 8
# stage: 2
# tuning_range: None
# split_data: false
# tuning:
# batch_size: 1
# dataset: null
# enable: false
# profile_end_step: 1
# profile_start_step: 1
# run_after_tuning: true
# verbose: true
# use_cache: true
# """
# yaml_path = "./strategy.yml"
# yaml_dict = yaml.load(yaml_data, Loader=yaml.Loader)
# with open(yaml_path, 'w') as outfile:
# yaml.dump(yaml_dict, outfile, default_flow_style=False)
# strategy = auto.Strategy(yaml_path)
# self.assertEqual(yaml_dict, strategy.to_dict())
# # Remove the created file
# if os.path.exists(yaml_path):
# os.remove(yaml_path)
if __name__ == '__main__':
unittest.main()
......@@ -27,7 +27,6 @@ from paddle import LazyGuard
from paddle.io import Dataset
from paddle.static import InputSpec
from paddle.fluid.framework import _non_static_mode
from paddle.distributed.auto_parallel.engine import Engine
from paddle.distributed.auto_parallel.helper import ProgramHelper
batch_size = 4
......@@ -140,23 +139,19 @@ class TestToStatic(unittest.TestCase):
dataset = MyDataset(batch_num * batch_size)
inputs = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels = InputSpec([batch_size], 'int64', 'label')
# inputs = InputSpec([batch_size, hidden_size], 'float32', 'x')
# labels = InputSpec([batch_size], 'int64', 'label')
engine = Engine(model=mlp,
inputs_spec=inputs,
labels_spec=labels,
strategy=None)
assert _non_static_mode() == True
engine.prepare(optimizer=optimizer,
engine = auto.Engine(model=mlp,
loss=loss,
metrics=paddle.metric.Accuracy())
assert _non_static_mode() == False
optimizer=optimizer,
metrics=paddle.metric.Accuracy(),
strategy=None)
engine.fit(dataset, batch_size=batch_size)
engine.evaluate(dataset, batch_size=batch_size)
engine.predict(dataset, batch_size=batch_size)
assert _non_static_mode() == False
class TestLazyInit(unittest.TestCase):
......
......@@ -36,7 +36,7 @@ batch_size = 4
epoch_num = 10
hidden_size = 1024
sequence_len = 512
_g_process_mesh = [[0, 1], [2, 3]]
_g_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y'])
def get_random_inputs_and_labels(input_shape, label_shape):
......@@ -84,18 +84,12 @@ class MLPLayer(nn.Layer):
def forward(self, input):
out = self.norm(input)
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear0.weight, _g_process_mesh[:, 0],
[None, 'x'])
out = self.linear0(out)
out = F.gelu(out, approximate=True)
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _g_process_mesh[1],
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.linear1.weight, _g_process_mesh[:, 1],
['x', None])
out = self.linear1(out)
return out
......@@ -155,16 +149,8 @@ def get_program():
dataloader.set_batch_generator(batch_generator_creator(),
places=paddle.static.cuda_places())
# data dist_attr
auto.shard_tensor(input,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, -1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": _g_process_mesh[0],
"dims_mapping": [-1, -1, -1]
})
auto.shard_tensor(input, _g_process_mesh[:, 0], [None, None, None])
auto.shard_tensor(label, _g_process_mesh[:, 0], [None, None, None])
mlp_start = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......
......@@ -37,7 +37,7 @@ batch_size = 4
epoch_num = 10
hidden_size = 1024
sequence_len = 512
_g_process_mesh = auto.ProcessMesh([0, 1])
_g_process_mesh = auto.ProcessMesh([0, 1], dim_names=['x'])
def get_random_inputs_and_labels(input_shape, label_shape):
......@@ -85,61 +85,21 @@ class MLPLayer(nn.Layer):
def forward(self, input):
auto.shard_tensor(self.norm.weight,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1]
})
auto.shard_tensor(self.norm.bias,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1]
})
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear0.bias,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [0]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.linear1.bias,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1]
})
auto.shard_tensor(self.norm.weight, _g_process_mesh, [None])
auto.shard_tensor(self.norm.bias, _g_process_mesh, [None])
auto.shard_tensor(self.linear0.weight, _g_process_mesh, [None, 'x'])
auto.shard_tensor(self.linear0.bias, _g_process_mesh, ['x'])
auto.shard_tensor(self.linear1.weight, _g_process_mesh, ['x', None])
auto.shard_tensor(self.linear1.bias, _g_process_mesh, [None])
out = self.norm(input)
auto.shard_tensor(out,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
auto.shard_tensor(out, _g_process_mesh, [None, None, None])
out = self.linear0(out)
auto.shard_tensor(out,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, 0]
})
auto.shard_tensor(out, _g_process_mesh, [None, None, 'x'])
out = F.gelu(out, approximate=True)
auto.shard_tensor(out,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, 0]
})
auto.shard_tensor(out, _g_process_mesh, [None, None, 'x'])
out = self.linear1(out)
auto.shard_tensor(out,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
auto.shard_tensor(out, _g_process_mesh, [None, None, None])
return out
......@@ -155,21 +115,13 @@ def get_program():
# 循环计数器
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
auto.shard_tensor(i,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1]
})
auto.shard_tensor(i, _g_process_mesh, [None])
# 循环次数
loop_len = fluid.layers.fill_constant(shape=[1],
dtype='int64',
value=epoch_num)
auto.shard_tensor(loop_len,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1]
})
auto.shard_tensor(loop_len, _g_process_mesh, [None])
# input
input = static.data(name="input",
......@@ -188,25 +140,13 @@ def get_program():
dataloader.set_batch_generator(batch_generator_creator(),
places=paddle.static.cuda_places())
# data dist_attr
auto.shard_tensor(input,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
auto.shard_tensor(input, _g_process_mesh, [None, None, None])
auto.shard_tensor(label, _g_process_mesh, [None, None, None])
# fill constant bsz like
tmp = paddle.fluid.layers.fill_constant_batch_size_like(
input=input, shape=[-1, 16, 0, 48], dtype='float32', value=0)
auto.shard_tensor(tmp,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, 0, -1, -1]
})
auto.shard_tensor(tmp, _g_process_mesh, [None, 'x', None, None])
# model
mlp_start = MLPLayer(hidden_size=hidden_size,
......@@ -216,28 +156,21 @@ def get_program():
pred = mlp_start(input)
input_array = fluid.layers.array_write(pred, i)
auto.shard_tensor(input_array,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
# TODO: check whether this annotation is needed
# auto.shard_tensor(input_array,
# dist_attr={
# "process_mesh": _g_process_mesh,
# "dims_mapping": [-1, -1, -1]
# })
cond = fluid.layers.less_than(x=i, y=loop_len)
auto.shard_tensor(cond,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1]
})
auto.shard_tensor(cond, _g_process_mesh, [None])
while_op = fluid.layers.While(cond=cond)
with while_op.block():
pre_input = fluid.layers.array_read(array=input_array, i=i)
auto.shard_tensor(pre_input,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
auto.shard_tensor(pre_input, _g_process_mesh, [None, None, None])
mlp_while = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......@@ -251,11 +184,7 @@ def get_program():
fluid.layers.less_than(x=i, y=loop_len, cond=cond)
end_pred = fluid.layers.array_read(array=input_array, i=i)
auto.shard_tensor(end_pred,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
auto.shard_tensor(end_pred, _g_process_mesh, [None, None, None])
mlp_end = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......@@ -264,18 +193,10 @@ def get_program():
pred = mlp_end(end_pred)
error_cost = paddle.nn.functional.square_error_cost(pred, label)
auto.shard_tensor(error_cost,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
auto.shard_tensor(error_cost, _g_process_mesh, [None, None, None])
loss = paddle.mean(error_cost)
auto.shard_tensor(loss,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1]
})
auto.shard_tensor(loss, _g_process_mesh, [None])
return train_program, start_program, dataloader, i, loss
......
......@@ -67,38 +67,18 @@ class MLPLayer(nn.Layer):
def forward(self, input):
if _global_parallel_strategy == "pp":
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None])
auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None])
elif _global_parallel_strategy == "mp":
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.linear0.weight, _global_process_mesh,
[None, "x"])
auto.shard_tensor(self.linear1.weight, _global_process_mesh,
["x", None])
elif _global_parallel_strategy == "dp":
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear0.weight, _global_process_mesh,
[None, None])
auto.shard_tensor(self.linear1.weight, _global_process_mesh,
[None, None])
out = self.norm(input)
out = self.linear0(out)
......@@ -120,28 +100,12 @@ def mlp_forward(train_program, start_program):
dtype='float32')
if _global_parallel_strategy == "pp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(input, PP_MESH_0, [None, None])
auto.shard_tensor(label, PP_MESH_1, [None, None])
elif _global_parallel_strategy == "dp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(input, _global_process_mesh, ["x", None])
elif _global_parallel_strategy == "mp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(input, _global_process_mesh, [None, None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......@@ -186,7 +150,7 @@ class TestMLPAutoConvert(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh([0, 1])
_global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
input = np.random.random(size=(80, 64)).astype('float32')
label = np.random.random(size=(80, 1)).astype('float32')
......@@ -212,11 +176,11 @@ class TestMLPAutoConvert(unittest.TestCase):
set_default_distributed_context(None)
_global_parallel_strategy = "pp"
_global_process_mesh = auto.ProcessMesh([0, 1])
_global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
global PP_MESH_0
PP_MESH_0 = auto.ProcessMesh(mesh=[0])
PP_MESH_0 = auto.ProcessMesh(mesh=[0], dim_names=["pp0"])
global PP_MESH_1
PP_MESH_1 = auto.ProcessMesh(mesh=[1])
PP_MESH_1 = auto.ProcessMesh(mesh=[1], dim_names=["pp1"])
dist_main_prog_load, dist_start_prog_load, loss_load = get_distributed_program(
)
......@@ -268,7 +232,7 @@ class TestMLPAutoConvert2(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "pp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh([0, 1])
_global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
global PP_MESH_0
PP_MESH_0 = auto.ProcessMesh(mesh=[0])
global PP_MESH_1
......@@ -303,7 +267,7 @@ class TestMLPAutoConvert2(unittest.TestCase):
set_default_distributed_context(None)
_global_parallel_strategy = "mp"
_global_process_mesh = auto.ProcessMesh([0, 1])
_global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
dist_main_prog_load, dist_start_prog_load, loss_load = get_distributed_program(
)
......@@ -350,7 +314,7 @@ class TestMLPAutoConvertInvalid(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh([0, 1])
_global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
dist_main_prog, _, _ = get_distributed_program()
with self.assertRaises(TypeError):
save_distributed_checkpoint(dist_main_prog, [""], [""],
......
......@@ -38,7 +38,7 @@ class TestDataUnshard(unittest.TestCase):
def create_model(train_program, start_program):
with paddle.static.program_guard(train_program, start_program):
MESH_0 = auto.ProcessMesh([0, 1])
MESH_0 = auto.ProcessMesh([0, 1], dim_names=["x"])
input = paddle.static.data(name='input', shape=[2, 8])
label = paddle.static.data(name='label', shape=[2, 8])
......@@ -47,26 +47,10 @@ class TestDataUnshard(unittest.TestCase):
linear0 = nn.Linear(8, 8, weight_attr)
linear1 = nn.Linear(8, 8, weight_attr)
auto.shard_tensor(input,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(linear0.weight,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(linear1.weight,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(input, MESH_0, ["x", None])
auto.shard_tensor(label, MESH_0, ["x", None])
auto.shard_tensor(linear0.weight, MESH_0, [None, None])
auto.shard_tensor(linear1.weight, MESH_0, [None, None])
linear0_out = linear0(input)
gelu_out = F.gelu(linear0_out)
......@@ -124,7 +108,7 @@ class TestDataUnshard(unittest.TestCase):
def create_model(train_program, start_program):
with paddle.static.program_guard(train_program, start_program):
MESH_0 = auto.ProcessMesh([0, 1])
MESH_0 = auto.ProcessMesh([0, 1], dim_names=["x"])
input = paddle.static.data(name='input', shape=[8, 8])
label = paddle.static.data(name='label', shape=[8, 8])
......@@ -133,27 +117,10 @@ class TestDataUnshard(unittest.TestCase):
linear0 = nn.Linear(8, 8, weight_attr)
linear1 = nn.Linear(8, 8, weight_attr)
auto.shard_tensor(input,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(linear0.weight,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(linear1.weight,
dist_attr={
"process_mesh": MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(input, MESH_0, [None, None])
auto.shard_tensor(label, MESH_0, [None, None])
auto.shard_tensor(linear0.weight, MESH_0, [None, "x"])
auto.shard_tensor(linear1.weight, MESH_0, ["x", None])
linear0_out = linear0(input)
gelu_out = F.gelu(linear0_out)
......
......@@ -114,30 +114,18 @@ class MultiHeadAttention(nn.Layer):
"""
q = self.q_proj(query)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.q_proj.weight, _global_process_mesh,
[None, "x"])
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.q_proj.weight, _global_process_mesh,
[None, "y"])
elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.q_proj.weight, MPPP_MESH_LIST[self.mesh_idx],
[None, "x"])
elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1]
})
DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"])
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
if isinstance(cache, self.StaticCache):
......@@ -165,56 +153,30 @@ class MultiHeadAttention(nn.Layer):
"""
k = self.k_proj(key)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.k_proj.weight, _global_process_mesh,
[None, "x"])
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.k_proj.weight, _global_process_mesh,
[None, "y"])
elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.k_proj.weight, MPPP_MESH_LIST[self.mesh_idx],
[None, "x"])
elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1]
})
DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"])
v = self.v_proj(value)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.v_proj.weight, _global_process_mesh,
[None, "x"])
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.v_proj.weight, _global_process_mesh,
[None, "y"])
elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.v_proj.weight, MPPP_MESH_LIST[self.mesh_idx],
[None, "x"])
elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1]
})
DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"])
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
......@@ -287,30 +249,18 @@ class MultiHeadAttention(nn.Layer):
# project to output
out = self.out_proj(out)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.out_proj.weight, _global_process_mesh,
["x", None])
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
auto.shard_tensor(self.out_proj.weight, _global_process_mesh,
["y", None])
elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [0, -1]
})
MPPP_MESH_LIST[self.mesh_idx], ["x", None])
elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [1, -1]
})
DPMPPP_MESH_LIST[self.mesh_idx], ["y", None])
outs = [out]
if self.need_weights:
outs.append(weights)
......@@ -352,96 +302,53 @@ class TransformerDecoder(nn.Layer):
new_caches = []
self.checkpoints = []
if _global_parallel_strategy == "pp":
auto.shard_tensor(output,
dist_attr={
"process_mesh":
PP_MESH_LIST[0],
"dims_mapping":
[-1 for i in range(len(output.shape))]
})
auto.shard_tensor(output, PP_MESH_LIST[0],
[None for i in range(len(output.shape))])
if _global_parallel_strategy == "dp_pp":
auto.shard_tensor(output,
dist_attr={
"process_mesh":
DPPP_MESH_LIST[0],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
auto.shard_tensor(output, DPPP_MESH_LIST[0], ["x"].extends(
[None for i in range(len(output.shape) - 1)]))
if _global_parallel_strategy == "mp_pp":
auto.shard_tensor(output,
dist_attr={
"process_mesh":
MPPP_MESH_LIST[0],
"dims_mapping": [-1] +
[-1 for i in range(len(output.shape) - 1)]
})
auto.shard_tensor(output, MPPP_MESH_LIST[0],
[None for i in range(len(output.shape))])
if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(output,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[0],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
auto.shard_tensor(output, DPMPPP_MESH_LIST[0], ["x"].extends(
[None for i in range(len(output.shape) - 1)]))
for i, mod in enumerate(self.layers):
if cache is None:
if use_cache:
if _global_parallel_strategy == "pp":
output, new_cache = auto.shard_op(
mod,
dist_attr={
"process_mesh": PP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)
mod, PP_MESH_LIST[mod.mesh_idx])(output, memory,
tgt_mask,
use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh":
PP_MESH_LIST[mod.mesh_idx],
"dims_mapping":
[-1 for i in range(len(output.shape))]
})
output, PP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))])
elif _global_parallel_strategy == "dp_pp":
output, new_cache = auto.shard_op(
mod,
dist_attr={
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)
mod, DPPP_MESH_LIST[mod.mesh_idx])(output, memory,
tgt_mask,
use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh":
DPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
output, DPPP_MESH_LIST[mod.mesh_idx], ["x"].extends(
[None for i in range(len(output.shape) - 1)]))
elif _global_parallel_strategy == "mp_pp":
output, new_cache = auto.shard_op(
mod,
dist_attr={
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)
mod, MPPP_MESH_LIST[mod.mesh_idx])(output, memory,
tgt_mask,
use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh":
MPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [-1] +
[-1 for i in range(len(output.shape) - 1)]
})
output, MPPP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))])
elif _global_parallel_strategy == "dp_mp_pp":
output, new_cache = auto.shard_op(
mod,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)
DPMPPP_MESH_LIST[mod.mesh_idx])(output, memory,
tgt_mask, use_cache,
cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
output, DPMPPP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))])
else:
output, new_cache = mod(output,
memory,
......@@ -451,64 +358,36 @@ class TransformerDecoder(nn.Layer):
new_caches.append(new_cache)
else:
if _global_parallel_strategy == "pp":
output = auto.shard_op(mod,
dist_attr={
"process_mesh":
PP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask,
use_cache, cache)
output = auto.shard_op(mod, PP_MESH_LIST[mod.mesh_idx])(
output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh":
PP_MESH_LIST[mod.mesh_idx],
"dims_mapping":
[-1 for i in range(len(output.shape))]
})
output, PP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))])
elif _global_parallel_strategy == "dp_pp":
output = auto.shard_op(mod,
dist_attr={
"process_mesh":
DPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask,
output = auto.shard_op(
mod, DPPP_MESH_LIST[mod.mesh_idx])(output, memory,
tgt_mask,
use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh":
DPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
output, DPPP_MESH_LIST[mod.mesh_idx], ["x"].extends(
[None for i in range(len(output.shape) - 1)]))
elif _global_parallel_strategy == "mp_pp":
output = auto.shard_op(mod,
dist_attr={
"process_mesh":
MPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask,
output = auto.shard_op(
mod, MPPP_MESH_LIST[mod.mesh_idx])(output, memory,
tgt_mask,
use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh":
MPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [-1] +
[-1 for i in range(len(output.shape) - 1)]
})
output, MPPP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))])
elif _global_parallel_strategy == "dp_mp_pp":
output = auto.shard_op(
mod,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)
output = auto.shard_op(mod,
DPMPPP_MESH_LIST[mod.mesh_idx])(
output, memory, tgt_mask,
use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping": [0] +
[-1 for i in range(len(output.shape) - 1)]
})
output, DPMPPP_MESH_LIST[mod.mesh_idx],
["x"].extends(
[None for i in range(len(output.shape) - 1)]))
else:
output = mod(output,
memory,
......@@ -519,58 +398,33 @@ class TransformerDecoder(nn.Layer):
if _global_parallel_strategy == "pp":
output, new_cache = auto.shard_op(
mod,
dist_attr={"process_mesh": PP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache,
cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh": PP_MESH_LIST[mod.mesh_idx],
"dims_mapping":
[-1 for i in range(len(output.shape))]
})
PP_MESH_LIST[mod.mesh_idx])(output, memory, tgt_mask,
use_cache, cache)
auto.shard_tensor(output, PP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))])
elif _global_parallel_strategy == "dp_pp":
output, new_cache = auto.shard_op(
mod,
dist_attr={
"process_mesh": DPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh":
DPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping":
[0] + [-1 for i in range(len(output.shape) - 1)]
})
DPPP_MESH_LIST[mod.mesh_idx])(output, memory, tgt_mask,
use_cache, cache)
auto.shard_tensor(output, DPPP_MESH_LIST[mod.mesh_idx], [
"x"
].extends([None for i in range(len(output.shape) - 1)]))
elif _global_parallel_strategy == "mp_pp":
output, new_cache = auto.shard_op(
mod,
dist_attr={
"process_mesh": MPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh":
MPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping":
[-1] + [-1 for i in range(len(output.shape) - 1)]
})
MPPP_MESH_LIST[mod.mesh_idx])(output, memory, tgt_mask,
use_cache, cache)
auto.shard_tensor(output, MPPP_MESH_LIST[mod.mesh_idx],
[None for i in range(len(output.shape))])
elif _global_parallel_strategy == "dp_mp_pp":
output, new_cache = auto.shard_op(
mod,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx]
})(output, memory, tgt_mask, use_cache, cache)
auto.shard_tensor(
output,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[mod.mesh_idx],
"dims_mapping":
[0] + [-1 for i in range(len(output.shape) - 1)]
})
mod, DPMPPP_MESH_LIST[mod.mesh_idx])(output, memory,
tgt_mask,
use_cache, cache)
auto.shard_tensor(output, DPMPPP_MESH_LIST[mod.mesh_idx], [
"x"
].extends([None for i in range(len(output.shape) - 1)]))
else:
output, new_cache = mod(output,
memory,
......@@ -661,55 +515,30 @@ class TransformerDecoderLayer(nn.Layer):
if self.normalize_before:
tgt = self.norm2(tgt)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear1.weight, _global_process_mesh,
[None, "x"])
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight, _global_process_mesh,
[None, "y"])
elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 0]
})
MPPP_MESH_LIST[self.mesh_idx], [None, "x"])
if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [-1, 1]
})
DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"])
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.linear2.weight, _global_process_mesh,
["x", None])
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
auto.shard_tensor(self.linear2.weight, _global_process_mesh,
["y", None])
elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [0, -1]
})
MPPP_MESH_LIST[self.mesh_idx], ["x", None])
elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[self.mesh_idx],
"dims_mapping": [1, -1]
})
DPMPPP_MESH_LIST[self.mesh_idx], ["y", None])
tgt = self.dropout2(
self.linear2(F.gelu(self.linear1(tgt), approximate=True)))
tgt = residual + tgt
......@@ -757,29 +586,18 @@ class GPTEmbeddings(nn.Layer):
position_ids = seq_length - ones
input_embedings = self.word_embeddings(input_ids)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.word_embeddings.weight, _global_process_mesh,
["x", None])
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
auto.shard_tensor(self.word_embeddings.weight, _global_process_mesh,
["y", None])
elif _global_parallel_strategy == "mp_pp":
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": MPPP_MESH_LIST[0],
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.word_embeddings.weight, MPPP_MESH_LIST[0],
["x", None])
elif _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": DPMPPP_MESH_LIST[0],
"dims_mapping": [1, -1]
})
auto.shard_tensor(self.word_embeddings.weight, DPMPPP_MESH_LIST[0],
["y", None])
position_embeddings = self.position_embeddings(position_ids)
embeddings = input_embedings + position_embeddings
embeddings = self.dropout(embeddings)
......@@ -868,29 +686,14 @@ class GPTModel(nn.Layer):
embedding_output = self.embeddings(input_ids=input_ids,
position_ids=position_ids)
if _global_parallel_strategy == "pp":
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh":
PP_MESH_LIST[0],
"dims_mapping":
[-1 for i in range(len(input_ids.shape))]
})
auto.shard_tensor(input_ids, PP_MESH_LIST[0],
[None for i in range(len(input_ids.shape))])
if _global_parallel_strategy == "dp_pp":
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh":
DPPP_MESH_LIST[0],
"dims_mapping": [0] +
[-1 for i in range(len(input_ids.shape) - 1)]
})
auto.shard_tensor(input_ids, DPPP_MESH_LIST[0], ["x"].extends(
[None for i in range(len(input_ids.shape) - 1)]))
if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh":
DPMPPP_MESH_LIST[0],
"dims_mapping": [0] +
[-1 for i in range(len(input_ids.shape) - 1)]
})
auto.shard_tensor(input_ids, DPMPPP_MESH_LIST[0], ["x"].extends(
[None for i in range(len(input_ids.shape) - 1)]))
encoder_outputs = self.decoder(embedding_output,
memory=None,
tgt_mask=attention_mask,
......@@ -923,6 +726,10 @@ class GPTForPretraining(nn.Layer):
masked_positions=None,
use_cache=False,
cache=None):
input_ids.stop_gradient = True
position_ids.stop_gradient = True
attention_mask.stop_gradient = True
outputs = self.gpt(input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
......@@ -936,40 +743,42 @@ class GPTForPretraining(nn.Layer):
x = encoder_outputs
w = self.gpt.embeddings.word_embeddings.weight
mesh = _global_process_mesh
x_dims_mapping = [-1 for i in range(len(x.shape))]
w_dims_mapping = [-1 for i in range(len(w.shape))]
mesh = None
if _global_parallel_strategy == "pp":
mesh = PP_MESH_LIST[-1]
x_dims_mapping = [None for i in range(len(x.shape))]
w_dims_mapping = [None for i in range(len(w.shape))]
elif _global_parallel_strategy == "dp":
x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)]
mesh = _global_process_mesh
x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
w_dims_mapping = [None for i in range(len(w.shape))]
elif _global_parallel_strategy == "mp":
w_dims_mapping = [0] + [-1 for i in range(len(w.shape) - 1)]
mesh = _global_process_mesh
x_dims_mapping = [None for i in range(len(x.shape))]
w_dims_mapping = ["x"] + [None for i in range(len(w.shape) - 1)]
elif _global_parallel_strategy == "dp_mp":
x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)]
w_dims_mapping = [1] + [-1 for i in range(len(w.shape) - 1)]
mesh = _global_process_mesh
x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)]
elif _global_parallel_strategy == "dp_pp":
mesh = DPPP_MESH_LIST[-1]
x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)]
x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
w_dims_mapping = [None for i in range(len(w.shape))]
elif _global_parallel_strategy == "mp_pp":
mesh = MPPP_MESH_LIST[-1]
w_dims_mapping = [0] + [-1 for i in range(len(w.shape) - 1)]
x_dims_mapping = [None for i in range(len(x.shape))]
w_dims_mapping = ["x"] + [-1 for i in range(len(w.shape) - 1)]
elif _global_parallel_strategy == "dp_mp_pp":
mesh = DPMPPP_MESH_LIST[-1]
x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)]
w_dims_mapping = [1] + [-1 for i in range(len(w.shape) - 1)]
matmul = auto.shard_op(paddle.matmul,
dist_attr={
'process_mesh': mesh,
x: {
"dims_mapping": x_dims_mapping
},
w: {
"dims_mapping": w_dims_mapping
}
})
x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)]
w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)]
if mesh:
matmul = auto.shard_op(paddle.matmul, mesh,
[x_dims_mapping, w_dims_mapping, None])
logits = matmul(x, w, transpose_y=True)
else:
logits = paddle.matmul(x, w, transpose_y=True)
if use_cache:
return logits, cached_kvs
......@@ -988,25 +797,29 @@ class GPTPretrainingCriterion(nn.Layer):
self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none")
def forward(self, prediction_scores, masked_lm_labels, loss_mask):
masked_lm_labels.stop_gradient = True
loss_mask.stop_gradient = True
mesh = _global_process_mesh
dims_mapping = [-1 for i in range(len(loss_mask.shape))]
mesh = None
if _global_parallel_strategy == "dp":
dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)]
mesh = _global_process_mesh
dims_mapping = ["x"
] + [None for i in range(len(loss_mask.shape) - 1)]
elif _global_parallel_strategy == "dp_mp":
dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)]
mesh = _global_process_mesh
dims_mapping = ["x"
] + [None for i in range(len(loss_mask.shape) - 1)]
elif _global_parallel_strategy == "dp_pp":
mesh = DPPP_MESH_LIST[-1]
dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)]
dims_mapping = ["x"
] + [None for i in range(len(loss_mask.shape) - 1)]
elif _global_parallel_strategy == "dp_mp_pp":
mesh = DPMPPP_MESH_LIST[-1]
dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)]
dims_mapping = ["x"
] + [None for i in range(len(loss_mask.shape) - 1)]
auto.shard_tensor(loss_mask,
dist_attr={
"process_mesh": mesh,
"dims_mapping": dims_mapping
})
if mesh:
auto.shard_tensor(loss_mask, mesh, dims_mapping)
masked_lm_loss = self.loss_func(prediction_scores,
masked_lm_labels.unsqueeze(2))
......
......@@ -64,38 +64,18 @@ class MLPLayer(nn.Layer):
def forward(self, input):
if _global_parallel_strategy == "pp":
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None])
auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None])
elif _global_parallel_strategy == "mp":
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.linear0.weight, _global_process_mesh,
[None, "x"])
auto.shard_tensor(self.linear1.weight, _global_process_mesh,
["x", None])
elif _global_parallel_strategy == "dp":
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear0.weight, _global_process_mesh,
[None, None])
auto.shard_tensor(self.linear1.weight, _global_process_mesh,
[None, None])
out = self.norm(input)
out = self.linear0(out)
......@@ -119,28 +99,12 @@ def mlp_forward(train_program, start_program):
dtype='float32')
if _global_parallel_strategy == "pp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(input, PP_MESH_0, [None, None])
auto.shard_tensor(label, PP_MESH_1, [None, None])
elif _global_parallel_strategy == "dp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(input, _global_process_mesh, ["x", None])
elif _global_parallel_strategy == "mp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(input, _global_process_mesh, [None, None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......@@ -183,7 +147,7 @@ class TestMLPSaveLoad(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh([0, 1])
_global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
dist_main_prog, dist_start_prog, loss = get_distributed_program()
place = paddle.set_device("gpu")
......@@ -230,7 +194,7 @@ class TestMLPSaveLoad(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh([0, 1])
_global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
dist_main_prog, dist_start_prog, loss = get_distributed_program()
......@@ -278,11 +242,11 @@ class TestMLPSaveLoad(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "pp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh([0, 1])
_global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
global PP_MESH_0
PP_MESH_0 = auto.ProcessMesh(mesh=[0])
PP_MESH_0 = auto.ProcessMesh(mesh=[0], dim_names=["x"])
global PP_MESH_1
PP_MESH_1 = auto.ProcessMesh(mesh=[1])
PP_MESH_1 = auto.ProcessMesh(mesh=[1], dim_names=["x"])
dist_main_prog, dist_start_prog, loss = get_distributed_program()
......
......@@ -82,11 +82,7 @@ def mlp_pretrain_forward(train_program, start_program):
shape=[batch_size, sequence_len, 1],
dtype='float32')
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mappig": [-1, -1, -1]
})
auto.shard_tensor(input, _global_process_mesh, [None, None, None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......@@ -106,7 +102,7 @@ class TestMLPAutoParallelizer(unittest.TestCase):
def test_mlp_serial(self):
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"])
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
......
......@@ -86,7 +86,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
paddle.static.Program()):
with paddle.static.scope_guard(scope):
with paddle.fluid.unique_name.guard():
main_prog, startup_prog, inputs, outputs, reader = self.get_model(
main_prog, startup_prog, inputs, outputs, data_loader = self.get_model(
place, **kwargs)
inputs = self._to_var_names(inputs)
outputs = self._to_var_names(outputs)
......@@ -95,27 +95,57 @@ class AutoPallelPassTestBase(DistPassTestBase):
exe = paddle.static.Executor(place)
with paddle.static.scope_guard(scope):
exe.run(startup_prog)
for batch_id, input_data in enumerate(reader()):
assert len(input_data) == len(inputs), "{} vs {}".format(
len(input_data), len(inputs))
feed = dict(zip(inputs, input_data))
fetch_values = exe.run(main_prog, feed=feed, fetch_list=outputs)
data_loader.start()
batch_id = 0
while True:
try:
fetch_values = exe.run(main_prog, fetch_list=outputs)
if paddle.distributed.get_rank() == 0:
output_dict = OrderedDict(zip(outputs, fetch_values))
print('batch {}, outputs {}'.format(batch_id, output_dict))
print('batch {}, outputs {}'.format(
batch_id, output_dict))
all_fetch_values.append(fetch_values)
batch_id += 1
except paddle.fluid.core.EOFException:
data_loader.reset()
break
with open(dump_file, "wb") as f:
pickle.dump(all_fetch_values, f)
def get_gpt_model(self, strategy, place, batch_size, sequence_len,
vocab_size, **kwargs):
def gen_data():
np.random.seed(2021)
for _ in range(10):
tokens = []
position_ids = []
attention_mask = []
labels = []
loss_mask = []
for _ in range(batch_size):
tokens.append(
np.random.randint(vocab_size,
size=sequence_len).astype("int64"))
position_ids.append(np.arange(sequence_len).astype("int64"))
attention_mask.append(
[np.tril(np.ones(sequence_len)).astype("float32")])
labels.append(
np.random.randint(vocab_size,
size=sequence_len).astype("int64"))
loss_mask.append(np.ones(sequence_len).astype("float32"))
yield tokens, position_ids, attention_mask, labels, loss_mask
modeling.init_global()
if strategy == "dp":
modeling._global_parallel_strategy = "dp"
modeling._global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
modeling._global_process_mesh = auto.ProcessMesh(mesh=[0, 1],
dim_names=["x"])
elif strategy == "mp":
modeling._global_parallel_strategy = "mp"
modeling._global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
modeling._global_process_mesh = auto.ProcessMesh(mesh=[0, 1],
dim_names=["x"])
else:
raise ValueError("'get_gpt_model' only support dp and mp.")
......@@ -137,23 +167,17 @@ class AutoPallelPassTestBase(DistPassTestBase):
dtype='float32')
data_holder = [tokens, position_ids, attention_mask, labels, loss_mask]
data_loader = paddle.fluid.io.DataLoader.from_generator(
feed_list=data_holder, capacity=70, iterable=False)
data_loader.set_batch_generator(gen_data, paddle.static.cuda_places())
if modeling._global_parallel_strategy == "dp":
auto.shard_tensor(tokens,
dist_attr={
"process_mesh": modeling._global_process_mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(tokens, modeling._global_process_mesh,
["x", None])
elif modeling._global_parallel_strategy == "pp":
auto.shard_tensor(tokens,
dist_attr={
"process_mesh": modeling.PP_MESH_LIST[0],
"dims_mapping": [-1, -1]
})
auto.shard_tensor(attention_mask,
dist_attr={
"process_mesh": modeling.PP_MESH_LIST[0],
"dims_mapping": [-1, -1, -1, -1]
})
auto.shard_tensor(tokens, modeling.PP_MESH_LIST[0], [None, None])
auto.shard_tensor(attention_mask, modeling.PP_MESH_LIST[0],
[None, None, None, None])
gpt = GPTModel(vocab_size=1000,
hidden_size=64,
......@@ -179,38 +203,20 @@ class AutoPallelPassTestBase(DistPassTestBase):
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
clip = paddle.nn.ClipGradByNorm(clip_norm=1.0)
if kwargs.get('optimizer', None) == "LarsMomentum":
optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer(
learning_rate=0.001, momentum=0.9)
else:
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
grad_clip=clip)
optimizer = fleet.distributed_optimizer(optimizer)
startup_program = paddle.static.default_startup_program()
_, _, dist_startup_prog, dist_main_prog = optimizer.minimize(
loss, startup_program)
def gen_data():
np.random.seed(2021)
for _ in range(10):
tokens = []
position_ids = []
attention_mask = []
labels = []
loss_mask = []
for _ in range(batch_size):
tokens.append(
np.random.randint(vocab_size, size=sequence_len))
position_ids.append(np.arange(sequence_len))
attention_mask.append([np.tril(np.ones(sequence_len))])
labels.append(
np.random.randint(vocab_size, size=sequence_len))
loss_mask.append(np.ones(sequence_len))
yield tokens, position_ids, attention_mask, labels, loss_mask
return dist_main_prog, dist_startup_prog, data_holder, [loss], gen_data
return dist_main_prog, dist_startup_prog, data_holder, [loss
], data_loader
......@@ -20,10 +20,19 @@ import unittest
import paddle
import paddle.distributed.fleet as fleet
from auto_parallel_pass_test_base import AutoPallelPassTestBase
from test_auto_parallel_amp_pass import TestAMPPass
class TestPF16Pass(TestAMPPass):
class TestPF16Pass(AutoPallelPassTestBase):
def init(self):
if paddle.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
self.rtol = 1e-5
self.atol = 1e-8
paddle.seed(2021)
random.seed(2021)
np.random.seed(2021)
def apply_passes(self):
dist_strategy = fleet.DistributedStrategy()
......@@ -34,14 +43,30 @@ class TestPF16Pass(TestAMPPass):
'layer_norm',
'gelu',
],
"custom_black_list": ['c_softmax_with_cross_entropy'],
"init_loss_scaling": 32768,
"use_dynamic_loss_scaling": True,
"use_pure_fp16": True
"custom_black_list":
['c_softmax_with_cross_entropy', 'elementwise_div', 'reduce_sum'],
"init_loss_scaling":
32768,
"use_dynamic_loss_scaling":
True,
"use_pure_fp16":
True,
"use_fp16_guard":
False
}
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
def test_bs_8(self):
self.check_main(gpus=[0, 1],
batch_size=8,
sequence_len=512,
vocab_size=1000)
def get_model(self, place, batch_size, sequence_len, vocab_size):
return self.get_gpt_model("mp", place, batch_size, sequence_len,
vocab_size)
if __name__ == "__main__":
unittest.main()
......@@ -97,11 +97,8 @@ class MLPLayer(nn.Layer):
def mlp_forward(input, label, hidden_size):
auto.shard_tensor(input,
dist_attr={
"process_mesh": [0],
"dims_mapping": [-1, -1]
})
auto.shard_tensor(input, auto.ProcessMesh([0], dim_names=["x"]),
[None, None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02)
......@@ -160,6 +157,12 @@ class TestGradientMergePass(AutoPallelPassTestBase):
def get_model(self, place, batch_size, hidden_size, max_step):
def gen_data():
for i in range(max_step):
x_data = input_data[i * batch_size:(i + 1) * batch_size, :]
y_data = label_data[i * batch_size:(i + 1) * batch_size, :]
yield x_data, y_data
train_program = static.Program()
startup_program = static.Program()
with static.program_guard(train_program, startup_program), \
......@@ -171,6 +174,12 @@ class TestGradientMergePass(AutoPallelPassTestBase):
shape=[batch_size, 1],
dtype='float32')
input.stop_gradient = False
data_holder = [input, label]
data_loader = paddle.fluid.io.DataLoader.from_generator(
feed_list=data_holder, capacity=70, iterable=False)
data_loader.set_batch_generator(gen_data,
paddle.static.cuda_places())
loss = mlp_forward(input, label, hidden_size)
optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.01)
......@@ -181,13 +190,8 @@ class TestGradientMergePass(AutoPallelPassTestBase):
input_data = np.random.random(size=(128, hidden_size)).astype('float32')
label_data = np.random.random(size=(128, 1)).astype('float32')
def reader():
for i in range(max_step):
x_data = input_data[i * batch_size:(i + 1) * batch_size, :]
y_data = label_data[i * batch_size:(i + 1) * batch_size, :]
yield x_data, y_data
return dist_main_prog, dist_startup_prog, [input, label], [loss], reader
return dist_main_prog, dist_startup_prog, [input,
label], [loss], data_loader
if __name__ == "__main__":
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import paddle.fluid as fluid
import paddle.nn as nn
import paddle.distributed as dist
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
paddle.enable_static()
process_mesh1 = [0, 1, 2, 3]
process_mesh2 = [[0, 1, 2], [3, 4, 5]]
class SimpleNet(nn.Layer):
def __init__(self, vocab_size=128, hidden_size=4):
super(SimpleNet, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.dense1 = nn.Linear(hidden_size, hidden_size)
self.dense2 = nn.Linear(hidden_size, hidden_size // 2)
def forward(self, x, y):
# Test shard_tensor interface with dist_attr arg
x = dist.shard_tensor(x,
dist_attr={
"process_mesh": process_mesh1,
"dims_mapping": [0, -1]
})
emb_out = self.word_embeddings(x)
# Test shard_tensor interface with no dist_attr arg
y = dist.shard_tensor(y)
linear1 = self.dense1(y)
out = self.dense2(linear1)
return x, y
class TestAutoParallelAPI(unittest.TestCase):
def test_api(self):
dist_context = get_default_distributed_context()
net = SimpleNet()
data1 = fluid.layers.fill_constant(shape=[2, 4], value=1, dtype="int64")
data2 = fluid.layers.fill_constant(shape=[2, 4],
value=2,
dtype="float32")
data3 = fluid.layers.fill_constant(shape=[2, 4],
value=4,
dtype="float32")
x, y = net.forward(data1, data2)
dist_x = dist_context.get_dist_tensor_for_program(x)
self.assertEqual(dist_x.dist_attr.process_mesh.processes, process_mesh1)
self.assertEqual(dist_x.dist_attr.dims_mapping, [0, -1])
self.assertEqual(dist_x.dist_attr.shard_sizes, None)
self.assertEqual(dist_x.dist_attr.device_placement, None)
self.assertTrue(dist_x.dist_attr.is_annotated("process_mesh"))
self.assertTrue(dist_x.dist_attr.is_annotated("dims_mapping"))
self.assertFalse(dist_x.dist_attr.is_annotated("shard_sizes"))
self.assertFalse(dist_x.dist_attr.is_annotated("device_placement"))
dist_y = dist_context.get_dist_tensor_for_program(y)
self.assertEqual(dist_y.dist_attr.process_mesh, None)
self.assertEqual(dist_y.dist_attr.dims_mapping, [-1, -1])
self.assertEqual(dist_y.dist_attr.shard_sizes, None)
self.assertEqual(dist_y.dist_attr.device_placement, None)
self.assertFalse(dist_y.dist_attr.is_annotated("process_mesh"))
self.assertFalse(dist_y.dist_attr.is_annotated("dims_mapping"))
self.assertFalse(dist_y.dist_attr.is_annotated("shard_sizes"))
self.assertFalse(dist_y.dist_attr.is_annotated("device_placement"))
# Test shard_op interface with dist_attr
dims_mapping1 = [0, 1]
dims_mapping2 = [-1, 0]
dist_add = dist.shard_op(paddle.add,
dist_attr={
data2: {
"process_mesh": process_mesh2,
"dims_mapping": dims_mapping1
},
data3: {
"dims_mapping": dims_mapping2
}
})
results = dist_add(data2, data3)
ops = paddle.static.default_main_program().block(0).ops
last_op = ops[-1]
dist_op = dist_context.get_dist_op_for_program(last_op)
self.assertEqual(dist_op.dist_attr.process_mesh,
ProcessMesh(process_mesh2))
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name)
self.assertEqual(data2_dist_attr.process_mesh,
dist_op.dist_attr.process_mesh)
self.assertEqual(data2_dist_attr.dims_mapping, dims_mapping1)
self.assertEqual(data2_dist_attr.shard_sizes, None)
self.assertEqual(data2_dist_attr.device_placement, None)
self.assertTrue(data2_dist_attr.is_annotated("process_mesh"))
self.assertTrue(data2_dist_attr.is_annotated("dims_mapping"))
self.assertFalse(data2_dist_attr.is_annotated("shard_sizes"))
self.assertFalse(data2_dist_attr.is_annotated("device_placement"))
data3_dist_attr = dist_op.dist_attr.get_input_dist_attr(data3.name)
self.assertEqual(data3_dist_attr.process_mesh,
dist_op.dist_attr.process_mesh)
self.assertEqual(data3_dist_attr.dims_mapping, dims_mapping2)
self.assertEqual(data3_dist_attr.shard_sizes, None)
self.assertEqual(data3_dist_attr.device_placement, None)
self.assertTrue(data3_dist_attr.is_annotated("process_mesh"))
self.assertTrue(data3_dist_attr.is_annotated("dims_mapping"))
self.assertFalse(data3_dist_attr.is_annotated("shard_sizes"))
self.assertFalse(data3_dist_attr.is_annotated("device_placement"))
# Test shard_op interface with dist_attr
dist_add = dist.shard_op(paddle.add)
results = dist_add(data2, data3)
ops = paddle.static.default_main_program().block(0).ops
last_op = ops[-1]
dist_op = dist_context.get_dist_op_for_program(last_op)
self.assertEqual(dist_op.dist_attr.process_mesh, None)
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertFalse(dist_op.dist_attr.is_annotated("process_mesh"))
data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name)
self.assertEqual(data2_dist_attr.process_mesh,
dist_op.dist_attr.process_mesh)
self.assertEqual(data2_dist_attr.dims_mapping, [-1, -1])
self.assertEqual(data2_dist_attr.shard_sizes, None)
self.assertEqual(data2_dist_attr.device_placement, None)
self.assertFalse(data2_dist_attr.is_annotated("process_mesh"))
self.assertFalse(data2_dist_attr.is_annotated("dims_mapping"))
self.assertFalse(data2_dist_attr.is_annotated("shard_sizes"))
self.assertFalse(data2_dist_attr.is_annotated("device_placement"))
data3_dist_attr = dist_op.dist_attr.get_input_dist_attr(data3.name)
self.assertEqual(data3_dist_attr.process_mesh,
dist_op.dist_attr.process_mesh)
self.assertEqual(data3_dist_attr.dims_mapping, [-1, -1])
self.assertEqual(data3_dist_attr.shard_sizes, None)
self.assertEqual(data3_dist_attr.device_placement, None)
self.assertFalse(data3_dist_attr.is_annotated("process_mesh"))
self.assertFalse(data3_dist_attr.is_annotated("dims_mapping"))
self.assertFalse(data3_dist_attr.is_annotated("shard_sizes"))
self.assertFalse(data3_dist_attr.is_annotated("device_placement"))
if __name__ == '__main__':
unittest.main()
......@@ -66,39 +66,13 @@ class MLPLayer(nn.Layer):
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
if _global_parallel_strategy == "mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
elif _global_parallel_strategy == "pp":
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh2,
"dims_mapping": [1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["mp", None])
out = self.norm(input)
out = self.linear0(out)
......@@ -119,18 +93,10 @@ def mlp_pretrain_forward(train_program, start_program):
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
if _global_parallel_strategy == "dp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["dp", "dp_mp"]:
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["dp", None, None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......@@ -146,7 +112,8 @@ class TestMLPAutoCompletion(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["dp"])
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
......@@ -161,7 +128,8 @@ class TestMLPAutoCompletion(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["mp"])
train_program = static.Program()
start_program = static.Program()
......@@ -177,8 +145,9 @@ class TestMLPAutoCompletion(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
_global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
[4, 5, 6, 7]],
dim_names=["dp", "mp"])
train_program = static.Program()
start_program = static.Program()
......@@ -286,18 +255,10 @@ class AttentionLayer(nn.Layer):
bias_attr=bias_attr)
def forward(self, input):
if _global_parallel_strategy == "dp":
if _global_parallel_strategy in ["dp", "dp_mp"]:
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["dp", None, None])
q = self.q_proj(input)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
......@@ -306,38 +267,16 @@ class AttentionLayer(nn.Layer):
k = self.k_proj(input)
v = self.v_proj(input)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
......@@ -369,18 +308,10 @@ class AttentionLayer(nn.Layer):
# project to output
out = self.out_proj(out)
if _global_parallel_strategy == "mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["mp", None])
return out
......@@ -411,7 +342,8 @@ class TestAttentionAutoCompletion(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["dp"])
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
......@@ -420,15 +352,14 @@ class TestAttentionAutoCompletion(unittest.TestCase):
completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation(
train_program)
# print_program_with_dist_attr(complete_train_program,
# dist_context)
self.assertTrue(dist_context.validate_dist_attr_for_program())
def test_attn_mp(self):
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["mp"])
train_program = static.Program()
start_program = static.Program()
......@@ -444,8 +375,9 @@ class TestAttentionAutoCompletion(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
_global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
[4, 5, 6, 7]],
dim_names=["dp", "mp"])
train_program = static.Program()
start_program = static.Program()
......@@ -542,34 +474,18 @@ class DecoderLayer(nn.Layer):
self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
def forward(self, input_ids, position_ids):
if _global_parallel_strategy == "dp":
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["dp", "dp_mp"]:
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["dp", None])
input_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["mp", None])
embeddings = input_embeddings + position_embeddings
embeddings = self.dropout1(embeddings)
......@@ -585,38 +501,16 @@ class DecoderLayer(nn.Layer):
k = self.k_proj(target)
v = self.v_proj(target)
if _global_parallel_strategy == "mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
......@@ -649,18 +543,10 @@ class DecoderLayer(nn.Layer):
# project to output
out = self.out_proj(out)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["mp", None])
# Add residual
residual = embeddings + self.dropout2(out)
......@@ -673,28 +559,13 @@ class DecoderLayer(nn.Layer):
out2 = F.gelu(out1, approximate=True)
out3 = self.linear1(out2)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["mp", None])
# Add residual
final = residual + self.dropout3(out3)
......@@ -732,7 +603,8 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["dp"])
train_program = static.Program()
start_program = static.Program()
dist_context = DistributedContext()
......@@ -747,7 +619,8 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["mp"])
train_program = static.Program()
start_program = static.Program()
......@@ -763,8 +636,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
_global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
[4, 5, 6, 7]],
dim_names=["dp", "mp"])
train_program = static.Program()
start_program = static.Program()
......
......@@ -116,18 +116,10 @@ class MultiHeadAttention(nn.Layer):
"""
q = self.q_proj(query)
if _global_parallel_strategy == "mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
......@@ -158,34 +150,15 @@ class MultiHeadAttention(nn.Layer):
to construct cache for inference.
"""
k = self.k_proj(key)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
v = self.v_proj(value)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.k_proj.weight,
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
......@@ -265,18 +238,10 @@ class MultiHeadAttention(nn.Layer):
# project to output
out = self.out_proj(out)
if _global_parallel_strategy == "mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["mp", None])
outs = [out]
if self.need_weights:
......@@ -439,31 +404,13 @@ class TransformerDecoderLayer(nn.Layer):
if self.normalize_before:
tgt = self.norm2(tgt)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
if _global_parallel_strategy == "mp":
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["mp", None])
# tgt = self.dropout2(
# self.linear2(F.gelu(
......@@ -523,18 +470,10 @@ class GPTEmbeddings(nn.Layer):
input_embedings = self.word_embeddings(input_ids)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["mp", None])
position_embeddings = self.position_embeddings(position_ids)
embeddings = input_embedings + position_embeddings
......@@ -757,18 +696,10 @@ def gpt_pretrain_forward(train_program, start_program):
shape=[batch_size, sequence_len],
dtype='float64')
if _global_parallel_strategy == "dp":
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["dp", "dp_mp"]:
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["dp", None])
gpt = GPTModel(vocab_size=32768,
hidden_size=1024,
......@@ -801,7 +732,8 @@ class TestGPTAutoCompletion(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["dp"])
train_program = static.Program()
start_program = static.Program()
......@@ -817,7 +749,8 @@ class TestGPTAutoCompletion(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["mp"])
train_program = static.Program()
start_program = static.Program()
......@@ -833,8 +766,9 @@ class TestGPTAutoCompletion(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
_global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
[4, 5, 6, 7]],
dim_names=["dp", "mp"])
train_program = static.Program()
start_program = static.Program()
......
......@@ -35,8 +35,8 @@ from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp"
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]])
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], dim_names=["x", "y"])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], dim_names=["x", "y"])
NUM_RANKS = 8
STAGE_0_CNT = 5
STAGE_1_CNT = 10
......@@ -73,16 +73,8 @@ class MLPLayer(nn.Layer):
def forward(self, input):
if self.is_distributed:
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [1, -1]
})
auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None])
auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["y", None])
out = self.norm(input)
out = self.linear0(out)
......@@ -135,16 +127,8 @@ def mlp_forward(train_program, start_program, is_distributed=True):
dtype='float32')
if is_distributed:
auto.shard_tensor(input,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
auto.shard_tensor(input, PP_MESH_0, ["x", None])
auto.shard_tensor(label, PP_MESH_1, ["x", None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......
......@@ -71,7 +71,7 @@ class TestDistributedTensor(unittest.TestCase):
def test_new_local_tensor(self):
test_auto_parallel_reshard._global_process_mesh = auto.ProcessMesh(
mesh=[0, 1])
mesh=[0, 1], dim_names=["x"])
test_auto_parallel_reshard._global_parallel_strategy = "dp"
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
......
......@@ -414,37 +414,25 @@ class MLPLayer(nn.Layer):
def forward(self, input):
if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh[0],
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh[0],
"dims_mapping": [1, -1]
})
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh[1],
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear3.weight,
dist_attr={
"process_mesh": _global_process_mesh[1],
"dims_mapping": [1, -1]
})
auto.shard_tensor(self.linear0.weight, _global_process_mesh[0],
[None, "y"])
auto.shard_tensor(self.linear1.weight, _global_process_mesh[0],
["y", None])
auto.shard_tensor(self.linear2.weight, _global_process_mesh[1],
[None, "y"])
auto.shard_tensor(self.linear3.weight, _global_process_mesh[1],
["y", None])
out = self.norm(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = self.linear1(out)
auto.shard_tensor(out,
dist_attr={
"process_mesh": _global_process_mesh[1],
"dims_mapping": [0, -1]
})
auto.shard_tensor(out, _global_process_mesh[1], ["x", None])
out = self.linear2(out)
out = F.gelu(out, approximate=True)
out = self.linear3(out)
......@@ -464,11 +452,7 @@ def mlp_forward(train_program, start_program):
dtype='float32')
if _global_parallel_strategy == "dp_mp_pp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh[0],
"dims_mapping": [0, -1]
})
auto.shard_tensor(input, _global_process_mesh[0], ["x", None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
initializer_range=0.02)
......@@ -548,7 +532,10 @@ class TestAutoParallelMapper(unittest.TestCase):
global _global_num_stages
_global_num_stages = 2
global _global_process_mesh
_global_process_mesh = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
_global_process_mesh = [
auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]),
auto.ProcessMesh([[4, 5], [6, 7]], dim_names=["x", "y"])
]
processes = [0, 1, 2, 3, 4, 5, 6, 7]
dist_programs = {}
......
......@@ -276,39 +276,20 @@ class MLPLayer(nn.Layer):
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
if _global_parallel_strategy == "mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["mp", None])
else:
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=[None, None])
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=[None, None])
out = self.norm(input)
out = self.linear0(out)
......@@ -329,18 +310,10 @@ def mlp_pretrain_forward(train_program, start_program):
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
if _global_parallel_strategy == "dp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["dp", "dp_mp"]:
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["dp", None, None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......@@ -356,7 +329,8 @@ class TestMLPAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["dp"])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
mlp_pretrain_forward)
......@@ -391,7 +365,8 @@ class TestMLPAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["mp"])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
mlp_pretrain_forward)
......@@ -453,8 +428,9 @@ class TestMLPAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
_global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
[4, 5, 6, 7]],
dim_names=["dp", "mp"])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
mlp_pretrain_forward)
......@@ -558,18 +534,10 @@ class AttentionLayer(nn.Layer):
bias_attr=bias_attr)
def forward(self, input):
if _global_parallel_strategy == "dp":
if _global_parallel_strategy in ["dp", "dp_mp"]:
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["dp", None, None])
q = self.q_proj(input)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
......@@ -578,38 +546,16 @@ class AttentionLayer(nn.Layer):
k = self.k_proj(input)
v = self.v_proj(input)
if _global_parallel_strategy == "mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
......@@ -641,18 +587,11 @@ class AttentionLayer(nn.Layer):
# project to output
out = self.out_proj(out)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["mp", None])
return out
......@@ -683,7 +622,8 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["dp"])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
attn_pretrain_forward)
......@@ -717,7 +657,8 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3])
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3],
dim_names=["mp"])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
attn_pretrain_forward)
......@@ -783,8 +724,9 @@ class TestAttentionAutoPartitioner(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
_global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
[4, 5, 6, 7]],
dim_names=["dp", "mp"])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
attn_pretrain_forward)
......@@ -930,34 +872,18 @@ class DecoderLayer(nn.Layer):
self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train")
def forward(self, input_ids, position_ids):
if _global_parallel_strategy == "dp":
if _global_parallel_strategy in ["dp", "dp_mp"]:
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["dp", None])
input_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["mp", None])
embeddings = input_embeddings + position_embeddings
embeddings = self.dropout1(embeddings)
......@@ -973,38 +899,16 @@ class DecoderLayer(nn.Layer):
k = self.k_proj(target)
v = self.v_proj(target)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
......@@ -1037,24 +941,14 @@ class DecoderLayer(nn.Layer):
# project to output
out = self.out_proj(out)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["mp", None])
else:
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=[None, None])
# Add residual
residual = embeddings + self.dropout2(out)
......@@ -1067,28 +961,13 @@ class DecoderLayer(nn.Layer):
out2 = F.gelu(out1, approximate=True)
out3 = self.linear1(out2)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["mp", None])
# Add residual
final = residual + self.dropout3(out3)
......@@ -1126,8 +1005,9 @@ class TestDecoderLayerPartitioner(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
_global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
[4, 5, 6, 7]],
dim_names=["dp", "mp"])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
decoder_pretrain_forward)
......@@ -1208,8 +1088,9 @@ class TestDecoderLayerPartitioner(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "None"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
_global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
[4, 5, 6, 7]],
dim_names=["x", "y"])
serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs(
decoder_pretrain_forward)
......
......@@ -163,18 +163,10 @@ class MultiHeadAttention(nn.Layer):
"""
q = self.q_proj(query)
if _global_parallel_strategy == "mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.q_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q = tensor.transpose(x=q, perm=[0, 2, 1, 3])
......@@ -205,34 +197,15 @@ class MultiHeadAttention(nn.Layer):
to construct cache for inference.
"""
k = self.k_proj(key)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.k_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
v = self.v_proj(value)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.k_proj.weight,
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
auto.shard_tensor(self.v_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k = tensor.transpose(x=k, perm=[0, 2, 1, 3])
......@@ -312,18 +285,10 @@ class MultiHeadAttention(nn.Layer):
# project to output
out = self.out_proj(out)
if _global_parallel_strategy == "mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.out_proj.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["mp", None])
outs = [out]
if self.need_weights:
......@@ -486,31 +451,13 @@ class TransformerDecoderLayer(nn.Layer):
if self.normalize_before:
tgt = self.norm2(tgt)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 0]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, 1]
})
if _global_parallel_strategy == "mp":
process_mesh=_global_process_mesh,
shard_spec=[None, "mp"])
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["mp", None])
# tgt = self.dropout2(
# self.linear2(F.gelu(
......@@ -570,18 +517,10 @@ class GPTEmbeddings(nn.Layer):
input_embedings = self.word_embeddings(input_ids)
if _global_parallel_strategy == "mp":
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["mp", "dp_mp"]:
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [1, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["mp", None])
position_embeddings = self.position_embeddings(position_ids)
embeddings = input_embedings + position_embeddings
......@@ -804,18 +743,10 @@ def gpt_pretrain_forward(train_program, startup_program):
shape=[batch_size, sequence_len],
dtype='float64')
if _global_parallel_strategy == "dp":
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
elif _global_parallel_strategy == "dp_mp":
if _global_parallel_strategy in ["dp", "dp_mp"]:
auto.shard_tensor(input_ids,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
process_mesh=_global_process_mesh,
shard_spec=["dp", None])
gpt = GPTModel(vocab_size=32768,
hidden_size=768,
......@@ -863,8 +794,9 @@ class TestGPTPartitioner(unittest.TestCase):
_global_parallel_strategy = "dp_mp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(
mesh=[[0, 1, 2, 3], [4, 5, 6, 7]])
_global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3],
[4, 5, 6, 7]],
dim_names=["dp", "mp"])
train_program = static.Program()
startup_program = static.Program()
......
......@@ -63,27 +63,13 @@ class MLPLayer(nn.Layer):
def forward(self, input):
if _global_parallel_strategy == "pp":
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None])
auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None])
else:
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear0.weight, _global_process_mesh,
[None, None])
auto.shard_tensor(self.linear1.weight, _global_process_mesh,
[None, None])
out = self.norm(input)
out = self.linear0(out)
......@@ -107,28 +93,12 @@ def mlp_forward(train_program, start_program):
dtype='float32')
if _global_parallel_strategy == "pp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(input, PP_MESH_0, [None, None])
auto.shard_tensor(label, PP_MESH_1, [None, None])
elif _global_parallel_strategy == "dp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(input, _global_process_mesh, ["x", None])
else:
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(input, _global_process_mesh, [None, None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......@@ -296,11 +266,11 @@ class TestMLPReshard(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "pp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"])
global PP_MESH_0
PP_MESH_0 = auto.ProcessMesh(mesh=[0])
PP_MESH_0 = auto.ProcessMesh(mesh=[0], dim_names=["x"])
global PP_MESH_1
PP_MESH_1 = auto.ProcessMesh(mesh=[1])
PP_MESH_1 = auto.ProcessMesh(mesh=[1], dim_names=["x"])
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
......@@ -325,11 +295,11 @@ class TestMLPReshard(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "pp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"])
global PP_MESH_0
PP_MESH_0 = auto.ProcessMesh(mesh=[0])
PP_MESH_0 = auto.ProcessMesh(mesh=[0], dim_names=["x"])
global PP_MESH_1
PP_MESH_1 = auto.ProcessMesh(mesh=[1])
PP_MESH_1 = auto.ProcessMesh(mesh=[1], dim_names=["x"])
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
......@@ -352,7 +322,7 @@ class TestMLPReshard(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = "dp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"])
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
......
......@@ -34,9 +34,10 @@ from paddle.distributed.auto_parallel.cluster import Cluster
paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp"
_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]])
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]])
_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]],
dim_names=["x", "y", "z"])
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], dim_names=["x", "y"])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], dim_names=["x", "y"])
class MLPLayer(nn.Layer):
......@@ -63,16 +64,8 @@ class MLPLayer(nn.Layer):
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, input):
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [1, -1]
})
auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, "y"])
auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["y", None])
out = self.norm(input)
out = self.linear0(out)
......@@ -80,11 +73,7 @@ class MLPLayer(nn.Layer):
out = self.linear1(out)
param = paddle.fluid.layers.create_parameter([1024, 4096],
paddle.float32)
auto.shard_tensor(param,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, 1]
})
auto.shard_tensor(param, PP_MESH_1, [None, "y"])
out = paddle.fluid.layers.mul(out, param)
return out
......@@ -103,16 +92,8 @@ def mlp_forward(train_program, start_program):
shape=[batch_size, 1],
dtype='float32')
auto.shard_tensor(input,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
auto.shard_tensor(input, PP_MESH_0, ["x", None])
auto.shard_tensor(label, PP_MESH_1, ["x", None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......
......@@ -34,9 +34,9 @@ from paddle.distributed.auto_parallel.cluster import Cluster
paddle.enable_static()
_global_parallel_strategy = "mp_pp"
_global_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]])
PP_MESH_0 = auto.ProcessMesh([0, 1])
PP_MESH_1 = auto.ProcessMesh([2, 3])
_global_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
PP_MESH_0 = auto.ProcessMesh([0, 1], dim_names=["x"])
PP_MESH_1 = auto.ProcessMesh([2, 3], dim_names=["x"])
class MLPLayer(nn.Layer):
......@@ -73,35 +73,15 @@ class MLPLayer(nn.Layer):
bias_attr=bias_attr)
def forward(self, input):
auto.shard_tensor(self.word_embeddings.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.linear2.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [0, -1]
})
auto.shard_tensor(self.word_embeddings.weight, PP_MESH_0, ["x", None])
auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, "x"])
auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["x", None])
auto.shard_tensor(self.linear2.weight, PP_MESH_1, ["x", None])
w_out = self.word_embeddings(input)
out = self.linear0(w_out)
param = paddle.fluid.layers.create_parameter([4096, 4096],
paddle.float32)
auto.shard_tensor(param,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [0, -1]
})
auto.shard_tensor(param, PP_MESH_0, ["x", None])
out = paddle.fluid.layers.mul(out, param)
gelu_out = F.gelu(out, approximate=True)
out = self.linear1(gelu_out)
......@@ -122,16 +102,8 @@ def mlp_forward(train_program, start_program):
shape=[batch_size, 1],
dtype='float32')
auto.shard_tensor(input,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(input, PP_MESH_0, [None])
auto.shard_tensor(label, PP_MESH_1, [None, None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......@@ -238,7 +210,6 @@ class TestMLPReshard(unittest.TestCase):
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads)
resharder.reshard()
print_program_with_dist_attr(dist_main_prog, dist_context)
# check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
......@@ -249,32 +220,15 @@ class TestMLPReshard(unittest.TestCase):
def test_allgather(self):
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
process_mesh = auto.ProcessMesh(mesh=[0, 1])
process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"])
with static.program_guard(train_program, startup_program):
x = paddle.static.data(name="x", shape=[4, 4], dtype='float32')
x = auto.shard_tensor(x,
dist_attr={
"process_mesh": process_mesh,
"dims_mapping": [0, -1]
})
x = auto.shard_tensor(x, process_mesh, ["x", None])
w = paddle.static.data(name="w", shape=[4, 4], dtype='float32')
w = auto.shard_tensor(w,
dist_attr={
"process_mesh": process_mesh,
"dims_mapping": [-1, -1]
})
y = paddle.distributed.shard_op(paddle.matmul,
dist_attr={
"process_mesh": process_mesh,
x: {
"dims_mapping": [-1, -1]
},
w: {
"dims_mapping": [-1, -1]
}
})(x, w)
w = auto.shard_tensor(w, process_mesh, [None, None])
y = paddle.distributed.shard_op(paddle.matmul, process_mesh,
[[None, None], [None, None]])(x, w)
rank_id = 0
dist_context = DistributedContext()
......
......@@ -62,27 +62,13 @@ class MLPLayer(nn.Layer):
def forward(self, input):
if _global_parallel_strategy == "pp":
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None])
auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None])
else:
auto.shard_tensor(self.linear0.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear1.weight,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(self.linear0.weight, _global_process_mesh,
[None, None])
auto.shard_tensor(self.linear1.weight, _global_process_mesh,
[None, None])
out = self.norm(input)
out = self.linear0(out)
......@@ -106,28 +92,12 @@ def mlp_forward(train_program, start_program):
dtype='float32')
if _global_parallel_strategy == "pp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": PP_MESH_0,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(label,
dist_attr={
"process_mesh": PP_MESH_1,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(input, PP_MESH_0, [None, None])
auto.shard_tensor(label, PP_MESH_1, [None, None])
elif _global_parallel_strategy == "dp":
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(input, _global_process_mesh, ["x", None])
else:
auto.shard_tensor(input,
dist_attr={
"process_mesh": _global_process_mesh,
"dims_mapping": [-1, -1]
})
auto.shard_tensor(input, _global_process_mesh, [None, None])
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
......@@ -196,7 +166,7 @@ class TestMLPReshard(unittest.TestCase):
global _global_parallel_strategy
_global_parallel_strategy = None
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0])
_global_process_mesh = auto.ProcessMesh(mesh=[0], dim_names=["x"])
train_program = paddle.static.Program()
startup_program = paddle.static.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册