From 81244fbfabe40971284f37faf2e35d80f39d6ffa Mon Sep 17 00:00:00 2001 From: mapingshuo Date: Mon, 26 Oct 2020 10:08:10 +0800 Subject: [PATCH] add sharding strategy in fleet(#27900) * add sharding --- .../framework/distributed_strategy.proto | 6 + .../fleet/base/distributed_strategy.py | 49 +++ .../fleet/meta_optimizers/__init__.py | 1 + .../fleet/meta_optimizers/common.py | 6 + .../fleet/meta_optimizers/dgc_optimizer.py | 8 + .../meta_optimizers/sharding/__init__.py | 13 + .../meta_optimizers/sharding/fp16_helper.py | 154 +++++++ .../sharding/gradient_clip_helper.py | 90 ++++ .../fleet/meta_optimizers/sharding/prune.py | 131 ++++++ .../fleet/meta_optimizers/sharding/shard.py | 144 ++++++ .../fleet/meta_optimizers/sharding/utils.py | 274 ++++++++++++ .../sharding/weight_decay_helper.py | 37 ++ .../meta_optimizers/sharding_optimizer.py | 411 ++++++++++++++++++ python/paddle/fluid/clip.py | 4 +- python/paddle/fluid/framework.py | 36 +- .../fluid/tests/unittests/CMakeLists.txt | 2 + .../unittests/fleet_meta_optimizer_base.py | 17 +- ...est_fleet_gradient_merge_meta_optimizer.py | 3 - .../test_fleet_sharding_meta_optimizer.py | 275 ++++++++++++ python/setup.py.in | 1 + 20 files changed, 1648 insertions(+), 14 deletions(-) create mode 100644 python/paddle/distributed/fleet/meta_optimizers/sharding/__init__.py create mode 100644 python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py create mode 100644 python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py create mode 100644 python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py create mode 100644 python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py create mode 100644 python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py create mode 100644 python/paddle/distributed/fleet/meta_optimizers/sharding/weight_decay_helper.py create mode 100644 python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py create mode 100644 python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 881ef30ffe6..50b7d62547b 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -24,6 +24,10 @@ enum Mode { message RecomputeConfig { repeated string checkpoints = 1; } +message ShardingConfig { + optional float fuse_broadcast_MB = 1 [ default = 32.0 ]; +} + message AMPConfig { optional float init_loss_scaling = 1 [ default = 32768.0 ]; optional int32 incr_every_n_steps = 2 [ default = 1000 ]; @@ -130,6 +134,7 @@ message DistributedStrategy { optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ]; optional bool adaptive_localsgd = 24 [ default = false ]; optional bool fp16_allreduce = 25 [ default = false ]; + optional bool sharding = 26 [ default = false ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; @@ -141,6 +146,7 @@ message DistributedStrategy { optional LarsConfig lars_configs = 108; optional LambConfig lamb_configs = 109; optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110; + optional ShardingConfig sharding_configs = 111; optional BuildStrategy build_strategy = 201; optional ExecutionStrategy execution_strategy = 202; } diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 847050b404f..71eca424fe6 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -611,6 +611,55 @@ class DistributedStrategy(object): "checkpoint_configs") assign_configs_value(self.strategy.recompute_configs, configs) + @property + def sharding(self): + """ + Indicating whether we are using sharding Optimizer for memory + optimization + + Default value: False + + Examples: + .. code-block:: python + import paddle.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.sharding = True + """ + return self.strategy.sharding + + @sharding.setter + @is_strict_auto + def sharding(self, flag): + if isinstance(flag, bool): + self.strategy.sharding = flag + else: + print("WARNING: sharding should have value of bool type") + + @property + def sharding_configs(self): + """ + Set sharding configurations. + + **Note**: + fuse_broadcast_MB(float): size of a fused group of broadcasted parameters. + + Examples: + .. code-block:: python + + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.sharding = True + strategy.sharding_configs = {"fuse_broadcast_MB": 32} + """ + return get_msg_dict(self.strategy.sharding_configs) + + @sharding_configs.setter + @is_strict_auto + def sharding_configs(self, configs): + check_configs_key(self.strategy.sharding_configs, configs, + "sharding_configs") + assign_configs_value(self.strategy.sharding_configs, configs) + @property def pipeline(self): """ diff --git a/python/paddle/distributed/fleet/meta_optimizers/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/__init__.py index 2e63e82e630..cdc8162f6de 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/__init__.py @@ -24,3 +24,4 @@ from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer from .dgc_optimizer import DGCOptimizer from .lamb_optimizer import LambOptimizer from .fp16_allreduce_optimizer import FP16AllReduceOptimizer +from .sharding_optimizer import ShardingOptimizer diff --git a/python/paddle/distributed/fleet/meta_optimizers/common.py b/python/paddle/distributed/fleet/meta_optimizers/common.py index 8ff4114bf8e..0f7ca4f4294 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/common.py +++ b/python/paddle/distributed/fleet/meta_optimizers/common.py @@ -99,6 +99,12 @@ class CollectiveHelper(object): OP_ROLE_KEY: OpRole.Forward }) + def _wait(self, current_endpoint, endpoints): + assert (self.wait_port) + other_endpoints = endpoints[:] + other_endpoints.remove(current_endpoint) + wait_server_ready(other_endpoints) + def _broadcast_params(self): block = self.startup_program.global_block() ring_id = -1 diff --git a/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py index be614a05147..7bd68325569 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dgc_optimizer.py @@ -30,6 +30,10 @@ class DGCOptimizer(MetaOptimizerBase): super(DGCOptimizer, self)._set_basic_info( loss, role_maker, user_defined_optimizer, user_defined_strategy) + def _init_dgc_opt(self): + if self.dgc_opt is not None: + return + opt = self.inner_opt if not self.role_maker._is_collective: @@ -86,13 +90,16 @@ class DGCOptimizer(MetaOptimizerBase): parameter_list=None, no_grad_set=None, callbacks=None): + self._init_dgc_opt() return self.dgc_opt.backward(loss, startup_program, parameter_list, no_grad_set, callbacks) def apply_gradients(self, params_grads): + self._init_dgc_opt() return self.dgc_opt.apply_gradients(params_grads=params_grads) def apply_optimize(self, loss, startup_program, params_grads): + self._init_dgc_opt() return self.dgc_opt.apply_optimize( loss, startup_program=startup_program, params_grads=params_grads) @@ -101,6 +108,7 @@ class DGCOptimizer(MetaOptimizerBase): startup_program=None, parameter_list=None, no_grad_set=None): + self._init_dgc_opt() optimize_ops, params_grads = \ self.dgc_opt.minimize(loss, startup_program, parameter_list, no_grad_set) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/__init__.py new file mode 100644 index 00000000000..5d358dbd35f --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py new file mode 100644 index 00000000000..cf6ab514b0b --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py @@ -0,0 +1,154 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op, OP_ROLE_KEY, OpRole +from paddle.distributed.fleet.meta_optimizers.sharding.utils import * + +from paddle.fluid import core + + +class FP16Utils(object): + def __init__(self): + pass + + @staticmethod + def is_fp16_cast_op(block, op, params): + if op.type != "cast": + return False + if is_optimizer_op(op): + return False + assert (len(op.desc.input_arg_names()) == 1) + assert (len(op.desc.output_arg_names()) == 1) + input_name, output_name = op.desc.input_arg_names()[ + 0], op.desc.output_arg_names()[0] + if input_name not in params: + return False + input_var = block.var(input_name) + output_var = block.var(output_name) + if input_var.dtype != core.VarDesc.VarType.FP32 or \ + output_var.dtype != core.VarDesc.VarType.FP16: + return False + return True + + @staticmethod + def is_fp32_cast_op(block, op): + if op.type != "cast": + return False + if not is_optimizer_op(op): + return False + assert (len(op.desc.input_arg_names()) == 1) + assert (len(op.desc.output_arg_names()) == 1) + input_name, output_name = op.desc.input_arg_names()[ + 0], op.desc.output_arg_names()[0] + input_var = block.var(input_name) + output_var = block.var(output_name) + if input_var.dtype != core.VarDesc.VarType.FP16 or \ + output_var.dtype != core.VarDesc.VarType.FP32: + return False + return True + + @staticmethod + def remove_cast_op(block, params, segment, offset): + inserted_op_num = 0 + for op_idx in reversed( + range(offset + segment._start_idx, offset + segment._end_idx)): + op = block.ops[op_idx] + if FP16Utils.is_fp16_cast_op(block, op, params): + block._remove_op(op_idx, sync=False) + inserted_op_num -= 1 + block._sync_with_cpp() + return inserted_op_num + + @staticmethod + def prune_fp16(block, shard, reduced_grads_to_param, nrings): + # remove cast + for idx, op in reversed(list(enumerate(block.ops))): + if not FP16Utils.is_fp32_cast_op(block, op): + continue + output_name = op.desc.output_arg_names()[0] + param_name = output_name.strip("@GRAD") + if param_name not in shard.global_params: + raise ValueError("Input 'X' of check_finite_and_unscale must" + "be grads, but {} is not a grad".format( + input_name)) + if output_name in reduced_grads_to_param: + continue + if shard.has_param(param_name): + continue + block._remove_op(idx, sync=False) + block._remove_var(output_name, sync=False) + + block._sync_with_cpp() + update_loss_scaling_op_idx = -1 + inf_var_name = '' + for idx, op in reversed(list(enumerate(block.ops))): + if op.type == "update_loss_scaling": + update_loss_scaling_op_idx = idx + inf_var_name = op.desc.input('FoundInfinite')[0] + op._rename_input(inf_var_name, inf_var_name + "@sharding") + if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: + reversed_x = [] + for input_name in op.desc.input('X'): + param_name = input_name.strip("@GRAD") + if param_name not in shard.global_params: + raise ValueError( + "Input 'X' of check_finite_and_unscale must" + "be grads, but {} is not a grad".format(input_name)) + if shard.has_param(param_name): + reversed_x.append(input_name) + op.desc.set_input('X', reversed_x) + op.desc.set_output('Out', reversed_x) + if update_loss_scaling_op_idx == -1: + return + inf_var = block.var(inf_var_name) + inf_var_fp32 = block.create_var( + name=inf_var_name + "@cast_int32", + shape=inf_var.shape, + dtype=core.VarDesc.VarType.INT32) + inf_var_sharding = block.create_var( + name=inf_var_name + "@sharding", + shape=inf_var.shape, + dtype=inf_var.dtype) + block._insert_op_without_sync( + update_loss_scaling_op_idx, + type='cast', + inputs={'X': inf_var}, + outputs={'Out': inf_var_fp32}, + attrs={ + "in_dtype": inf_var.dtype, + "out_dtype": inf_var_fp32.dtype, + OP_ROLE_KEY: OpRole.Optimize + }) + insert_sync_calc_op(block, update_loss_scaling_op_idx + 1, + [inf_var_fp32]) + block._insert_op_without_sync( + update_loss_scaling_op_idx + 2, + type='c_allreduce_max', + inputs={'X': inf_var_fp32}, + outputs={'Out': inf_var_fp32}, + attrs={'ring_id': 0, + OP_ROLE_KEY: OpRole.Optimize}) + comm_op_num = insert_sync_comm_ops( + block, update_loss_scaling_op_idx + 3, nrings, [inf_var_fp32]) + block._insert_op_without_sync( + update_loss_scaling_op_idx + 3 + comm_op_num, + type='cast', + inputs={'X': inf_var_fp32}, + outputs={'Out': inf_var_sharding}, + attrs={ + "in_dtype": inf_var_fp32.dtype, + "out_dtype": inf_var_sharding.dtype, + OP_ROLE_KEY: OpRole.Optimize + }) + block._sync_with_cpp() diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py new file mode 100644 index 00000000000..afa46f43fc0 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -0,0 +1,90 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole + + +class GradientClipHelper(object): + def __init__(self): + pass + + def _is_gradient_clip_op(self, op): + return op.desc.has_attr("op_namescope") \ + and op.desc.attr("op_namescope").startswith("/gradient_clip") + + def prune_gradient_clip(self, block, shard): + deperated_vars = set() + deperate_op_idx = set() + for idx, op in enumerate(block.ops): + if not self._is_gradient_clip_op(op): + continue + if op.type == "sum": + continue + deperate_op = False + for input_name in op.desc.input_arg_names(): + if input_name in deperated_vars: + deperate_op = True + param_name = input_name.strip("@GRAD") + if shard.is_param(param_name) and \ + not shard.has_param(param_name): + deperate_op = True + + if deperate_op: + deperate_op_idx.add(idx) + for output_name in op.desc.output_arg_names(): + deperated_vars.add(output_name) + + if not deperated_vars: + # got no gradient_clip op + return + + for idx, op in reversed(list(enumerate(block.ops))): + if not self._is_gradient_clip_op(op): + continue + if idx in deperate_op_idx: + block._remove_op(idx, sync=False) + continue + reversed_inputs = [] + if op.type == "sum": + for input_name in op.desc.input_arg_names(): + if input_name not in deperated_vars: + reversed_inputs.append(input_name) + op.desc.set_input("X", reversed_inputs) + assert (len(op.desc.output_arg_names()) == 1) + sum_res = op.desc.output_arg_names()[0] + block._insert_op_without_sync( + idx + 1, + type='c_sync_comm_stream', + inputs={'X': sum_res}, + outputs={'Out': sum_res}, + attrs={'ring_id': 0, + OP_ROLE_KEY: OpRole.Optimize}) + block._insert_op_without_sync( + idx + 1, + type='c_allreduce_sum', + inputs={'X': sum_res}, + outputs={'Out': sum_res}, + attrs={'ring_id': 0, + OP_ROLE_KEY: OpRole.Optimize}) + block._insert_op_without_sync( + idx + 1, + type='c_sync_calc_stream', + inputs={'X': sum_res}, + outputs={'Out': sum_res}, + attrs={OP_ROLE_KEY: OpRole.Optimize}) + + for var_name in deperated_vars: + block._remove_var(var_name, sync=False) + block._sync_with_cpp() + return diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py new file mode 100644 index 00000000000..7348e5f6d14 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py @@ -0,0 +1,131 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class ProgramDeps(object): + def __init__(self, block, start_vars, end_vars): + self._block = block + # vars where to start to build the deps + self._start_vars = start_vars + # vars where to stop to build the deps + self._end_vars = end_vars + # var name -> op idxs which depends on this var + self._var_to_use_op = {} + # sub block deps which is a subset of this topo + self._sub_block_deps = {} + # var name -> op idxs which generate var + self._var_to_generate_op = {} + self._should_removed_var = set() + self._father_block_deps = None + self._build_deps() + + def get_sub_block_deps(self, idx): + if idx in self._sub_block_deps: + return self._sub_block_deps[idx] + else: + return None + + def get_var_deps(self, var_name): + if var_name in self._var_to_use_op: + return self._var_to_use_op[var_name] + else: + return None + + def _build_deps(self, ): + for var_name in self._start_vars: + self._var_to_use_op[var_name] = [] + self._var_to_generate_op[var_name] = [] + + for idx, op in enumerate(self._block.ops): + if op.type in [ + "c_allreduce_sum", "c_sync_comm_stream", + "c_calc_comm_stream" + ]: + continue + input_vars = op.desc.input_arg_names() + output_vars = op.desc.output_arg_names() + deps_reduce = False + for input_name in input_vars: + if input_name in self._var_to_use_op: + deps_reduce = True + if not deps_reduce: + continue + for input_name in input_vars: + if input_name in self._var_to_use_op: + self._var_to_use_op[input_name].append(idx) + for output_name in output_vars: + if output_name not in self._var_to_use_op: + self._var_to_use_op[output_name] = [] + if output_name not in self._var_to_generate_op: + self._var_to_generate_op[output_name] = [idx] + else: + self._var_to_generate_op[output_name].append(idx) + if op.type == "conditional_block": + # subblock + assert (op.desc.has_attr("sub_block")) + subblock_idx = op.desc.attr("sub_block").id + subblock_deps = ProgramDeps( + self._block.program.block(subblock_idx), + op.desc.input_arg_names(), op.desc.output_arg_names()) + self._sub_block_deps[subblock_idx] = subblock_deps + subblock_deps._father_block_deps = self + + def crop_input_var_from_op(self, op_idx, var_name): + if var_name in self._var_to_use_op: + # update var -> dep_var_op + if self._var_to_use_op[var_name] != []: + if op_idx not in self._var_to_use_op[var_name]: + raise ValueError( + "op_idx: {} is not in self._var_to_use_op[{}], " + "self._var_to_use_op[{}] is {}".format( + op_idx, var_name, var_name, self._var_to_use_op[ + var_name])) + self._var_to_use_op[var_name].remove(op_idx) + # update _should_removed_var + if var_name in self._start_vars: + self._should_removed_var.discard(var_name) + elif self._var_to_use_op[ + var_name] == []: # no more deps of this var + self._should_removed_var.add(var_name) + elif self._var_to_generate_op[var_name][-1] >= self._var_to_use_op[ + var_name][-1]: + # there are circle in the graph + self._should_removed_var.add(var_name) + else: # input_name should not be deleted + self._should_removed_var.discard(var_name) + + def crop_output_var_from_op(self, op_idx, var_name): + if var_name in self._var_to_generate_op: + assert (op_idx in self._var_to_generate_op[var_name]) + self._var_to_generate_op[var_name].remove(op_idx) + if self._block.has_var(var_name): + if var_name not in self._var_to_generate_op or self._var_to_generate_op[ + var_name] == []: + self._block._remove_var(var_name, sync=False) + + def remove_op(self, op_idx): + # update deps + op = self._block.ops[op_idx] + for input_name in op.desc.input_arg_names(): + self.crop_input_var_from_op(op_idx, input_name) + for output_name in op.desc.output_arg_names(): + self.crop_output_var_from_op(op_idx, output_name) + self._block._remove_op(op_idx, sync=False) + + def should_remove_op(self, op_idx): + op = self._block.ops[op_idx] + for output_name in op.desc.output_arg_names(): + if output_name not in self._should_removed_var: + return False + return True diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py new file mode 100644 index 00000000000..27c63fc406f --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/shard.py @@ -0,0 +1,144 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op +from paddle.distributed.fleet.meta_optimizers.sharding.utils import * +from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils + + +class Shard(object): + def __init__(self, ): + self.global_params = set([]) + self.worker_idx = -1 + self.worker_num = -1 + self.global_param2device = {} + + def setup(self, params_grads, worker_idx, worker_num): + # param names of all devices + self.global_params = set([x[0].name for x in params_grads]) + # _param(str) -> device_id(int) + self.worker_idx = worker_idx + self.worker_num = worker_num + # global_param2device contains fp32 params and fp16 params + self.global_param2device = self._split_params(params_grads, worker_idx, + worker_num) + + def has_param(self, var_name): + return var_name in self.global_param2device and \ + self._var_device_id(var_name) == self.worker_idx + + def has_opt_var(self, var_name): + return self._var_device_id(var_name) == self.worker_idx + + def has_var(self, var_name): + return self._var_device_id(var_name) == -1 or \ + self._var_device_id(var_name) == self.worker_idx + + def _split_params(self, params_grads, worker_idx, worker_num): + param2device = {} + total_param_mem = 0.0 + param2mem = [] + for param in [x[0] for x in params_grads]: + mem = get_var_size(param) + total_param_mem += mem + param2mem.append((param.name, mem)) + device2params = {x: [] for x in range(worker_num)} + device_idx = 0 + mem_accu = 0.0 + for param_name, mem in param2mem: + if mem_accu > total_param_mem * 1.0 * (device_idx + 1) / worker_num: + device_idx += 1 + device2params[device_idx].append(param_name) + param2device[param_name] = device_idx + mem_accu += mem + return param2device + + def _var_device_id(self, var_name): + if var_name in self.global_param2device: + return self.global_param2device[var_name] + for suffix in [ + "_moment1_0", "_moment2_0", "_beta1_pow_acc_0", + "_beta2_pow_acc_0", "_velocity_0" + ]: + base_name = re.sub(suffix, '', var_name) + if base_name in self.global_param2device: + return self.global_param2device[base_name] + return -1 + + def find_broadcast_params(self, block): + broadcast_vars = set([]) + fp16_params = set([]) + fp16_to_fp32 = {} + + param_usage = {x: 0 for x in self.global_params} + for op in block.ops: + if is_optimizer_op(op): + continue + for input_name in op.desc.input_arg_names(): + if input_name in self.global_params: + param_usage[input_name] += 1 + + for op in block.ops: + if not FP16Utils.is_fp16_cast_op(block, op, self.global_params): + continue + input_name = op.input_arg_names[0] + output_name = op.output_arg_names[0] + broadcast_vars.add(output_name) + fp16_params.add(output_name) + fp16_to_fp32[output_name] = input_name + param_usage[input_name] -= 1 + self.global_param2device[output_name] = self.global_param2device[ + input_name] + + for param, usage in param_usage.items(): + if usage > 0: + broadcast_vars.add(param) + return broadcast_vars + + def device(self, var_name): + return self._var_device_id(var_name) + + def is_param(self, var_name): + return var_name in self.global_params + + def is_opti_var(self, var_name): + if var_name in self.global_params: + return True + for suffix in [ + "_moment1_0", "_moment2_0", "_beta1_pow_acc_0", + "_beta2_pow_acc_0", "_velocity_0" + ]: + base_name = re.sub(suffix, '', var_name) + if base_name in self.global_params: + return True + return False + + +class ProgramSegment(object): + def __init__(self, block): + self._block = block + self._allreduce_vars = [] + # sub program start idx + self._start_idx = -1 + # sub program end idx + self._end_idx = -1 + # param name to broadcast name + self._param2broadcast = {} + self._broadcast_vars = [] + # cast op pairs, fp16 name (str) -> fp32 name (str) + self._cast_ops = {} + # fill constant vars + self._fill_constant_vars = [] + # parameter mems + self._param_mem = 0.0 diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py new file mode 100644 index 00000000000..51435ebb9e5 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -0,0 +1,274 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid import core +from functools import reduce +from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op +from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY + +import re + + +def check_broadcast(block): + """ + if a var is broadcasted, it should have a sync_comm before + this var is used, if not, raise error. + if the broadcasted var has a fill_constant op, the fill_constant + op should stay forward before the broadcast op, and before a + sync_calc op. Otherwise, raise error. + """ + broadcast_vars = {} + for idx, op in enumerate(block.ops): + if op.type == "c_broadcast": + var_name = op.desc.input_arg_names()[0] + if "@BroadCast" in var_name: + if var_name in broadcast_vars: + raise ValueError("var_name areadly exist: {}" + "the old pos is {}, the new pos is {}". + format(var_name, broadcast_vars[var_name][ + "broadcast_pos"], idx)) + broadcast_vars[var_name] = { + "fill_constant_pos": -1, + "broadcast_pos": idx, + } + + for idx, op in enumerate(block.ops): + if op.type == "fill_constant": + var_name = op.desc.output_arg_names()[0] + if var_name in broadcast_vars: + broadcast_vars[var_name]["fill_constant_pos"] = idx + continue + + last_sync_comm_op_idx = -1 + last_sync_calc_op_idx = -1 + for idx, op in enumerate(block.ops): + if op.type == "c_sync_comm_stream": + last_sync_comm_op_idx = idx + continue + if op.type == "c_sync_calc_stream": + last_sync_calc_op_idx = idx + continue + if op.type == "c_broadcast": + var_name = op.desc.input_arg_names()[0] + if "@BroadCast" in var_name: + if broadcast_vars[var_name]["fill_constant_pos"] != -1: + assert (last_sync_calc_op_idx != -1) + assert (broadcast_vars[var_name]["fill_constant_pos"] < + last_sync_calc_op_idx) + assert (last_sync_calc_op_idx < idx) + continue + for input_name in op.desc.input_arg_names(): + if input_name in broadcast_vars: + assert (broadcast_vars[input_name]["broadcast_pos"] != -1) + assert (broadcast_vars[input_name]["broadcast_pos"] < + last_sync_comm_op_idx) + assert (last_sync_comm_op_idx < idx) + return + + +def check_allreduce_sum(block): + """ + if a Var is allreduced, the op order should be: + - 0: op that generate Var + - 1: sync_calc + - 2: allreduce_sum op + - 3: sync_comm + - 4: op that use Var + """ + var_status = {} + for op in block.ops: + if op.type == "c_allreduce_sum": + var_name = op.desc.input_arg_names()[0] + var_status[var_name] = -1 + + for op in block.ops: + if op.type == "c_sync_calc_stream": + for var_name in var_status: + if var_name in var_status and var_status[var_name] == 0: + var_status[var_name] = 1 + elif op.type == "c_allreduce_sum": + var_name = op.desc.input_arg_names()[0] + if var_status[var_name] == -1: + raise ValueError("{} is not generated, but you are" + "trying to all-reduce it".format(var_name)) + if var_status[var_name] == 0: + raise ValueError("There should be a sync_calc op " + "after generate Var: {} and before the" + "c_allreduce_sum op".format(var_name)) + assert (var_status[var_name] == 1) + var_status[var_name] = 2 + elif op.type == "c_sync_comm_stream": + for var_name in op.desc.input_arg_names(): + if var_name in var_status and var_status[var_name] == 2: + var_status[var_name] = 3 + else: + for input_name in op.desc.input_arg_names(): + if input_name in var_status: + if var_status[input_name] != 3: + raise ValueError("There should be a sync_comm op " + "after allreduce the Var: {}".format( + var_name)) + for output_name in op.desc.output_arg_names(): + if output_name in var_status and \ + var_status[output_name] == -1: + var_status[output_name] = 0 + return + + +def insert_sync_calc_op(block, insert_idx, calc_dep_vars): + """ + _insert_sync_calc_op + """ + op_role = block.ops[insert_idx].attr('op_role') + block._insert_op_without_sync( + insert_idx, + type='c_sync_calc_stream', + inputs={'X': calc_dep_vars}, + outputs={'Out': calc_dep_vars}, + attrs={OP_ROLE_KEY: op_role}) + return + + +def insert_sync_comm_ops(block, insert_idx, nrings, comm_dep_vars): + """ + _insert_sync_comm_ops + """ + op_role = block.ops[insert_idx].attr('op_role') + for i in range(nrings): + block._insert_op_without_sync( + insert_idx, + type='c_sync_comm_stream', + inputs={'X': comm_dep_vars}, + outputs={'Out': comm_dep_vars}, + attrs={'ring_id': i, + OP_ROLE_KEY: op_role}) + return nrings + + +def insert_fill_constant_ops(block, insert_idx, fill_constant_vars): + """ + _add_fill_constant_ops + """ + op_role = block.ops[insert_idx].attr('op_role') + for broadcast_name in fill_constant_vars: + broadcast_var = block.var(broadcast_name) + block._insert_op_without_sync( + insert_idx, + type="fill_constant", + outputs={"Out": broadcast_var.name}, + attrs={ + "shape": broadcast_var.shape, + "dtype": broadcast_var.dtype, + "value": 0.0, + OP_ROLE_KEY: op_role + }) + return + + +def insert_cast_ops(block, insert_idx, cast_ops): + """ + _add_cast_ops + """ + op_role = block.ops[insert_idx].attr('op_role') + for fp16_name, fp32_name in cast_ops.items(): + block._insert_op_without_sync( + insert_idx, + type="cast", + inputs={"X": fp32_name}, + outputs={"Out": fp16_name}, + attrs={ + "in_dtype": core.VarDesc.VarType.FP32, + "out_dtype": core.VarDesc.VarType.FP16, + OP_ROLE_KEY: op_role + }) + return + + +def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars): + """ + _add_allreduce_ops + """ + ring_id = -1 + for var in allreduce_vars: + ring_id = (ring_id + 1) % nrings + block._insert_op_without_sync( + insert_idx, + type='c_allreduce_sum', + inputs={'X': var}, + outputs={'Out': var}, + attrs={'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Backward}) + return + + +def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root): + """ + _add_broadcast_ops + """ + ring_id = -1 + op_role = block.ops[insert_idx].attr('op_role') + for broadcast_name, root_device in broadcast2root: + ring_id = (ring_id + 1) % nrings + block._insert_op_without_sync( + insert_idx, + type='c_broadcast', + inputs={'X': broadcast_name}, + outputs={'Out': broadcast_name}, + attrs={ + 'ring_id': ring_id, + 'root': root_device, + OP_ROLE_KEY: op_role + }) + return + + +DtypeToSize = { + core.VarDesc.VarType.FP16: 2, + core.VarDesc.VarType.FP32: 4, + core.VarDesc.VarType.FP64: 8, + core.VarDesc.VarType.INT16: 2, + core.VarDesc.VarType.INT32: 4, + core.VarDesc.VarType.INT64: 8, + core.VarDesc.VarType.BOOL: 1, + core.VarDesc.VarType.UINT8: 1, +} + + +def get_var_size(param): + """ + input: + - param: var + return: + var size in Bytes + """ + assert -1 not in param.shape + return reduce(lambda x, y: x * y, + param.shape) * DtypeToSize[param.dtype] / 1024.0 / 1024.0 + + +def insert_scale_loss_grad_ops(block, scale=1.0): + ''' + In order to keep the learning rate consistent in different numbers of + training workers, we scale the loss grad by the number of workers + ''' + for idx, op in reversed(list(enumerate(block.ops))): + if is_loss_grad_op(op): + loss_grad_var = block.vars[op.output_arg_names[0]] + block._insert_op_without_sync( + idx + 1, + type='scale', + inputs={'X': loss_grad_var}, + outputs={'Out': loss_grad_var}, + attrs={'scale': scale, + OP_ROLE_KEY: OpRole.Backward}) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/weight_decay_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/weight_decay_helper.py new file mode 100644 index 00000000000..2833e8c6dac --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/weight_decay_helper.py @@ -0,0 +1,37 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_VAR_KEY + + +class WeightDecayHelper(object): + def __init__(self): + pass + + def _is_weight_decay_op(self, op): + return op.desc.has_attr("op_namescope") \ + and op.desc.attr("op_namescope").startswith("/regularization") + + def prune_weight_decay(self, block, shard): + for idx, op in reversed(list(enumerate(block.ops))): + if not self._is_weight_decay_op(op): + continue + if OP_ROLE_VAR_KEY not in op.attr_names: + raise ValueError( + "The Weight Dacay op should hold op_role_var attribute" + "but the {} op does not hold op_role_var".format(op.type)) + op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] + if not shard.has_param(op_role_var[0]): + block._remove_op(idx, sync=False) + block._sync_with_cpp() diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py new file mode 100644 index 00000000000..a449821f8c2 --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -0,0 +1,411 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid import unique_name, core +import paddle.fluid as fluid + +from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper +from paddle.distributed.fleet.meta_optimizers.common import is_backward_op +from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase +from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment +from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils +from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper import WeightDecayHelper +from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper +from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps +from paddle.distributed.fleet.meta_optimizers.sharding.utils import * + +from functools import reduce + +__all__ = ["ShardingOptimizer"] + + +class ShardingOptimizer(MetaOptimizerBase): + def __init__(self, optimizer): + super(ShardingOptimizer, self).__init__(optimizer) + self.inner_opt = optimizer + self.meta_optimizers_white_list = [ + "RecomputeOptimizer", + "AMPOptimizer", + ] + self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] + self._main_program = None + self._startup_program = None + self._segments = [] + # params and fp16 params is for broadcast + self._params = set([]) + self._broadcast_vars = set([]) + # reduced grads to param name + self._reduced_grads_to_param = {} + self._shard = Shard() + + def _can_apply(self): + if not self.role_maker._is_collective: + return False + if self.role_maker._worker_num() <= 1: + return False + return self.user_defined_strategy.sharding + + def _disable_strategy(self, dist_strategy): + dist_strategy.sharding = False + dist_strategy.sharding_configs = {} + + def _enable_strategy(self, dist_strategy, context): + dist_strategy.sharding = True + dist_strategy.sharding_configs = {"fuse_broadcast_MB": 32} + + def minimize_impl(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + self._nrings = self.user_defined_strategy.nccl_comm_num + self._fuse_broadcast_MB = self.user_defined_strategy.sharding_configs[ + "fuse_broadcast_MB"] + + if self.inner_opt is None: + raise ValueError( + "self.inner_opt of ShardingOptimizer should not be None.") + optimize_ops, params_grads = self.inner_opt.minimize( + loss, startup_program, parameter_list, no_grad_set) + + if startup_program is None: + startup_program = default_startup_program() + main_block = loss.block + startup_block = startup_program.global_block() + self._main_program = main_block.program + self._startup_program = startup_program + + # step1: set_up + self._set_up(params_grads) + + # step2: split_program + self._split_program(main_block) + + # step3: add broadcast and reduce ops + self._add_broadcast_allreduce(main_block) + main_block._sync_with_cpp() + startup_block._sync_with_cpp() + + # step4: insert reduce_sum for grad + insert_scale_loss_grad_ops( + main_block, scale=1.0 / self.role_maker._worker_num()) + main_block._sync_with_cpp() + + # step5: remove unneeded ops and vars from block + self._prune_main_program(main_block) + self._prune_startup_program(startup_block) + + # check op dependecy + check_broadcast(main_block) + check_allreduce_sum(main_block) + self._wait() + return optimize_ops, params_grads + + def _set_up(self, params_grads): + # step 1: initialize nccl + worker_idx = self.role_maker._worker_index() + endpoints = self.role_maker._get_trainer_endpoints() + current_endpoint = endpoints[worker_idx] + self._collective_helper = CollectiveHelper(self.role_maker, + self._nrings) + for ring_id in range(self._nrings): + self._collective_helper._init_communicator( + self._startup_program, current_endpoint, endpoints, worker_idx, + ring_id, None) + startup_block = self._startup_program.global_block() + startup_block._sync_with_cpp() + + # step 2: split params + self._params = set([x[0].name for x in params_grads]) + self._shard.setup(params_grads, worker_idx, + self.role_maker._worker_num()) + + # step 3: get broadcast vars + self._broadcast_vars = self._shard.find_broadcast_params( + self._main_program.global_block()) + + def _wait(self, ): + endpoints = self.role_maker._get_trainer_endpoints() + current_endpoint = endpoints[self.role_maker._worker_index()] + if self.role_maker._worker_index() == 0: + self._collective_helper._wait(current_endpoint, endpoints) + + def _split_program(self, block): + for op_idx, op in reversed(list(enumerate(block.ops))): + if int(op.attr('op_role')) != int(OpRole.Optimize): + last_backward_op_idx = op_idx + 1 + break + segment = ProgramSegment(block) + segment._end_idx = last_backward_op_idx + for op_idx in reversed(range(last_backward_op_idx)): + op = block.ops[op_idx] + assert (int(op.attr('op_role')) != int(OpRole.Optimize)) + if segment._param_mem >= self._fuse_broadcast_MB: + segment._start_idx = op_idx + 1 + self._segments.insert(0, segment) + segment = ProgramSegment(block) + segment._end_idx = op_idx + 1 + + # find broadcast vars + for input_name in op.desc.input_arg_names(): + if input_name not in self._broadcast_vars: + continue + if input_name in segment._param2broadcast: + # skip broadcast because it reuse the old broadcast var + broadcast_name = segment._param2broadcast[input_name] + if input_name != broadcast_name: + op._rename_input(input_name, broadcast_name) + continue + if self._shard.has_param(input_name): + broadcast_var_name = input_name + else: + broadcast_var_name = unique_name.generate(input_name + + "@BroadCast") + segment._fill_constant_vars.append(broadcast_var_name) + segment._param2broadcast[input_name] = broadcast_var_name + segment._broadcast_vars.append((broadcast_var_name, + self._shard.device(input_name))) + segment._param_mem += get_var_size( + self._main_program.global_block().var(input_name)) + + # find reduce vars + if is_backward_op(op) and \ + OP_ROLE_VAR_KEY in op.attr_names: + op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] + if len(op_role_var) != 0: + assert len(op_role_var) % 2 == 0 + for i in range(0, len(op_role_var), 2): + param, reduced_grad = op_role_var[i], op_role_var[i + 1] + segment._allreduce_vars.append(reduced_grad) + assert ( + reduced_grad not in self._reduced_grads_to_param) + self._reduced_grads_to_param[reduced_grad] = param + + # find cast op + if FP16Utils.is_fp16_cast_op(block, op, self._params): + fp32_param = op.desc.input_arg_names()[0] + fp16_param = op.desc.output_arg_names()[0] + if self._shard.has_param(fp32_param): + segment._cast_ops[fp16_param] = fp32_param + + if segment._param_mem > 0: + segment._start_idx = 0 + self._segments.insert(0, segment) + return + + def _prune_main_program(self, block): + """ + calculate deps from allredce op to optimize op, + remove ops and vars not needed in this worker + """ + weightdecay_helper = WeightDecayHelper() + weightdecay_helper.prune_weight_decay(block, self._shard) + FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, + self._nrings) + gradientclip_helper = GradientClipHelper() + gradientclip_helper.prune_gradient_clip(block, self._shard) + + # build prog deps + reduced_grads = [] + for idx, op in enumerate(block.ops): + input_names = op.desc.input_arg_names() + output_names = op.desc.output_arg_names() + if op.type == "c_allreduce_sum": + assert (len(output_names) == 1) + output_name = output_names[0] + reduced_grads.append(output_name) + + pruned_opti_vars = [] + for var_name in list(block.vars.keys()): + if self._shard.is_opti_var(var_name) and \ + not self._shard.has_opt_var(var_name): + pruned_opti_vars.append(var_name) + program_deps = ProgramDeps(block, reduced_grads, pruned_opti_vars) + + # Init + for var_name in program_deps._end_vars: + program_deps._should_removed_var.add(var_name) + + # Prune + for idx, op in reversed(list(enumerate(block.ops))): + if op.type in [ + "c_allreduce_sum", "c_sync_comm_stream", + "c_calc_comm_stream", "c_gen_nccl_id", "c_comm_init" + ]: + pass + elif op.type == "conditional_block": + assert (op.desc.has_attr("sub_block")) + subblock_idx = op.desc.attr("sub_block").id + subblock_deps = program_deps.get_sub_block_deps(subblock_idx) + # only prune amp subblock + if subblock_deps is None or not self._is_amp_subblock(op): + continue + # init + reversed_output_vars = [] + for output_name in op.desc.output("Out"): + if output_name in program_deps._should_removed_var: + subblock_deps._should_removed_var.add(output_name) + program_deps.crop_output_var_from_op(idx, output_name) + else: + reversed_output_vars.append(output_name) + # prune + for sub_op_idx, _ in reversed( + list(enumerate(subblock_deps._block.ops))): + if subblock_deps.should_remove_op(sub_op_idx): + subblock_deps.remove_op(sub_op_idx) + reversed_input_vars = [] + for input_name in op.desc.input('Input'): + if input_name not in subblock_deps._should_removed_var: + reversed_input_vars.append(input_name) + else: + program_deps.crop_input_var_from_op(idx, input_name) + op.desc.set_input('Input', reversed_input_vars) + op.desc.set_output('Out', reversed_output_vars) + else: + if program_deps.should_remove_op(idx): + program_deps.remove_op(idx) + + block._sync_with_cpp() + return + + def _add_broadcast_allreduce(self, block): + """ + _add_broadcast_allreduce + """ + ring_id = -1 + if len(self._segments) < 1: + return + + if self._segments[-1]._allreduce_vars: + insert_sync_comm_ops(block, self._segments[-1]._end_idx, + self._nrings, + self._segments[-1]._allreduce_vars) + insert_allreduce_ops(block, self._segments[-1]._end_idx, + self._nrings, + self._segments[-1]._allreduce_vars) + + for idx, segment in reversed(list(enumerate(self._segments))): + allreduce_vars = self._segments[ + idx - 1]._allreduce_vars if idx > 0 else [] + broadcast_vars = self._segments[idx + + 1]._broadcast_vars if idx < len( + self._segments) - 1 else [] + fill_constant_vars = self._segments[ + idx + 2]._fill_constant_vars if idx < len( + self._segments) - 2 else [] + cast_ops = self._segments[idx + 2]._cast_ops if idx < len( + self._segments) - 2 else {} + + for op_idx in reversed(range(segment._start_idx, segment._end_idx)): + op = block.ops[op_idx] + for input_name in op.desc.input_arg_names(): + if input_name in segment._param2broadcast and \ + input_name != segment._param2broadcast[input_name]: + op._rename_input(input_name, + segment._param2broadcast[input_name]) + + for param_name, broadcast_name in segment._param2broadcast.items(): + if param_name != broadcast_name: + block.create_var( + name=broadcast_name, + shape=self._main_program.global_block().var( + param_name).shape, + dtype=self._main_program.global_block().var(param_name) + .dtype, + persistable=False) + + # step1: remove cast ops + block._sync_with_cpp() + segment._end_idx += FP16Utils.remove_cast_op(block, self._params, + segment, 0) + + # step2: add Sync ops + comm_dep_vars = allreduce_vars + [x[0] for x in broadcast_vars] + if len(comm_dep_vars) > 0: + insert_sync_comm_ops( + block, + segment._end_idx, + self._nrings, + comm_dep_vars, ) + calc_dep_vars = fill_constant_vars + [ + k for k, v in cast_ops.items() + ] + self._segments[idx]._allreduce_vars + + if len(calc_dep_vars) > 0: + insert_sync_calc_op(block, segment._end_idx, + [calc_dep_vars[-1]]) + + # step3: insert `fill_constant` ops + insert_fill_constant_ops(block, segment._end_idx, + fill_constant_vars) + + # step4: add `cast` ops + insert_cast_ops(block, segment._end_idx, cast_ops) + + # step5: add broadcast ops + insert_broadcast_ops(block, segment._start_idx, self._nrings, + broadcast_vars) + + # step6: add all_reduce ops + insert_allreduce_ops(block, segment._start_idx, self._nrings, + allreduce_vars) + + block._sync_with_cpp() + + if self._segments[0]._broadcast_vars: + insert_sync_comm_ops( + block, self._segments[0]._start_idx, self._nrings, + [x[0] for x in self._segments[0]._broadcast_vars]) + insert_broadcast_ops(block, self._segments[0]._start_idx, + self._nrings, + self._segments[0]._broadcast_vars) + + fill_constant_vars = [] + for x in self._segments[:2]: + fill_constant_vars += x._fill_constant_vars + + # Join + cast_ops = {} + for x in self._segments[:2]: + for k, v in x._cast_ops.items(): + cast_ops[k] = v + + calc_deps_vars = fill_constant_vars + [k for k, v in cast_ops.items()] + if fill_constant_vars or cast_ops: + insert_sync_calc_op(block, self._segments[0]._start_idx, + [calc_deps_vars[-1]]) + + if fill_constant_vars: + insert_fill_constant_ops(block, self._segments[0]._start_idx, + fill_constant_vars) + + if cast_ops: + insert_cast_ops(block, self._segments[0]._start_idx, cast_ops) + + return + + def _prune_startup_program(self, block): + for idx, op in reversed(list(enumerate(block.ops))): + for output_name in op.desc.output_arg_names(): + if self._shard.has_var(output_name): + continue + #TODO why do we remove op, when only one var is removed + block._remove_op(idx, sync=False) + break + + for var_name in list(block.vars.keys()): + if self._shard.has_var(var_name): + continue + block._remove_var(var_name, sync=False) + block._sync_with_cpp() diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 505d6fef8fb..f20716c3a15 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -669,7 +669,7 @@ def append_gradient_clip_ops(param_grads): if g is None: continue with p.block.program._optimized_guard( - [p, g]), framework.name_scope('gradient_clip_@CLIP'): + [p, g]), framework.name_scope('gradient_clip'): clip_attr = getattr(p, 'gradient_clip_attr', None) if clip_attr is None: return param_grads @@ -685,7 +685,7 @@ def append_gradient_clip_ops(param_grads): if g is None: continue with p.block.program._optimized_guard( - [p, g]), framework.name_scope('graident_clip_@CLIP'): + [p, g]), framework.name_scope('gradient_clip'): param, new_grad = clip_attr._create_operators(param=p, grad=g) param_new_grad_name_dict[param.name] = new_grad.name res.append([param, new_grad]) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index aaceb22b98d..6be7fe0612e 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2100,10 +2100,16 @@ class Operator(object): % (out_proto.name, len(out_args))) out_arg_names = [] for arg in out_args: - out_arg_names.append(cpt.to_text(arg.name)) + if isinstance(arg, six.string_types): + out_arg_names.append(arg) + else: + out_arg_names.append(cpt.to_text(arg.name)) # TODO(minqiyang): could we remove variable's op in static mode? if not in_dygraph_mode(): - arg.op = self + if isinstance(arg, six.string_types): + block.var(arg).op = self + else: + arg.op = self self.desc.set_output(out_proto.name, out_arg_names) if op_attrs is not None: @@ -2837,8 +2843,9 @@ class Block(object): self._sync_with_cpp() return var - def _remove_var(self, name): - self._sync_with_cpp() + def _remove_var(self, name, sync=True): + if sync == True: + self._sync_with_cpp() self.desc._remove_var(cpt.to_bytes(name)) del self.vars[name] @@ -2936,7 +2943,23 @@ class Block(object): self.ops.insert(index, op) return op - def _remove_op(self, index): + def _insert_op_without_sync(self, index, *args, **kwargs): + """ + Insert an Operator according to the giving arguments, + without sync_with_cpp to meke the compilation faster. + + Args: + index(int): the place that the operator to insert. + + Returns: + Operator: the insert Operator. + """ + op_desc = self.desc._insert_op(index) + op = Operator(block=self, desc=op_desc, *args, **kwargs) + self.ops.insert(index, op) + return op + + def _remove_op(self, index, sync=True): """ Remove the specific position operator. @@ -2946,7 +2969,8 @@ class Block(object): Returns: None """ - self._sync_with_cpp() + if sync == True: + self._sync_with_cpp() self.desc._remove_op(index, index + 1) del self.ops[index] diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 39e44f6aaa1..101242808b2 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -41,6 +41,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_recompute_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer) +list(APPEND MIXED_DIST_TEST_OPS test_fleet_sharding_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_localsgd_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_lars_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_lamb_meta_optimizer) @@ -461,6 +462,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_fleet_recompute_meta_optimizer MODULES test_fleet_recompute_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_graph_executor MODULES test_fleet_graph_executor ENVS ${dist_ENVS}) py_test_modules(test_fleet_gradient_merge_meta_optimizer MODULES test_fleet_gradient_merge_meta_optimizer ENVS ${dist_ENVS}) + py_test_modules(test_fleet_sharding_meta_optimizer MODULES test_fleet_sharding_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_amp_meta_optimizer MODULES test_fleet_amp_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_fp16_allreduce_meta_optimizer MODULES test_fleet_fp16_allreduce_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_pipeline_meta_optimizer MODULES test_fleet_pipeline_meta_optimizer ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py b/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py index 48df06cddd9..b6ecc07fd9f 100755 --- a/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py +++ b/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py @@ -55,14 +55,22 @@ class TestFleetMetaOptimizer(unittest.TestCase): strategy, train_prog, startup_prog, - name='momentum'): + name='momentum', + regularization=None, + grad_clip=None): with fluid.program_guard(train_prog, startup_prog): with fluid.unique_name.guard(): if name == 'momentum': optimizer = paddle.fluid.optimizer.Momentum( - learning_rate=0.01, momentum=0.9) + learning_rate=0.01, + momentum=0.9, + regularization=regularization, + grad_clip=grad_clip) elif name == 'adam': - optimizer = paddle.fluid.optimizer.Adam(learning_rate=0.01) + optimizer = paddle.fluid.optimizer.Adam( + learning_rate=0.01, + regularization=regularization, + grad_clip=grad_clip) optimizer = fleet.distributed_optimizer( optimizer, strategy=strategy) optimizer.minimize(loss) @@ -121,5 +129,8 @@ class TestFleetMetaOptimizer(unittest.TestCase): elif name == "gradient_merge": strategy.gradient_merge = True strategy.gradient_merge_configs = {"k_steps": 2, "avg": True} + elif name == "sharding": + strategy.sharding = True + strategy.sharding_configs = {"fuse_broadcast_MB": 0.2} else: raise NotImplementedError() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_gradient_merge_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_gradient_merge_meta_optimizer.py index 29eb3d9ab16..a40bc9a9fba 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_gradient_merge_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_gradient_merge_meta_optimizer.py @@ -32,9 +32,6 @@ class TestFleetGradientMergeMetaOptimizer(TestFleetMetaOptimizer): self.optimizer(avg_cost, strategy, train_prog, startup_prog) vars = [x.name for x in train_prog.list_vars()] - with open("main_program", 'w') as f: - f.write(str(train_prog)) - self.assertIn('@GradientMerge', ''.join(vars)) def test_recom_gm_optimizer(self): diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py new file mode 100644 index 00000000000..6a9f3e3ba7b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -0,0 +1,275 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import os +import paddle.distributed.fleet as fleet +import paddle.distributed.fleet.base.role_maker as role_maker + +from fleet_meta_optimizer_base import TestFleetMetaOptimizer + +paddle.enable_static() + + +class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): + def test_sharding_optimizer(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'sharding') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + parameters = [ + x.name for x in train_prog.list_vars() if x.persistable == True + ] + ops = [op.type for op in avg_cost.block.ops] + vars = [x.name for x in train_prog.list_vars()] + self.assertIn('@BroadCast', ''.join(vars)) + self.assertEqual( + set(parameters), + set([ + "fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0", + "fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0" + ])) + self.assertEqual(ops, [ + 'fill_constant', 'fill_constant', 'fill_constant', + 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', + 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', + 'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', + 'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2', + 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', + 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', + 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', + 'c_sync_comm_stream', 'momentum', 'momentum', 'momentum' + ]) + + def test_sharding_amp_optimizer(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'sharding') + self.set_strategy(strategy, 'amp') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + ops = [op.type for op in avg_cost.block.ops] + vars = [x.name for x in train_prog.list_vars()] + parameters = [ + x.name for x in train_prog.list_vars() if x.persistable == True + ] + self.assertIn('@BroadCast', ''.join(vars)) + self.assertIn('cast', ops) + self.assertIn('check_finite_and_unscale', ops) + self.assertEqual( + set(parameters), + set([ + "fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0", + "fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0", + "loss_scaling_0", "num_bad_steps_0", "num_good_steps_0" + ])) + self.assertEqual(ops, [ + 'cast', 'cast', 'cast', 'fill_constant', 'fill_constant', + 'fill_constant', 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_sync_comm_stream', 'cast', 'mul', 'elementwise_add', 'cast', + 'tanh', 'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast', + 'mul', 'elementwise_add', 'softmax', 'cast', 'cross_entropy2', + 'mean', 'elementwise_mul', 'fill_constant', 'scale', + 'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', 'cast', + 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast', + 'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast', + 'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', + 'c_sync_calc_stream', 'c_allreduce_sum', 'c_allreduce_sum', + 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', + 'c_allreduce_sum', 'c_sync_comm_stream', 'cast', 'cast', 'cast', + 'check_finite_and_unscale', 'cast', 'c_sync_calc_stream', + 'c_allreduce_max', 'c_sync_comm_stream', 'cast', + 'update_loss_scaling', 'momentum', 'momentum', 'momentum' + ]) + + def test_sharding_recompute_optimizer(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'sharding') + self.set_strategy(strategy, 'recompute') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + + ops = [op.type for op in avg_cost.block.ops] + vars = [x.name for x in train_prog.list_vars()] + parameters = [ + x.name for x in train_prog.list_vars() if x.persistable == True + ] + + self.assertIn('@BroadCast', ''.join(vars)) + self.assertIn('subprog', ''.join(vars)) + self.assertEqual( + set(parameters), + set([ + "fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0", + "fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0" + ])) + self.assertEqual(ops, [ + 'fill_constant', 'fill_constant', 'fill_constant', + 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', + 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', + 'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', + 'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2', + 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'mul', + 'elementwise_add', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', + 'mul', 'elementwise_add', 'tanh_grad', 'elementwise_add_grad', + 'mul_grad', 'c_sync_calc_stream', 'c_allreduce_sum', + 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', + 'c_allreduce_sum', 'c_allreduce_sum', 'c_sync_comm_stream', + 'momentum', 'momentum', 'momentum' + ]) + + def test_sharding_amp_recompute_optimizer(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'sharding') + self.set_strategy(strategy, 'recompute') + self.set_strategy(strategy, 'amp') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + + ops = [op.type for op in avg_cost.block.ops] + vars = [x.name for x in train_prog.list_vars()] + parameters = [ + x.name for x in train_prog.list_vars() if x.persistable == True + ] + + self.assertIn('@BroadCast', ''.join(vars)) + self.assertIn('subprog', ''.join(vars)) + self.assertIn('cast', ops) + self.assertIn('check_finite_and_unscale', ops) + + self.assertEqual( + set(parameters), + set([ + "fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0", + "fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0", + "loss_scaling_0", "num_bad_steps_0", "num_good_steps_0" + ])) + + self.assertEqual(ops, [ + 'cast', 'cast', 'cast', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', + 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', 'cast', 'cast', + 'mul', 'cast', 'elementwise_add', 'cast', 'tanh', 'cast', 'mul', + 'elementwise_add', 'cast', 'tanh', 'cast', 'mul', 'elementwise_add', + 'softmax', 'cast', 'cross_entropy2', 'mean', 'elementwise_mul', + 'fill_constant', 'scale', 'elementwise_mul_grad', 'mean_grad', + 'cross_entropy_grad2', 'cast', 'softmax_grad', + 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul', 'cast', + 'elementwise_add', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul', 'cast', + 'elementwise_add', 'cast', 'tanh_grad', 'cast', + 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', + 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', + 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', + 'c_sync_comm_stream', 'cast', 'cast', 'cast', + 'check_finite_and_unscale', 'cast', 'c_sync_calc_stream', + 'c_allreduce_max', 'c_sync_comm_stream', 'cast', + 'update_loss_scaling', 'momentum', 'momentum', 'momentum' + ]) + + def test_sharding_weight_decay(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'sharding') + regularization = paddle.fluid.regularizer.L2Decay(0.0001) + self.optimizer( + avg_cost, + strategy, + train_prog, + startup_prog, + regularization=regularization) + parameters = [ + x.name for x in train_prog.list_vars() if x.persistable == True + ] + ops = [op.type for op in avg_cost.block.ops] + vars = [x.name for x in train_prog.list_vars()] + self.assertIn('@BroadCast', ''.join(vars)) + self.assertEqual( + set(parameters), + set([ + "fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0", + "fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0" + ])) + + self.assertEqual(ops, [ + 'fill_constant', 'fill_constant', 'fill_constant', + 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', + 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', + 'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', + 'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2', + 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', + 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', + 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', + 'c_sync_comm_stream', 'scale', 'sum', 'scale', 'sum', 'scale', + 'sum', 'momentum', 'momentum', 'momentum' + ]) + + def test_sharding_gradient_clip(self): + train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( + ) + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'sharding') + clip = paddle.fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0) + self.optimizer( + avg_cost, strategy, train_prog, startup_prog, grad_clip=clip) + parameters = [ + x.name for x in train_prog.list_vars() if x.persistable == True + ] + ops = [op.type for op in avg_cost.block.ops] + vars = [x.name for x in train_prog.list_vars()] + self.assertIn('@BroadCast', ''.join(vars)) + self.assertEqual( + set(parameters), + set([ + "fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0", + "fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0" + ])) + self.assertEqual(ops, [ + 'fill_constant', 'fill_constant', 'fill_constant', + 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', + 'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh', + 'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean', + 'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2', + 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'tanh_grad', + 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', + 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', + 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', + 'c_sync_comm_stream', 'square', 'reduce_sum', 'square', + 'reduce_sum', 'square', 'reduce_sum', 'sum', 'c_sync_calc_stream', + 'c_allreduce_sum', 'c_sync_comm_stream', 'sqrt', 'fill_constant', + 'elementwise_max', 'elementwise_div', 'elementwise_mul', + 'elementwise_mul', 'elementwise_mul', 'momentum', 'momentum', + 'momentum' + ]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index f09c189a68e..f9395f8dd31 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -148,6 +148,7 @@ packages=['paddle', 'paddle.distributed.fleet', 'paddle.distributed.fleet.base', 'paddle.distributed.fleet.meta_optimizers', + 'paddle.distributed.fleet.meta_optimizers.sharding', 'paddle.distributed.fleet.runtime', 'paddle.distributed.fleet.dataset', 'paddle.distributed.fleet.data_generator', -- GitLab