From 0205e9f84ecb74ab0bcb3e06ba45779440da4c75 Mon Sep 17 00:00:00 2001 From: lilong12 Date: Wed, 10 Mar 2021 15:59:47 +0800 Subject: [PATCH] remove the send/recv of tensor size (#31460) * remove the send/recv of tensor size, but users have to specify the shape of the received var explicitly. --- .../framework/distributed_strategy.proto | 5 +++- .../fluid/operators/collective/recv_v2_op.cc | 8 +++++ .../operators/collective/recv_v2_op.cu.cc | 29 ++----------------- .../operators/collective/send_v2_op.cu.cc | 15 ---------- .../meta_optimizers/pipeline_optimizer.py | 11 +++++-- python/paddle/fluid/optimizer.py | 12 ++++++-- .../tests/unittests/collective_sendrecv_op.py | 5 +++- .../fluid/tests/unittests/pipeline_mnist.py | 1 + .../test_fleet_distributed_strategy.py | 7 +++-- .../test_fleet_pipeline_meta_optimizer.py | 5 +++- 10 files changed, 47 insertions(+), 51 deletions(-) diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 208ab9a93c..300f0eb0db 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -117,7 +117,10 @@ message AsyncConfig { optional int32 lr_decay_steps = 11 [ default = 10 ]; } -message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; } +message PipelineConfig { + optional int32 micro_batch_size = 1 [ default = 1 ]; + optional int32 accumulate_steps = 2 [ default = 1 ]; +} message DistributedStrategy { // bool options diff --git a/paddle/fluid/operators/collective/recv_v2_op.cc b/paddle/fluid/operators/collective/recv_v2_op.cc index 1040882038..0ae7b82161 100644 --- a/paddle/fluid/operators/collective/recv_v2_op.cc +++ b/paddle/fluid/operators/collective/recv_v2_op.cc @@ -40,6 +40,14 @@ class RecvOpV2 : public framework::OperatorWithKernel { "The size of the output shape must be greater than 0 " "but the value given is %d.", out_shape.size())); + for (size_t i = 0; i < out_shape.size(); ++i) { + PADDLE_ENFORCE_GE(out_shape[i], 1, + platform::errors::InvalidArgument( + "The shape attribute for recv_v2 must be set " + "explicitly, but the %dth element is %d which " + "is less than 1.", + i, out_shape[i])); + } ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); } diff --git a/paddle/fluid/operators/collective/recv_v2_op.cu.cc b/paddle/fluid/operators/collective/recv_v2_op.cu.cc index 5b846598b8..7912733fa5 100644 --- a/paddle/fluid/operators/collective/recv_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/recv_v2_op.cu.cc @@ -42,6 +42,7 @@ class RecvOpV2CUDAKernel : public framework::OpKernel { auto out = ctx.Output("Out"); auto out_dims = out->dims(); + auto numel = out->numel(); int data_type = ctx.Attr("dtype"); framework::proto::VarType::Type type = framework::proto::VarType::Type(data_type); @@ -61,34 +62,8 @@ class RecvOpV2CUDAKernel : public framework::OpKernel { platform::errors::InvalidArgument("The value of peer (%d) you set must " "be less than comm->nranks (%d).", peer, comm->nranks())); - ncclDataType_t dtype = platform::ToNCCLDataType(type); - - // Recv the number of elements to receive first - int numel = 0; - int *numel_ptr = nullptr; -#ifdef PADDLE_WITH_RCCL - PADDLE_ENFORCE_CUDA_SUCCESS(hipMalloc(&numel_ptr, sizeof(int))); -#else - PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc(&numel_ptr, sizeof(int))); -#endif - PADDLE_ENFORCE_CUDA_SUCCESS( - platform::dynload::ncclRecv(static_cast(numel_ptr), 1, ncclInt, - peer, comm->comm(), stream)); -#ifdef PADDLE_WITH_RCCL - PADDLE_ENFORCE_CUDA_SUCCESS( - hipMemcpy(&numel, numel_ptr, sizeof(int), hipMemcpyDeviceToHost)); -#else - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaMemcpy(&numel, numel_ptr, sizeof(int), cudaMemcpyDeviceToHost)); -#endif - - int rest_numel = 1; - for (int i = 1; i < out_dims.size(); ++i) { - rest_numel = rest_numel * out_dims[i]; - } - out_dims[0] = numel / rest_numel; out->mutable_data(out_dims, place); - + ncclDataType_t dtype = platform::ToNCCLDataType(type); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( out->data(), numel, dtype, peer, comm->comm(), stream)); VLOG(3) << "rank " << comm->rank() << " recv " diff --git a/paddle/fluid/operators/collective/send_v2_op.cu.cc b/paddle/fluid/operators/collective/send_v2_op.cu.cc index b70124a7bf..c4f5d05e68 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cu.cc @@ -57,21 +57,6 @@ class SendOpV2CUDAKernel : public framework::OpKernel { "be less than comm->nranks (%d).", peer, comm->nranks())); ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); - // Send number of elements to the receiver, as the receiver may have - // no information of the Tensor size. - int* numel_ptr = nullptr; -#ifdef PADDLE_WITH_RCCL - PADDLE_ENFORCE_CUDA_SUCCESS(hipMalloc(&numel_ptr, sizeof(int))); - PADDLE_ENFORCE_CUDA_SUCCESS( - hipMemcpy(numel_ptr, &numel, sizeof(int), hipMemcpyHostToDevice)); -#else - PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc(&numel_ptr, sizeof(int))); - PADDLE_ENFORCE_CUDA_SUCCESS( - cudaMemcpy(numel_ptr, &numel, sizeof(int), cudaMemcpyHostToDevice)); -#endif - - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( - numel_ptr, 1, ncclInt, peer, comm->comm(), stream)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( x->data(), numel, dtype, peer, comm->comm(), stream)); VLOG(3) << "rank " << comm->rank() << " send " diff --git a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py index 9e46bf3368..1b79de03fd 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py @@ -145,8 +145,10 @@ class PipelineOptimizer(MetaOptimizerBase): user_defined_strategy): super(PipelineOptimizer, self)._set_basic_info( loss, role_maker, user_defined_optimizer, user_defined_strategy) + self.micro_batch_size = user_defined_strategy.pipeline_configs[ + 'micro_batch_size'] self.num_microbatches = user_defined_strategy.pipeline_configs[ - 'micro_batch'] + 'accumulate_steps'] def _can_apply(self): if not self.role_maker._is_collective: @@ -162,7 +164,10 @@ class PipelineOptimizer(MetaOptimizerBase): def _enable_strategy(self, dist_strategy, context): dist_strategy.pipeline = True - dist_strategy.pipeline_configs = {"micro_batch": 1, } + dist_strategy.pipeline_configs = { + "micro_batch_size": 1, + "accumulate_steps": 1, + } def minimize_impl(self, loss, @@ -185,6 +190,8 @@ class PipelineOptimizer(MetaOptimizerBase): loss.block.program._pipeline_opt = dict() loss.block.program._pipeline_opt['local_rank'] = self.rank + loss.block.program._pipeline_opt[ + 'micro_batch_size'] = self.micro_batch_size optimize_ops, params_grads, prog_list = self.wrapped_opt.minimize( loss, startup_program, parameter_list, no_grad_set) assert prog_list diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 01a0a78fba..80f49ea939 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4075,12 +4075,15 @@ class PipelineOptimizer(object): break source_var = main_program.block(0).var(var_name) new_var = self._create_var(block, source_var, var_name) + new_var_shape = list(new_var.shape) + new_var_shape[0] = self.micro_batch_size if new_var_shape[ + 0] < 0 else new_var_shape[0] block._insert_op( index=index, type='recv_v2', outputs={'Out': [new_var]}, attrs={ - 'out_shape': new_var.shape, + 'out_shape': new_var_shape, 'dtype': new_var.dtype, self._op_device_key: device, self._op_role_key: self._op_role.Forward, @@ -4243,12 +4246,15 @@ class PipelineOptimizer(object): 'peer': cur_device_index, }) extra_index += 1 + var_shape = list(var.shape) + var_shape[0] = self.micro_batch_size if var_shape[ + 0] < 0 else var_shape[0] block._insert_op( index=index + extra_index, type='recv_v2', outputs={'Out': [var]}, attrs={ - 'out_shape': var.shape, + 'out_shape': var_shape, 'dtype': var.dtype, self._op_device_key: cur_device_spec, self._op_role_key: op_role, @@ -4455,6 +4461,8 @@ class PipelineOptimizer(object): optimize_ops, params_grads = self._optimizer.minimize( loss, startup_program, parameter_list, no_grad_set) self._param_device_map = self._optimizer._param_device_map + self.micro_batch_size = main_block.program._pipeline_opt[ + 'micro_batch_size'] # Step1: add default op_device attribute for regulization and clip ops self._add_opdevice_attr_for_regularization_clip(main_block) diff --git a/python/paddle/fluid/tests/unittests/collective_sendrecv_op.py b/python/paddle/fluid/tests/unittests/collective_sendrecv_op.py index 0a1967aa65..18a7aeccf4 100644 --- a/python/paddle/fluid/tests/unittests/collective_sendrecv_op.py +++ b/python/paddle/fluid/tests/unittests/collective_sendrecv_op.py @@ -46,7 +46,10 @@ class TestCollectiveSendRecv(TestCollectiveRunnerBase): ring_id = self.global_ring_id with fluid.program_guard(main_prog, startup_program): tindata = layers.data( - name="tindata", shape=[10, 1000], dtype='float64') + name="tindata", + shape=[10, 1000], + dtype='float64', + append_batch_size=False) if self.rank == 0: main_prog.global_block().append_op( type="send_v2", diff --git a/python/paddle/fluid/tests/unittests/pipeline_mnist.py b/python/paddle/fluid/tests/unittests/pipeline_mnist.py index 8987646b3e..d06be76b33 100644 --- a/python/paddle/fluid/tests/unittests/pipeline_mnist.py +++ b/python/paddle/fluid/tests/unittests/pipeline_mnist.py @@ -120,6 +120,7 @@ class TestDistMnist2x2(TestDistRunnerBase): fleet.init(is_collective=True) strategy = fleet.DistributedStrategy() strategy.pipeline = True + strategy.pipeline_configs = {'micro_batch_size': batch_size, } dist_opt = fleet.distributed_optimizer( optimizer=opt, strategy=strategy) dist_opt.minimize(avg_cost) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py index 7375049b3c..31771ddbd6 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py @@ -66,9 +66,12 @@ class TestStrategyConfig(unittest.TestCase): def test_pipeline_configs(self): strategy = paddle.distributed.fleet.DistributedStrategy() - configs = {"micro_batch": 4} + configs = {"micro_batch_size": 4} strategy.pipeline_configs = configs - self.assertEqual(strategy.pipeline_configs["micro_batch"], 4) + self.assertEqual(strategy.pipeline_configs["micro_batch_size"], 4) + configs = {"accumulate_steps": 2} + strategy.pipeline_configs = configs + self.assertEqual(strategy.pipeline_configs["accumulate_steps"], 2) def test_localsgd(self): strategy = paddle.distributed.fleet.DistributedStrategy() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer.py index 68702562dd..a9c37d7853 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_pipeline_meta_optimizer.py @@ -48,7 +48,10 @@ class TestFleetMetaOptimizer(unittest.TestCase): strategy = paddle.distributed.fleet.DistributedStrategy() strategy.pipeline = True - strategy.pipeline_configs = {'micro_batch': 2} + strategy.pipeline_configs = { + 'micro_batch_size': 1, + 'accumulate_steps': 2 + } optimizer = paddle.fluid.optimizer.Adam(0.01) optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) -- GitLab