diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto old mode 100755 new mode 100644 index e6a7d74cc43433318ea825927e72c779b14ab43c..654b88920acaf68f1ea5b7b1513735f25255b118 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -139,6 +139,10 @@ message PipelineConfig { optional string schedule_mode = 3 [ default = '1F1B' ]; } +message TensorParallelConfig { + optional int32 tensor_parallel_degree = 1 [ default = 1 ]; +} + message DistributedStrategy { // bool options optional Mode mode = 1 [ default = COLLECTIVE ]; @@ -169,6 +173,7 @@ message DistributedStrategy { optional bool sharding = 26 [ default = false ]; optional float last_comm_group_size_MB = 27 [ default = 1 ]; optional bool find_unused_parameters = 28 [ default = true ]; + optional bool tensor_parallel = 29 [ default = false ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; @@ -182,6 +187,7 @@ message DistributedStrategy { optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110; optional ShardingConfig sharding_configs = 111; optional HybridConfig hybrid_configs = 112; + optional TensorParallelConfig tensor_parallel_configs = 113; optional BuildStrategy build_strategy = 201; optional ExecutionStrategy execution_strategy = 202; } diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index 2756dea72e84a998680cddf53d2150a860ba0e34..32c607ec672a3bfafa070221026bd8ac1e7cadd9 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -692,6 +692,79 @@ def scatter(tensor, tensor_list=None, src=0, group=None, use_calc_stream=True): }) +def _c_identity(tensor, group=0): + """ + Return a copy of the tensor, mainly used with model parallel. + + Args: + tensor (Tensor): The input Tensor. Its data type + should be float16, float32, float64, int32 or int64. + group (int): The id of the process group to work on. + + Returns: + Tensor. + """ + op_type = 'c_identity' + helper = LayerHelper(op_type, **locals()) + out = helper.create_variable_for_type_inference(dtype=tensor.dtype) + if in_dygraph_mode(): + return core.ops.c_identity(out, tensor, 'use_calc_stream', True, + 'ring_id', group, 'use_model_parallel', True) + check_variable_and_dtype( + tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], + '_c_identity') + if not isinstance(group, int): + raise ValueError("The type of 'group' for _c_identity should be int.") + helper.append_op( + type=op_type, + inputs={'X': tensor}, + outputs={'Out': out}, + attrs={ + 'ring_id': group, + 'use_calc_stream': True, + 'use_model_parallel': True, + }) + return out + + +def _c_split(tensor, rank, nranks, group=0): + """ + Split tensor evenly among all members, mainly used with model parallel. + + Args: + tensor (Tensor): The input Tensor. Its data type + should be float16, float32, float64, int32 or int64. + rank (int): The rank of the current process. + group (int): The id of the process group to work on. + + Returns: + Tensor. + """ + op_type = 'c_split' + helper = LayerHelper(op_type, **locals()) + out = helper.create_variable_for_type_inference(dtype=tensor.dtype) + if in_dygraph_mode(): + return core.ops.c_split(out, tensor, 'use_calc_stream', True, 'ring_id', + group, 'rank', rank, 'use_model_parallel', True) + check_variable_and_dtype( + tensor, 'tensor', ['float16', 'float32', 'float64', 'int32', 'int64'], + '_c_split') + if not isinstance(group, int): + raise ValueError("The type of 'group' for _identity should be int.") + helper.append_op( + type=op_type, + inputs={'X': tensor}, + outputs={'Out': out}, + attrs={ + 'ring_id': group, + 'use_calc_stream': True, + 'rank': rank, + 'nranks': nranks, + 'use_model_parallel': True, + }) + return out + + def barrier(group=None): """ @@ -732,15 +805,27 @@ def barrier(group=None): attrs={'ring_id': ring_id}) -def _parallel_linear(x, num_rows, num_cols, axis, param_attr, bias_attr, - gather_out, inner_rank, name): +def _parallel_linear(x, + num_rows, + num_cols, + axis, + param_attr, + bias_attr, + gather_out, + inner_rank, + nranks, + split_tensor, + name, + group=0): """ Parallel Linear """ - if not name: - name = "fc_by_row_rank_%d" % inner_rank if axis == 0 else "fc_by_col_rank_%d" % inner_rank + if axis == 0: + if split_tensor: + x = _c_split(x, inner_rank, nranks, group=group) else: - name = name + "_by_row_rank_%d" % inner_rank if axis == 0 else name + "_by_col_rank_%d" % inner_rank + x = _c_identity(x, group=group) + linear = paddle.nn.Linear( num_rows, num_cols, @@ -748,34 +833,60 @@ def _parallel_linear(x, num_rows, num_cols, axis, param_attr, bias_attr, bias_attr=bias_attr, name=name) - weight = linear.weight - weight.is_distributed = True linear_out = linear(x) startup_block = paddle.static.default_startup_program().global_block() main_block = paddle.static.default_main_program().global_block() - startup_block.vars[weight.name].is_distributed = True - main_block.vars[weight.name].is_distributed = True - - if gather_out: - if axis == 0: - paddle.distributed.all_reduce(linear_out) - else: - output = [] - paddle.distributed.all_gather(output, linear_out) - linear_out = paddle.concat(output, axis=len(linear_out.shape) - 1) - return linear_out + startup_block.vars[linear.weight.name].is_distributed = True + main_block.vars[linear.weight.name].is_distributed = True + + if not gather_out: return linear_out + + op_type = 'c_allreduce_sum' if axis == 0 else 'c_concat' + out_shape = list(linear_out.shape) + out_shape[0] *= 1 if axis == 0 else nranks + out = main_block.create_var( + shape=out_shape, + dtype=linear_out.dtype, + type=linear_out.type, + lod_level=linear_out.lod_level, + persistable=False, + is_data=False, + need_check_feed=linear_out.desc.need_check_feed()) + if axis == 0: + main_block.append_op( + type='c_allreduce_sum', + inputs={'X': linear_out}, + outputs={'Out': out}, + attrs={ + 'ring_id': group, + 'use_calc_stream': True, + 'use_model_parallel': True + }) + else: + main_block.append_op( + type='c_concat', + inputs={'X': linear_out}, + outputs={'Out': out}, + attrs={ + 'ring_id': group, + 'nranks': nranks, + 'use_calc_stream': True, + 'use_model_parallel': True + }) + return out -def _parallel_embedding(x, per_part_embeddings, origin_size, param_attr, - inner_rank, num_partitions, name): +def _parallel_embedding(x, + per_part_embeddings, + origin_size, + param_attr, + inner_rank, + num_partitions, + name, + group=0): """ Parallel Embedding """ - if not name: - name = "emb_rank_%d" % inner_rank - else: - name = name + "_rank_%d" % inner_rank - origin_num_embeddings = origin_size[0] embedding = paddle.nn.Embedding( per_part_embeddings, @@ -795,15 +906,29 @@ def _parallel_embedding(x, per_part_embeddings, origin_size, param_attr, inner_rank, per_part_embeddings - 1) if len(origin_input_shape) == 2: x_shard = paddle.squeeze(x_shard, axis=-1) - - embedding.weight.is_distributed = True emb_out = embedding(x_shard) startup_block = paddle.static.default_startup_program().global_block() main_block = paddle.static.default_main_program().global_block() startup_block.vars[embedding.weight.name].is_distributed = True main_block.vars[embedding.weight.name].is_distributed = True - paddle.distributed.all_reduce(emb_out, group=None) - return emb_out + out = main_block.create_var( + shape=emb_out.shape, + dtype=emb_out.dtype, + type=emb_out.type, + lod_level=emb_out.lod_level, + persistable=False, + is_data=False, + need_check_feed=emb_out.desc.need_check_feed()) + main_block.append_op( + type='c_allreduce_sum', + inputs={'X': emb_out}, + outputs={'Out': out}, + attrs={ + 'ring_id': group, + 'use_calc_stream': True, + 'use_model_parallel': True + }) + return out def split(x, @@ -896,8 +1021,10 @@ def split(x, "paddle.distributed.split must be one of {}.".format( supported_operations)) if in_dygraph_mode(): - rank = paddle.distributed.get_rank() - nranks = paddle.distributed.get_world_size() + raise ValueError( + "paddle.distributed.split cannot be used in dynamic " + "graph mode, plese use ParallelEmbedding, ParallelRowLinear, " + "ParallelColumnLinear instead.") else: assert fleet._role_maker, ("To use paddle.distributed.split, " "you must call fleet.init() firstly.") @@ -915,10 +1042,18 @@ def split(x, if inner_rank == num_partitions - 1: per_part_size = last_part_size per_part_size += 1 # make the last row as the padding index - emb_out = _parallel_embedding(x, per_part_size, size, weight_attr, - inner_rank, num_partitions, name) + emb_out = _parallel_embedding( + x, + per_part_size, + size, + weight_attr, + inner_rank, + num_partitions, + name, + group=0) return emb_out else: + should_split = False if axis == 0: assert size[0] % num_partitions == 0, ( "Number of rows of the weight for linear ({}) must be" @@ -926,11 +1061,7 @@ def split(x, num_partitions)) per_part_size = size[0] // num_partitions linear_size = (per_part_size, size[1]) - assert x.shape[-1] == per_part_size, ( - "The width ({}) of the input " - "x must be equal to the height ({}) of the weight. Maybe you " - "should split the input x using paddle.split.".format( - x.shape[-1], per_part_size)) + if x.shape[-1] == size[0]: should_split = True elif axis == 1: assert size[1] % num_partitions == 0, ( @@ -952,5 +1083,8 @@ def split(x, bias_attr, gather_out, inner_rank, - name=name) + num_partitions, + should_split, + name=name, + group=0) return linear_out diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 443c5a2954b0c571ac2cf2cff54c670900f38dae..9fed3a8550c407491c37a0eab9e7e0b1f96db5ed 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -891,6 +891,58 @@ class DistributedStrategy(object): "pipeline_configs") assign_configs_value(self.strategy.pipeline_configs, configs) + @property + def tensor_parallel(self): + """ + Indicating whether we are using tensor parallel for distributed training. + + Examples: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.tensor_parallel = True + + """ + return self.strategy.tensor_parallel + + @tensor_parallel.setter + @is_strict_auto + def tensor_parallel(self, flag): + if isinstance(flag, bool): + self.strategy.tensor_parallel = flag + else: + print("WARNING: tensor_parallel should have value of bool type") + + @property + def tensor_parallel_configs(self): + """ + Set tensor_parallel configurations. + + **Notes**: + **Detailed arguments for tensor_parallel_configs** + **tensor_parallel_degree**: degree of tensor parallel + + Examples: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.tensor_parallel = True + strategy.tensor_parallel_configs = {"tensor_parallel_degree": 4} + + """ + return get_msg_dict(self.strategy.tensor_parallel_configs) + + @tensor_parallel_configs.setter + @is_strict_auto + def tensor_parallel_configs(self, configs): + check_configs_key(self.strategy.tensor_parallel_configs, configs, + "tensor_parallel_configs") + assign_configs_value(self.strategy.tensor_parallel_configs, configs) + @property def hybrid_configs(self): """ diff --git a/python/paddle/distributed/fleet/meta_optimizers/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/__init__.py index 3be8a479491dc04042e2d333163ff38a0ac67f3b..827835fde20e3e662124b24929d18c53151dbd92 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/__init__.py @@ -27,3 +27,4 @@ from .fp16_allreduce_optimizer import FP16AllReduceOptimizer from .sharding_optimizer import ShardingOptimizer from .dygraph_optimizer import HybridParallelOptimizer from .dygraph_optimizer import HybridParallelGradScaler +from .tensor_parallel_optimizer import TensorParallelOptimizer diff --git a/python/paddle/distributed/fleet/meta_optimizers/tensor_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/tensor_parallel_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..2ba0195156082cf1b72ff95e90feea724113adf9 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/tensor_parallel_optimizer.py @@ -0,0 +1,231 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +from __future__ import print_function +from __future__ import division + +import paddle.fluid as fluid +from paddle.fluid import core, unique_name +from .meta_optimizer_base import MetaOptimizerBase +from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_update_op, is_loss_grad_op, is_backward_op, is_optimizer_op + + +class TensorParallelOptimizer(MetaOptimizerBase): + def __init__(self, optimizer): + super(TensorParallelOptimizer, self).__init__(optimizer) + self.inner_opt = optimizer + self.meta_optimizers_white_list = [ + "RecomputeOptimizer", + "AMPOptimizer", + "LarsOptimizer", + "LambOptimizer", + ] + self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] + self.mp_ring_id = 0 + self.global_ring_id = 1 + self.dp_ring_id = 2 + + def _set_basic_info(self, loss, role_maker, user_defined_optimizer, + user_defined_strategy): + super(TensorParallelOptimizer, self)._set_basic_info( + loss, role_maker, user_defined_optimizer, user_defined_strategy) + self.mp_degree = user_defined_strategy.tensor_parallel_configs[ + 'tensor_parallel_degree'] + + def _can_apply(self): + if not self.role_maker._is_collective: + return False + + if self.user_defined_strategy.tensor_parallel == True: + return True + return False + + def _disable_strategy(self, dist_strategy): + dist_strategy.tensor_parallel = False + dist_strategy.tensor_parallel_configs = {} + + def _enable_strategy(self, dist_strategy, context): + dist_strategy.tensor_parallel = True + dist_strategy.tensor_parallel_configs = {"tensor_parallel_degree": 1, } + + def _broadcast_params(self, ring_id, mp_mode): + block = self.startup_program.global_block() + param = None + for param in block.iter_parameters(): + if param.is_distributed and mp_mode: + continue + + block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': ring_id, + 'root': 0, + OP_ROLE_KEY: OpRole.Forward + }) + + if not param: return # no parameter on this device + block.append_op( + type='c_sync_comm_stream', + inputs={'X': param}, + outputs={'Out': param}, + attrs={'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Forward}) + + def _get_process_group_info(self): + # global ring info + self.global_endpoints = self.endpoints + self.global_rank = self.rank + self.global_nranks = self.nranks + + # model parallel ring info + self.mp_rank = self.rank % self.mp_degree + self.mp_nranks = self.mp_degree + mp_group = self.rank // self.mp_degree + self.mp_endpoints = [ + self.endpoints[i] for i in range(self.global_nranks) + if i // self.mp_degree == mp_group + ] + + # data parallel ring info + if self.nranks > self.mp_degree: + self.dp_rank = self.rank // self.mp_degree + self.dp_nranks = self.nranks // self.mp_degree + start_index = self.rank % self.mp_degree + self.dp_endpoints = [ + self.endpoints[start_index + i * self.mp_degree] + for i in range(self.dp_nranks) + ] + + def _init_process_group(self): + self._get_process_group_info() + collective_helper = CollectiveHelper(self.role_maker, wait_port=False) + + # Create global ring for all gpus + collective_helper._init_communicator( + self.startup_program, self.current_endpoint, self.global_endpoints, + self.global_rank, self.global_ring_id, True, self.global_ring_id, + True) + + # Create model parallel ring for all gpus + collective_helper._init_communicator( + self.startup_program, self.current_endpoint, self.mp_endpoints, + self.mp_rank, self.mp_ring_id, True, self.global_ring_id, True) + #self._broadcast_params(self.mp_ring_id, mp_mode=True) + + # Create dp rings + if self.nranks > self.mp_degree: + collective_helper._init_communicator( + self.startup_program, self.current_endpoint, self.dp_endpoints, + self.dp_rank, self.dp_ring_id, True, self.global_ring_id, True) + self._broadcast_params(self.dp_ring_id, mp_mode=False) + + def minimize_impl(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + self.endpoints = self.role_maker._get_trainer_endpoints() + self.current_endpoint = self.endpoints[self.role_maker._worker_index()] + self.startup_program = startup_program + if startup_program is None: + self.startup_program = fluid.default_startup_program() + + optimize_ops, params_grads = self.inner_opt.minimize( + loss, self.startup_program, parameter_list, no_grad_set) + + self.main_program = loss.block.program + self.nranks = len(self.endpoints) + self.rank = self.role_maker._worker_index() + + self._init_process_group() + + assert self.nranks % self.mp_degree == 0 + + if self.nranks > self.mp_degree: + # data parallelism + dp_degree = self.nranks // self.mp_degree + self._transpile_main_program(loss, dp_degree) + return optimize_ops, params_grads + + def _transpile_main_program(self, loss, dp_degree): + self._insert_loss_grad_ops(loss, dp_degree) + self._insert_allreduce_ops(loss, self.dp_ring_id) + + def _insert_loss_grad_ops(self, loss, dp_degree): + """ + In order to keep the learning rate consistent in different numbers of + training workers, we scale the loss grad by the number of workers + """ + block = loss.block + for idx, op in reversed(list(enumerate(block.ops))): + if is_loss_grad_op(op): + loss_grad_var = block.vars[op.output_arg_names[0]] + block._insert_op( + idx + 1, + type='scale', + inputs={'X': loss_grad_var}, + outputs={'Out': loss_grad_var}, + attrs={ + 'scale': 1.0 / dp_degree, + OP_ROLE_KEY: OpRole.Backward + }) + break + + def _insert_allreduce_ops(self, loss, ring_id): + block = loss.block + grad = None + for idx, op in reversed(list(enumerate(block.ops))): + if is_backward_op(op) and OP_ROLE_VAR_KEY in op.attr_names: + op_role_var = op.attr(OP_ROLE_VAR_KEY) + if len(op_role_var) == 0: + continue + assert len(op_role_var) % 2 == 0 + offset = idx + for i in range(0, len(op_role_var), 2): + param = block.vars[op_role_var[i]] + grad = block.vars[op_role_var[i + 1]] + if offset == idx: + offset += 1 + block._insert_op( + offset, + type='c_sync_calc_stream', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={OP_ROLE_KEY: OpRole.Backward}) + offset += 1 + + block._insert_op( + offset, + type='c_allreduce_sum', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={ + 'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Backward + }) + + if grad is None: + return + + for idx, op in list(enumerate(block.ops)): + if is_optimizer_op(op): + block._insert_op( + idx, + type='c_sync_comm_stream', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Backward}) + break diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index c4d1c6fb3552f3a920822c4c8443f292b2ca7a31..89950fd62e2648ed8a07dba7eb546eedf3bdf0c6 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -11,6 +11,7 @@ endif() string(REPLACE ".py" "" DIST_TEST_OPS "${DIST_TEST_OPS}") list(APPEND DIST_TEST_OPS test_parallel_dygraph_mnist) list(APPEND DIST_TEST_OPS test_pipeline) +list(APPEND DIST_TEST_OPS test_static_model_parallel) list(APPEND DIST_TEST_OPS test_parallel_dygraph_se_resnext) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding) list(APPEND DIST_TEST_OPS test_parallel_dygraph_sparse_embedding_over_height) @@ -869,6 +870,7 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) set_tests_properties(test_new_group_api PROPERTIES TIMEOUT 120) if(WITH_DISTRIBUTE) set_tests_properties(test_pipeline PROPERTIES TIMEOUT 120) + set_tests_properties(test_static_model_parallel PROPERTIES TIMEOUT 240) endif() set_tests_properties(test_reducescatter_api PROPERTIES TIMEOUT 120) set_tests_properties(test_broadcast PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/static_model_parallel_by_col.py b/python/paddle/fluid/tests/unittests/static_model_parallel_by_col.py new file mode 100644 index 0000000000000000000000000000000000000000..416f6bc4f0d417db6ae82380787f2a715b398ca6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/static_model_parallel_by_col.py @@ -0,0 +1,119 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import time +import math + +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +from paddle.fluid import core +import unittest +from multiprocessing import Process +import os +import signal +from functools import reduce +from test_dist_base import TestDistRunnerBase, runtime_main +import paddle.distributed.fleet as fleet + +paddle.enable_static() + +DTYPE = "float32" +MODEL_PARALLEL_SIZE = 2 +IN_SIZE = 2 * MODEL_PARALLEL_SIZE +OUT_SIZE = 2 * MODEL_PARALLEL_SIZE + +# Fix seed for test +#fluid.default_startup_program().random_seed = 1 +#fluid.default_main_program().random_seed = 1 + + +def create_model(data, rank): + np.random.seed(2021) + np_weight = np.random.uniform(-1, 1, size=(IN_SIZE, OUT_SIZE)).astype(DTYPE) + if rank is not None: + start_col = 0 if rank == 0 else OUT_SIZE // 2 + np_weight_part = np_weight[:, start_col:start_col + OUT_SIZE // 2] + result = paddle.distributed.split( + data, + size=(IN_SIZE, OUT_SIZE), + operation='linear', + axis=1, + num_partitions=MODEL_PARALLEL_SIZE, + weight_attr=paddle.ParamAttr( + initializer=fluid.initializer.NumpyArrayInitializer( + np_weight_part)), + bias_attr=False, ) + else: + result = fluid.layers.fc( + data, + size=OUT_SIZE, + param_attr=paddle.ParamAttr( + initializer=fluid.initializer.NumpyArrayInitializer(np_weight)), + bias_attr=False, ) + + predict = paddle.sum(result) + return predict + + +class TestModelParallel(TestDistRunnerBase): + def get_model(self, batch_size=2, use_dgc=False, dist_strategy=None): + # Input data + data_in = fluid.data( + name='data_in', shape=[batch_size, IN_SIZE], dtype=DTYPE) + + if dist_strategy: + data_loader = fluid.io.DataLoader.from_generator( + feed_list=[data_in], + capacity=64, + use_double_buffer=False, + iterable=False) + + if dist_strategy: + fleet.init(is_collective=True) + strategy = fleet.DistributedStrategy() + strategy.tensor_parallel = True + strategy.tensor_parallel_configs = {'tensor_parallel_degree': 2} + + rank = fleet.worker_index() if dist_strategy else None + avg_cost = create_model(data_in, rank) + opt = fluid.optimizer.SGD(0.1) + + if dist_strategy: + dist_opt = fleet.distributed_optimizer( + optimizer=opt, strategy=strategy) + dist_opt.minimize(avg_cost) + else: + opt.minimize(avg_cost) + + def gen_data(): + np.random.seed(2021) + while True: + data = [np.random.random([IN_SIZE]).astype(DTYPE)] + yield data + + train_reader = paddle.batch(gen_data, batch_size=batch_size) + + if dist_strategy: + return None, avg_cost, train_reader, None, None, None, data_loader + else: + return None, avg_cost, train_reader, None, None, None + + +if __name__ == "__main__": + runtime_main(TestModelParallel) diff --git a/python/paddle/fluid/tests/unittests/static_model_parallel_by_row.py b/python/paddle/fluid/tests/unittests/static_model_parallel_by_row.py new file mode 100644 index 0000000000000000000000000000000000000000..4a98792f8a0473ec855dccb0d06c7c5751e72f41 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/static_model_parallel_by_row.py @@ -0,0 +1,119 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import time +import math + +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +from paddle.fluid import core +import unittest +from multiprocessing import Process +import os +import signal +from functools import reduce +from test_dist_base import TestDistRunnerBase, runtime_main +import paddle.distributed.fleet as fleet + +paddle.enable_static() + +DTYPE = "float32" +MODEL_PARALLEL_SIZE = 2 +IN_SIZE = 2 * MODEL_PARALLEL_SIZE +OUT_SIZE = 2 * MODEL_PARALLEL_SIZE + +# Fix seed for test +#fluid.default_startup_program().random_seed = 1 +#fluid.default_main_program().random_seed = 1 + + +def create_model(data, rank): + np.random.seed(2021) + np_weight = np.random.uniform(-1, 1, size=(IN_SIZE, OUT_SIZE)).astype(DTYPE) + if rank is not None: + start_row = 0 if rank == 0 else IN_SIZE // 2 + np_weight_part = np_weight[start_row:start_row + IN_SIZE // 2, :] + result = paddle.distributed.split( + data, + size=(IN_SIZE, OUT_SIZE), + operation='linear', + axis=0, + num_partitions=MODEL_PARALLEL_SIZE, + weight_attr=paddle.ParamAttr( + initializer=fluid.initializer.NumpyArrayInitializer( + np_weight_part)), + bias_attr=False, ) + else: + result = fluid.layers.fc( + data, + size=OUT_SIZE, + param_attr=paddle.ParamAttr( + initializer=fluid.initializer.NumpyArrayInitializer(np_weight)), + bias_attr=False, ) + + predict = paddle.sum(result) + return predict + + +class TestModelParallel(TestDistRunnerBase): + def get_model(self, batch_size=2, use_dgc=False, dist_strategy=None): + # Input data + data_in = fluid.data( + name='data_in', shape=[batch_size, IN_SIZE], dtype=DTYPE) + + if dist_strategy: + data_loader = fluid.io.DataLoader.from_generator( + feed_list=[data_in], + capacity=64, + use_double_buffer=False, + iterable=False) + + if dist_strategy: + fleet.init(is_collective=True) + strategy = fleet.DistributedStrategy() + strategy.tensor_parallel = True + strategy.tensor_parallel_configs = {'tensor_parallel_degree': 2} + + rank = fleet.worker_index() if dist_strategy else None + avg_cost = create_model(data_in, rank) + opt = fluid.optimizer.SGD(0.1) + + if dist_strategy: + dist_opt = fleet.distributed_optimizer( + optimizer=opt, strategy=strategy) + dist_opt.minimize(avg_cost) + else: + opt.minimize(avg_cost) + + def gen_data(): + np.random.seed(2021) + while True: + data = [np.random.random([IN_SIZE]).astype(DTYPE)] + yield data + + train_reader = paddle.batch(gen_data, batch_size=batch_size) + + if dist_strategy: + return None, avg_cost, train_reader, None, None, None, data_loader + else: + return None, avg_cost, train_reader, None, None, None + + +if __name__ == "__main__": + runtime_main(TestModelParallel) diff --git a/python/paddle/fluid/tests/unittests/static_model_parallel_embedding.py b/python/paddle/fluid/tests/unittests/static_model_parallel_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..4a98792f8a0473ec855dccb0d06c7c5751e72f41 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/static_model_parallel_embedding.py @@ -0,0 +1,119 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import argparse +import time +import math + +import paddle +import paddle.fluid as fluid +import paddle.fluid.profiler as profiler +from paddle.fluid import core +import unittest +from multiprocessing import Process +import os +import signal +from functools import reduce +from test_dist_base import TestDistRunnerBase, runtime_main +import paddle.distributed.fleet as fleet + +paddle.enable_static() + +DTYPE = "float32" +MODEL_PARALLEL_SIZE = 2 +IN_SIZE = 2 * MODEL_PARALLEL_SIZE +OUT_SIZE = 2 * MODEL_PARALLEL_SIZE + +# Fix seed for test +#fluid.default_startup_program().random_seed = 1 +#fluid.default_main_program().random_seed = 1 + + +def create_model(data, rank): + np.random.seed(2021) + np_weight = np.random.uniform(-1, 1, size=(IN_SIZE, OUT_SIZE)).astype(DTYPE) + if rank is not None: + start_row = 0 if rank == 0 else IN_SIZE // 2 + np_weight_part = np_weight[start_row:start_row + IN_SIZE // 2, :] + result = paddle.distributed.split( + data, + size=(IN_SIZE, OUT_SIZE), + operation='linear', + axis=0, + num_partitions=MODEL_PARALLEL_SIZE, + weight_attr=paddle.ParamAttr( + initializer=fluid.initializer.NumpyArrayInitializer( + np_weight_part)), + bias_attr=False, ) + else: + result = fluid.layers.fc( + data, + size=OUT_SIZE, + param_attr=paddle.ParamAttr( + initializer=fluid.initializer.NumpyArrayInitializer(np_weight)), + bias_attr=False, ) + + predict = paddle.sum(result) + return predict + + +class TestModelParallel(TestDistRunnerBase): + def get_model(self, batch_size=2, use_dgc=False, dist_strategy=None): + # Input data + data_in = fluid.data( + name='data_in', shape=[batch_size, IN_SIZE], dtype=DTYPE) + + if dist_strategy: + data_loader = fluid.io.DataLoader.from_generator( + feed_list=[data_in], + capacity=64, + use_double_buffer=False, + iterable=False) + + if dist_strategy: + fleet.init(is_collective=True) + strategy = fleet.DistributedStrategy() + strategy.tensor_parallel = True + strategy.tensor_parallel_configs = {'tensor_parallel_degree': 2} + + rank = fleet.worker_index() if dist_strategy else None + avg_cost = create_model(data_in, rank) + opt = fluid.optimizer.SGD(0.1) + + if dist_strategy: + dist_opt = fleet.distributed_optimizer( + optimizer=opt, strategy=strategy) + dist_opt.minimize(avg_cost) + else: + opt.minimize(avg_cost) + + def gen_data(): + np.random.seed(2021) + while True: + data = [np.random.random([IN_SIZE]).astype(DTYPE)] + yield data + + train_reader = paddle.batch(gen_data, batch_size=batch_size) + + if dist_strategy: + return None, avg_cost, train_reader, None, None, None, data_loader + else: + return None, avg_cost, train_reader, None, None, None + + +if __name__ == "__main__": + runtime_main(TestModelParallel) diff --git a/python/paddle/fluid/tests/unittests/test_static_model_parallel.py b/python/paddle/fluid/tests/unittests/test_static_model_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..6f2f7408262d94e2aa97e908fbac1057a900bc2e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_static_model_parallel.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +from test_dist_base import TestDistBase + +import os +import paddle + +paddle.enable_static() +flag_name = os.path.splitext(__file__)[0] + + +class TestStaticModelParallel(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_reduce = False + self._use_reader_alloc = False + self._nccl_comm_num = 1 + self._pipeline_mode = True + + def test_dist_static_model_parallel(self): + import paddle.fluid as fluid + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "static_model_parallel_by_row.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + def test_dist_static_model_parallel2(self): + import paddle.fluid as fluid + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "static_model_parallel_by_col.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + def test_dist_static_model_parallel3(self): + import paddle.fluid as fluid + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "static_model_parallel_embedding.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +if __name__ == '__main__': + unittest.main()