diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index ee6bee45fd7fe4aae3d95fe162049bcb2a692d89..07b505efbd4c584bda56b3d8455ea0466bd57a25 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -54,8 +54,8 @@ from .interface import _get_fetches class Engine: """ - An Engine object can provide the full power of auto parallel to users. - With the help of it, users can easily obtain the abilities of the + An Engine object can provide the full power of auto parallel to users. + With the help of it, users can easily obtain the abilities of the distributed training and inference. It also support the dynamic graph and static graph at the same time. @@ -63,8 +63,8 @@ class Engine: model (paddle.nn.Layer, optional): The model is an instance of paddle.nn.Layer. loss (Loss|Callable|None, optional): The loss can be a `paddle.nn.Layer` - instance or any callable function taken the predicted values and - ground truth values as input. It can be None when there is no loss. + instance or any callable function taken the predicted values and + ground truth values as input. It can be None when there is no loss. Default: None. optimizer (Optimizer|None, optional): The optimizer need to be set in training and should be None in eval and predict mode. Default: None. @@ -92,17 +92,17 @@ class Engine: valid_dataset = MNIST(mode='test', transform=transform) model = paddle.vision.models.LeNet() - loss = paddle.nn.CrossEntropyLoss() + loss = paddle.nn.CrossEntropyLoss() optimizer = paddle.optimizer.Adam( learning_rate=0.001, parameters=model.parameters()) metrics = paddle.metric.Accuracy(topk=(1, 2)) - engine = auto.Engine(model, loss, optimizer, metrics) - # fit + engine = auto.Engine(model, loss, optimizer, metrics) + # fit engine.fit(train_dataset, epochs=2, batch_size=64) - # evaluate + # evaluate engine.evaluate(valid_dataset, batch_size=64) # predict @@ -110,7 +110,7 @@ class Engine: batch_size=64) # save engine.save("./my_model") - # load + # load engine.load("./my_model") """ @@ -502,32 +502,32 @@ class Engine: train_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None. train_sample_split (int, optional): Each sample of the train dataset is assumed to be a (input, label) pair by default and has two items. If each sample has - more than two items, train_sample_split specifies how to split these items into + more than two items, train_sample_split specifies how to split these items into input and label. The items before it are input and the left are label. Default: None. - batch_size (int, optional): The batch size of train_data and valid_data if provided. + batch_size (int, optional): The batch size of train_data and valid_data if provided. The user's data will be used directly without batching if set to None. Default: 1. epochs (int, optional): The number of epochs to train the model. Default: 1. steps_per_epoch (int, optional): The total number of steps (batches of samples) - is executed in one epoch before stating the next one. If None, it is equal to + is executed in one epoch before stating the next one. If None, it is equal to the number samples in your dataset divided by the batch size. Default: None. valid_data (Dataset, optional): An instance of paddle paddle.io.Dataset used for - evaluation at the end of epoch. No evaluation will be done if set to None. + evaluation at the end of epoch. No evaluation will be done if set to None. Default: None. (Unsupported for now) - valid_freq (int, optional): Only relevant if valid_data is provided. This specifies + valid_freq (int, optional): Only relevant if valid_data is provided. This specifies how many training epochs before a new evaluation is performed. Default: 1. valid_sample_split (int, optional): Only relevant if valid_data is provided. - Each sample of the valid dataset is assumed to be a (input, label) pair - by default and has two items. If each sample has more than two items, + Each sample of the valid dataset is assumed to be a (input, label) pair + by default and has two items. If each sample has more than two items, valid_sample_split specifies how to split these items into input and label. The items before it are input and the left are label. Default: None. valid_steps (int, optional): Only relevant if valid_data is provided. - It is the total number of steps (batches of samples) to draw before - stopping validation at the end of every epoch. If None, validation will run until the + It is the total number of steps (batches of samples) to draw before + stopping validation at the end of every epoch. If None, validation will run until the `valid_data` dataset is exhausted. The validation will start from the beginning of the dataset at each epoch. Default: None. collate_fn(callable, optional): function to generate mini-batch data by merging the sample list, None for only stack each fields of sample in axis - 0. Default None. + 0. Default None. callbacks (Callback|None, optional): A list of `Callback` instances to apply during training. Default: None. (Unused for now) @@ -550,12 +550,12 @@ class Engine: train_dataset = MNIST(mode='train', transform=transform) model = paddle.vision.models.LeNet() - loss = paddle.nn.CrossEntropyLoss() + loss = paddle.nn.CrossEntropyLoss() optimizer = paddle.optimizer.Adam( learning_rate=0.001, parameters=model.parameters()) metrics = paddle.metric.Accuracy(topk=(1, 2)) - engine = auto.Engine(model, loss, optimizer, metrics) + engine = auto.Engine(model, loss, optimizer, metrics) engine.fit(train_dataset, epochs=2, batch_size=64) @@ -636,15 +636,15 @@ class Engine: Evaluate the loss and metrics of the model on evaluation data. Args: - eval_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None. - eval_sample_split (int, optional): Each sample of the eval dataset is assumed + valid_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None. + valid_sample_split (int, optional): Each sample of the eval dataset is assumed to be a (input, label) pair by default and has two items. If each sample has - more than two items, eval_sample_split specifies how to split these items into + more than two items, valid_sample_split specifies how to split these items into input and label. The items before it are input and the left are label. Default: None. - batch_size (int, optional): The batch size of eval_data. The user's data will + batch_size (int, optional): The batch size of valid_data. The user's data will be used directly without batching if set to None. Default: 1. - steps (int, optional): It is the total number of steps (batches of samples) to draw before - stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted. + steps (int, optional): It is the total number of steps (batches of samples) to draw before + stopping evaluation. If None, evaluation will run until the `valid_data` dataset is exhausted. The evaluation will start from the beginning of the dataset in each run. Default: None. collate_fn(callable, optional): function to generate mini-batch data by merging the sample list, None for only stack each fields of sample in axis @@ -671,10 +671,10 @@ class Engine: valid_dataset = MNIST(mode='test', transform=transform) model = paddle.vision.models.LeNet() - loss = paddle.nn.CrossEntropyLoss() + loss = paddle.nn.CrossEntropyLoss() metrics = paddle.metric.Accuracy(topk=(1, 2)) - engine = auto.Engine(model, loss, metrics=metrics) + engine = auto.Engine(model, loss, metrics=metrics) engine.evaluate(valid_dataset, batch_size=64) """ @@ -745,12 +745,12 @@ class Engine: test_data (Dataset): An instance of paddle paddle.io.Dataset. Default: None. test_sample_split (int, optional): Each sample of the test dataset is assumed to be a (input, label) pair by default and has two items. If each sample has - more than two items, test_sample_split specifies how to split these items into + more than two items, test_sample_split specifies how to split these items into input and label. The items before it are input and the left are label. Default: None. batch_size (int, optional): The batch size of test_data. The user's data will be used directly without batching if set to None. Default: 1. - steps (int, optional): It is the total number of steps (batches of samples) to draw before - stopping predict. If None, predict will run until the `test_data` dataset is exhausted. + steps (int, optional): It is the total number of steps (batches of samples) to draw before + stopping predict. If None, predict will run until the `test_data` dataset is exhausted. The predict will start from the beginning of the dataset in each run. Default: None. collate_fn(callable, optional): function to generate mini-batch data by merging the sample list, None for only stack each fields of sample in axis @@ -778,7 +778,7 @@ class Engine: model = paddle.vision.models.LeNet() - engine = auto.Engine(model) + engine = auto.Engine(model) engine.predict(valid_dataset, batch_size=64) """ self.mode = 'predict' @@ -1013,8 +1013,8 @@ class Engine: program.set_state_dict(state_dict) def save(self, path, training=True): - """ - Saves the model, parameters, optimizer state to path. + """ + Saves the model, parameters, optimizer state to path. If `training` is set to False, only inference model will be saved. Args: @@ -1045,12 +1045,12 @@ class Engine: train_dataset = MNIST(mode='train', transform=transform) model = paddle.vision.models.LeNet() - loss = paddle.nn.CrossEntropyLoss() + loss = paddle.nn.CrossEntropyLoss() optimizer = paddle.optimizer.Adam( learning_rate=0.001, parameters=model.parameters()) metrics = paddle.metric.Accuracy(topk=(1, 2)) - engine = auto.Engine(model, loss, optimizer, metrics) + engine = auto.Engine(model, loss, optimizer, metrics) engine.fit(train_dataset, epochs=1, batch_size=64) @@ -1084,7 +1084,7 @@ class Engine: Args: path (str): The prefix of files storing the model states and - optimizer states. + optimizer states. strict (bool, optional): Whether to skip the loading of mismatch parameter or raise an error when mismatch happens (not found the parameter in file storing model states of or receives a @@ -1111,12 +1111,12 @@ class Engine: train_dataset = MNIST(mode='train', transform=transform) model = paddle.vision.models.LeNet() - loss = paddle.nn.CrossEntropyLoss() + loss = paddle.nn.CrossEntropyLoss() optimizer = paddle.optimizer.Adam( learning_rate=0.001, parameters=model.parameters()) metrics = paddle.metric.Accuracy(topk=(1, 2)) - engine = auto.Engine(model, loss, optimizer, metrics) + engine = auto.Engine(model, loss, optimizer, metrics) engine.fit(train_dataset, epochs=1, batch_size=64) diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index ad3078c449048e17cb99b23ce9db93f440bc5d03..c6951012ee8633a85d883d16403ac4f1eb8e9db1 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle from paddle.fluid import core from .process_mesh import ProcessMesh from .process_mesh import get_current_process_mesh @@ -31,11 +32,11 @@ def shard_tensor(x, process_mesh=None, shard_spec=None): x (Tensor): the tensor to be sharded. process_mesh (ProcessMesh, optional): An instance of ProcessMesh describes a mesh topology of the used logical processes where the tensor is sharded. If it is None, - the found current process mesh will be used. And an error will be raised if the + the found current process mesh will be used. And an error will be raised if the current process mesh cannot be found. Default: None. shard_spec (list, optional): a list to describe the sharding mapping between `x` and `process_mesh`, which means the dimension `i` of `x` is split across the dimension `shard_spec[i]` of `process_mesh`, - where `None` means that tensor dimension is not split. For example, given a tensor wih + where `None` means that tensor dimension is not split. For example, given a tensor wih the shape [6, 12] and a process mesh with the shape [2, 3] and the dimension names ["x", "y"]: If `shard_spec=["x", "y"]`, each shard of the tensor will have a shape [3, 4]; If `shard_spec=["y", "x"]`, each shard of the tensor will have a shape [2, 6]; @@ -48,13 +49,13 @@ def shard_tensor(x, process_mesh=None, shard_spec=None): In the above example, the `shard_spec=None` is same as 'shard_spec=[None, None]'. Defaults: None. Returns: - Tensor: the tensor `x` annotated with sharding information. + Tensor: the tensor `x` annotated with sharding information. Examples: .. code-block:: python import paddle - import paddle.distributed.auto_parallel as auto + import paddle.distributed.auto_parallel as auto mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]) x = paddle.ones([4, 6]) @@ -111,13 +112,13 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None): in_shard_specs (list of list, optional): a list of list to describe the sharding specifications for the inputs. Each item of `in_shard_specs` is a `shard_spec` between the correspoinding input and `process_mesh`. If one item is None, the cooresponding input is replicated across all processes - If it is None, all inputs are replicated accross all processes. Note that the lenght of the + If it is None, all inputs are replicated accross all processes. Note that the lenght of the `in_shard_specs` should be equal to the actual number of inputs when calling this operation. Default: None. out_shard_specs (list of list, optional): a list of list to describe the sharding specifications for the outputs. Each item of `out_shard_specs` is a `shard_spec` between the correspoinding output and `process_mesh`. If one item is None, the cooresponding output is replicated across all processes - If it is None, all outputs are replicated accross all processes. Note that the lenght of the + If it is None, all outputs are replicated accross all processes. Note that the lenght of the `in_shard_specs` should be equal to the actual number of inputs when calling this operation. Default: None. Default: None. @@ -128,8 +129,8 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None): .. code-block:: python import paddle - import paddle.distributed.auto_parallel as auto - + import paddle.distributed.auto_parallel as auto + x = paddle.ones([4, 6]) y = paddle.zeros([4, 6]) mesh = auto.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"]) diff --git a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py index 44f504887cf165a73ad5510d2113d557f85fa5e0..e2515cedbd3ea802aa8eaec4e11a0c09626494ad 100644 --- a/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py +++ b/python/paddle/distributed/passes/auto_parallel_data_parallel_optimization.py @@ -111,14 +111,9 @@ class DataParallelOptimizationPass(PassBase): if not self._could_be_fuse(): return [] - with open('./before_program.txt.' + str(paddle.distributed.get_rank()), - 'w') as f: - f.write(str(default_main_program())) grad_group = self._group_grads() self._update_program(grad_group) - with open('./after_program.txt.' + str(paddle.distributed.get_rank()), - 'w') as f: - f.write(str(default_main_program())) + return grad_group def _analyze_program(self): @@ -569,6 +564,11 @@ class GradientsGroup(object): self.remove_scale_op_indices.append(i + 1) if len(self.gradients) == 1: + # TODO Remove this is a temporary hack for Tensor Parallel. the logic + # for find grad_op should be more general. + if self.ops[grad_op_idx].type == "c_allreduce_sum": + grad_op_idx -= 1 + grad_op = self.ops[grad_op_idx] assert grad_var.name in grad_op.output_arg_names, "grad [{}] should be output of {}".format( grad_var.name, str(grad_op))