未验证 提交 81244fbf 编写于 作者: M mapingshuo 提交者: GitHub

add sharding strategy in fleet(#27900)

* add sharding
上级 4877bd59
......@@ -24,6 +24,10 @@ enum Mode {
message RecomputeConfig { repeated string checkpoints = 1; }
message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
}
message AMPConfig {
optional float init_loss_scaling = 1 [ default = 32768.0 ];
optional int32 incr_every_n_steps = 2 [ default = 1000 ];
......@@ -130,6 +134,7 @@ message DistributedStrategy {
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ];
optional bool adaptive_localsgd = 24 [ default = false ];
optional bool fp16_allreduce = 25 [ default = false ];
optional bool sharding = 26 [ default = false ];
optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
......@@ -141,6 +146,7 @@ message DistributedStrategy {
optional LarsConfig lars_configs = 108;
optional LambConfig lamb_configs = 109;
optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110;
optional ShardingConfig sharding_configs = 111;
optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202;
}
......
......@@ -611,6 +611,55 @@ class DistributedStrategy(object):
"checkpoint_configs")
assign_configs_value(self.strategy.recompute_configs, configs)
@property
def sharding(self):
"""
Indicating whether we are using sharding Optimizer for memory
optimization
Default value: False
Examples:
.. code-block:: python
import paddle.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.sharding = True
"""
return self.strategy.sharding
@sharding.setter
@is_strict_auto
def sharding(self, flag):
if isinstance(flag, bool):
self.strategy.sharding = flag
else:
print("WARNING: sharding should have value of bool type")
@property
def sharding_configs(self):
"""
Set sharding configurations.
**Note**:
fuse_broadcast_MB(float): size of a fused group of broadcasted parameters.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.sharding = True
strategy.sharding_configs = {"fuse_broadcast_MB": 32}
"""
return get_msg_dict(self.strategy.sharding_configs)
@sharding_configs.setter
@is_strict_auto
def sharding_configs(self, configs):
check_configs_key(self.strategy.sharding_configs, configs,
"sharding_configs")
assign_configs_value(self.strategy.sharding_configs, configs)
@property
def pipeline(self):
"""
......
......@@ -24,3 +24,4 @@ from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer
from .dgc_optimizer import DGCOptimizer
from .lamb_optimizer import LambOptimizer
from .fp16_allreduce_optimizer import FP16AllReduceOptimizer
from .sharding_optimizer import ShardingOptimizer
......@@ -99,6 +99,12 @@ class CollectiveHelper(object):
OP_ROLE_KEY: OpRole.Forward
})
def _wait(self, current_endpoint, endpoints):
assert (self.wait_port)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
wait_server_ready(other_endpoints)
def _broadcast_params(self):
block = self.startup_program.global_block()
ring_id = -1
......
......@@ -30,6 +30,10 @@ class DGCOptimizer(MetaOptimizerBase):
super(DGCOptimizer, self)._set_basic_info(
loss, role_maker, user_defined_optimizer, user_defined_strategy)
def _init_dgc_opt(self):
if self.dgc_opt is not None:
return
opt = self.inner_opt
if not self.role_maker._is_collective:
......@@ -86,13 +90,16 @@ class DGCOptimizer(MetaOptimizerBase):
parameter_list=None,
no_grad_set=None,
callbacks=None):
self._init_dgc_opt()
return self.dgc_opt.backward(loss, startup_program, parameter_list,
no_grad_set, callbacks)
def apply_gradients(self, params_grads):
self._init_dgc_opt()
return self.dgc_opt.apply_gradients(params_grads=params_grads)
def apply_optimize(self, loss, startup_program, params_grads):
self._init_dgc_opt()
return self.dgc_opt.apply_optimize(
loss, startup_program=startup_program, params_grads=params_grads)
......@@ -101,6 +108,7 @@ class DGCOptimizer(MetaOptimizerBase):
startup_program=None,
parameter_list=None,
no_grad_set=None):
self._init_dgc_opt()
optimize_ops, params_grads = \
self.dgc_opt.minimize(loss, startup_program,
parameter_list, no_grad_set)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op, OP_ROLE_KEY, OpRole
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
from paddle.fluid import core
class FP16Utils(object):
def __init__(self):
pass
@staticmethod
def is_fp16_cast_op(block, op, params):
if op.type != "cast":
return False
if is_optimizer_op(op):
return False
assert (len(op.desc.input_arg_names()) == 1)
assert (len(op.desc.output_arg_names()) == 1)
input_name, output_name = op.desc.input_arg_names()[
0], op.desc.output_arg_names()[0]
if input_name not in params:
return False
input_var = block.var(input_name)
output_var = block.var(output_name)
if input_var.dtype != core.VarDesc.VarType.FP32 or \
output_var.dtype != core.VarDesc.VarType.FP16:
return False
return True
@staticmethod
def is_fp32_cast_op(block, op):
if op.type != "cast":
return False
if not is_optimizer_op(op):
return False
assert (len(op.desc.input_arg_names()) == 1)
assert (len(op.desc.output_arg_names()) == 1)
input_name, output_name = op.desc.input_arg_names()[
0], op.desc.output_arg_names()[0]
input_var = block.var(input_name)
output_var = block.var(output_name)
if input_var.dtype != core.VarDesc.VarType.FP16 or \
output_var.dtype != core.VarDesc.VarType.FP32:
return False
return True
@staticmethod
def remove_cast_op(block, params, segment, offset):
inserted_op_num = 0
for op_idx in reversed(
range(offset + segment._start_idx, offset + segment._end_idx)):
op = block.ops[op_idx]
if FP16Utils.is_fp16_cast_op(block, op, params):
block._remove_op(op_idx, sync=False)
inserted_op_num -= 1
block._sync_with_cpp()
return inserted_op_num
@staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, nrings):
# remove cast
for idx, op in reversed(list(enumerate(block.ops))):
if not FP16Utils.is_fp32_cast_op(block, op):
continue
output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError("Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad".format(
input_name))
if output_name in reduced_grads_to_param:
continue
if shard.has_param(param_name):
continue
block._remove_op(idx, sync=False)
block._remove_var(output_name, sync=False)
block._sync_with_cpp()
update_loss_scaling_op_idx = -1
inf_var_name = ''
for idx, op in reversed(list(enumerate(block.ops))):
if op.type == "update_loss_scaling":
update_loss_scaling_op_idx = idx
inf_var_name = op.desc.input('FoundInfinite')[0]
op._rename_input(inf_var_name, inf_var_name + "@sharding")
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
reversed_x = []
for input_name in op.desc.input('X'):
param_name = input_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError(
"Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad".format(input_name))
if shard.has_param(param_name):
reversed_x.append(input_name)
op.desc.set_input('X', reversed_x)
op.desc.set_output('Out', reversed_x)
if update_loss_scaling_op_idx == -1:
return
inf_var = block.var(inf_var_name)
inf_var_fp32 = block.create_var(
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
inf_var_sharding = block.create_var(
name=inf_var_name + "@sharding",
shape=inf_var.shape,
dtype=inf_var.dtype)
block._insert_op_without_sync(
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var},
outputs={'Out': inf_var_fp32},
attrs={
"in_dtype": inf_var.dtype,
"out_dtype": inf_var_fp32.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
insert_sync_calc_op(block, update_loss_scaling_op_idx + 1,
[inf_var_fp32])
block._insert_op_without_sync(
update_loss_scaling_op_idx + 2,
type='c_allreduce_max',
inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_fp32},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
comm_op_num = insert_sync_comm_ops(
block, update_loss_scaling_op_idx + 3, nrings, [inf_var_fp32])
block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num,
type='cast',
inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_sharding},
attrs={
"in_dtype": inf_var_fp32.dtype,
"out_dtype": inf_var_sharding.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
block._sync_with_cpp()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
class GradientClipHelper(object):
def __init__(self):
pass
def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")
def prune_gradient_clip(self, block, shard):
deperated_vars = set()
deperate_op_idx = set()
for idx, op in enumerate(block.ops):
if not self._is_gradient_clip_op(op):
continue
if op.type == "sum":
continue
deperate_op = False
for input_name in op.desc.input_arg_names():
if input_name in deperated_vars:
deperate_op = True
param_name = input_name.strip("@GRAD")
if shard.is_param(param_name) and \
not shard.has_param(param_name):
deperate_op = True
if deperate_op:
deperate_op_idx.add(idx)
for output_name in op.desc.output_arg_names():
deperated_vars.add(output_name)
if not deperated_vars:
# got no gradient_clip op
return
for idx, op in reversed(list(enumerate(block.ops))):
if not self._is_gradient_clip_op(op):
continue
if idx in deperate_op_idx:
block._remove_op(idx, sync=False)
continue
reversed_inputs = []
if op.type == "sum":
for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars:
reversed_inputs.append(input_name)
op.desc.set_input("X", reversed_inputs)
assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync(
idx + 1,
type='c_sync_comm_stream',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
block._insert_op_without_sync(
idx + 1,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
block._insert_op_without_sync(
idx + 1,
type='c_sync_calc_stream',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={OP_ROLE_KEY: OpRole.Optimize})
for var_name in deperated_vars:
block._remove_var(var_name, sync=False)
block._sync_with_cpp()
return
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class ProgramDeps(object):
def __init__(self, block, start_vars, end_vars):
self._block = block
# vars where to start to build the deps
self._start_vars = start_vars
# vars where to stop to build the deps
self._end_vars = end_vars
# var name -> op idxs which depends on this var
self._var_to_use_op = {}
# sub block deps which is a subset of this topo
self._sub_block_deps = {}
# var name -> op idxs which generate var
self._var_to_generate_op = {}
self._should_removed_var = set()
self._father_block_deps = None
self._build_deps()
def get_sub_block_deps(self, idx):
if idx in self._sub_block_deps:
return self._sub_block_deps[idx]
else:
return None
def get_var_deps(self, var_name):
if var_name in self._var_to_use_op:
return self._var_to_use_op[var_name]
else:
return None
def _build_deps(self, ):
for var_name in self._start_vars:
self._var_to_use_op[var_name] = []
self._var_to_generate_op[var_name] = []
for idx, op in enumerate(self._block.ops):
if op.type in [
"c_allreduce_sum", "c_sync_comm_stream",
"c_calc_comm_stream"
]:
continue
input_vars = op.desc.input_arg_names()
output_vars = op.desc.output_arg_names()
deps_reduce = False
for input_name in input_vars:
if input_name in self._var_to_use_op:
deps_reduce = True
if not deps_reduce:
continue
for input_name in input_vars:
if input_name in self._var_to_use_op:
self._var_to_use_op[input_name].append(idx)
for output_name in output_vars:
if output_name not in self._var_to_use_op:
self._var_to_use_op[output_name] = []
if output_name not in self._var_to_generate_op:
self._var_to_generate_op[output_name] = [idx]
else:
self._var_to_generate_op[output_name].append(idx)
if op.type == "conditional_block":
# subblock
assert (op.desc.has_attr("sub_block"))
subblock_idx = op.desc.attr("sub_block").id
subblock_deps = ProgramDeps(
self._block.program.block(subblock_idx),
op.desc.input_arg_names(), op.desc.output_arg_names())
self._sub_block_deps[subblock_idx] = subblock_deps
subblock_deps._father_block_deps = self
def crop_input_var_from_op(self, op_idx, var_name):
if var_name in self._var_to_use_op:
# update var -> dep_var_op
if self._var_to_use_op[var_name] != []:
if op_idx not in self._var_to_use_op[var_name]:
raise ValueError(
"op_idx: {} is not in self._var_to_use_op[{}], "
"self._var_to_use_op[{}] is {}".format(
op_idx, var_name, var_name, self._var_to_use_op[
var_name]))
self._var_to_use_op[var_name].remove(op_idx)
# update _should_removed_var
if var_name in self._start_vars:
self._should_removed_var.discard(var_name)
elif self._var_to_use_op[
var_name] == []: # no more deps of this var
self._should_removed_var.add(var_name)
elif self._var_to_generate_op[var_name][-1] >= self._var_to_use_op[
var_name][-1]:
# there are circle in the graph
self._should_removed_var.add(var_name)
else: # input_name should not be deleted
self._should_removed_var.discard(var_name)
def crop_output_var_from_op(self, op_idx, var_name):
if var_name in self._var_to_generate_op:
assert (op_idx in self._var_to_generate_op[var_name])
self._var_to_generate_op[var_name].remove(op_idx)
if self._block.has_var(var_name):
if var_name not in self._var_to_generate_op or self._var_to_generate_op[
var_name] == []:
self._block._remove_var(var_name, sync=False)
def remove_op(self, op_idx):
# update deps
op = self._block.ops[op_idx]
for input_name in op.desc.input_arg_names():
self.crop_input_var_from_op(op_idx, input_name)
for output_name in op.desc.output_arg_names():
self.crop_output_var_from_op(op_idx, output_name)
self._block._remove_op(op_idx, sync=False)
def should_remove_op(self, op_idx):
op = self._block.ops[op_idx]
for output_name in op.desc.output_arg_names():
if output_name not in self._should_removed_var:
return False
return True
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
class Shard(object):
def __init__(self, ):
self.global_params = set([])
self.worker_idx = -1
self.worker_num = -1
self.global_param2device = {}
def setup(self, params_grads, worker_idx, worker_num):
# param names of all devices
self.global_params = set([x[0].name for x in params_grads])
# _param(str) -> device_id(int)
self.worker_idx = worker_idx
self.worker_num = worker_num
# global_param2device contains fp32 params and fp16 params
self.global_param2device = self._split_params(params_grads, worker_idx,
worker_num)
def has_param(self, var_name):
return var_name in self.global_param2device and \
self._var_device_id(var_name) == self.worker_idx
def has_opt_var(self, var_name):
return self._var_device_id(var_name) == self.worker_idx
def has_var(self, var_name):
return self._var_device_id(var_name) == -1 or \
self._var_device_id(var_name) == self.worker_idx
def _split_params(self, params_grads, worker_idx, worker_num):
param2device = {}
total_param_mem = 0.0
param2mem = []
for param in [x[0] for x in params_grads]:
mem = get_var_size(param)
total_param_mem += mem
param2mem.append((param.name, mem))
device2params = {x: [] for x in range(worker_num)}
device_idx = 0
mem_accu = 0.0
for param_name, mem in param2mem:
if mem_accu > total_param_mem * 1.0 * (device_idx + 1) / worker_num:
device_idx += 1
device2params[device_idx].append(param_name)
param2device[param_name] = device_idx
mem_accu += mem
return param2device
def _var_device_id(self, var_name):
if var_name in self.global_param2device:
return self.global_param2device[var_name]
for suffix in [
"_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
"_beta2_pow_acc_0", "_velocity_0"
]:
base_name = re.sub(suffix, '', var_name)
if base_name in self.global_param2device:
return self.global_param2device[base_name]
return -1
def find_broadcast_params(self, block):
broadcast_vars = set([])
fp16_params = set([])
fp16_to_fp32 = {}
param_usage = {x: 0 for x in self.global_params}
for op in block.ops:
if is_optimizer_op(op):
continue
for input_name in op.desc.input_arg_names():
if input_name in self.global_params:
param_usage[input_name] += 1
for op in block.ops:
if not FP16Utils.is_fp16_cast_op(block, op, self.global_params):
continue
input_name = op.input_arg_names[0]
output_name = op.output_arg_names[0]
broadcast_vars.add(output_name)
fp16_params.add(output_name)
fp16_to_fp32[output_name] = input_name
param_usage[input_name] -= 1
self.global_param2device[output_name] = self.global_param2device[
input_name]
for param, usage in param_usage.items():
if usage > 0:
broadcast_vars.add(param)
return broadcast_vars
def device(self, var_name):
return self._var_device_id(var_name)
def is_param(self, var_name):
return var_name in self.global_params
def is_opti_var(self, var_name):
if var_name in self.global_params:
return True
for suffix in [
"_moment1_0", "_moment2_0", "_beta1_pow_acc_0",
"_beta2_pow_acc_0", "_velocity_0"
]:
base_name = re.sub(suffix, '', var_name)
if base_name in self.global_params:
return True
return False
class ProgramSegment(object):
def __init__(self, block):
self._block = block
self._allreduce_vars = []
# sub program start idx
self._start_idx = -1
# sub program end idx
self._end_idx = -1
# param name to broadcast name
self._param2broadcast = {}
self._broadcast_vars = []
# cast op pairs, fp16 name (str) -> fp32 name (str)
self._cast_ops = {}
# fill constant vars
self._fill_constant_vars = []
# parameter mems
self._param_mem = 0.0
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import core
from functools import reduce
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
import re
def check_broadcast(block):
"""
if a var is broadcasted, it should have a sync_comm before
this var is used, if not, raise error.
if the broadcasted var has a fill_constant op, the fill_constant
op should stay forward before the broadcast op, and before a
sync_calc op. Otherwise, raise error.
"""
broadcast_vars = {}
for idx, op in enumerate(block.ops):
if op.type == "c_broadcast":
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if var_name in broadcast_vars:
raise ValueError("var_name areadly exist: {}"
"the old pos is {}, the new pos is {}".
format(var_name, broadcast_vars[var_name][
"broadcast_pos"], idx))
broadcast_vars[var_name] = {
"fill_constant_pos": -1,
"broadcast_pos": idx,
}
for idx, op in enumerate(block.ops):
if op.type == "fill_constant":
var_name = op.desc.output_arg_names()[0]
if var_name in broadcast_vars:
broadcast_vars[var_name]["fill_constant_pos"] = idx
continue
last_sync_comm_op_idx = -1
last_sync_calc_op_idx = -1
for idx, op in enumerate(block.ops):
if op.type == "c_sync_comm_stream":
last_sync_comm_op_idx = idx
continue
if op.type == "c_sync_calc_stream":
last_sync_calc_op_idx = idx
continue
if op.type == "c_broadcast":
var_name = op.desc.input_arg_names()[0]
if "@BroadCast" in var_name:
if broadcast_vars[var_name]["fill_constant_pos"] != -1:
assert (last_sync_calc_op_idx != -1)
assert (broadcast_vars[var_name]["fill_constant_pos"] <
last_sync_calc_op_idx)
assert (last_sync_calc_op_idx < idx)
continue
for input_name in op.desc.input_arg_names():
if input_name in broadcast_vars:
assert (broadcast_vars[input_name]["broadcast_pos"] != -1)
assert (broadcast_vars[input_name]["broadcast_pos"] <
last_sync_comm_op_idx)
assert (last_sync_comm_op_idx < idx)
return
def check_allreduce_sum(block):
"""
if a Var is allreduced, the op order should be:
- 0: op that generate Var
- 1: sync_calc
- 2: allreduce_sum op
- 3: sync_comm
- 4: op that use Var
"""
var_status = {}
for op in block.ops:
if op.type == "c_allreduce_sum":
var_name = op.desc.input_arg_names()[0]
var_status[var_name] = -1
for op in block.ops:
if op.type == "c_sync_calc_stream":
for var_name in var_status:
if var_name in var_status and var_status[var_name] == 0:
var_status[var_name] = 1
elif op.type == "c_allreduce_sum":
var_name = op.desc.input_arg_names()[0]
if var_status[var_name] == -1:
raise ValueError("{} is not generated, but you are"
"trying to all-reduce it".format(var_name))
if var_status[var_name] == 0:
raise ValueError("There should be a sync_calc op "
"after generate Var: {} and before the"
"c_allreduce_sum op".format(var_name))
assert (var_status[var_name] == 1)
var_status[var_name] = 2
elif op.type == "c_sync_comm_stream":
for var_name in op.desc.input_arg_names():
if var_name in var_status and var_status[var_name] == 2:
var_status[var_name] = 3
else:
for input_name in op.desc.input_arg_names():
if input_name in var_status:
if var_status[input_name] != 3:
raise ValueError("There should be a sync_comm op "
"after allreduce the Var: {}".format(
var_name))
for output_name in op.desc.output_arg_names():
if output_name in var_status and \
var_status[output_name] == -1:
var_status[output_name] = 0
return
def insert_sync_calc_op(block, insert_idx, calc_dep_vars):
"""
_insert_sync_calc_op
"""
op_role = block.ops[insert_idx].attr('op_role')
block._insert_op_without_sync(
insert_idx,
type='c_sync_calc_stream',
inputs={'X': calc_dep_vars},
outputs={'Out': calc_dep_vars},
attrs={OP_ROLE_KEY: op_role})
return
def insert_sync_comm_ops(block, insert_idx, nrings, comm_dep_vars):
"""
_insert_sync_comm_ops
"""
op_role = block.ops[insert_idx].attr('op_role')
for i in range(nrings):
block._insert_op_without_sync(
insert_idx,
type='c_sync_comm_stream',
inputs={'X': comm_dep_vars},
outputs={'Out': comm_dep_vars},
attrs={'ring_id': i,
OP_ROLE_KEY: op_role})
return nrings
def insert_fill_constant_ops(block, insert_idx, fill_constant_vars):
"""
_add_fill_constant_ops
"""
op_role = block.ops[insert_idx].attr('op_role')
for broadcast_name in fill_constant_vars:
broadcast_var = block.var(broadcast_name)
block._insert_op_without_sync(
insert_idx,
type="fill_constant",
outputs={"Out": broadcast_var.name},
attrs={
"shape": broadcast_var.shape,
"dtype": broadcast_var.dtype,
"value": 0.0,
OP_ROLE_KEY: op_role
})
return
def insert_cast_ops(block, insert_idx, cast_ops):
"""
_add_cast_ops
"""
op_role = block.ops[insert_idx].attr('op_role')
for fp16_name, fp32_name in cast_ops.items():
block._insert_op_without_sync(
insert_idx,
type="cast",
inputs={"X": fp32_name},
outputs={"Out": fp16_name},
attrs={
"in_dtype": core.VarDesc.VarType.FP32,
"out_dtype": core.VarDesc.VarType.FP16,
OP_ROLE_KEY: op_role
})
return
def insert_allreduce_ops(block, insert_idx, nrings, allreduce_vars):
"""
_add_allreduce_ops
"""
ring_id = -1
for var in allreduce_vars:
ring_id = (ring_id + 1) % nrings
block._insert_op_without_sync(
insert_idx,
type='c_allreduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward})
return
def insert_broadcast_ops(block, insert_idx, nrings, broadcast2root):
"""
_add_broadcast_ops
"""
ring_id = -1
op_role = block.ops[insert_idx].attr('op_role')
for broadcast_name, root_device in broadcast2root:
ring_id = (ring_id + 1) % nrings
block._insert_op_without_sync(
insert_idx,
type='c_broadcast',
inputs={'X': broadcast_name},
outputs={'Out': broadcast_name},
attrs={
'ring_id': ring_id,
'root': root_device,
OP_ROLE_KEY: op_role
})
return
DtypeToSize = {
core.VarDesc.VarType.FP16: 2,
core.VarDesc.VarType.FP32: 4,
core.VarDesc.VarType.FP64: 8,
core.VarDesc.VarType.INT16: 2,
core.VarDesc.VarType.INT32: 4,
core.VarDesc.VarType.INT64: 8,
core.VarDesc.VarType.BOOL: 1,
core.VarDesc.VarType.UINT8: 1,
}
def get_var_size(param):
"""
input:
- param: var
return:
var size in Bytes
"""
assert -1 not in param.shape
return reduce(lambda x, y: x * y,
param.shape) * DtypeToSize[param.dtype] / 1024.0 / 1024.0
def insert_scale_loss_grad_ops(block, scale=1.0):
'''
In order to keep the learning rate consistent in different numbers of
training workers, we scale the loss grad by the number of workers
'''
for idx, op in reversed(list(enumerate(block.ops))):
if is_loss_grad_op(op):
loss_grad_var = block.vars[op.output_arg_names[0]]
block._insert_op_without_sync(
idx + 1,
type='scale',
inputs={'X': loss_grad_var},
outputs={'Out': loss_grad_var},
attrs={'scale': scale,
OP_ROLE_KEY: OpRole.Backward})
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_VAR_KEY
class WeightDecayHelper(object):
def __init__(self):
pass
def _is_weight_decay_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/regularization")
def prune_weight_decay(self, block, shard):
for idx, op in reversed(list(enumerate(block.ops))):
if not self._is_weight_decay_op(op):
continue
if OP_ROLE_VAR_KEY not in op.attr_names:
raise ValueError(
"The Weight Dacay op should hold op_role_var attribute"
"but the {} op does not hold op_role_var".format(op.type))
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if not shard.has_param(op_role_var[0]):
block._remove_op(idx, sync=False)
block._sync_with_cpp()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import unique_name, core
import paddle.fluid as fluid
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op
from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase
from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper import WeightDecayHelper
from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper
from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
from functools import reduce
__all__ = ["ShardingOptimizer"]
class ShardingOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
super(ShardingOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
self.meta_optimizers_white_list = [
"RecomputeOptimizer",
"AMPOptimizer",
]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self._main_program = None
self._startup_program = None
self._segments = []
# params and fp16 params is for broadcast
self._params = set([])
self._broadcast_vars = set([])
# reduced grads to param name
self._reduced_grads_to_param = {}
self._shard = Shard()
def _can_apply(self):
if not self.role_maker._is_collective:
return False
if self.role_maker._worker_num() <= 1:
return False
return self.user_defined_strategy.sharding
def _disable_strategy(self, dist_strategy):
dist_strategy.sharding = False
dist_strategy.sharding_configs = {}
def _enable_strategy(self, dist_strategy, context):
dist_strategy.sharding = True
dist_strategy.sharding_configs = {"fuse_broadcast_MB": 32}
def minimize_impl(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
self._nrings = self.user_defined_strategy.nccl_comm_num
self._fuse_broadcast_MB = self.user_defined_strategy.sharding_configs[
"fuse_broadcast_MB"]
if self.inner_opt is None:
raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.")
optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set)
if startup_program is None:
startup_program = default_startup_program()
main_block = loss.block
startup_block = startup_program.global_block()
self._main_program = main_block.program
self._startup_program = startup_program
# step1: set_up
self._set_up(params_grads)
# step2: split_program
self._split_program(main_block)
# step3: add broadcast and reduce ops
self._add_broadcast_allreduce(main_block)
main_block._sync_with_cpp()
startup_block._sync_with_cpp()
# step4: insert reduce_sum for grad
insert_scale_loss_grad_ops(
main_block, scale=1.0 / self.role_maker._worker_num())
main_block._sync_with_cpp()
# step5: remove unneeded ops and vars from block
self._prune_main_program(main_block)
self._prune_startup_program(startup_block)
# check op dependecy
check_broadcast(main_block)
check_allreduce_sum(main_block)
self._wait()
return optimize_ops, params_grads
def _set_up(self, params_grads):
# step 1: initialize nccl
worker_idx = self.role_maker._worker_index()
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[worker_idx]
self._collective_helper = CollectiveHelper(self.role_maker,
self._nrings)
for ring_id in range(self._nrings):
self._collective_helper._init_communicator(
self._startup_program, current_endpoint, endpoints, worker_idx,
ring_id, None)
startup_block = self._startup_program.global_block()
startup_block._sync_with_cpp()
# step 2: split params
self._params = set([x[0].name for x in params_grads])
self._shard.setup(params_grads, worker_idx,
self.role_maker._worker_num())
# step 3: get broadcast vars
self._broadcast_vars = self._shard.find_broadcast_params(
self._main_program.global_block())
def _wait(self, ):
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker._worker_index()]
if self.role_maker._worker_index() == 0:
self._collective_helper._wait(current_endpoint, endpoints)
def _split_program(self, block):
for op_idx, op in reversed(list(enumerate(block.ops))):
if int(op.attr('op_role')) != int(OpRole.Optimize):
last_backward_op_idx = op_idx + 1
break
segment = ProgramSegment(block)
segment._end_idx = last_backward_op_idx
for op_idx in reversed(range(last_backward_op_idx)):
op = block.ops[op_idx]
assert (int(op.attr('op_role')) != int(OpRole.Optimize))
if segment._param_mem >= self._fuse_broadcast_MB:
segment._start_idx = op_idx + 1
self._segments.insert(0, segment)
segment = ProgramSegment(block)
segment._end_idx = op_idx + 1
# find broadcast vars
for input_name in op.desc.input_arg_names():
if input_name not in self._broadcast_vars:
continue
if input_name in segment._param2broadcast:
# skip broadcast because it reuse the old broadcast var
broadcast_name = segment._param2broadcast[input_name]
if input_name != broadcast_name:
op._rename_input(input_name, broadcast_name)
continue
if self._shard.has_param(input_name):
broadcast_var_name = input_name
else:
broadcast_var_name = unique_name.generate(input_name +
"@BroadCast")
segment._fill_constant_vars.append(broadcast_var_name)
segment._param2broadcast[input_name] = broadcast_var_name
segment._broadcast_vars.append((broadcast_var_name,
self._shard.device(input_name)))
segment._param_mem += get_var_size(
self._main_program.global_block().var(input_name))
# find reduce vars
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) != 0:
assert len(op_role_var) % 2 == 0
for i in range(0, len(op_role_var), 2):
param, reduced_grad = op_role_var[i], op_role_var[i + 1]
segment._allreduce_vars.append(reduced_grad)
assert (
reduced_grad not in self._reduced_grads_to_param)
self._reduced_grads_to_param[reduced_grad] = param
# find cast op
if FP16Utils.is_fp16_cast_op(block, op, self._params):
fp32_param = op.desc.input_arg_names()[0]
fp16_param = op.desc.output_arg_names()[0]
if self._shard.has_param(fp32_param):
segment._cast_ops[fp16_param] = fp32_param
if segment._param_mem > 0:
segment._start_idx = 0
self._segments.insert(0, segment)
return
def _prune_main_program(self, block):
"""
calculate deps from allredce op to optimize op,
remove ops and vars not needed in this worker
"""
weightdecay_helper = WeightDecayHelper()
weightdecay_helper.prune_weight_decay(block, self._shard)
FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param,
self._nrings)
gradientclip_helper = GradientClipHelper()
gradientclip_helper.prune_gradient_clip(block, self._shard)
# build prog deps
reduced_grads = []
for idx, op in enumerate(block.ops):
input_names = op.desc.input_arg_names()
output_names = op.desc.output_arg_names()
if op.type == "c_allreduce_sum":
assert (len(output_names) == 1)
output_name = output_names[0]
reduced_grads.append(output_name)
pruned_opti_vars = []
for var_name in list(block.vars.keys()):
if self._shard.is_opti_var(var_name) and \
not self._shard.has_opt_var(var_name):
pruned_opti_vars.append(var_name)
program_deps = ProgramDeps(block, reduced_grads, pruned_opti_vars)
# Init
for var_name in program_deps._end_vars:
program_deps._should_removed_var.add(var_name)
# Prune
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in [
"c_allreduce_sum", "c_sync_comm_stream",
"c_calc_comm_stream", "c_gen_nccl_id", "c_comm_init"
]:
pass
elif op.type == "conditional_block":
assert (op.desc.has_attr("sub_block"))
subblock_idx = op.desc.attr("sub_block").id
subblock_deps = program_deps.get_sub_block_deps(subblock_idx)
# only prune amp subblock
if subblock_deps is None or not self._is_amp_subblock(op):
continue
# init
reversed_output_vars = []
for output_name in op.desc.output("Out"):
if output_name in program_deps._should_removed_var:
subblock_deps._should_removed_var.add(output_name)
program_deps.crop_output_var_from_op(idx, output_name)
else:
reversed_output_vars.append(output_name)
# prune
for sub_op_idx, _ in reversed(
list(enumerate(subblock_deps._block.ops))):
if subblock_deps.should_remove_op(sub_op_idx):
subblock_deps.remove_op(sub_op_idx)
reversed_input_vars = []
for input_name in op.desc.input('Input'):
if input_name not in subblock_deps._should_removed_var:
reversed_input_vars.append(input_name)
else:
program_deps.crop_input_var_from_op(idx, input_name)
op.desc.set_input('Input', reversed_input_vars)
op.desc.set_output('Out', reversed_output_vars)
else:
if program_deps.should_remove_op(idx):
program_deps.remove_op(idx)
block._sync_with_cpp()
return
def _add_broadcast_allreduce(self, block):
"""
_add_broadcast_allreduce
"""
ring_id = -1
if len(self._segments) < 1:
return
if self._segments[-1]._allreduce_vars:
insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self._nrings,
self._segments[-1]._allreduce_vars)
insert_allreduce_ops(block, self._segments[-1]._end_idx,
self._nrings,
self._segments[-1]._allreduce_vars)
for idx, segment in reversed(list(enumerate(self._segments))):
allreduce_vars = self._segments[
idx - 1]._allreduce_vars if idx > 0 else []
broadcast_vars = self._segments[idx +
1]._broadcast_vars if idx < len(
self._segments) - 1 else []
fill_constant_vars = self._segments[
idx + 2]._fill_constant_vars if idx < len(
self._segments) - 2 else []
cast_ops = self._segments[idx + 2]._cast_ops if idx < len(
self._segments) - 2 else {}
for op_idx in reversed(range(segment._start_idx, segment._end_idx)):
op = block.ops[op_idx]
for input_name in op.desc.input_arg_names():
if input_name in segment._param2broadcast and \
input_name != segment._param2broadcast[input_name]:
op._rename_input(input_name,
segment._param2broadcast[input_name])
for param_name, broadcast_name in segment._param2broadcast.items():
if param_name != broadcast_name:
block.create_var(
name=broadcast_name,
shape=self._main_program.global_block().var(
param_name).shape,
dtype=self._main_program.global_block().var(param_name)
.dtype,
persistable=False)
# step1: remove cast ops
block._sync_with_cpp()
segment._end_idx += FP16Utils.remove_cast_op(block, self._params,
segment, 0)
# step2: add Sync ops
comm_dep_vars = allreduce_vars + [x[0] for x in broadcast_vars]
if len(comm_dep_vars) > 0:
insert_sync_comm_ops(
block,
segment._end_idx,
self._nrings,
comm_dep_vars, )
calc_dep_vars = fill_constant_vars + [
k for k, v in cast_ops.items()
] + self._segments[idx]._allreduce_vars
if len(calc_dep_vars) > 0:
insert_sync_calc_op(block, segment._end_idx,
[calc_dep_vars[-1]])
# step3: insert `fill_constant` ops
insert_fill_constant_ops(block, segment._end_idx,
fill_constant_vars)
# step4: add `cast` ops
insert_cast_ops(block, segment._end_idx, cast_ops)
# step5: add broadcast ops
insert_broadcast_ops(block, segment._start_idx, self._nrings,
broadcast_vars)
# step6: add all_reduce ops
insert_allreduce_ops(block, segment._start_idx, self._nrings,
allreduce_vars)
block._sync_with_cpp()
if self._segments[0]._broadcast_vars:
insert_sync_comm_ops(
block, self._segments[0]._start_idx, self._nrings,
[x[0] for x in self._segments[0]._broadcast_vars])
insert_broadcast_ops(block, self._segments[0]._start_idx,
self._nrings,
self._segments[0]._broadcast_vars)
fill_constant_vars = []
for x in self._segments[:2]:
fill_constant_vars += x._fill_constant_vars
# Join
cast_ops = {}
for x in self._segments[:2]:
for k, v in x._cast_ops.items():
cast_ops[k] = v
calc_deps_vars = fill_constant_vars + [k for k, v in cast_ops.items()]
if fill_constant_vars or cast_ops:
insert_sync_calc_op(block, self._segments[0]._start_idx,
[calc_deps_vars[-1]])
if fill_constant_vars:
insert_fill_constant_ops(block, self._segments[0]._start_idx,
fill_constant_vars)
if cast_ops:
insert_cast_ops(block, self._segments[0]._start_idx, cast_ops)
return
def _prune_startup_program(self, block):
for idx, op in reversed(list(enumerate(block.ops))):
for output_name in op.desc.output_arg_names():
if self._shard.has_var(output_name):
continue
#TODO why do we remove op, when only one var is removed
block._remove_op(idx, sync=False)
break
for var_name in list(block.vars.keys()):
if self._shard.has_var(var_name):
continue
block._remove_var(var_name, sync=False)
block._sync_with_cpp()
......@@ -669,7 +669,7 @@ def append_gradient_clip_ops(param_grads):
if g is None:
continue
with p.block.program._optimized_guard(
[p, g]), framework.name_scope('gradient_clip_@CLIP'):
[p, g]), framework.name_scope('gradient_clip'):
clip_attr = getattr(p, 'gradient_clip_attr', None)
if clip_attr is None:
return param_grads
......@@ -685,7 +685,7 @@ def append_gradient_clip_ops(param_grads):
if g is None:
continue
with p.block.program._optimized_guard(
[p, g]), framework.name_scope('graident_clip_@CLIP'):
[p, g]), framework.name_scope('gradient_clip'):
param, new_grad = clip_attr._create_operators(param=p, grad=g)
param_new_grad_name_dict[param.name] = new_grad.name
res.append([param, new_grad])
......
......@@ -2100,10 +2100,16 @@ class Operator(object):
% (out_proto.name, len(out_args)))
out_arg_names = []
for arg in out_args:
out_arg_names.append(cpt.to_text(arg.name))
if isinstance(arg, six.string_types):
out_arg_names.append(arg)
else:
out_arg_names.append(cpt.to_text(arg.name))
# TODO(minqiyang): could we remove variable's op in static mode?
if not in_dygraph_mode():
arg.op = self
if isinstance(arg, six.string_types):
block.var(arg).op = self
else:
arg.op = self
self.desc.set_output(out_proto.name, out_arg_names)
if op_attrs is not None:
......@@ -2837,8 +2843,9 @@ class Block(object):
self._sync_with_cpp()
return var
def _remove_var(self, name):
self._sync_with_cpp()
def _remove_var(self, name, sync=True):
if sync == True:
self._sync_with_cpp()
self.desc._remove_var(cpt.to_bytes(name))
del self.vars[name]
......@@ -2936,7 +2943,23 @@ class Block(object):
self.ops.insert(index, op)
return op
def _remove_op(self, index):
def _insert_op_without_sync(self, index, *args, **kwargs):
"""
Insert an Operator according to the giving arguments,
without sync_with_cpp to meke the compilation faster.
Args:
index(int): the place that the operator to insert.
Returns:
Operator: the insert Operator.
"""
op_desc = self.desc._insert_op(index)
op = Operator(block=self, desc=op_desc, *args, **kwargs)
self.ops.insert(index, op)
return op
def _remove_op(self, index, sync=True):
"""
Remove the specific position operator.
......@@ -2946,7 +2969,8 @@ class Block(object):
Returns:
None
"""
self._sync_with_cpp()
if sync == True:
self._sync_with_cpp()
self.desc._remove_op(index, index + 1)
del self.ops[index]
......
......@@ -41,6 +41,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_recompute_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_pipeline_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_sharding_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_localsgd_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_lars_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_lamb_meta_optimizer)
......@@ -461,6 +462,7 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_fleet_recompute_meta_optimizer MODULES test_fleet_recompute_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_graph_executor MODULES test_fleet_graph_executor ENVS ${dist_ENVS})
py_test_modules(test_fleet_gradient_merge_meta_optimizer MODULES test_fleet_gradient_merge_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_sharding_meta_optimizer MODULES test_fleet_sharding_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_amp_meta_optimizer MODULES test_fleet_amp_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_fp16_allreduce_meta_optimizer MODULES test_fleet_fp16_allreduce_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_pipeline_meta_optimizer MODULES test_fleet_pipeline_meta_optimizer ENVS ${dist_ENVS})
......
......@@ -55,14 +55,22 @@ class TestFleetMetaOptimizer(unittest.TestCase):
strategy,
train_prog,
startup_prog,
name='momentum'):
name='momentum',
regularization=None,
grad_clip=None):
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard():
if name == 'momentum':
optimizer = paddle.fluid.optimizer.Momentum(
learning_rate=0.01, momentum=0.9)
learning_rate=0.01,
momentum=0.9,
regularization=regularization,
grad_clip=grad_clip)
elif name == 'adam':
optimizer = paddle.fluid.optimizer.Adam(learning_rate=0.01)
optimizer = paddle.fluid.optimizer.Adam(
learning_rate=0.01,
regularization=regularization,
grad_clip=grad_clip)
optimizer = fleet.distributed_optimizer(
optimizer, strategy=strategy)
optimizer.minimize(loss)
......@@ -121,5 +129,8 @@ class TestFleetMetaOptimizer(unittest.TestCase):
elif name == "gradient_merge":
strategy.gradient_merge = True
strategy.gradient_merge_configs = {"k_steps": 2, "avg": True}
elif name == "sharding":
strategy.sharding = True
strategy.sharding_configs = {"fuse_broadcast_MB": 0.2}
else:
raise NotImplementedError()
......@@ -32,9 +32,6 @@ class TestFleetGradientMergeMetaOptimizer(TestFleetMetaOptimizer):
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
vars = [x.name for x in train_prog.list_vars()]
with open("main_program", 'w') as f:
f.write(str(train_prog))
self.assertIn('@GradientMerge', ''.join(vars))
def test_recom_gm_optimizer(self):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import os
import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker
from fleet_meta_optimizer_base import TestFleetMetaOptimizer
paddle.enable_static()
class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
def test_sharding_optimizer(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.net(train_prog, startup_prog)
self.set_strategy(strategy, 'sharding')
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
parameters = [
x.name for x in train_prog.list_vars() if x.persistable == True
]
ops = [op.type for op in avg_cost.block.ops]
vars = [x.name for x in train_prog.list_vars()]
self.assertIn('@BroadCast', ''.join(vars))
self.assertEqual(
set(parameters),
set([
"fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0",
"fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0"
]))
self.assertEqual(ops, [
'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream',
'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh',
'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean',
'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2',
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_sync_comm_stream', 'momentum', 'momentum', 'momentum'
])
def test_sharding_amp_optimizer(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.net(train_prog, startup_prog)
self.set_strategy(strategy, 'sharding')
self.set_strategy(strategy, 'amp')
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
ops = [op.type for op in avg_cost.block.ops]
vars = [x.name for x in train_prog.list_vars()]
parameters = [
x.name for x in train_prog.list_vars() if x.persistable == True
]
self.assertIn('@BroadCast', ''.join(vars))
self.assertIn('cast', ops)
self.assertIn('check_finite_and_unscale', ops)
self.assertEqual(
set(parameters),
set([
"fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0",
"fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0",
"loss_scaling_0", "num_bad_steps_0", "num_good_steps_0"
]))
self.assertEqual(ops, [
'cast', 'cast', 'cast', 'fill_constant', 'fill_constant',
'fill_constant', 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_sync_comm_stream', 'cast', 'mul', 'elementwise_add', 'cast',
'tanh', 'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast',
'mul', 'elementwise_add', 'softmax', 'cast', 'cross_entropy2',
'mean', 'elementwise_mul', 'fill_constant', 'scale',
'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', 'cast',
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast',
'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'cast',
'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad',
'c_sync_calc_stream', 'c_allreduce_sum', 'c_allreduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_allreduce_sum', 'c_sync_comm_stream', 'cast', 'cast', 'cast',
'check_finite_and_unscale', 'cast', 'c_sync_calc_stream',
'c_allreduce_max', 'c_sync_comm_stream', 'cast',
'update_loss_scaling', 'momentum', 'momentum', 'momentum'
])
def test_sharding_recompute_optimizer(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.net(train_prog, startup_prog)
self.set_strategy(strategy, 'sharding')
self.set_strategy(strategy, 'recompute')
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
ops = [op.type for op in avg_cost.block.ops]
vars = [x.name for x in train_prog.list_vars()]
parameters = [
x.name for x in train_prog.list_vars() if x.persistable == True
]
self.assertIn('@BroadCast', ''.join(vars))
self.assertIn('subprog', ''.join(vars))
self.assertEqual(
set(parameters),
set([
"fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0",
"fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0"
]))
self.assertEqual(ops, [
'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream',
'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh',
'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean',
'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2',
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'mul',
'elementwise_add', 'tanh_grad', 'elementwise_add_grad', 'mul_grad',
'mul', 'elementwise_add', 'tanh_grad', 'elementwise_add_grad',
'mul_grad', 'c_sync_calc_stream', 'c_allreduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_sync_comm_stream',
'momentum', 'momentum', 'momentum'
])
def test_sharding_amp_recompute_optimizer(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.net(train_prog, startup_prog)
self.set_strategy(strategy, 'sharding')
self.set_strategy(strategy, 'recompute')
self.set_strategy(strategy, 'amp')
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
ops = [op.type for op in avg_cost.block.ops]
vars = [x.name for x in train_prog.list_vars()]
parameters = [
x.name for x in train_prog.list_vars() if x.persistable == True
]
self.assertIn('@BroadCast', ''.join(vars))
self.assertIn('subprog', ''.join(vars))
self.assertIn('cast', ops)
self.assertIn('check_finite_and_unscale', ops)
self.assertEqual(
set(parameters),
set([
"fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0",
"fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0",
"loss_scaling_0", "num_bad_steps_0", "num_good_steps_0"
]))
self.assertEqual(ops, [
'cast', 'cast', 'cast', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', 'cast', 'cast',
'mul', 'cast', 'elementwise_add', 'cast', 'tanh', 'cast', 'mul',
'elementwise_add', 'cast', 'tanh', 'cast', 'mul', 'elementwise_add',
'softmax', 'cast', 'cross_entropy2', 'mean', 'elementwise_mul',
'fill_constant', 'scale', 'elementwise_mul_grad', 'mean_grad',
'cross_entropy_grad2', 'cast', 'softmax_grad',
'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_sync_comm_stream', 'cast', 'cast', 'cast',
'check_finite_and_unscale', 'cast', 'c_sync_calc_stream',
'c_allreduce_max', 'c_sync_comm_stream', 'cast',
'update_loss_scaling', 'momentum', 'momentum', 'momentum'
])
def test_sharding_weight_decay(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.net(train_prog, startup_prog)
self.set_strategy(strategy, 'sharding')
regularization = paddle.fluid.regularizer.L2Decay(0.0001)
self.optimizer(
avg_cost,
strategy,
train_prog,
startup_prog,
regularization=regularization)
parameters = [
x.name for x in train_prog.list_vars() if x.persistable == True
]
ops = [op.type for op in avg_cost.block.ops]
vars = [x.name for x in train_prog.list_vars()]
self.assertIn('@BroadCast', ''.join(vars))
self.assertEqual(
set(parameters),
set([
"fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0",
"fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0"
]))
self.assertEqual(ops, [
'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream',
'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh',
'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean',
'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2',
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_sync_comm_stream', 'scale', 'sum', 'scale', 'sum', 'scale',
'sum', 'momentum', 'momentum', 'momentum'
])
def test_sharding_gradient_clip(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
)
avg_cost, strategy = self.net(train_prog, startup_prog)
self.set_strategy(strategy, 'sharding')
clip = paddle.fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)
self.optimizer(
avg_cost, strategy, train_prog, startup_prog, grad_clip=clip)
parameters = [
x.name for x in train_prog.list_vars() if x.persistable == True
]
ops = [op.type for op in avg_cost.block.ops]
vars = [x.name for x in train_prog.list_vars()]
self.assertIn('@BroadCast', ''.join(vars))
self.assertEqual(
set(parameters),
set([
"fc_1.b_0", "fc_2.b_0", "fc_2.w_0", "fc_1.b_0_velocity_0",
"fc_2.b_0_velocity_0", "fc_2.w_0_velocity_0", "learning_rate_0"
]))
self.assertEqual(ops, [
'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream',
'mul', 'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'tanh',
'mul', 'elementwise_add', 'softmax', 'cross_entropy2', 'mean',
'fill_constant', 'scale', 'mean_grad', 'cross_entropy_grad2',
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'tanh_grad',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_sync_comm_stream', 'square', 'reduce_sum', 'square',
'reduce_sum', 'square', 'reduce_sum', 'sum', 'c_sync_calc_stream',
'c_allreduce_sum', 'c_sync_comm_stream', 'sqrt', 'fill_constant',
'elementwise_max', 'elementwise_div', 'elementwise_mul',
'elementwise_mul', 'elementwise_mul', 'momentum', 'momentum',
'momentum'
])
if __name__ == "__main__":
unittest.main()
......@@ -148,6 +148,7 @@ packages=['paddle',
'paddle.distributed.fleet',
'paddle.distributed.fleet.base',
'paddle.distributed.fleet.meta_optimizers',
'paddle.distributed.fleet.meta_optimizers.sharding',
'paddle.distributed.fleet.runtime',
'paddle.distributed.fleet.dataset',
'paddle.distributed.fleet.data_generator',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册