diff --git a/python/paddle/distributed/auto_parallel/__init__.py b/python/paddle/distributed/auto_parallel/__init__.py index 4dc68edfe2d55383793ddc48aa2021d75321d7c2..269a0ec644dbd2e54ab255fa88f2e80ae744985f 100644 --- a/python/paddle/distributed/auto_parallel/__init__.py +++ b/python/paddle/distributed/auto_parallel/__init__.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .interface import shard_tensor # noqa: F401 -from .interface import shard_op # noqa: F401 +from .strategy import Strategy from .process_mesh import ProcessMesh -from .reshard import Resharder # noqa: F401 -from .cost_model import estimate_cost +from .engine import Engine +from .interface import shard_tensor +from .interface import shard_op +from .interface import recompute +from .interface import fetch __all__ = [] diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..82b3d4554b76a6a731bc3c38b3226c65e9392f2d --- /dev/null +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -0,0 +1,118 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from collections import defaultdict + +# _g_default_config[category][field] = default_value +_g_default_config = defaultdict(dict) + + +def get_category_default_config(category): + return _g_default_config[category] + + +def set_category_default_config(category, default_value): + _g_default_config[category] = default_value + + +def get_field_default_config(category, field): + return _g_default_config[category][field] + + +def set_field_default_config(category, field, default_value): + _g_default_config[category][field] = default_value + + +NOT_FOUND = "not_found" + +######################################### +# base configuration +######################################### +BASE = "base" +set_field_default_config(BASE, "auto_mode", "semi") +set_field_default_config(BASE, "gradient_scale", True) +set_field_default_config(BASE, "use_cache", True) +set_field_default_config(BASE, "return_numpy", True) +set_field_default_config(BASE, "all_ranks", False) +set_field_default_config(BASE, "split_data", False) +set_field_default_config(BASE, "seed", None) +set_field_default_config(BASE, "reinit", False) # Only for debug + +######################################### +# recompute configuration +######################################### +RECOMPUTE = "recompute" +set_field_default_config(RECOMPUTE, "enable", False) +set_field_default_config(RECOMPUTE, "checkpoints", None) +set_field_default_config(RECOMPUTE, "enable_tuning", False) + +######################################### +# AMP configuration +######################################### +AMP = "amp" +set_field_default_config(AMP, "enable", False) +set_field_default_config(AMP, "init_loss_scaling", 32768.0) +set_field_default_config(AMP, "incr_every_n_steps", 1000) +set_field_default_config(AMP, "decr_every_n_nan_or_inf", 2) +set_field_default_config(AMP, "incr_ratio", 2.0) +set_field_default_config(AMP, "decr_ratio", 0.8) +set_field_default_config(AMP, "use_dynamic_loss_scaling", True) +set_field_default_config(AMP, "custom_white_list", []) +set_field_default_config(AMP, "custom_black_list", []) +set_field_default_config(AMP, "custom_black_varnames", []) +set_field_default_config(AMP, "use_pure_fp16", False) +set_field_default_config(AMP, "use_fp16_guard", True) +set_field_default_config(AMP, "use_optimizer_fp16", False) + +######################################### +# sharding configuration +######################################### +SHARDING = "sharding" +set_field_default_config(SHARDING, "enable", False) +set_field_default_config(SHARDING, "stage", 1) +set_field_default_config(SHARDING, "sharding_degree", 8) +set_field_default_config(SHARDING, "segment_broadcast_MB", 32.0) +set_field_default_config(SHARDING, "enable_tuning", False) +set_field_default_config(SHARDING, "tuning_range", []) + +######################################### +# gradient merge configuration +######################################### +GRADIENT_MERGE = "gradient_merge" +set_field_default_config(GRADIENT_MERGE, "enable", False) +set_field_default_config(GRADIENT_MERGE, "k_steps", 1) +set_field_default_config(GRADIENT_MERGE, "avg", True) + +######################################### +# quantization configuration +######################################### +QAT = "qat" +set_field_default_config(QAT, "enable", False) +set_field_default_config(QAT, "channel_wise_abs_max", True) +set_field_default_config(QAT, "weight_bits", 8) +set_field_default_config(QAT, "activation_bits", 8) +set_field_default_config(QAT, "not_quant_pattern", ['skip_quant']) +set_field_default_config(QAT, "algo", None) + +# ######################################### +# auto tuning configuration +# ######################################### +TUNING = "tuning" +set_field_default_config(TUNING, "enable", False) +set_field_default_config(TUNING, "batch_size", 1) +set_field_default_config(TUNING, "dataset", None) +set_field_default_config(TUNING, "profile_start_step", 1) +set_field_default_config(TUNING, "profile_end_step", 1) +set_field_default_config(TUNING, "run_after_tuning", True) +set_field_default_config(TUNING, "verbose", True) diff --git a/python/paddle/distributed/auto_parallel/converter.py b/python/paddle/distributed/auto_parallel/converter.py index 162ca135f37972cd4a27ec0919f97086188288ac..95f8bad828b098daa0030288ce3984d698bce6a7 100644 --- a/python/paddle/distributed/auto_parallel/converter.py +++ b/python/paddle/distributed/auto_parallel/converter.py @@ -16,7 +16,7 @@ import paddle import warnings import logging import numpy as np -from ..utils import get_logger +from .utils import get_logger class Converter(object): diff --git a/python/paddle/distributed/auto_parallel/dist_attribute.py b/python/paddle/distributed/auto_parallel/dist_attribute.py index ff07deb42aad3fcb87de8bcaab88032081087bec..92d0304eaf6138371261fe4781d31656e3c1185d 100644 --- a/python/paddle/distributed/auto_parallel/dist_attribute.py +++ b/python/paddle/distributed/auto_parallel/dist_attribute.py @@ -173,6 +173,17 @@ class TensorDistributedAttribute: def clear_annotated(self): self._is_annotated.clear() + def __eq__(self, other): + if not isinstance(other, TensorDistributedAttribute): + return False + if self.process_mesh != other.process_mesh: + return False + if self.dims_mapping != other.dims_mapping: + return False + if self._is_annotated != other._is_annotated: + return False + return True + def __str__(self): str = "\n\ttensor_dist_attr = {" if self.is_annotated("process_mesh"): @@ -486,6 +497,27 @@ class OperatorDistributedAttribute: else: return False + def __eq__(self, other): + if not isinstance(other, OperatorDistributedAttribute): + return False + if self.process_mesh != other.process_mesh: + return False + if self.op_type != other.op_type: + return False + if self.impl_type != other.impl_type: + return False + if self.impl_idx != other.impl_idx: + return False + if self._is_annotated != other._is_annotated: + return False + if self._is_recompute != other._is_recompute: + return False + if self.inputs_dist_attrs != other.inputs_dist_attrs: + return False + if self.outputs_dist_attrs != other.outputs_dist_attrs: + return False + return True + def __str__(self): str = "\n\top_dist_attr = {" if self.is_annotated("process_mesh"): diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index 92a503659041ebc199400528bb68ab4956e76d3b..d1f00e8a7ba4fc89995ec7382dafa076a63e822a 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -126,9 +126,6 @@ class DistributedContext: # A flag indicates whether the used parallelism is data parallel self._data_parallel = False - # flag whether using `to_static` - self._dygraph_mode = False - @property def serial_main_program(self): return self._serial_main_program diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py index b6a77b778885f552438fb654ab3c4a88ebd79a05..300c80ec71878b4ab8e00cf822e739801f54f243 100644 --- a/python/paddle/distributed/auto_parallel/dist_op.py +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -23,6 +23,7 @@ from .dist_attribute import append_op_input_suffix from .dist_attribute import append_op_output_suffix from .dist_attribute import get_tensor_dist_attr_field_keys from .dist_attribute import get_op_dist_attr_field_keys +from .utils import convert_to_shard_spec, verify_shard_spec class DistributedOperator: @@ -248,23 +249,106 @@ class DistributedOperator: return result -class DistributedModule: +class DistributedOperatorHelper: - def __init__(self, serial_module, dist_attr=None): - self._serial_module = serial_module - self._dist_attr = dist_attr + def __init__(self, serial_op, process_mesh, in_dims_mappings, + out_dims_mappings): + self._serial_op = serial_op + self._process_mesh = process_mesh + self._in_dims_mappings = in_dims_mappings + self._out_dims_mappings = out_dims_mappings def __call__(self, *args, **kwargs): - from .dist_context import get_default_distributed_context + tensor_to_dims_mapping = {} + index = 0 + if self._in_dims_mappings: + assert len(args) + len(kwargs) == len(self._in_dims_mappings), \ + "The length of dims_mapping {} does not matching the length output {}.".format(len(self._in_dims_mappings), len(args) + len(kwargs)) + for arg in args: + if isinstance(arg, Variable) and self._in_dims_mappings: + tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index] + index += 1 + for arg in kwargs.values() and self._in_dims_mappings: + if isinstance(arg, Variable): + tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index] + index += 1 + default_prog = paddle.fluid.default_main_program() cur_block = default_prog.current_block() op_size = len(cur_block.ops) - output = self._serial_module(*args, **kwargs) + output = self._serial_op(*args, **kwargs) new_op_size = len(cur_block.ops) + + if isinstance(output, tuple) or isinstance(output, list): + new_output = list(output) + elif isinstance(output, Variable): + new_output = [output] + else: + raise ValueError("Unrecognized outpout.") + + if self._out_dims_mappings: + assert len(new_output) == len(self._out_dims_mappings), \ + "The length of dims_mapping {} does not matching the length output {}.".format(len(self._out_dims_mappings), len(new_output)) + for i, item in enumerate(new_output): + if isinstance(item, Variable) and self._out_dims_mappings: + tensor_to_dims_mapping[item.name] = self._out_dims_mappings[i] + + from .dist_context import get_default_distributed_context default_dist_ctx = get_default_distributed_context() for idx in range(op_size, new_op_size): op = cur_block.ops[idx] - dist_op = DistributedOperator(op, self._dist_attr) - dist_op.dist_attr.mark_annotated_as(self._dist_attr) + dist_op = DistributedOperator(op) + for name in dist_op.serial_op.input_arg_names: + if name in tensor_to_dims_mapping.keys(): + tensor = dist_op.get_serial_input(name) + tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr( + name) + dims_mapping = tensor_to_dims_mapping[name] + if tensor is None: + tensor_shape = [] + else: + if tensor.type == core.VarDesc.VarType.READER \ + or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ + or tensor.type == core.VarDesc.VarType.STEP_SCOPES: + tensor_shape = [] + else: + tensor_shape = tensor.shape + if dims_mapping is not None: + dims_mapping = tensor_to_dims_mapping[name] + shard_spec = convert_to_shard_spec( + dims_mapping, self._process_mesh) + assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \ + "For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format( + name, shard_spec, tensor_shape, self._process_mesh) + tensor_dist_attr.dims_mapping = dims_mapping + tensor_dist_attr.mark_annotated("dims_mapping") + for name in dist_op.serial_op.output_arg_names: + if name in tensor_to_dims_mapping.keys(): + tensor = dist_op.get_serial_output(name) + tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr( + name) + dims_mapping = tensor_to_dims_mapping[name] + if tensor is None: + tensor_shape = [] + else: + if tensor.type == core.VarDesc.VarType.READER \ + or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ + or tensor.type == core.VarDesc.VarType.STEP_SCOPES: + tensor_shape = [] + else: + tensor_shape = tensor.shape + if dims_mapping is not None: + dims_mapping = tensor_to_dims_mapping[name] + shard_spec = convert_to_shard_spec( + dims_mapping, self._process_mesh) + assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \ + "For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format( + name, shard_spec, tensor_shape, self._process_mesh) + tensor_dist_attr.dims_mapping = dims_mapping + tensor_dist_attr.mark_annotated("dims_mapping") + dist_op.dist_attr.process_mesh = self._process_mesh + if self._process_mesh is not None: + dist_op.dist_attr.mark_annotated("process_mesh") default_dist_ctx.add_dist_op_for_program(dist_op) + return output diff --git a/python/paddle/distributed/auto_parallel/dist_saver.py b/python/paddle/distributed/auto_parallel/dist_saver.py index c3dad9e2873866bff561d3c3e52e89341f194beb..aef2dcc6b7ee73b3a149ba28bc26496d3d2587c3 100644 --- a/python/paddle/distributed/auto_parallel/dist_saver.py +++ b/python/paddle/distributed/auto_parallel/dist_saver.py @@ -27,7 +27,7 @@ from paddle.fluid.framework import static_only from .utils import get_dist_attr from .converter import Converter from .process_group import _g_process_group_map -from ..utils import get_logger +from .utils import get_logger def check_filename(re_exp, filename): @@ -59,6 +59,14 @@ class DistributedSaver: def save(self, path, serial_program, dist_main_program, dist_context): + def _save_state(program, path, mode="param"): + state = { + k: np.array(v) + for k, v in program.state_dict(mode).items() + } + with open(path, "wb") as f: + pickle.dump(state, f) + dirname, filename = _process_path(path) rank_id = paddle.distributed.get_rank() @@ -76,16 +84,6 @@ class DistributedSaver: with open(dist_model_path, "wb") as f: f.write(dist_main_program.desc.serialize_to_string()) - # save distributed params - dist_param_filename = filename + "_dist" + str(rank_id) + ".pdparams" - dist_param_path = os.path.join(dirname, dist_param_filename) - dist_param = { - k: np.array(v) - for k, v in dist_main_program.state_dict().items() - } - with open(dist_param_path, "wb") as f: - pickle.dump(dist_param, f) - # save distributed attribute dist_attr_filename = filename + "_dist" + str(rank_id) + ".pdattr" dist_attr_path = os.path.join(dirname, dist_attr_filename) @@ -93,65 +91,69 @@ class DistributedSaver: with open(dist_attr_path, "wb") as f: pickle.dump(dist_attrs, f) + # save distributed params + dist_param_filename = filename + "_dist" + str(rank_id) + ".pdparams" + dist_param_path = os.path.join(dirname, dist_param_filename) + _save_state(dist_main_program, dist_param_path) + + # save distributed opt states + dist_opt_filename = filename + "_dist" + str(rank_id) + ".pdopt" + dist_opt_path = os.path.join(dirname, dist_opt_filename) + _save_state(dist_main_program, dist_opt_path, "opt") + # TODO:save cluster.json - def load(self, - path, - program, - dist_context, - strict=True, - load_optimizer=True): + def load(self, path, load_optimizer=True): # TODO: if `program` is None, load `path.pdmodel`. + def _load_file(filename, dirname, suffix="pdparams"): + file_list = [] + for file in os.listdir(dirname): + if check_filename('{}(.*)_dist(.*).{}'.format(filename, suffix), + file): + file_list.append(os.path.join(dirname, file)) + file_list.sort() + return file_list + + def _load_state(filename, dirname, suffix="pdparams"): + file_list = _load_file(filename, dirname, suffix) + state_dict = {} + for file in file_list: + with open(file, 'rb') as f: + state_dict_info = pickle.load(f, encoding='latin1') + for name, value in state_dict_info.items(): + if name in state_dict: + state_dict[name].append(np.array(value)) + else: + state_dict[name] = [np.array(value)] + self._logger.info("Load param file: {}".format(file_list)) + return state_dict + filename = os.path.basename(path) if filename == "": raise ValueError( "path should be of 'dirname/filename' format, but received filename is empty string" ) dirname = os.path.dirname(path) - # load path.pdparam - param_file_list = [] - for param_file in os.listdir(dirname): - if check_filename('{}(.*)_dist(.*).pdparams'.format(filename), - param_file): - param_file_list.append(os.path.join(dirname, param_file)) - param_file_list.sort() - self._logger.info( - "Load distributed attribute file: {}".format(param_file_list)) - param_dict = {} - for param_file in param_file_list: - with open(param_file, 'rb') as f: - state_dict_info = pickle.load(f, encoding='latin1') - for name, value in state_dict_info.items(): - if name in param_dict: - param_dict[name].append(np.array(value)) - else: - param_dict[name] = [np.array(value)] + + # load path.pdparam and path.pdopt + param_state_dict = _load_state(filename, dirname) + opt_state_dict = _load_state(filename, dirname, + "pdopt") if load_optimizer else {} + state_dict = dict(param_state_dict, **opt_state_dict) # load path.pdattr - dist_attr_file_list = [] - for dist_attr_file in os.listdir(dirname): - if check_filename('{}(.*)_dist(.*).pdattr'.format(filename), - dist_attr_file): - dist_attr_file_list.append(os.path.join(dirname, - dist_attr_file)) - dist_attr_file_list.sort() + dist_attr_file_list = _load_file(filename, dirname, "pdattr") self._logger.info( "Load distributed attribute file: {}".format(dist_attr_file_list)) - pre_dist_attr = {} + dist_attr = {} for dist_attr_file in dist_attr_file_list: with open(dist_attr_file, 'rb') as f: - dist_attr = pickle.load(f, encoding='latin1') - for name, attr in dist_attr.items(): - if name not in pre_dist_attr: - pre_dist_attr[name] = attr - - # get current dist_attr - cur_dist_attr = get_dist_attr(program, dist_context) - - # param convert - converter = Converter(param_dict, pre_dist_attr, cur_dist_attr) - param_dict = converter.convert(strict=strict) - program.set_state_dict(param_dict) + dist_attr_info = pickle.load(f, encoding='latin1') + for name, attr in dist_attr_info.items(): + if name not in dist_attr: + dist_attr[name] = attr + + return state_dict, dist_attr def save_inference_model(self, path, feed_vars, fetch_vars, exe, **kwargs): diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 5389438d388a585d3082ca12dcc1161524a7f242..ee6bee45fd7fe4aae3d95fe162049bcb2a692d89 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -12,76 +12,169 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import time import copy import logging +import random +import numpy as np from collections import defaultdict import paddle import paddle.utils as utils from paddle import fluid, static -from paddle.io import Dataset from paddle.jit import to_static from paddle.metric import Metric from paddle.static import InputSpec from paddle.fluid import core -from paddle.fluid import program_guard +from paddle.fluid import Variable from paddle.fluid.layers.utils import flatten from paddle.fluid.executor import global_scope, _to_name_str -from paddle.fluid.backward import append_backward from paddle.fluid.framework import Operator, Parameter, _non_static_mode from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.distributed import fleet -from paddle.distributed.passes import new_pass, PassContext +from .converter import Converter from .helper import ProgramHelper -from ..collective import _get_global_env from .cluster import Cluster, get_default_cluster from .planner_v2 import Planner from .parallelizer_v2 import Parallelizer from .dist_op import DistributedOperator from .dist_saver import DistributedSaver from .dist_loader import NonIterableGeneratorLoader -from .utils import make_data_unshard, set_grad_var_shape from .utils import print_program_with_dist_attr, to_list -from .process_group import new_process_group, get_all_process_groups, get_world_process_group +from .utils import get_logger, get_dist_attr +from .process_group import new_process_group, get_all_process_groups from .dist_context import DistributedContext, get_default_distributed_context +from .strategy import Strategy +from .interface import _get_fetches class Engine: + """ + An Engine object can provide the full power of auto parallel to users. + With the help of it, users can easily obtain the abilities of the + distributed training and inference. It also support the dynamic graph and + static graph at the same time. + + Args: + model (paddle.nn.Layer, optional): The model is an instance of + paddle.nn.Layer. + loss (Loss|Callable|None, optional): The loss can be a `paddle.nn.Layer` + instance or any callable function taken the predicted values and + ground truth values as input. It can be None when there is no loss. + Default: None. + optimizer (Optimizer|None, optional): The optimizer need to be set in training + and should be None in eval and predict mode. Default: None. + metrics (Metric|list[Metric]|None, optional): If metrics is set, all + metrics will be calculated and output in train/eval mode. Default: None. + cluster (Cluster|None, optional): The cluster represents the topology information + about the used physical devices. Default: None. (Unused for now) + strategy (Strategy|None, optional): The strategy is used to configure the + parallelization and optimization behaviors. Default: None. + + Examples: + + .. code-block:: python + + import paddle + import paddle.vision.transforms as T + import paddle.distributed.auto_parallel as auto + from paddle.vision.datasets import MNIST + + transform = T.Compose([ + T.Transpose(), + T.Normalize([127.5], [127.5]) + ]) + train_dataset = MNIST(mode='train', transform=transform) + valid_dataset = MNIST(mode='test', transform=transform) + + model = paddle.vision.models.LeNet() + loss = paddle.nn.CrossEntropyLoss() + optimizer = paddle.optimizer.Adam( + learning_rate=0.001, parameters=model.parameters()) + metrics = paddle.metric.Accuracy(topk=(1, 2)) + + engine = auto.Engine(model, loss, optimizer, metrics) + # fit + engine.fit(train_dataset, + epochs=2, + batch_size=64) + # evaluate + engine.evaluate(valid_dataset, + batch_size=64) + # predict + engine.predict(valid_dataset, + batch_size=64) + # save + engine.save("./my_model") + # load + engine.load("./my_model") + + """ def __init__(self, model=None, - inputs_spec=None, - labels_spec=None, + loss=None, + optimizer=None, + metrics=None, cluster=None, - strategy=None, - user_tuning_config=None): - self.model = model - self.strategy = strategy or fleet.DistributedStrategy() - self.inputs_spec = self._validate_spec(inputs_spec) - self.labels_spec = self._validate_spec(labels_spec) - self.cluster = cluster or get_default_cluster() - self._user_tuning_config = user_tuning_config + strategy=None): + + if model and not isinstance(model, + paddle.nn.Layer) and not callable(model): + raise TypeError( + "'model must be sub classes of `paddle.nn.Layer` or any callable function." + ) + self._model = model + + if loss and not isinstance(loss, + paddle.nn.Layer) and not callable(loss): + raise TypeError( + "'loss' must be sub classes of `paddle.nn.Layer` or any callable function." + ) + self._loss = loss + + if optimizer and not isinstance( + optimizer, + (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)): + raise TypeError( + "'optimizer' must be object of class `paddle.optimizer.Optimizer`" + " or `paddle.fluid.optimizer.Optimizer`.") + self._optimizer = self._validate_opt(optimizer) + + metrics = metrics or [] + for metric in to_list(metrics): + assert isinstance(metric, Metric), \ + "{} is not sub class of Metric".format( + metric.__class__.__name__) + self._metrics = to_list(metrics) + + if cluster and not isinstance(cluster, Cluster): + raise TypeError( + "'cluster' must be the object or class `paddle.distributed.auto_parallel.Cluster`" + ) + self._cluster = cluster or get_default_cluster() + + if strategy and not isinstance(strategy, Strategy): + raise TypeError( + "'strategy' must be object of class `paddle.distributed.auto_parallel.Strategy`" + ) + self._strategy = strategy or Strategy() + + if os.getenv("POD_NAME"): + print("Distribute training by paddle.distributed.launch", + flush=True) + fleet.init(is_collective=True) self._executor = None self._cur_rank = paddle.distributed.get_rank() self._nranks = paddle.distributed.get_world_size() self._saver = DistributedSaver() - # TODO: add logger module - self._logger = logging.getLogger() - self._logger.propagate = False - if not self._logger.handlers: - self._logger.setLevel(logging.INFO) - log_handler = logging.StreamHandler() - log_format = logging.Formatter( - '[%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s' - ) - log_handler.setFormatter(log_format) - self._logger.addHandler(log_handler) + self._logger = get_logger(logging.INFO) self._orig_main_prog = static.default_main_program() self._orig_startup_prog = static.default_startup_program() @@ -99,54 +192,18 @@ class Engine: "eval": False, "predict": False } - self._dygraph_mode = False - - def prepare(self, - optimizer=None, - loss=None, - gradient_scale=True, - metrics=None, - all_ranks=False): - if optimizer and not isinstance( - optimizer, - (paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)): - raise TypeError( - "'optimizer' must be object of class `paddle.optimizer.Optimizer`" \ - " or `paddle.fluid.optimizer.Optimizer`." - ) - self._optimizer = self._validate_opt(optimizer) - - if loss and not isinstance(loss, - paddle.nn.Layer) and not callable(loss): - raise TypeError( - "'loss' must be sub classes of `paddle.nn.Layer` or any callable function." - ) - self._loss = loss - metrics = metrics or [] - for metric in to_list(metrics): - assert isinstance(metric, Metric), \ - "{} is not sub class of Metric".format( - metric.__class__.__name__) - self._metrics = to_list(metrics) - self._gradient_scale = gradient_scale self._planned_mode = None - self._all_ranks = all_ranks - self._prepare_single_mode("train") + self._dygraph_mode = False + self._tuning = self._strategy.tuning def _prepare_single_mode(self, mode): - + # Do the build process self._build(mode) # Do the planning process self._plan(mode) - - # Do the Optimization tuning - if self._user_tuning_config and mode == "train": - self._optimization_tuning(mode) - # Do the parallel process - self._parallel(mode, self._all_ranks) - + self._parallel(mode) # Init comm and startup program self._initialize(mode) self._mode_init_states[mode] = True @@ -159,7 +216,7 @@ class Engine: inputs_spec = self.inputs_spec labels_spec = self.labels_spec if self.labels_spec else [] - self.program_helper = ProgramHelper(self.model, self._loss, + self.program_helper = ProgramHelper(self._model, self._loss, self._metrics, inputs_spec, labels_spec) # build forward main program @@ -186,14 +243,13 @@ class Engine: metrics = [] serial_main_prog = self._orig_main_prog.clone() serial_startup_prog = self._orig_startup_prog.clone() - # FIXME to support grad clip with static.program_guard(serial_main_prog, serial_startup_prog), \ utils.unique_name.guard(): inputs_spec = self.inputs_spec labels_spec = self.labels_spec if self.labels_spec else [] inputs = [s._create_feed_layer() for s in inputs_spec] labels = [s._create_feed_layer() for s in labels_spec] - outputs = to_list(self.model(*inputs)) + outputs = to_list(self._model(*inputs)) if mode != "predict" and self._loss: losses = to_list(self._loss(*(outputs + labels))) @@ -217,25 +273,30 @@ class Engine: "metrics": metrics } + if mode != "train": + serial_main_prog = serial_main_prog.clone(for_test=True) + self._set_recompute_ckpts() self._dist_contexts[mode] = DistributedContext( serial_main_prog, serial_startup_prog, self._optimizer, losses, - feed_vars, fetch_vars, self.cluster, self.strategy) - self._dist_contexts[mode].gradient_scale = self._gradient_scale - self._dist_contexts[mode]._dygraph_mode = self._dygraph_mode + feed_vars, fetch_vars, self._cluster, self._strategy) + self._dist_contexts[mode].gradient_scale = self._strategy.gradient_scale - def _optimization_tuning(self, mode): + def _optimization_tuning(self, mode, dataset, batch_size): + if not self._tuning.enable: + raise ValueError("Please set `tuning.enable=True`.") - self.mode = mode - assert "batch_size" in self._user_tuning_config, "Optimization Tuning should provide with batch size." - assert "dataset" in self._user_tuning_config, "Optimization Tuning should provide with dataset." - batch_size = self._user_tuning_config["batch_size"] - dataset = self._user_tuning_config["dataset"] - dataset.dp_world_size = self.dp_world_sizes - dataset.dp_rank = self.dp_ranks + assert mode == "train" + # Do the build process + self._build(mode) + # Do the planning process + self._plan(mode) + + dataset.dp_world_size = self._dp_world_sizes + dataset.dp_rank = self._dp_ranks from .tuner.optimization_tuner import OptimizationTuner - self._optimization_tuner = OptimizationTuner(self._user_tuning_config, + self._optimization_tuner = OptimizationTuner(self._tuning.to_dict(), self._dist_contexts[mode], dataset, self.inputs_spec, @@ -245,12 +306,10 @@ class Engine: self._optimization_tuner.tune() - if self._user_tuning_config["run_after_tuning"]: + if self._tuning.run_after_tuning: # update the strategy self._dist_contexts[ mode]._strategy = self._optimization_tuner.get_best_config() - else: - return def _plan(self, mode): if self._planned_mode is None: @@ -270,15 +329,15 @@ class Engine: if var.name in block.vars: feed_list.append(block.vars[var.name]) - self.dp_world_sizes = [] - self.dp_ranks = [] + self._dp_world_sizes = [] + self._dp_ranks = [] for feed_var in feed_list: dp_world_size, dp_rank = self._get_input_split_info( feed_var, self._dist_contexts[mode]) - self.dp_world_sizes.append(dp_world_size) - self.dp_ranks.append(dp_rank) + self._dp_world_sizes.append(dp_world_size) + self._dp_ranks.append(dp_rank) - def _parallel(self, mode, all_ranks): + def _parallel(self, mode, all_ranks=False): # Parallelize program based on the planner's results # For now, the completer has to be passed to the planner, # because we may use it to complete the annotation of the backwarkward and update. @@ -336,6 +395,11 @@ class Engine: if isinstance(place, fluid.CUDAPlace): place = fluid.CUDAPlace(ParallelEnv().dev_id) + if self._strategy.seed: + paddle.seed(self._strategy.seed + self._dp_ranks[0]) + np.random.seed(self._strategy.seed + self._dp_ranks[0]) + random.seed(self._strategy.seed + self._dp_ranks[0]) + if self._dygraph_mode: dist_context = self._dist_contexts[mode] dist_main_program = self._dist_main_progs[mode][self._cur_rank] @@ -354,102 +418,299 @@ class Engine: prune_startup_prog = dist_startup_prog._prune(uninitialized) self._executor.run(prune_startup_prog) + if hasattr(self, "_state_dict") and hasattr(self, "_dist_attr"): + self._set_state_dict(mode, self._strict, self._state_dict, + self._dist_attr) + + if self._strategy.reinit: + self._logger.info("NOTE: parameters wiil be re-initialized.") + dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] + self._executor.run(dist_startup_prog) + + def _infer_sample_spec(self, data, batch_size, split): + if isinstance(data, paddle.io.IterableDataset): + if split is None: + input, label = next(iter(data)) + else: + sample = next(iter(data)) + input = sample[:split] + label = sample[split:] + elif isinstance(data, paddle.io.Dataset): + if split is None: + input, label = data[0] + else: + sample = data[0] + input = sample[:split] + label = sample[split:] + else: + raise ValueError( + "Data should be a Dataset or IterableDatset, but received {}.". + format(type(data).__name__)) + + self.inputs_spec = [] + self.labels_spec = [] + input_list = to_list(input) + label_list = to_list(label) + + def _infer_item_spec(item, name, batch_size, specs): + if isinstance(item, np.ndarray): + spec = InputSpec.from_numpy(item, name) + if batch_size is None: + specs.append(spec) + else: + specs.append(spec.batch(batch_size)) + elif isinstance(item, (Variable, core.VarBase, core.eager.Tensor)): + spec = InputSpec.from_tensor(item, name) + if batch_size is None: + specs.append(spec) + else: + specs.append(spec.batch(batch_size)) + else: + specs.append(InputSpec([batch_size], type(item), name)) + + if input_list is not None: + for i, item in enumerate(input_list): + assert item is not None, "Receive None input." + name = "input" + str(i) + _infer_item_spec(item, name, batch_size, self.inputs_spec) + if label_list is not None: + for i, item in enumerate(label_list): + assert item is not None, "Receive None input." + name = "label" + str(i) + _infer_item_spec(item, name, batch_size, self.labels_spec) + + self.inputs_spec = self._validate_spec(self.inputs_spec) + self.labels_spec = self._validate_spec(self.labels_spec) + def fit(self, train_data, + train_sample_split=None, batch_size=1, epochs=1, - fetches=None, steps_per_epoch=None, + valid_data=None, + valid_sample_split=None, + valid_freq=1, + valid_steps=None, collate_fn=None, - use_cache=False, - return_numpy=True): - # TODO: callbacks - # TODO: evaluate after training - - if not self._mode_init_states['train']: - raise Exception( - "train program is not initialized yet, please call engine.prepare() before calling fit() funtion." - ) - + callbacks=None): + """ + Trains the model for a fixed number of epochs. If `valid_data` is set, + evaluation will be done at the end of each epoch. + + Args: + train_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None. + train_sample_split (int, optional): Each sample of the train dataset is assumed + to be a (input, label) pair by default and has two items. If each sample has + more than two items, train_sample_split specifies how to split these items into + input and label. The items before it are input and the left are label. Default: None. + batch_size (int, optional): The batch size of train_data and valid_data if provided. + The user's data will be used directly without batching if set to None. Default: 1. + epochs (int, optional): The number of epochs to train the model. Default: 1. + steps_per_epoch (int, optional): The total number of steps (batches of samples) + is executed in one epoch before stating the next one. If None, it is equal to + the number samples in your dataset divided by the batch size. Default: None. + valid_data (Dataset, optional): An instance of paddle paddle.io.Dataset used for + evaluation at the end of epoch. No evaluation will be done if set to None. + Default: None. (Unsupported for now) + valid_freq (int, optional): Only relevant if valid_data is provided. This specifies + how many training epochs before a new evaluation is performed. Default: 1. + valid_sample_split (int, optional): Only relevant if valid_data is provided. + Each sample of the valid dataset is assumed to be a (input, label) pair + by default and has two items. If each sample has more than two items, + valid_sample_split specifies how to split these items into input and label. + The items before it are input and the left are label. Default: None. + valid_steps (int, optional): Only relevant if valid_data is provided. + It is the total number of steps (batches of samples) to draw before + stopping validation at the end of every epoch. If None, validation will run until the + `valid_data` dataset is exhausted. The validation will start from the + beginning of the dataset at each epoch. Default: None. + collate_fn(callable, optional): function to generate mini-batch data by merging + the sample list, None for only stack each fields of sample in axis + 0. Default None. + callbacks (Callback|None, optional): A list of `Callback` instances to apply + during training. Default: None. (Unused for now) + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle + import paddle.vision.transforms as T + import paddle.distributed.auto_parallel as auto + from paddle.vision.datasets import MNIST + + transform = T.Compose([ + T.Transpose(), + T.Normalize([127.5], [127.5]) + ]) + train_dataset = MNIST(mode='train', transform=transform) + + model = paddle.vision.models.LeNet() + loss = paddle.nn.CrossEntropyLoss() + optimizer = paddle.optimizer.Adam( + learning_rate=0.001, parameters=model.parameters()) + metrics = paddle.metric.Accuracy(topk=(1, 2)) + + engine = auto.Engine(model, loss, optimizer, metrics) + engine.fit(train_dataset, + epochs=2, + batch_size=64) + """ self.mode = 'train' + self._infer_sample_spec(train_data, batch_size, train_sample_split) + if not self._mode_init_states[self.mode]: + self._prepare_single_mode(self.mode) + assert self.mode in self._dist_main_progs, \ "train model is not ready, please call `engine.prepare()` first." train_dataloader = self._create_dataloader(train_data, batch_size, epochs, steps_per_epoch, collate_fn) - usr_fetch = self._validate_fetches(fetches) fetch_loss = self._validate_fetches(self.fetch_vars["loss"]) - fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch) - lr_scheduler = self.get_lr_scheduler(self.main_program) + fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"]) + inner_fetch = dict(fetch_loss, **fetch_metrics) + usr_fetch = self._validate_fetches(_get_fetches()) + fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch) + lr_scheduler = self._get_lr_scheduler(self.main_program) + outputs = defaultdict(list) for epoch in range(epochs): train_logs = {"epoch: {:d} ": epoch} for step, _ in enumerate(train_dataloader): try: - outs = self._executor.run(self.main_program, - fetch_list=fetch_list, - use_program_cache=use_cache, - return_numpy=return_numpy) - except fluid.core.EOFException: + outs = self._executor.run( + self.main_program, + fetch_list=fetch_list, + use_program_cache=self._strategy.use_cache, + return_numpy=self._strategy.return_numpy) + except core.EOFException: break - train_logs["step: {:d} "] = step - if lr_scheduler is not None and step % self.k_steps == 0: + # update lr + if lr_scheduler and step % self._k_steps == 0: lr_scheduler.step() - try: - train_logs["lr: {:5e} "] = self._lr_optimizer.get_lr() - except: - train_logs[ - "lr: {:5e} "] = self._lr_optimizer._learning_rate.get_lr( - ) + train_logs["lr: {:5e} "] = self._get_lr(self._lr_optimizer) # inner fetches if fetch_loss: - train_logs["loss: {:9f} "] = outs[0][0] + train_logs["loss: {:8f} "] = outs[0][0] + outputs["loss"].append(outs[0][0]) + # Metric + if fetch_metrics: + metric_out = outs[len(fetch_loss):len(inner_fetch)] + for metric in self._metrics: + metric.update(*metric_out) + results = metric.accumulate() + for i, res in enumerate(to_list(results)): + train_logs[metric.name()[i] + ": {:8f} "] = res + outputs[metric.name()[i]].append(outs[0][0]) # user fetches - user_outs = outs[len(fetch_loss):] - user_fetch_list = fetch_list[len(fetch_loss):] + user_outs = outs[len(inner_fetch):] + user_fetch_list = fetch_list[len(inner_fetch):] for i, out in enumerate(user_outs): train_logs[fetch_map[user_fetch_list[i]] + ": {}"] = out # logger string = '[train] ' + ''.join(list(train_logs.keys())) self._logger.info(string.format(*list(train_logs.values()))) + if valid_data and epoch % valid_freq == 0: + self.evaluate(valid_data, valid_sample_split, batch_size, + valid_steps, collate_fn, callbacks) + self._switch_mode("train") + + self._reset_metrics() + return outputs + def evaluate(self, - eval_data, + valid_data, + valid_sample_split=None, batch_size=1, - fetches=None, + steps=None, collate_fn=None, - use_cache=False, - return_numpy=True): + callbacks=None): + """ + Evaluate the loss and metrics of the model on evaluation data. + + Args: + eval_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None. + eval_sample_split (int, optional): Each sample of the eval dataset is assumed + to be a (input, label) pair by default and has two items. If each sample has + more than two items, eval_sample_split specifies how to split these items into + input and label. The items before it are input and the left are label. Default: None. + batch_size (int, optional): The batch size of eval_data. The user's data will + be used directly without batching if set to None. Default: 1. + steps (int, optional): It is the total number of steps (batches of samples) to draw before + stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted. + The evaluation will start from the beginning of the dataset in each run. Default: None. + collate_fn(callable, optional): function to generate mini-batch data by merging + the sample list, None for only stack each fields of sample in axis + 0. Default None. + callbacks (Callback|None, optional): A list of `Callback` instances to apply + during evaling. Default: None. (Unused for now) + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle + import paddle.vision.transforms as T + import paddle.distributed.auto_parallel as auto + from paddle.vision.datasets import MNIST + + transform = T.Compose([ + T.Transpose(), + T.Normalize([127.5], [127.5]) + ]) + valid_dataset = MNIST(mode='test', transform=transform) + + model = paddle.vision.models.LeNet() + loss = paddle.nn.CrossEntropyLoss() + metrics = paddle.metric.Accuracy(topk=(1, 2)) + + engine = auto.Engine(model, loss, metrics=metrics) + engine.evaluate(valid_dataset, batch_size=64) + + """ self.mode = 'eval' + self._infer_sample_spec(valid_data, batch_size, valid_sample_split) if not self._mode_init_states[self.mode]: self._prepare_single_mode(self.mode) assert self.mode in self._dist_main_progs, \ "eval model is not ready, please call `engine.prepare()` first." - eval_dataloader = self._create_dataloader(eval_data, - batch_size, - collate_fn=collate_fn) + valid_dataloader = self._create_dataloader(valid_data, + batch_size, + steps_per_epoch=steps, + collate_fn=collate_fn) - usr_fetch = self._validate_fetches(fetches) fetch_loss = self._validate_fetches(self.fetch_vars["loss"]) fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"]) inner_fetch = dict(fetch_loss, **fetch_metrics) + usr_fetch = self._validate_fetches(_get_fetches()) fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch) - for step, _ in enumerate(eval_dataloader): - eval_logs = {"step: {:d} ": step} + outputs = defaultdict(list) + for step, _ in enumerate(valid_dataloader): try: - outs = self._executor.run(self.main_program, - fetch_list=fetch_list, - use_program_cache=use_cache, - return_numpy=return_numpy) - except fluid.core.EOFException: + outs = self._executor.run( + self.main_program, + fetch_list=fetch_list, + use_program_cache=self._strategy.use_cache, + return_numpy=self._strategy.return_numpy) + except core.EOFException: break + eval_logs = {"step: {:d} ": step} # inner fetches if fetch_loss: - eval_logs["loss: {:9f} "] = outs[0][0] + eval_logs["loss: {:8f} "] = outs[0][0] + outputs["eval_loss"].append(outs[0][0]) # Metric if fetch_metrics: metric_out = outs[len(fetch_loss):len(inner_fetch)] @@ -457,8 +718,9 @@ class Engine: metric.update(*metric_out) results = metric.accumulate() for i, res in enumerate(to_list(results)): - eval_logs[metric.name()[i] + ": {:9f} "] = res - # usr fetches + eval_logs[metric.name()[i] + ": {:8f} "] = res + outputs["eval_" + metric.name()[i]].append(res) + # user fetches usr_outs = outs[len(inner_fetch):] usr_fetch_list = fetch_list[len(inner_fetch):] for i, out in enumerate(usr_outs): @@ -466,15 +728,61 @@ class Engine: # logger string = '[eval] ' + ''.join(list(eval_logs.keys())) self._logger.info(string.format(*list(eval_logs.values()))) + self._reset_metrics() + return outputs def predict(self, test_data, + test_sample_split=None, batch_size=1, - fetches=None, + steps=None, collate_fn=None, - use_cache=False, - return_numpy=True): + callbacks=None): + """ + Compute the output predictions on testing data. + + Args: + test_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None. + test_sample_split (int, optional): Each sample of the test dataset is assumed + to be a (input, label) pair by default and has two items. If each sample has + more than two items, test_sample_split specifies how to split these items into + input and label. The items before it are input and the left are label. Default: None. + batch_size (int, optional): The batch size of test_data. The user's data will + be used directly without batching if set to None. Default: 1. + steps (int, optional): It is the total number of steps (batches of samples) to draw before + stopping predict. If None, predict will run until the `test_data` dataset is exhausted. + The predict will start from the beginning of the dataset in each run. Default: None. + collate_fn(callable, optional): function to generate mini-batch data by merging + the sample list, None for only stack each fields of sample in axis + 0. Default None. + callbacks (Callback|None, optional): A list of `Callback` instances to apply + during testing. Default: None. (Unused for now) + + Returns: + None + + Examples: + + .. code-block:: python + + import paddle + import paddle.vision.transforms as T + import paddle.distributed.auto_parallel as auto + from paddle.vision.datasets import MNIST + + transform = T.Compose([ + T.Transpose(), + T.Normalize([127.5], [127.5]) + ]) + valid_dataset = MNIST(mode='test', transform=transform) + + model = paddle.vision.models.LeNet() + + engine = auto.Engine(model) + engine.predict(valid_dataset, batch_size=64) + """ self.mode = 'predict' + self._infer_sample_spec(test_data, batch_size, test_sample_split) if not self._mode_init_states[self.mode]: self._prepare_single_mode(self.mode) @@ -482,22 +790,24 @@ class Engine: "predict model is not ready, please call `engine.prepare()` first." test_dataloader = self._create_dataloader(test_data, batch_size, + steps_per_epoch=steps, collate_fn=collate_fn) - usr_fetch = self._validate_fetches(fetches) fetch_outputs = self._validate_fetches(self.fetch_vars["outputs"]) + usr_fetch = self._validate_fetches(_get_fetches()) fetch_list, fetch_map = self._fetch_map(fetch_outputs, usr_fetch) outputs = [] for step, _ in enumerate(test_dataloader): - predict_logs = {"step: {:d} ": step} try: - outs = self._executor.run(self.main_program, - fetch_list=fetch_list, - use_program_cache=use_cache, - return_numpy=return_numpy) - except fluid.core.EOFException: + outs = self._executor.run( + self.main_program, + fetch_list=fetch_list, + use_program_cache=self._strategy.use_cache, + return_numpy=self._strategy.return_numpy) + except core.EOFException: break + predict_logs = {"step: {:d} ": step} outputs.append(outs[:len(fetch_outputs)]) for i, out in enumerate(outs): predict_logs[fetch_map[fetch_list[i]] + ": {}"] = out @@ -507,6 +817,11 @@ class Engine: return outputs + def _tune(self, tune_data, tune_sample_split=None, batch_size=1): + self.mode = 'train' + self._infer_sample_spec(tune_data, batch_size, tune_sample_split) + self._optimization_tuning(self.mode, tune_data, batch_size) + def _create_dataloader(self, dataset, batch_size, @@ -514,10 +829,10 @@ class Engine: steps_per_epoch=None, collate_fn=None): - if self.strategy.gradient_merge and batch_size is not None: - assert batch_size % self.k_steps == 0, \ - "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self.k_steps) - batch_size //= self.k_steps + if self._strategy.gradient_merge and batch_size is not None: + assert batch_size % self._k_steps == 0, \ + "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self._k_steps) + batch_size //= self._k_steps dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank] @@ -557,9 +872,9 @@ class Engine: epochs, steps_per_epoch, collate_fn, - data_parallel_world_size=self.dp_world_sizes, - data_parallel_rank=self.dp_ranks, - split_data=self.strategy.split_data) + data_parallel_world_size=self._dp_world_sizes, + data_parallel_rank=self._dp_ranks, + split_data=self._strategy.split_data) # move read op from the end of program to the start of program new_op_size = len(dist_main_block.ops) @@ -580,9 +895,7 @@ class Engine: def _validate_spec(self, specs): specs = to_list(specs) - self.k_steps = 1 - if self.strategy.gradient_merge: - self.k_steps = self.strategy.gradient_merge_configs['k_steps'] + self._k_steps = self._strategy.gradient_merge.k_steps if specs is not None: for i, spec in enumerate(specs): assert isinstance(spec, InputSpec) @@ -590,11 +903,11 @@ class Engine: raise ValueError( "Requires Input[{}].name != None, but receive `None` with {}." .format(i, spec)) - if self.k_steps > 1: + if self._k_steps > 1: shape = list(spec.shape) - assert shape[0] % self.k_steps == 0, \ - "Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self.k_steps) - shape[0] //= self.k_steps + assert shape[0] % self._k_steps == 0, \ + "Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self._k_steps) + shape[0] //= self._k_steps spec.shape = shape return specs @@ -655,38 +968,95 @@ class Engine: # NOTE hack to enable recompute in engine api for GPT-3 # TODO support more PaddleNLP/CV models here - config = self.strategy.recompute_configs + recompute = self._strategy.recompute # extract ckpts by specific model - if isinstance(self.model, paddle.nn.Layer): + if isinstance(self._model, paddle.nn.Layer): if hasattr( - self.model, "gpt" - ) and self.model.__class__.__name__ == 'GPTForPretraining': - exact_ckpts = self.model.gpt.checkpoints + self._model, "gpt" + ) and self._model.__class__.__name__ == 'GPTForPretraining': + exact_ckpts = self._model.gpt.checkpoints else: - exact_ckpts = config["checkpoints"] + exact_ckpts = recompute.checkpoints else: - exact_ckpts = config["checkpoints"] + exact_ckpts = recompute.checkpoints # modify strategy - if self.strategy.recompute: - config["checkpoints"] = exact_ckpts[:] - self.strategy.recompute_configs = config + if recompute.enable: + recompute.checkpoints = exact_ckpts[:] logs = { - 'Model Class': self.model.__class__.__name__, + 'Model Class': self._model.__class__.__name__, 'Applied Recompute ckpts': exact_ckpts } self._logger.info(logs) def _validate_opt(self, optimizer): - optimizer._parameter_list = None - optimizer._param_groups = None + if optimizer is not None: + optimizer._parameter_list = None + optimizer._param_groups = None return optimizer - def save(self, path, training=True, mode=None): - if not mode: - mode = self.mode + def _reset_metrics(self): + for metric in self._metrics: + metric.reset() + + def _switch_mode(self, mode): + self.mode = mode + self._initialize(mode) + + def _set_state_dict(self, mode, strict, state_dict, dist_attr): + program = self._dist_main_progs[mode][self._cur_rank] + dist_context = self._dist_contexts[mode] + cur_dist_attr = get_dist_attr(program, dist_context) + converter = Converter(state_dict, dist_attr, cur_dist_attr) + state_dict = converter.convert(strict=strict) + program.set_state_dict(state_dict) + + def save(self, path, training=True): + """ + Saves the model, parameters, optimizer state to path. + If `training` is set to False, only inference model will be saved. + + Args: + path (str): The file prefix to save model. The format + is 'dirname/file_prefix' or 'file_prefix'. if empty str. + A exception will be raised. + training (bool, optional): Whether to save for training. If not, save + for inference only. If `training` is set to True, the optimzer state + will be saved. Otherwise, only the model and parameters are saved. + This function will silently overwrite existing file at the target + location. Default: True. + + Returns: + None + + Examples: + + .. code-block:: python + import paddle + import paddle.vision.transforms as T + import paddle.distributed.auto_parallel as auto + from paddle.vision.datasets import MNIST + + transform = T.Compose([ + T.Transpose(), + T.Normalize([127.5], [127.5]) + ]) + train_dataset = MNIST(mode='train', transform=transform) + + model = paddle.vision.models.LeNet() + loss = paddle.nn.CrossEntropyLoss() + optimizer = paddle.optimizer.Adam( + learning_rate=0.001, parameters=model.parameters()) + metrics = paddle.metric.Accuracy(topk=(1, 2)) + + engine = auto.Engine(model, loss, optimizer, metrics) + engine.fit(train_dataset, + epochs=1, + batch_size=64) + engine.save("./my_model") + """ if training: assert 'train' in self._serial_main_progs, \ "training model is not ready, please call `engine.prepare()` first." @@ -698,7 +1068,7 @@ class Engine: dist_main_program=dist_main_prog, dist_context=dist_context) else: - assert mode, "Please set the 'mode' you want to save." + mode = "predict" feed_vars = self._feed_vars[mode]['inputs'] fetch_vars = self._fetch_vars[mode]['outputs'] dist_main_prog = self._dist_main_progs[mode][self._cur_rank] @@ -708,18 +1078,59 @@ class Engine: self._executor, program=dist_main_prog) - def load(self, path, strict=True, load_optimizer=True, mode=None): - if not mode: - mode = self.mode - assert mode, "Please set the 'mode' you want to load." + def load(self, path, strict=True, load_optimizer=True): + """ + Load the stored model, parameters and optimizer states. + + Args: + path (str): The prefix of files storing the model states and + optimizer states. + strict (bool, optional): Whether to skip the loading of mismatch + parameter or raise an error when mismatch happens (not found + the parameter in file storing model states of or receives a + mismatch shape). Default: False. + load_optimizer (bool, optional): If True, the stored optimizer + states is restored. Otherwise, the optimizer states is intialized + from scratch. Default: False. + + Returns: + None + + Examples: + + .. code-block:: python + import paddle + import paddle.vision.transforms as T + import paddle.distributed.auto_parallel as auto + from paddle.vision.datasets import MNIST + + transform = T.Compose([ + T.Transpose(), + T.Normalize([127.5], [127.5]) + ]) + train_dataset = MNIST(mode='train', transform=transform) + + model = paddle.vision.models.LeNet() + loss = paddle.nn.CrossEntropyLoss() + optimizer = paddle.optimizer.Adam( + learning_rate=0.001, parameters=model.parameters()) + metrics = paddle.metric.Accuracy(topk=(1, 2)) + + engine = auto.Engine(model, loss, optimizer, metrics) + engine.fit(train_dataset, + epochs=1, + batch_size=64) + engine.save("./my_model") + engine.load("./my_model") - dist_main_prog = self._dist_main_progs[mode][self._cur_rank] - dist_context = self._dist_contexts[mode] - self._saver.load(path, dist_main_prog, dist_context, strict, - load_optimizer) + """ + self._strict = strict + self._state_dict, self._dist_attr = self._saver.load( + path, load_optimizer) + return self._state_dict, self._dist_attr @staticmethod - def get_lr_scheduler(program): + def _get_lr_scheduler(program): lr_sheduler = None if hasattr(program, 'lr_sheduler'): from paddle.optimizer.lr import LRScheduler @@ -727,6 +1138,20 @@ class Engine: assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler" return lr_sheduler + def _get_lr(self, optimizer): + if isinstance(optimizer, paddle.optimizer.Optimizer): + return optimizer.get_lr() + elif isinstance(optimizer, paddle.fluid.optimizer.Optimizer): + if isinstance(optimizer._learning_rate, float): + return optimizer._learning_rate + else: + return optimizer._learning_rate() + else: + raise TypeError( + "'optimizer' must be object of class `paddle.optimizer.Optimizer`" \ + " or `paddle.fluid.optimizer.Optimizer`, but got {}.".format(type(optimizer)) + ) + @property def mode(self): return self._mode @@ -758,3 +1183,11 @@ class Engine: @property def fetch_vars(self): return self._fetch_vars[self.mode] + + @property + def inputs(self): + return self.inputs_spec + + @property + def labels(self): + return self.labels_spec diff --git a/python/paddle/distributed/auto_parallel/helper.py b/python/paddle/distributed/auto_parallel/helper.py index 7a17ba65414cec96c4f5b2c33e7834b77bfbcbd7..6bc177efc9de99f0afa9012ed1a5802786a3b660 100644 --- a/python/paddle/distributed/auto_parallel/helper.py +++ b/python/paddle/distributed/auto_parallel/helper.py @@ -19,13 +19,13 @@ import paddle from paddle.nn import Layer from paddle.jit import to_static, not_to_static -from paddle.distributed.utils import get_logger from paddle.fluid.framework import Operator, Parameter, _non_static_mode from paddle.fluid.framework import program_guard from paddle.fluid.executor import global_scope from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction from .utils import to_list +from .utils import get_logger from .converter import Converter diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index e06120a7e19d07e0cc51c8ff46e8d2cee733a687..ad3078c449048e17cb99b23ce9db93f440bc5d03 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -12,101 +12,198 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy -import copy -import paddle -import paddle.fluid.core as core -from paddle.fluid.framework import Variable -from paddle.fluid.framework import _non_static_mode +from paddle.fluid import core +from .process_mesh import ProcessMesh +from .process_mesh import get_current_process_mesh +from .process_mesh import set_current_process_mesh +from .process_mesh import reset_current_process_mesh from .dist_context import get_default_distributed_context from .dist_tensor import DistributedTensor -from .dist_op import DistributedModule -from .dist_attribute import TensorDistributedAttribute -from .dist_attribute import OperatorDistributedAttribute +from .dist_op import DistributedOperatorHelper +from .utils import verify_shard_spec, convert_to_dims_mapping -def _static_mode_check(): - if _non_static_mode(): - raise RuntimeError("Auto-parallel only supports static mode for now, " - "please use paddle.enable_static() first.") - - -def shard_tensor(x, dist_attr=None): +def shard_tensor(x, process_mesh=None, shard_spec=None): """ - Add distributed attributes for a tensors. + Shard a tensor on a process mesh according to the shard specification. Args: x (Tensor): the tensor to be sharded. - dist_attr (dict): the tensor distributed attributes. The accepted attributes are as follow: - "process_mesh": a nested list an to describe the mesh topology of logical processes. - "dims_mapping": a list to describe the mapping between `x` and `process_mesh`, the dimension - `i` of `x` is split across the dimension `dims_mapping[i]` of `process_mesh`, - where -1 means that tensor dimension is not split. - Both process_mesh and dims_mapping are optional and users can specify as need. + process_mesh (ProcessMesh, optional): An instance of ProcessMesh describes a mesh + topology of the used logical processes where the tensor is sharded. If it is None, + the found current process mesh will be used. And an error will be raised if the + current process mesh cannot be found. Default: None. + shard_spec (list, optional): a list to describe the sharding mapping between `x` and `process_mesh`, + which means the dimension `i` of `x` is split across the dimension `shard_spec[i]` of `process_mesh`, + where `None` means that tensor dimension is not split. For example, given a tensor wih + the shape [6, 12] and a process mesh with the shape [2, 3] and the dimension names ["x", "y"]: + If `shard_spec=["x", "y"]`, each shard of the tensor will have a shape [3, 4]; + If `shard_spec=["y", "x"]`, each shard of the tensor will have a shape [2, 6]; + If `shard_spec=["x", None]`, each shard of the tensor will have a shape [3, 12]; + If `shard_spec=[None, "x"]`, each shard of the tensor will have a shape [6, 4]; + If `shard_spec=["y", None]`, each shard of the tensor will have a shape [2, 12]; + If `shard_spec=[None, "y"]`, each shard of the tensor will have a shape [6, 4]; + If `shard_spec=[None, None]`, each shard of the tensor will have a shape [6, 12]; + If the `shard_spec` is None, the tensor will be replicated across all the processes of `process_mesh`. + In the above example, the `shard_spec=None` is same as 'shard_spec=[None, None]'. Defaults: None. Returns: - Tensor: the tensor `x` annotated with distributed attributes. + Tensor: the tensor `x` annotated with sharding information. Examples: .. code-block:: python import paddle - import paddle.distributed as dist - - paddle.enable_static() + import paddle.distributed.auto_parallel as auto + mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]) x = paddle.ones([4, 6]) - dist.shard_tensor(x, dist_attr={"process_mesh": [[0, 1], [2, 3]], - "dims_mapping": [0, -1]}) + shard_spec = ["x", "y"] + auto.shard_tensor(x, mesh, shard_spec) """ - _static_mode_check() - assert dist_attr is None or isinstance(dist_attr, (dict, TensorDistributedAttribute)), \ - "The type of dist_attr must be None, dict or TensorDistributedAttribute." - dist_tensor = DistributedTensor(x, dist_attr) - dist_tensor.dist_attr.mark_annotated_as(dist_attr) + + if process_mesh is not None: + assert isinstance(process_mesh, ProcessMesh), \ + "Argument process_mesh {} is not an instance of ProcessMesh".format(process_mesh) + else: + process_mesh = get_current_process_mesh() + assert process_mesh is not None, \ + "Specify the process mesh argument or use ProcessMesh context manager first." + assert isinstance(shard_spec, list), \ + "Argument shard_spec {} is not an instance of list".format(shard_spec) + dist_tensor = DistributedTensor(x) + serial_tensor = dist_tensor.serial_tensor + dist_tensor.dist_attr.process_mesh = process_mesh + if serial_tensor.type == core.VarDesc.VarType.READER \ + or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ + or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES: + tensor_shape = [] + else: + tensor_shape = serial_tensor.shape + if shard_spec is not None: + assert verify_shard_spec(shard_spec, tensor_shape, process_mesh), \ + "For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format( + serial_tensor.name, shard_spec, tensor_shape, process_mesh) + dist_tensor.dist_attr.dims_mapping = convert_to_dims_mapping( + shard_spec, process_mesh) + if process_mesh is not None: + dist_tensor.dist_attr.mark_annotated("process_mesh") + if shard_spec is not None: + dist_tensor.dist_attr.mark_annotated("dims_mapping") default_dist_ctx = get_default_distributed_context() default_dist_ctx.add_dist_tensor_for_program(dist_tensor) + dist_tensor = default_dist_ctx.get_dist_tensor_for_program(x) return x -def shard_op(op_fn, dist_attr=None): +def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None): """ - Call a functioin and add distributed attributes for ops added by the function. + Shard an operation on a process mesh according to its input and output shard specification. Args: - op_fn (callable): a callable operator or module to be sharded. - dist_attr (dict): the operator distributed attributes. The accepted attributes are classified into - two categories. The first category decsribes the distributed attributes shared by all inputs and - outputs, and only `process_mesh` can be specified now. The second category describes distributed - attributes for inputs or outputs same as the `dist_attr` of `shard_tensor`. All of them are - optional and users can specify them as need. Note that `process_mesh` for operators must be the - same as these process_meshes for inputs and outputs. + op (Callable): a callable operator or module to be sharded. + process_mesh (ProcessMesh, optional): An instance of ProcessMesh describes a mesh + topology of the used logical processes where the op is sharded. All of its inputs and + outputs are sharded by this process mesh. If it is None, the found current process mesh + will be used. And an error will be raised if the current process mesh cannot be found. + Default: None. + in_shard_specs (list of list, optional): a list of list to describe the sharding specifications + for the inputs. Each item of `in_shard_specs` is a `shard_spec` between the correspoinding input + and `process_mesh`. If one item is None, the cooresponding input is replicated across all processes + If it is None, all inputs are replicated accross all processes. Note that the lenght of the + `in_shard_specs` should be equal to the actual number of inputs when calling this operation. + Default: None. + out_shard_specs (list of list, optional): a list of list to describe the sharding specifications + for the outputs. Each item of `out_shard_specs` is a `shard_spec` between the correspoinding output + and `process_mesh`. If one item is None, the cooresponding output is replicated across all processes + If it is None, all outputs are replicated accross all processes. Note that the lenght of the + `in_shard_specs` should be equal to the actual number of inputs when calling this operation. + Default: None. Default: None. Returns: - list: the outputs of the function `op_fn`, which are annotated with distributed attributes. + Outputs of `op`, each of which is annotated with sharding information. Examples: .. code-block:: python import paddle - import paddle.distributed as dist - - paddle.enable_static() - + import paddle.distributed.auto_parallel as auto + x = paddle.ones([4, 6]) y = paddle.zeros([4, 6]) - dist_add = dist.shard_op(paddle.add, - dist_attr={ - "process_mesh": [[2, 3, 1], [0, 4, 5]], - x: {"dims_mapping": [-1, 0]}, - y: {"dims_mapping": [0, -1]} - }) + mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]) + dist_add = auto.shard_op(paddle.add, + in_shard_specs=[["x", "y"], ["y", None]], + out_shard_specs=[[None, "x"]]) dist_add(x, y) """ - _static_mode_check() - assert dist_attr is None or isinstance(dist_attr, (dict, OperatorDistributedAttribute)), \ - "The type of dist_attr must be dict or OperatorDistributedAttribute." - dist_module = DistributedModule(op_fn, dist_attr) - return dist_module + + if process_mesh is not None: + assert isinstance(process_mesh, ProcessMesh), \ + "Argument process_mesh {} is not an instance of ProcessMesh".format(process_mesh) + else: + process_mesh = get_current_process_mesh() + assert process_mesh is not None, \ + "Specify the process mesh argument or use ProcessMesh context manager first." + in_dims_mappings = [] + if in_shard_specs is not None: + assert all((isinstance(shard_spec, list) or shard_spec is None) for shard_spec in in_shard_specs), \ + "in_shard_spec {} is not a list of list or None".format(in_shard_specs) + for shard_spec in in_shard_specs: + if shard_spec is not None: + in_dims_mappings.append( + convert_to_dims_mapping(shard_spec, process_mesh)) + else: + in_dims_mappings.append(None) + out_dims_mappings = [] + if out_shard_specs is not None: + assert all((isinstance(shard_spec, list) or shard_spec is None) for shard_spec in out_shard_specs), \ + "out_shard_spec {} is not a list of list or None".format(out_shard_specs) + for shard_spec in out_shard_specs: + if shard_spec is not None: + out_dims_mappings.append( + convert_to_dims_mapping(shard_spec, process_mesh)) + else: + out_dims_mappings.append(None) + op = DistributedOperatorHelper(op, process_mesh, in_dims_mappings, + out_dims_mappings) + return op + + +def recompute(op): + + class RecomputeOperator: + + def __init__(self, op): + self._op = op + + def __call__(self, *args, **kwargs): + default_prog = paddle.fluid.default_main_program() + cur_block = default_prog.current_block() + op_size = len(cur_block.ops) + output = self._op(*args, **kwargs) + new_op_size = len(cur_block.ops) + + for idx in range(op_size, new_op_size): + op = cur_block.ops[idx] + op._set_attr("is_recompute@auto_parallel", True) + + return output + + return RecomputeOperator(op) + + +_g_fetched_tensors = {} + + +def fetch(tensor, name=None): + if name is None: + _g_fetched_tensors[tensor.name] = tensor + else: + _g_fetched_tensors[name] = tensor + + +def _get_fetches(): + return _g_fetched_tensors diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 70d9a2d21b5e9625f67b5df244184134a9880996..2449a1a2c5e5c0aa0fadfe30ccc66e4d709e7ac6 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -42,6 +42,7 @@ from .utils import make_data_unshard from .utils import set_grad_var_shape from .utils import print_program_with_dist_attr from .utils import SerialProgramInfo +from .utils import get_logger from .reshard import Resharder from .cluster import Cluster from .mapper import mapping diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index 7e43ee95266438b0689656d3916ea91d3008bf31..b83a19b512ef8be8c6fc13aede8056676de4483d 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -22,7 +22,6 @@ from paddle.fluid import program_guard from paddle.fluid.backward import append_backward from paddle.fluid.framework import _non_static_mode, unique_name from paddle.distributed.passes import new_pass -from paddle.distributed.utils import get_logger from .reshard import Resharder from .partitioner import Partitioner @@ -31,6 +30,7 @@ from .dist_saver import DistributedSaver from .dist_loader import NonIterableGeneratorLoader from .utils import make_data_unshard, set_grad_var_shape from .utils import print_program_with_dist_attr, to_list +from .utils import get_logger from .process_group import get_all_process_groups, get_world_process_group from .dist_context import DistributedContext, get_default_distributed_context @@ -160,8 +160,8 @@ class Parallelizer: # apply quantization pass # The pass can be applied when mode must be 'train' - if self._mode == 'train' and self._strategy.qat: - config = copy.deepcopy(self._strategy.qat_configs) + if self._mode == 'train' and self._strategy.qat.enable: + config = copy.deepcopy(self._strategy.qat.to_dict()) config["dist_context"] = self._dist_context config["params_grads"] = params_grads auto_parallel_quantization_pass = new_pass( @@ -176,8 +176,8 @@ class Parallelizer: # apply amp pass # FIXME we disenable amp for eval since it has a little bug with # eval program and which will be fixed in future - if self._mode == 'train' and self._strategy.amp: - config = copy.deepcopy(self._strategy.amp_configs) + if self._mode == 'train' and self._strategy.amp.enable: + config = copy.deepcopy(self._strategy.amp.to_dict()) config["dist_context"] = self._dist_context config["params_grads"] = params_grads config["loss"] = loss @@ -195,8 +195,8 @@ class Parallelizer: # apply recompute pass # recompute is then train-only optimization - if self._mode == "train" and self._strategy.recompute: - config = copy.deepcopy(self._strategy.recompute_configs) + if self._mode == "train" and self._strategy.recompute.enable: + config = copy.deepcopy(self._strategy.recompute.to_dict()) config["dist_context"] = self._dist_context config["no_grad_set"] = None config["loss"] = loss @@ -217,12 +217,12 @@ class Parallelizer: config = {} config["dist_context"] = self._dist_context config["global_rank"] = rank - config["use_sharding"] = self._strategy.sharding + config["use_sharding"] = self._strategy.sharding.enable dp_pass = new_pass("auto_parallel_data_parallel_optimization", config) dp_pass.apply([main_program], [startup_program], self._pass_context) - if self._strategy.sharding: - config = copy.deepcopy(self._strategy.sharding_configs) + if self._strategy.sharding.enable: + config = copy.deepcopy(self._strategy.sharding.to_dict()) config["dist_context"] = self._dist_context config["params_grads"] = params_grads config["global_rank"] = rank @@ -234,7 +234,7 @@ class Parallelizer: # GradClip is train-only optimization if self._mode == "train": - config = copy.deepcopy(self._strategy.sharding_configs) + config = copy.deepcopy(self._strategy.sharding.to_dict()) config["dist_context"] = self._dist_context config["params_grads"] = params_grads config["rank_id"] = rank @@ -244,8 +244,8 @@ class Parallelizer: self._pass_context) # gradient_merge is then train-only optimization - if self._mode == "train" and self._strategy.gradient_merge: - config = copy.deepcopy(self._strategy.gradient_merge_configs) + if self._mode == "train" and self._strategy.gradient_merge.enable: + config = copy.deepcopy(self._strategy.gradient_merge.to_dict()) config["dist_context"] = self._dist_context config["params_grads"] = params_grads auto_parallel_gradient_merge_pass = new_pass( diff --git a/python/paddle/distributed/auto_parallel/process_mesh.py b/python/paddle/distributed/auto_parallel/process_mesh.py index ab1d68bbf8ea067c927ab5caca710b4995e6376f..14ce5ea75b10cc0c5e2b1acf251c3859e6ec7f22 100644 --- a/python/paddle/distributed/auto_parallel/process_mesh.py +++ b/python/paddle/distributed/auto_parallel/process_mesh.py @@ -12,86 +12,90 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy +import numpy as np import copy +import paddle +# Use to store the previous and current process mesh +_g_previous_process_mesh = None +_g_current_process_mesh = None -def _get_nested_list_shape(nested_list): - """ - Get the shape of a nested_list. - """ - result = [] - while isinstance(nested_list, list): - result.append(len(nested_list)) - nested_list = nested_list[0] - return result +def get_current_process_mesh(): + global _g_current_process_mesh + return _g_current_process_mesh -def _flatten_nested_list(nested_list): - """ - Get a list of all items in a nested_list. - Ref: https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-a-list-of-lists - """ - result = numpy.array(nested_list).flatten().tolist() - return result +def set_current_process_mesh(process_mesh): + global _g_previous_process_mesh + global _g_current_process_mesh + _g_previous_process_mesh = _g_current_process_mesh + _g_current_process_mesh = process_mesh -class ProcessMesh(object): - r""" - The class `Processmesh` describes the topology of logical processes. - A mesh is an N-dimensional array. The shape of the N-dimensional - array represents the topology of logical processes and every - element of the N-dimensional array represent a logical process. For - example, the 2-dimensional array [[2, 4, 5], [0, 1, 3]] - illustrates six logical processes organized as the topology [2, 3], - i.e., the shape of the 2-dimensional array. With the above topology, - there are two parallel groups, where the first parallel group has a - parallel degree of 2 and the second one has a parallel degree of 3. - And the first logical process is the one with id=2. - Args: - mesh (list): an N-dimensional array (nested list) describes the toplogy - of logical processes. The shape of the N-dimensional array - represents the topology of logical processes and every - element of the N-dimensional array represents a logical process. +def reset_current_process_mesh(): + global _g_previous_process_mesh + global _g_current_process_mesh + _g_current_process_mesh = _g_previous_process_mesh - Returns: - None - Raises: - ValueError: If `mesh` is not an instance of list. +class ProcessMesh(object): + """ + The `Processmesh` object describes the topology of the used processes. + Args: + mesh (list|numpy.array): an n-dimensional array describes the toplogy + of the processes. + dim_names (list, optional): the i-th element of this list gives the name of the + i-th dimension of the mesh. + Examples: .. code-block:: python import paddle - import paddle.distributed as dist - - paddle.enable_static() - - mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]]) - assert mesh.topology == [2, 3] - assert mesh.processes == [2, 4, 5, 0, 1, 3] + + mesh = auto.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]) + assert mesh.shape == [2, 3] + assert mesh.processe_ids == [2, 4, 5, 0, 1, 3] """ - def __init__(self, mesh): - if mesh is None or not isinstance(mesh, list): - raise ValueError('mesh must be an instance of list.') - - processes = _flatten_nested_list(mesh) - - assert all(isinstance(p, int) for p in processes), \ - ("All elements of mesh must be integer") - - assert min(processes) >= 0, ('All elements of mesh must be >= 0.') - - unique_processes = set(processes) - assert len(unique_processes) == len(processes), ( - 'All elements of mesh must be unique.') - - self._topology = _get_nested_list_shape(mesh) - self._processes = processes + def __init__(self, mesh=None, dim_names=None, shape=None, process_ids=None): + # Use shape and process_ids just for compatibility + # Users should not use these directly + if mesh is None: + assert shape is not None + assert process_ids is not None + mesh = np.array(process_ids).reshape(shape) + + if not isinstance(mesh, list) and \ + not isinstance(mesh, np.ndarray): + raise ValueError( + 'The mesh must be an instance of list or np.ndarray.') + if isinstance(mesh, list): + mesh = np.array(mesh) + + self._mesh = mesh + self._shape = list(self._mesh.shape) + self._process_ids = self._mesh.flatten().tolist() + + assert all(isinstance(p, int) for p in self._process_ids), \ + ("All elements of the mesh must be integer") + assert min( + self._process_ids) >= 0, ('All elements of the mesh must be >= 0.') + unique_process_ids = set(self._process_ids) + assert len(unique_process_ids) == len( + self._process_ids), ('All elements of the mesh must be unique.') + + if dim_names is not None: + assert len(dim_names) == len(self._shape), \ + ("The length of dims_names must be same as the shape of the mesh.") + self._dim_names = copy.deepcopy(dim_names) + else: + self._dim_names = ["d" + str(i) for i in range(len(self._shape))] + unique_dim_names = set(self._dim_names) + assert len(unique_dim_names) == len(self._dim_names), ( + 'All dim_names {} must be unique.'.format(dim_names)) # Store all process meshes from .dist_context import get_default_distributed_context @@ -103,31 +107,117 @@ class ProcessMesh(object): pg0.add_ranks(self.processes) @property - def topology(self): - r""" - Get the topology of logical processes belonging to this ProcessMesh. - This is the shape of `mesh` used to initialized this ProcessMesh. + def shape(self): + """ + Get the shape of this ProcessMesh. """ - return self._topology + return self._shape @property - def processes(self): - r""" - Get a list of all processes belonging to this ProcessMesh. + def process_ids(self): + """ + Get the process ids belonging to this ProcessMesh. """ - return self._processes + return self._process_ids + + @property + def dim_names(self): + """ + Get the dimension names of this ProcessMesh. + """ + return self._dim_names @property def ndim(self): - r""" - Get the number of dimension of ProcessMesh. """ - return len(self._topology) + Get the number of dimension of this ProcessMesh. + """ + return len(self._shape) + + @property + def mesh(self): + """ + Get the underlying mesh of ProcessMesh. + """ + return self._mesh + + @property + def topology(self): + return self._shape + + @property + def processes(self): + return self._process_ids + + def __getitem__(self, index): + if isinstance(index, tuple): + new_dim_names = [] + for i, item in enumerate(index): + if isinstance(item, slice): + new_dim_names.append(self._dim_names[i]) + new_mesh = self._mesh[index] + if new_mesh.shape: + return ProcessMesh(new_mesh, new_dim_names) + else: + # Wrap a scalar into a list but without dim_names + return ProcessMesh([new_mesh]) + elif isinstance(index, slice): + new_mesh = self._mesh[index] + new_dim_names = self._dim_names + return ProcessMesh(new_mesh, new_dim_names) + else: + new_mesh = self._mesh[index] + new_dim_names = self._dim_names[1:] + return ProcessMesh(new_mesh, new_dim_names) + + def __enter__(self): + set_current_process_mesh(self) + default_prog = paddle.fluid.default_main_program() + cur_block = default_prog.current_block() + self._old_var_names = list(cur_block.vars.keys()) + self._old_op_size = len(cur_block.ops) + + def __exit__(self, exc_type, exc_value, exc_traceback): + from .dist_tensor import DistributedTensor + from .dist_op import DistributedOperator + default_prog = paddle.fluid.default_main_program() + cur_block = default_prog.current_block() + new_var_names = list(cur_block.vars.keys()) + new_op_size = len(cur_block.ops) + from .dist_context import get_default_distributed_context + default_dist_ctx = get_default_distributed_context() + for name in new_var_names: + if name not in self._old_var_names: + tensor = cur_block.vars[name] + dist_tensor = default_dist_ctx.get_dist_tensor_for_program( + tensor) + if dist_tensor is None: + dist_tensor = DistributedTensor(cur_block.vars[name], + {"process_mesh": self}) + dist_tensor.dist_attr.mark_annotated("process_mesh") + default_dist_ctx.add_dist_tensor_for_program(dist_tensor) + else: + if dist_tensor.dist_attr.process_mesh is None: + dist_tensor.dist_attr.process_mesh = self + dist_tensor.dist_attr.mark_annotated("process_mesh") + + for idx in range(self._old_op_size, new_op_size): + op = cur_block.ops[idx] + dist_op = default_dist_ctx.get_dist_op_for_program(op) + if dist_op is None: + dist_op = DistributedOperator(op, {"process_mesh": self}) + dist_op.dist_attr.mark_annotated("process_mesh") + default_dist_ctx.add_dist_op_for_program(dist_op) + else: + if dist_op.dist_attr.process_mesh is None: + dist_op.dist_attr.process_mesh = self + dist_op.dist_attr.mark_annotated("process_mesh") + reset_current_process_mesh() def __eq__(self, other): if not isinstance(other, ProcessMesh): return False - if self.topology != other.topology or self.processes != other.processes: + if self.shape != other.shape or self.process_ids != other.process_ids: return False return True @@ -135,6 +225,6 @@ class ProcessMesh(object): return not self.__eq__(other) def __str__(self): - str = "shape {} and process group {}".format(self.topology, - self.processes) + str = "shape {}, process_ids {}, dim_nams {}".format( + self.shape, self.process_ids, self.dim_names) return str diff --git a/python/paddle/distributed/auto_parallel/process_mesh_v2.py b/python/paddle/distributed/auto_parallel/process_mesh_v2.py index b57cecf41e26227fd86987516d690007acf2d249..aa9401b5f50e8c992031c1624760c5157c4f03a0 100644 --- a/python/paddle/distributed/auto_parallel/process_mesh_v2.py +++ b/python/paddle/distributed/auto_parallel/process_mesh_v2.py @@ -81,54 +81,57 @@ class ProcessMesh(core.ProcessMesh): return self._mesh -# def compute_compatible_process_meshes(process_meshes): -# """Compute the compatible process mesh given a list of process meshes.""" -# if not process_meshes: -# return None - -# def _compute_compatible_two_process_meshes(pm1, pm2): -# if pm1 is None: -# return True, pm2 -# if pm2 is None: -# return True, pm1 -# if pm1 == pm2: -# return True, pm1 -# if pm1.device_mesh != pm2.device_mesh: -# return False, None -# if pm1.process_ids == pm2.process_ids: -# if len(pm1.shape) >= len(pm2.shape): -# return True, pm1 -# else: -# return True, pm2 -# process_set1 = set(pm1.process_ids) -# process_set2 = set(pm2.process_ids) -# if process_set1.issubset(process_set2): -# return True, pm2 -# if process_set2.issubset(process_set1): -# return True, pm1 -# return False, None - -# compatible_result = None -# for process_mesh in process_meshes: -# compatible, compatible_result = _compute_compatible_two_process_meshes( -# compatible_result, process_mesh) -# if not compatible: -# return None -# return ProcessMesh(compatible_result.mesh, compatible_result.dim_names) - -# def merge_process_meshes(process_meshes): -# """Merge a list of process meshes.""" -# merged_process_mesh = None -# merged_process_ids = set() -# device_type = "" -# for process_mesh in process_meshes: -# if process_mesh is not None: -# process_ids = set(process_mesh.process_ids) -# if not device_type: -# device_type = process_mesh.device_type -# assert device_type != process_mesh.device_type, \ -# "All process meshes must have the same device_type." -# merged_process_ids.union(process_ids) -# if len(merged_process_ids) != 0: -# merged_process_mesh = ProcessMesh(list(merged_process_ids)) -# return merged_process_mesh +def compute_compatible_process_mesh(process_meshes): + """Compute the compatible process mesh given a list of process meshes.""" + if not process_meshes: + return None + + def _compute_compatible_of_two_process_meshes(pm1, pm2): + if pm1 is None: + return True, pm2 + if pm2 is None: + return True, pm1 + if pm1 == pm2: + return True, pm1 + if pm1.process_ids == pm2.process_ids: + if len(pm1.shape) >= len(pm2.shape): + return True, pm1 + else: + return True, pm2 + process_set1 = set(pm1.process_ids) + process_set2 = set(pm2.process_ids) + if process_set1.issubset(process_set2): + return True, pm2 + if process_set2.issubset(process_set1): + return True, pm1 + return False, None + + compatible_result = None + for process_mesh in process_meshes: + compatible, compatible_result = _compute_compatible_of_two_process_meshes( + compatible_result, process_mesh) + if not compatible: + return None + if compatible_result.empty(): + return None + if isinstance(compatible_result, core.ProcessMesh): + mesh = np.array(compatible_result.process_ids).reshape( + compatible_result.shape) + return ProcessMesh(mesh, compatible_result.dim_names) + elif isinstance(compatible_result, ProcessMesh): + return ProcessMesh(compatible_result.mesh, compatible_result.dim_names) + else: + raise ValueError("Unrecognized ProcessMesh.") + + +def merge_process_mesh(process_meshes): + """Merge a list of process meshes.""" + merged_process_mesh = None + merged_process_ids = set() + for process_mesh in process_meshes: + if process_mesh is not None: + process_ids = set(process_mesh.process_ids) + merged_process_ids = merged_process_ids.union(process_ids) + if len(merged_process_ids) != 0: + merged_process_mesh = ProcessMesh(list(merged_process_ids)) + return merged_process_mesh diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..c196b321eafd0ad91d98f4f432eb41cab46af618 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -0,0 +1,181 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import os +import copy +import argparse +from . import constants + + +class BaseConfig(object): + + def __init__(self, category, config_dict=None): + self._category = category + self._config_dict = None + if config_dict is not None: + if isinstance(config_dict, dict): + self._config_dict = config_dict + else: + raise ValueError( + "Expected a dictionary. But received: {}".format( + config_dict)) + # Initialize attributes by the default config + config = constants.get_category_default_config(self._category) + for field, default_value in config.items(): + setattr(self, field, default_value) + + # Overide attributes by the config_dict + if self._config_dict: + self.from_dict(self._config_dict) + + def from_dict(self, config_dict): + config = constants.get_category_default_config(self._category) + for field in config.keys(): + value = config_dict.get(field, constants.NOT_FOUND) + # Use the default value if we cannot found the value + if value != constants.NOT_FOUND: + setattr(self, field, value) + + def to_dict(self): + result_dict = {} + config = constants.get_category_default_config(self._category) + for field in config.keys(): + value = getattr(self, field) + result_dict[field] = value + for field, value in self.__dict__.items(): + if isinstance(value, BaseConfig): + result_dict[field] = value.to_dict() + return result_dict + + def __repr__(self): + return yaml.dump(self.to_dict(), + default_flow_style=False, + sort_keys=True, + indent=4) + + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + setattr(result, k, copy.deepcopy(v, memo)) + return result + + +class RecomputeConfig(BaseConfig): + + def __init__(self, config_dict=None): + category = constants.RECOMPUTE + super(RecomputeConfig, self).__init__(category, config_dict) + + +class AMPConfig(BaseConfig): + + def __init__(self, config_dict=None): + category = constants.AMP + super(AMPConfig, self).__init__(category, config_dict) + + +class ShardingConfig(BaseConfig): + + def __init__(self, config_dict=None): + category = constants.SHARDING + super(ShardingConfig, self).__init__(category, config_dict) + + +class GradientMergeConfig(BaseConfig): + + def __init__(self, config_dict=None): + category = constants.GRADIENT_MERGE + super(GradientMergeConfig, self).__init__(category, config_dict) + + +class QATConfig(BaseConfig): + + def __init__(self, config_dict=None): + category = constants.QAT + super(QATConfig, self).__init__(category, config_dict) + + +class TuningConfig(BaseConfig): + + def __init__(self, config_dict=None): + category = constants.TUNING + super(TuningConfig, self).__init__(category, config_dict) + + +class Strategy(BaseConfig): + """ + The `Strategy` object is used to configure the paralleization and optimization beheviors. + + Args: + config (dict|string, optional): If this is None, the default configurations will used. + If this is a dictionary, the recognized key-value of it will be used to override the default + configurations while other default configurations are left unchanged. If this is a string, + it is interpreted as the path to a YAML configuration and will be loaded to override the + corresponding default configurations. + + Examples: + .. code-block:: python + + import paddle + import paddle.distributed.auto_parallel as auto + + strategy = auto.Strategy() + sharding = strategy.sharding + self.assertEqual(sharding.enabled, False) + self.assertEqual(sharding.stage, 1) + self.assertEqual(sharding.sharding_degree, 8) + sharding.enabled = True + sharding.stage = 2 + sharding.sharding_degree = 2 + self.assertEqual(sharding.enabled, True) + self.assertEqual(sharding.stage, 2) + self.assertEqual(sharding.sharding_degree, 2) + + """ + + def __init__(self, config=None): + if config is not None: + if isinstance(config, dict): + self._config_dict = copy.deepcopy(config) + # elif os.path.exists(config): + # with open(config, "rb") as yaml_file: + # self._config_dict = yaml.load(yaml_file, Loader=yaml.Loader) + else: + raise ValueError( + "Expected a dictionary. But received: {}".format(config)) + else: + self._config_dict = {} + + category = constants.BASE + super(Strategy, self).__init__(category, self._config_dict) + + config_dict = self._config_dict.get(constants.RECOMPUTE, None) + self.recompute = RecomputeConfig(config_dict) + + config_dict = self._config_dict.get(constants.AMP, None) + self.amp = AMPConfig(config_dict) + + config_dict = self._config_dict.get(constants.SHARDING, None) + self.sharding = ShardingConfig(config_dict) + + config_dict = self._config_dict.get(constants.GRADIENT_MERGE, None) + self.gradient_merge = GradientMergeConfig(config_dict) + + config_dict = self._config_dict.get(constants.QAT, None) + self.qat = QATConfig(config_dict) + + config_dict = self._config_dict.get(constants.TUNING, None) + self.tuning = TuningConfig(config_dict) diff --git a/python/paddle/distributed/auto_parallel/tuner/algorithms.py b/python/paddle/distributed/auto_parallel/tuner/algorithms.py index 63aa56f3e1faf2424449ec5ee9b4951b150a4213..f892a7838fe7a53b0af45f761ec013bab4637861 100644 --- a/python/paddle/distributed/auto_parallel/tuner/algorithms.py +++ b/python/paddle/distributed/auto_parallel/tuner/algorithms.py @@ -16,7 +16,7 @@ import copy from abc import ABC, abstractmethod import logging -from paddle.distributed.utils import get_logger +from ..utils import get_logger from .trial import TrialStatus from .trial import OptimizationTunerTrial as Trial @@ -110,13 +110,13 @@ class ShardingStageAlgorithm(AlgorithmBase): # TODO import trial class & copy strategy def __init__(self, config): super().__init__(config) - self._changed_configs = ["sharding_configs"] + self._changed_configs = ["sharding"] def _init_spaces(self): self._max_stage = 3 self._trial_idx = 0 - stage_range = self._config.sharding_configs.get("stage_range", None) + stage_range = self._config.sharding.to_dict().get("tuning_range", None) if stage_range: assert set(stage_range).issubset( set([0, 1, 2, 3]) @@ -136,9 +136,8 @@ class ShardingStageAlgorithm(AlgorithmBase): stage = self._stage_range[self._trial_idx] new_strategy = copy.deepcopy(self._config.dist_strategy) - config_dict = new_strategy.sharding_configs - config_dict["stage"] = stage - new_strategy.sharding_configs = config_dict + sharding = new_strategy.sharding + sharding.stage = stage name = "trial-sharding-stage{}".format(stage) trial = Trial(new_strategy, name, self.changed_configs) diff --git a/python/paddle/distributed/auto_parallel/tuner/config.py b/python/paddle/distributed/auto_parallel/tuner/config.py index 151a9a8bc76aa2693ae57a5522b30deb0bc26cfe..3083298eff87d731ef2dfc2f9e62b56aec241492 100644 --- a/python/paddle/distributed/auto_parallel/tuner/config.py +++ b/python/paddle/distributed/auto_parallel/tuner/config.py @@ -17,15 +17,13 @@ import copy import pathlib import paddle -from paddle.distributed import fleet +from ..strategy import Strategy _tuning_supported_passes = ["sharding", "recompute"] -_strategy_config_suffiex = "_configs" def _get_pass_config(strategy, pass_name): - config_name = pass_name + _strategy_config_suffiex - config = getattr(strategy, config_name) + config = getattr(strategy, pass_name) return config @@ -38,10 +36,8 @@ class TuningConfig(object): def __init__(self, user_config, strategy): - if not isinstance(strategy, fleet.DistributedStrategy): - raise TypeError( - "'strategy' must be object of class `fleet.DistributedStrategy`." - ) + if not isinstance(strategy, Strategy): + raise TypeError("'strategy' must be object of class `Strategy`.") if not user_config: user_config = {} @@ -116,11 +112,11 @@ class TuningConfig(object): for p in _tuning_supported_passes: if getattr(self._dist_strategy, p) and _get_pass_config( - self._dist_strategy, p)["enable_tuning"]: + self._dist_strategy, p).enable_tuning: # TODO distinguish different args of each passes self._tuning_passes_name.add(p) - config_name = p + _strategy_config_suffiex + config_name = p p_dict = getattr(self._dist_strategy, config_name) self.__dict__[config_name] = p_dict diff --git a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py index 261a382eb174a00c9db3c2b8ea556504aa8a230f..a2da7396ce83c6e67bcb2edc1bdaadb8c4a7970b 100644 --- a/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py +++ b/python/paddle/distributed/auto_parallel/tuner/optimization_tuner.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# import yaml import os import sys import copy @@ -29,7 +30,6 @@ import paddle from paddle.fluid import program_guard from paddle.fluid.backward import append_backward from paddle.distributed.passes import new_pass, PassContext -from paddle.distributed.utils import get_logger from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context from paddle.distributed.auto_parallel.completion import Completer @@ -39,6 +39,7 @@ from paddle.distributed.auto_parallel.process_group import clear_all_process_gro from paddle.distributed.auto_parallel.utils import debug_program from paddle.distributed.auto_parallel.utils import make_data_unshard, set_grad_var_shape +from ..utils import get_logger from .config import TuningConfig from .algorithms import new_algorithm from .trial import TrialStatus @@ -256,8 +257,8 @@ class OptimizationTuner: startup_program = dist_context.serial_startup_program # applying optimization pass - if new_strategy.amp: - config = copy.deepcopy(new_strategy.amp_configs) + if new_strategy.amp.enable: + config = copy.deepcopy(new_strategy.amp.to_dict()) config["dist_context"] = dist_context config["params_grads"] = dist_context._params_grads @@ -275,8 +276,8 @@ class OptimizationTuner: auto_parallel_amp_pass.apply([main_program], [startup_program], pass_context) - if new_strategy.recompute: - config = copy.deepcopy(new_strategy.recompute_configs) + if new_strategy.recompute.enable: + config = copy.deepcopy(new_strategy.recompute.to_dict()) config["dist_context"] = dist_context config["no_grad_set"] = None config["loss"] = dist_context.serial_loss @@ -303,8 +304,8 @@ class OptimizationTuner: dist_context, dist_params_grads) resharder.reshard() - if new_strategy.sharding: - config = copy.deepcopy(new_strategy.sharding_configs) + if new_strategy.sharding.enable: + config = copy.deepcopy(new_strategy.sharding.to_dict()) config["dist_context"] = dist_context config["params_grads"] = dist_params_grads config["global_rank"] = self.rank @@ -313,8 +314,8 @@ class OptimizationTuner: auto_parallel_sharding_pass.apply([dist_main_prog], [dist_startup_prog], pass_context) - if new_strategy.gradient_merge: - config = copy.deepcopy(new_strategy.gradient_merge_configs) + if new_strategy.gradient_merge.enable: + config = copy.deepcopy(new_strategy.gradient_merge.to_dict()) config["dist_context"] = dist_context config["params_grads"] = dist_params_grads auto_parallel_gradient_merge_pass = new_pass( @@ -492,9 +493,10 @@ The best trial is: [{}], whose configuration is following: for line in summary_.split("\n"): fw.write(line + "\n") - full_strategy = self.get_best_config() - full_strategy.save_to_prototxt( - os.path.join(self.project_dir, "tuned_dist_strategy.prototxt")) + # full_strategy = self.get_best_config() + # path = os.path.join(self.project_dir, "tuned_dist_strategy.yaml") + # with open(path, 'w') as outfile: + # yaml.dump(full_strategy, outfile, default_flow_style=False) def clear(self): """ diff --git a/python/paddle/distributed/auto_parallel/tuner/trial.py b/python/paddle/distributed/auto_parallel/tuner/trial.py index 3937ca9865181f066597d4afb48ac9832f1036bd..edc588b4c70fec3de995bae616961e6b1f87c81e 100644 --- a/python/paddle/distributed/auto_parallel/tuner/trial.py +++ b/python/paddle/distributed/auto_parallel/tuner/trial.py @@ -156,9 +156,10 @@ class OptimizationTunerTrial(Trial): draws += h1_format.format("{} auto=True <-> {}".format(name, name)) draws += line + "\n" my_configs = getattr(self.space, name) - keys = my_configs.keys() + keys = my_configs.to_dict().keys() for key in keys: - draws += h2_format.format(key, str(my_configs.get(key, None))) + draws += h2_format.format( + key, str(my_configs.to_dict().get(key, None))) result_res = draws + border return result_res diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index bc797530b75356212d7516a549532d515bba6062..ef165f5ff086d18bb10bf9360442ff70a91785d2 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -28,6 +28,19 @@ from paddle.fluid.io import is_parameter, is_belong_to_optimizer from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute +def get_logger(log_level, name="auto_parallel"): + logger = logging.getLogger(name) + logger.propagate = False + if not logger.handlers: + logger.setLevel(log_level) + log_handler = logging.StreamHandler() + log_format = logging.Formatter( + '%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s') + log_handler.setFormatter(log_format) + logger.addHandler(log_handler) + return logger + + def is_valid_list_index(list, index): if index >= -len(list) and index < len(list): return True @@ -49,6 +62,58 @@ def is_dim_replicate(mapping): return False +def verify_dims_mapping(dims_mapping, process_mesh): + if dims_mapping is None: + return False + if not all(isinstance(d, int) for d in dims_mapping): + return False + for i in range(len(dims_mapping)): + if dims_mapping[i] < -1 or dims_mapping[i] >= len(process_mesh.shape): + return False + for i in range(len(process_mesh.shape)): + if dims_mapping.count(i) > 1: + return False + return True + + +def convert_to_dims_mapping(shard_spec, process_mesh): + dims_mapping = [] + for shard in shard_spec: + if shard is None: + dims_mapping.append(-1) + else: + dims_mapping.append(process_mesh.dim_names.index(shard)) + return dims_mapping + + +def convert_to_shard_spec(dims_mapping, process_mesh): + shard_spec = [] + for dim_mapping in dims_mapping: + if dim_mapping == -1: + shard_spec.append(None) + else: + shard_spec.append(process_mesh.dim_names[dim_mapping]) + return shard_spec + + +def verify_shard_spec(shard_spec, tensor_shape, process_mesh): + if len(shard_spec) != len(tensor_shape): + return False + for shard in shard_spec: + if shard is not None and not isinstance(shard, str): + return False + if shard is not None and shard not in process_mesh.dim_names: + return False + dims_mapping = convert_to_dims_mapping(shard_spec, process_mesh) + if not verify_dims_mapping(dims_mapping, process_mesh): + return False + for i in range(len(tensor_shape)): + if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ + and tensor_shape[i] % process_mesh.shape[dims_mapping[i]] != 0: + return False + return True + + def compute_compatible_dim_mapping(dim_mappings): if not dim_mappings: return None @@ -1040,7 +1105,7 @@ def set_grad_var_shape(program, dist_context): if op.type in [ "c_allreduce_sum", "c_identity", "scale", "cast", - 'fill_any_like' + "fill_any_like" ]: forward_var_name = op.input_arg_names[0] elif op.type == "matmul_v2_grad" or op.type == "matmul_grad" or op.type == "mul_grad": diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index 9495ffa22b0c60ad76e861d7840bc803cc2d2ada..44f504887cf165a73ad5510d2113d557f85fa5e0 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -174,7 +174,7 @@ class DataParallelOptimizationPass(PassBase): def _could_be_prune(self): - return self.dist_context._gradient_scale and ( + return self.dist_context.gradient_scale and ( self._support_rescale_grad or self._all_dp_groups_same_degree()) def _all_dp_groups_same_degree(self): diff --git a/python/paddle/distributed/passes/auto_parallel_fp16.py b/python/paddle/distributed/passes/auto_parallel_fp16.py index 89ff2019d73920e16bda6449e401620350cbcd66..64562668a42ac7cf74d159582e2dd3e5493fbb3f 100644 --- a/python/paddle/distributed/passes/auto_parallel_fp16.py +++ b/python/paddle/distributed/passes/auto_parallel_fp16.py @@ -380,6 +380,10 @@ class FP16State(object): # create cast grad grad_slot_name = slot_name + "@GRAD" assert grad_slot_name in op.output_names + if len(op.output(grad_slot_name)) == 0: + var = block.var(src_name) + assert var.stop_gradient is True + continue assert len(op.output(grad_slot_name)) == 1 grad_name = op.output(grad_slot_name)[0] grad = block.var(grad_name) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 27f86dc9f100a73f01eda6e10220cbec44b347c0..bbccf452742a3837549f436c79ebb0d67be4bf4e 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -37,9 +37,29 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ${dist_ENVS}) set_tests_properties(test_high_order_grad PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) - py_test_modules(test_grad_clip MODULES test_grad_clip ENVS ${dist_ENVS}) - set_tests_properties(test_grad_clip PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" - TIMEOUT 50) + py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS + ${dist_ENVS}) + set_tests_properties(test_iterable_dataset + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80) + py_test_modules(test_pass_grad_clip MODULES test_pass_grad_clip ENVS + ${dist_ENVS}) + set_tests_properties(test_pass_grad_clip + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) + py_test_modules(test_pass_gradient_merge MODULES test_pass_gradient_merge + ENVS ${dist_ENVS}) + set_tests_properties(test_pass_gradient_merge + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) + py_test_modules(test_pass_recompute MODULES test_pass_recompute ENVS + ${dist_ENVS}) + set_tests_properties(test_pass_recompute + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) + py_test_modules(test_pass_sharding MODULES test_pass_sharding ENVS + ${dist_ENVS}) + set_tests_properties(test_pass_sharding + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) + py_test_modules(test_pass_amp MODULES test_pass_amp ENVS ${dist_ENVS}) + set_tests_properties(test_pass_amp PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" + TIMEOUT 50) py_test_modules(test_while_op_completion MODULES test_while_op_completion ENVS ${dist_ENVS}) @@ -70,11 +90,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2) py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2) py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip) - py_test_modules(test_quantization MODULES test_quantization) py_test_modules(test_dist_matmul MODULES test_dist_matmul) + py_test_modules(test_process_mesh MODULES test_process_mesh) + py_test_modules(test_interface MODULES test_interface) + py_test_modules(test_strategy MODULES test_strategy) + py_test_modules(test_pass_quantization MODULES test_pass_quantization) - py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS - ${dist_ENVS}) - set_tests_properties(test_iterable_dataset - PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 80) endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..5ca2d8132e2947a3bd45ad22f4596c0c4b736023 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py @@ -0,0 +1,120 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +import random +import numpy as np +import paddle + +import paddle.distributed.auto_parallel as auto +from paddle.fluid.dygraph.parallel import ParallelEnv +from get_gpt_model import generate_model, create_data_holder, FakeDataset + + +def apply_pass(use_amp=False, level=None): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True + if use_amp: + amp = strategy.amp + amp.enable = True + amp.custom_white_list = ['softmax', 'layer_norm', 'gelu'] + amp.custom_black_list = [ + 'c_softmax_with_cross_entropy', 'elementwise_div', 'reduce_sum' + ] + amp.init_loss_scaling = 32768 + amp.use_fp16_guard = False + amp.use_pure_fp16 = level in ["o2", "o3"] + amp.use_optimizer_fp16 = level == "o3" + print("amp level: ", level) + return strategy + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class TestAMPPass(unittest.TestCase): + + def setUp(self): + self.rtol = 1e-5 + self.atol = 1e-8 + self.batch_size = 1 + self.batch_num = 10 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) + place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine(self, use_amp=False, level=None): + reset_prog() + + strategy = apply_pass(use_amp, level) + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("mp") + + engine = auto.Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def check_results(self, ref_losses, check_losses, rtol=None, atol=None): + np.testing.assert_allclose( + ref_losses, + check_losses, + rtol=rtol or self.rtol, + atol=atol or self.atol, + err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( + __class__, ref_losses, check_losses, ref_losses - check_losses)) + + def test_amp_pass(self): + # mp2 training + mp_engine = self.get_engine() + mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size) + mp_losses = np.array(mp_losses["loss"]) + + # mp2 amp-o1 training + amp_o1_engine = self.get_engine(True, "o1") + amp_o1_losses = amp_o1_engine.fit(self.dataset, + 3, + batch_size=self.batch_size) + amp_o1_losses = np.array(amp_o1_losses["loss"]) + # self.check_results(mp_losses, amp_o1_losses) + + # mp2 amp-o2 training + amp_o2_engine = self.get_engine(True, "o2") + amp_o2_losses = amp_o2_engine.fit(self.dataset, + 3, + batch_size=self.batch_size) + amp_o2_losses = np.array(amp_o2_losses["loss"]) + # self.check_results(mp_losses, amp_o2_losses) + + # mp2 amp-o3 training + amp_o3_engine = self.get_engine(True, "o3") + amp_o3_losses = amp_o3_engine.fit(self.dataset, + 3, + batch_size=self.batch_size) + amp_o3_losses = np.array(amp_o3_losses["loss"]) + # self.check_results(mp_losses, amp_o3_losses) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/auto_parallel_relaunch_model.py b/python/paddle/fluid/tests/unittests/auto_parallel/auto_parallel_relaunch_model.py index d459ffd6d680d5893e908f4a833b75890e8a9f05..4639abf32554e045aa4b4a0eb6100c9c3c58f22b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/auto_parallel_relaunch_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/auto_parallel_relaunch_model.py @@ -32,7 +32,7 @@ import paddle.distributed.auto_parallel as auto paddle.enable_static() _global_parallel_strategy = None -_global_process_mesh = None +_global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"]) batch_size = 4 hidden_size = 1024 sequence_len = 512 @@ -103,11 +103,7 @@ def mlp_pretrain_forward(train_program, start_program): shape=[batch_size, sequence_len, 1], dtype='float32') - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mappig": [-1, -1, -1] - }) + auto.shard_tensor(input, _global_process_mesh, [None, None, None]) mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, @@ -126,9 +122,6 @@ def mlp_pretrain_forward(train_program, start_program): def train(): - global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) - dist_strategy = fleet.DistributedStrategy() dist_strategy.amp = False dist_strategy.pipeline = False diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py b/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py index 60a915c53cddfcfde386a23c38dcde17329b9096..1a8c5e6072cba29a1c2477ad07f210ebcf3a0f75 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py @@ -18,24 +18,21 @@ import random import numpy as np import paddle -import paddle.distributed.fleet as fleet import paddle.distributed.auto_parallel as auto - -from paddle.distributed.auto_parallel.engine import Engine +from paddle.fluid.dygraph.parallel import ParallelEnv from get_gpt_model import generate_model, create_data_holder, FakeDataset paddle.enable_static() def apply_pass(use_sharding=False): - strategy = fleet.DistributedStrategy() - strategy.semi_auto = True + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True if use_sharding: - strategy.sharding = True - strategy.sharding_configs = { - "sharding_degree": 2, - "stage": 2, - } + sharding = strategy.sharding + sharding.sharding_degree = 2 + sharding.stage = 2 return strategy @@ -76,34 +73,17 @@ class TestGradientClipByGlobalNorm(unittest.TestCase): paddle.seed(2022) np.random.seed(2022) random.seed(2022) - engine.mode = "train" - engine._executor.run(engine.startup_program) + place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) - def get_dp2_engine(self): + def get_engine(self, use_sharding=False): reset_prog() - strategy = apply_pass() + strategy = apply_pass(use_sharding) clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) model, loss = generate_model("dp") - inputs_spec, labels_spec = create_data_holder(self.batch_size) - - engine = Engine(model, inputs_spec, labels_spec, strategy=strategy) - engine.prepare(optimizer=opt, loss=loss) - self.init(engine) - return engine - - def get_dp2sharding2_engine(self): - reset_prog() - - strategy = apply_pass(True) - clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) - opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) - model, loss = generate_model("dp") - inputs_spec, labels_spec = create_data_holder(self.batch_size) - - engine = Engine(model, inputs_spec, labels_spec, strategy=strategy) - engine.prepare(optimizer=opt, loss=loss) + engine = auto.Engine(model, loss, opt, strategy=strategy) self.init(engine) return engine @@ -121,15 +101,13 @@ class TestGradientClipByGlobalNorm(unittest.TestCase): def test_grad_clip(self): # dp2 training - dp_engine = self.get_dp2_engine() - dp_engine.fit(self.dataset, batch_size=self.batch_size, use_cache=True) + dp_engine = self.get_engine() + dp_engine.fit(self.dataset, 3, batch_size=self.batch_size) dp_param_values = get_parameter_value(dp_engine.main_program) # dp2sharding2 training - sharding_engine = self.get_dp2sharding2_engine() - sharding_engine.fit(self.dataset, - batch_size=self.batch_size, - use_cache=True) + sharding_engine = self.get_engine(True) + sharding_engine.fit(self.dataset, 3, batch_size=self.batch_size) sharding_param_values = get_parameter_value( sharding_engine.main_program) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py index 104614e3e9d4edb7343666eb8624ce98f6b0b80e..94677645ad4e8d22dfd643052ddfb1c86648127e 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py @@ -27,10 +27,8 @@ import paddle.nn.functional as F import paddle.utils as utils from paddle.fluid import layers from paddle.io import Dataset, IterableDataset, DataLoader -from paddle.static import InputSpec -from paddle.distributed import fleet + import paddle.distributed.auto_parallel as auto -from paddle.distributed.auto_parallel.engine import Engine from paddle.optimizer.lr import CosineAnnealingDecay from paddle.fluid.dataloader.collate import default_collate_fn @@ -47,6 +45,8 @@ class_num = 10 paddle.seed(44) +is_fetch = True + class MyDataset(Dataset): @@ -90,19 +90,20 @@ class MLPLayer(nn.Layer): self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") def forward(self, input): - out = auto.shard_op(self.norm, dist_attr={"process_mesh": - PP_MESH_0})(input) + out = auto.shard_op(self.norm, PP_MESH_0)(input) out = self.linear0(out) out = F.gelu(out, approximate=True) - out = auto.shard_op(self.linear1, dist_attr={"process_mesh": - PP_MESH_1})(out) + out = auto.shard_op(self.linear1, PP_MESH_1)(out) out = self.dropout(out) out = self.linear2(out) - self.out = out + if is_fetch: + auto.fetch(out, "out") return out def train(fetch): + global is_fetch + is_fetch = fetch mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, dropout_ratio=0.1, @@ -113,46 +114,34 @@ def train(fetch): beta2=0.999, epsilon=1e-08, grad_clip=None) + metric = paddle.metric.Accuracy() - inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') - labels_spec = InputSpec([batch_size], 'int64', 'label') - - dist_strategy = fleet.DistributedStrategy() - dist_strategy.semi_auto = True - fleet.init(is_collective=True, strategy=dist_strategy) - - # init engine - engine = Engine(mlp, - inputs_spec=inputs_spec, - labels_spec=labels_spec, - strategy=dist_strategy) - engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy()) + strategy = auto.Strategy() + strategy.auto_mode = "semi" - # fetch - if fetch: - fetches = {'out': mlp.out} - else: - fetches = None + engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy) # train train_dataset = MyDataset(batch_num * batch_size) - engine.fit(train_dataset, + eval_dataset1 = MyDataset(5 * batch_size) + engine.fit(train_data=train_dataset, + epochs=2, batch_size=batch_size, - steps_per_epoch=batch_num * batch_size, - fetches=fetches) + valid_data=eval_dataset1) # eval - eval_dataset = MyDataset(batch_size) - engine.evaluate(eval_dataset, batch_size, fetches=fetches) + eval_dataset2 = MyDataset(batch_size) + engine.evaluate(eval_dataset2, batch_size=batch_size) # predict test_dataset = MyDataset(batch_size) - engine.predict(test_dataset, batch_size, fetches=fetches) + engine.predict(test_dataset, batch_size=batch_size) # save temp_dir = tempfile.TemporaryDirectory() - model_filename = os.path.join(temp_dir.name, 'mlp_inf') - engine.save(model_filename, training=False, mode='predict') + model_filename = os.path.join(temp_dir.name, 'mlp') + engine.save(model_filename, training=True) + engine.load(model_filename) temp_dir.cleanup() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api_dp.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api_dp.py index 76a4772290db9dfb5e1b899c1a06d10cee32d3f1..8e863e1f532bf390bd4ca98a02af01ccd49c4b1a 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api_dp.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api_dp.py @@ -26,11 +26,9 @@ import paddle.static as static import paddle.nn.functional as F import paddle.utils as utils from paddle.fluid import layers -from paddle.io import Dataset, IterableDataset, DataLoader -from paddle.static import InputSpec -from paddle.distributed import fleet +from paddle.io import Dataset, DataLoader + import paddle.distributed.auto_parallel as auto -from paddle.distributed.auto_parallel.engine import Engine paddle.enable_static() batch_size = 2 @@ -91,6 +89,7 @@ class MLPLayer(nn.Layer): out = self.linear1(out) out = self.dropout(out) out = self.linear2(out) + auto.fetch(out, "out") self.out = out return out @@ -107,46 +106,32 @@ def train(fetch): epsilon=1e-08, grad_clip=None) - inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') - labels_spec = InputSpec([batch_size], 'int64', 'label') - - dist_strategy = fleet.DistributedStrategy() - dist_strategy.amp = False - dist_strategy.pipeline = False - dist_strategy.recompute = False - # init parallel optimizer - dist_strategy.semi_auto = True - fleet.init(is_collective=True, strategy=dist_strategy) + dist_strategy = auto.Strategy() + dist_strategy.auto_mode = "semi" # init engine - engine = Engine(mlp, - inputs_spec=inputs_spec, - labels_spec=labels_spec, - strategy=dist_strategy) - engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy()) - - # fetch - if fetch: - fetches = {'out': mlp.out} - else: - fetches = None + engine = auto.Engine(mlp, + loss, + optimizer, + paddle.metric.Accuracy(), + strategy=dist_strategy) # train train_dataset = MyDataset(batch_num * batch_size) - engine.fit(train_dataset, batch_size=batch_size, fetches=fetches) + engine.fit(train_dataset, batch_size=batch_size) # eval eval_dataset = MyDataset(batch_size) - engine.evaluate(eval_dataset, batch_size, fetches=fetches) + engine.evaluate(eval_dataset, batch_size=batch_size) # predict test_dataset = MyDataset(batch_size) - engine.predict(test_dataset, batch_size, fetches=fetches) + engine.predict(test_dataset, batch_size=batch_size) # save temp_dir = tempfile.TemporaryDirectory() model_filename = os.path.join(temp_dir.name, 'mlp_inf') - engine.save(model_filename, training=False, mode='predict') + engine.save(model_filename, training=False) temp_dir.cleanup() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py index 2884a03a023e541e40d04dfdfb9d3f377ca0f8a9..9e32bb1cee57110bfd2f1bb0c0891c431f32d4db 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/get_gpt_model.py @@ -14,8 +14,10 @@ import sys import numpy as np +import random import paddle +import paddle.distributed.auto_parallel as auto sys.path.append("..") import auto_parallel_gpt_model as modeling @@ -25,7 +27,7 @@ sequence_len = 512 vocab_size = 1000 -class FakeDataset: +class FakeDataset(paddle.io.Dataset): def __init__(self, num_samples): self.num_samples = num_samples @@ -33,6 +35,9 @@ class FakeDataset: self.vocab_size = vocab_size def __getitem__(self, idx): + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) tokens = np.random.randint(self.vocab_size, size=self.sequence_len) position_ids = np.arange(self.sequence_len) attention_mask = np.tril(np.ones(self.sequence_len)).reshape( @@ -67,8 +72,9 @@ def create_data_holder(batch_size): def generate_model(strategy): modeling.init_global() - modeling._global_process_mesh = list( - range(paddle.distributed.get_world_size())) + ranks = list(range(paddle.distributed.get_world_size())) + modeling._global_process_mesh = auto.ProcessMesh(mesh=ranks, + dim_names=["x"]) if strategy == "serial": modeling._global_parallel_strategy = "serial" elif strategy == "mp": diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..75aa7d9c1e05f8cdad26fc27ab78cbf7af204305 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py @@ -0,0 +1,109 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +import random +import numpy as np +import paddle + +import paddle.distributed.auto_parallel as auto +from paddle.fluid.dygraph.parallel import ParallelEnv +from get_gpt_model import generate_model, create_data_holder, FakeDataset + +paddle.enable_static() + + +def apply_pass(use_gradient_merge=False): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True + if use_gradient_merge: + gradient_merge = strategy.gradient_merge + gradient_merge.enable = True + gradient_merge.k_steps = 4 + gradient_merge.avg = True + + return strategy + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class TestGradientMergePass(unittest.TestCase): + + def setUp(self): + self.rtol = 1e-5 + self.atol = 1e-8 + self.batch_size = 8 + self.batch_num = 10 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2021) + np.random.seed(2021) + random.seed(2021) + place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine(self, use_gradient_merge=False): + reset_prog() + + strategy = apply_pass(use_gradient_merge) + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("dp") + + engine = auto.Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def check_results(self, ref_losses, check_losses): + np.testing.assert_allclose( + ref_losses, + check_losses, + rtol=self.rtol, + atol=self.atol, + err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( + __class__, ref_losses, check_losses, ref_losses - check_losses)) + + def test_gradient_merge_pass(self): + # dp2 training + dp_engine = self.get_engine() + dp_losses = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size) + dp_losses = np.array(dp_losses["loss"]) + + # dp2 gradient merge training + gm_engine = self.get_engine(True) + gm_losses = gm_engine.fit(self.dataset, 3, batch_size=self.batch_size) + gm_losses = np.array(gm_losses["loss"]) + + avg_loss = 0 + pass_avg_ret_list = [] + for i, pass_ret in enumerate(gm_losses): + if (i + 1) % 4 == 0: + avg_loss += pass_ret + pass_avg_ret_list.append(avg_loss / 4) + avg_loss = 0 + else: + avg_loss += pass_ret + + self.check_results(dp_losses, np.array(pass_avg_ret_list)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py b/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py index 9ab49b30d9d6773feeb3f2beab3da6e365f5ff89..85a6189985136f8cc83e3515cc53f01d76c3c45c 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/high_order_grad.py @@ -17,11 +17,7 @@ import paddle import unittest import numpy as np import paddle.distributed.auto_parallel as auto - -from paddle.static import InputSpec -from paddle.distributed import fleet from paddle.incubate.autograd import Hessian -from paddle.distributed.auto_parallel.engine import Engine np.random.seed(1234) paddle.seed(1234) @@ -87,7 +83,7 @@ class LaplaceModel(paddle.nn.Layer): return eq_loss, bc_u -class LaplaceDataset: +class LaplaceDataset(paddle.io.Dataset): def __init__(self, num_sample): self.num_sample = num_sample @@ -129,23 +125,14 @@ def main(): # model laplace = LaplaceModel() - # spec - inputs_spec = [ - InputSpec([100, 2], 'float32', 'x'), - InputSpec([36], 'int64', 'bc_idx') - ] - labels_spec = InputSpec([36, 1], 'float32', 'bc_v') - - dist_strategy = fleet.DistributedStrategy() - dist_strategy.semi_auto = True - fleet.init(is_collective=True, strategy=dist_strategy) - - engine = Engine(laplace, - inputs_spec=inputs_spec, - labels_spec=labels_spec, - strategy=dist_strategy) - engine.prepare(optimizer=optimizer, loss=loss_func) - engine.fit(train_dataset, batch_size=None) + dist_strategy = auto.Strategy() + dist_strategy.auto_mode = "semi" + + engine = auto.Engine(laplace, + loss=loss_func, + optimizer=optimizer, + strategy=dist_strategy) + engine.fit(train_dataset, train_sample_split=2, batch_size=None) dist_context = engine.dist_context block = engine.main_program.global_block() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/iterable_dataset.py b/python/paddle/fluid/tests/unittests/auto_parallel/iterable_dataset.py index 4ca3d14f7165a26dc335f1c5e53f5d4448e7060e..7bb183c54c9383fedbf48407a7d1cb7d1be5dab7 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/iterable_dataset.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/iterable_dataset.py @@ -28,9 +28,8 @@ import paddle.utils as utils from paddle.fluid import layers from paddle.io import Dataset, IterableDataset, DataLoader from paddle.static import InputSpec -from paddle.distributed import fleet + import paddle.distributed.auto_parallel as auto -from paddle.distributed.auto_parallel.engine import Engine from paddle.optimizer.lr import CosineAnnealingDecay from paddle.fluid.dataloader.collate import default_collate_fn @@ -48,10 +47,9 @@ class_num = 10 paddle.seed(44) -class MyDataset(IterableDataset): +class MyDataset(paddle.io.IterableDataset): def __init__(self, num_samples): - super(MyDataset, self).__init__() self.num_samples = num_samples def __iter__(self): @@ -61,10 +59,9 @@ class MyDataset(IterableDataset): yield input, label -class MyDataset1(Dataset): +class MyDataset1(paddle.io.Dataset): def __init__(self, num_samples): - super(MyDataset1, self).__init__() self.num_samples = num_samples self.data = [] for i in range(self.num_samples): @@ -112,12 +109,10 @@ class MLPLayer(nn.Layer): self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") def forward(self, input): - out = auto.shard_op(self.norm, dist_attr={"process_mesh": - PP_MESH_0})(input) + out = auto.shard_op(self.norm, PP_MESH_0)(input) out = self.linear0(out) out = F.gelu(out, approximate=True) - out = auto.shard_op(self.linear1, dist_attr={"process_mesh": - PP_MESH_1})(out) + out = auto.shard_op(self.linear1, PP_MESH_1)(out) out = self.dropout(out) out = self.linear2(out) self.out = out @@ -136,54 +131,36 @@ def train(fetch): epsilon=1e-08, grad_clip=None) - inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') - labels_spec = InputSpec([batch_size], 'int64', 'label') - - dist_strategy = fleet.DistributedStrategy() - dist_strategy.semi_auto = True + dist_strategy = auto.Strategy() + dist_strategy.auto_mode = "semi" dist_strategy.split_data = True - fleet.init(is_collective=True, strategy=dist_strategy) # init engine - engine = Engine(mlp, - inputs_spec=inputs_spec, - labels_spec=labels_spec, - strategy=dist_strategy) - engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy()) - - # fetch - if fetch: - fetches = {'out': mlp.out} - else: - fetches = None + engine = auto.Engine(mlp, + loss, + optimizer, + paddle.metric.Accuracy(), + strategy=dist_strategy) # train train_dataset = MyDataset(batch_num * batch_size) - train_dataset1 = MyDataset1(batch_num) - engine.fit(train_dataset, - epochs=2, - batch_size=batch_size, - steps_per_epoch=batch_num, - fetches=fetches) - - engine.fit(train_dataset1, - epochs=2, - batch_size=None, - steps_per_epoch=batch_num, - fetches=fetches) + engine.fit(train_dataset, epochs=2, batch_size=batch_size) + + train_dataset1 = MyDataset1(batch_size * batch_num) + engine.fit(train_dataset1, epochs=2, batch_size=None) # eval eval_dataset = MyDataset(batch_size) - engine.evaluate(eval_dataset, batch_size, fetches=fetches) + engine.evaluate(eval_dataset, batch_size=batch_size) # predict test_dataset = MyDataset(batch_size) - engine.predict(test_dataset, batch_size, fetches=fetches) + engine.predict(test_dataset, batch_size=batch_size) # save temp_dir = tempfile.TemporaryDirectory() model_filename = os.path.join(temp_dir.name, 'mlp_inf') - engine.save(model_filename, training=False, mode='predict') + engine.save(model_filename, training=False) temp_dir.cleanup() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py index 8e058d16b87b369da8a90ef4a41c1159d1f64d0b..a245329a93a956ba9d5815ec1a51c17fd5d6025d 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/optimization_tuner_api.py @@ -27,10 +27,8 @@ import paddle.nn.functional as F import paddle.utils as utils from paddle.fluid import layers from paddle.io import Dataset, IterableDataset, DataLoader -from paddle.static import InputSpec -from paddle.distributed import fleet + import paddle.distributed.auto_parallel as auto -from paddle.distributed.auto_parallel.engine import Engine from engine_api_dp import MyDataset paddle.enable_static() @@ -43,20 +41,6 @@ class_num = 10 paddle.seed(44) -# class MyDataset(Dataset): - -# def __init__(self, num_samples): -# super(MyDataset, self).__init__() -# self.num_samples = num_samples - -# def __getitem__(self, index): -# input = np.random.uniform(size=image_size).astype("float32") -# label = np.random.randint(0, class_num - 1, dtype="int64") -# return input, label - -# def __len__(self): -# return self.num_samples - class MLPLayer(nn.Layer): @@ -107,50 +91,33 @@ def train(fetch): epsilon=1e-08, grad_clip=None) - inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') - labels_spec = InputSpec([batch_size], 'int64', 'label') - - dist_strategy = fleet.DistributedStrategy() - dist_strategy.amp = False - dist_strategy.pipeline = False - dist_strategy.recompute = False - # init parallel optimizer - dist_strategy.semi_auto = True - dist_strategy.sharding = True - dist_strategy.sharding_configs = { - "sharding_degree": 2, - "stage": 3, - "enable_tuning": True, - } - fleet.init(is_collective=True, strategy=dist_strategy) - - # init engine - import tempfile - tmp_dir = tempfile.TemporaryDirectory() - dataset = MyDataset(batch_num * batch_size) - + dist_strategy = auto.Strategy() + dist_strategy.auto_mode = "semi" + # sharding config + sharding = dist_strategy.sharding + sharding.enable = True + sharding.sharding_degree = 2 + sharding.stage = 3 + sharding.enable_tuning = True + sharding.tuning_range = [0, 1, 2, 3] # Tuning configuration - tuning_config = { - "batch_size": batch_size, - "dataset": dataset, - "profile_start_step": 1, - "profile_end_step": 5, - "run_after_tuning": True, - "sharding": { - "stage_range": [0, 1, 2, 3] - }, - "verbose": True, - } - engine = Engine(mlp, - inputs_spec=inputs_spec, - labels_spec=labels_spec, - strategy=dist_strategy, - user_tuning_config=tuning_config) - engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy()) + tuning = dist_strategy.tuning + tuning.enable = True + tuning.profile_start_step = 1 + tuning.profile_end_step = 5 + tuning.run_after_tuning = True + tuning.verbose = True + + dataset = MyDataset(batch_num * batch_size) + engine = auto.Engine(mlp, + loss, + optimizer, + paddle.metric.Accuracy(), + strategy=dist_strategy) + engine._tune(dataset, batch_size=batch_size) # check tuned - assert (engine._dist_contexts['train'].strategy.sharding_configs['stage'] != - 3) + assert (engine._dist_contexts['train'].strategy.sharding.stage != 3) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..271752deca077099e7773cb6a200381685791e98 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +import random +import numpy as np +import paddle + +import paddle.distributed.auto_parallel as auto +from paddle.fluid.dygraph.parallel import ParallelEnv +from get_gpt_model import generate_model, create_data_holder, FakeDataset + + +def apply_pass(use_recompute=False): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True + if use_recompute: + recompute = strategy.recompute + recompute.enable = True + return strategy + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class TestRecomputePass(unittest.TestCase): + + def setUp(self): + self.rtol = 1e-6 + self.atol = 1e-8 + self.batch_size = 1 + self.batch_num = 10 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2022) + np.random.seed(2022) + random.seed(2022) + place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine(self, use_recompute=False): + reset_prog() + + strategy = apply_pass(use_recompute) + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("mp") + + engine = auto.Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def check_results(self, ref_losses, check_losses): + np.testing.assert_allclose( + ref_losses, + check_losses, + rtol=self.rtol, + atol=self.atol, + err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( + __class__, ref_losses, check_losses, ref_losses - check_losses)) + + def test_recompute_pass(self): + # mp2 training + mp_engine = self.get_engine() + mp_losses = mp_engine.fit(self.dataset, 3, batch_size=self.batch_size) + mp_losses = np.array(mp_losses["loss"]) + + # mp2 recompute training + rc_engine = self.get_engine(True) + rc_losses = rc_engine.fit(self.dataset, 3, batch_size=self.batch_size) + rc_losses = np.array(rc_losses["loss"]) + self.check_results(mp_losses, rc_losses) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py new file mode 100644 index 0000000000000000000000000000000000000000..70dfd5f87df9976790a40512f9b931f50f3ee06c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py @@ -0,0 +1,116 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +import random +import numpy as np +import paddle + +import paddle.distributed.auto_parallel as auto +from paddle.fluid.dygraph.parallel import ParallelEnv +from get_gpt_model import generate_model, create_data_holder, FakeDataset + +paddle.enable_static() + + +def apply_pass(use_sharding=False, stage=None): + strategy = auto.Strategy() + strategy.auto_mode = "semi" + strategy.reinit = True + if use_sharding: + sharding = strategy.sharding + sharding.enable = True + sharding.sharding_degree = 2 + sharding.stage = 1 + + return strategy + + +def reset_prog(): + paddle.fluid.framework.switch_main_program(paddle.static.Program()) + paddle.fluid.framework.switch_startup_program(paddle.static.Program()) + + +class TestShardingPass(unittest.TestCase): + + def setUp(self): + self.rtol = 1e-6 + self.atol = 1e-8 + self.batch_size = 2 + self.batch_num = 10 + self.clip_norm = 0.2 + self.dataset = FakeDataset(self.batch_size * self.batch_num) + + def init(self, engine): + paddle.seed(2022) + np.random.seed(2022) + random.seed(2022) + place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) + engine._executor = paddle.static.Executor(place) + + def get_engine(self, use_sharding=False, stage=None): + reset_prog() + + strategy = apply_pass(use_sharding, stage) + clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm) + opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip) + model, loss = generate_model("dp") + + engine = auto.Engine(model, loss, opt, strategy=strategy) + self.init(engine) + return engine + + def check_results(self, ref_losses, check_losses): + np.testing.assert_allclose( + ref_losses, + check_losses, + rtol=self.rtol, + atol=self.atol, + err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( + __class__, ref_losses, check_losses, ref_losses - check_losses)) + + def test_sharding_pass(self): + # dp2 training + dp_engine = self.get_engine() + dp_losses = dp_engine.fit(self.dataset, 3, batch_size=self.batch_size) + dp_losses = np.array(dp_losses["loss"]) + + # sharding2 stage1 training + sharding1_engine = self.get_engine(True, 1) + sharding1_losses = sharding1_engine.fit(self.dataset, + 3, + batch_size=self.batch_size) + sharding1_losses = np.array(sharding1_losses["loss"]) + self.check_results(dp_losses, sharding1_losses) + + # sharding2 stage2 training + sharding2_engine = self.get_engine(True, 2) + sharding2_losses = sharding2_engine.fit(self.dataset, + 3, + batch_size=self.batch_size) + sharding2_losses = np.array(sharding2_losses["loss"]) + self.check_results(dp_losses, sharding2_losses) + + # sharding2 stage3 training + sharding3_engine = self.get_engine(True, 3) + sharding3_losses = sharding3_engine.fit(self.dataset, + 3, + batch_size=self.batch_size) + sharding3_losses = np.array(sharding3_losses["loss"]) + self.check_results(dp_losses, sharding3_losses) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_base_cost.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_base_cost.py index 0fbe4f5bd3d0950c232702671b5cfe5f1963cbb8..d797df3b8ad156989354eefcccc9ec402c32295e 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_base_cost.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_base_cost.py @@ -45,9 +45,10 @@ from test_cluster import cluster_json paddle.enable_static() _global_parallel_strategy = "dp_mp_pp" -_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]]) -PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]]) -PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]]) +_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]], + dim_names=["x", "y", "z"]) +PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], dim_names=["x", "y"]) +PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], dim_names=["x", "y"]) class MLPLayer(nn.Layer): @@ -74,16 +75,8 @@ class MLPLayer(nn.Layer): self.norm = nn.LayerNorm(d_model, epsilon=1e-5) def forward(self, input): - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [-1, 1] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [1, -1] - }) + auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, "y"]) + auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["y", None]) out = self.norm(input) out = self.linear0(out) @@ -111,16 +104,8 @@ def mlp_forward(train_program, start_program): embedding = paddle.nn.Embedding(10, hidden_size, sparse=True) embedding_out = embedding(fill_constant_out) - auto.shard_tensor(input, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [0, -1] - }) - auto.shard_tensor(label, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(input, PP_MESH_0, ["x", None]) + auto.shard_tensor(label, PP_MESH_1, ["x", None]) mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py index 62d87fcc191ad81ecb54936adf379d84037b3d34..5a8e59b2969b08f3af7c2482bca98a79a876122b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py @@ -34,7 +34,10 @@ paddle.enable_static() batch_size = 4 hidden_size = 1024 sequence_len = 512 -_g_process_mesh = [[0, 1], [2, 3]] +_g_process_mesh = [ + auto.ProcessMesh([0, 1], dim_names=["x"]), + auto.ProcessMesh([2, 3], dim_names=["x"]) +] def get_random_inputs_and_labels(input_shape, label_shape): @@ -82,18 +85,10 @@ class MLPLayer(nn.Layer): def forward(self, input): out = self.norm(input) - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _g_process_mesh[0], - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.linear0.weight, _g_process_mesh[0], [None, "x"]) out = self.linear0(out) out = F.gelu(out, approximate=True) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _g_process_mesh[1], - "dims_mapping": [0, -1] - }) + auto.shard_tensor(self.linear1.weight, _g_process_mesh[1], ["x", None]) out = self.linear1(out) return out @@ -123,16 +118,8 @@ def get_program(): dataloader.set_batch_generator(batch_generator_creator(), places=paddle.static.cuda_places()) # data dist_attr - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _g_process_mesh[0], - "dims_mapping": [0, -1, -1] - }) - auto.shard_tensor(label, - dist_attr={ - "process_mesh": _g_process_mesh[0], - "dims_mapping": [0, -1, -1] - }) + auto.shard_tensor(input, _g_process_mesh[0], ["x", None, None]) + auto.shard_tensor(label, _g_process_mesh[0], ["x", None, None]) mlp_start = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_embedding.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_embedding.py index 0b81b5bd48ca5f304cee5b872eeb2d0e63aeb245..0cf5fca08acd8b105cf6263e8d38ed1dcee03cab 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_embedding.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_embedding.py @@ -42,19 +42,13 @@ def make_program_lookup_table_v1_mp_dp(): is_sparse=False) loss = paddle.fluid.layers.reduce_mean(emb_out) - auto.shard_tensor(src_ids, - dist_attr={ - "process_mesh": auto.ProcessMesh([[0, 1], [2, - 3]]), - "dims_mapping": [0, -1, -1] - }) + auto.shard_tensor( + src_ids, auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]), + ["x", None, None]) emb_weight = block.vars["emb_weight"] - auto.shard_tensor(emb_weight, - dist_attr={ - "process_mesh": auto.ProcessMesh([[0, 1], [2, - 3]]), - "dims_mapping": [1, -1] - }) + auto.shard_tensor( + emb_weight, auto.ProcessMesh([[0, 1], [2, 3]], + dim_names=["x", "y"]), ["y", None]) return main_program, start_program, loss diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_matmul.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_matmul.py index 8cf2b47660fe5cbdb44c280ab831099ead66e37a..77c6888d26e10a50e7b8449f8e890e7e1352e645 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_matmul.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_matmul.py @@ -22,82 +22,58 @@ from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr paddle.enable_static() -mesh = [[0, 1], [2, 3]] +mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]) def init_x_row(trans_x): if trans_x: x = paddle.static.data(name='x', shape=[10, 6, 8], dtype='float32') - auto.shard_tensor(x, - dist_attr={ - "process_mesh": mesh, - "dims_mapping": [0, 1, -1] - }) + auto.shard_tensor(x, mesh, ["x", "y", None]) + return x else: x = paddle.static.data(name='x', shape=[10, 8, 6], dtype='float32') - auto.shard_tensor(x, - dist_attr={ - "process_mesh": mesh, - "dims_mapping": [0, -1, 1] - }) + auto.shard_tensor(x, mesh, ["x", None, "y"]) + return x def init_x_col(trans_x): if trans_x: x = paddle.static.data(name='x', shape=[6, 8], dtype='float32') - auto.shard_tensor(x, - dist_attr={ - "process_mesh": mesh, - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(x, mesh, [None, "x"]) + return x else: x = paddle.static.data(name='x', shape=[8, 6], dtype='float32') - auto.shard_tensor(x, - dist_attr={ - "process_mesh": mesh, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(x, mesh, ["x", None]) + return x def init_y_row(trans_y): if trans_y: y = paddle.static.data(name='y', shape=[4, 6], dtype='float32') - auto.shard_tensor(y, - dist_attr={ - "process_mesh": mesh, - "dims_mapping": [-1, 1] - }) + auto.shard_tensor(y, mesh, [None, "y"]) + return y else: y = paddle.static.data(name='y', shape=[6, 4], dtype='float32') - auto.shard_tensor(y, - dist_attr={ - "process_mesh": mesh, - "dims_mapping": [1, -1] - }) + auto.shard_tensor(y, mesh, ["y", None]) + return y def init_y_col(trans_y): if trans_y: y = paddle.static.data(name='y', shape=[4, 6], dtype='float32') - auto.shard_tensor(y, - dist_attr={ - "process_mesh": mesh, - "dims_mapping": [1, -1] - }) + auto.shard_tensor(y, mesh, ["y", None]) + return y else: y = paddle.static.data(name='y', shape=[6, 4], dtype='float32') - auto.shard_tensor(y, - dist_attr={ - "process_mesh": mesh, - "dims_mapping": [-1, 1] - }) + auto.shard_tensor(y, mesh, [None, "y"]) + return y diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py index 734bd7acf9dec1312ce3137cabb1586b97efa095..cf220a2049a31d1e96606f83ead9541aaed5d0cc 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_op_cost.py @@ -71,11 +71,8 @@ class TestDistOpCost(unittest.TestCase): shape=[4, 1], dtype='float32') label.stop_gradient = True - auto.shard_tensor(x, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [0, -1] - }) + auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]), + ["x", None]) tmp = paddle.fluid.layers.fill_constant_batch_size_like( input=x, shape=[2, 8], value=1, dtype='float32') weight_attr = paddle.ParamAttr() @@ -121,17 +118,12 @@ class TestDistOpCost(unittest.TestCase): shape=[8, 1], dtype='float32') label.stop_gradient = True - auto.shard_tensor(x, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [0] - }) + auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]), + ["x"]) auto.shard_tensor(label, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [0, -1] - }) + auto.ProcessMesh([0, 1], dim_names=["x"]), + ["x", None]) # embedding tmp = paddle.fluid.layers.fill_constant_batch_size_like( input=x, shape=[4], value=1, dtype='int32') @@ -141,12 +133,9 @@ class TestDistOpCost(unittest.TestCase): for op in main_program.global_block().ops: if op.type == "lookup_table_v2": W = main_program.global_block().vars[op.input("W")[0]] - auto.shard_tensor(W, - dist_attr={ - "process_mesh": - auto.ProcessMesh([0, 1]), - "dims_mapping": [0, -1] - }) + auto.shard_tensor( + W, auto.ProcessMesh([0, 1], dim_names=["x"]), + ["x", None]) out = paddle.fluid.layers.transpose(out, [1, 0]) # [8, 2] [-1, 0] @@ -154,26 +143,20 @@ class TestDistOpCost(unittest.TestCase): param1 = paddle.fluid.layers.create_parameter( [4, 8], paddle.float32) # [2, 8] [0, -1] auto.shard_tensor(param1, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [0, -1] - }) + auto.ProcessMesh([0, 1], dim_names=["x"]), + ["x", None]) param2 = paddle.fluid.layers.create_parameter( [8, 8], paddle.float32) # [8, 4] [-1, 0] auto.shard_tensor(param2, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [-1, 0] - }) + auto.ProcessMesh([0, 1], dim_names=["x"]), + [None, "x"]) out1 = paddle.fluid.layers.matmul(out, param1) # [8, 8] [-1, -1] tmp_param = paddle.fluid.layers.create_parameter( [8, 8], paddle.float32) # [8, 8] [-1, -1] auto.shard_tensor(param2, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [-1, -1] - }) + auto.ProcessMesh([0, 1], dim_names=["x"]), + [None, None]) tmp_out = paddle.fluid.layers.matmul(out1, tmp_param) out2 = paddle.fluid.layers.matmul(tmp_out, param2) # [8, 4] [-1, 0] @@ -227,17 +210,12 @@ class TestDistOpCost(unittest.TestCase): shape=[8, 1], dtype='float32') label.stop_gradient = True - auto.shard_tensor(x, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [0] - }) + auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]), + ["x"]) auto.shard_tensor(label, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [0, -1] - }) + auto.ProcessMesh([0, 1], dim_names=["x"]), + ["x", None]) # embedding tmp = paddle.fluid.layers.fill_constant_batch_size_like( input=x, shape=[4], value=1, dtype='int32') @@ -247,12 +225,9 @@ class TestDistOpCost(unittest.TestCase): for op in main_program.global_block().ops: if op.type == "lookup_table_v2": W = main_program.global_block().vars[op.input("W")[0]] - auto.shard_tensor(W, - dist_attr={ - "process_mesh": - auto.ProcessMesh([0, 1]), - "dims_mapping": [0, -1] - }) + auto.shard_tensor( + W, auto.ProcessMesh([0, 1], dim_names=["x"]), + ["x", None]) out = paddle.fluid.layers.transpose(out, [1, 0]) # [8, 2] [-1, 0] @@ -260,25 +235,20 @@ class TestDistOpCost(unittest.TestCase): param1 = paddle.fluid.layers.create_parameter( [4, 8], paddle.float32) # [2, 8] [0, -1] auto.shard_tensor(param1, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [0, -1] - }) + auto.ProcessMesh([0, 1], dim_names=["x"]), + ["x", None]) param2 = paddle.fluid.layers.create_parameter( [8, 8], paddle.float32) # [8, 4] [-1, 0] auto.shard_tensor(param2, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [-1, 0] - }) + auto.ProcessMesh([0, 1], dim_names=["x"]), + [None, "x"]) out1 = paddle.matmul(out, param1) # [8, 8] [-1, -1] tmp_param = paddle.fluid.layers.create_parameter( [8, 8], paddle.float32) # [8, 8] [-1, -1] auto.shard_tensor(param2, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [-1, -1] - }) + auto.ProcessMesh([0, 1], dim_names=["x"]), + [None, None]) + tmp_out = paddle.matmul(out1, tmp_param) out2 = paddle.matmul(tmp_out, param2) # [8, 4] [-1, 0] @@ -331,17 +301,11 @@ class TestDistOpCost(unittest.TestCase): shape=[8, 1], dtype='float32') label.stop_gradient = True - auto.shard_tensor(x, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [0] - }) - + auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]), + ["x"]) auto.shard_tensor(label, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [0, -1] - }) + auto.ProcessMesh([0, 1], dim_names=["x"]), + ["x", None]) # embedding tmp = paddle.fluid.layers.fill_constant_batch_size_like( input=x, shape=[4], value=1, dtype='int32') @@ -351,12 +315,9 @@ class TestDistOpCost(unittest.TestCase): for op in main_program.global_block().ops: if op.type == "lookup_table_v2": W = main_program.global_block().vars[op.input("W")[0]] - auto.shard_tensor(W, - dist_attr={ - "process_mesh": - auto.ProcessMesh([0, 1]), - "dims_mapping": [0, -1] - }) + auto.shard_tensor( + W, auto.ProcessMesh([0, 1], dim_names=["x"]), + ["x", None]) out = paddle.fluid.layers.transpose(out, [1, 0]) # [8, 2] [-1, 0] @@ -364,25 +325,21 @@ class TestDistOpCost(unittest.TestCase): param1 = paddle.fluid.layers.create_parameter( [4, 8], paddle.float32) # [2, 8] [0, -1] auto.shard_tensor(param1, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [0, -1] - }) + auto.ProcessMesh([0, 1], dim_names=["x"]), + ["x", None]) param2 = paddle.fluid.layers.create_parameter( [8, 8], paddle.float32) # [8, 4] [-1, 0] auto.shard_tensor(param2, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [-1, 0] - }) + auto.ProcessMesh([0, 1], dim_names=["x"]), + [None, "x"]) + out1 = paddle.fluid.layers.mul(out, param1) # [8, 8] [-1, -1] tmp_param = paddle.fluid.layers.create_parameter( [8, 8], paddle.float32) # [8, 8] [-1, -1] auto.shard_tensor(param2, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [-1, -1] - }) + auto.ProcessMesh([0, 1], dim_names=["x"]), + [None, None]) + tmp_out = paddle.fluid.layers.mul(out1, tmp_param) out2 = paddle.fluid.layers.mul(tmp_out, param2) # [8, 4] [-1, 0] diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_pnorm.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_pnorm.py index dfddba3dda1c96891a95774acd0adf28abcc9c68..14783dd89115220a1046e9271e8dee15624173af 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_pnorm.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_pnorm.py @@ -29,11 +29,8 @@ def make_program_dp2(): with paddle.static.program_guard(main_program, start_program): x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32') x.stop_gradient = False - auto.shard_tensor(x, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [0, -1, -1] - }) + auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]), + ["x", None, None]) tmp_0 = paddle.norm(x, p=2) return main_program, start_program, tmp_0 @@ -44,11 +41,8 @@ def make_program_serial(): with paddle.static.program_guard(main_program, start_program): x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32') x.stop_gradient = False - auto.shard_tensor(x, - dist_attr={ - "process_mesh": auto.ProcessMesh([0]), - "dims_mapping": [-1, -1, -1] - }) + auto.shard_tensor(x, auto.ProcessMesh([0], dim_names=["x"]), + [None, None, None]) tmp_0 = paddle.norm(x, p=2) return main_program, start_program, tmp_0 diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_reshape.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_reshape.py index 60b43ef9fe3bc28410691e16f6575f27faa4e9fb..e563e7554e905cada101b94ad15d6ff756d61cc4 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_reshape.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_reshape.py @@ -29,11 +29,9 @@ def make_program_dp2(): with paddle.static.program_guard(main_program, start_program): x = paddle.static.data(name='x', shape=[4, 4, 8], dtype='float32') x.stop_gradient = False - auto.shard_tensor(x, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [0, -1, -1] - }) + auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]), + ["x", None, None]) + tmp_0 = paddle.reshape(x, shape=[0, 0, 4, 2]) tmp_1 = paddle.reshape(tmp_0, shape=[0, 0, 8]) tmp_2 = tmp_1.reshape((tmp_1.shape[0], tmp_1.shape[1], -1)) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py index e12fd0f922a5e85400e2c5069045ff3d1abf691d..a1098899e3c535a3e62e2b9c6d20a6aceca625b3 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_slice.py @@ -25,11 +25,9 @@ def make_program_dp2(): start_program = paddle.fluid.Program() with paddle.static.program_guard(main_program, start_program): x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32') - auto.shard_tensor(x, - dist_attr={ - "process_mesh": auto.ProcessMesh([0, 1]), - "dims_mapping": [0, -1, -1] - }) + auto.shard_tensor(x, auto.ProcessMesh([0, 1], dim_names=["x"]), + ["x", None, None]) + tmp_0 = x[0] tmp_1 = x[:, 0, :] tmp_2 = x[:, :, 1] @@ -42,11 +40,9 @@ def make_program_serial(): start_program = paddle.fluid.Program() with paddle.static.program_guard(main_program, start_program): x = paddle.static.data(name='x', shape=[4, 5, 6], dtype='float32') - auto.shard_tensor(x, - dist_attr={ - "process_mesh": auto.ProcessMesh([0]), - "dims_mapping": [-1, -1, -1] - }) + auto.shard_tensor(x, auto.ProcessMesh([0], dim_names=["x"]), + [None, None, None]) + tmp_0 = x[0] tmp_1 = x[:, 0, :] tmp_2 = x[:, :, 1] diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_interface.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..6f0b73d83a7447033aff515e83f02be52d45a388 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_interface.py @@ -0,0 +1,224 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import paddle.fluid as fluid +import paddle.nn as nn +import paddle.nn.functional as F +import paddle.static as static +import paddle.distributed as dist +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context +from paddle.distributed.auto_parallel.process_mesh import ProcessMesh +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr + +paddle.enable_static() + +batch_size = 4 +epoch_num = 10 +hidden_size = 1024 +sequence_len = 512 +process_mesh1 = ProcessMesh(mesh=[[0, 1, 2, 3], [4, 5, 6, 7]], + dim_names=["x", "y"]) +process_mesh2 = ProcessMesh(mesh=[0, 1, 2, 3], dim_names=["x"]) + + +class MLPLayer(nn.Layer): + + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + param_initializer = nn.initializer.Normal(mean=0.0, + std=initializer_range) + + self.linear0 = nn.Linear( + d_model, + dim_feedforward, + weight_attr=paddle.ParamAttr(initializer=param_initializer), + bias_attr=None) + self.linear1 = nn.Linear( + dim_feedforward, + d_model, + weight_attr=paddle.ParamAttr(initializer=param_initializer), + bias_attr=None) + + def forward(self, input): + auto.shard_tensor(self.linear0.weight, process_mesh1[0], [None, "y"]) + linear0 = auto.shard_op(self.linear0, process_mesh1, + [["y", None, None]], [[None, "x", None]]) + linear0_out = linear0(input) + + gelu = auto.shard_op(F.gelu, process_mesh1, [["y", "x", None], None]) + gelu_out = gelu(linear0_out, approximate=True) + + auto.shard_tensor(self.linear1.weight, shard_spec=["y", None]) + linear1 = auto.shard_op(self.linear1, + process_mesh1[1], + out_shard_specs=[["y", None, None]]) + linear1_out = linear1(gelu_out) + + return self.linear0, self.linear1, linear0_out, gelu_out, linear1_out + + +class TestAutoParallelAPI(unittest.TestCase): + + def test_api(self): + # input + input = static.data(name="input", + shape=[batch_size, sequence_len, hidden_size], + dtype='float32') + label = static.data(name="label", + shape=[batch_size, sequence_len, 1], + dtype='float32') + + auto.shard_tensor(input, process_mesh1, ["x", None, None]) + auto.shard_tensor(label, process_mesh1, ["y", None, None]) + + mlp = MLPLayer(hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + + with ProcessMesh(process_mesh1.mesh, process_mesh1.dim_names): + linear0, linear1, linear0_out, gelu_out, linear1_out = mlp(input) + + default_program = paddle.fluid.default_main_program() + default_dist_context = get_default_distributed_context() + + self.assertEqual(len(default_program.blocks[0].ops), 5) + matmul0 = default_program.blocks[0].ops[0] + self.assertEqual(matmul0.type, "matmul_v2") + ewise_add0 = default_program.blocks[0].ops[1] + self.assertEqual(ewise_add0.type, "elementwise_add") + gelu = default_program.blocks[0].ops[2] + self.assertEqual(gelu.type, "gelu") + matmul1 = default_program.blocks[0].ops[3] + self.assertEqual(matmul1.type, "matmul_v2") + ewise_add1 = default_program.blocks[0].ops[4] + self.assertEqual(ewise_add1.type, "elementwise_add") + + dist_input = default_dist_context.get_dist_tensor_for_program(input) + self.assertEqual(dist_input.dist_attr.process_mesh, process_mesh1) + self.assertEqual(dist_input.dist_attr.dims_mapping, [0, -1, -1]) + self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh")) + self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping")) + + dist_input = default_dist_context.get_dist_tensor_for_program(label) + self.assertEqual(dist_input.dist_attr.process_mesh, process_mesh1) + self.assertEqual(dist_input.dist_attr.dims_mapping, [1, -1, -1]) + self.assertTrue(dist_input.dist_attr.is_annotated("process_mesh")) + self.assertTrue(dist_input.dist_attr.is_annotated("dims_mapping")) + + dist_linear0_weight = default_dist_context.get_dist_tensor_for_program( + linear0.weight) + self.assertEqual(dist_linear0_weight.dist_attr.process_mesh, + process_mesh1[0]) + self.assertEqual(dist_linear0_weight.dist_attr.dims_mapping, [-1, 0]) + self.assertTrue( + dist_linear0_weight.dist_attr.is_annotated("process_mesh")) + self.assertTrue( + dist_linear0_weight.dist_attr.is_annotated("dims_mapping")) + + dist_linear1_weight = default_dist_context.get_dist_tensor_for_program( + linear1.weight) + self.assertEqual(dist_linear1_weight.dist_attr.process_mesh, + process_mesh1) + self.assertEqual(dist_linear1_weight.dist_attr.dims_mapping, [1, -1]) + self.assertTrue( + dist_linear1_weight.dist_attr.is_annotated("process_mesh")) + self.assertTrue( + dist_linear1_weight.dist_attr.is_annotated("dims_mapping")) + + dist_linear1_out = default_dist_context.get_dist_tensor_for_program( + linear1_out) + self.assertEqual(dist_linear1_out.dist_attr.process_mesh, process_mesh1) + self.assertEqual(dist_linear1_out.dist_attr.dims_mapping, [-1, -1, -1]) + self.assertTrue(dist_linear1_out.dist_attr.is_annotated("process_mesh")) + self.assertFalse( + dist_linear1_out.dist_attr.is_annotated("dims_mapping")) + + dist_op = default_dist_context.get_dist_op_for_program(matmul0) + self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1) + self.assertEqual(dist_op.dist_attr.impl_type, "default") + self.assertEqual(dist_op.dist_attr.impl_idx, 0) + self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh")) + tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(input.name) + self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1) + self.assertEqual(tensor_dist_attr.dims_mapping, [1, -1, -1]) + self.assertTrue(tensor_dist_attr.is_annotated("process_mesh")) + self.assertTrue(tensor_dist_attr.is_annotated("dims_mapping")) + + dist_op = default_dist_context.get_dist_op_for_program(ewise_add0) + self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1) + self.assertEqual(dist_op.dist_attr.impl_type, "default") + self.assertEqual(dist_op.dist_attr.impl_idx, 0) + tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr( + linear0_out.name) + self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1) + self.assertEqual(tensor_dist_attr.dims_mapping, [-1, 0, -1]) + self.assertTrue(tensor_dist_attr.is_annotated("process_mesh")) + self.assertTrue(tensor_dist_attr.is_annotated("dims_mapping")) + self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh")) + + dist_op = default_dist_context.get_dist_op_for_program(gelu) + self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1) + self.assertEqual(dist_op.dist_attr.impl_type, "default") + self.assertEqual(dist_op.dist_attr.impl_idx, 0) + self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh")) + tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr( + linear0_out.name) + self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1) + self.assertEqual(tensor_dist_attr.dims_mapping, [1, 0, -1]) + self.assertTrue(tensor_dist_attr.is_annotated("process_mesh")) + self.assertTrue(tensor_dist_attr.is_annotated("dims_mapping")) + tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(gelu_out.name) + self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1) + self.assertEqual(tensor_dist_attr.dims_mapping, [-1, -1, -1]) + self.assertTrue(tensor_dist_attr.is_annotated("process_mesh")) + self.assertFalse(tensor_dist_attr.is_annotated("dims_mapping")) + + dist_op = default_dist_context.get_dist_op_for_program(matmul1) + self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1[1]) + self.assertEqual(dist_op.dist_attr.impl_type, "default") + self.assertEqual(dist_op.dist_attr.impl_idx, 0) + self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh")) + tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(gelu_out.name) + self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1[1]) + self.assertEqual(tensor_dist_attr.dims_mapping, [-1, -1, -1]) + self.assertTrue(tensor_dist_attr.is_annotated("process_mesh")) + self.assertFalse(tensor_dist_attr.is_annotated("dims_mapping")) + + dist_op = default_dist_context.get_dist_op_for_program(ewise_add1) + self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1[1]) + self.assertEqual(dist_op.dist_attr.impl_type, "default") + self.assertEqual(dist_op.dist_attr.impl_idx, 0) + self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh")) + tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr( + linear1_out.name) + self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1[1]) + self.assertEqual(tensor_dist_attr.dims_mapping, [0, -1, -1]) + self.assertTrue(tensor_dist_attr.is_annotated("process_mesh")) + self.assertTrue(tensor_dist_attr.is_annotated("dims_mapping")) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py index 35301e448959873c1c1c3cc3f59f6698560ec251..c0ff991ca52fe7f435eece48a832b4769fbc5caa 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_lr_grad_clip.py @@ -26,7 +26,6 @@ import paddle.distributed.fleet as fleet from paddle.io import Dataset from paddle.static import InputSpec from paddle.fluid.framework import _non_static_mode -from paddle.distributed.auto_parallel.engine import Engine from test_to_static import MLPLayer, MyDataset @@ -60,15 +59,13 @@ class TestEngineBase(unittest.TestCase): self.dataset = MyDataset(self.batch_num * self.batch_size) def init_engine(self): - inputs = InputSpec([self.batch_size, self.hidden_size], 'float32', 'x') - labels = InputSpec([self.batch_size], 'int64', 'label') + # inputs = InputSpec([self.batch_size, self.hidden_size], 'float32', 'x') + # labels = InputSpec([self.batch_size], 'int64', 'label') - self.engine = Engine(model=self.mlp, - inputs_spec=inputs, - labels_spec=labels) - self.engine.prepare(optimizer=self.optimizer, - loss=self.loss, - metrics=paddle.metric.Accuracy()) + self.engine = auto.Engine(model=self.mlp, + loss=self.loss, + optimizer=self.optimizer, + metrics=paddle.metric.Accuracy()) class TestLRScheduler(TestEngineBase): @@ -80,9 +77,9 @@ class TestLRScheduler(TestEngineBase): def test_lr_scheduler(self): self.init_engine() - lr = self.engine._optimizer._learning_rate - assert isinstance(lr, paddle.optimizer.lr.LRScheduler) self.engine.fit(self.dataset, batch_size=self.batch_size) + lr = self.engine._lr_optimizer._learning_rate + assert isinstance(lr, paddle.optimizer.lr.LRScheduler) class TestGradClipByGlobalNorm(TestEngineBase): @@ -94,7 +91,6 @@ class TestGradClipByGlobalNorm(TestEngineBase): def test_grad_clip(self): - clip = self.engine._optimizer._grad_clip self.engine.fit(self.dataset, batch_size=self.batch_size) self.check_program() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_amp.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_amp.py new file mode 100644 index 0000000000000000000000000000000000000000..ed2cf0328e85c5a12f4773ef938a31adf9fd1f4d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_amp.py @@ -0,0 +1,49 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +import os +import sys +import shutil +import subprocess +from paddle.distributed.fleet.launch_utils import run_with_coverage + + +class TestAMPPass(unittest.TestCase): + + def test_mp2(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "amp_pass_unittest.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = [sys.executable, "-u"] + coverage_args + [ + "-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir", + tmp_dir.name, launch_model_path + ] + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_grad_clip.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_grad_clip.py similarity index 100% rename from python/paddle/fluid/tests/unittests/auto_parallel/test_grad_clip.py rename to python/paddle/fluid/tests/unittests/auto_parallel/test_pass_grad_clip.py diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_gradient_merge.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_gradient_merge.py new file mode 100644 index 0000000000000000000000000000000000000000..e55ddbea58336665bb39aaa96ab9426a84f54fd1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_gradient_merge.py @@ -0,0 +1,50 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +import os +import sys +import shutil +import subprocess +from paddle.distributed.fleet.launch_utils import run_with_coverage + + +class TestGradientMergePass(unittest.TestCase): + + def test_dp2(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, + "gradient_merge_pass_unittest.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = [sys.executable, "-u"] + coverage_args + [ + "-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir", + tmp_dir.name, launch_model_path + ] + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_quantization.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..ff96f43a928a9e155903d7fc20c3a77bdbde5be7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_quantization.py @@ -0,0 +1,98 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +import random +import numpy as np +import paddle + +import paddle.distributed.auto_parallel as auto +from get_gpt_model import generate_model, create_data_holder, FakeDataset + +paddle.enable_static() + + +def apply_pass(): + dist_strategy = auto.Strategy() + dist_strategy.auto_mode = "semi" + qat = dist_strategy.qat + qat.enable = True + qat.channel_wise_abs_max = True + qat.weight_bits = 8 + qat.activation_bits = 8 + qat.not_quant_pattern = ['skip_quant'] + return dist_strategy + + +class TestQuantizationPass(unittest.TestCase): + + def test_qat_pass(self): + + batch_size = 8 + batch_num = 10 + + strategy = apply_pass() + model, loss = generate_model("serial") + opt = paddle.optimizer.AdamW(learning_rate=0.00001) + engine = auto.Engine(model, loss, opt, strategy=strategy) + dataset = FakeDataset(batch_size * batch_num) + engine.fit(dataset, 3, batch_size=batch_size) + + self.check_program(engine.main_program) + + def check_program(self, program): + + quantizable_op_and_inputs = {'matmul_v2': ['X', 'Y']} + quantizable_grad_op_inputs = {'matmul_v2_grad': ['X', 'Y']} + + quantized_ops = set() + for block in program.blocks: + for op in block.ops: + is_quntized = False + if op.type in quantizable_op_and_inputs: + for arg_name in op.input_arg_names: + if ".quantized" in arg_name: + is_quntized = True + + if not is_quntized: + continue + + # check forward + if op.type in quantizable_op_and_inputs: + for arg_name in op.input_arg_names: + assert arg_name.endswith('.quantized.dequantized') + quantized_ops.add(arg_name) + + for op in block.ops: + is_quntized = False + if op.type in quantizable_grad_op_inputs: + for pname in quantizable_grad_op_inputs[op.type]: + arg_name = op.input(pname)[0] + if ".quantized" in arg_name: + is_quntized = True + + if not is_quntized: + continue + + # check backward + if op.type in quantizable_grad_op_inputs: + for pname in quantizable_grad_op_inputs[op.type]: + arg_name = op.input(pname)[0] + assert arg_name.endswith('.quantized.dequantized') + assert arg_name in quantized_ops + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_recompute.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_recompute.py new file mode 100644 index 0000000000000000000000000000000000000000..e7eb7ddd2a604b4ff252dd82fdb6b2f798dd1e23 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_recompute.py @@ -0,0 +1,49 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +import os +import sys +import shutil +import subprocess +from paddle.distributed.fleet.launch_utils import run_with_coverage + + +class TestRecomputePass(unittest.TestCase): + + def test_mp2(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "recompute_pass_unittest.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = [sys.executable, "-u"] + coverage_args + [ + "-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir", + tmp_dir.name, launch_model_path + ] + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_sharding.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_sharding.py new file mode 100644 index 0000000000000000000000000000000000000000..77e969c83bf812fffee7ac1d70dd543984495bfd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_sharding.py @@ -0,0 +1,49 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +import os +import sys +import shutil +import subprocess +from paddle.distributed.fleet.launch_utils import run_with_coverage + + +class TestShardingPass(unittest.TestCase): + + def test_dp2sharding2(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join(file_dir, "sharding_pass_unittest.py") + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = [sys.executable, "-u"] + coverage_args + [ + "-m", "paddle.distributed.launch", "--devices", "0,1", "--log_dir", + tmp_dir.name, launch_model_path + ] + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..4232d64071e14e5cb2640ac2ff7cc93ddd4f87a8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh.py @@ -0,0 +1,139 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.nn as nn +import paddle.nn.functional as F +import paddle.static as static +import paddle.distributed.auto_parallel as auto +from paddle.distributed.auto_parallel.process_mesh import ProcessMesh +from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr + +paddle.enable_static() + +batch_size = 4 +epoch_num = 10 +hidden_size = 1024 +sequence_len = 512 + + +class MLPLayer(nn.Layer): + + def __init__(self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + param_initializer = nn.initializer.Normal(mean=0.0, + std=initializer_range) + + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.linear0 = nn.Linear( + d_model, + dim_feedforward, + weight_attr=paddle.ParamAttr(initializer=param_initializer), + bias_attr=None) + self.linear1 = nn.Linear( + dim_feedforward, + d_model, + weight_attr=paddle.ParamAttr(initializer=param_initializer), + bias_attr=None) + + def forward(self, input): + out = self.norm(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + return out + + +class TestProcessMesh(unittest.TestCase): + + def test_construction(self): + mesh = [[0, 1, 2], [3, 4, 5]] + process_mesh = ProcessMesh(mesh, dim_names=["x", "y"]) + self.assertEqual(process_mesh.shape, [2, 3]) + self.assertEqual(process_mesh.process_ids, [0, 1, 2, 3, 4, 5]) + self.assertEqual(process_mesh.dim_names, ["x", "y"]) + self.assertEqual(process_mesh.ndim, 2) + self.assertEqual(process_mesh, process_mesh) + self.assertEqual(str(process_mesh), str(process_mesh)) + + sub_process_mesh1 = process_mesh[0] + self.assertEqual(sub_process_mesh1.shape, [3]) + self.assertEqual(sub_process_mesh1.process_ids, [0, 1, 2]) + self.assertEqual(sub_process_mesh1.dim_names, ["y"]) + self.assertEqual(sub_process_mesh1.ndim, 1) + + sub_process_mesh2 = process_mesh[:, 1] + self.assertEqual(sub_process_mesh2.shape, [2]) + self.assertEqual(sub_process_mesh2.process_ids, [1, 4]) + self.assertEqual(sub_process_mesh2.dim_names, ["x"]) + self.assertEqual(sub_process_mesh2.ndim, 1) + + sub_process_mesh3 = sub_process_mesh2[:] + self.assertEqual(sub_process_mesh3.shape, [2]) + self.assertEqual(sub_process_mesh3.process_ids, [1, 4]) + self.assertEqual(sub_process_mesh3.dim_names, ["x"]) + self.assertEqual(sub_process_mesh3.ndim, 1) + + sub_process_mesh4 = process_mesh[1, 1] + self.assertEqual(sub_process_mesh4.shape, [1]) + self.assertEqual(sub_process_mesh4.process_ids, [4]) + self.assertEqual(sub_process_mesh4.dim_names, ["d0"]) + self.assertEqual(sub_process_mesh4.ndim, 1) + + def test_context_manager(self): + mesh = np.array([1, 2, 3, 4]) + input = static.data(name="input", + shape=[batch_size, sequence_len, hidden_size], + dtype='float32') + label = static.data(name="label", + shape=[batch_size, sequence_len, 1], + dtype='float32') + + mlp = MLPLayer(hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + + with ProcessMesh(mesh, "d"): + out = mlp(input) + + default_program = paddle.fluid.default_main_program() + default_dist_context = get_default_distributed_context() + + for block in default_program.blocks: + for tensor in block.vars.values(): + dist_tensor = default_dist_context.get_dist_tensor_for_program( + tensor) + if dist_tensor is not None: + self.assertEqual(dist_tensor.dist_attr.process_mesh, + ProcessMesh(mesh)) + for op in block.ops: + dist_op = default_dist_context.get_dist_op_for_program(op) + if dist_op is not None: + self.assertEqual(dist_op.dist_attr.process_mesh, + ProcessMesh(mesh)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh_v2.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh_v2.py index fcfafcb3e6d6d1a8607753dc74cf476bbe845140..3c58f9e8cd393a40fd62e152a71217464ed3165f 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh_v2.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_process_mesh_v2.py @@ -13,7 +13,8 @@ # limitations under the License import unittest -from paddle.distributed.auto_parallel.process_mesh_v2 import ProcessMesh +from paddle.distributed.auto_parallel.process_mesh_v2 import ( + ProcessMesh, compute_compatible_process_mesh, merge_process_mesh) class TestProcessMesh(unittest.TestCase): @@ -39,6 +40,54 @@ class TestProcessMesh(unittest.TestCase): self.assertNotEqual(process_mesh, process_mesh2) self.assertEqual(str(process_mesh), str(process_mesh)) + def test_compute_compatible_process_mesh(self): + process_mesh1 = ProcessMesh([[0, 1, 2], [3, 4, 5]], + dim_names=["x", "y"]) + compatible_process_mesh = compute_compatible_process_mesh( + [process_mesh1, None]) + self.assertEqual(compatible_process_mesh, process_mesh1) + compatible_process_mesh = compute_compatible_process_mesh( + [None, process_mesh1]) + self.assertEqual(compatible_process_mesh, process_mesh1) + + process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]]) + compatible_process_mesh = compute_compatible_process_mesh( + [process_mesh1, process_mesh2]) + self.assertEqual(compatible_process_mesh, process_mesh1) + self.assertEqual(compatible_process_mesh, process_mesh2) + + process_mesh2 = ProcessMesh([[0, 1, 2, 3, 4, 5]]) + compatible_process_mesh = compute_compatible_process_mesh( + [process_mesh1, process_mesh2]) + self.assertEqual(compatible_process_mesh, process_mesh1) + + process_mesh2 = ProcessMesh([[0, 1, 2]]) + compatible_process_mesh = compute_compatible_process_mesh( + [process_mesh1, process_mesh2]) + self.assertEqual(compatible_process_mesh, process_mesh1) + + def test_merge_process_mesh(self): + process_mesh1 = ProcessMesh([[0, 1, 2], [3, 4, 5]], + dim_names=["x", "y"]) + merged_process_mesh = merge_process_mesh([process_mesh1, None]) + print(merged_process_mesh) + self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5])) + merged_process_mesh = merge_process_mesh([None, process_mesh1]) + self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5])) + + process_mesh2 = ProcessMesh([[0, 1, 2], [3, 4, 5]]) + merged_process_mesh = merge_process_mesh([process_mesh1, process_mesh2]) + self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5])) + + process_mesh2 = ProcessMesh([[0, 1, 2]]) + merged_process_mesh = merge_process_mesh([process_mesh1, process_mesh2]) + self.assertEqual(merged_process_mesh, ProcessMesh([0, 1, 2, 3, 4, 5])) + + process_mesh2 = ProcessMesh([[6, 7]]) + merged_process_mesh = merge_process_mesh([process_mesh1, process_mesh2]) + self.assertEqual(merged_process_mesh, + ProcessMesh([0, 1, 2, 3, 4, 5, 6, 7])) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_quantization.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_quantization.py deleted file mode 100644 index f84ee03e0c9401e6c5bb369b9b0e72749e7325d8..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_quantization.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -import sys -import numpy as np -import paddle - -import paddle.distributed.fleet as fleet -import paddle.distributed.auto_parallel as auto - -from paddle.distributed.auto_parallel.engine import Engine -from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr - -sys.path.append("..") -import auto_parallel_gpt_model as modeling -from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion - -paddle.enable_static() - - -class FakeDataset: - - def __init__(self, num_samples, sequence_len, vocab_size): - self.num_samples = num_samples - self.sequence_len = sequence_len - self.vocab_size = vocab_size - - def __getitem__(self, idx): - tokens = np.random.randint(self.vocab_size, size=self.sequence_len) - position_ids = np.arange(self.sequence_len) - attention_mask = np.tril(np.ones(self.sequence_len)).reshape( - (1, self.sequence_len, self.sequence_len)).astype(np.float32) - labels = np.random.randint(self.vocab_size, size=self.sequence_len) - loss_mask = np.ones(self.sequence_len).astype(np.float32) - return tokens, position_ids, attention_mask, labels, loss_mask - - def __len__(self): - return self.num_samples - - -def apply_pass(): - dist_strategy = fleet.DistributedStrategy() - dist_strategy.semi_auto = True - dist_strategy.qat = True - dist_strategy.qat_configs = { - 'channel_wise_abs_max': True, - 'weight_bits': 8, - 'activation_bits': 8, - 'not_quant_pattern': ['skip_quant'], - } - return dist_strategy - - -def create_data_holder(batch_size, sequence_len): - tokens = paddle.static.InputSpec(name="tokens", - shape=[batch_size, sequence_len], - dtype='int64') - position_ids = paddle.static.InputSpec(name="position_ids", - shape=[batch_size, sequence_len], - dtype='int64') - attention_mask = paddle.static.InputSpec( - name="attention_mask", - shape=[batch_size, 1, sequence_len, sequence_len], - dtype='float32') - labels = paddle.static.InputSpec(name="labels", - shape=[batch_size, sequence_len], - dtype='int64') - loss_mask = paddle.static.InputSpec(name="loss_mask", - shape=[batch_size, sequence_len], - dtype='float32') - return [tokens, position_ids, attention_mask], [labels, loss_mask] - - -def get_gpt_model(): - modeling.init_global() - modeling._global_parallel_strategy = "serial" - modeling._global_process_mesh = auto.ProcessMesh(mesh=[0]) - - gpt = GPTModel(vocab_size=1000, - hidden_size=64, - num_hidden_layers=2, - num_attention_heads=8, - intermediate_size=256, - hidden_act="gelu", - hidden_dropout_prob=0.0, - attention_probs_dropout_prob=0.0, - max_position_embeddings=1024, - type_vocab_size=1, - initializer_range=0.02, - pad_token_id=0, - eos_token_id=7, - bos_token_id=0, - eol_token_id=3) - model = GPTForPretraining(gpt, - vocab_size=1000, - hidden_size=64, - initializer_range=0.02) - criterion = GPTPretrainingCriterion() - return model, criterion - - -class TestQuantizationPass(unittest.TestCase): - - def test_qat_pass(self): - - batch_size = 8 - batch_num = 10 - sequence_len = 512 - vocab_size = 1000 - - strategy = apply_pass() - model, loss = get_gpt_model() - opt = paddle.optimizer.AdamW(learning_rate=0.00001) - inputs_spec, labels_spec = create_data_holder(batch_size=batch_size, - sequence_len=sequence_len) - - engine = Engine(model, inputs_spec, labels_spec, strategy=strategy) - engine.prepare(optimizer=opt, loss=loss) - - dataset = FakeDataset(batch_size * batch_num, sequence_len, vocab_size) - engine.fit(train_data=dataset, batch_size=batch_size) - - self.check_program(engine.main_program) - - def check_program(self, program): - - quantizable_op_and_inputs = {'matmul_v2': ['X', 'Y']} - quantizable_grad_op_inputs = {'matmul_v2_grad': ['X', 'Y']} - - quantized_ops = set() - for block in program.blocks: - for op in block.ops: - is_quntized = False - if op.type in quantizable_op_and_inputs: - for arg_name in op.input_arg_names: - if ".quantized" in arg_name: - is_quntized = True - - if not is_quntized: - continue - - # check forward - if op.type in quantizable_op_and_inputs: - for arg_name in op.input_arg_names: - assert arg_name.endswith('.quantized.dequantized') - quantized_ops.add(arg_name) - - for op in block.ops: - is_quntized = False - if op.type in quantizable_grad_op_inputs: - for pname in quantizable_grad_op_inputs[op.type]: - arg_name = op.input(pname)[0] - if ".quantized" in arg_name: - is_quntized = True - - if not is_quntized: - continue - - # check backward - if op.type in quantizable_grad_op_inputs: - for pname in quantizable_grad_op_inputs[op.type]: - arg_name = op.input(pname)[0] - assert arg_name.endswith('.quantized.dequantized') - assert arg_name in quantized_ops - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..9fae8d970b2bb372ffcc18dd17e76e6f270ef1cf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py @@ -0,0 +1,206 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +# import yaml +import unittest +import paddle.distributed.auto_parallel as auto + + +class TestStrategy(unittest.TestCase): + + def test_default_config(self): + strategy = auto.Strategy() + + recompute = strategy.recompute + self.assertEqual(recompute.enable, False) + self.assertEqual(recompute.checkpoints, None) + + amp = strategy.amp + self.assertEqual(amp.enable, False) + self.assertAlmostEqual(amp.init_loss_scaling, 32768.0) + self.assertEqual(amp.incr_every_n_steps, 1000) + self.assertEqual(amp.decr_every_n_nan_or_inf, 2) + self.assertAlmostEqual(amp.incr_ratio, 2.0) + self.assertAlmostEqual(amp.decr_ratio, 0.8) + self.assertEqual(amp.use_dynamic_loss_scaling, True) + self.assertEqual(amp.custom_black_list, []) + self.assertEqual(amp.custom_white_list, []) + self.assertEqual(amp.custom_black_varnames, []) + self.assertEqual(amp.use_pure_fp16, False) + self.assertEqual(amp.use_fp16_guard, True) + self.assertEqual(amp.use_optimizer_fp16, False) + + sharding = strategy.sharding + self.assertEqual(sharding.enable, False) + self.assertEqual(sharding.stage, 1) + self.assertEqual(sharding.sharding_degree, 8) + self.assertAlmostEqual(sharding.segment_broadcast_MB, 32.0) + self.assertEqual(sharding.enable_tuning, False) + self.assertEqual(sharding.tuning_range, []) + + gradient_merge = strategy.gradient_merge + self.assertEqual(gradient_merge.enable, False) + self.assertEqual(gradient_merge.k_steps, 1) + self.assertEqual(gradient_merge.avg, True) + + qat = strategy.qat + self.assertEqual(qat.enable, False) + self.assertEqual(qat.channel_wise_abs_max, True) + self.assertEqual(qat.weight_bits, 8) + self.assertEqual(qat.activation_bits, 8) + self.assertEqual(qat.not_quant_pattern, ['skip_quant']) + self.assertEqual(qat.algo, None) + + tuning = strategy.tuning + self.assertEqual(tuning.enable, False) + self.assertEqual(tuning.batch_size, 1) + self.assertEqual(tuning.dataset, None) + self.assertEqual(tuning.profile_start_step, 1) + self.assertEqual(tuning.profile_end_step, 1) + self.assertEqual(tuning.run_after_tuning, True) + self.assertEqual(tuning.verbose, True) + + def test_modify_config(self): + strategy = auto.Strategy() + + recompute = strategy.recompute + recompute.enable = True + recompute.checkpoints = ["x"] + self.assertEqual(recompute.enable, True) + self.assertEqual(recompute.checkpoints, ["x"]) + + amp = strategy.amp + amp.enable = True + amp.init_loss_scaling = 16384.0 + amp.incr_every_n_steps = 2000 + amp.decr_every_n_nan_or_inf = 4 + amp.incr_ratio = 4.0 + amp.decr_ratio = 0.4 + amp.use_dynamic_loss_scaling = False + amp.custom_white_list = ["x"] + amp.custom_black_list = ["y"] + amp.custom_black_varnames = ["z"] + amp.use_pure_fp16 = True + amp.use_fp16_guard = False + amp.use_optimizer_fp16 = True + self.assertEqual(amp.enable, True) + self.assertAlmostEqual(amp.init_loss_scaling, 16384.0) + self.assertEqual(amp.incr_every_n_steps, 2000) + self.assertEqual(amp.decr_every_n_nan_or_inf, 4) + self.assertAlmostEqual(amp.incr_ratio, 4.0) + self.assertAlmostEqual(amp.decr_ratio, 0.4) + self.assertEqual(amp.use_dynamic_loss_scaling, False) + self.assertEqual(amp.custom_white_list, ["x"]) + self.assertEqual(amp.custom_black_list, ["y"]) + self.assertEqual(amp.custom_black_varnames, ["z"]) + self.assertEqual(amp.use_pure_fp16, True) + self.assertEqual(amp.use_fp16_guard, False) + self.assertEqual(amp.use_optimizer_fp16, True) + + sharding = strategy.sharding + sharding.enable = True + sharding.stage = 2 + sharding.sharding_degree = 2 + sharding.segment_broadcast_MB = 64.0 + sharding.enable_tuning = True + sharding.tuning_range = [1, 2, 3] + self.assertEqual(sharding.enable, True) + self.assertEqual(sharding.stage, 2) + self.assertEqual(sharding.sharding_degree, 2) + self.assertAlmostEqual(sharding.segment_broadcast_MB, 64.0) + self.assertEqual(sharding.enable_tuning, True) + self.assertEqual(sharding.tuning_range, [1, 2, 3]) + + gradient_merge = strategy.gradient_merge + gradient_merge.enable = True + gradient_merge.k_steps = 4 + gradient_merge.avg = False + self.assertEqual(gradient_merge.enable, True) + self.assertEqual(gradient_merge.k_steps, 4) + self.assertEqual(gradient_merge.avg, False) + + # def test_file_config(self): + # yaml_data = """ + # all_ranks: false + # amp: + # custom_black_list: + # - y + # custom_black_varnames: + # - z + # custom_white_list: + # - x + # decr_every_n_nan_or_inf: 4 + # decr_ratio: 0.4 + # enable: false + # incr_every_n_steps: 2000 + # incr_ratio: 4.0 + # init_loss_scaling: 16384.0 + # use_dynamic_loss_scaling: false + # use_fp16_guard: false + # use_optimizer_fp16: true + # use_pure_fp16: true + # auto_mode: semi + # gradient_merge: + # avg: false + # enable: false + # k_steps: 4 + # gradient_scale: true + # qat: + # activation_bits: 8 + # algo: null + # channel_wise_abs_max: true + # enable: false + # not_quant_pattern: + # - skip_quant + # weight_bits: 8 + # recompute: + # checkpoints: null + # enable: false + # enable_tuning: false + # return_numpy: true + # seed: null + # sharding: + # enable: false + # enable_tuning: true + # segment_broadcast_MB: 64.0 + # sharding_degree: 8 + # stage: 2 + # tuning_range: None + # split_data: false + # tuning: + # batch_size: 1 + # dataset: null + # enable: false + # profile_end_step: 1 + # profile_start_step: 1 + # run_after_tuning: true + # verbose: true + # use_cache: true + # """ + # yaml_path = "./strategy.yml" + # yaml_dict = yaml.load(yaml_data, Loader=yaml.Loader) + # with open(yaml_path, 'w') as outfile: + # yaml.dump(yaml_dict, outfile, default_flow_style=False) + + # strategy = auto.Strategy(yaml_path) + # self.assertEqual(yaml_dict, strategy.to_dict()) + + # # Remove the created file + # if os.path.exists(yaml_path): + # os.remove(yaml_path) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_to_static.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_to_static.py index 86832f485c162a1cbb189e8cfdcbd64cb527e183..5e545a7a63a0e53229575041ac096e25bf4e7286 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_to_static.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_to_static.py @@ -27,7 +27,6 @@ from paddle import LazyGuard from paddle.io import Dataset from paddle.static import InputSpec from paddle.fluid.framework import _non_static_mode -from paddle.distributed.auto_parallel.engine import Engine from paddle.distributed.auto_parallel.helper import ProgramHelper batch_size = 4 @@ -140,23 +139,19 @@ class TestToStatic(unittest.TestCase): dataset = MyDataset(batch_num * batch_size) - inputs = InputSpec([batch_size, hidden_size], 'float32', 'x') - labels = InputSpec([batch_size], 'int64', 'label') + # inputs = InputSpec([batch_size, hidden_size], 'float32', 'x') + # labels = InputSpec([batch_size], 'int64', 'label') - engine = Engine(model=mlp, - inputs_spec=inputs, - labels_spec=labels, - strategy=None) assert _non_static_mode() == True - - engine.prepare(optimizer=optimizer, - loss=loss, - metrics=paddle.metric.Accuracy()) - - assert _non_static_mode() == False + engine = auto.Engine(model=mlp, + loss=loss, + optimizer=optimizer, + metrics=paddle.metric.Accuracy(), + strategy=None) engine.fit(dataset, batch_size=batch_size) engine.evaluate(dataset, batch_size=batch_size) engine.predict(dataset, batch_size=batch_size) + assert _non_static_mode() == False class TestLazyInit(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_completion.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_completion.py index 3dabe38ff6e1d7b414f0562c7aaa28bee4aabf88..1c869813d319b439343fa78e0fd8fc59036b323e 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_completion.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_completion.py @@ -36,7 +36,7 @@ batch_size = 4 epoch_num = 10 hidden_size = 1024 sequence_len = 512 -_g_process_mesh = [[0, 1], [2, 3]] +_g_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=['x', 'y']) def get_random_inputs_and_labels(input_shape, label_shape): @@ -84,18 +84,12 @@ class MLPLayer(nn.Layer): def forward(self, input): out = self.norm(input) - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _g_process_mesh[0], - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.linear0.weight, _g_process_mesh[:, 0], + [None, 'x']) out = self.linear0(out) out = F.gelu(out, approximate=True) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _g_process_mesh[1], - "dims_mapping": [0, -1] - }) + auto.shard_tensor(self.linear1.weight, _g_process_mesh[:, 1], + ['x', None]) out = self.linear1(out) return out @@ -155,16 +149,8 @@ def get_program(): dataloader.set_batch_generator(batch_generator_creator(), places=paddle.static.cuda_places()) # data dist_attr - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _g_process_mesh[0], - "dims_mapping": [-1, -1, -1] - }) - auto.shard_tensor(label, - dist_attr={ - "process_mesh": _g_process_mesh[0], - "dims_mapping": [-1, -1, -1] - }) + auto.shard_tensor(input, _g_process_mesh[:, 0], [None, None, None]) + auto.shard_tensor(label, _g_process_mesh[:, 0], [None, None, None]) mlp_start = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_partition.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_partition.py index 3c6e086ae7face1c1bdbe11dc7c099bd62a2173f..444e0df454d96430c6140582fd1c7df4ec8d6734 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_partition.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_while_op_partition.py @@ -37,7 +37,7 @@ batch_size = 4 epoch_num = 10 hidden_size = 1024 sequence_len = 512 -_g_process_mesh = auto.ProcessMesh([0, 1]) +_g_process_mesh = auto.ProcessMesh([0, 1], dim_names=['x']) def get_random_inputs_and_labels(input_shape, label_shape): @@ -85,61 +85,21 @@ class MLPLayer(nn.Layer): def forward(self, input): - auto.shard_tensor(self.norm.weight, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1] - }) - auto.shard_tensor(self.norm.bias, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1] - }) - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, 0] - }) - auto.shard_tensor(self.linear0.bias, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [0] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [0, -1] - }) - auto.shard_tensor(self.linear1.bias, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1] - }) + auto.shard_tensor(self.norm.weight, _g_process_mesh, [None]) + auto.shard_tensor(self.norm.bias, _g_process_mesh, [None]) + auto.shard_tensor(self.linear0.weight, _g_process_mesh, [None, 'x']) + auto.shard_tensor(self.linear0.bias, _g_process_mesh, ['x']) + auto.shard_tensor(self.linear1.weight, _g_process_mesh, ['x', None]) + auto.shard_tensor(self.linear1.bias, _g_process_mesh, [None]) out = self.norm(input) - auto.shard_tensor(out, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, -1] - }) + auto.shard_tensor(out, _g_process_mesh, [None, None, None]) out = self.linear0(out) - auto.shard_tensor(out, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, 0] - }) + auto.shard_tensor(out, _g_process_mesh, [None, None, 'x']) out = F.gelu(out, approximate=True) - auto.shard_tensor(out, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, 0] - }) + auto.shard_tensor(out, _g_process_mesh, [None, None, 'x']) out = self.linear1(out) - auto.shard_tensor(out, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, -1] - }) + auto.shard_tensor(out, _g_process_mesh, [None, None, None]) return out @@ -155,21 +115,13 @@ def get_program(): # 循环计数器 i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0) - auto.shard_tensor(i, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1] - }) + auto.shard_tensor(i, _g_process_mesh, [None]) # 循环次数 loop_len = fluid.layers.fill_constant(shape=[1], dtype='int64', value=epoch_num) - auto.shard_tensor(loop_len, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1] - }) + auto.shard_tensor(loop_len, _g_process_mesh, [None]) # input input = static.data(name="input", @@ -188,25 +140,13 @@ def get_program(): dataloader.set_batch_generator(batch_generator_creator(), places=paddle.static.cuda_places()) # data dist_attr - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, -1] - }) - auto.shard_tensor(label, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, -1] - }) + auto.shard_tensor(input, _g_process_mesh, [None, None, None]) + auto.shard_tensor(label, _g_process_mesh, [None, None, None]) # fill constant bsz like tmp = paddle.fluid.layers.fill_constant_batch_size_like( input=input, shape=[-1, 16, 0, 48], dtype='float32', value=0) - auto.shard_tensor(tmp, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, 0, -1, -1] - }) + auto.shard_tensor(tmp, _g_process_mesh, [None, 'x', None, None]) # model mlp_start = MLPLayer(hidden_size=hidden_size, @@ -216,28 +156,21 @@ def get_program(): pred = mlp_start(input) input_array = fluid.layers.array_write(pred, i) - auto.shard_tensor(input_array, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, -1] - }) + # TODO: check whether this annotation is needed + # auto.shard_tensor(input_array, + # dist_attr={ + # "process_mesh": _g_process_mesh, + # "dims_mapping": [-1, -1, -1] + # }) cond = fluid.layers.less_than(x=i, y=loop_len) - auto.shard_tensor(cond, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1] - }) + auto.shard_tensor(cond, _g_process_mesh, [None]) while_op = fluid.layers.While(cond=cond) with while_op.block(): pre_input = fluid.layers.array_read(array=input_array, i=i) - auto.shard_tensor(pre_input, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, -1] - }) + auto.shard_tensor(pre_input, _g_process_mesh, [None, None, None]) mlp_while = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, @@ -251,11 +184,7 @@ def get_program(): fluid.layers.less_than(x=i, y=loop_len, cond=cond) end_pred = fluid.layers.array_read(array=input_array, i=i) - auto.shard_tensor(end_pred, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, -1] - }) + auto.shard_tensor(end_pred, _g_process_mesh, [None, None, None]) mlp_end = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, @@ -264,18 +193,10 @@ def get_program(): pred = mlp_end(end_pred) error_cost = paddle.nn.functional.square_error_cost(pred, label) - auto.shard_tensor(error_cost, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1, -1, -1] - }) + auto.shard_tensor(error_cost, _g_process_mesh, [None, None, None]) loss = paddle.mean(error_cost) - auto.shard_tensor(loss, - dist_attr={ - "process_mesh": _g_process_mesh, - "dims_mapping": [-1] - }) + auto.shard_tensor(loss, _g_process_mesh, [None]) return train_program, start_program, dataloader, i, loss diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_autoconvert.py b/python/paddle/fluid/tests/unittests/auto_parallel_autoconvert.py index c3f64e30fc5967de0a0ed968bc7100cc42037842..2e65c9bd467356537387ba4dbaaa69c0ca54fc64 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel_autoconvert.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_autoconvert.py @@ -67,38 +67,18 @@ class MLPLayer(nn.Layer): def forward(self, input): if _global_parallel_strategy == "pp": - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [-1, -1] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None]) + auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None]) elif _global_parallel_strategy == "mp": - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(self.linear0.weight, _global_process_mesh, + [None, "x"]) + auto.shard_tensor(self.linear1.weight, _global_process_mesh, + ["x", None]) elif _global_parallel_strategy == "dp": - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, -1] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(self.linear0.weight, _global_process_mesh, + [None, None]) + auto.shard_tensor(self.linear1.weight, _global_process_mesh, + [None, None]) out = self.norm(input) out = self.linear0(out) @@ -120,28 +100,12 @@ def mlp_forward(train_program, start_program): dtype='float32') if _global_parallel_strategy == "pp": - auto.shard_tensor(input, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [-1, -1] - }) - auto.shard_tensor(label, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(input, PP_MESH_0, [None, None]) + auto.shard_tensor(label, PP_MESH_1, [None, None]) elif _global_parallel_strategy == "dp": - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(input, _global_process_mesh, ["x", None]) elif _global_parallel_strategy == "mp": - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(input, _global_process_mesh, [None, None]) mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, @@ -186,7 +150,7 @@ class TestMLPAutoConvert(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh([0, 1]) + _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"]) input = np.random.random(size=(80, 64)).astype('float32') label = np.random.random(size=(80, 1)).astype('float32') @@ -212,11 +176,11 @@ class TestMLPAutoConvert(unittest.TestCase): set_default_distributed_context(None) _global_parallel_strategy = "pp" - _global_process_mesh = auto.ProcessMesh([0, 1]) + _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"]) global PP_MESH_0 - PP_MESH_0 = auto.ProcessMesh(mesh=[0]) + PP_MESH_0 = auto.ProcessMesh(mesh=[0], dim_names=["pp0"]) global PP_MESH_1 - PP_MESH_1 = auto.ProcessMesh(mesh=[1]) + PP_MESH_1 = auto.ProcessMesh(mesh=[1], dim_names=["pp1"]) dist_main_prog_load, dist_start_prog_load, loss_load = get_distributed_program( ) @@ -268,7 +232,7 @@ class TestMLPAutoConvert2(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "pp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh([0, 1]) + _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"]) global PP_MESH_0 PP_MESH_0 = auto.ProcessMesh(mesh=[0]) global PP_MESH_1 @@ -303,7 +267,7 @@ class TestMLPAutoConvert2(unittest.TestCase): set_default_distributed_context(None) _global_parallel_strategy = "mp" - _global_process_mesh = auto.ProcessMesh([0, 1]) + _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"]) dist_main_prog_load, dist_start_prog_load, loss_load = get_distributed_program( ) @@ -350,7 +314,7 @@ class TestMLPAutoConvertInvalid(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh([0, 1]) + _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"]) dist_main_prog, _, _ = get_distributed_program() with self.assertRaises(TypeError): save_distributed_checkpoint(dist_main_prog, [""], [""], diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_data_unshard.py b/python/paddle/fluid/tests/unittests/auto_parallel_data_unshard.py index 9d2b2739401214c560644a63c35d7e9724bca134..c7ce4c2326cf27c3b38e6b15ea1b285987ebba86 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel_data_unshard.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_data_unshard.py @@ -38,7 +38,7 @@ class TestDataUnshard(unittest.TestCase): def create_model(train_program, start_program): with paddle.static.program_guard(train_program, start_program): - MESH_0 = auto.ProcessMesh([0, 1]) + MESH_0 = auto.ProcessMesh([0, 1], dim_names=["x"]) input = paddle.static.data(name='input', shape=[2, 8]) label = paddle.static.data(name='label', shape=[2, 8]) @@ -47,26 +47,10 @@ class TestDataUnshard(unittest.TestCase): linear0 = nn.Linear(8, 8, weight_attr) linear1 = nn.Linear(8, 8, weight_attr) - auto.shard_tensor(input, - dist_attr={ - "process_mesh": MESH_0, - "dims_mapping": [0, -1] - }) - auto.shard_tensor(label, - dist_attr={ - "process_mesh": MESH_0, - "dims_mapping": [0, -1] - }) - auto.shard_tensor(linear0.weight, - dist_attr={ - "process_mesh": MESH_0, - "dims_mapping": [-1, -1] - }) - auto.shard_tensor(linear1.weight, - dist_attr={ - "process_mesh": MESH_0, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(input, MESH_0, ["x", None]) + auto.shard_tensor(label, MESH_0, ["x", None]) + auto.shard_tensor(linear0.weight, MESH_0, [None, None]) + auto.shard_tensor(linear1.weight, MESH_0, [None, None]) linear0_out = linear0(input) gelu_out = F.gelu(linear0_out) @@ -124,7 +108,7 @@ class TestDataUnshard(unittest.TestCase): def create_model(train_program, start_program): with paddle.static.program_guard(train_program, start_program): - MESH_0 = auto.ProcessMesh([0, 1]) + MESH_0 = auto.ProcessMesh([0, 1], dim_names=["x"]) input = paddle.static.data(name='input', shape=[8, 8]) label = paddle.static.data(name='label', shape=[8, 8]) @@ -133,27 +117,10 @@ class TestDataUnshard(unittest.TestCase): linear0 = nn.Linear(8, 8, weight_attr) linear1 = nn.Linear(8, 8, weight_attr) - auto.shard_tensor(input, - dist_attr={ - "process_mesh": MESH_0, - "dims_mapping": [-1, -1] - }) - auto.shard_tensor(label, - dist_attr={ - "process_mesh": MESH_0, - "dims_mapping": [-1, -1] - }) - - auto.shard_tensor(linear0.weight, - dist_attr={ - "process_mesh": MESH_0, - "dims_mapping": [-1, 0] - }) - auto.shard_tensor(linear1.weight, - dist_attr={ - "process_mesh": MESH_0, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(input, MESH_0, [None, None]) + auto.shard_tensor(label, MESH_0, [None, None]) + auto.shard_tensor(linear0.weight, MESH_0, [None, "x"]) + auto.shard_tensor(linear1.weight, MESH_0, ["x", None]) linear0_out = linear0(input) gelu_out = F.gelu(linear0_out) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py index 8aef4d1086066a7a23548e00bbef3a8168e322e3..e7f721dd422cf5fb3ff8bacedf651332a5308675 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_gpt_model.py @@ -114,30 +114,18 @@ class MultiHeadAttention(nn.Layer): """ q = self.q_proj(query) if _global_parallel_strategy == "mp": - auto.shard_tensor(self.q_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.q_proj.weight, _global_process_mesh, + [None, "x"]) elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.q_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + auto.shard_tensor(self.q_proj.weight, _global_process_mesh, + [None, "y"]) elif _global_parallel_strategy == "mp_pp": - auto.shard_tensor(self.q_proj.weight, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.q_proj.weight, MPPP_MESH_LIST[self.mesh_idx], + [None, "x"]) elif _global_parallel_strategy == "dp_mp_pp": auto.shard_tensor(self.q_proj.weight, - dist_attr={ - "process_mesh": - DPMPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [-1, 1] - }) + DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"]) + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) if isinstance(cache, self.StaticCache): @@ -165,56 +153,30 @@ class MultiHeadAttention(nn.Layer): """ k = self.k_proj(key) if _global_parallel_strategy == "mp": - auto.shard_tensor(self.k_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.k_proj.weight, _global_process_mesh, + [None, "x"]) elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.k_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + auto.shard_tensor(self.k_proj.weight, _global_process_mesh, + [None, "y"]) elif _global_parallel_strategy == "mp_pp": - auto.shard_tensor(self.k_proj.weight, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.k_proj.weight, MPPP_MESH_LIST[self.mesh_idx], + [None, "x"]) elif _global_parallel_strategy == "dp_mp_pp": auto.shard_tensor(self.k_proj.weight, - dist_attr={ - "process_mesh": - DPMPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [-1, 1] - }) + DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"]) v = self.v_proj(value) if _global_parallel_strategy == "mp": - auto.shard_tensor(self.v_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.v_proj.weight, _global_process_mesh, + [None, "x"]) elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.v_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + auto.shard_tensor(self.v_proj.weight, _global_process_mesh, + [None, "y"]) elif _global_parallel_strategy == "mp_pp": - auto.shard_tensor(self.v_proj.weight, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.v_proj.weight, MPPP_MESH_LIST[self.mesh_idx], + [None, "x"]) elif _global_parallel_strategy == "dp_mp_pp": auto.shard_tensor(self.v_proj.weight, - dist_attr={ - "process_mesh": - DPMPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [-1, 1] - }) + DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) @@ -287,30 +249,18 @@ class MultiHeadAttention(nn.Layer): # project to output out = self.out_proj(out) if _global_parallel_strategy == "mp": - auto.shard_tensor(self.out_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(self.out_proj.weight, _global_process_mesh, + ["x", None]) elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.out_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + auto.shard_tensor(self.out_proj.weight, _global_process_mesh, + ["y", None]) elif _global_parallel_strategy == "mp_pp": auto.shard_tensor(self.out_proj.weight, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [0, -1] - }) + MPPP_MESH_LIST[self.mesh_idx], ["x", None]) elif _global_parallel_strategy == "dp_mp_pp": auto.shard_tensor(self.out_proj.weight, - dist_attr={ - "process_mesh": - DPMPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [1, -1] - }) + DPMPPP_MESH_LIST[self.mesh_idx], ["y", None]) + outs = [out] if self.need_weights: outs.append(weights) @@ -352,96 +302,53 @@ class TransformerDecoder(nn.Layer): new_caches = [] self.checkpoints = [] if _global_parallel_strategy == "pp": - auto.shard_tensor(output, - dist_attr={ - "process_mesh": - PP_MESH_LIST[0], - "dims_mapping": - [-1 for i in range(len(output.shape))] - }) + auto.shard_tensor(output, PP_MESH_LIST[0], + [None for i in range(len(output.shape))]) if _global_parallel_strategy == "dp_pp": - auto.shard_tensor(output, - dist_attr={ - "process_mesh": - DPPP_MESH_LIST[0], - "dims_mapping": [0] + - [-1 for i in range(len(output.shape) - 1)] - }) + auto.shard_tensor(output, DPPP_MESH_LIST[0], ["x"].extends( + [None for i in range(len(output.shape) - 1)])) if _global_parallel_strategy == "mp_pp": - auto.shard_tensor(output, - dist_attr={ - "process_mesh": - MPPP_MESH_LIST[0], - "dims_mapping": [-1] + - [-1 for i in range(len(output.shape) - 1)] - }) + auto.shard_tensor(output, MPPP_MESH_LIST[0], + [None for i in range(len(output.shape))]) if _global_parallel_strategy == "dp_mp_pp": - auto.shard_tensor(output, - dist_attr={ - "process_mesh": - DPMPPP_MESH_LIST[0], - "dims_mapping": [0] + - [-1 for i in range(len(output.shape) - 1)] - }) + auto.shard_tensor(output, DPMPPP_MESH_LIST[0], ["x"].extends( + [None for i in range(len(output.shape) - 1)])) for i, mod in enumerate(self.layers): if cache is None: if use_cache: if _global_parallel_strategy == "pp": output, new_cache = auto.shard_op( - mod, - dist_attr={ - "process_mesh": PP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache) + mod, PP_MESH_LIST[mod.mesh_idx])(output, memory, + tgt_mask, + use_cache, cache) auto.shard_tensor( - output, - dist_attr={ - "process_mesh": - PP_MESH_LIST[mod.mesh_idx], - "dims_mapping": - [-1 for i in range(len(output.shape))] - }) + output, PP_MESH_LIST[mod.mesh_idx], + [None for i in range(len(output.shape))]) elif _global_parallel_strategy == "dp_pp": output, new_cache = auto.shard_op( - mod, - dist_attr={ - "process_mesh": DPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache) + mod, DPPP_MESH_LIST[mod.mesh_idx])(output, memory, + tgt_mask, + use_cache, cache) auto.shard_tensor( - output, - dist_attr={ - "process_mesh": - DPPP_MESH_LIST[mod.mesh_idx], - "dims_mapping": [0] + - [-1 for i in range(len(output.shape) - 1)] - }) + output, DPPP_MESH_LIST[mod.mesh_idx], ["x"].extends( + [None for i in range(len(output.shape) - 1)])) elif _global_parallel_strategy == "mp_pp": output, new_cache = auto.shard_op( - mod, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache) + mod, MPPP_MESH_LIST[mod.mesh_idx])(output, memory, + tgt_mask, + use_cache, cache) auto.shard_tensor( - output, - dist_attr={ - "process_mesh": - MPPP_MESH_LIST[mod.mesh_idx], - "dims_mapping": [-1] + - [-1 for i in range(len(output.shape) - 1)] - }) + output, MPPP_MESH_LIST[mod.mesh_idx], + [None for i in range(len(output.shape))]) elif _global_parallel_strategy == "dp_mp_pp": output, new_cache = auto.shard_op( mod, - dist_attr={ - "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache) + DPMPPP_MESH_LIST[mod.mesh_idx])(output, memory, + tgt_mask, use_cache, + cache) auto.shard_tensor( - output, - dist_attr={ - "process_mesh": - DPMPPP_MESH_LIST[mod.mesh_idx], - "dims_mapping": [0] + - [-1 for i in range(len(output.shape) - 1)] - }) + output, DPMPPP_MESH_LIST[mod.mesh_idx], + [None for i in range(len(output.shape))]) else: output, new_cache = mod(output, memory, @@ -451,64 +358,36 @@ class TransformerDecoder(nn.Layer): new_caches.append(new_cache) else: if _global_parallel_strategy == "pp": - output = auto.shard_op(mod, - dist_attr={ - "process_mesh": - PP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, - use_cache, cache) + output = auto.shard_op(mod, PP_MESH_LIST[mod.mesh_idx])( + output, memory, tgt_mask, use_cache, cache) auto.shard_tensor( - output, - dist_attr={ - "process_mesh": - PP_MESH_LIST[mod.mesh_idx], - "dims_mapping": - [-1 for i in range(len(output.shape))] - }) + output, PP_MESH_LIST[mod.mesh_idx], + [None for i in range(len(output.shape))]) elif _global_parallel_strategy == "dp_pp": - output = auto.shard_op(mod, - dist_attr={ - "process_mesh": - DPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, - use_cache, cache) + output = auto.shard_op( + mod, DPPP_MESH_LIST[mod.mesh_idx])(output, memory, + tgt_mask, + use_cache, cache) auto.shard_tensor( - output, - dist_attr={ - "process_mesh": - DPPP_MESH_LIST[mod.mesh_idx], - "dims_mapping": [0] + - [-1 for i in range(len(output.shape) - 1)] - }) + output, DPPP_MESH_LIST[mod.mesh_idx], ["x"].extends( + [None for i in range(len(output.shape) - 1)])) elif _global_parallel_strategy == "mp_pp": - output = auto.shard_op(mod, - dist_attr={ - "process_mesh": - MPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, - use_cache, cache) + output = auto.shard_op( + mod, MPPP_MESH_LIST[mod.mesh_idx])(output, memory, + tgt_mask, + use_cache, cache) auto.shard_tensor( - output, - dist_attr={ - "process_mesh": - MPPP_MESH_LIST[mod.mesh_idx], - "dims_mapping": [-1] + - [-1 for i in range(len(output.shape) - 1)] - }) + output, MPPP_MESH_LIST[mod.mesh_idx], + [None for i in range(len(output.shape))]) elif _global_parallel_strategy == "dp_mp_pp": - output = auto.shard_op( - mod, - dist_attr={ - "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache) + output = auto.shard_op(mod, + DPMPPP_MESH_LIST[mod.mesh_idx])( + output, memory, tgt_mask, + use_cache, cache) auto.shard_tensor( - output, - dist_attr={ - "process_mesh": - DPMPPP_MESH_LIST[mod.mesh_idx], - "dims_mapping": [0] + - [-1 for i in range(len(output.shape) - 1)] - }) + output, DPMPPP_MESH_LIST[mod.mesh_idx], + ["x"].extends( + [None for i in range(len(output.shape) - 1)])) else: output = mod(output, memory, @@ -519,58 +398,33 @@ class TransformerDecoder(nn.Layer): if _global_parallel_strategy == "pp": output, new_cache = auto.shard_op( mod, - dist_attr={"process_mesh": PP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, - cache) - auto.shard_tensor( - output, - dist_attr={ - "process_mesh": PP_MESH_LIST[mod.mesh_idx], - "dims_mapping": - [-1 for i in range(len(output.shape))] - }) + PP_MESH_LIST[mod.mesh_idx])(output, memory, tgt_mask, + use_cache, cache) + auto.shard_tensor(output, PP_MESH_LIST[mod.mesh_idx], + [None for i in range(len(output.shape))]) elif _global_parallel_strategy == "dp_pp": output, new_cache = auto.shard_op( mod, - dist_attr={ - "process_mesh": DPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache) - auto.shard_tensor( - output, - dist_attr={ - "process_mesh": - DPPP_MESH_LIST[mod.mesh_idx], - "dims_mapping": - [0] + [-1 for i in range(len(output.shape) - 1)] - }) + DPPP_MESH_LIST[mod.mesh_idx])(output, memory, tgt_mask, + use_cache, cache) + auto.shard_tensor(output, DPPP_MESH_LIST[mod.mesh_idx], [ + "x" + ].extends([None for i in range(len(output.shape) - 1)])) elif _global_parallel_strategy == "mp_pp": output, new_cache = auto.shard_op( mod, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache) - auto.shard_tensor( - output, - dist_attr={ - "process_mesh": - MPPP_MESH_LIST[mod.mesh_idx], - "dims_mapping": - [-1] + [-1 for i in range(len(output.shape) - 1)] - }) + MPPP_MESH_LIST[mod.mesh_idx])(output, memory, tgt_mask, + use_cache, cache) + auto.shard_tensor(output, MPPP_MESH_LIST[mod.mesh_idx], + [None for i in range(len(output.shape))]) elif _global_parallel_strategy == "dp_mp_pp": output, new_cache = auto.shard_op( - mod, - dist_attr={ - "process_mesh": DPMPPP_MESH_LIST[mod.mesh_idx] - })(output, memory, tgt_mask, use_cache, cache) - auto.shard_tensor( - output, - dist_attr={ - "process_mesh": - DPMPPP_MESH_LIST[mod.mesh_idx], - "dims_mapping": - [0] + [-1 for i in range(len(output.shape) - 1)] - }) + mod, DPMPPP_MESH_LIST[mod.mesh_idx])(output, memory, + tgt_mask, + use_cache, cache) + auto.shard_tensor(output, DPMPPP_MESH_LIST[mod.mesh_idx], [ + "x" + ].extends([None for i in range(len(output.shape) - 1)])) else: output, new_cache = mod(output, memory, @@ -661,55 +515,30 @@ class TransformerDecoderLayer(nn.Layer): if self.normalize_before: tgt = self.norm2(tgt) if _global_parallel_strategy == "mp": - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) + auto.shard_tensor(self.linear1.weight, _global_process_mesh, + [None, "x"]) elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + auto.shard_tensor(self.linear1.weight, _global_process_mesh, + [None, "y"]) elif _global_parallel_strategy == "mp_pp": auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [-1, 0] - }) + MPPP_MESH_LIST[self.mesh_idx], [None, "x"]) if _global_parallel_strategy == "dp_mp_pp": auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": - DPMPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [-1, 1] - }) + DPMPPP_MESH_LIST[self.mesh_idx], [None, "y"]) + if _global_parallel_strategy == "mp": - auto.shard_tensor(self.linear2.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(self.linear2.weight, _global_process_mesh, + ["x", None]) elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.linear2.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + auto.shard_tensor(self.linear2.weight, _global_process_mesh, + ["y", None]) elif _global_parallel_strategy == "mp_pp": auto.shard_tensor(self.linear2.weight, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [0, -1] - }) + MPPP_MESH_LIST[self.mesh_idx], ["x", None]) elif _global_parallel_strategy == "dp_mp_pp": auto.shard_tensor(self.linear2.weight, - dist_attr={ - "process_mesh": - DPMPPP_MESH_LIST[self.mesh_idx], - "dims_mapping": [1, -1] - }) + DPMPPP_MESH_LIST[self.mesh_idx], ["y", None]) tgt = self.dropout2( self.linear2(F.gelu(self.linear1(tgt), approximate=True))) tgt = residual + tgt @@ -757,29 +586,18 @@ class GPTEmbeddings(nn.Layer): position_ids = seq_length - ones input_embedings = self.word_embeddings(input_ids) if _global_parallel_strategy == "mp": - auto.shard_tensor(self.word_embeddings.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(self.word_embeddings.weight, _global_process_mesh, + ["x", None]) elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.word_embeddings.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + auto.shard_tensor(self.word_embeddings.weight, _global_process_mesh, + ["y", None]) elif _global_parallel_strategy == "mp_pp": - auto.shard_tensor(self.word_embeddings.weight, - dist_attr={ - "process_mesh": MPPP_MESH_LIST[0], - "dims_mapping": [0, -1] - }) + auto.shard_tensor(self.word_embeddings.weight, MPPP_MESH_LIST[0], + ["x", None]) elif _global_parallel_strategy == "dp_mp_pp": - auto.shard_tensor(self.word_embeddings.weight, - dist_attr={ - "process_mesh": DPMPPP_MESH_LIST[0], - "dims_mapping": [1, -1] - }) + auto.shard_tensor(self.word_embeddings.weight, DPMPPP_MESH_LIST[0], + ["y", None]) + position_embeddings = self.position_embeddings(position_ids) embeddings = input_embedings + position_embeddings embeddings = self.dropout(embeddings) @@ -868,29 +686,14 @@ class GPTModel(nn.Layer): embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids) if _global_parallel_strategy == "pp": - auto.shard_tensor(input_ids, - dist_attr={ - "process_mesh": - PP_MESH_LIST[0], - "dims_mapping": - [-1 for i in range(len(input_ids.shape))] - }) + auto.shard_tensor(input_ids, PP_MESH_LIST[0], + [None for i in range(len(input_ids.shape))]) if _global_parallel_strategy == "dp_pp": - auto.shard_tensor(input_ids, - dist_attr={ - "process_mesh": - DPPP_MESH_LIST[0], - "dims_mapping": [0] + - [-1 for i in range(len(input_ids.shape) - 1)] - }) + auto.shard_tensor(input_ids, DPPP_MESH_LIST[0], ["x"].extends( + [None for i in range(len(input_ids.shape) - 1)])) if _global_parallel_strategy == "dp_mp_pp": - auto.shard_tensor(input_ids, - dist_attr={ - "process_mesh": - DPMPPP_MESH_LIST[0], - "dims_mapping": [0] + - [-1 for i in range(len(input_ids.shape) - 1)] - }) + auto.shard_tensor(input_ids, DPMPPP_MESH_LIST[0], ["x"].extends( + [None for i in range(len(input_ids.shape) - 1)])) encoder_outputs = self.decoder(embedding_output, memory=None, tgt_mask=attention_mask, @@ -923,6 +726,10 @@ class GPTForPretraining(nn.Layer): masked_positions=None, use_cache=False, cache=None): + input_ids.stop_gradient = True + position_ids.stop_gradient = True + attention_mask.stop_gradient = True + outputs = self.gpt(input_ids, position_ids=position_ids, attention_mask=attention_mask, @@ -936,40 +743,42 @@ class GPTForPretraining(nn.Layer): x = encoder_outputs w = self.gpt.embeddings.word_embeddings.weight - mesh = _global_process_mesh - x_dims_mapping = [-1 for i in range(len(x.shape))] - w_dims_mapping = [-1 for i in range(len(w.shape))] + mesh = None if _global_parallel_strategy == "pp": mesh = PP_MESH_LIST[-1] + x_dims_mapping = [None for i in range(len(x.shape))] + w_dims_mapping = [None for i in range(len(w.shape))] elif _global_parallel_strategy == "dp": - x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)] + mesh = _global_process_mesh + x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)] + w_dims_mapping = [None for i in range(len(w.shape))] elif _global_parallel_strategy == "mp": - w_dims_mapping = [0] + [-1 for i in range(len(w.shape) - 1)] + mesh = _global_process_mesh + x_dims_mapping = [None for i in range(len(x.shape))] + w_dims_mapping = ["x"] + [None for i in range(len(w.shape) - 1)] elif _global_parallel_strategy == "dp_mp": - x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)] - w_dims_mapping = [1] + [-1 for i in range(len(w.shape) - 1)] + mesh = _global_process_mesh + x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)] + w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)] elif _global_parallel_strategy == "dp_pp": mesh = DPPP_MESH_LIST[-1] - x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)] + x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)] + w_dims_mapping = [None for i in range(len(w.shape))] elif _global_parallel_strategy == "mp_pp": mesh = MPPP_MESH_LIST[-1] - w_dims_mapping = [0] + [-1 for i in range(len(w.shape) - 1)] + x_dims_mapping = [None for i in range(len(x.shape))] + w_dims_mapping = ["x"] + [-1 for i in range(len(w.shape) - 1)] elif _global_parallel_strategy == "dp_mp_pp": mesh = DPMPPP_MESH_LIST[-1] - x_dims_mapping = [0] + [-1 for i in range(len(x.shape) - 1)] - w_dims_mapping = [1] + [-1 for i in range(len(w.shape) - 1)] - - matmul = auto.shard_op(paddle.matmul, - dist_attr={ - 'process_mesh': mesh, - x: { - "dims_mapping": x_dims_mapping - }, - w: { - "dims_mapping": w_dims_mapping - } - }) - logits = matmul(x, w, transpose_y=True) + x_dims_mapping = ["x"] + [None for i in range(len(x.shape) - 1)] + w_dims_mapping = ["y"] + [None for i in range(len(w.shape) - 1)] + + if mesh: + matmul = auto.shard_op(paddle.matmul, mesh, + [x_dims_mapping, w_dims_mapping, None]) + logits = matmul(x, w, transpose_y=True) + else: + logits = paddle.matmul(x, w, transpose_y=True) if use_cache: return logits, cached_kvs @@ -988,25 +797,29 @@ class GPTPretrainingCriterion(nn.Layer): self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none") def forward(self, prediction_scores, masked_lm_labels, loss_mask): + masked_lm_labels.stop_gradient = True + loss_mask.stop_gradient = True - mesh = _global_process_mesh - dims_mapping = [-1 for i in range(len(loss_mask.shape))] + mesh = None if _global_parallel_strategy == "dp": - dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)] + mesh = _global_process_mesh + dims_mapping = ["x" + ] + [None for i in range(len(loss_mask.shape) - 1)] elif _global_parallel_strategy == "dp_mp": - dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)] + mesh = _global_process_mesh + dims_mapping = ["x" + ] + [None for i in range(len(loss_mask.shape) - 1)] elif _global_parallel_strategy == "dp_pp": mesh = DPPP_MESH_LIST[-1] - dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)] + dims_mapping = ["x" + ] + [None for i in range(len(loss_mask.shape) - 1)] elif _global_parallel_strategy == "dp_mp_pp": mesh = DPMPPP_MESH_LIST[-1] - dims_mapping = [0] + [-1 for i in range(len(loss_mask.shape) - 1)] + dims_mapping = ["x" + ] + [None for i in range(len(loss_mask.shape) - 1)] - auto.shard_tensor(loss_mask, - dist_attr={ - "process_mesh": mesh, - "dims_mapping": dims_mapping - }) + if mesh: + auto.shard_tensor(loss_mask, mesh, dims_mapping) masked_lm_loss = self.loss_func(prediction_scores, masked_lm_labels.unsqueeze(2)) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_save_load.py b/python/paddle/fluid/tests/unittests/auto_parallel_save_load.py index 12f4cc08b0874e7811468050825ed4a2bf77b76b..e98577f8458b88b1c52541d7ca9cd81c5d07f755 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel_save_load.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_save_load.py @@ -64,38 +64,18 @@ class MLPLayer(nn.Layer): def forward(self, input): if _global_parallel_strategy == "pp": - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [-1, -1] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None]) + auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None]) elif _global_parallel_strategy == "mp": - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(self.linear0.weight, _global_process_mesh, + [None, "x"]) + auto.shard_tensor(self.linear1.weight, _global_process_mesh, + ["x", None]) elif _global_parallel_strategy == "dp": - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, -1] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(self.linear0.weight, _global_process_mesh, + [None, None]) + auto.shard_tensor(self.linear1.weight, _global_process_mesh, + [None, None]) out = self.norm(input) out = self.linear0(out) @@ -119,28 +99,12 @@ def mlp_forward(train_program, start_program): dtype='float32') if _global_parallel_strategy == "pp": - auto.shard_tensor(input, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [-1, -1] - }) - auto.shard_tensor(label, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(input, PP_MESH_0, [None, None]) + auto.shard_tensor(label, PP_MESH_1, [None, None]) elif _global_parallel_strategy == "dp": - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(input, _global_process_mesh, ["x", None]) elif _global_parallel_strategy == "mp": - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(input, _global_process_mesh, [None, None]) mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, @@ -183,7 +147,7 @@ class TestMLPSaveLoad(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh([0, 1]) + _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"]) dist_main_prog, dist_start_prog, loss = get_distributed_program() place = paddle.set_device("gpu") @@ -230,7 +194,7 @@ class TestMLPSaveLoad(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh([0, 1]) + _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"]) dist_main_prog, dist_start_prog, loss = get_distributed_program() @@ -278,11 +242,11 @@ class TestMLPSaveLoad(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "pp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh([0, 1]) + _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"]) global PP_MESH_0 - PP_MESH_0 = auto.ProcessMesh(mesh=[0]) + PP_MESH_0 = auto.ProcessMesh(mesh=[0], dim_names=["x"]) global PP_MESH_1 - PP_MESH_1 = auto.ProcessMesh(mesh=[1]) + PP_MESH_1 = auto.ProcessMesh(mesh=[1], dim_names=["x"]) dist_main_prog, dist_start_prog, loss = get_distributed_program() diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/auto_parallel_parallelizer.py b/python/paddle/fluid/tests/unittests/collective/fleet/auto_parallel_parallelizer.py index 688a31b78de002bd9493ec222edc2cbf38d13d9e..2aa113b55d5c9572f90d8ea0c23c6e7eb724e497 100755 --- a/python/paddle/fluid/tests/unittests/collective/fleet/auto_parallel_parallelizer.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/auto_parallel_parallelizer.py @@ -82,11 +82,7 @@ def mlp_pretrain_forward(train_program, start_program): shape=[batch_size, sequence_len, 1], dtype='float32') - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mappig": [-1, -1, -1] - }) + auto.shard_tensor(input, _global_process_mesh, [None, None, None]) mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, @@ -106,7 +102,7 @@ class TestMLPAutoParallelizer(unittest.TestCase): def test_mlp_serial(self): global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"]) dist_strategy = fleet.DistributedStrategy() dist_strategy.amp = False diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py index ec879e77611cd491a7585d8b673d145577303b79..3091a927a82249743d8671fdaff4210b9e9f1fb7 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/auto_parallel_pass_test_base.py @@ -86,7 +86,7 @@ class AutoPallelPassTestBase(DistPassTestBase): paddle.static.Program()): with paddle.static.scope_guard(scope): with paddle.fluid.unique_name.guard(): - main_prog, startup_prog, inputs, outputs, reader = self.get_model( + main_prog, startup_prog, inputs, outputs, data_loader = self.get_model( place, **kwargs) inputs = self._to_var_names(inputs) outputs = self._to_var_names(outputs) @@ -95,27 +95,57 @@ class AutoPallelPassTestBase(DistPassTestBase): exe = paddle.static.Executor(place) with paddle.static.scope_guard(scope): exe.run(startup_prog) - for batch_id, input_data in enumerate(reader()): - assert len(input_data) == len(inputs), "{} vs {}".format( - len(input_data), len(inputs)) - feed = dict(zip(inputs, input_data)) - fetch_values = exe.run(main_prog, feed=feed, fetch_list=outputs) - if paddle.distributed.get_rank() == 0: - output_dict = OrderedDict(zip(outputs, fetch_values)) - print('batch {}, outputs {}'.format(batch_id, output_dict)) - all_fetch_values.append(fetch_values) + data_loader.start() + batch_id = 0 + while True: + try: + fetch_values = exe.run(main_prog, fetch_list=outputs) + if paddle.distributed.get_rank() == 0: + output_dict = OrderedDict(zip(outputs, fetch_values)) + print('batch {}, outputs {}'.format( + batch_id, output_dict)) + all_fetch_values.append(fetch_values) + batch_id += 1 + except paddle.fluid.core.EOFException: + data_loader.reset() + break with open(dump_file, "wb") as f: pickle.dump(all_fetch_values, f) def get_gpt_model(self, strategy, place, batch_size, sequence_len, vocab_size, **kwargs): + + def gen_data(): + np.random.seed(2021) + for _ in range(10): + tokens = [] + position_ids = [] + attention_mask = [] + labels = [] + loss_mask = [] + for _ in range(batch_size): + tokens.append( + np.random.randint(vocab_size, + size=sequence_len).astype("int64")) + position_ids.append(np.arange(sequence_len).astype("int64")) + attention_mask.append( + [np.tril(np.ones(sequence_len)).astype("float32")]) + labels.append( + np.random.randint(vocab_size, + size=sequence_len).astype("int64")) + loss_mask.append(np.ones(sequence_len).astype("float32")) + + yield tokens, position_ids, attention_mask, labels, loss_mask + modeling.init_global() if strategy == "dp": modeling._global_parallel_strategy = "dp" - modeling._global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) + modeling._global_process_mesh = auto.ProcessMesh(mesh=[0, 1], + dim_names=["x"]) elif strategy == "mp": modeling._global_parallel_strategy = "mp" - modeling._global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) + modeling._global_process_mesh = auto.ProcessMesh(mesh=[0, 1], + dim_names=["x"]) else: raise ValueError("'get_gpt_model' only support dp and mp.") @@ -137,23 +167,17 @@ class AutoPallelPassTestBase(DistPassTestBase): dtype='float32') data_holder = [tokens, position_ids, attention_mask, labels, loss_mask] + data_loader = paddle.fluid.io.DataLoader.from_generator( + feed_list=data_holder, capacity=70, iterable=False) + data_loader.set_batch_generator(gen_data, paddle.static.cuda_places()) + if modeling._global_parallel_strategy == "dp": - auto.shard_tensor(tokens, - dist_attr={ - "process_mesh": modeling._global_process_mesh, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(tokens, modeling._global_process_mesh, + ["x", None]) elif modeling._global_parallel_strategy == "pp": - auto.shard_tensor(tokens, - dist_attr={ - "process_mesh": modeling.PP_MESH_LIST[0], - "dims_mapping": [-1, -1] - }) - auto.shard_tensor(attention_mask, - dist_attr={ - "process_mesh": modeling.PP_MESH_LIST[0], - "dims_mapping": [-1, -1, -1, -1] - }) + auto.shard_tensor(tokens, modeling.PP_MESH_LIST[0], [None, None]) + auto.shard_tensor(attention_mask, modeling.PP_MESH_LIST[0], + [None, None, None, None]) gpt = GPTModel(vocab_size=1000, hidden_size=64, @@ -178,40 +202,21 @@ class AutoPallelPassTestBase(DistPassTestBase): preds = model(tokens, position_ids, attention_mask) criterion = GPTPretrainingCriterion() loss = criterion(preds, labels, loss_mask) - clip = paddle.nn.ClipGradByNorm(clip_norm=1.0) + clip = paddle.nn.ClipGradByNorm(clip_norm=1.0) if kwargs.get('optimizer', None) == "LarsMomentum": optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer( learning_rate=0.001, momentum=0.9) else: - optimizer = paddle.fluid.optimizer.AdamOptimizer( - learning_rate=0.00001, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - grad_clip=clip) + optimizer = paddle.optimizer.Adam(learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=clip) optimizer = fleet.distributed_optimizer(optimizer) startup_program = paddle.static.default_startup_program() _, _, dist_startup_prog, dist_main_prog = optimizer.minimize( loss, startup_program) - def gen_data(): - np.random.seed(2021) - for _ in range(10): - tokens = [] - position_ids = [] - attention_mask = [] - labels = [] - loss_mask = [] - for _ in range(batch_size): - tokens.append( - np.random.randint(vocab_size, size=sequence_len)) - position_ids.append(np.arange(sequence_len)) - attention_mask.append([np.tril(np.ones(sequence_len))]) - labels.append( - np.random.randint(vocab_size, size=sequence_len)) - loss_mask.append(np.ones(sequence_len)) - - yield tokens, position_ids, attention_mask, labels, loss_mask - - return dist_main_prog, dist_startup_prog, data_holder, [loss], gen_data + return dist_main_prog, dist_startup_prog, data_holder, [loss + ], data_loader diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_fp16_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_fp16_pass.py index 5ac78cc5fec4de46b7ad9cc5954936c5b89847c2..4c20153ccbfd99f59af34f6e80469cae0fb0283f 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_fp16_pass.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_fp16_pass.py @@ -20,10 +20,19 @@ import unittest import paddle import paddle.distributed.fleet as fleet from auto_parallel_pass_test_base import AutoPallelPassTestBase -from test_auto_parallel_amp_pass import TestAMPPass -class TestPF16Pass(TestAMPPass): +class TestPF16Pass(AutoPallelPassTestBase): + + def init(self): + if paddle.is_compiled_with_cuda(): + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + self.rtol = 1e-5 + self.atol = 1e-8 + + paddle.seed(2021) + random.seed(2021) + np.random.seed(2021) def apply_passes(self): dist_strategy = fleet.DistributedStrategy() @@ -34,14 +43,30 @@ class TestPF16Pass(TestAMPPass): 'layer_norm', 'gelu', ], - "custom_black_list": ['c_softmax_with_cross_entropy'], - "init_loss_scaling": 32768, - "use_dynamic_loss_scaling": True, - "use_pure_fp16": True + "custom_black_list": + ['c_softmax_with_cross_entropy', 'elementwise_div', 'reduce_sum'], + "init_loss_scaling": + 32768, + "use_dynamic_loss_scaling": + True, + "use_pure_fp16": + True, + "use_fp16_guard": + False } dist_strategy.semi_auto = True fleet.init(is_collective=True, strategy=dist_strategy) + def test_bs_8(self): + self.check_main(gpus=[0, 1], + batch_size=8, + sequence_len=512, + vocab_size=1000) + + def get_model(self, place, batch_size, sequence_len, vocab_size): + return self.get_gpt_model("mp", place, batch_size, sequence_len, + vocab_size) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_gradient_merge_pass.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_gradient_merge_pass.py index 50e18718201865e42782eb45f1f2cca911ce6d9b..8f45b67090e934577895922dba495e769c388d5c 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_gradient_merge_pass.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_auto_parallel_gradient_merge_pass.py @@ -97,11 +97,8 @@ class MLPLayer(nn.Layer): def mlp_forward(input, label, hidden_size): - auto.shard_tensor(input, - dist_attr={ - "process_mesh": [0], - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(input, auto.ProcessMesh([0], dim_names=["x"]), + [None, None]) mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, initializer_range=0.02) @@ -160,6 +157,12 @@ class TestGradientMergePass(AutoPallelPassTestBase): def get_model(self, place, batch_size, hidden_size, max_step): + def gen_data(): + for i in range(max_step): + x_data = input_data[i * batch_size:(i + 1) * batch_size, :] + y_data = label_data[i * batch_size:(i + 1) * batch_size, :] + yield x_data, y_data + train_program = static.Program() startup_program = static.Program() with static.program_guard(train_program, startup_program), \ @@ -171,6 +174,12 @@ class TestGradientMergePass(AutoPallelPassTestBase): shape=[batch_size, 1], dtype='float32') input.stop_gradient = False + data_holder = [input, label] + data_loader = paddle.fluid.io.DataLoader.from_generator( + feed_list=data_holder, capacity=70, iterable=False) + data_loader.set_batch_generator(gen_data, + paddle.static.cuda_places()) + loss = mlp_forward(input, label, hidden_size) optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.01) @@ -181,13 +190,8 @@ class TestGradientMergePass(AutoPallelPassTestBase): input_data = np.random.random(size=(128, hidden_size)).astype('float32') label_data = np.random.random(size=(128, 1)).astype('float32') - def reader(): - for i in range(max_step): - x_data = input_data[i * batch_size:(i + 1) * batch_size, :] - y_data = label_data[i * batch_size:(i + 1) * batch_size, :] - yield x_data, y_data - - return dist_main_prog, dist_startup_prog, [input, label], [loss], reader + return dist_main_prog, dist_startup_prog, [input, + label], [loss], data_loader if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py deleted file mode 100644 index f4a02679b322060212375f97f867f9c89b2b1328..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_api.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import unittest -import paddle -import paddle.fluid as fluid -import paddle.nn as nn -import paddle.distributed as dist -from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context -from paddle.distributed.auto_parallel.process_mesh import ProcessMesh - -paddle.enable_static() - -process_mesh1 = [0, 1, 2, 3] -process_mesh2 = [[0, 1, 2], [3, 4, 5]] - - -class SimpleNet(nn.Layer): - - def __init__(self, vocab_size=128, hidden_size=4): - super(SimpleNet, self).__init__() - self.word_embeddings = nn.Embedding(vocab_size, hidden_size) - self.dense1 = nn.Linear(hidden_size, hidden_size) - self.dense2 = nn.Linear(hidden_size, hidden_size // 2) - - def forward(self, x, y): - # Test shard_tensor interface with dist_attr arg - x = dist.shard_tensor(x, - dist_attr={ - "process_mesh": process_mesh1, - "dims_mapping": [0, -1] - }) - emb_out = self.word_embeddings(x) - # Test shard_tensor interface with no dist_attr arg - y = dist.shard_tensor(y) - linear1 = self.dense1(y) - out = self.dense2(linear1) - - return x, y - - -class TestAutoParallelAPI(unittest.TestCase): - - def test_api(self): - dist_context = get_default_distributed_context() - - net = SimpleNet() - data1 = fluid.layers.fill_constant(shape=[2, 4], value=1, dtype="int64") - data2 = fluid.layers.fill_constant(shape=[2, 4], - value=2, - dtype="float32") - data3 = fluid.layers.fill_constant(shape=[2, 4], - value=4, - dtype="float32") - - x, y = net.forward(data1, data2) - - dist_x = dist_context.get_dist_tensor_for_program(x) - self.assertEqual(dist_x.dist_attr.process_mesh.processes, process_mesh1) - self.assertEqual(dist_x.dist_attr.dims_mapping, [0, -1]) - self.assertEqual(dist_x.dist_attr.shard_sizes, None) - self.assertEqual(dist_x.dist_attr.device_placement, None) - self.assertTrue(dist_x.dist_attr.is_annotated("process_mesh")) - self.assertTrue(dist_x.dist_attr.is_annotated("dims_mapping")) - self.assertFalse(dist_x.dist_attr.is_annotated("shard_sizes")) - self.assertFalse(dist_x.dist_attr.is_annotated("device_placement")) - - dist_y = dist_context.get_dist_tensor_for_program(y) - self.assertEqual(dist_y.dist_attr.process_mesh, None) - self.assertEqual(dist_y.dist_attr.dims_mapping, [-1, -1]) - self.assertEqual(dist_y.dist_attr.shard_sizes, None) - self.assertEqual(dist_y.dist_attr.device_placement, None) - self.assertFalse(dist_y.dist_attr.is_annotated("process_mesh")) - self.assertFalse(dist_y.dist_attr.is_annotated("dims_mapping")) - self.assertFalse(dist_y.dist_attr.is_annotated("shard_sizes")) - self.assertFalse(dist_y.dist_attr.is_annotated("device_placement")) - - # Test shard_op interface with dist_attr - dims_mapping1 = [0, 1] - dims_mapping2 = [-1, 0] - dist_add = dist.shard_op(paddle.add, - dist_attr={ - data2: { - "process_mesh": process_mesh2, - "dims_mapping": dims_mapping1 - }, - data3: { - "dims_mapping": dims_mapping2 - } - }) - results = dist_add(data2, data3) - ops = paddle.static.default_main_program().block(0).ops - last_op = ops[-1] - - dist_op = dist_context.get_dist_op_for_program(last_op) - self.assertEqual(dist_op.dist_attr.process_mesh, - ProcessMesh(process_mesh2)) - self.assertEqual(dist_op.dist_attr.impl_type, "default") - self.assertEqual(dist_op.dist_attr.impl_idx, 0) - self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh")) - - data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name) - self.assertEqual(data2_dist_attr.process_mesh, - dist_op.dist_attr.process_mesh) - self.assertEqual(data2_dist_attr.dims_mapping, dims_mapping1) - self.assertEqual(data2_dist_attr.shard_sizes, None) - self.assertEqual(data2_dist_attr.device_placement, None) - self.assertTrue(data2_dist_attr.is_annotated("process_mesh")) - self.assertTrue(data2_dist_attr.is_annotated("dims_mapping")) - self.assertFalse(data2_dist_attr.is_annotated("shard_sizes")) - self.assertFalse(data2_dist_attr.is_annotated("device_placement")) - - data3_dist_attr = dist_op.dist_attr.get_input_dist_attr(data3.name) - self.assertEqual(data3_dist_attr.process_mesh, - dist_op.dist_attr.process_mesh) - self.assertEqual(data3_dist_attr.dims_mapping, dims_mapping2) - self.assertEqual(data3_dist_attr.shard_sizes, None) - self.assertEqual(data3_dist_attr.device_placement, None) - self.assertTrue(data3_dist_attr.is_annotated("process_mesh")) - self.assertTrue(data3_dist_attr.is_annotated("dims_mapping")) - self.assertFalse(data3_dist_attr.is_annotated("shard_sizes")) - self.assertFalse(data3_dist_attr.is_annotated("device_placement")) - - # Test shard_op interface with dist_attr - dist_add = dist.shard_op(paddle.add) - results = dist_add(data2, data3) - ops = paddle.static.default_main_program().block(0).ops - last_op = ops[-1] - dist_op = dist_context.get_dist_op_for_program(last_op) - self.assertEqual(dist_op.dist_attr.process_mesh, None) - self.assertEqual(dist_op.dist_attr.impl_type, "default") - self.assertEqual(dist_op.dist_attr.impl_idx, 0) - self.assertFalse(dist_op.dist_attr.is_annotated("process_mesh")) - - data2_dist_attr = dist_op.dist_attr.get_input_dist_attr(data2.name) - self.assertEqual(data2_dist_attr.process_mesh, - dist_op.dist_attr.process_mesh) - self.assertEqual(data2_dist_attr.dims_mapping, [-1, -1]) - self.assertEqual(data2_dist_attr.shard_sizes, None) - self.assertEqual(data2_dist_attr.device_placement, None) - self.assertFalse(data2_dist_attr.is_annotated("process_mesh")) - self.assertFalse(data2_dist_attr.is_annotated("dims_mapping")) - self.assertFalse(data2_dist_attr.is_annotated("shard_sizes")) - self.assertFalse(data2_dist_attr.is_annotated("device_placement")) - - data3_dist_attr = dist_op.dist_attr.get_input_dist_attr(data3.name) - self.assertEqual(data3_dist_attr.process_mesh, - dist_op.dist_attr.process_mesh) - self.assertEqual(data3_dist_attr.dims_mapping, [-1, -1]) - self.assertEqual(data3_dist_attr.shard_sizes, None) - self.assertEqual(data3_dist_attr.device_placement, None) - self.assertFalse(data3_dist_attr.is_annotated("process_mesh")) - self.assertFalse(data3_dist_attr.is_annotated("dims_mapping")) - self.assertFalse(data3_dist_attr.is_annotated("shard_sizes")) - self.assertFalse(data3_dist_attr.is_annotated("device_placement")) - - -if __name__ == '__main__': - unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py index 393d79557a927ad23844734f598901bf3dfe97c8..e07cc5cef93ade8e85bc4eeabafa1d026656d6c1 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion.py @@ -66,39 +66,13 @@ class MLPLayer(nn.Layer): self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") def forward(self, input): - if _global_parallel_strategy == "mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) - elif _global_parallel_strategy == "pp": - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh2, - "dims_mapping": [1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["mp", None]) out = self.norm(input) out = self.linear0(out) @@ -119,18 +93,10 @@ def mlp_pretrain_forward(train_program, start_program): shape=[batch_size, sequence_len, hidden_size], dtype='float32') - if _global_parallel_strategy == "dp": - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1, -1] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["dp", "dp_mp"]: auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["dp", None, None]) mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, @@ -146,7 +112,8 @@ class TestMLPAutoCompletion(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3], + dim_names=["dp"]) train_program = static.Program() start_program = static.Program() dist_context = DistributedContext() @@ -161,7 +128,8 @@ class TestMLPAutoCompletion(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3], + dim_names=["mp"]) train_program = static.Program() start_program = static.Program() @@ -177,8 +145,9 @@ class TestMLPAutoCompletion(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], + [4, 5, 6, 7]], + dim_names=["dp", "mp"]) train_program = static.Program() start_program = static.Program() @@ -286,18 +255,10 @@ class AttentionLayer(nn.Layer): bias_attr=bias_attr) def forward(self, input): - if _global_parallel_strategy == "dp": + if _global_parallel_strategy in ["dp", "dp_mp"]: auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1, -1] - }) - elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["dp", None, None]) q = self.q_proj(input) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) @@ -306,38 +267,16 @@ class AttentionLayer(nn.Layer): k = self.k_proj(input) v = self.v_proj(input) - if _global_parallel_strategy == "mp": - auto.shard_tensor(self.q_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - auto.shard_tensor(self.k_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - auto.shard_tensor(self.v_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.q_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) auto.shard_tensor(self.k_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) auto.shard_tensor(self.v_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) @@ -369,18 +308,10 @@ class AttentionLayer(nn.Layer): # project to output out = self.out_proj(out) - if _global_parallel_strategy == "mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.out_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.out_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["mp", None]) return out @@ -411,7 +342,8 @@ class TestAttentionAutoCompletion(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3], + dim_names=["dp"]) train_program = static.Program() start_program = static.Program() dist_context = DistributedContext() @@ -420,15 +352,14 @@ class TestAttentionAutoCompletion(unittest.TestCase): completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( train_program) - # print_program_with_dist_attr(complete_train_program, - # dist_context) self.assertTrue(dist_context.validate_dist_attr_for_program()) def test_attn_mp(self): global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3], + dim_names=["mp"]) train_program = static.Program() start_program = static.Program() @@ -444,8 +375,9 @@ class TestAttentionAutoCompletion(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], + [4, 5, 6, 7]], + dim_names=["dp", "mp"]) train_program = static.Program() start_program = static.Program() @@ -542,34 +474,18 @@ class DecoderLayer(nn.Layer): self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") def forward(self, input_ids, position_ids): - if _global_parallel_strategy == "dp": - auto.shard_tensor(input_ids, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["dp", "dp_mp"]: auto.shard_tensor(input_ids, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["dp", None]) input_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) - if _global_parallel_strategy == "mp": - auto.shard_tensor(self.word_embeddings.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.word_embeddings.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["mp", None]) embeddings = input_embeddings + position_embeddings embeddings = self.dropout1(embeddings) @@ -585,38 +501,16 @@ class DecoderLayer(nn.Layer): k = self.k_proj(target) v = self.v_proj(target) - if _global_parallel_strategy == "mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.q_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) auto.shard_tensor(self.k_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) auto.shard_tensor(self.v_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.q_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) - auto.shard_tensor(self.k_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) - auto.shard_tensor(self.v_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) @@ -649,18 +543,10 @@ class DecoderLayer(nn.Layer): # project to output out = self.out_proj(out) - if _global_parallel_strategy == "mp": - auto.shard_tensor(self.out_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.out_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["mp", None]) # Add residual residual = embeddings + self.dropout2(out) @@ -673,28 +559,13 @@ class DecoderLayer(nn.Layer): out2 = F.gelu(out1, approximate=True) out3 = self.linear1(out2) - if _global_parallel_strategy == "mp": - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["mp", None]) # Add residual final = residual + self.dropout3(out3) @@ -732,7 +603,8 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3], + dim_names=["dp"]) train_program = static.Program() start_program = static.Program() dist_context = DistributedContext() @@ -747,7 +619,8 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3], + dim_names=["mp"]) train_program = static.Program() start_program = static.Program() @@ -763,8 +636,9 @@ class TestDecoderLayerAutoCompletion(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], + [4, 5, 6, 7]], + dim_names=["dp", "mp"]) train_program = static.Program() start_program = static.Program() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py index ab110c929f5c542ab6b8b2eba2c189c0106f660a..088b7b636c4184a4e3d75e024bb913505c7ecb6f 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_completion_gpt.py @@ -116,18 +116,10 @@ class MultiHeadAttention(nn.Layer): """ q = self.q_proj(query) - if _global_parallel_strategy == "mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.q_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.q_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) @@ -158,34 +150,15 @@ class MultiHeadAttention(nn.Layer): to construct cache for inference. """ k = self.k_proj(key) - - if _global_parallel_strategy == "mp": - auto.shard_tensor(self.k_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.k_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) - v = self.v_proj(value) - if _global_parallel_strategy == "mp": - auto.shard_tensor(self.v_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: + auto.shard_tensor(self.k_proj.weight, + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) auto.shard_tensor(self.v_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) @@ -265,18 +238,10 @@ class MultiHeadAttention(nn.Layer): # project to output out = self.out_proj(out) - if _global_parallel_strategy == "mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.out_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.out_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["mp", None]) outs = [out] if self.need_weights: @@ -439,31 +404,13 @@ class TransformerDecoderLayer(nn.Layer): if self.normalize_before: tgt = self.norm2(tgt) - if _global_parallel_strategy == "mp": - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) - - if _global_parallel_strategy == "mp": + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) auto.shard_tensor(self.linear2.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.linear2.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["mp", None]) # tgt = self.dropout2( # self.linear2(F.gelu( @@ -523,18 +470,10 @@ class GPTEmbeddings(nn.Layer): input_embedings = self.word_embeddings(input_ids) - if _global_parallel_strategy == "mp": - auto.shard_tensor(self.word_embeddings.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.word_embeddings.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["mp", None]) position_embeddings = self.position_embeddings(position_ids) embeddings = input_embedings + position_embeddings @@ -757,18 +696,10 @@ def gpt_pretrain_forward(train_program, start_program): shape=[batch_size, sequence_len], dtype='float64') - if _global_parallel_strategy == "dp": - auto.shard_tensor(input_ids, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["dp", "dp_mp"]: auto.shard_tensor(input_ids, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["dp", None]) gpt = GPTModel(vocab_size=32768, hidden_size=1024, @@ -801,7 +732,8 @@ class TestGPTAutoCompletion(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3], + dim_names=["dp"]) train_program = static.Program() start_program = static.Program() @@ -817,7 +749,8 @@ class TestGPTAutoCompletion(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3], + dim_names=["mp"]) train_program = static.Program() start_program = static.Program() @@ -833,8 +766,9 @@ class TestGPTAutoCompletion(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], + [4, 5, 6, 7]], + dim_names=["dp", "mp"]) train_program = static.Program() start_program = static.Program() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py index bb8642d569e424269c4e65beb3a5ed117e3a08cd..7b48b921d5cece4e71ac66e107b760664ed3546c 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_cost_model.py @@ -35,8 +35,8 @@ from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr paddle.enable_static() _global_parallel_strategy = "dp_mp_pp" -PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]]) -PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]]) +PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], dim_names=["x", "y"]) +PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], dim_names=["x", "y"]) NUM_RANKS = 8 STAGE_0_CNT = 5 STAGE_1_CNT = 10 @@ -73,16 +73,8 @@ class MLPLayer(nn.Layer): def forward(self, input): if self.is_distributed: - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [-1, 1] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [1, -1] - }) + auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None]) + auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["y", None]) out = self.norm(input) out = self.linear0(out) @@ -135,16 +127,8 @@ def mlp_forward(train_program, start_program, is_distributed=True): dtype='float32') if is_distributed: - auto.shard_tensor(input, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [0, -1] - }) - auto.shard_tensor(label, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(input, PP_MESH_0, ["x", None]) + auto.shard_tensor(label, PP_MESH_1, ["x", None]) mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py index ca69535049c3bc64b8109612c9636599ed456804..63586c234b3558af55d940a59efb7d9dad1886a9 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py @@ -71,7 +71,7 @@ class TestDistributedTensor(unittest.TestCase): def test_new_local_tensor(self): test_auto_parallel_reshard._global_process_mesh = auto.ProcessMesh( - mesh=[0, 1]) + mesh=[0, 1], dim_names=["x"]) test_auto_parallel_reshard._global_parallel_strategy = "dp" train_program = paddle.static.Program() startup_program = paddle.static.Program() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py index 97855c8a8f156c46e4263b8357b1ec34d20eb304..7cc6b64894ebc43e7d1c1290d01c533100868516 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_mapper.py @@ -414,37 +414,25 @@ class MLPLayer(nn.Layer): def forward(self, input): if _global_parallel_strategy == "dp_mp_pp": - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _global_process_mesh[0], - "dims_mapping": [-1, 1] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh[0], - "dims_mapping": [1, -1] - }) - auto.shard_tensor(self.linear2.weight, - dist_attr={ - "process_mesh": _global_process_mesh[1], - "dims_mapping": [-1, 1] - }) - auto.shard_tensor(self.linear3.weight, - dist_attr={ - "process_mesh": _global_process_mesh[1], - "dims_mapping": [1, -1] - }) + auto.shard_tensor(self.linear0.weight, _global_process_mesh[0], + [None, "y"]) + + auto.shard_tensor(self.linear1.weight, _global_process_mesh[0], + ["y", None]) + + auto.shard_tensor(self.linear2.weight, _global_process_mesh[1], + [None, "y"]) + + auto.shard_tensor(self.linear3.weight, _global_process_mesh[1], + ["y", None]) out = self.norm(input) out = self.linear0(out) out = F.gelu(out, approximate=True) out = self.linear1(out) - auto.shard_tensor(out, - dist_attr={ - "process_mesh": _global_process_mesh[1], - "dims_mapping": [0, -1] - }) + auto.shard_tensor(out, _global_process_mesh[1], ["x", None]) + out = self.linear2(out) out = F.gelu(out, approximate=True) out = self.linear3(out) @@ -464,11 +452,7 @@ def mlp_forward(train_program, start_program): dtype='float32') if _global_parallel_strategy == "dp_mp_pp": - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh[0], - "dims_mapping": [0, -1] - }) + auto.shard_tensor(input, _global_process_mesh[0], ["x", None]) mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, initializer_range=0.02) @@ -548,7 +532,10 @@ class TestAutoParallelMapper(unittest.TestCase): global _global_num_stages _global_num_stages = 2 global _global_process_mesh - _global_process_mesh = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] + _global_process_mesh = [ + auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]), + auto.ProcessMesh([[4, 5], [6, 7]], dim_names=["x", "y"]) + ] processes = [0, 1, 2, 3, 4, 5, 6, 7] dist_programs = {} diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py index 80135b62885311c79fc81d51e10075d9b7d1c81c..af0f48e0676499c1fd8a7aced44112bebf433028 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner.py @@ -276,39 +276,20 @@ class MLPLayer(nn.Layer): self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") def forward(self, input): - if _global_parallel_strategy == "mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["mp", None]) else: auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, None]) auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, None]) out = self.norm(input) out = self.linear0(out) @@ -329,18 +310,10 @@ def mlp_pretrain_forward(train_program, start_program): shape=[batch_size, sequence_len, hidden_size], dtype='float32') - if _global_parallel_strategy == "dp": - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1, -1] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["dp", "dp_mp"]: auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["dp", None, None]) mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, @@ -356,7 +329,8 @@ class TestMLPAutoPartitioner(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3], + dim_names=["dp"]) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( mlp_pretrain_forward) @@ -391,7 +365,8 @@ class TestMLPAutoPartitioner(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3], + dim_names=["mp"]) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( mlp_pretrain_forward) @@ -453,8 +428,9 @@ class TestMLPAutoPartitioner(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], + [4, 5, 6, 7]], + dim_names=["dp", "mp"]) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( mlp_pretrain_forward) @@ -558,18 +534,10 @@ class AttentionLayer(nn.Layer): bias_attr=bias_attr) def forward(self, input): - if _global_parallel_strategy == "dp": + if _global_parallel_strategy in ["dp", "dp_mp"]: auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1, -1] - }) - elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["dp", None, None]) q = self.q_proj(input) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) @@ -578,38 +546,16 @@ class AttentionLayer(nn.Layer): k = self.k_proj(input) v = self.v_proj(input) - if _global_parallel_strategy == "mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.q_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) auto.shard_tensor(self.k_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) auto.shard_tensor(self.v_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.q_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) - auto.shard_tensor(self.k_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) - auto.shard_tensor(self.v_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) @@ -641,18 +587,11 @@ class AttentionLayer(nn.Layer): # project to output out = self.out_proj(out) - if _global_parallel_strategy == "mp": - auto.shard_tensor(self.out_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": + + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.out_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["mp", None]) return out @@ -683,7 +622,8 @@ class TestAttentionAutoPartitioner(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3], + dim_names=["dp"]) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( attn_pretrain_forward) @@ -717,7 +657,8 @@ class TestAttentionAutoPartitioner(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3]) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1, 2, 3], + dim_names=["mp"]) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( attn_pretrain_forward) @@ -783,8 +724,9 @@ class TestAttentionAutoPartitioner(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], + [4, 5, 6, 7]], + dim_names=["dp", "mp"]) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( attn_pretrain_forward) @@ -930,34 +872,18 @@ class DecoderLayer(nn.Layer): self.dropout3 = nn.Dropout(self.dropout_ratio, mode="upscale_in_train") def forward(self, input_ids, position_ids): - if _global_parallel_strategy == "dp": + if _global_parallel_strategy in ["dp", "dp_mp"]: auto.shard_tensor(input_ids, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(input_ids, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["dp", None]) input_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) - if _global_parallel_strategy == "mp": - auto.shard_tensor(self.word_embeddings.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.word_embeddings.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["mp", None]) embeddings = input_embeddings + position_embeddings embeddings = self.dropout1(embeddings) @@ -973,38 +899,16 @@ class DecoderLayer(nn.Layer): k = self.k_proj(target) v = self.v_proj(target) - if _global_parallel_strategy == "mp": - auto.shard_tensor(self.q_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - auto.shard_tensor(self.k_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - auto.shard_tensor(self.v_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.q_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) auto.shard_tensor(self.k_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) auto.shard_tensor(self.v_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) @@ -1037,24 +941,14 @@ class DecoderLayer(nn.Layer): # project to output out = self.out_proj(out) - if _global_parallel_strategy == "mp": - auto.shard_tensor(self.out_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.out_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["mp", None]) else: auto.shard_tensor(self.out_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, None]) # Add residual residual = embeddings + self.dropout2(out) @@ -1067,28 +961,13 @@ class DecoderLayer(nn.Layer): out2 = F.gelu(out1, approximate=True) out3 = self.linear1(out2) - if _global_parallel_strategy == "mp": - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["mp", None]) # Add residual final = residual + self.dropout3(out3) @@ -1126,8 +1005,9 @@ class TestDecoderLayerPartitioner(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp_mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], + [4, 5, 6, 7]], + dim_names=["dp", "mp"]) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( decoder_pretrain_forward) @@ -1208,8 +1088,9 @@ class TestDecoderLayerPartitioner(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "None" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], + [4, 5, 6, 7]], + dim_names=["x", "y"]) serial_main_prog, serial_startup_prog, dist_main_prog, dist_startup_prog, dist_context = get_programs( decoder_pretrain_forward) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py index 00ba2151fcba51aea9b8125b72f95a523f1a8dd3..b01959af2986ecc940b9982c81d280a566fbf113 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_partitioner_gpt.py @@ -163,18 +163,10 @@ class MultiHeadAttention(nn.Layer): """ q = self.q_proj(query) - if _global_parallel_strategy == "mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.q_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.q_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) q = tensor.transpose(x=q, perm=[0, 2, 1, 3]) @@ -205,34 +197,15 @@ class MultiHeadAttention(nn.Layer): to construct cache for inference. """ k = self.k_proj(key) - - if _global_parallel_strategy == "mp": - auto.shard_tensor(self.k_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.k_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) - v = self.v_proj(value) - if _global_parallel_strategy == "mp": - auto.shard_tensor(self.v_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: + auto.shard_tensor(self.k_proj.weight, + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) auto.shard_tensor(self.v_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) k = tensor.transpose(x=k, perm=[0, 2, 1, 3]) @@ -312,18 +285,10 @@ class MultiHeadAttention(nn.Layer): # project to output out = self.out_proj(out) - if _global_parallel_strategy == "mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.out_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.out_proj.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["mp", None]) outs = [out] if self.need_weights: @@ -486,31 +451,13 @@ class TransformerDecoderLayer(nn.Layer): if self.normalize_before: tgt = self.norm2(tgt) - if _global_parallel_strategy == "mp": - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 0] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, 1] - }) - - if _global_parallel_strategy == "mp": + process_mesh=_global_process_mesh, + shard_spec=[None, "mp"]) auto.shard_tensor(self.linear2.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": - auto.shard_tensor(self.linear2.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["mp", None]) # tgt = self.dropout2( # self.linear2(F.gelu( @@ -570,18 +517,10 @@ class GPTEmbeddings(nn.Layer): input_embedings = self.word_embeddings(input_ids) - if _global_parallel_strategy == "mp": - auto.shard_tensor(self.word_embeddings.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["mp", "dp_mp"]: auto.shard_tensor(self.word_embeddings.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [1, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["mp", None]) position_embeddings = self.position_embeddings(position_ids) embeddings = input_embedings + position_embeddings @@ -804,18 +743,10 @@ def gpt_pretrain_forward(train_program, startup_program): shape=[batch_size, sequence_len], dtype='float64') - if _global_parallel_strategy == "dp": - auto.shard_tensor(input_ids, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) - elif _global_parallel_strategy == "dp_mp": + if _global_parallel_strategy in ["dp", "dp_mp"]: auto.shard_tensor(input_ids, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) + process_mesh=_global_process_mesh, + shard_spec=["dp", None]) gpt = GPTModel(vocab_size=32768, hidden_size=768, @@ -863,8 +794,9 @@ class TestGPTPartitioner(unittest.TestCase): _global_parallel_strategy = "dp_mp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh( - mesh=[[0, 1, 2, 3], [4, 5, 6, 7]]) + _global_process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2, 3], + [4, 5, 6, 7]], + dim_names=["dp", "mp"]) train_program = static.Program() startup_program = static.Program() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index 51926286acc1517cfe4d5c4291004fa79d6a7fb2..140ed2dae61eb7c1069261713535c1c8988eb641 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -63,27 +63,13 @@ class MLPLayer(nn.Layer): def forward(self, input): if _global_parallel_strategy == "pp": - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [-1, -1] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None]) + auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None]) else: - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, -1] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(self.linear0.weight, _global_process_mesh, + [None, None]) + auto.shard_tensor(self.linear1.weight, _global_process_mesh, + [None, None]) out = self.norm(input) out = self.linear0(out) @@ -107,28 +93,12 @@ def mlp_forward(train_program, start_program): dtype='float32') if _global_parallel_strategy == "pp": - auto.shard_tensor(input, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [-1, -1] - }) - auto.shard_tensor(label, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(input, PP_MESH_0, [None, None]) + auto.shard_tensor(label, PP_MESH_1, [None, None]) elif _global_parallel_strategy == "dp": - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(input, _global_process_mesh, ["x", None]) else: - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(input, _global_process_mesh, [None, None]) mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, @@ -296,11 +266,11 @@ class TestMLPReshard(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "pp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"]) global PP_MESH_0 - PP_MESH_0 = auto.ProcessMesh(mesh=[0]) + PP_MESH_0 = auto.ProcessMesh(mesh=[0], dim_names=["x"]) global PP_MESH_1 - PP_MESH_1 = auto.ProcessMesh(mesh=[1]) + PP_MESH_1 = auto.ProcessMesh(mesh=[1], dim_names=["x"]) train_program = paddle.static.Program() startup_program = paddle.static.Program() @@ -325,11 +295,11 @@ class TestMLPReshard(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "pp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"]) global PP_MESH_0 - PP_MESH_0 = auto.ProcessMesh(mesh=[0]) + PP_MESH_0 = auto.ProcessMesh(mesh=[0], dim_names=["x"]) global PP_MESH_1 - PP_MESH_1 = auto.ProcessMesh(mesh=[1]) + PP_MESH_1 = auto.ProcessMesh(mesh=[1], dim_names=["x"]) train_program = paddle.static.Program() startup_program = paddle.static.Program() @@ -352,7 +322,7 @@ class TestMLPReshard(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = "dp" global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) + _global_process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"]) train_program = paddle.static.Program() startup_program = paddle.static.Program() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py index 33396f283ec0ee78216ad0cff84e453a5bc06764..f77e0db3450e24d88a691c7cddbca96593687524 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_dpmppp.py @@ -34,9 +34,10 @@ from paddle.distributed.auto_parallel.cluster import Cluster paddle.enable_static() _global_parallel_strategy = "dp_mp_pp" -_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]]) -PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]]) -PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]]) +_global_process_mesh = auto.ProcessMesh([[[0, 1], [4, 5]], [[2, 3], [6, 7]]], + dim_names=["x", "y", "z"]) +PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], dim_names=["x", "y"]) +PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], dim_names=["x", "y"]) class MLPLayer(nn.Layer): @@ -63,16 +64,8 @@ class MLPLayer(nn.Layer): self.norm = nn.LayerNorm(d_model, epsilon=1e-5) def forward(self, input): - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [-1, 1] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [1, -1] - }) + auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, "y"]) + auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["y", None]) out = self.norm(input) out = self.linear0(out) @@ -80,11 +73,7 @@ class MLPLayer(nn.Layer): out = self.linear1(out) param = paddle.fluid.layers.create_parameter([1024, 4096], paddle.float32) - auto.shard_tensor(param, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [-1, 1] - }) + auto.shard_tensor(param, PP_MESH_1, [None, "y"]) out = paddle.fluid.layers.mul(out, param) return out @@ -103,16 +92,8 @@ def mlp_forward(train_program, start_program): shape=[batch_size, 1], dtype='float32') - auto.shard_tensor(input, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [0, -1] - }) - auto.shard_tensor(label, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(input, PP_MESH_0, ["x", None]) + auto.shard_tensor(label, PP_MESH_1, ["x", None]) mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py index d5de1c128733197cebbd67ec732a0a0f31df7289..c9dbc77da8a78321c9413011214c35fd8360aa66 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_mppp.py @@ -34,9 +34,9 @@ from paddle.distributed.auto_parallel.cluster import Cluster paddle.enable_static() _global_parallel_strategy = "mp_pp" -_global_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]]) -PP_MESH_0 = auto.ProcessMesh([0, 1]) -PP_MESH_1 = auto.ProcessMesh([2, 3]) +_global_process_mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]) +PP_MESH_0 = auto.ProcessMesh([0, 1], dim_names=["x"]) +PP_MESH_1 = auto.ProcessMesh([2, 3], dim_names=["x"]) class MLPLayer(nn.Layer): @@ -73,35 +73,15 @@ class MLPLayer(nn.Layer): bias_attr=bias_attr) def forward(self, input): - auto.shard_tensor(self.word_embeddings.weight, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [0, -1] - }) - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [-1, 0] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [0, -1] - }) - auto.shard_tensor(self.linear2.weight, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(self.word_embeddings.weight, PP_MESH_0, ["x", None]) + auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, "x"]) + auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["x", None]) + auto.shard_tensor(self.linear2.weight, PP_MESH_1, ["x", None]) w_out = self.word_embeddings(input) out = self.linear0(w_out) param = paddle.fluid.layers.create_parameter([4096, 4096], paddle.float32) - auto.shard_tensor(param, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(param, PP_MESH_0, ["x", None]) out = paddle.fluid.layers.mul(out, param) gelu_out = F.gelu(out, approximate=True) out = self.linear1(gelu_out) @@ -122,16 +102,8 @@ def mlp_forward(train_program, start_program): shape=[batch_size, 1], dtype='float32') - auto.shard_tensor(input, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [-1] - }) - auto.shard_tensor(label, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(input, PP_MESH_0, [None]) + auto.shard_tensor(label, PP_MESH_1, [None, None]) mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, @@ -238,7 +210,6 @@ class TestMLPReshard(unittest.TestCase): resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, dist_context, dist_params_grads) resharder.reshard() - print_program_with_dist_attr(dist_main_prog, dist_context) # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) @@ -249,32 +220,15 @@ class TestMLPReshard(unittest.TestCase): def test_allgather(self): train_program = paddle.static.Program() startup_program = paddle.static.Program() - process_mesh = auto.ProcessMesh(mesh=[0, 1]) + process_mesh = auto.ProcessMesh(mesh=[0, 1], dim_names=["x"]) with static.program_guard(train_program, startup_program): x = paddle.static.data(name="x", shape=[4, 4], dtype='float32') - x = auto.shard_tensor(x, - dist_attr={ - "process_mesh": process_mesh, - "dims_mapping": [0, -1] - }) - + x = auto.shard_tensor(x, process_mesh, ["x", None]) w = paddle.static.data(name="w", shape=[4, 4], dtype='float32') - w = auto.shard_tensor(w, - dist_attr={ - "process_mesh": process_mesh, - "dims_mapping": [-1, -1] - }) - - y = paddle.distributed.shard_op(paddle.matmul, - dist_attr={ - "process_mesh": process_mesh, - x: { - "dims_mapping": [-1, -1] - }, - w: { - "dims_mapping": [-1, -1] - } - })(x, w) + w = auto.shard_tensor(w, process_mesh, [None, None]) + + y = paddle.distributed.shard_op(paddle.matmul, process_mesh, + [[None, None], [None, None]])(x, w) rank_id = 0 dist_context = DistributedContext() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py index 64ff030f5b1e2ad25a057e6ff97c0cddd802397e..e255bcbcc009619853bc6d827066dcea22e97c2e 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py @@ -62,27 +62,13 @@ class MLPLayer(nn.Layer): def forward(self, input): if _global_parallel_strategy == "pp": - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [-1, -1] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None]) + auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None]) else: - auto.shard_tensor(self.linear0.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, -1] - }) - auto.shard_tensor(self.linear1.weight, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(self.linear0.weight, _global_process_mesh, + [None, None]) + auto.shard_tensor(self.linear1.weight, _global_process_mesh, + [None, None]) out = self.norm(input) out = self.linear0(out) @@ -106,28 +92,12 @@ def mlp_forward(train_program, start_program): dtype='float32') if _global_parallel_strategy == "pp": - auto.shard_tensor(input, - dist_attr={ - "process_mesh": PP_MESH_0, - "dims_mapping": [-1, -1] - }) - auto.shard_tensor(label, - dist_attr={ - "process_mesh": PP_MESH_1, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(input, PP_MESH_0, [None, None]) + auto.shard_tensor(label, PP_MESH_1, [None, None]) elif _global_parallel_strategy == "dp": - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [0, -1] - }) + auto.shard_tensor(input, _global_process_mesh, ["x", None]) else: - auto.shard_tensor(input, - dist_attr={ - "process_mesh": _global_process_mesh, - "dims_mapping": [-1, -1] - }) + auto.shard_tensor(input, _global_process_mesh, [None, None]) mlp = MLPLayer(hidden_size=hidden_size, intermediate_size=4 * hidden_size, @@ -196,7 +166,7 @@ class TestMLPReshard(unittest.TestCase): global _global_parallel_strategy _global_parallel_strategy = None global _global_process_mesh - _global_process_mesh = auto.ProcessMesh(mesh=[0]) + _global_process_mesh = auto.ProcessMesh(mesh=[0], dim_names=["x"]) train_program = paddle.static.Program() startup_program = paddle.static.Program()