提交 4eac955c 编写于 作者: G gavin1332

integrate distributed classification

上级 4cfe432c
...@@ -729,6 +729,35 @@ class Variable(object): ...@@ -729,6 +729,35 @@ class Variable(object):
""" """
self.error_clip = error_clip self.error_clip = error_clip
def _set_info(self, key, value):
"""
Set some key-value information associated to this variable.
Args:
key(str): The key of the information entry.
value(object): The value of the information entry.
Returns:
None
"""
if not hasattr(self, "_info"):
self._info = {}
self._info[key] = value
def _get_info(self, key):
"""
Get the information associated to this variable.
Args:
key(str): The key of the information entry.
Returns:
object
"""
if hasattr(self, "_info") and key in self._info:
return self._info[key]
return None
def _slice_indices(self, slice, length): def _slice_indices(self, slice, length):
""" """
Reference implementation for the slice.indices method. Reference implementation for the slice.indices method.
......
...@@ -13,8 +13,13 @@ ...@@ -13,8 +13,13 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import math
from ..layer_helper import LayerHelper, unique_name from ..layer_helper import LayerHelper, unique_name
from ..framework import Variable from ..framework import Variable, default_startup_program
from ..param_attr import ParamAttr
from ..initializer import Normal, Constant
import nn, ops
def _allreduce(x, out=None, reduce_type="sum", sync_mode=False): def _allreduce(x, out=None, reduce_type="sum", sync_mode=False):
...@@ -178,3 +183,251 @@ def _c_sync_comm_stream(x, ring_id): ...@@ -178,3 +183,251 @@ def _c_sync_comm_stream(x, ring_id):
outputs={'Out': [x]}, outputs={'Out': [x]},
attrs={'ring_id': ring_id}) attrs={'ring_id': ring_id})
return x return x
class DistributedClassifier(object):
'''
Tookit for distributed classification, in which the parameter of the last
full-connected layer is distributed to all trainers
'''
def __init__(self, nclasses, nranks, rank_id, layer_helper):
self.nclasses = nclasses
self.nranks = nranks
self.rank_id = rank_id
self._layer_helper = layer_helper
self.shard_dim = (nclasses + nranks - 1) // nranks
self.padding_dim = 0
self.is_equal_division = True
if nclasses % nranks != 0:
self.is_equal_division = False
if rank_id == nranks - 1:
other_shard_dim = self.shard_dim
self.shard_dim = nclasses % other_shard_dim
self.padding_dim = other_shard_dim - self.shard_dim
def create_parameter(self,
dtype,
in_dim,
param_attr=None,
transpose_weight=False,
use_bias=True):
if param_attr is None:
stdv = math.sqrt(2.0 / (in_dim + self.nclasses))
param_attr = ParamAttr(initializer=Normal(scale=stdv))
weight_shape = [self.shard_dim, in_dim
] if transpose_weight else [in_dim, self.shard_dim]
weight = self._layer_helper.create_parameter(
shape=weight_shape, dtype=dtype, attr=param_attr, is_bias=False)
# avoid distributed parameter allreduce gradients
weight.is_distributed = True
# avoid distributed parameter broadcasting in startup program
default_startup_program().global_block().vars[
weight.name].is_distributed = True
bias = None
if use_bias:
bias = self._layer_helper.create_parameter(
shape=[self.shard_dim],
attr=ParamAttr(),
dtype=dtype,
is_bias=True)
bias.is_distributed = True
default_startup_program().global_block().vars[
bias.name].is_distributed = True
return weight, bias
def softmax_with_cross_entropy(self, shard_logit, shard_label):
shard_max = nn.reduce_max(shard_logit, dim=1, keep_dim=True)
global_max = _c_allreduce(
shard_max, reduce_type='max', use_calc_stream=True)
shard_logit_new = nn.elementwise_sub(shard_logit, global_max)
shard_exp = ops.exp(shard_logit_new)
shard_demon = nn.reduce_sum(shard_exp, dim=1, keep_dim=True)
global_demon = _c_allreduce(
shard_demon, reduce_type='sum', use_calc_stream=True)
global_log_demon = nn.log(global_demon)
shard_log_prob = shard_logit_new - global_log_demon
shard_prob = ops.exp(shard_log_prob)
shard_one_hot = nn.one_hot(
shard_label, depth=self.shard_dim, allow_out_of_range=True)
target_log_prob = nn.reduce_min(
shard_log_prob * shard_one_hot, dim=1, keep_dim=True)
shard_loss = nn.scale(target_log_prob, scale=-1.0)
global_loss = _c_reducescatter(
shard_loss, nranks=self.nranks, use_calc_stream=True)
return global_loss, shard_prob
def fc_classify(self, x, label, param_attr=None, use_bias=True):
flatten_dim = reduce(lambda a, b: a * b, x.shape[1:], 1)
weight, bias = self.create_parameter(
dtype=x.dtype,
in_dim=flatten_dim,
param_attr=param_attr,
use_bias=use_bias)
x_all = _c_allgather(x, nranks=self.nranks, use_calc_stream=True)
label_all = _c_allgather(
label, nranks=self.nranks, use_calc_stream=True)
label_all.stop_gradient = True
shard_fc = nn.mul(x_all, weight, x_num_col_dims=1)
if use_bias:
shard_fc = nn.elementwise_add(shard_fc, bias)
shard_label = nn.shard_index(
label_all,
index_num=self.nclasses,
nshards=self.nranks,
shard_id=self.rank_id,
ignore_value=-1)
shard_label.stop_gradient = True
global_loss, shard_prob = self.softmax_with_cross_entropy(shard_fc,
shard_label)
avg_loss = nn.mean(global_loss)
avg_loss._set_info('shard_logit', shard_fc)
avg_loss._set_info('shard_prob', shard_prob)
avg_loss._set_info('shard_label', shard_label)
avg_loss._set_info('shard_dim', self.shard_dim)
return avg_loss
def arcmargin_classify(self,
x,
label,
margin=0.5,
logit_scale=64,
param_attr=None):
'''
reference: ArcFace. https://arxiv.org/abs/1801.07698
'''
flatten_dim = reduce(lambda a, b: a * b, x.shape[1:], 1)
weight, bias = self.create_parameter(
dtype=x.dtype,
in_dim=flatten_dim,
param_attr=param_attr,
transpose_weight=True,
use_bias=False)
# normalize x
x_l2 = ops.sqrt(nn.reduce_sum(nn.square(x), dim=1))
norm_x = nn.elementwise_div(x, x_l2, axis=0)
norm_x_all = _c_allgather(
norm_x, nranks=self.nranks, use_calc_stream=True)
label_all = _c_allgather(
label, nranks=self.nranks, use_calc_stream=True)
label_all.stop_gradient = True
shard_label = nn.shard_index(
label_all,
index_num=self.nclasses,
nshards=self.nranks,
shard_id=self.rank_id,
ignore_value=-1)
shard_label.stop_gradient = True
# normalize weight
weight_l2 = ops.sqrt(nn.reduce_sum(nn.square(weight), dim=1))
norm_weight = nn.elementwise_div(weight, weight_l2, axis=0)
norm_weight = nn.transpose(norm_weight, perm=[1, 0])
shard_cos = nn.mul(norm_x_all, norm_weight, x_num_col_dims=1)
theta = ops.acos(shard_cos)
margin_cos = ops.cos(theta + margin)
shard_one_hot = nn.one_hot(
shard_label, depth=self.shard_dim, allow_out_of_range=True)
shard_one_hot.stop_gradient = True
diff = (margin_cos - shard_cos) * shard_one_hot
shard_target_cos = shard_cos + diff
shard_logit = nn.scale(shard_target_cos, scale=logit_scale)
global_loss, shard_prob = self.softmax_with_cross_entropy(shard_logit,
shard_label)
avg_loss = nn.mean(global_loss)
avg_loss._set_info('shard_logit', shard_logit)
avg_loss._set_info('shard_prob', shard_prob)
avg_loss._set_info('shard_label', shard_label)
avg_loss._set_info('shard_dim', self.shard_dim)
return avg_loss
def distributed_fc_classify(x,
label,
class_num,
nranks,
rank_id,
param_attr=None,
use_bias=True,
name='dist_fc'):
'''
'''
helper = LayerHelper(name, **locals())
classifier = DistributedClassifier(class_num, nranks, rank_id, helper)
return classifier.fc_classify(x, label, param_attr, use_bias)
def distributed_arcmargin_classify(x,
label,
class_num,
nranks,
rank_id,
margin=0.5,
logit_scale=64,
param_attr=None,
name='dist_fc'):
'''
'''
helper = LayerHelper(name, **locals())
classifier = DistributedClassifier(class_num, nranks, rank_id, helper)
return classifier.arcmargin_classify(
x=x,
label=label,
margin=margin,
logit_scale=logit_scale,
param_attr=param_attr)
def distributed_fc(x,
out_dim,
nranks,
rank_id,
param_attr=None,
use_bias=True,
name='dist_fc'):
'''
'''
helper = LayerHelper(name, **locals())
classifier = DistributedClassifier(out_dim, nranks, rank_id, helper)
weight, bias = classifier.create_parameter(
dtype=x.dtype,
in_dim=x.shape[-1],
param_attr=param_attr,
use_bias=use_bias)
x_all = _c_allgather(x, nranks=self.nranks, use_calc_stream=True)
label_all = _c_allgather(label, nranks=self.nranks, use_calc_stream=True)
shard_fc = nn.mul(x_all, weight)
if use_bias:
shard_fc = nn.elementwise_add(shard_fc, bias)
# sample code
#if not classifier.is_equal_division:
# shard_fc = nn.pad(shard_fc)
#fc = _c_slice_allgather(shard_fc,
# nranks=nranks,
# rank_id=rank_id)
#if not classifier.is_equal_division:
# fc = nn.depad(fc)
#return fc
raise NotImplementedError('distributed_fc')
...@@ -26,9 +26,11 @@ import numpy as np ...@@ -26,9 +26,11 @@ import numpy as np
from .. import core, unique_name from .. import core, unique_name
from ..framework import Program, default_main_program, default_startup_program from ..framework import Program, default_main_program, default_startup_program
from ..backward import _append_grad_suffix_
from .details import wait_server_ready from .details import wait_server_ready
from .. import layers
__all__ = ['GradAllReduce', 'LocalSGD'] __all__ = ['GradAllReduce', 'LocalSGD', 'DistributedClassificationOptimizer']
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
...@@ -168,7 +170,7 @@ class Collective(object): ...@@ -168,7 +170,7 @@ class Collective(object):
def _is_update_op(self, op): def _is_update_op(self, op):
return 'Param' in op.input_names and 'Grad' in op.input_names and \ return 'Param' in op.input_names and 'Grad' in op.input_names and \
"LearningRate" in op.input_names 'LearningRate' in op.input_names
def _is_optimizer_op(self, op): def _is_optimizer_op(self, op):
return self.op_role_key in op.attr_names and \ return self.op_role_key in op.attr_names and \
...@@ -370,3 +372,81 @@ class LocalSGD(Collective): ...@@ -370,3 +372,81 @@ class LocalSGD(Collective):
inputs={'X': [param]}, inputs={'X': [param]},
outputs={'Out': [snapshot]}, outputs={'Out': [snapshot]},
attrs={self.op_role_key: OpRole.Optimize}) attrs={self.op_role_key: OpRole.Optimize})
class DistributedClassificationOptimizer(object):
'''
'''
def __init__(self, optimizer, batch_size):
self._optimizer = optimizer
self._batch_size = batch_size
def minimize(self, loss):
# TODO: use paddle enforce
assert loss._get_info('shard_logit')
shard_logit = loss._get_info('shard_logit')
shard_prob = loss._get_info('shard_prob')
shard_label = loss._get_info('shard_label')
shard_dim = loss._get_info('shard_dim')
op_maker = core.op_proto_and_checker_maker
op_role_key = op_maker.kOpRoleAttrName()
op_role_var_key = op_maker.kOpRoleVarAttrName()
backward_role = int(op_maker.OpRole.Backward)
loss_backward_role = int(op_maker.OpRole.Loss) | int(
op_maker.OpRole.Backward)
# minimize a scalar of reduce_sum to generate the backward network
scalar = layers.reduce_sum(shard_logit)
ret = self._optimizer.minimize(scalar)
block = loss.block
# remove the unnecessary ops
index = 0
for i, op in enumerate(block.ops):
if op.all_attrs()[op_role_key] == loss_backward_role:
index = i
break
# TODO: use paddle enforce
assert block.ops[index - 1].type == 'reduce_sum'
assert block.ops[index].type == 'fill_constant'
assert block.ops[index + 1].type == 'reduce_sum_grad'
block._remove_op(index + 1)
block._remove_op(index)
block._remove_op(index - 1)
# insert the calculated gradient
dtype = shard_logit.dtype
shard_one_hot = layers.create_tensor(dtype, name='shard_one_hot')
block._insert_op(
index - 1,
type='one_hot',
inputs={'X': shard_label},
outputs={'Out': shard_one_hot},
attrs={
'depth': shard_dim,
'allow_out_of_range': True,
op_role_key: backward_role
})
shard_logit_grad = layers.create_tensor(
dtype, name=_append_grad_suffix_(shard_logit.name))
block._insert_op(
index,
type='elementwise_sub',
inputs={'X': shard_prob,
'Y': shard_one_hot},
outputs={'Out': shard_logit_grad},
attrs={op_role_key: backward_role})
block._insert_op(
index + 1,
type='scale',
inputs={'X': shard_logit_grad},
outputs={'Out': shard_logit_grad},
attrs={
'scale': 1.0 / self._batch_size,
op_role_key: loss_backward_role
})
return ret
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册