From 2be9036fcc911bbbaaf7e92e790cc0817f01ce67 Mon Sep 17 00:00:00 2001 From: gavin1332 Date: Wed, 4 Sep 2019 16:48:00 +0800 Subject: [PATCH] extract a common distributed testing class for ut test=develop test=document_preview --- paddle/fluid/API.spec | 2 +- python/paddle/fluid/layers/collective.py | 303 ---------------- python/paddle/fluid/layers/dist_algo.py | 326 ++++++++++++++++++ .../unittests/dist_arcface_classification.py | 14 +- .../unittests/dist_classification_base.py | 97 ++++++ .../unittests/dist_softmax_classification.py | 11 +- .../test_dist_arcface_classification.py | 6 +- ...n_base.py => test_dist_collective_base.py} | 234 +++++++------ .../test_dist_softmax_classification.py | 9 +- python/paddle/fluid/transpiler/collective.py | 15 +- 10 files changed, 581 insertions(+), 436 deletions(-) create mode 100644 python/paddle/fluid/layers/dist_algo.py create mode 100644 python/paddle/fluid/tests/unittests/dist_classification_base.py rename python/paddle/fluid/tests/unittests/{test_dist_classification_base.py => test_dist_collective_base.py} (66%) diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index b8658b17be2..a955cff302b 100755 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -290,7 +290,7 @@ paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'tran paddle.fluid.layers.match_matrix_tensor (ArgSpec(args=['x', 'y', 'channel_num', 'act', 'param_attr', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, 'float32', None)), ('document', 'b6ea7d4ddeacae85e37d1e47d5262948')) paddle.fluid.layers.filter_by_instag (ArgSpec(args=['ins', 'ins_tag', 'filter_tag', 'is_lod'], varargs=None, keywords=None, defaults=None), ('document', '7703a2088af8de4128b143ff1164ca4a')) paddle.fluid.layers.var_conv_2d (ArgSpec(args=['input', 'row', 'col', 'input_channel', 'output_channel', 'filter_size', 'stride', 'param_attr', 'act', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, None, 'float32', None)), ('document', '7a8b8ade5512c95f9ea30261d33ded6c')) -paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '5786fdbba6753ecd6cbce5e6b0889924')) +paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '3a209cbe5f648c00f8d7c2187dc23674')) paddle.fluid.layers.hard_swish (ArgSpec(args=['x', 'threshold', 'scale', 'offset', 'name'], varargs=None, keywords=None, defaults=(6.0, 6.0, 3.0, None)), ('document', '6a5152a7015c62cb8278fc24cb456459')) paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '9d7806e31bdf727c1a23b8782a09b545')) paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '88367daf9a30c9ab83adc5d7221e23ef')) diff --git a/python/paddle/fluid/layers/collective.py b/python/paddle/fluid/layers/collective.py index 290dc96b634..d0be098c7a4 100644 --- a/python/paddle/fluid/layers/collective.py +++ b/python/paddle/fluid/layers/collective.py @@ -19,7 +19,6 @@ from ..layer_helper import LayerHelper, unique_name from ..framework import Variable, default_startup_program from ..param_attr import ParamAttr from ..initializer import Normal, Constant -import nn, ops def _allreduce(x, out=None, reduce_type="sum", sync_mode=False): @@ -183,305 +182,3 @@ def _c_sync_comm_stream(x, ring_id): outputs={'Out': [x]}, attrs={'ring_id': ring_id}) return x - - -class DistributedClassifier(object): - ''' - Tookit for distributed classification, in which the parameter of the last - full-connected layer is distributed to all trainers - ''' - - def __init__(self, nclasses, nranks, rank_id, layer_helper): - self.nclasses = nclasses - self.nranks = nranks - self.rank_id = rank_id - self._layer_helper = layer_helper - - self.shard_dim = (nclasses + nranks - 1) // nranks - self.padding_dim = 0 - self.is_equal_division = True - if nclasses % nranks != 0: - self.is_equal_division = False - if rank_id == nranks - 1: - other_shard_dim = self.shard_dim - self.shard_dim = nclasses % other_shard_dim - self.padding_dim = other_shard_dim - self.shard_dim - - def create_parameter(self, - dtype, - in_dim, - param_attr=None, - transpose_weight=False, - use_bias=True): - if param_attr is None: - stdv = math.sqrt(2.0 / (in_dim + self.nclasses)) - param_attr = ParamAttr(initializer=Normal(scale=stdv)) - weight_shape = [self.shard_dim, in_dim - ] if transpose_weight else [in_dim, self.shard_dim] - weight = self._layer_helper.create_parameter( - shape=weight_shape, dtype=dtype, attr=param_attr, is_bias=False) - # avoid distributed parameter allreduce gradients - weight.is_distributed = True - # avoid distributed parameter broadcasting in startup program - default_startup_program().global_block().vars[ - weight.name].is_distributed = True - - bias = None - if use_bias: - bias = self._layer_helper.create_parameter( - shape=[self.shard_dim], - attr=ParamAttr(), - dtype=dtype, - is_bias=True) - bias.is_distributed = True - default_startup_program().global_block().vars[ - bias.name].is_distributed = True - return weight, bias - - def softmax_with_cross_entropy(self, shard_logit, shard_label): - shard_max = nn.reduce_max(shard_logit, dim=1, keep_dim=True) - global_max = _c_allreduce( - shard_max, reduce_type='max', use_calc_stream=True) - shard_logit_new = nn.elementwise_sub(shard_logit, global_max) - - shard_exp = ops.exp(shard_logit_new) - shard_demon = nn.reduce_sum(shard_exp, dim=1, keep_dim=True) - global_demon = _c_allreduce( - shard_demon, reduce_type='sum', use_calc_stream=True) - - global_log_demon = nn.log(global_demon) - shard_log_prob = shard_logit_new - global_log_demon - shard_prob = ops.exp(shard_log_prob) - - shard_one_hot = nn.one_hot( - shard_label, depth=self.shard_dim, allow_out_of_range=True) - target_log_prob = nn.reduce_min( - shard_log_prob * shard_one_hot, dim=1, keep_dim=True) - shard_loss = nn.scale(target_log_prob, scale=-1.0) - global_loss = _c_reducescatter( - shard_loss, nranks=self.nranks, use_calc_stream=True) - return global_loss, shard_prob - - def fc_classify(self, x, label, param_attr=None, use_bias=True): - flatten_dim = reduce(lambda a, b: a * b, x.shape[1:], 1) - weight, bias = self.create_parameter( - dtype=x.dtype, - in_dim=flatten_dim, - param_attr=param_attr, - use_bias=use_bias) - - x_all = _c_allgather(x, nranks=self.nranks, use_calc_stream=True) - label_all = _c_allgather( - label, nranks=self.nranks, use_calc_stream=True) - label_all.stop_gradient = True - - shard_fc = nn.mul(x_all, weight, x_num_col_dims=1) - if use_bias: - shard_fc = nn.elementwise_add(shard_fc, bias) - - shard_label = nn.shard_index( - label_all, - index_num=self.nclasses, - nshards=self.nranks, - shard_id=self.rank_id, - ignore_value=-1) - shard_label.stop_gradient = True - - global_loss, shard_prob = self.softmax_with_cross_entropy(shard_fc, - shard_label) - avg_loss = nn.mean(global_loss) - - avg_loss._set_info('shard_logit', shard_fc) - avg_loss._set_info('shard_prob', shard_prob) - avg_loss._set_info('shard_label', shard_label) - avg_loss._set_info('shard_dim', self.shard_dim) - - return avg_loss - - def arcface_classify(self, - x, - label, - margin=0.5, - logit_scale=64, - param_attr=None): - ''' - reference: ArcFace. https://arxiv.org/abs/1801.07698 - ''' - flatten_dim = reduce(lambda a, b: a * b, x.shape[1:], 1) - weight, bias = self.create_parameter( - dtype=x.dtype, - in_dim=flatten_dim, - param_attr=param_attr, - transpose_weight=True, - use_bias=False) - - # normalize x - x_l2 = ops.sqrt(nn.reduce_sum(nn.square(x), dim=1)) - norm_x = nn.elementwise_div(x, x_l2, axis=0) - - norm_x_all = _c_allgather( - norm_x, nranks=self.nranks, use_calc_stream=True) - label_all = _c_allgather( - label, nranks=self.nranks, use_calc_stream=True) - label_all.stop_gradient = True - shard_label = nn.shard_index( - label_all, - index_num=self.nclasses, - nshards=self.nranks, - shard_id=self.rank_id, - ignore_value=-1) - shard_label.stop_gradient = True - - # normalize weight - weight_l2 = ops.sqrt(nn.reduce_sum(nn.square(weight), dim=1)) - norm_weight = nn.elementwise_div(weight, weight_l2, axis=0) - norm_weight = nn.transpose(norm_weight, perm=[1, 0]) - - shard_cos = nn.mul(norm_x_all, norm_weight, x_num_col_dims=1) - - theta = ops.acos(shard_cos) - margin_cos = ops.cos(theta + margin) - - shard_one_hot = nn.one_hot( - shard_label, depth=self.shard_dim, allow_out_of_range=True) - shard_one_hot.stop_gradient = True - - diff = (margin_cos - shard_cos) * shard_one_hot - shard_target_cos = shard_cos + diff - shard_logit = nn.scale(shard_target_cos, scale=logit_scale) - - global_loss, shard_prob = self.softmax_with_cross_entropy(shard_logit, - shard_label) - avg_loss = nn.mean(global_loss) - - avg_loss._set_info('shard_logit', shard_logit) - avg_loss._set_info('shard_prob', shard_prob) - avg_loss._set_info('shard_label', shard_label) - avg_loss._set_info('shard_dim', self.shard_dim) - - return avg_loss - - -def _distributed_fc_classify(x, - label, - class_num, - nranks, - rank_id, - param_attr=None, - use_bias=True, - name=None): - ''' - Classification layer with FC, softmax and cross entropy calculation of - distibuted version in case of too large number of classes. - - Args: - x (Variable): The feature representation of the input samples. This - feature will be flattened into 2-D tensor from dimension index - 1. E.g. [32, 1024, 1, 1] will be flattened to [32, 1024]. - label (Variable): The label corresponding to the input samples. - class_num (integer): The number of classes of the classification problem. - nranks (integer): The number of ranks of distributed trainers. - rank_id (integer): The rank index of the current trainer. - param_attr (ParamAttr, default None): The parameter attribute for - learnable distributed parameters/weights of this layer. - use_bias (float, default 64.0): The scale factor for logit value - of cosine range. - name (str, default None): The name of this layer. - Returns: - Variable: The ArcFace loss. - - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - input = fluid.layers.data(name="input", - shape=[32, 1024], - dtype='float32', - append_batch_size=False) - label = fluid.layers.data(name="label", - shape=[32, 1], - dtype='int64', - append_batch_size=False) - y = fluid.layers.collective.distributed_fc_classify(x=input, - label=label, - class_num=1000, - nranks=8, - rank_id=0) - ''' - - if name is None: - name = 'dist_fc' - helper = LayerHelper(name, **locals()) - classifier = DistributedClassifier(class_num, nranks, rank_id, helper) - return classifier.fc_classify(x, label, param_attr, use_bias) - - -def _distributed_arcface_classify(x, - label, - class_num, - nranks, - rank_id, - margin=0.5, - logit_scale=64.0, - param_attr=None, - name=None): - ''' - Classification layer with ArcFace loss of distibuted version in case of - too large number of classes. the equation is - - .. math:: - - L=-\frac{1}{N}\sum^N_{i=1}\log\frac{e^{s(cos(\theta_{y_i}+m))}}{e^{s(cos(\theta_{y_i}+m))}+\sum^n_{j=1,j\neq y_i} e^{scos\theta_{y_i}}} - - where the :math: `\theta_{y_i}` is the angle between the feature :math: `x` and - the representation of class :math: `i`. The details of ArcFace loss - could be referred to https://arxiv.org/abs/1801.07698. - - Args: - x (Variable): The feature representation of the input samples. This - feature will be flattened into 2-D tensor from dimension index - 1. E.g. [32, 1024, 1, 1] will be flattened to [32, 1024]. - label (Variable): The label corresponding to the input samples. - class_num (integer): The number of classes of the classification problem. - nranks (integer): The number of ranks of distributed trainers. - rank_id (integer): The rank index of the current trainer. - margin (float, default 0.5): The angular margin penalty to enhance - the intra-class compactness and inter-class discrepancy. - logit_scale (float, default 64.0): The scale factor for logit value - of cosine range. - param_attr (ParamAttr, default None): The parameter attribute for - learnable distributed parameters/weights of this layer. - name (str, default None): The name of this layer. - Returns: - Variable: The ArcFace loss. - - - Examples: - .. code-block:: python - - import paddle.fluid as fluid - input = fluid.layers.data(name="input", - shape=[32, 1024], - dtype='float32', - append_batch_size=False) - label = fluid.layers.data(name="label", - shape=[32, 1], - dtype='int64', - append_batch_size=False) - y = fluid.layers.collective.distributed_arcface_classify(x=input, - label=label, - class_num=1000, - nranks=8, - rank_id=0) - ''' - if name is None: - name = 'dist_fc' - helper = LayerHelper(name, **locals()) - classifier = DistributedClassifier(class_num, nranks, rank_id, helper) - return classifier.arcface_classify( - x=x, - label=label, - margin=margin, - logit_scale=logit_scale, - param_attr=param_attr) diff --git a/python/paddle/fluid/layers/dist_algo.py b/python/paddle/fluid/layers/dist_algo.py new file mode 100644 index 00000000000..f157fceb02c --- /dev/null +++ b/python/paddle/fluid/layers/dist_algo.py @@ -0,0 +1,326 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import math + +from six.moves import reduce +from ..layer_helper import LayerHelper +from ..framework import Variable, default_startup_program +from ..param_attr import ParamAttr +from ..initializer import Normal, Constant +from . import nn, ops, collective + + +class DistributedClassifier(object): + ''' + Tookit for distributed classification, in which the parameter of the last + full-connected layer is distributed to all trainers + ''' + + def __init__(self, nclasses, nranks, rank_id, layer_helper): + self.nclasses = nclasses + self.nranks = nranks + self.rank_id = rank_id + self._layer_helper = layer_helper + + self.shard_dim = (nclasses + nranks - 1) // nranks + self.padding_dim = 0 + self.is_equal_division = True + if nclasses % nranks != 0: + self.is_equal_division = False + if rank_id == nranks - 1: + other_shard_dim = self.shard_dim + self.shard_dim = nclasses % other_shard_dim + self.padding_dim = other_shard_dim - self.shard_dim + + def create_parameter(self, + dtype, + in_dim, + param_attr=None, + transpose_weight=False, + use_bias=True): + if param_attr is None: + stdv = math.sqrt(2.0 / (in_dim + self.nclasses)) + param_attr = ParamAttr(initializer=Normal(scale=stdv)) + weight_shape = [self.shard_dim, in_dim + ] if transpose_weight else [in_dim, self.shard_dim] + weight = self._layer_helper.create_parameter( + shape=weight_shape, dtype=dtype, attr=param_attr, is_bias=False) + # avoid distributed parameter allreduce gradients + weight.is_distributed = True + # avoid distributed parameter broadcasting in startup program + default_startup_program().global_block().vars[ + weight.name].is_distributed = True + + bias = None + if use_bias: + bias = self._layer_helper.create_parameter( + shape=[self.shard_dim], + attr=ParamAttr(), + dtype=dtype, + is_bias=True) + bias.is_distributed = True + default_startup_program().global_block().vars[ + bias.name].is_distributed = True + return weight, bias + + def softmax_with_cross_entropy(self, shard_logit, shard_label): + shard_max = nn.reduce_max(shard_logit, dim=1, keep_dim=True) + global_max = collective._c_allreduce( + shard_max, reduce_type='max', use_calc_stream=True) + shard_logit_new = nn.elementwise_sub(shard_logit, global_max) + + shard_exp = ops.exp(shard_logit_new) + shard_demon = nn.reduce_sum(shard_exp, dim=1, keep_dim=True) + global_demon = collective._c_allreduce( + shard_demon, reduce_type='sum', use_calc_stream=True) + + global_log_demon = nn.log(global_demon) + shard_log_prob = shard_logit_new - global_log_demon + shard_prob = ops.exp(shard_log_prob) + + shard_one_hot = nn.one_hot( + shard_label, depth=self.shard_dim, allow_out_of_range=True) + target_log_prob = nn.reduce_min( + shard_log_prob * shard_one_hot, dim=1, keep_dim=True) + shard_loss = nn.scale(target_log_prob, scale=-1.0) + global_loss = collective._c_reducescatter( + shard_loss, nranks=self.nranks, use_calc_stream=True) + return global_loss, shard_prob + + def softmax_classify(self, x, label, param_attr=None, use_bias=True): + flatten_dim = reduce(lambda a, b: a * b, x.shape[1:], 1) + weight, bias = self.create_parameter( + dtype=x.dtype, + in_dim=flatten_dim, + param_attr=param_attr, + use_bias=use_bias) + + x_all = collective._c_allgather( + x, nranks=self.nranks, use_calc_stream=True) + label_all = collective._c_allgather( + label, nranks=self.nranks, use_calc_stream=True) + label_all.stop_gradient = True + + shard_fc = nn.mul(x_all, weight, x_num_col_dims=1) + if use_bias: + shard_fc = nn.elementwise_add(shard_fc, bias) + + shard_label = nn.shard_index( + label_all, + index_num=self.nclasses, + nshards=self.nranks, + shard_id=self.rank_id, + ignore_value=-1) + shard_label.stop_gradient = True + + global_loss, shard_prob = self.softmax_with_cross_entropy(shard_fc, + shard_label) + avg_loss = nn.mean(global_loss) + + avg_loss._set_info('shard_logit', shard_fc) + avg_loss._set_info('shard_prob', shard_prob) + avg_loss._set_info('shard_label', shard_label) + avg_loss._set_info('shard_dim', self.shard_dim) + + return avg_loss + + def arcface_classify(self, + x, + label, + margin=0.5, + logit_scale=64, + param_attr=None): + ''' + reference: ArcFace. https://arxiv.org/abs/1801.07698 + ''' + flatten_dim = reduce(lambda a, b: a * b, x.shape[1:], 1) + weight, bias = self.create_parameter( + dtype=x.dtype, + in_dim=flatten_dim, + param_attr=param_attr, + transpose_weight=True, + use_bias=False) + + # normalize x + x_l2 = ops.sqrt(nn.reduce_sum(nn.square(x), dim=1)) + norm_x = nn.elementwise_div(x, x_l2, axis=0) + + norm_x_all = collective._c_allgather( + norm_x, nranks=self.nranks, use_calc_stream=True) + label_all = collective._c_allgather( + label, nranks=self.nranks, use_calc_stream=True) + label_all.stop_gradient = True + shard_label = nn.shard_index( + label_all, + index_num=self.nclasses, + nshards=self.nranks, + shard_id=self.rank_id, + ignore_value=-1) + shard_label.stop_gradient = True + + # normalize weight + weight_l2 = ops.sqrt(nn.reduce_sum(nn.square(weight), dim=1)) + norm_weight = nn.elementwise_div(weight, weight_l2, axis=0) + norm_weight = nn.transpose(norm_weight, perm=[1, 0]) + + shard_cos = nn.mul(norm_x_all, norm_weight, x_num_col_dims=1) + + theta = ops.acos(shard_cos) + margin_cos = ops.cos(theta + margin) + + shard_one_hot = nn.one_hot( + shard_label, depth=self.shard_dim, allow_out_of_range=True) + shard_one_hot.stop_gradient = True + + diff = (margin_cos - shard_cos) * shard_one_hot + shard_target_cos = shard_cos + diff + shard_logit = nn.scale(shard_target_cos, scale=logit_scale) + + global_loss, shard_prob = self.softmax_with_cross_entropy(shard_logit, + shard_label) + avg_loss = nn.mean(global_loss) + + avg_loss._set_info('shard_logit', shard_logit) + avg_loss._set_info('shard_prob', shard_prob) + avg_loss._set_info('shard_label', shard_label) + avg_loss._set_info('shard_dim', self.shard_dim) + + return avg_loss + + +def _distributed_softmax_classify(x, + label, + class_num, + nranks, + rank_id, + param_attr=None, + use_bias=True, + name=None): + ''' + Classification layer with FC, softmax and cross entropy calculation of + distibuted version in case of too large number of classes. + + Args: + x (Variable): The feature representation of the input samples. This + feature will be flattened into 2-D tensor from dimension index + 1. E.g. [32, 1024, 1, 1] will be flattened to [32, 1024]. + label (Variable): The label corresponding to the input samples. + class_num (integer): The number of classes of the classification problem. + nranks (integer): The number of ranks of distributed trainers. + rank_id (integer): The rank index of the current trainer. + param_attr (ParamAttr, default None): The parameter attribute for + learnable distributed parameters/weights of this layer. + use_bias (float, default 64.0): The scale factor for logit value + of cosine range. + name (str, default None): The name of this layer. + Returns: + Variable: The ArcFace loss. + + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + input = fluid.layers.data(name="input", + shape=[32, 1024], + dtype='float32', + append_batch_size=False) + label = fluid.layers.data(name="label", + shape=[32, 1], + dtype='int64', + append_batch_size=False) + y = fluid.layers.collective.distributed_softmax_classify(x=input, + label=label, + class_num=1000, + nranks=8, + rank_id=0) + ''' + + if name is None: + name = 'dist_softmax' + helper = LayerHelper(name, **locals()) + classifier = DistributedClassifier(class_num, nranks, rank_id, helper) + return classifier.softmax_classify(x, label, param_attr, use_bias) + + +def _distributed_arcface_classify(x, + label, + class_num, + nranks, + rank_id, + margin=0.5, + logit_scale=64.0, + param_attr=None, + name=None): + ''' + Classification layer with ArcFace loss of distibuted version in case of + too large number of classes. the equation is + + .. math:: + + L=-\frac{1}{N}\sum^N_{i=1}\log\frac{e^{s(cos(\theta_{y_i}+m))}}{e^{s(cos(\theta_{y_i}+m))}+\sum^n_{j=1,j\neq y_i} e^{scos\theta_{y_i}}} + + where the :math: `\theta_{y_i}` is the angle between the feature :math: `x` and + the representation of class :math: `i`. The details of ArcFace loss + could be referred to https://arxiv.org/abs/1801.07698. + + Args: + x (Variable): The feature representation of the input samples. This + feature will be flattened into 2-D tensor from dimension index + 1. E.g. [32, 1024, 1, 1] will be flattened to [32, 1024]. + label (Variable): The label corresponding to the input samples. + class_num (integer): The number of classes of the classification problem. + nranks (integer): The number of ranks of distributed trainers. + rank_id (integer): The rank index of the current trainer. + margin (float, default 0.5): The angular margin penalty to enhance + the intra-class compactness and inter-class discrepancy. + logit_scale (float, default 64.0): The scale factor for logit value + of cosine range. + param_attr (ParamAttr, default None): The parameter attribute for + learnable distributed parameters/weights of this layer. + name (str, default None): The name of this layer. + Returns: + Variable: The ArcFace loss. + + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + input = fluid.layers.data(name="input", + shape=[32, 1024], + dtype='float32', + append_batch_size=False) + label = fluid.layers.data(name="label", + shape=[32, 1], + dtype='int64', + append_batch_size=False) + y = fluid.layers.collective.distributed_arcface_classify(x=input, + label=label, + class_num=1000, + nranks=8, + rank_id=0) + ''' + if name is None: + name = 'dist_fc' + helper = LayerHelper(name, **locals()) + classifier = DistributedClassifier(class_num, nranks, rank_id, helper) + return classifier.arcface_classify( + x=x, + label=label, + margin=margin, + logit_scale=logit_scale, + param_attr=param_attr) diff --git a/python/paddle/fluid/tests/unittests/dist_arcface_classification.py b/python/paddle/fluid/tests/unittests/dist_arcface_classification.py index ec043956d1d..59766f31eab 100644 --- a/python/paddle/fluid/tests/unittests/dist_arcface_classification.py +++ b/python/paddle/fluid/tests/unittests/dist_arcface_classification.py @@ -17,22 +17,24 @@ import unittest import numpy as np import paddle.fluid as fluid import paddle.fluid.layers as layers -import paddle.fluid.layers.collective as collective +import paddle.fluid.layers.dist_algo as dist_algo from paddle.fluid.initializer import NumpyArrayInitializer -from test_dist_classification_base import DistClassificationRunner, runtime_main +from dist_classification_base import DistClassificationRunner +from test_dist_collective_base import runtime_main -# TODO donot transpose weight +# TODO(gavin1332) check whether it is necessary to transpose weight class DistArcfaceClassificationRunner(DistClassificationRunner): @classmethod - def add_arguments(cls, parser): + def add_other_arguments(cls, parser): parser.add_argument('--arcface_margin', type=float, default=0.0) parser.add_argument('--arcface_scale', type=float, default=1.0) def __init__(self, args): super(DistArcfaceClassificationRunner, self).__init__(args) np.random.seed(1024) - self.param_value = np.random.rand(args.class_num, args.feature_size) + self.param_value = np.random.rand(self.args.class_num, + self.args.feature_size) def local_classify_subnet(self, feature, label): args = self.args @@ -76,7 +78,7 @@ class DistArcfaceClassificationRunner(DistClassificationRunner): shard_start = shard_dim * args.rank rank_param_value = self.param_value[shard_start:(shard_start + shard_dim ), :] - cost = layers.collective._distributed_arcface_classify( + cost = layers.dist_algo._distributed_arcface_classify( x=feature, label=label, class_num=args.class_num, diff --git a/python/paddle/fluid/tests/unittests/dist_classification_base.py b/python/paddle/fluid/tests/unittests/dist_classification_base.py new file mode 100644 index 00000000000..4091afec2fc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_classification_base.py @@ -0,0 +1,97 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from datetime import datetime + +import unittest +import os +import sys +import subprocess +import six +import argparse +import pickle +import numpy as np +import paddle.fluid as fluid + +from paddle.fluid.transpiler.collective import \ + GradAllReduce, DistributedClassificationOptimizer +from test_dist_collective_base import DistCollectiveRunner, elog + +DEFAULT_FEATURE_SIZE = 4 +DEFAULT_CLASS_NUM = 4 + + +class DistClassificationRunner(DistCollectiveRunner): + ################################## + ##### user specified methods ##### + + @classmethod + def add_other_arguments(cls, parser): + pass + + def local_classify_subnet(self, feature, label): + raise NotImplementedError( + 'local_classifiy_subnet should be implemented by child classes.') + + def parall_classify_subnet(self, feature, label): + raise NotImplementedError( + 'parall_classify_subnet should be implemented by child classes.') + + ##### user specified methods ##### + ################################## + + @classmethod + def add_arguments(cls, parser): + parser.add_argument( + '--feature_size', type=int, default=DEFAULT_FEATURE_SIZE) + parser.add_argument('--class_num', type=int, default=DEFAULT_CLASS_NUM) + cls.add_other_arguments(parser) + + def build_local_net(self): + return self.build_classification_net() + + def build_parall_net(self): + return self.build_classification_net() + + def yield_sample(self, np_random): + yield [ + np_random.rand(self.args.feature_size), + np_random.randint(self.args.class_num) + ] + + def dist_optimize(self, optimizer, loss): + args = self.args + optimizer_wrapper = DistributedClassificationOptimizer(optimizer, + args.batch_size) + optimizer_wrapper.minimize(loss) + transpiler = GradAllReduce() + transpiler.transpile( + rank=args.rank, + endpoints=args.endpoints, + current_endpoint=args.current_endpoint, + wait_port=True) + + def build_classification_net(self): + args = self.args + feature = fluid.layers.data( + name='feature', shape=[args.feature_size], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + if args.nranks <= 1: + elog(self, 'build local network') + loss = self.local_classify_subnet(feature, label) + else: + elog(self, 'build parallel network') + loss = self.parall_classify_subnet(feature, label) + return [feature, label], loss diff --git a/python/paddle/fluid/tests/unittests/dist_softmax_classification.py b/python/paddle/fluid/tests/unittests/dist_softmax_classification.py index c041f0bb4dd..09dffc99ba4 100644 --- a/python/paddle/fluid/tests/unittests/dist_softmax_classification.py +++ b/python/paddle/fluid/tests/unittests/dist_softmax_classification.py @@ -17,16 +17,13 @@ import unittest import numpy as np import paddle.fluid as fluid import paddle.fluid.layers as layers +import paddle.fluid.layers.dist_algo as dist_algo from paddle.fluid.initializer import NumpyArrayInitializer -from test_dist_classification_base import DistClassificationRunner, runtime_main +from dist_classification_base import DistClassificationRunner +from test_dist_collective_base import runtime_main -# TODO bias attr class DistSoftmaxClassificationRunner(DistClassificationRunner): - @classmethod - def add_arguments(cls, parser): - pass - def __init__(self, args): super(DistSoftmaxClassificationRunner, self).__init__(args) np.random.seed(1024) @@ -47,7 +44,7 @@ class DistSoftmaxClassificationRunner(DistClassificationRunner): shard_start = shard_dim * args.rank rank_param_value = self.param_value[:, shard_start:(shard_start + shard_dim)] - cost = layers.collective._distributed_fc_classify( + cost = layers.dist_algo._distributed_softmax_classify( x=feature, label=label, class_num=args.class_num, diff --git a/python/paddle/fluid/tests/unittests/test_dist_arcface_classification.py b/python/paddle/fluid/tests/unittests/test_dist_arcface_classification.py index 3dbfbf7306c..8865c2dcdc2 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_arcface_classification.py +++ b/python/paddle/fluid/tests/unittests/test_dist_arcface_classification.py @@ -14,17 +14,17 @@ import unittest import paddle.fluid as fluid -from test_dist_classification_base import TestDistClassificationBase +from test_dist_collective_base import TestDistCollectiveBase -class TestDistArcfaceClassification(TestDistClassificationBase): +class TestDistArcfaceClassification(TestDistCollectiveBase): def test_training(self): if fluid.core.is_compiled_with_cuda(): self.compare_parall_to_local( 'dist_arcface_classification.py', delta=1e-5) -class TestDistArcfaceClassificationParam(TestDistClassificationBase): +class TestDistArcfaceClassificationParam(TestDistCollectiveBase): def append_common_cmd(self): return '--arcface_margin 0.5 --arcface_scale 64' diff --git a/python/paddle/fluid/tests/unittests/test_dist_classification_base.py b/python/paddle/fluid/tests/unittests/test_dist_collective_base.py similarity index 66% rename from python/paddle/fluid/tests/unittests/test_dist_classification_base.py rename to python/paddle/fluid/tests/unittests/test_dist_collective_base.py index 4cef1c73ca3..c9352fbd6fa 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_classification_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_collective_base.py @@ -25,14 +25,9 @@ import pickle import numpy as np import paddle.fluid as fluid -from paddle.fluid.transpiler.collective import \ - GradAllReduce, DistributedClassificationOptimizer +from paddle.fluid.transpiler.collective import GradAllReduce DEFAULT_BATCH_SIZE = 2 -DEFAULT_FEATURE_SIZE = 4 -DEFAULT_CLASS_NUM = 4 -DEFAULT_LR = 0.001 - RUN_STEPS = 5 @@ -55,51 +50,64 @@ def elog(ref, message, to_pipe=False): print(log_str, file=sys.stderr) -class DistClassificationRunner(object): - def __init__(self, args): - args.rank = int(os.getenv('PADDLE_TRAINER_ID', '0')) - args.current_endpoint = os.getenv('PADDLE_CURRENT_ENDPOINT') - args.nranks = int(os.getenv('PADDLE_TRAINERS_NUM', '1')) - args.endpoints = os.getenv('PADDLE_TRAINER_ENDPOINTS', '').split(',') - args.device_id = int(os.getenv('FLAGS_selected_gpus', '0')) - self.args = args +class DistCollectiveRunner(object): + ################################## + ##### user specified methods ##### - def elog(self, message, to_pipe=False): - elog(self, message, to_pipe) + @classmethod + def add_arguments(cls, parser): + pass - def local_classify_subnet(self, feature, label): + def build_local_net(self): raise NotImplementedError( - 'get_local_model should be implemented by child classes.') + 'local_net should be implemented by child classes.') - def parall_classify_subnet(self, feature, label): + def build_parall_net(self): raise NotImplementedError( - 'get_parall_model should be implemented by child classes.') + 'parall_net should be implemented by child classes.') + + def yield_sample(self, np_random): + raise NotImplementedError( + 'data_generator should be implemented by child classes') + + def create_optimizer(self): + return fluid.optimizer.SGD(learning_rate=0.001) + + def dist_optimize(self, optimizer, loss): + args = self.args + optimizer.minimize(loss) + transpiler = GradAllReduce() + transpiler.transpile( + rank=args.rank, + endpoints=args.endpoints, + current_endpoint=args.current_endpoint, + wait_port=True) + + ##### user specified methods ##### + ################################## + + def __init__(self, args): + self.args = args def build_net(self): args = self.args - main_prog = fluid.Program() - start_prog = fluid.Program() - optimizer = fluid.optimizer.SGD(learning_rate=args.lr) - with fluid.program_guard(main_prog, start_prog): - feature = fluid.layers.data( - name='feature', shape=[args.feature_size], dtype='float32') - label = fluid.layers.data(name='label', shape=[1], dtype='int64') - if args.nranks <= 1: - elog(self, 'build local network') - loss = self.local_classify_subnet(feature, label) - optimizer.minimize(loss) - else: - elog(self, 'build parallel network') - loss = self.parall_classify_subnet(feature, label) - # TODO why need batch size? - optimizer_wrapper = DistributedClassificationOptimizer( - optimizer, args.batch_size) - optimizer_wrapper.minimize(loss) - self.transpile(main_prog, start_prog) - - return [feature, label], loss, start_prog - - def gen_rank_batch(self): + if args.nranks <= 1: + elog(self, 'build local network') + data, loss = self.build_local_net() + else: + elog(self, '[r%d] build parallel network' % args.rank) + data, loss = self.build_parall_net() + return data, loss + + def optimize(self, loss): + args = self.args + optimizer = self.create_optimizer() + if args.nranks <= 1: + optimizer.minimize(loss) + else: + self.dist_optimize(optimizer, loss) + + def get_rank_batch(self): args = self.args def generate_global_batch(): @@ -109,10 +117,10 @@ class DistClassificationRunner(object): self.seed += 1 global_batch_size = args.batch_size * args.nranks - return [[ - np.random.rand(args.feature_size), - np.random.randint(args.class_num) - ] for i in range(global_batch_size)] + return [ + next(self.yield_sample(np.random)) + for i in range(global_batch_size) + ] rank_batch = [] global_batch = generate_global_batch() @@ -122,34 +130,26 @@ class DistClassificationRunner(object): return rank_batch - def transpile(self, main_prog, start_prog): - args = self.args - transpiler = GradAllReduce() - transpiler.transpile( - startup_program=start_prog, - main_program=main_prog, - rank=args.rank, - endpoints=args.endpoints, - current_endpoint=args.current_endpoint, - wait_port=True) - def run(self): - feed_vars, loss, start_prog = self.build_net() - main_prog = loss.block.program + main_prog = fluid.Program() + start_prog = fluid.Program() + with fluid.program_guard(main_prog, start_prog): + data, loss = self.build_net() + self.optimize(loss) place = fluid.CUDAPlace(self.args.device_id) exe = fluid.Executor(place) exe.run(start_prog) elog(self, 'finish running startup program.') - feeder = fluid.DataFeeder(feed_vars, place) + feeder = fluid.DataFeeder(data, place) elog(self, 'start to train') out_losses = [] for i in range(RUN_STEPS): losses = exe.run(main_prog, fetch_list=[loss], - feed=feeder.feed(self.gen_rank_batch())) + feed=feeder.feed(self.get_rank_batch())) out_losses.append(losses[0][0]) elog(self, "step %d loss: %f" % (i, losses[0][0])) @@ -157,22 +157,20 @@ class DistClassificationRunner(object): print2pipe(out_losses) - @classmethod - def add_arguments(cls, parser): - pass - def runtime_main(test_class): parser = argparse.ArgumentParser( description='Run distributed classification test.') parser.add_argument('--batch_size', type=int, required=True) - parser.add_argument( - '--feature_size', type=int, default=DEFAULT_FEATURE_SIZE) - parser.add_argument('--class_num', type=int, default=DEFAULT_CLASS_NUM) - parser.add_argument('--lr', type=float, default=DEFAULT_LR) test_class.add_arguments(parser) args = parser.parse_args() + args.rank = int(os.getenv('PADDLE_TRAINER_ID', '0')) + args.current_endpoint = os.getenv('PADDLE_CURRENT_ENDPOINT') + args.nranks = int(os.getenv('PADDLE_TRAINERS_NUM', '1')) + args.endpoints = os.getenv('PADDLE_TRAINER_ENDPOINTS', '').split(',') + args.device_id = int(os.getenv('FLAGS_selected_gpus', '0')) + trainer = test_class(args) trainer.run() @@ -181,7 +179,26 @@ import socket from contextlib import closing -class TestDistClassificationBase(unittest.TestCase): +class TestDistCollectiveBase(unittest.TestCase): + ################################## + ##### user specified methods ##### + + # override configurations in setUp + def update_config(self): + pass + + def append_common_cmd(self): + return '' + + def append_local_cmd(self): + return '' + + def append_parall_cmd(self): + return '' + + ##### user specified methods ##### + ################################## + def setUp(self): self.nranks = 2 self.batch_size = DEFAULT_BATCH_SIZE @@ -201,42 +218,36 @@ class TestDistClassificationBase(unittest.TestCase): port = s.getsockname()[1] return port - # override configurations in setUp - def update_config(self): - pass - - def append_common_cmd(self): - return '' - - def append_local_cmd(self): - return '' - - def append_parall_cmd(self): - return '' - - def run_local(self, train_script, user_env): + def run_local(self, train_script, update_env): env = {} - cmd = '%s -u %s --batch_size %d' % (sys.executable, train_script, - self.global_batch_size) + cmd = sys.executable + ' -u' + if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': + env['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '') + cmd += ' -m coverage run --branch -p' + + cmd += ' %s --batch_size %d' % (train_script, self.global_batch_size) if self.append_common_cmd(): cmd += ' ' + self.append_common_cmd().strip() if self.append_local_cmd(): cmd += ' ' + self.append_local_cmd().strip() - if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': - env['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '') - cmd += ' -m coverage run --branch -p' - env.update(user_env) + env.update(update_env) elog(self, 'local_cmd: %s' % cmd) elog(self, 'local_env: %s' % env) - ferr = open('/tmp/local.log', 'w') - proc = subprocess.Popen( - cmd.split(' '), stdout=subprocess.PIPE, stderr=ferr, env=env) + local_log = '/tmp/local.log' + with open(local_log, 'w') as ferr: + proc = subprocess.Popen( + cmd.split(' '), stdout=subprocess.PIPE, stderr=ferr, env=env) + out, err = proc.communicate() - out, err = proc.communicate() - ferr.close() + with open(local_log, 'r') as fin: + proc_log_str = ''.join(fin.readlines()) + message = 'local_stderr:\n%s\nlocal_stderr end' % proc_log_str + if proc.returncode != 0: + raise RuntimeError(message) + elog(self, message) elog(self, 'local_stdout: %s' % pickle.loads(out)) @@ -254,25 +265,27 @@ class TestDistClassificationBase(unittest.TestCase): env['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '') return env - def run_parall(self, train_script, user_env): - cmd = '%s -u %s --batch_size %d' % (sys.executable, train_script, - self.batch_size) + def run_parall(self, train_script, update_env): + cmd = sys.executable + ' -u' + if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': + cmd += ' -m coverage run --branch -p' + + cmd += ' %s --batch_size %d' % (train_script, self.batch_size) if self.append_common_cmd(): cmd += ' ' + self.append_common_cmd().strip() if self.append_parall_cmd(): cmd += ' ' + self.append_parall_cmd().strip() - if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': - cmd += ' -m coverage run --branch -p' procs = [] ferrs = [] + parall_log_format = '/tmp/parall_tr%d.log' for rank in range(self.nranks): env = self.get_parall_env(rank) - env.update(user_env) + env.update(update_env) elog(self, '[r%d] parall_cmd: %s' % (rank, cmd)) elog(self, '[r%d] parall_env: %s' % (rank, env)) - ferr = open('/tmp/parall_tr%d.log' % rank, 'w') + ferr = open(parall_log_format % rank, 'w') proc = subprocess.Popen( cmd.strip().split(' '), stdout=subprocess.PIPE, @@ -286,22 +299,31 @@ class TestDistClassificationBase(unittest.TestCase): out, err = procs[rank].communicate() ferrs[rank].close() + with open(parall_log_format % rank, 'r') as fin: + proc_log_str = ''.join(fin.readlines()) + message = '[r%d] parall_stderr:\n%s\nparall_stderr end' % ( + rank, proc_log_str) + if procs[rank].returncode != 0: + raise RuntimeError(message) + elog(self, message) + + elog(self, '[r%d] parall_stdout: %s' % (rank, pickle.loads(out))) outs.append(out) return [pickle.loads(outs[i]) for i in range(self.nranks)] - def compare_parall_to_local(self, train_script, delta=1e-3, user_envs={}): + def compare_parall_to_local(self, train_script, delta=1e-3, update_envs={}): required_envs = { 'PATH': os.getenv('PATH', ''), 'PYTHONPATH': os.getenv('PYTHONPATH', ''), 'LD_LIBRARY_PATH': os.getenv('LD_LIBRARY_PATH', ''), 'FLAGS_fraction_of_gpu_memory_to_use': '0.15', - 'FLAGS_rpc_deadline': '30000', # 5s to fail fast + 'FLAGS_rpc_deadline': '5000', # 5s to fail fast 'FLAGS_cudnn_deterministic': '1', 'NCCL_P2P_DISABLE': '1', 'NCCL_SHM_DISABLE': '1' } - required_envs.update(user_envs) + required_envs.update(update_envs) local_losses = self.run_local(train_script, required_envs) parall_losses = self.run_parall(train_script, required_envs) diff --git a/python/paddle/fluid/tests/unittests/test_dist_softmax_classification.py b/python/paddle/fluid/tests/unittests/test_dist_softmax_classification.py index c872412ed03..f890ec4e2e2 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_softmax_classification.py +++ b/python/paddle/fluid/tests/unittests/test_dist_softmax_classification.py @@ -14,14 +14,11 @@ import unittest import paddle.fluid as fluid -from test_dist_classification_base import TestDistClassificationBase +from test_dist_collective_base import TestDistCollectiveBase -class TestDistSoftmaxClassification(TestDistClassificationBase): - def setup_config(self): - pass - - def test_dist_train(self): +class TestDistSoftmaxClassification(TestDistCollectiveBase): + def test_training(self): if fluid.core.is_compiled_with_cuda(): self.compare_parall_to_local( "dist_softmax_classification.py", delta=1e-5) diff --git a/python/paddle/fluid/transpiler/collective.py b/python/paddle/fluid/transpiler/collective.py index ecca3734921..636df20c8d8 100644 --- a/python/paddle/fluid/transpiler/collective.py +++ b/python/paddle/fluid/transpiler/collective.py @@ -51,19 +51,26 @@ class Collective(object): self.op_role_key = op_maker.kOpRoleAttrName() self.op_role_var_key = op_maker.kOpRoleVarAttrName() - def transpile(self, startup_program, main_program, rank, endpoints, - current_endpoint, wait_port): + def transpile(self, + rank, + endpoints, + current_endpoint, + wait_port, + startup_program=None, + main_program=None): # in case of '127.0.0.1:6700,127.0.0.1:6701,...' if isinstance(endpoints, str): endpoints = endpoints.split(',') - self.startup_program = startup_program if startup_program is None: self.startup_program = default_startup_program() + else: + self.startup_program = startup_program - self.main_program = main_program if main_program is None: self.main_program = default_main_program() + else: + self.main_program = main_program self.nranks = len(endpoints) if self.nranks == 1: -- GitLab