未验证 提交 0205e9f8 编写于 作者: L lilong12 提交者: GitHub

remove the send/recv of tensor size (#31460)

* remove the send/recv of tensor size, but users have to specify the shape of the received var explicitly.
上级 c8ae837d
...@@ -117,7 +117,10 @@ message AsyncConfig { ...@@ -117,7 +117,10 @@ message AsyncConfig {
optional int32 lr_decay_steps = 11 [ default = 10 ]; optional int32 lr_decay_steps = 11 [ default = 10 ];
} }
message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; } message PipelineConfig {
optional int32 micro_batch_size = 1 [ default = 1 ];
optional int32 accumulate_steps = 2 [ default = 1 ];
}
message DistributedStrategy { message DistributedStrategy {
// bool options // bool options
......
...@@ -40,6 +40,14 @@ class RecvOpV2 : public framework::OperatorWithKernel { ...@@ -40,6 +40,14 @@ class RecvOpV2 : public framework::OperatorWithKernel {
"The size of the output shape must be greater than 0 " "The size of the output shape must be greater than 0 "
"but the value given is %d.", "but the value given is %d.",
out_shape.size())); out_shape.size()));
for (size_t i = 0; i < out_shape.size(); ++i) {
PADDLE_ENFORCE_GE(out_shape[i], 1,
platform::errors::InvalidArgument(
"The shape attribute for recv_v2 must be set "
"explicitly, but the %dth element is %d which "
"is less than 1.",
i, out_shape[i]));
}
ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
} }
......
...@@ -42,6 +42,7 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> { ...@@ -42,6 +42,7 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
auto out = ctx.Output<framework::LoDTensor>("Out"); auto out = ctx.Output<framework::LoDTensor>("Out");
auto out_dims = out->dims(); auto out_dims = out->dims();
auto numel = out->numel();
int data_type = ctx.Attr<int>("dtype"); int data_type = ctx.Attr<int>("dtype");
framework::proto::VarType::Type type = framework::proto::VarType::Type type =
framework::proto::VarType::Type(data_type); framework::proto::VarType::Type(data_type);
...@@ -61,34 +62,8 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> { ...@@ -61,34 +62,8 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument("The value of peer (%d) you set must " platform::errors::InvalidArgument("The value of peer (%d) you set must "
"be less than comm->nranks (%d).", "be less than comm->nranks (%d).",
peer, comm->nranks())); peer, comm->nranks()));
ncclDataType_t dtype = platform::ToNCCLDataType(type);
// Recv the number of elements to receive first
int numel = 0;
int *numel_ptr = nullptr;
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS(hipMalloc(&numel_ptr, sizeof(int)));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc(&numel_ptr, sizeof(int)));
#endif
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::ncclRecv(static_cast<void *>(numel_ptr), 1, ncclInt,
peer, comm->comm(), stream));
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS(
hipMemcpy(&numel, numel_ptr, sizeof(int), hipMemcpyDeviceToHost));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMemcpy(&numel, numel_ptr, sizeof(int), cudaMemcpyDeviceToHost));
#endif
int rest_numel = 1;
for (int i = 1; i < out_dims.size(); ++i) {
rest_numel = rest_numel * out_dims[i];
}
out_dims[0] = numel / rest_numel;
out->mutable_data<T>(out_dims, place); out->mutable_data<T>(out_dims, place);
ncclDataType_t dtype = platform::ToNCCLDataType(type);
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv(
out->data<T>(), numel, dtype, peer, comm->comm(), stream)); out->data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " recv " VLOG(3) << "rank " << comm->rank() << " recv "
......
...@@ -57,21 +57,6 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> { ...@@ -57,21 +57,6 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
"be less than comm->nranks (%d).", "be less than comm->nranks (%d).",
peer, comm->nranks())); peer, comm->nranks()));
ncclDataType_t dtype = platform::ToNCCLDataType(x->type()); ncclDataType_t dtype = platform::ToNCCLDataType(x->type());
// Send number of elements to the receiver, as the receiver may have
// no information of the Tensor size.
int* numel_ptr = nullptr;
#ifdef PADDLE_WITH_RCCL
PADDLE_ENFORCE_CUDA_SUCCESS(hipMalloc(&numel_ptr, sizeof(int)));
PADDLE_ENFORCE_CUDA_SUCCESS(
hipMemcpy(numel_ptr, &numel, sizeof(int), hipMemcpyHostToDevice));
#else
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc(&numel_ptr, sizeof(int)));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMemcpy(numel_ptr, &numel, sizeof(int), cudaMemcpyHostToDevice));
#endif
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend(
numel_ptr, 1, ncclInt, peer, comm->comm(), stream));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend(
x->data<T>(), numel, dtype, peer, comm->comm(), stream)); x->data<T>(), numel, dtype, peer, comm->comm(), stream));
VLOG(3) << "rank " << comm->rank() << " send " VLOG(3) << "rank " << comm->rank() << " send "
......
...@@ -145,8 +145,10 @@ class PipelineOptimizer(MetaOptimizerBase): ...@@ -145,8 +145,10 @@ class PipelineOptimizer(MetaOptimizerBase):
user_defined_strategy): user_defined_strategy):
super(PipelineOptimizer, self)._set_basic_info( super(PipelineOptimizer, self)._set_basic_info(
loss, role_maker, user_defined_optimizer, user_defined_strategy) loss, role_maker, user_defined_optimizer, user_defined_strategy)
self.micro_batch_size = user_defined_strategy.pipeline_configs[
'micro_batch_size']
self.num_microbatches = user_defined_strategy.pipeline_configs[ self.num_microbatches = user_defined_strategy.pipeline_configs[
'micro_batch'] 'accumulate_steps']
def _can_apply(self): def _can_apply(self):
if not self.role_maker._is_collective: if not self.role_maker._is_collective:
...@@ -162,7 +164,10 @@ class PipelineOptimizer(MetaOptimizerBase): ...@@ -162,7 +164,10 @@ class PipelineOptimizer(MetaOptimizerBase):
def _enable_strategy(self, dist_strategy, context): def _enable_strategy(self, dist_strategy, context):
dist_strategy.pipeline = True dist_strategy.pipeline = True
dist_strategy.pipeline_configs = {"micro_batch": 1, } dist_strategy.pipeline_configs = {
"micro_batch_size": 1,
"accumulate_steps": 1,
}
def minimize_impl(self, def minimize_impl(self,
loss, loss,
...@@ -185,6 +190,8 @@ class PipelineOptimizer(MetaOptimizerBase): ...@@ -185,6 +190,8 @@ class PipelineOptimizer(MetaOptimizerBase):
loss.block.program._pipeline_opt = dict() loss.block.program._pipeline_opt = dict()
loss.block.program._pipeline_opt['local_rank'] = self.rank loss.block.program._pipeline_opt['local_rank'] = self.rank
loss.block.program._pipeline_opt[
'micro_batch_size'] = self.micro_batch_size
optimize_ops, params_grads, prog_list = self.wrapped_opt.minimize( optimize_ops, params_grads, prog_list = self.wrapped_opt.minimize(
loss, startup_program, parameter_list, no_grad_set) loss, startup_program, parameter_list, no_grad_set)
assert prog_list assert prog_list
......
...@@ -4075,12 +4075,15 @@ class PipelineOptimizer(object): ...@@ -4075,12 +4075,15 @@ class PipelineOptimizer(object):
break break
source_var = main_program.block(0).var(var_name) source_var = main_program.block(0).var(var_name)
new_var = self._create_var(block, source_var, var_name) new_var = self._create_var(block, source_var, var_name)
new_var_shape = list(new_var.shape)
new_var_shape[0] = self.micro_batch_size if new_var_shape[
0] < 0 else new_var_shape[0]
block._insert_op( block._insert_op(
index=index, index=index,
type='recv_v2', type='recv_v2',
outputs={'Out': [new_var]}, outputs={'Out': [new_var]},
attrs={ attrs={
'out_shape': new_var.shape, 'out_shape': new_var_shape,
'dtype': new_var.dtype, 'dtype': new_var.dtype,
self._op_device_key: device, self._op_device_key: device,
self._op_role_key: self._op_role.Forward, self._op_role_key: self._op_role.Forward,
...@@ -4243,12 +4246,15 @@ class PipelineOptimizer(object): ...@@ -4243,12 +4246,15 @@ class PipelineOptimizer(object):
'peer': cur_device_index, 'peer': cur_device_index,
}) })
extra_index += 1 extra_index += 1
var_shape = list(var.shape)
var_shape[0] = self.micro_batch_size if var_shape[
0] < 0 else var_shape[0]
block._insert_op( block._insert_op(
index=index + extra_index, index=index + extra_index,
type='recv_v2', type='recv_v2',
outputs={'Out': [var]}, outputs={'Out': [var]},
attrs={ attrs={
'out_shape': var.shape, 'out_shape': var_shape,
'dtype': var.dtype, 'dtype': var.dtype,
self._op_device_key: cur_device_spec, self._op_device_key: cur_device_spec,
self._op_role_key: op_role, self._op_role_key: op_role,
...@@ -4455,6 +4461,8 @@ class PipelineOptimizer(object): ...@@ -4455,6 +4461,8 @@ class PipelineOptimizer(object):
optimize_ops, params_grads = self._optimizer.minimize( optimize_ops, params_grads = self._optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set) loss, startup_program, parameter_list, no_grad_set)
self._param_device_map = self._optimizer._param_device_map self._param_device_map = self._optimizer._param_device_map
self.micro_batch_size = main_block.program._pipeline_opt[
'micro_batch_size']
# Step1: add default op_device attribute for regulization and clip ops # Step1: add default op_device attribute for regulization and clip ops
self._add_opdevice_attr_for_regularization_clip(main_block) self._add_opdevice_attr_for_regularization_clip(main_block)
......
...@@ -46,7 +46,10 @@ class TestCollectiveSendRecv(TestCollectiveRunnerBase): ...@@ -46,7 +46,10 @@ class TestCollectiveSendRecv(TestCollectiveRunnerBase):
ring_id = self.global_ring_id ring_id = self.global_ring_id
with fluid.program_guard(main_prog, startup_program): with fluid.program_guard(main_prog, startup_program):
tindata = layers.data( tindata = layers.data(
name="tindata", shape=[10, 1000], dtype='float64') name="tindata",
shape=[10, 1000],
dtype='float64',
append_batch_size=False)
if self.rank == 0: if self.rank == 0:
main_prog.global_block().append_op( main_prog.global_block().append_op(
type="send_v2", type="send_v2",
......
...@@ -120,6 +120,7 @@ class TestDistMnist2x2(TestDistRunnerBase): ...@@ -120,6 +120,7 @@ class TestDistMnist2x2(TestDistRunnerBase):
fleet.init(is_collective=True) fleet.init(is_collective=True)
strategy = fleet.DistributedStrategy() strategy = fleet.DistributedStrategy()
strategy.pipeline = True strategy.pipeline = True
strategy.pipeline_configs = {'micro_batch_size': batch_size, }
dist_opt = fleet.distributed_optimizer( dist_opt = fleet.distributed_optimizer(
optimizer=opt, strategy=strategy) optimizer=opt, strategy=strategy)
dist_opt.minimize(avg_cost) dist_opt.minimize(avg_cost)
......
...@@ -66,9 +66,12 @@ class TestStrategyConfig(unittest.TestCase): ...@@ -66,9 +66,12 @@ class TestStrategyConfig(unittest.TestCase):
def test_pipeline_configs(self): def test_pipeline_configs(self):
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
configs = {"micro_batch": 4} configs = {"micro_batch_size": 4}
strategy.pipeline_configs = configs strategy.pipeline_configs = configs
self.assertEqual(strategy.pipeline_configs["micro_batch"], 4) self.assertEqual(strategy.pipeline_configs["micro_batch_size"], 4)
configs = {"accumulate_steps": 2}
strategy.pipeline_configs = configs
self.assertEqual(strategy.pipeline_configs["accumulate_steps"], 2)
def test_localsgd(self): def test_localsgd(self):
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
......
...@@ -48,7 +48,10 @@ class TestFleetMetaOptimizer(unittest.TestCase): ...@@ -48,7 +48,10 @@ class TestFleetMetaOptimizer(unittest.TestCase):
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.pipeline = True strategy.pipeline = True
strategy.pipeline_configs = {'micro_batch': 2} strategy.pipeline_configs = {
'micro_batch_size': 1,
'accumulate_steps': 2
}
optimizer = paddle.fluid.optimizer.Adam(0.01) optimizer = paddle.fluid.optimizer.Adam(0.01)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册