未验证 提交 78465703 编写于 作者: W WangXi 提交者: GitHub

[hybrid] out data parallel as optimizer sharding parallel (#35593)

上级 2b88057f
......@@ -43,6 +43,8 @@ message ShardingConfig {
optional bool pp_allreduce_in_optimize = 10 [ default = false ];
optional int32 pp_degree = 11 [ default = 1 ];
optional bool optimize_cast = 12 [ default = false ];
// Optimizer sharding. Temporary plans and may be deprecated
optional bool _dp_as_optimizer_sharding = 13 [ default = false ];
}
message HybridConfig {
......
......@@ -26,27 +26,28 @@ class CheckFiniteAndUnscaleOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X",
"check_finite_and_unscale");
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out",
"check_finite_and_unscale");
PADDLE_ENFORCE_EQ(
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
platform::errors::InvalidArgument(
"The input(X) and output(Out) should have same size in "
"Operator(check_finite_and_unscale), size of input(X) is %d "
"and size of output(Out) is %d.",
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
if (ctx->HasInputs("X") || ctx->HasOutputs("Out")) {
PADDLE_ENFORCE_EQ(
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
platform::errors::InvalidArgument(
"The input(X) and output(Out) should have same size in "
"Operator(check_finite_and_unscale), size of input(X) is %d "
"and size of output(Out) is %d.",
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
}
ctx->SetOutputDim("FoundInfinite", {1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
auto dtype = framework::proto::VarType::FP32;
if (ctx.MultiInputVar("X").size() >= 1) {
dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
}
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};
......
......@@ -97,6 +97,8 @@ class CheckFiniteAndUnscaleGpuKernel : public framework::OpKernel<T> {
scale_data, inverse_scale_v, found_inf_data);
size_t xs_size = xs.size();
if (xs_size == 0) return;
const auto& cpu_place = platform::CPUPlace();
// calculate each tensor's start index and copy to device
auto h_starts_tensor =
......
......@@ -26,7 +26,6 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("FoundInfinite"), "Input", "FoundInfinite",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("PrevLossScaling"), "Input", "PrevLossScaling",
......@@ -35,16 +34,25 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasInput("InBadSteps"), "Input", "InBadSteps",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutputs("Out"), "Output", "Out",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("LossScaling"), "Output", "LossScaling",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("OutGoodSteps"), "Output", "OutGoodSteps",
"update_loss_scaling");
OP_INOUT_CHECK(ctx->HasOutput("OutBadSteps"), "Output", "OutBadSteps",
"update_loss_scaling");
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
if (ctx->HasInputs("X") || ctx->HasOutputs("Out")) {
PADDLE_ENFORCE_EQ(
ctx->Inputs("X").size(), ctx->Outputs("Out").size(),
platform::errors::InvalidArgument(
"The input(X) and output(Out) should have same size in "
"Operator(update_loss_scaling), size of input(X) is %d "
"and size of output(Out) is %d.",
ctx->Inputs("X").size(), ctx->Outputs("Out").size()));
auto x_dims = ctx->GetInputsDim("X");
ctx->SetOutputsDim("Out", x_dims);
}
ctx->SetOutputDim("LossScaling", {1});
ctx->SetOutputDim("OutGoodSteps", {1});
ctx->SetOutputDim("OutBadSteps", {1});
......@@ -53,8 +61,12 @@ class UpdateLossScalingOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
auto dtype = framework::proto::VarType::FP32;
if (ctx.MultiInputVar("X").size() >= 1) {
dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
}
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};
......
......@@ -95,6 +95,8 @@ class LazyZeros<platform::CUDADeviceContext, T> {
const std::vector<const framework::Tensor*>& xs,
const std::vector<framework::Tensor*>& outs) const {
size_t xs_size = xs.size();
if (xs_size == 0) return;
const auto& cpu_place = platform::CPUPlace();
// alloc each tensor's start index and copy to device
auto h_in_starts_mem =
......
......@@ -105,7 +105,6 @@ class FP16Utils(object):
if op.type == "update_loss_scaling":
update_loss_scaling_op_idx = idx
inf_var_name = op.desc.input('FoundInfinite')[0]
op._rename_input(inf_var_name, inf_var_name + "@sharding")
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
reversed_x = []
reversed_x_paramname = []
......@@ -142,10 +141,6 @@ class FP16Utils(object):
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
inf_var_sharding = block.create_var(
name=inf_var_name + "@sharding",
shape=inf_var.shape,
dtype=inf_var.dtype)
block._insert_op_without_sync(
update_loss_scaling_op_idx,
......@@ -179,10 +174,10 @@ class FP16Utils(object):
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_sharding},
outputs={'Out': inf_var},
attrs={
"in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var_sharding.dtype,
"out_dtype": inf_var.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
update_loss_scaling_op_idx += 1
......@@ -210,10 +205,6 @@ class FP16Utils(object):
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
inf_var_global = block.create_var(
name=inf_var_name + "@GLOBAL_WORLD",
shape=inf_var.shape,
dtype=inf_var.dtype)
block._insert_op_without_sync(
update_loss_scaling_op_idx,
type='cast',
......
......@@ -39,6 +39,7 @@ class GradientClipHelper(object):
if not self._is_gradient_clip_op(op):
continue
if op.type == "sum":
global_norm_sum_op_idx = idx
continue
deperate_op = False
for input_name in op.desc.input_arg_names():
......@@ -61,7 +62,10 @@ class GradientClipHelper(object):
if output_name not in op.desc.input_arg_names():
deperated_vars.add(output_name)
if not deperated_vars:
# NOTE(wangxi): If only have 2 sharding, and 1 param.
# sharding 0 will not deperated_vars, will return, only
# sharding 1 will insert allreduce, then hang.
if not deperated_vars and global_norm_sum_op_idx == -1:
# got no gradient_clip op
return
......@@ -71,8 +75,8 @@ class GradientClipHelper(object):
if idx in deperate_op_idx:
block._remove_op(idx, sync=False)
continue
reversed_inputs = []
if op.type == "sum":
reversed_inputs = []
global_norm_sum_op_idx = idx
for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars:
......@@ -82,6 +86,28 @@ class GradientClipHelper(object):
assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0]
# NOTE(wangxi): If we have 2 param, but sharding is 4,
# then the sum op in some cards will not have input.
# So we use fill_constant_op to set `sum_var` to zero,
# which does not affect correctness.
if len(reversed_inputs) == 0:
sum_var = block.var(sum_res)
namescope = op.attr("op_namescope")
block._remove_op(idx, sync=False)
op = block._insert_op_without_sync(
idx,
type='fill_constant',
inputs={},
outputs={'Out': sum_res},
attrs={
'shape': sum_var.shape,
'dtype': sum_var.dtype,
'value': 0.0,
OP_ROLE_KEY: OpRole.Optimize
})
op._set_attr('op_namescope', namescope)
# allreduce(mp)->allreduce(sharding)->allreduce(pp)
idx_offset = 1
for ring_id in ring_ids:
......
......@@ -117,21 +117,28 @@ class ProgramDeps(object):
var_name] == []:
self._block._remove_var(var_name, sync=False)
def remove_op(self, op_idx):
def remove_op(self, op_idx, reserved_vars=None):
# update deps
op = self._block.ops[op_idx]
for input_name in op.desc.input_arg_names():
if reserved_vars is not None and input_name in reserved_vars:
continue
self.crop_input_var_from_op(op_idx, input_name)
for output_name in op.desc.output_arg_names():
if reserved_vars is not None and output_name in reserved_vars:
continue
self.crop_output_var_from_op(op_idx, output_name)
self._block._remove_op(op_idx, sync=False)
def should_remove_op(self, op_idx):
op = self._block.ops[op_idx]
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
# remove check_finite_and_unscale op if its input 'X' is empty
if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0:
return True
# NOTE: At present, it is found that the OP without output is
# only send_v2 and partial_send op, which will be used in
# all device
if len(op.desc.output_arg_names()) == 0:
return False
for output_name in op.desc.output_arg_names():
if output_name not in self._should_removed_var:
return False
......
......@@ -24,7 +24,8 @@ class Shard(object):
self.global_params = set([])
self.worker_idx = -1
self.worker_num = -1
self.global_param2device = {}
self.global_param2device = dict()
self.device2global_params = dict()
def setup(self, params_grads, worker_idx, worker_num):
# param names of all devices
......@@ -33,8 +34,9 @@ class Shard(object):
self.worker_idx = worker_idx
self.worker_num = worker_num
# global_param2device contains fp32 params and fp16 params
self.global_param2device = self._split_params(params_grads, worker_idx,
worker_num)
# device2global_params only contains fp32 params
self.global_param2device, self.device2global_params \
= self._split_params(params_grads, worker_idx, worker_num)
def has_param(self, var_name):
return var_name in self.global_param2device and \
......@@ -64,7 +66,7 @@ class Shard(object):
device2params[device_idx].append(param_name)
param2device[param_name] = device_idx
mem_accu += mem
return param2device
return param2device, device2params
def _var_device_id(self, var_name):
if var_name in self.global_param2device:
......
......@@ -365,6 +365,65 @@ def insert_allreduce_ops(block,
return
class FuseHelper(object):
@staticmethod
def get_fused_groups(block, vars_name, fuse_size=32.):
""" coalesce tensor, get fused group """
groups = []
cur_size = 0.
last_dtype = None
for var_name in vars_name:
real_var = block.var(var_name)
var_size = get_var_size(real_var)
if cur_size + var_size > fuse_size \
or len(groups) == 0 \
or real_var.dtype != last_dtype:
groups.append([real_var])
cur_size = var_size
last_dtype = real_var.dtype
else:
groups[-1].append(real_var)
cur_size += var_size
return groups
@staticmethod
def insert_coalesce_tensor(block,
index,
groups,
op_role=OpRole.Backward,
prefix="Output"):
fused_vars = []
insert_num = 0
for group in groups:
assert len(group) >= 1
if len(group) == 1:
# no need fuse
fused_vars.append(group[0])
continue
fused_var = block.create_var(
name=unique_name.generate('Fused{}_{}'.format(prefix, group[0]
.name)),
dtype=group[0].dtype,
persistable=False,
stop_gradient=True)
fused_vars.append(fused_var)
block._insert_op_without_sync(
index,
type="coalesce_tensor",
inputs={"Input": group},
outputs={"Output": group,
"FusedOutput": fused_var},
attrs={
"copy_data": True,
"use_align": True,
"dtype": group[0].dtype,
OP_ROLE_KEY: op_role
})
insert_num += 1
return fused_vars, insert_num
def insert_fused_allreduce_ops(block,
insert_idx,
ring_id,
......@@ -372,46 +431,15 @@ def insert_fused_allreduce_ops(block,
op_role=OpRole.Backward,
use_calc_stream=False,
fuse_grad_size_in_MB=32):
segments = []
cur_size = 0.
last_dtype = None
for var in allreduce_vars:
real_var = block.var(var)
var_size = get_var_size(real_var)
if cur_size + var_size > fuse_grad_size_in_MB \
or len(segments) == 0 \
or real_var.dtype != last_dtype:
segments.append([real_var])
cur_size = var_size
last_dtype = real_var.dtype
else:
segments[-1].append(real_var)
cur_size += var_size
fused_vars = []
for segment in segments:
tmp_var = block.create_var(
name=unique_name.generate('FusedOutput_{}'.format(segment[0].name)),
dtype=segment[0].dtype,
persistable=False,
stop_gradient=True)
fused_vars.append(tmp_var)
block._insert_op_without_sync(
insert_idx,
type="coalesce_tensor",
inputs={"Input": segment},
outputs={"Output": segment,
"FusedOutput": tmp_var},
attrs={
"copy_data": True,
"use_align": True,
"dtype": segment[0].dtype,
OP_ROLE_KEY: op_role
})
groups = FuseHelper.get_fused_groups(block, allreduce_vars,
fuse_grad_size_in_MB)
fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(
block, insert_idx, groups, op_role, prefix="Grad")
for fused_var in fused_vars:
block._insert_op_without_sync(
insert_idx + len(fused_vars),
insert_idx + insert_num,
type='c_allreduce_sum',
inputs={'X': fused_var},
outputs={'Out': fused_var},
......@@ -422,13 +450,61 @@ def insert_fused_allreduce_ops(block,
})
if not use_calc_stream:
block._insert_op_without_sync(
insert_idx + len(fused_vars),
insert_idx + insert_num,
type='c_sync_calc_stream',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={OP_ROLE_KEY: op_role})
def insert_fused_reduce_ops(block,
insert_idx,
ring_id,
reduce_vars,
shard,
op_role=OpRole.Backward,
use_calc_stream=False,
rank=None,
fuse_grad_size=32):
nranks = shard.worker_num
device_to_vars = [[] for _ in range(nranks)]
for var in reduce_vars:
root_id = get_grad_device(var, shard)
assert 0 <= root_id < nranks, "root_id should >=0 and < nranks, " \
"but now nranks={}, the root_id of var={} is {}"\
.format(nranks, var, root_id)
device_to_vars[root_id].append(var)
for root_id, vars_name in enumerate(device_to_vars):
groups = FuseHelper.get_fused_groups(block, vars_name, fuse_grad_size)
fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(
block, insert_idx, groups, op_role, prefix="Grad")
for fused_var in fused_vars:
block._insert_op_without_sync(
insert_idx + insert_num,
type='c_reduce_sum',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={
'ring_id': ring_id,
'root_id': root_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
if not use_calc_stream:
block._insert_op_without_sync(
insert_idx + insert_num,
type='c_sync_calc_stream',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={OP_ROLE_KEY: op_role})
return [] if rank is None else device_to_vars[rank]
def insert_reduce_ops(block,
insert_idx,
ring_id,
......@@ -436,14 +512,26 @@ def insert_reduce_ops(block,
shard,
op_role=OpRole.Backward,
use_calc_stream=False,
rank=None):
rank=None,
strategy=None):
"""
_add_allreduce_ops
_add_reduce_ops
"""
if strategy and strategy.fuse_all_reduce_ops and \
not strategy.fuse_grad_merge:
return insert_fused_reduce_ops(block, insert_idx, ring_id, reduce_vars,
shard, op_role, use_calc_stream, rank,
strategy.fuse_grad_size_in_MB)
grad_in_this_device = []
for var in reduce_vars:
root_id = get_grad_device(var, shard)
grad_var = var
if strategy and strategy.fuse_all_reduce_ops and \
strategy.fuse_grad_merge:
# TODO(wangxi): if support fp16_allreduce, need be
# 'FusedMergedGrad.cast_fp16._'
grad_var = var.replace('FusedMergedGrad_', '')
root_id = get_grad_device(grad_var, shard)
assert root_id >= 0, "root id should be a positive int, but now root id is {}".format(
root_id)
if rank is not None and rank == root_id:
......@@ -463,6 +551,94 @@ def insert_reduce_ops(block,
return grad_in_this_device
def insert_fused_broadcast_param_ops(block,
insert_idx,
ring_id,
params,
shard,
op_role=OpRole.Optimize,
use_calc_stream=False,
rank=None,
fuse_size=32):
nranks = shard.worker_num
device_to_vars = [[] for _ in range(nranks)]
for var in params:
root_id = shard.device(var)
assert 0 <= root_id < nranks, "root_id should >=0 and < nranks, " \
"but now nranks={}, the root_id of var={} is {}"\
.format(nranks, var, root_id)
device_to_vars[root_id].append(var)
for root_id, vars_name in enumerate(device_to_vars):
groups = FuseHelper.get_fused_groups(block, vars_name, fuse_size)
fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(
block, insert_idx, groups, op_role, prefix="Param")
for fused_var in fused_vars:
block._insert_op_without_sync(
insert_idx + insert_num,
type='c_broadcast',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={
'ring_id': ring_id,
'root': root_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
if not use_calc_stream:
block._insert_op_without_sync(
insert_idx + insert_num,
type='c_sync_calc_stream',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={OP_ROLE_KEY: op_role})
return [] if rank is None else device_to_vars[rank]
def insert_broadcast_param_ops(block,
insert_idx,
ring_id,
params,
shard,
op_role=OpRole.Optimize,
use_calc_stream=False,
rank=None,
strategy=None):
"""
add broadcast param ops
"""
if strategy and strategy.fuse_all_reduce_ops:
# TODO(wangxi): put fused var in startup_program, only need exec once
return insert_fused_broadcast_param_ops(
block, insert_idx, ring_id, params, shard, op_role, use_calc_stream,
rank, strategy.fuse_grad_size_in_MB)
param_in_this_device = []
for param in params:
root_id = shard.device(param)
assert root_id >= 0, "root id should be a positive int, but now root id is {}".format(
root_id)
if rank is not None and rank == root_id:
param_in_this_device.append(param)
block._insert_op_without_sync(
insert_idx,
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': ring_id,
'root': root_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
return param_in_this_device
def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name)
......
......@@ -15,19 +15,22 @@
import paddle
from paddle.fluid import unique_name, core
import paddle.fluid as fluid
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op
from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase
from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper import WeightDecayHelper
from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper
from .sharding.offload_helper import OffloadHelper
from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard
from paddle.static import default_startup_program, device_guard
from paddle.fluid import layers
from .common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper
from .common import is_backward_op, is_optimizer_op, is_update_op
from .meta_optimizer_base import MetaOptimizerBase
from .sharding.shard import Shard, ProgramSegment
from .sharding.fp16_helper import FP16Utils
from .sharding.weight_decay_helper import WeightDecayHelper
from .sharding.gradient_clip_helper import GradientClipHelper
from .sharding.offload_helper import OffloadHelper
from .sharding.prune import ProgramDeps
from .sharding import utils
# FIXME: import *
from .sharding.utils import *
import logging
logger = logging.getLogger(__name__)
formatter = logging.Formatter(
......@@ -35,7 +38,6 @@ formatter = logging.Formatter(
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
from functools import reduce
__all__ = []
......@@ -154,10 +156,11 @@ class ShardingOptimizer(MetaOptimizerBase):
def _get_hybrid_dp_mode(self):
""" get
self.hybrid_dp_mode
self.gradient_merge_mode
self.hybrid_dp_mode = 'pp_hybrid_dp' or 'sharding_hybrid_dp'
self.gradient_merge_mode = 'pp_gm' or 'sharding_gm'
self._gradient_merge_acc_step
self.pp_allreduce_in_optimize
self._optimizer_sharding
"""
strategy = self.user_defined_strategy
sharding_configs = strategy.sharding_configs
......@@ -194,9 +197,18 @@ class ShardingOptimizer(MetaOptimizerBase):
logger.info("Gradient merge in [{}], acc step = [{}]".format(
gm_mode, gm_acc_step))
optimizer_sharding = False
# TODO(wangxi): need support dp_as_opt_sharding with sharding
# need support without pp in future
if self.sharding_degree == 1 and self.dp_degree > 1 \
and sharding_configs['_dp_as_optimizer_sharding'] \
and self.pp_degree > 1:
optimizer_sharding = True
self.hybrid_dp_mode = dp_mode
self.gradient_merge_mode = gm_mode
self._gradient_merge_acc_step = gm_acc_step
self._optimizer_sharding = optimizer_sharding
# this feature is design for ascend, and should NOT be used in GPU training
self.pp_allreduce_in_optimize = sharding_configs[
......@@ -276,7 +288,8 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_block = self._startup_program.global_block()
# step1: build shard
self._build_shard(params_grads)
self._build_shard(params_grads, self.sharding_rank,
self.sharding_degree)
# step2: split_program
self._split_program(main_block)
......@@ -287,14 +300,35 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_block._sync_with_cpp()
# step4: remove unneeded ops and vars from block
self._prune_main_program(main_block)
self._prune_startup_program(startup_block)
self._prune_main_program(
main_block, self._shard,
[self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id])
self._prune_startup_program(startup_block, self._shard)
def _apply_opt_sharding_pass(self, params_grads):
""" outer dp as optimizer sharding """
if self._optimizer_sharding is False: return
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
def _insert_allreduce_for_pp(self):
# step1: build shard
self._build_shard(params_grads, self.dp_rank, self.dp_degree)
# NOTE(wangxi): prune_main_program will prune cast if not add this
for param, grad in params_grads:
self._reduced_grads_to_param[grad.name] = param.name
# step4: remove unneeded ops and vars from block
self._prune_main_program(
main_block, self._shard,
[self.mp_ring_id, self.pp_ring_id, self.dp_ring_id])
self._prune_startup_program(startup_block, self._shard)
def _insert_allreduce_for_pp(self, params_grads):
if self.pp_degree == 1: return
strategy = self.user_defined_strategy
fp16_allreduce = strategy.fp16_allreduce
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
......@@ -318,10 +352,13 @@ class ShardingOptimizer(MetaOptimizerBase):
if in_name not in main_block.vars:
main_block._remove_op(idx)
if self._optimizer_sharding:
# TODO(wangxi): support fp16_allreduce with optimizer sharding
strategy.fp16_allreduce = False
shard = self._shard if self._optimizer_sharding else None
accumulated_grad_names = self._pp_optimizer._accumulate_gradients(
main_block,
fp16_allreduce=fp16_allreduce,
user_defined_strategy=strategy)
main_block, strategy=strategy, shard=shard)
len_of_ops = len(main_block.ops)
first_optimize_op_index = get_first_optimize_op_idx(main_block)
......@@ -346,7 +383,36 @@ class ShardingOptimizer(MetaOptimizerBase):
first_optimize_op_index += (len(main_block.ops) - len_of_ops)
len_of_ops = len(main_block.ops)
if self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp":
if self._optimizer_sharding:
accumulated_grad_names = utils.insert_reduce_ops(
main_block,
first_optimize_op_index,
self.dp_ring_id,
accumulated_grad_names,
self._shard,
OpRole.Optimize,
use_calc_stream=True,
rank=self.dp_rank,
strategy=strategy)
logger.info("Optimizer grad in this rank {}".format(
accumulated_grad_names))
first_optimize_op_index += (len(main_block.ops) - len_of_ops)
len_of_ops = len(main_block.ops)
optimizer_param = utils.insert_broadcast_param_ops(
main_block,
len_of_ops,
self.dp_ring_id, [x[0].name for x in params_grads],
self._shard,
OpRole.Optimize,
use_calc_stream=True,
rank=self.dp_rank,
strategy=strategy)
logger.info("Optimizer param in this rank {}".format(
optimizer_param))
if not strategy.fuse_grad_merge:
assert len(accumulated_grad_names) == len(optimizer_param)
elif self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp":
insert_allreduce_ops(
main_block,
first_optimize_op_index,
......@@ -361,9 +427,10 @@ class ShardingOptimizer(MetaOptimizerBase):
# FIXME(wangxi): if fp16_allreduce, put cast fp16->fp32 to there?
def _adapt_amp_clip_without_sharding(self):
if self.sharding_degree > 1: return
# if not use sharding, adapt amp/clip, for remain parallelism.
# cast --> amp --> clip --> opt
if self.sharding_degree > 1: return
if self._optimizer_sharding: return
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
......@@ -449,7 +516,9 @@ class ShardingOptimizer(MetaOptimizerBase):
self._apply_sharding_pass(params_grads)
self._insert_allreduce_for_pp()
self._apply_opt_sharding_pass(params_grads)
self._insert_allreduce_for_pp(params_grads)
self._adapt_amp_clip_without_sharding()
......@@ -630,11 +699,10 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_block._sync_with_cpp()
def _build_shard(self, params_grads):
def _build_shard(self, params_grads, shard_rank, shard_size):
# step 2: split params
self._params = set([x[0].name for x in params_grads])
self._shard.setup(params_grads, self.sharding_rank,
self.sharding_degree)
self._shard.setup(params_grads, shard_rank, shard_size)
# step 3: get broadcast vars
self._broadcast_vars = self._shard.find_broadcast_params(
......@@ -787,7 +855,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self._segments[idx_]._end_idx].desc.input_arg_names()))
return
def _prune_main_program(self, block):
def _prune_main_program(self, block, shard, rings):
"""
calculate deps from allredce op to optimize op,
remove ops and vars not needed in this worker
......@@ -799,28 +867,26 @@ class ShardingOptimizer(MetaOptimizerBase):
"""
weightdecay_helper = WeightDecayHelper()
weightdecay_helper.prune_weight_decay(block, self._shard)
weightdecay_helper.prune_weight_decay(block, shard)
# FIXME(wangxi): mp should prune duplicated param_grads
# NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism
# group. and each Data Parallelism group should have its own sync of FoundInfinite
# amp could use global group for sync
FP16Utils.prune_fp16(
block, self._shard, self._reduced_grads_to_param,
[self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id])
FP16Utils.prune_fp16(block, shard, self._reduced_grads_to_param, rings)
# clipbyglobalnorm should only use the Model paramllelism group (mp-sharding-pp)
gradientclip_helper = GradientClipHelper(None)
gradientclip_helper.prune_gradient_clip(
block, self._shard,
[self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id])
gradientclip_helper.prune_gradient_clip(block, shard, rings)
# build prog deps
reduced_grads = []
for idx, op in enumerate(block.ops):
input_names = op.desc.input_arg_names()
output_names = op.desc.output_arg_names()
if op.type == "c_allreduce_sum":
# FIXME(wangxi): need use grads, pipeline grad is @GRAD@MERGE
if op.type == "c_allreduce_sum" and \
op.attr('use_model_parallel') is False:
assert (len(output_names) == 1)
output_name = output_names[0]
reduced_grads.append(output_name)
......@@ -828,8 +894,8 @@ class ShardingOptimizer(MetaOptimizerBase):
# prune optimizer state and param
pruned_opti_vars = []
for var_name in list(block.vars.keys()):
if self._shard.is_opti_var(var_name) and \
not self._shard.has_opt_var(var_name):
if shard.is_opti_var(var_name) and \
not shard.has_opt_var(var_name):
pruned_opti_vars.append(var_name)
program_deps = ProgramDeps(block, reduced_grads, pruned_opti_vars)
......@@ -881,7 +947,9 @@ class ShardingOptimizer(MetaOptimizerBase):
# if all outputs of this op are in _should_removed_var
# _should_removed_var: opt state not cur shard
if program_deps.should_remove_op(idx):
program_deps.remove_op(idx)
# NOTE(wangxi): need reserve all param in optimizer_sharding
reserved_vars = self._params if self._optimizer_sharding else None
program_deps.remove_op(idx, reserved_vars)
# NOTE (JZ-LIANG) revise and unify logic here
# sharding support fp16_allreduce logic
......@@ -1112,17 +1180,21 @@ class ShardingOptimizer(MetaOptimizerBase):
return
def _prune_startup_program(self, block):
def _prune_startup_program(self, block, shard):
for idx, op in reversed(list(enumerate(block.ops))):
for output_name in op.desc.output_arg_names():
if self._shard.has_var(output_name):
if shard.has_var(output_name):
continue
if self._optimizer_sharding and shard.is_param(output_name):
continue
#TODO why do we remove op, when only one var is removed
block._remove_op(idx, sync=False)
break
for var_name in list(block.vars.keys()):
if self._shard.has_var(var_name):
if shard.has_var(var_name):
continue
if self._optimizer_sharding and shard.is_param(var_name):
continue
block._remove_var(var_name, sync=False)
block._sync_with_cpp()
......
......@@ -5049,16 +5049,16 @@ class PipelineOptimizer(object):
def _accumulate_gradients(self,
block,
pp_allreduce_in_optimize=False,
fp16_allreduce=False,
user_defined_strategy=None):
strategy=None,
shard=None):
"""
Create a new merged gradient for each parameter and accumulate the
corresponding gradient to it.
"""
if user_defined_strategy and user_defined_strategy.fuse_grad_merge:
fp16_allreduce = strategy.fp16_allreduce if strategy else False
if strategy and strategy.fuse_grad_merge:
fused_gradient_names = self._accumulate_gradients_with_fuse(
block, fp16_allreduce,
user_defined_strategy.fuse_grad_size_in_MB)
block, fp16_allreduce, strategy.fuse_grad_size_in_MB, shard)
return fused_gradient_names
merged_gradient_names = []
......@@ -5079,8 +5079,8 @@ class PipelineOptimizer(object):
if self._is_backward_op(op) and first_opt_op_idx is None:
first_opt_op_idx = index + 1
# no optimize phase
if first_opt_op_idx == len(block.ops): return
# maybe have no optimize
# if first_opt_op_idx == len(block.ops): return
if self._is_backward_op(op) and (
self._op_role_var_key in op.attr_names):
......@@ -5190,44 +5190,9 @@ class PipelineOptimizer(object):
return merged_gradient_names
def _accumulate_gradients_with_fuse(self, main_block, fp16, fused_size):
first_opt_op_idx = None
grad_param_pairs = []
# obtain all param/grad pairs that needed to be fused
for index, op in reversed(tuple(enumerate(list(main_block.ops)))):
# remove the cast op of fp16 grad to fp32 grad
if self._is_optimize_op(op) and op.type == 'cast':
in_name = op.input_arg_names[0]
out_name = op.output_arg_names[0]
if out_name.strip('@GRAD') in self._param_device_map:
assert in_name.replace('.cast_fp16', '') == out_name
main_block._remove_op(index)
continue
if self._is_backward_op(op) and first_opt_op_idx is None:
first_opt_op_idx = index + 1
# no optimize phase
if first_opt_op_idx == len(main_block.ops):
return
if self._is_backward_op(op) and (
self._op_role_var_key in op.attr_names):
op_role_var = op.attr(self._op_role_var_key)
if len(op_role_var) == 0:
continue
assert len(op_role_var) % 2 == 0
for i in range(0, len(op_role_var), 2):
param_name = op_role_var[i]
if not main_block.has_var(param_name):
continue
if '@BroadCast' in param_name:
continue
grad_param_pairs.append(
(op_role_var[i + 1], op_role_var[i]))
if len(grad_param_pairs) == 0:
return
def _insert_accumulate_gradients_with_fuse(self, main_block, fp16,
fused_size, grad_param_pairs,
first_opt_op_idx):
grad_param_pairs = self._sort_grad_param_by_dtype(main_block,
grad_param_pairs)
......@@ -5426,9 +5391,66 @@ class PipelineOptimizer(object):
for i in range(len(fused_merged_gradients)):
fused_merged_gradients[i] = fused_merged_gradients[i].name
main_block._sync_with_cpp()
return fused_merged_gradients, first_opt_op_idx
return fused_merged_gradients
def _accumulate_gradients_with_fuse(self,
main_block,
fp16,
fused_size,
shard=None):
first_opt_op_idx = None
grad_param_pairs = []
# obtain all param/grad pairs that needed to be fused
for index, op in reversed(tuple(enumerate(list(main_block.ops)))):
# remove the cast op of fp16 grad to fp32 grad
if self._is_optimize_op(op) and op.type == 'cast':
in_name = op.input_arg_names[0]
out_name = op.output_arg_names[0]
if out_name.strip('@GRAD') in self._param_device_map:
assert in_name.replace('.cast_fp16', '') == out_name
main_block._remove_op(index)
continue
if self._is_backward_op(op) and first_opt_op_idx is None:
first_opt_op_idx = index + 1
# no optimize phase
if first_opt_op_idx == len(main_block.ops):
return
if self._is_backward_op(op) and (
self._op_role_var_key in op.attr_names):
op_role_var = op.attr(self._op_role_var_key)
if len(op_role_var) == 0:
continue
assert len(op_role_var) % 2 == 0
for i in range(0, len(op_role_var), 2):
param_name = op_role_var[i]
if not main_block.has_var(param_name):
continue
if '@BroadCast' in param_name:
continue
grad_param_pairs.append(
(op_role_var[i + 1], op_role_var[i]))
if len(grad_param_pairs) == 0:
return
nranks = shard.worker_num if shard else 1
device_to_pairs = [[] for _ in range(nranks)]
for pair in grad_param_pairs:
root_id = shard.device(pair[1]) if shard else 0
assert 0 <= root_id < nranks
device_to_pairs[root_id].append(pair)
all_fused_merged_gradients = []
for pairs in device_to_pairs:
fused_merged_gradients, first_opt_op_idx = \
self._insert_accumulate_gradients_with_fuse(
main_block, fp16, fused_size, pairs, first_opt_op_idx)
all_fused_merged_gradients += fused_merged_gradients
main_block._sync_with_cpp()
return all_fused_merged_gradients
def _sort_grad_param_by_dtype(self, main_block, grad_param_pairs):
# sort the grad param paris by the dtype
......
......@@ -70,6 +70,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_amp_init)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_gradient_merge_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_sharding_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_hybrid_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_localsgd_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_lars_meta_optimizer)
list(APPEND MIXED_DIST_TEST_OPS test_fleet_lamb_meta_optimizer)
......@@ -568,6 +569,7 @@ if(WITH_DISTRIBUTE)
py_test_modules(test_fleet_graph_executor MODULES test_fleet_graph_executor ENVS ${dist_ENVS})
py_test_modules(test_fleet_gradient_merge_meta_optimizer MODULES test_fleet_gradient_merge_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_sharding_meta_optimizer MODULES test_fleet_sharding_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_hybrid_meta_optimizer MODULES test_fleet_hybrid_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_amp_meta_optimizer MODULES test_fleet_amp_meta_optimizer ENVS ${dist_ENVS})
py_test_modules(test_fleet_amp_init MODULES test_fleet_amp_init ENVS ${dist_ENVS})
py_test_modules(test_fleet_fp16_allreduce_meta_optimizer MODULES test_fleet_fp16_allreduce_meta_optimizer ENVS ${dist_ENVS})
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
import paddle
from paddle import fluid
......@@ -25,6 +26,23 @@ class TestFleetMetaOptimizer(unittest.TestCase):
os.environ["PADDLE_TRAINER_ID"] = "1"
os.environ[
"PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.1:36002"
self._debug = False
def debug_program(self, main_prog, startup_prog):
if not self._debug: return
main_prog_ops = main_prog.global_block().ops
startup_prog_ops = startup_prog.global_block().ops
main_prog_op_types = [op.type for op in main_prog_ops]
startup_prog_op_types = [op.type for op in startup_prog_ops]
print("=== debug program and ops in func [{}] ==="
.format(inspect.stack()[1].function))
print(main_prog)
print(main_prog_op_types)
print(startup_prog)
print(startup_prog_op_types)
def net(self, main_prog, startup_prog):
with fluid.program_guard(main_prog, startup_prog):
......@@ -82,6 +100,20 @@ class TestFleetMetaOptimizer(unittest.TestCase):
strategy = paddle.distributed.fleet.DistributedStrategy()
return avg_cost, strategy
def boundary_net(self, main_prog, startup_prog):
with fluid.program_guard(main_prog, startup_prog):
fleet.init(is_collective=True)
x = paddle.static.data(name='x', shape=[-1, 4], dtype='float32')
with paddle.static.device_guard('gpu:0'):
linear = fluid.Linear(4, 8, bias_attr=False)
out = linear(x)
with paddle.static.device_guard('gpu:1'):
linear = fluid.Linear(8, 5, bias_attr=False)
out = linear(out)
avg_cost = paddle.mean(out)
strategy = fleet.DistributedStrategy()
return avg_cost, strategy
def optimizer(self,
loss,
strategy,
......@@ -190,5 +222,12 @@ class TestFleetMetaOptimizer(unittest.TestCase):
"enable_offload": True,
"checkpoint_shape": [256]
}
elif name == "pipeline":
strategy.pipeline = True
strategy.pipeline_configs = {
"schedule_mode": "1F1B",
"micro_batch_size": 2,
"accumulate_steps": 4,
}
else:
raise NotImplementedError()
......@@ -612,7 +612,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
for op in main_prog_ops:
if op.type == 'c_allreduce_sum':
assert 'FusedOutput' in op.input_arg_names[0]
assert 'FusedGrad' in op.input_arg_names[0]
def test_hybrid_with_mp_pp_amp_gclip(self):
train_prog, startup_prog = paddle.fluid.Program(), paddle.fluid.Program(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册