未验证 提交 78465703 编写于 作者: W WangXi 提交者: GitHub

[hybrid] out data parallel as optimizer sharding parallel (#35593)

上级 2b88057f
...@@ -43,6 +43,8 @@ message ShardingConfig { ...@@ -43,6 +43,8 @@ message ShardingConfig {
optional bool pp_allreduce_in_optimize = 10 [ default = false ]; optional bool pp_allreduce_in_optimize = 10 [ default = false ];
optional int32 pp_degree = 11 [ default = 1 ]; optional int32 pp_degree = 11 [ default = 1 ];
optional bool optimize_cast = 12 [ default = false ]; optional bool optimize_cast = 12 [ default = false ];
// Optimizer sharding. Temporary plans and may be deprecated
optional bool _dp_as_optimizer_sharding = 13 [ default = false ];
} }
message HybridConfig { message HybridConfig {
......
...@@ -26,27 +26,28 @@ class CheckFiniteAndUnscaleOp : public framework::OperatorWithKernel { ...@@ -26,27 +26,28 @@ class CheckFiniteAndUnscaleOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", if (ctx->HasInputs("X") || ctx->HasOutputs("Out")) {
"check_finite_and_unscale"); PADDLE_ENFORCE_EQ(
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out", ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
"check_finite_and_unscale"); platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ( "The input(X) and output(Out) should have same size in "
ctx->Inputs("X").size(), ctx->Outputs("Out").size(), "Operator(check_finite_and_unscale), size of input(X) is %d "
platform::errors::InvalidArgument( "and size of output(Out) is %d.",
"The input(X) and output(Out) should have same size in " ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
"Operator(check_finite_and_unscale), size of input(X) is %d " auto x_dims = ctx->GetInputsDim("X");
"and size of output(Out) is %d.", ctx->SetOutputsDim("Out", x_dims);
ctx->Inputs("X").size(), ctx->Outputs("Out").size())); }
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
ctx->SetOutputDim("FoundInfinite", {1}); ctx->SetOutputDim("FoundInfinite", {1});
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( auto dtype = framework::proto::VarType::FP32;
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); if (ctx.MultiInputVar("X").size() >= 1) {
dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
}
return framework::OpKernelType(dtype, ctx.GetPlace());
} }
}; };
......
...@@ -97,6 +97,8 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> { ...@@ -97,6 +97,8 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
scale_data, inverse_scale_v, found_inf_data); scale_data, inverse_scale_v, found_inf_data);
size_t xs_size = xs.size(); size_t xs_size = xs.size();
if (xs_size == 0) return;
const auto& cpu_place = platform::CPUPlace(); const auto& cpu_place = platform::CPUPlace();
// calculate each tensor's start index and copy to device // calculate each tensor's start index and copy to device
auto h_starts_tensor = auto h_starts_tensor =
......
...@@ -26,7 +26,6 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel { ...@@ -26,7 +26,6 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("FoundInfinite"), "Input", "FoundInfinite", OP_INOUT_CHECK(ctx->HasInput("FoundInfinite"), "Input", "FoundInfinite",
"update_loss_scaling"); "update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("PrevLossScaling"), "Input", "PrevLossScaling", OP_INOUT_CHECK(ctx->HasInput("PrevLossScaling"), "Input", "PrevLossScaling",
...@@ -35,16 +34,25 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel { ...@@ -35,16 +34,25 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
"update_loss_scaling"); "update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("InBadSteps"), "Input", "InBadSteps", OP_INOUT_CHECK(ctx->HasInput("InBadSteps"), "Input", "InBadSteps",
"update_loss_scaling"); "update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("LossScaling"), "Output", "LossScaling", OP_INOUT_CHECK(ctx->HasOutput("LossScaling"), "Output", "LossScaling",
"update_loss_scaling"); "update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("OutGoodSteps"), "Output", "OutGoodSteps", OP_INOUT_CHECK(ctx->HasOutput("OutGoodSteps"), "Output", "OutGoodSteps",
"update_loss_scaling"); "update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("OutBadSteps"), "Output", "OutBadSteps", OP_INOUT_CHECK(ctx->HasOutput("OutBadSteps"), "Output", "OutBadSteps",
"update_loss_scaling"); "update_loss_scaling");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims); if (ctx->HasInputs("X") || ctx->HasOutputs("Out")) {
PADDLE_ENFORCE_EQ(
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
platform::errors::InvalidArgument(
"The input(X) and output(Out) should have same size in "
"Operator(update_loss_scaling), size of input(X) is %d "
"and size of output(Out) is %d.",
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
}
ctx->SetOutputDim("LossScaling", {1}); ctx->SetOutputDim("LossScaling", {1});
ctx->SetOutputDim("OutGoodSteps", {1}); ctx->SetOutputDim("OutGoodSteps", {1});
ctx->SetOutputDim("OutBadSteps", {1}); ctx->SetOutputDim("OutBadSteps", {1});
...@@ -53,8 +61,12 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel { ...@@ -53,8 +61,12 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( auto dtype = framework::proto::VarType::FP32;
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); if (ctx.MultiInputVar("X").size() >= 1) {
dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
}
return framework::OpKernelType(dtype, ctx.GetPlace());
} }
}; };
......
...@@ -95,6 +95,8 @@ class LazyZeros<platform::CUDADeviceContext, T> { ...@@ -95,6 +95,8 @@ class LazyZeros<platform::CUDADeviceContext, T> {
const std::vector<const framework::Tensor*>& xs, const std::vector<const framework::Tensor*>& xs,
const std::vector<framework::Tensor*>& outs) const { const std::vector<framework::Tensor*>& outs) const {
size_t xs_size = xs.size(); size_t xs_size = xs.size();
if (xs_size == 0) return;
const auto& cpu_place = platform::CPUPlace(); const auto& cpu_place = platform::CPUPlace();
// alloc each tensor's start index and copy to device // alloc each tensor's start index and copy to device
auto h_in_starts_mem = auto h_in_starts_mem =
......
...@@ -105,7 +105,6 @@ class FP16Utils(object): ...@@ -105,7 +105,6 @@ class FP16Utils(object):
if op.type == "update_loss_scaling": if op.type == "update_loss_scaling":
update_loss_scaling_op_idx = idx update_loss_scaling_op_idx = idx
inf_var_name = op.desc.input('FoundInfinite')[0] inf_var_name = op.desc.input('FoundInfinite')[0]
op._rename_input(inf_var_name, inf_var_name + "@sharding")
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
reversed_x = [] reversed_x = []
reversed_x_paramname = [] reversed_x_paramname = []
...@@ -142,10 +141,6 @@ class FP16Utils(object): ...@@ -142,10 +141,6 @@ class FP16Utils(object):
name=inf_var_name + "@cast_int32", name=inf_var_name + "@cast_int32",
shape=inf_var.shape, shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32) dtype=core.VarDesc.VarType.INT32)
inf_var_sharding = block.create_var(
name=inf_var_name + "@sharding",
shape=inf_var.shape,
dtype=inf_var.dtype)
block._insert_op_without_sync( block._insert_op_without_sync(
update_loss_scaling_op_idx, update_loss_scaling_op_idx,
...@@ -179,10 +174,10 @@ class FP16Utils(object): ...@@ -179,10 +174,10 @@ class FP16Utils(object):
update_loss_scaling_op_idx, update_loss_scaling_op_idx,
type='cast', type='cast',
inputs={'X': inf_var_int32}, inputs={'X': inf_var_int32},
outputs={'Out': inf_var_sharding}, outputs={'Out': inf_var},
attrs={ attrs={
"in_dtype": inf_var_int32.dtype, "in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var_sharding.dtype, "out_dtype": inf_var.dtype,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
update_loss_scaling_op_idx += 1 update_loss_scaling_op_idx += 1
...@@ -210,10 +205,6 @@ class FP16Utils(object): ...@@ -210,10 +205,6 @@ class FP16Utils(object):
name=inf_var_name + "@cast_int32", name=inf_var_name + "@cast_int32",
shape=inf_var.shape, shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32) dtype=core.VarDesc.VarType.INT32)
inf_var_global = block.create_var(
name=inf_var_name + "@GLOBAL_WORLD",
shape=inf_var.shape,
dtype=inf_var.dtype)
block._insert_op_without_sync( block._insert_op_without_sync(
update_loss_scaling_op_idx, update_loss_scaling_op_idx,
type='cast', type='cast',
......
...@@ -39,6 +39,7 @@ class GradientClipHelper(object): ...@@ -39,6 +39,7 @@ class GradientClipHelper(object):
if not self._is_gradient_clip_op(op): if not self._is_gradient_clip_op(op):
continue continue
if op.type == "sum": if op.type == "sum":
global_norm_sum_op_idx = idx
continue continue
deperate_op = False deperate_op = False
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
...@@ -61,7 +62,10 @@ class GradientClipHelper(object): ...@@ -61,7 +62,10 @@ class GradientClipHelper(object):
if output_name not in op.desc.input_arg_names(): if output_name not in op.desc.input_arg_names():
deperated_vars.add(output_name) deperated_vars.add(output_name)
if not deperated_vars: # NOTE(wangxi): If only have 2 sharding, and 1 param.
# sharding 0 will not deperated_vars, will return, only
# sharding 1 will insert allreduce, then hang.
if not deperated_vars and global_norm_sum_op_idx == -1:
# got no gradient_clip op # got no gradient_clip op
return return
...@@ -71,8 +75,8 @@ class GradientClipHelper(object): ...@@ -71,8 +75,8 @@ class GradientClipHelper(object):
if idx in deperate_op_idx: if idx in deperate_op_idx:
block._remove_op(idx, sync=False) block._remove_op(idx, sync=False)
continue continue
reversed_inputs = []
if op.type == "sum": if op.type == "sum":
reversed_inputs = []
global_norm_sum_op_idx = idx global_norm_sum_op_idx = idx
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars: if input_name not in deperated_vars:
...@@ -82,6 +86,28 @@ class GradientClipHelper(object): ...@@ -82,6 +86,28 @@ class GradientClipHelper(object):
assert (len(op.desc.output_arg_names()) == 1) assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0] sum_res = op.desc.output_arg_names()[0]
# NOTE(wangxi): If we have 2 param, but sharding is 4,
# then the sum op in some cards will not have input.
# So we use fill_constant_op to set `sum_var` to zero,
# which does not affect correctness.
if len(reversed_inputs) == 0:
sum_var = block.var(sum_res)
namescope = op.attr("op_namescope")
block._remove_op(idx, sync=False)
op = block._insert_op_without_sync(
idx,
type='fill_constant',
inputs={},
outputs={'Out': sum_res},
attrs={
'shape': sum_var.shape,
'dtype': sum_var.dtype,
'value': 0.0,
OP_ROLE_KEY: OpRole.Optimize
})
op._set_attr('op_namescope', namescope)
# allreduce(mp)->allreduce(sharding)->allreduce(pp) # allreduce(mp)->allreduce(sharding)->allreduce(pp)
idx_offset = 1 idx_offset = 1
for ring_id in ring_ids: for ring_id in ring_ids:
......
...@@ -117,21 +117,28 @@ class ProgramDeps(object): ...@@ -117,21 +117,28 @@ class ProgramDeps(object):
var_name] == []: var_name] == []:
self._block._remove_var(var_name, sync=False) self._block._remove_var(var_name, sync=False)
def remove_op(self, op_idx): def remove_op(self, op_idx, reserved_vars=None):
# update deps # update deps
op = self._block.ops[op_idx] op = self._block.ops[op_idx]
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if reserved_vars is not None and input_name in reserved_vars:
continue
self.crop_input_var_from_op(op_idx, input_name) self.crop_input_var_from_op(op_idx, input_name)
for output_name in op.desc.output_arg_names(): for output_name in op.desc.output_arg_names():
if reserved_vars is not None and output_name in reserved_vars:
continue
self.crop_output_var_from_op(op_idx, output_name) self.crop_output_var_from_op(op_idx, output_name)
self._block._remove_op(op_idx, sync=False) self._block._remove_op(op_idx, sync=False)
def should_remove_op(self, op_idx): def should_remove_op(self, op_idx):
op = self._block.ops[op_idx] op = self._block.ops[op_idx]
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
# remove check_finite_and_unscale op if its input 'X' is empty # NOTE: At present, it is found that the OP without output is
if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0: # only send_v2 and partial_send op, which will be used in
return True # all device
if len(op.desc.output_arg_names()) == 0:
return False
for output_name in op.desc.output_arg_names(): for output_name in op.desc.output_arg_names():
if output_name not in self._should_removed_var: if output_name not in self._should_removed_var:
return False return False
......
...@@ -24,7 +24,8 @@ class Shard(object): ...@@ -24,7 +24,8 @@ class Shard(object):
self.global_params = set([]) self.global_params = set([])
self.worker_idx = -1 self.worker_idx = -1
self.worker_num = -1 self.worker_num = -1
self.global_param2device = {} self.global_param2device = dict()
self.device2global_params = dict()
def setup(self, params_grads, worker_idx, worker_num): def setup(self, params_grads, worker_idx, worker_num):
# param names of all devices # param names of all devices
...@@ -33,8 +34,9 @@ class Shard(object): ...@@ -33,8 +34,9 @@ class Shard(object):
self.worker_idx = worker_idx self.worker_idx = worker_idx
self.worker_num = worker_num self.worker_num = worker_num
# global_param2device contains fp32 params and fp16 params # global_param2device contains fp32 params and fp16 params
self.global_param2device = self._split_params(params_grads, worker_idx, # device2global_params only contains fp32 params
worker_num) self.global_param2device, self.device2global_params \
= self._split_params(params_grads, worker_idx, worker_num)
def has_param(self, var_name): def has_param(self, var_name):
return var_name in self.global_param2device and \ return var_name in self.global_param2device and \
...@@ -64,7 +66,7 @@ class Shard(object): ...@@ -64,7 +66,7 @@ class Shard(object):
device2params[device_idx].append(param_name) device2params[device_idx].append(param_name)
param2device[param_name] = device_idx param2device[param_name] = device_idx
mem_accu += mem mem_accu += mem
return param2device return param2device, device2params
def _var_device_id(self, var_name): def _var_device_id(self, var_name):
if var_name in self.global_param2device: if var_name in self.global_param2device:
......
...@@ -365,6 +365,65 @@ def insert_allreduce_ops(block, ...@@ -365,6 +365,65 @@ def insert_allreduce_ops(block,
return return
class FuseHelper(object):
@staticmethod
def get_fused_groups(block, vars_name, fuse_size=32.):
""" coalesce tensor, get fused group """
groups = []
cur_size = 0.
last_dtype = None
for var_name in vars_name:
real_var = block.var(var_name)
var_size = get_var_size(real_var)
if cur_size + var_size > fuse_size \
or len(groups) == 0 \
or real_var.dtype != last_dtype:
groups.append([real_var])
cur_size = var_size
last_dtype = real_var.dtype
else:
groups[-1].append(real_var)
cur_size += var_size
return groups
@staticmethod
def insert_coalesce_tensor(block,
index,
groups,
op_role=OpRole.Backward,
prefix="Output"):
fused_vars = []
insert_num = 0
for group in groups:
assert len(group) >= 1
if len(group) == 1:
# no need fuse
fused_vars.append(group[0])
continue
fused_var = block.create_var(
name=unique_name.generate('Fused{}_{}'.format(prefix, group[0]
.name)),
dtype=group[0].dtype,
persistable=False,
stop_gradient=True)
fused_vars.append(fused_var)
block._insert_op_without_sync(
index,
type="coalesce_tensor",
inputs={"Input": group},
outputs={"Output": group,
"FusedOutput": fused_var},
attrs={
"copy_data": True,
"use_align": True,
"dtype": group[0].dtype,
OP_ROLE_KEY: op_role
})
insert_num += 1
return fused_vars, insert_num
def insert_fused_allreduce_ops(block, def insert_fused_allreduce_ops(block,
insert_idx, insert_idx,
ring_id, ring_id,
...@@ -372,46 +431,15 @@ def insert_fused_allreduce_ops(block, ...@@ -372,46 +431,15 @@ def insert_fused_allreduce_ops(block,
op_role=OpRole.Backward, op_role=OpRole.Backward,
use_calc_stream=False, use_calc_stream=False,
fuse_grad_size_in_MB=32): fuse_grad_size_in_MB=32):
segments = [] groups = FuseHelper.get_fused_groups(block, allreduce_vars,
cur_size = 0. fuse_grad_size_in_MB)
last_dtype = None
for var in allreduce_vars: fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(
real_var = block.var(var) block, insert_idx, groups, op_role, prefix="Grad")
var_size = get_var_size(real_var)
if cur_size + var_size > fuse_grad_size_in_MB \
or len(segments) == 0 \
or real_var.dtype != last_dtype:
segments.append([real_var])
cur_size = var_size
last_dtype = real_var.dtype
else:
segments[-1].append(real_var)
cur_size += var_size
fused_vars = []
for segment in segments:
tmp_var = block.create_var(
name=unique_name.generate('FusedOutput_{}'.format(segment[0].name)),
dtype=segment[0].dtype,
persistable=False,
stop_gradient=True)
fused_vars.append(tmp_var)
block._insert_op_without_sync(
insert_idx,
type="coalesce_tensor",
inputs={"Input": segment},
outputs={"Output": segment,
"FusedOutput": tmp_var},
attrs={
"copy_data": True,
"use_align": True,
"dtype": segment[0].dtype,
OP_ROLE_KEY: op_role
})
for fused_var in fused_vars: for fused_var in fused_vars:
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx + len(fused_vars), insert_idx + insert_num,
type='c_allreduce_sum', type='c_allreduce_sum',
inputs={'X': fused_var}, inputs={'X': fused_var},
outputs={'Out': fused_var}, outputs={'Out': fused_var},
...@@ -422,13 +450,61 @@ def insert_fused_allreduce_ops(block, ...@@ -422,13 +450,61 @@ def insert_fused_allreduce_ops(block,
}) })
if not use_calc_stream: if not use_calc_stream:
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx + len(fused_vars), insert_idx + insert_num,
type='c_sync_calc_stream', type='c_sync_calc_stream',
inputs={'X': fused_var}, inputs={'X': fused_var},
outputs={'Out': fused_var}, outputs={'Out': fused_var},
attrs={OP_ROLE_KEY: op_role}) attrs={OP_ROLE_KEY: op_role})
def insert_fused_reduce_ops(block,
insert_idx,
ring_id,
reduce_vars,
shard,
op_role=OpRole.Backward,
use_calc_stream=False,
rank=None,
fuse_grad_size=32):
nranks = shard.worker_num
device_to_vars = [[] for _ in range(nranks)]
for var in reduce_vars:
root_id = get_grad_device(var, shard)
assert 0 <= root_id < nranks, "root_id should >=0 and < nranks, " \
"but now nranks={}, the root_id of var={} is {}"\
.format(nranks, var, root_id)
device_to_vars[root_id].append(var)
for root_id, vars_name in enumerate(device_to_vars):
groups = FuseHelper.get_fused_groups(block, vars_name, fuse_grad_size)
fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(
block, insert_idx, groups, op_role, prefix="Grad")
for fused_var in fused_vars:
block._insert_op_without_sync(
insert_idx + insert_num,
type='c_reduce_sum',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={
'ring_id': ring_id,
'root_id': root_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
if not use_calc_stream:
block._insert_op_without_sync(
insert_idx + insert_num,
type='c_sync_calc_stream',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={OP_ROLE_KEY: op_role})
return [] if rank is None else device_to_vars[rank]
def insert_reduce_ops(block, def insert_reduce_ops(block,
insert_idx, insert_idx,
ring_id, ring_id,
...@@ -436,14 +512,26 @@ def insert_reduce_ops(block, ...@@ -436,14 +512,26 @@ def insert_reduce_ops(block,
shard, shard,
op_role=OpRole.Backward, op_role=OpRole.Backward,
use_calc_stream=False, use_calc_stream=False,
rank=None): rank=None,
strategy=None):
""" """
_add_allreduce_ops _add_reduce_ops
""" """
if strategy and strategy.fuse_all_reduce_ops and \
not strategy.fuse_grad_merge:
return insert_fused_reduce_ops(block, insert_idx, ring_id, reduce_vars,
shard, op_role, use_calc_stream, rank,
strategy.fuse_grad_size_in_MB)
grad_in_this_device = [] grad_in_this_device = []
for var in reduce_vars: for var in reduce_vars:
grad_var = var
root_id = get_grad_device(var, shard) if strategy and strategy.fuse_all_reduce_ops and \
strategy.fuse_grad_merge:
# TODO(wangxi): if support fp16_allreduce, need be
# 'FusedMergedGrad.cast_fp16._'
grad_var = var.replace('FusedMergedGrad_', '')
root_id = get_grad_device(grad_var, shard)
assert root_id >= 0, "root id should be a positive int, but now root id is {}".format( assert root_id >= 0, "root id should be a positive int, but now root id is {}".format(
root_id) root_id)
if rank is not None and rank == root_id: if rank is not None and rank == root_id:
...@@ -463,6 +551,94 @@ def insert_reduce_ops(block, ...@@ -463,6 +551,94 @@ def insert_reduce_ops(block,
return grad_in_this_device return grad_in_this_device
def insert_fused_broadcast_param_ops(block,
insert_idx,
ring_id,
params,
shard,
op_role=OpRole.Optimize,
use_calc_stream=False,
rank=None,
fuse_size=32):
nranks = shard.worker_num
device_to_vars = [[] for _ in range(nranks)]
for var in params:
root_id = shard.device(var)
assert 0 <= root_id < nranks, "root_id should >=0 and < nranks, " \
"but now nranks={}, the root_id of var={} is {}"\
.format(nranks, var, root_id)
device_to_vars[root_id].append(var)
for root_id, vars_name in enumerate(device_to_vars):
groups = FuseHelper.get_fused_groups(block, vars_name, fuse_size)
fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(
block, insert_idx, groups, op_role, prefix="Param")
for fused_var in fused_vars:
block._insert_op_without_sync(
insert_idx + insert_num,
type='c_broadcast',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={
'ring_id': ring_id,
'root': root_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
if not use_calc_stream:
block._insert_op_without_sync(
insert_idx + insert_num,
type='c_sync_calc_stream',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={OP_ROLE_KEY: op_role})
return [] if rank is None else device_to_vars[rank]
def insert_broadcast_param_ops(block,
insert_idx,
ring_id,
params,
shard,
op_role=OpRole.Optimize,
use_calc_stream=False,
rank=None,
strategy=None):
"""
add broadcast param ops
"""
if strategy and strategy.fuse_all_reduce_ops:
# TODO(wangxi): put fused var in startup_program, only need exec once
return insert_fused_broadcast_param_ops(
block, insert_idx, ring_id, params, shard, op_role, use_calc_stream,
rank, strategy.fuse_grad_size_in_MB)
param_in_this_device = []
for param in params:
root_id = shard.device(param)
assert root_id >= 0, "root id should be a positive int, but now root id is {}".format(
root_id)
if rank is not None and rank == root_id:
param_in_this_device.append(param)
block._insert_op_without_sync(
insert_idx,
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': ring_id,
'root': root_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
return param_in_this_device
def get_grad_device(grad_name, shard): def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format( assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name) grad_name)
......
...@@ -15,19 +15,22 @@ ...@@ -15,19 +15,22 @@
import paddle import paddle
from paddle.fluid import unique_name, core from paddle.fluid import unique_name, core
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper from paddle.static import default_startup_program, device_guard
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op
from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase
from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper import WeightDecayHelper
from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper
from .sharding.offload_helper import OffloadHelper
from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard
from paddle.fluid import layers from paddle.fluid import layers
from .common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper
from .common import is_backward_op, is_optimizer_op, is_update_op
from .meta_optimizer_base import MetaOptimizerBase
from .sharding.shard import Shard, ProgramSegment
from .sharding.fp16_helper import FP16Utils
from .sharding.weight_decay_helper import WeightDecayHelper
from .sharding.gradient_clip_helper import GradientClipHelper
from .sharding.offload_helper import OffloadHelper
from .sharding.prune import ProgramDeps
from .sharding import utils
# FIXME: import *
from .sharding.utils import *
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
formatter = logging.Formatter( formatter = logging.Formatter(
...@@ -35,7 +38,6 @@ formatter = logging.Formatter( ...@@ -35,7 +38,6 @@ formatter = logging.Formatter(
ch = logging.StreamHandler() ch = logging.StreamHandler()
ch.setFormatter(formatter) ch.setFormatter(formatter)
logger.addHandler(ch) logger.addHandler(ch)
from functools import reduce
__all__ = [] __all__ = []
...@@ -154,10 +156,11 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -154,10 +156,11 @@ class ShardingOptimizer(MetaOptimizerBase):
def _get_hybrid_dp_mode(self): def _get_hybrid_dp_mode(self):
""" get """ get
self.hybrid_dp_mode self.hybrid_dp_mode = 'pp_hybrid_dp' or 'sharding_hybrid_dp'
self.gradient_merge_mode self.gradient_merge_mode = 'pp_gm' or 'sharding_gm'
self._gradient_merge_acc_step self._gradient_merge_acc_step
self.pp_allreduce_in_optimize self.pp_allreduce_in_optimize
self._optimizer_sharding
""" """
strategy = self.user_defined_strategy strategy = self.user_defined_strategy
sharding_configs = strategy.sharding_configs sharding_configs = strategy.sharding_configs
...@@ -194,9 +197,18 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -194,9 +197,18 @@ class ShardingOptimizer(MetaOptimizerBase):
logger.info("Gradient merge in [{}], acc step = [{}]".format( logger.info("Gradient merge in [{}], acc step = [{}]".format(
gm_mode, gm_acc_step)) gm_mode, gm_acc_step))
optimizer_sharding = False
# TODO(wangxi): need support dp_as_opt_sharding with sharding
# need support without pp in future
if self.sharding_degree == 1 and self.dp_degree > 1 \
and sharding_configs['_dp_as_optimizer_sharding'] \
and self.pp_degree > 1:
optimizer_sharding = True
self.hybrid_dp_mode = dp_mode self.hybrid_dp_mode = dp_mode
self.gradient_merge_mode = gm_mode self.gradient_merge_mode = gm_mode
self._gradient_merge_acc_step = gm_acc_step self._gradient_merge_acc_step = gm_acc_step
self._optimizer_sharding = optimizer_sharding
# this feature is design for ascend, and should NOT be used in GPU training # this feature is design for ascend, and should NOT be used in GPU training
self.pp_allreduce_in_optimize = sharding_configs[ self.pp_allreduce_in_optimize = sharding_configs[
...@@ -276,7 +288,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -276,7 +288,8 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
# step1: build shard # step1: build shard
self._build_shard(params_grads) self._build_shard(params_grads, self.sharding_rank,
self.sharding_degree)
# step2: split_program # step2: split_program
self._split_program(main_block) self._split_program(main_block)
...@@ -287,14 +300,35 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -287,14 +300,35 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
# step4: remove unneeded ops and vars from block # step4: remove unneeded ops and vars from block
self._prune_main_program(main_block) self._prune_main_program(
self._prune_startup_program(startup_block) main_block, self._shard,
[self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id])
self._prune_startup_program(startup_block, self._shard)
def _apply_opt_sharding_pass(self, params_grads):
""" outer dp as optimizer sharding """
if self._optimizer_sharding is False: return
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
def _insert_allreduce_for_pp(self): # step1: build shard
self._build_shard(params_grads, self.dp_rank, self.dp_degree)
# NOTE(wangxi): prune_main_program will prune cast if not add this
for param, grad in params_grads:
self._reduced_grads_to_param[grad.name] = param.name
# step4: remove unneeded ops and vars from block
self._prune_main_program(
main_block, self._shard,
[self.mp_ring_id, self.pp_ring_id, self.dp_ring_id])
self._prune_startup_program(startup_block, self._shard)
def _insert_allreduce_for_pp(self, params_grads):
if self.pp_degree == 1: return if self.pp_degree == 1: return
strategy = self.user_defined_strategy strategy = self.user_defined_strategy
fp16_allreduce = strategy.fp16_allreduce
main_block = self._main_program.global_block() main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
...@@ -318,10 +352,13 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -318,10 +352,13 @@ class ShardingOptimizer(MetaOptimizerBase):
if in_name not in main_block.vars: if in_name not in main_block.vars:
main_block._remove_op(idx) main_block._remove_op(idx)
if self._optimizer_sharding:
# TODO(wangxi): support fp16_allreduce with optimizer sharding
strategy.fp16_allreduce = False
shard = self._shard if self._optimizer_sharding else None
accumulated_grad_names = self._pp_optimizer._accumulate_gradients( accumulated_grad_names = self._pp_optimizer._accumulate_gradients(
main_block, main_block, strategy=strategy, shard=shard)
fp16_allreduce=fp16_allreduce,
user_defined_strategy=strategy)
len_of_ops = len(main_block.ops) len_of_ops = len(main_block.ops)
first_optimize_op_index = get_first_optimize_op_idx(main_block) first_optimize_op_index = get_first_optimize_op_idx(main_block)
...@@ -346,7 +383,36 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -346,7 +383,36 @@ class ShardingOptimizer(MetaOptimizerBase):
first_optimize_op_index += (len(main_block.ops) - len_of_ops) first_optimize_op_index += (len(main_block.ops) - len_of_ops)
len_of_ops = len(main_block.ops) len_of_ops = len(main_block.ops)
if self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp": if self._optimizer_sharding:
accumulated_grad_names = utils.insert_reduce_ops(
main_block,
first_optimize_op_index,
self.dp_ring_id,
accumulated_grad_names,
self._shard,
OpRole.Optimize,
use_calc_stream=True,
rank=self.dp_rank,
strategy=strategy)
logger.info("Optimizer grad in this rank {}".format(
accumulated_grad_names))
first_optimize_op_index += (len(main_block.ops) - len_of_ops)
len_of_ops = len(main_block.ops)
optimizer_param = utils.insert_broadcast_param_ops(
main_block,
len_of_ops,
self.dp_ring_id, [x[0].name for x in params_grads],
self._shard,
OpRole.Optimize,
use_calc_stream=True,
rank=self.dp_rank,
strategy=strategy)
logger.info("Optimizer param in this rank {}".format(
optimizer_param))
if not strategy.fuse_grad_merge:
assert len(accumulated_grad_names) == len(optimizer_param)
elif self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp":
insert_allreduce_ops( insert_allreduce_ops(
main_block, main_block,
first_optimize_op_index, first_optimize_op_index,
...@@ -361,9 +427,10 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -361,9 +427,10 @@ class ShardingOptimizer(MetaOptimizerBase):
# FIXME(wangxi): if fp16_allreduce, put cast fp16->fp32 to there? # FIXME(wangxi): if fp16_allreduce, put cast fp16->fp32 to there?
def _adapt_amp_clip_without_sharding(self): def _adapt_amp_clip_without_sharding(self):
if self.sharding_degree > 1: return
# if not use sharding, adapt amp/clip, for remain parallelism. # if not use sharding, adapt amp/clip, for remain parallelism.
# cast --> amp --> clip --> opt # cast --> amp --> clip --> opt
if self.sharding_degree > 1: return
if self._optimizer_sharding: return
main_block = self._main_program.global_block() main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
...@@ -449,7 +516,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -449,7 +516,9 @@ class ShardingOptimizer(MetaOptimizerBase):
self._apply_sharding_pass(params_grads) self._apply_sharding_pass(params_grads)
self._insert_allreduce_for_pp() self._apply_opt_sharding_pass(params_grads)
self._insert_allreduce_for_pp(params_grads)
self._adapt_amp_clip_without_sharding() self._adapt_amp_clip_without_sharding()
...@@ -630,11 +699,10 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -630,11 +699,10 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
def _build_shard(self, params_grads): def _build_shard(self, params_grads, shard_rank, shard_size):
# step 2: split params # step 2: split params
self._params = set([x[0].name for x in params_grads]) self._params = set([x[0].name for x in params_grads])
self._shard.setup(params_grads, self.sharding_rank, self._shard.setup(params_grads, shard_rank, shard_size)
self.sharding_degree)
# step 3: get broadcast vars # step 3: get broadcast vars
self._broadcast_vars = self._shard.find_broadcast_params( self._broadcast_vars = self._shard.find_broadcast_params(
...@@ -787,7 +855,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -787,7 +855,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self._segments[idx_]._end_idx].desc.input_arg_names())) self._segments[idx_]._end_idx].desc.input_arg_names()))
return return
def _prune_main_program(self, block): def _prune_main_program(self, block, shard, rings):
""" """
calculate deps from allredce op to optimize op, calculate deps from allredce op to optimize op,
remove ops and vars not needed in this worker remove ops and vars not needed in this worker
...@@ -799,28 +867,26 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -799,28 +867,26 @@ class ShardingOptimizer(MetaOptimizerBase):
""" """
weightdecay_helper = WeightDecayHelper() weightdecay_helper = WeightDecayHelper()
weightdecay_helper.prune_weight_decay(block, self._shard) weightdecay_helper.prune_weight_decay(block, shard)
# FIXME(wangxi): mp should prune duplicated param_grads # FIXME(wangxi): mp should prune duplicated param_grads
# NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism # NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism
# group. and each Data Parallelism group should have its own sync of FoundInfinite # group. and each Data Parallelism group should have its own sync of FoundInfinite
# amp could use global group for sync # amp could use global group for sync
FP16Utils.prune_fp16( FP16Utils.prune_fp16(block, shard, self._reduced_grads_to_param, rings)
block, self._shard, self._reduced_grads_to_param,
[self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id])
# clipbyglobalnorm should only use the Model paramllelism group (mp-sharding-pp) # clipbyglobalnorm should only use the Model paramllelism group (mp-sharding-pp)
gradientclip_helper = GradientClipHelper(None) gradientclip_helper = GradientClipHelper(None)
gradientclip_helper.prune_gradient_clip( gradientclip_helper.prune_gradient_clip(block, shard, rings)
block, self._shard,
[self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id])
# build prog deps # build prog deps
reduced_grads = [] reduced_grads = []
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
input_names = op.desc.input_arg_names() input_names = op.desc.input_arg_names()
output_names = op.desc.output_arg_names() output_names = op.desc.output_arg_names()
if op.type == "c_allreduce_sum": # FIXME(wangxi): need use grads, pipeline grad is @GRAD@MERGE
if op.type == "c_allreduce_sum" and \
op.attr('use_model_parallel') is False:
assert (len(output_names) == 1) assert (len(output_names) == 1)
output_name = output_names[0] output_name = output_names[0]
reduced_grads.append(output_name) reduced_grads.append(output_name)
...@@ -828,8 +894,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -828,8 +894,8 @@ class ShardingOptimizer(MetaOptimizerBase):
# prune optimizer state and param # prune optimizer state and param
pruned_opti_vars = [] pruned_opti_vars = []
for var_name in list(block.vars.keys()): for var_name in list(block.vars.keys()):
if self._shard.is_opti_var(var_name) and \ if shard.is_opti_var(var_name) and \
not self._shard.has_opt_var(var_name): not shard.has_opt_var(var_name):
pruned_opti_vars.append(var_name) pruned_opti_vars.append(var_name)
program_deps = ProgramDeps(block, reduced_grads, pruned_opti_vars) program_deps = ProgramDeps(block, reduced_grads, pruned_opti_vars)
...@@ -881,7 +947,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -881,7 +947,9 @@ class ShardingOptimizer(MetaOptimizerBase):
# if all outputs of this op are in _should_removed_var # if all outputs of this op are in _should_removed_var
# _should_removed_var: opt state not cur shard # _should_removed_var: opt state not cur shard
if program_deps.should_remove_op(idx): if program_deps.should_remove_op(idx):
program_deps.remove_op(idx) # NOTE(wangxi): need reserve all param in optimizer_sharding
reserved_vars = self._params if self._optimizer_sharding else None
program_deps.remove_op(idx, reserved_vars)
# NOTE (JZ-LIANG) revise and unify logic here # NOTE (JZ-LIANG) revise and unify logic here
# sharding support fp16_allreduce logic # sharding support fp16_allreduce logic
...@@ -1112,17 +1180,21 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1112,17 +1180,21 @@ class ShardingOptimizer(MetaOptimizerBase):
return return
def _prune_startup_program(self, block): def _prune_startup_program(self, block, shard):
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
for output_name in op.desc.output_arg_names(): for output_name in op.desc.output_arg_names():
if self._shard.has_var(output_name): if shard.has_var(output_name):
continue
if self._optimizer_sharding and shard.is_param(output_name):
continue continue
#TODO why do we remove op, when only one var is removed #TODO why do we remove op, when only one var is removed
block._remove_op(idx, sync=False) block._remove_op(idx, sync=False)
break break
for var_name in list(block.vars.keys()): for var_name in list(block.vars.keys()):
if self._shard.has_var(var_name): if shard.has_var(var_name):
continue
if self._optimizer_sharding and shard.is_param(var_name):
continue continue
block._remove_var(var_name, sync=False) block._remove_var(var_name, sync=False)
block._sync_with_cpp() block._sync_with_cpp()
......
...@@ -5049,16 +5049,16 @@ class PipelineOptimizer(object): ...@@ -5049,16 +5049,16 @@ class PipelineOptimizer(object):
def _accumulate_gradients(self, def _accumulate_gradients(self,
block, block,
pp_allreduce_in_optimize=False, pp_allreduce_in_optimize=False,
fp16_allreduce=False, strategy=None,
user_defined_strategy=None): shard=None):
""" """
Create a new merged gradient for each parameter and accumulate the Create a new merged gradient for each parameter and accumulate the
corresponding gradient to it. corresponding gradient to it.
""" """
if user_defined_strategy and user_defined_strategy.fuse_grad_merge: fp16_allreduce = strategy.fp16_allreduce if strategy else False
if strategy and strategy.fuse_grad_merge:
fused_gradient_names = self._accumulate_gradients_with_fuse( fused_gradient_names = self._accumulate_gradients_with_fuse(
block, fp16_allreduce, block, fp16_allreduce, strategy.fuse_grad_size_in_MB, shard)
user_defined_strategy.fuse_grad_size_in_MB)
return fused_gradient_names return fused_gradient_names
merged_gradient_names = [] merged_gradient_names = []
...@@ -5079,8 +5079,8 @@ class PipelineOptimizer(object): ...@@ -5079,8 +5079,8 @@ class PipelineOptimizer(object):
if self._is_backward_op(op) and first_opt_op_idx is None: if self._is_backward_op(op) and first_opt_op_idx is None:
first_opt_op_idx = index + 1 first_opt_op_idx = index + 1
# no optimize phase # maybe have no optimize
if first_opt_op_idx == len(block.ops): return # if first_opt_op_idx == len(block.ops): return
if self._is_backward_op(op) and ( if self._is_backward_op(op) and (
self._op_role_var_key in op.attr_names): self._op_role_var_key in op.attr_names):
...@@ -5190,44 +5190,9 @@ class PipelineOptimizer(object): ...@@ -5190,44 +5190,9 @@ class PipelineOptimizer(object):
return merged_gradient_names return merged_gradient_names
def _accumulate_gradients_with_fuse(self, main_block, fp16, fused_size): def _insert_accumulate_gradients_with_fuse(self, main_block, fp16,
first_opt_op_idx = None fused_size, grad_param_pairs,
grad_param_pairs = [] first_opt_op_idx):
# obtain all param/grad pairs that needed to be fused
for index, op in reversed(tuple(enumerate(list(main_block.ops)))):
# remove the cast op of fp16 grad to fp32 grad
if self._is_optimize_op(op) and op.type == 'cast':
in_name = op.input_arg_names[0]
out_name = op.output_arg_names[0]
if out_name.strip('@GRAD') in self._param_device_map:
assert in_name.replace('.cast_fp16', '') == out_name
main_block._remove_op(index)
continue
if self._is_backward_op(op) and first_opt_op_idx is None:
first_opt_op_idx = index + 1
# no optimize phase
if first_opt_op_idx == len(main_block.ops):
return
if self._is_backward_op(op) and (
self._op_role_var_key in op.attr_names):
op_role_var = op.attr(self._op_role_var_key)
if len(op_role_var) == 0:
continue
assert len(op_role_var) % 2 == 0
for i in range(0, len(op_role_var), 2):
param_name = op_role_var[i]
if not main_block.has_var(param_name):
continue
if '@BroadCast' in param_name:
continue
grad_param_pairs.append(
(op_role_var[i + 1], op_role_var[i]))
if len(grad_param_pairs) == 0:
return
grad_param_pairs = self._sort_grad_param_by_dtype(main_block, grad_param_pairs = self._sort_grad_param_by_dtype(main_block,
grad_param_pairs) grad_param_pairs)
...@@ -5426,9 +5391,66 @@ class PipelineOptimizer(object): ...@@ -5426,9 +5391,66 @@ class PipelineOptimizer(object):
for i in range(len(fused_merged_gradients)): for i in range(len(fused_merged_gradients)):
fused_merged_gradients[i] = fused_merged_gradients[i].name fused_merged_gradients[i] = fused_merged_gradients[i].name
main_block._sync_with_cpp() return fused_merged_gradients, first_opt_op_idx
return fused_merged_gradients def _accumulate_gradients_with_fuse(self,
main_block,
fp16,
fused_size,
shard=None):
first_opt_op_idx = None
grad_param_pairs = []
# obtain all param/grad pairs that needed to be fused
for index, op in reversed(tuple(enumerate(list(main_block.ops)))):
# remove the cast op of fp16 grad to fp32 grad
if self._is_optimize_op(op) and op.type == 'cast':
in_name = op.input_arg_names[0]
out_name = op.output_arg_names[0]
if out_name.strip('@GRAD') in self._param_device_map:
assert in_name.replace('.cast_fp16', '') == out_name
main_block._remove_op(index)
continue
if self._is_backward_op(op) and first_opt_op_idx is None:
first_opt_op_idx = index + 1
# no optimize phase
if first_opt_op_idx == len(main_block.ops):
return
if self._is_backward_op(op) and (
self._op_role_var_key in op.attr_names):
op_role_var = op.attr(self._op_role_var_key)
if len(op_role_var) == 0:
continue
assert len(op_role_var) % 2 == 0
for i in range(0, len(op_role_var), 2):
param_name = op_role_var[i]
if not main_block.has_var(param_name):
continue
if '@BroadCast' in param_name:
continue
grad_param_pairs.append(
(op_role_var[i + 1], op_role_var[i]))
if len(grad_param_pairs) == 0:
return
nranks = shard.worker_num if shard else 1
device_to_pairs = [[] for _ in range(nranks)]
for pair in grad_param_pairs:
root_id = shard.device(pair[1]) if shard else 0
assert 0 <= root_id < nranks
device_to_pairs[root_id].append(pair)
all_fused_merged_gradients = []
for pairs in device_to_pairs:
fused_merged_gradients, first_opt_op_idx = \
self._insert_accumulate_gradients_with_fuse(
main_block, fp16, fused_size, pairs, first_opt_op_idx)
all_fused_merged_gradients += fused_merged_gradients
main_block._sync_with_cpp()
return all_fused_merged_gradients
def _sort_grad_param_by_dtype(self, main_block, grad_param_pairs): def _sort_grad_param_by_dtype(self, main_block, grad_param_pairs):
# sort the grad param paris by the dtype # sort the grad param paris by the dtype
......
...@@ -70,6 +70,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer) ...@@ -70,6 +70,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_init) list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_init)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_sharding_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_sharding_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_hybrid_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_localsgd_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_localsgd_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_lars_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_lars_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_lamb_meta_optimizer) list(APPEND MIXED_DIST_TEST_OPS test_fleet_lamb_meta_optimizer)
...@@ -568,6 +569,7 @@ if(WITH_DISTRIBUTE) ...@@ -568,6 +569,7 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_fleet_graph_executor MODULES test_fleet_graph_executor ENVS ${dist_ENVS}) py_test_modules(test_fleet_graph_executor MODULES test_fleet_graph_executor ENVS ${dist_ENVS})
py_test_modules(test_fleet_gradient_merge_meta_optimizer MODULES test_fleet_gradient_merge_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_gradient_merge_meta_optimizer MODULES test_fleet_gradient_merge_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_sharding_meta_optimizer MODULES test_fleet_sharding_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_sharding_meta_optimizer MODULES test_fleet_sharding_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_hybrid_meta_optimizer MODULES test_fleet_hybrid_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_amp_meta_optimizer MODULES test_fleet_amp_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_amp_meta_optimizer MODULES test_fleet_amp_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_amp_init MODULES test_fleet_amp_init ENVS ${dist_ENVS}) py_test_modules(test_fleet_amp_init MODULES test_fleet_amp_init ENVS ${dist_ENVS})
py_test_modules(test_fleet_fp16_allreduce_meta_optimizer MODULES test_fleet_fp16_allreduce_meta_optimizer ENVS ${dist_ENVS}) py_test_modules(test_fleet_fp16_allreduce_meta_optimizer MODULES test_fleet_fp16_allreduce_meta_optimizer ENVS ${dist_ENVS})
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
import unittest import unittest
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -25,6 +26,23 @@ class TestFleetMetaOptimizer(unittest.TestCase): ...@@ -25,6 +26,23 @@ class TestFleetMetaOptimizer(unittest.TestCase):
os.environ["PADDLE_TRAINER_ID"] = "1" os.environ["PADDLE_TRAINER_ID"] = "1"
os.environ[ os.environ[
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002" "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002"
self._debug = False
def debug_program(self, main_prog, startup_prog):
if not self._debug: return
main_prog_ops = main_prog.global_block().ops
startup_prog_ops = startup_prog.global_block().ops
main_prog_op_types = [op.type for op in main_prog_ops]
startup_prog_op_types = [op.type for op in startup_prog_ops]
print("=== debug program and ops in func [{}] ==="
.format(inspect.stack()[1].function))
print(main_prog)
print(main_prog_op_types)
print(startup_prog)
print(startup_prog_op_types)
def net(self, main_prog, startup_prog): def net(self, main_prog, startup_prog):
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
...@@ -82,6 +100,20 @@ class TestFleetMetaOptimizer(unittest.TestCase): ...@@ -82,6 +100,20 @@ class TestFleetMetaOptimizer(unittest.TestCase):
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
return avg_cost, strategy return avg_cost, strategy
def boundary_net(self, main_prog, startup_prog):
with fluid.program_guard(main_prog, startup_prog):
fleet.init(is_collective=True)
x = paddle.static.data(name='x', shape=[-1, 4], dtype='float32')
with paddle.static.device_guard('gpu:0'):
linear = fluid.Linear(4, 8, bias_attr=False)
out = linear(x)
with paddle.static.device_guard('gpu:1'):
linear = fluid.Linear(8, 5, bias_attr=False)
out = linear(out)
avg_cost = paddle.mean(out)
strategy = fleet.DistributedStrategy()
return avg_cost, strategy
def optimizer(self, def optimizer(self,
loss, loss,
strategy, strategy,
...@@ -190,5 +222,12 @@ class TestFleetMetaOptimizer(unittest.TestCase): ...@@ -190,5 +222,12 @@ class TestFleetMetaOptimizer(unittest.TestCase):
"enable_offload": True, "enable_offload": True,
"checkpoint_shape": [256] "checkpoint_shape": [256]
} }
elif name == "pipeline":
strategy.pipeline = True
strategy.pipeline_configs = {
"schedule_mode": "1F1B",
"micro_batch_size": 2,
"accumulate_steps": 4,
}
else: else:
raise NotImplementedError() raise NotImplementedError()
...@@ -612,7 +612,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): ...@@ -612,7 +612,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
for op in main_prog_ops: for op in main_prog_ops:
if op.type == 'c_allreduce_sum': if op.type == 'c_allreduce_sum':
assert 'FusedOutput' in op.input_arg_names[0] assert 'FusedGrad' in op.input_arg_names[0]
def test_hybrid_with_mp_pp_amp_gclip(self): def test_hybrid_with_mp_pp_amp_gclip(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program( train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册