diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 5e2c3394520bcc86be85adf00db043ec4106906a..81b355d24ba889806e445ff8036940d314b4ba3f 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -729,6 +729,35 @@ class Variable(object): """ self.error_clip = error_clip + def _set_info(self, key, value): + """ + Set some key-value information associated to this variable. + + Args: + key(str): The key of the information entry. + value(object): The value of the information entry. + + Returns: + None + """ + if not hasattr(self, "_info"): + self._info = {} + self._info[key] = value + + def _get_info(self, key): + """ + Get the information associated to this variable. + + Args: + key(str): The key of the information entry. + + Returns: + object + """ + if hasattr(self, "_info") and key in self._info: + return self._info[key] + return None + def _slice_indices(self, slice, length): """ Reference implementation for the slice.indices method. diff --git a/python/paddle/fluid/layers/collective.py b/python/paddle/fluid/layers/collective.py index 9e96624cf7c70f24a7f65a91c7ee41af45ddeb6c..712c0a7cde3845d81890e068fb426aa892b987c7 100644 --- a/python/paddle/fluid/layers/collective.py +++ b/python/paddle/fluid/layers/collective.py @@ -13,8 +13,13 @@ # limitations under the License. from __future__ import print_function +import math + from ..layer_helper import LayerHelper, unique_name -from ..framework import Variable +from ..framework import Variable, default_startup_program +from ..param_attr import ParamAttr +from ..initializer import Normal, Constant +import nn, ops def _allreduce(x, out=None, reduce_type="sum", sync_mode=False): @@ -178,3 +183,251 @@ def _c_sync_comm_stream(x, ring_id): outputs={'Out': [x]}, attrs={'ring_id': ring_id}) return x + + +class DistributedClassifier(object): + ''' + Tookit for distributed classification, in which the parameter of the last + full-connected layer is distributed to all trainers + ''' + + def __init__(self, nclasses, nranks, rank_id, layer_helper): + self.nclasses = nclasses + self.nranks = nranks + self.rank_id = rank_id + self._layer_helper = layer_helper + + self.shard_dim = (nclasses + nranks - 1) // nranks + self.padding_dim = 0 + self.is_equal_division = True + if nclasses % nranks != 0: + self.is_equal_division = False + if rank_id == nranks - 1: + other_shard_dim = self.shard_dim + self.shard_dim = nclasses % other_shard_dim + self.padding_dim = other_shard_dim - self.shard_dim + + def create_parameter(self, + dtype, + in_dim, + param_attr=None, + transpose_weight=False, + use_bias=True): + if param_attr is None: + stdv = math.sqrt(2.0 / (in_dim + self.nclasses)) + param_attr = ParamAttr(initializer=Normal(scale=stdv)) + weight_shape = [self.shard_dim, in_dim + ] if transpose_weight else [in_dim, self.shard_dim] + weight = self._layer_helper.create_parameter( + shape=weight_shape, dtype=dtype, attr=param_attr, is_bias=False) + # avoid distributed parameter allreduce gradients + weight.is_distributed = True + # avoid distributed parameter broadcasting in startup program + default_startup_program().global_block().vars[ + weight.name].is_distributed = True + + bias = None + if use_bias: + bias = self._layer_helper.create_parameter( + shape=[self.shard_dim], + attr=ParamAttr(), + dtype=dtype, + is_bias=True) + bias.is_distributed = True + default_startup_program().global_block().vars[ + bias.name].is_distributed = True + return weight, bias + + def softmax_with_cross_entropy(self, shard_logit, shard_label): + shard_max = nn.reduce_max(shard_logit, dim=1, keep_dim=True) + global_max = _c_allreduce( + shard_max, reduce_type='max', use_calc_stream=True) + shard_logit_new = nn.elementwise_sub(shard_logit, global_max) + + shard_exp = ops.exp(shard_logit_new) + shard_demon = nn.reduce_sum(shard_exp, dim=1, keep_dim=True) + global_demon = _c_allreduce( + shard_demon, reduce_type='sum', use_calc_stream=True) + + global_log_demon = nn.log(global_demon) + shard_log_prob = shard_logit_new - global_log_demon + shard_prob = ops.exp(shard_log_prob) + + shard_one_hot = nn.one_hot( + shard_label, depth=self.shard_dim, allow_out_of_range=True) + target_log_prob = nn.reduce_min( + shard_log_prob * shard_one_hot, dim=1, keep_dim=True) + shard_loss = nn.scale(target_log_prob, scale=-1.0) + global_loss = _c_reducescatter( + shard_loss, nranks=self.nranks, use_calc_stream=True) + return global_loss, shard_prob + + def fc_classify(self, x, label, param_attr=None, use_bias=True): + flatten_dim = reduce(lambda a, b: a * b, x.shape[1:], 1) + weight, bias = self.create_parameter( + dtype=x.dtype, + in_dim=flatten_dim, + param_attr=param_attr, + use_bias=use_bias) + + x_all = _c_allgather(x, nranks=self.nranks, use_calc_stream=True) + label_all = _c_allgather( + label, nranks=self.nranks, use_calc_stream=True) + label_all.stop_gradient = True + + shard_fc = nn.mul(x_all, weight, x_num_col_dims=1) + if use_bias: + shard_fc = nn.elementwise_add(shard_fc, bias) + + shard_label = nn.shard_index( + label_all, + index_num=self.nclasses, + nshards=self.nranks, + shard_id=self.rank_id, + ignore_value=-1) + shard_label.stop_gradient = True + + global_loss, shard_prob = self.softmax_with_cross_entropy(shard_fc, + shard_label) + avg_loss = nn.mean(global_loss) + + avg_loss._set_info('shard_logit', shard_fc) + avg_loss._set_info('shard_prob', shard_prob) + avg_loss._set_info('shard_label', shard_label) + avg_loss._set_info('shard_dim', self.shard_dim) + + return avg_loss + + def arcmargin_classify(self, + x, + label, + margin=0.5, + logit_scale=64, + param_attr=None): + ''' + reference: ArcFace. https://arxiv.org/abs/1801.07698 + ''' + flatten_dim = reduce(lambda a, b: a * b, x.shape[1:], 1) + weight, bias = self.create_parameter( + dtype=x.dtype, + in_dim=flatten_dim, + param_attr=param_attr, + transpose_weight=True, + use_bias=False) + + # normalize x + x_l2 = ops.sqrt(nn.reduce_sum(nn.square(x), dim=1)) + norm_x = nn.elementwise_div(x, x_l2, axis=0) + + norm_x_all = _c_allgather( + norm_x, nranks=self.nranks, use_calc_stream=True) + label_all = _c_allgather( + label, nranks=self.nranks, use_calc_stream=True) + label_all.stop_gradient = True + shard_label = nn.shard_index( + label_all, + index_num=self.nclasses, + nshards=self.nranks, + shard_id=self.rank_id, + ignore_value=-1) + shard_label.stop_gradient = True + + # normalize weight + weight_l2 = ops.sqrt(nn.reduce_sum(nn.square(weight), dim=1)) + norm_weight = nn.elementwise_div(weight, weight_l2, axis=0) + norm_weight = nn.transpose(norm_weight, perm=[1, 0]) + + shard_cos = nn.mul(norm_x_all, norm_weight, x_num_col_dims=1) + + theta = ops.acos(shard_cos) + margin_cos = ops.cos(theta + margin) + + shard_one_hot = nn.one_hot( + shard_label, depth=self.shard_dim, allow_out_of_range=True) + shard_one_hot.stop_gradient = True + + diff = (margin_cos - shard_cos) * shard_one_hot + shard_target_cos = shard_cos + diff + shard_logit = nn.scale(shard_target_cos, scale=logit_scale) + + global_loss, shard_prob = self.softmax_with_cross_entropy(shard_logit, + shard_label) + avg_loss = nn.mean(global_loss) + + avg_loss._set_info('shard_logit', shard_logit) + avg_loss._set_info('shard_prob', shard_prob) + avg_loss._set_info('shard_label', shard_label) + avg_loss._set_info('shard_dim', self.shard_dim) + + return avg_loss + + +def distributed_fc_classify(x, + label, + class_num, + nranks, + rank_id, + param_attr=None, + use_bias=True, + name='dist_fc'): + ''' + ''' + helper = LayerHelper(name, **locals()) + classifier = DistributedClassifier(class_num, nranks, rank_id, helper) + return classifier.fc_classify(x, label, param_attr, use_bias) + + +def distributed_arcmargin_classify(x, + label, + class_num, + nranks, + rank_id, + margin=0.5, + logit_scale=64, + param_attr=None, + name='dist_fc'): + ''' + ''' + helper = LayerHelper(name, **locals()) + classifier = DistributedClassifier(class_num, nranks, rank_id, helper) + return classifier.arcmargin_classify( + x=x, + label=label, + margin=margin, + logit_scale=logit_scale, + param_attr=param_attr) + + +def distributed_fc(x, + out_dim, + nranks, + rank_id, + param_attr=None, + use_bias=True, + name='dist_fc'): + ''' + ''' + helper = LayerHelper(name, **locals()) + classifier = DistributedClassifier(out_dim, nranks, rank_id, helper) + weight, bias = classifier.create_parameter( + dtype=x.dtype, + in_dim=x.shape[-1], + param_attr=param_attr, + use_bias=use_bias) + x_all = _c_allgather(x, nranks=self.nranks, use_calc_stream=True) + label_all = _c_allgather(label, nranks=self.nranks, use_calc_stream=True) + + shard_fc = nn.mul(x_all, weight) + if use_bias: + shard_fc = nn.elementwise_add(shard_fc, bias) + + # sample code + #if not classifier.is_equal_division: + # shard_fc = nn.pad(shard_fc) + #fc = _c_slice_allgather(shard_fc, + # nranks=nranks, + # rank_id=rank_id) + #if not classifier.is_equal_division: + # fc = nn.depad(fc) + #return fc + raise NotImplementedError('distributed_fc') diff --git a/python/paddle/fluid/transpiler/collective.py b/python/paddle/fluid/transpiler/collective.py index 6b5131e58c6d8ea3e2fd15b75c8ebd9169e21ae1..e8d9dad8c4008f09a6edc44a970781b47db99ba6 100644 --- a/python/paddle/fluid/transpiler/collective.py +++ b/python/paddle/fluid/transpiler/collective.py @@ -26,9 +26,11 @@ import numpy as np from .. import core, unique_name from ..framework import Program, default_main_program, default_startup_program +from ..backward import _append_grad_suffix_ from .details import wait_server_ready +from .. import layers -__all__ = ['GradAllReduce', 'LocalSGD'] +__all__ = ['GradAllReduce', 'LocalSGD', 'DistributedClassificationOptimizer'] OpRole = core.op_proto_and_checker_maker.OpRole @@ -168,7 +170,7 @@ class Collective(object): def _is_update_op(self, op): return 'Param' in op.input_names and 'Grad' in op.input_names and \ - "LearningRate" in op.input_names + 'LearningRate' in op.input_names def _is_optimizer_op(self, op): return self.op_role_key in op.attr_names and \ @@ -370,3 +372,81 @@ class LocalSGD(Collective): inputs={'X': [param]}, outputs={'Out': [snapshot]}, attrs={self.op_role_key: OpRole.Optimize}) + + +class DistributedClassificationOptimizer(object): + ''' + ''' + + def __init__(self, optimizer, batch_size): + self._optimizer = optimizer + self._batch_size = batch_size + + def minimize(self, loss): + # TODO: use paddle enforce + assert loss._get_info('shard_logit') + + shard_logit = loss._get_info('shard_logit') + shard_prob = loss._get_info('shard_prob') + shard_label = loss._get_info('shard_label') + shard_dim = loss._get_info('shard_dim') + + op_maker = core.op_proto_and_checker_maker + op_role_key = op_maker.kOpRoleAttrName() + op_role_var_key = op_maker.kOpRoleVarAttrName() + backward_role = int(op_maker.OpRole.Backward) + loss_backward_role = int(op_maker.OpRole.Loss) | int( + op_maker.OpRole.Backward) + + # minimize a scalar of reduce_sum to generate the backward network + scalar = layers.reduce_sum(shard_logit) + ret = self._optimizer.minimize(scalar) + + block = loss.block + # remove the unnecessary ops + index = 0 + for i, op in enumerate(block.ops): + if op.all_attrs()[op_role_key] == loss_backward_role: + index = i + break + + # TODO: use paddle enforce + assert block.ops[index - 1].type == 'reduce_sum' + assert block.ops[index].type == 'fill_constant' + assert block.ops[index + 1].type == 'reduce_sum_grad' + block._remove_op(index + 1) + block._remove_op(index) + block._remove_op(index - 1) + + # insert the calculated gradient + dtype = shard_logit.dtype + shard_one_hot = layers.create_tensor(dtype, name='shard_one_hot') + block._insert_op( + index - 1, + type='one_hot', + inputs={'X': shard_label}, + outputs={'Out': shard_one_hot}, + attrs={ + 'depth': shard_dim, + 'allow_out_of_range': True, + op_role_key: backward_role + }) + shard_logit_grad = layers.create_tensor( + dtype, name=_append_grad_suffix_(shard_logit.name)) + block._insert_op( + index, + type='elementwise_sub', + inputs={'X': shard_prob, + 'Y': shard_one_hot}, + outputs={'Out': shard_logit_grad}, + attrs={op_role_key: backward_role}) + block._insert_op( + index + 1, + type='scale', + inputs={'X': shard_logit_grad}, + outputs={'Out': shard_logit_grad}, + attrs={ + 'scale': 1.0 / self._batch_size, + op_role_key: loss_backward_role + }) + return ret