提交 39f08eb7 编写于 作者: Z Ziyan

enable optimizer parallel

上级 7976d775
......@@ -100,7 +100,10 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic
auto parallel_context = parallel::ParallelContext::GetInstance();
MS_EXCEPTION_IF_NULL(parallel_context);
const auto &split_indices = parallel_context->GetAllReduceFusionSplitIndices(group);
std::vector<uint32_t> split_indices;
if (!parallel_context->enable_parallel_optimizer()) {
split_indices = parallel_context->GetAllReduceFusionSplitIndices(group);
}
size_t segments = 0;
if (split_indices.size() != 0) {
......
......@@ -443,7 +443,7 @@ def _context():
@args_type_check(device_num=int, global_rank=int, mirror_mean=bool, cast_before_mirror=bool, parallel_mode=str,
auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool)
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool)
def set_auto_parallel_context(**kwargs):
"""
Set auto parallel context.
......@@ -487,6 +487,9 @@ def set_auto_parallel_context(**kwargs):
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in
data parallel training in the benefit of time and memory saving.
Raises:
ValueError: If input key is not attribute in auto parallel context.
......@@ -532,6 +535,7 @@ def reset_auto_parallel_context():
- parameter_broadcast: False.
- strategy_ckpt_load_file: "".
- strategy_ckpt_save_file: "".
- enable_parallel_optimizer: False.
"""
_reset_auto_parallel_context()
......
......@@ -28,8 +28,8 @@ from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore import log as logger
from mindspore.parallel._utils import _get_global_rank, _get_device_num, _get_parallel_mode
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore.train.parallel_utils import ParallelMode
from mindspore import context
__all__ = ['Optimizer']
......@@ -157,13 +157,12 @@ class Optimizer(Cell):
self.param_length = len(self.parameters)
self.map_ = C.Map()
use_parallel = auto_parallel_context().get_enable_parallel_optimizer()
use_parallel = context.get_auto_parallel_context("enable_parallel_optimizer")
self.use_parallel = use_parallel
if use_parallel:
if self.cls_name not in ["Lamb", "AdamWeightDecayDynamicLR", "AdamWeightDecay"]:
raise RuntimeError("Optimizer segmentation does not support optimizer {}".format(self.cls_name))
if _get_parallel_mode() not in [ParallelMode.HYBRID_PARALLEL, ParallelMode.DATA_PARALLEL,
ParallelMode.AUTO_PARALLEL]:
if _get_parallel_mode() != ParallelMode.DATA_PARALLEL:
raise RuntimeError("Optimizer segmentation does not support parallel mode {}".format
(_get_parallel_mode()))
self.dev_num = _get_device_num()
......@@ -175,6 +174,7 @@ class Optimizer(Cell):
self.param_names = []
for param in self.parameters:
self.param_names.append(param.name)
else:
self.optim_filter = (True,) * self.param_length
......
......@@ -13,107 +13,95 @@
# limitations under the License.
# ============================================================================
"""grad reducer cell for distributed training"""
from mindspore import context
from mindspore.nn.cell import Cell
from mindspore.communication.management import GlobalComm, get_group_size
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp, AllGather
from mindspore.ops.operations.comm_ops import AllReduce, AllGather
from mindspore.parallel._auto_parallel_context import auto_parallel_context
import mindspore.common.dtype as mstype
reduce_opt = C.MultitypeFuncGraph("reduce_opt")
_all_reduce = AllReduce()
_all_gather = None
def _init_optimizer_communication():
global _all_reduce
global _all_gather
_all_reduce = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
_all_reduce.add_prim_attr('fusion', 1)
_all_gather = AllGather(GlobalComm.WORLD_COMM_GROUP)
@reduce_opt.register("Function", "Number", "Bool", "Tensor")
def _tensors_allreduce_mean(mul, degree, allreduce_filter, grad):
def _init_allreduce_operators(length):
""" initialize allreduce communication operators"""
is_parallel_optimizer = context.get_auto_parallel_context("enable_parallel_optimizer")
split_indices = auto_parallel_context().get_all_reduce_fusion_split_indices()
if is_parallel_optimizer and split_indices:
group = 1
fusion = ()
for i in range(length):
fusion = fusion + (group,)
if split_indices[group - 1] <= i + 1:
if group >= len(split_indices):
continue
group = group + 1
index = tuple(range(1, length + 1))
else:
fusion = (1,) * length
index = (0,) * length
opt_list = ()
for i in range(length):
opt = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP)
opt.add_prim_attr('fusion', fusion[i])
opt.add_prim_attr('index', index[i])
opt_list = opt_list + (opt,)
return opt_list
@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tensor", "Function")
def _tensors_allreduce(degree, mean, allgather, allreduce_filter, grad, allreduce):
"""
Apply mean and allreduce on gradient. Allreduce is a communication operation used for distributed deep learning.
Apply allreduce on gradient.
Args:
mul (Primitive): Div operation.
degree (int): The mean coefficient.
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
allgather (Primitive): The communication operator for sparse gradients.
allreduce_filter (bool): When it is true, allreduce would apply.
grad (Tensor): The gradient tensor before operation.
allreduce (Primitive): The communication operator for gradients.
Returns:
Tensor, the gradient tensor after operation.
"""
if allreduce_filter:
degree = F.scalar_cast(degree, F.dtype(grad))
grad = _all_reduce(grad)
cast_op = P.Cast()
return mul(grad, cast_op(F.scalar_to_array(1.0/degree), F.dtype(grad)))
grad = allreduce(grad)
if mean:
degree = F.scalar_cast(degree, F.dtype(grad))
cast_op = P.Cast()
mul_op = P.Mul()
grad = mul_op(grad, cast_op(F.scalar_to_array(1.0/degree), F.dtype(grad)))
return grad
return grad
@reduce_opt.register("Function", "Number", "Bool", "Tuple")
def _tensors_allreduce_mean_with_sparse(mul, degree, allreduce_filter, grad):
@reduce_opt.register("Number", "Bool", "Function", "Bool", "Tuple", "Function")
def _tensors_allreduce_with_sparse(degree, mean, allgather, allreduce_filter, grad, allreduce):
"""
Apply mean and allgather on gradient instead of allreduce for sparse feature.
Apply allgather on gradient instead of allreduce for sparse feature.
Allgather is a communication operation used for distributed deep learning.
Args:
mul (Primitive): Div operation.
degree (int): The mean coefficient.
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients.
allgather (Primitive): The communication operator for sparse gradients.
allreduce_filter (bool): When it is true, allgather would apply.
grad (Tuple): The indices, gradient tensor and tensor_shape before operation.
grad (tuple): The indices, gradient tensor and tensor_shape before operation.
allreduce (Primitive): The communication operator for gradients.
Returns:
Tuple, include indices, the gradient tensor and tensor_shape after operation.
"""
if allreduce_filter:
indices = _all_gather(grad[0])
degree = F.scalar_cast(degree, F.dtype(grad[1]))
dout = _all_gather(grad[1])
cast_op = P.Cast()
dout = mul(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
grad = (indices, dout, grad[2])
return grad
@reduce_opt.register("Bool", "Tensor")
def _tensors_allreduce(allreduce_filter, grad):
"""
Apply allreduce on gradient.
Args:
allreduce_filter (bool): When it is true, allreduce would apply.
grad (Tensor): The gradient tensor before operation.
Returns:
Tensor, the gradient tensor after operation.
"""
if allreduce_filter:
return _all_reduce(grad)
return grad
@reduce_opt.register("Bool", "Tuple")
def _tensors_allreduce_with_sparse(allreduce_filter, grad):
"""
Apply mean and allgather on gradient instead of allreduce for sparse feature.
Allgather is a communication operation used for distributed deep learning.
Args:
allreduce_filter (bool): When it is true, allgather would apply.
grad (Tuple): The indices, gradient tensor and tensor_shape before operation.
Returns:
Tuple, include indices, the gradient tensor and tensor_shape after operation.
"""
if allreduce_filter:
indices = _all_gather(grad[0])
dout = _all_gather(grad[1])
indices = allgather(grad[0])
dout = allgather(grad[1])
if mean:
degree = F.scalar_cast(degree, F.dtype(grad[1]))
cast_op = P.Cast()
mul_op = P.Mul()
dout = mul_op(dout, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(dout)))
grad = (indices, dout, grad[2])
return grad
......@@ -259,7 +247,6 @@ class DistributedGradReducer(Cell):
def __init__(self, parameters, mean=True, degree=None):
super(DistributedGradReducer, self).__init__(auto_prefix=False)
self.map_ = C.Map()
self.mul = P.Mul()
if degree is None:
self.degree = get_group_size()
else:
......@@ -268,7 +255,8 @@ class DistributedGradReducer(Cell):
self.degree = degree
self.mean = mean
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
_init_optimizer_communication()
self.opt_list = _init_allreduce_operators(len(parameters))
self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP)
def construct(self, grads):
"""
......@@ -284,11 +272,8 @@ class DistributedGradReducer(Cell):
"""
datatypes = self.map_(F.partial(_get_datatype), grads)
grads = self.map_(F.partial(_cast_datatype, mstype.float32), grads)
if self.mean:
new_grad = self.map_(F.partial(reduce_opt, self.mul, self.degree), self.allreduce_filter, grads)
else:
new_grad = self.map_(F.partial(reduce_opt), self.allreduce_filter, grads)
new_grad = self.map_(F.partial(reduce_opt, self.degree, self.mean, self.allgather),
self.allreduce_filter, grads, self.opt_list)
new_grad = self.map_(F.partial(_cast_datatype), datatypes, new_grad)
return new_grad
......@@ -513,7 +513,7 @@ def _set_auto_parallel_context(**kwargs):
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer (bool): Enable using optimizer segmentation or noe. Default: False.
enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False.
Raises:
ValueError: If input key is not attribute in auto parallel context.
......
......@@ -22,7 +22,6 @@ from mindspore.common.api import _executor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam, AdamWeightDecay, AdamWeightDecayDynamicLR, Lamb
from mindspore.ops import operations as P
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from mindspore import context
......@@ -54,8 +53,7 @@ class Net(nn.Cell):
def test_AdamWeightDecayDynamicLR():
""" test_AdamWeightDecayDynamicLR """
auto_parallel_context().set_enable_parallel_optimizer(True)
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2)
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = Net()
......@@ -70,8 +68,7 @@ def test_AdamWeightDecayDynamicLR():
def test_AdamWeightDecay():
""" test_AdamWeightDecayDynamicLR """
auto_parallel_context().set_enable_parallel_optimizer(True)
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2)
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = Net()
......@@ -86,8 +83,7 @@ def test_AdamWeightDecay():
def test_lamb_compile():
""" test_Lamb_compile """
auto_parallel_context().set_enable_parallel_optimizer(True)
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=2)
context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True)
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
label = Tensor(np.zeros([32, 768]).astype(np.float32))
net = Net()
......@@ -102,7 +98,7 @@ def test_lamb_compile():
def test_edge_case():
""" test_edge_case """
auto_parallel_context().set_enable_parallel_optimizer(True)
context.set_auto_parallel_context(enable_parallel_optimizer=True)
net = Net()
with pytest.raises(RuntimeError):
context.set_auto_parallel_context(parallel_mode="stand_alone")
......
......@@ -81,8 +81,8 @@ def test_set_auto_parallel_context():
with pytest.raises(ValueError):
set_algo_parameters(tensor_slice_align_size=1025)
auto_parallel_context().set_enable_parallel_optimizer(True)
assert auto_parallel_context().get_enable_parallel_optimizer() is True
context.set_auto_parallel_context(enable_parallel_optimizer=True)
assert context.get_auto_parallel_context("enable_parallel_optimizer")
assert not auto_parallel_context().get_all_reduce_fusion_split_indices()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册