提交 2be9036f 编写于 作者: G gavin1332

extract a common distributed testing class for ut

test=develop
test=document_preview
上级 cc7f2bb0
......@@ -290,7 +290,7 @@ paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'tran
paddle.fluid.layers.match_matrix_tensor (ArgSpec(args=['x', 'y', 'channel_num', 'act', 'param_attr', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, 'float32', None)), ('document', 'b6ea7d4ddeacae85e37d1e47d5262948'))
paddle.fluid.layers.filter_by_instag (ArgSpec(args=['ins', 'ins_tag', 'filter_tag', 'is_lod'], varargs=None, keywords=None, defaults=None), ('document', '7703a2088af8de4128b143ff1164ca4a'))
paddle.fluid.layers.var_conv_2d (ArgSpec(args=['input', 'row', 'col', 'input_channel', 'output_channel', 'filter_size', 'stride', 'param_attr', 'act', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, None, 'float32', None)), ('document', '7a8b8ade5512c95f9ea30261d33ded6c'))
paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '5786fdbba6753ecd6cbce5e6b0889924'))
paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '3a209cbe5f648c00f8d7c2187dc23674'))
paddle.fluid.layers.hard_swish (ArgSpec(args=['x', 'threshold', 'scale', 'offset', 'name'], varargs=None, keywords=None, defaults=(6.0, 6.0, 3.0, None)), ('document', '6a5152a7015c62cb8278fc24cb456459'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '9d7806e31bdf727c1a23b8782a09b545'))
paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '88367daf9a30c9ab83adc5d7221e23ef'))
......
......@@ -19,7 +19,6 @@ from ..layer_helper import LayerHelper, unique_name
from ..framework import Variable, default_startup_program
from ..param_attr import ParamAttr
from ..initializer import Normal, Constant
import nn, ops
def _allreduce(x, out=None, reduce_type="sum", sync_mode=False):
......@@ -183,305 +182,3 @@ def _c_sync_comm_stream(x, ring_id):
outputs={'Out': [x]},
attrs={'ring_id': ring_id})
return x
class DistributedClassifier(object):
'''
Tookit for distributed classification, in which the parameter of the last
full-connected layer is distributed to all trainers
'''
def __init__(self, nclasses, nranks, rank_id, layer_helper):
self.nclasses = nclasses
self.nranks = nranks
self.rank_id = rank_id
self._layer_helper = layer_helper
self.shard_dim = (nclasses + nranks - 1) // nranks
self.padding_dim = 0
self.is_equal_division = True
if nclasses % nranks != 0:
self.is_equal_division = False
if rank_id == nranks - 1:
other_shard_dim = self.shard_dim
self.shard_dim = nclasses % other_shard_dim
self.padding_dim = other_shard_dim - self.shard_dim
def create_parameter(self,
dtype,
in_dim,
param_attr=None,
transpose_weight=False,
use_bias=True):
if param_attr is None:
stdv = math.sqrt(2.0 / (in_dim + self.nclasses))
param_attr = ParamAttr(initializer=Normal(scale=stdv))
weight_shape = [self.shard_dim, in_dim
] if transpose_weight else [in_dim, self.shard_dim]
weight = self._layer_helper.create_parameter(
shape=weight_shape, dtype=dtype, attr=param_attr, is_bias=False)
# avoid distributed parameter allreduce gradients
weight.is_distributed = True
# avoid distributed parameter broadcasting in startup program
default_startup_program().global_block().vars[
weight.name].is_distributed = True
bias = None
if use_bias:
bias = self._layer_helper.create_parameter(
shape=[self.shard_dim],
attr=ParamAttr(),
dtype=dtype,
is_bias=True)
bias.is_distributed = True
default_startup_program().global_block().vars[
bias.name].is_distributed = True
return weight, bias
def softmax_with_cross_entropy(self, shard_logit, shard_label):
shard_max = nn.reduce_max(shard_logit, dim=1, keep_dim=True)
global_max = _c_allreduce(
shard_max, reduce_type='max', use_calc_stream=True)
shard_logit_new = nn.elementwise_sub(shard_logit, global_max)
shard_exp = ops.exp(shard_logit_new)
shard_demon = nn.reduce_sum(shard_exp, dim=1, keep_dim=True)
global_demon = _c_allreduce(
shard_demon, reduce_type='sum', use_calc_stream=True)
global_log_demon = nn.log(global_demon)
shard_log_prob = shard_logit_new - global_log_demon
shard_prob = ops.exp(shard_log_prob)
shard_one_hot = nn.one_hot(
shard_label, depth=self.shard_dim, allow_out_of_range=True)
target_log_prob = nn.reduce_min(
shard_log_prob * shard_one_hot, dim=1, keep_dim=True)
shard_loss = nn.scale(target_log_prob, scale=-1.0)
global_loss = _c_reducescatter(
shard_loss, nranks=self.nranks, use_calc_stream=True)
return global_loss, shard_prob
def fc_classify(self, x, label, param_attr=None, use_bias=True):
flatten_dim = reduce(lambda a, b: a * b, x.shape[1:], 1)
weight, bias = self.create_parameter(
dtype=x.dtype,
in_dim=flatten_dim,
param_attr=param_attr,
use_bias=use_bias)
x_all = _c_allgather(x, nranks=self.nranks, use_calc_stream=True)
label_all = _c_allgather(
label, nranks=self.nranks, use_calc_stream=True)
label_all.stop_gradient = True
shard_fc = nn.mul(x_all, weight, x_num_col_dims=1)
if use_bias:
shard_fc = nn.elementwise_add(shard_fc, bias)
shard_label = nn.shard_index(
label_all,
index_num=self.nclasses,
nshards=self.nranks,
shard_id=self.rank_id,
ignore_value=-1)
shard_label.stop_gradient = True
global_loss, shard_prob = self.softmax_with_cross_entropy(shard_fc,
shard_label)
avg_loss = nn.mean(global_loss)
avg_loss._set_info('shard_logit', shard_fc)
avg_loss._set_info('shard_prob', shard_prob)
avg_loss._set_info('shard_label', shard_label)
avg_loss._set_info('shard_dim', self.shard_dim)
return avg_loss
def arcface_classify(self,
x,
label,
margin=0.5,
logit_scale=64,
param_attr=None):
'''
reference: ArcFace. https://arxiv.org/abs/1801.07698
'''
flatten_dim = reduce(lambda a, b: a * b, x.shape[1:], 1)
weight, bias = self.create_parameter(
dtype=x.dtype,
in_dim=flatten_dim,
param_attr=param_attr,
transpose_weight=True,
use_bias=False)
# normalize x
x_l2 = ops.sqrt(nn.reduce_sum(nn.square(x), dim=1))
norm_x = nn.elementwise_div(x, x_l2, axis=0)
norm_x_all = _c_allgather(
norm_x, nranks=self.nranks, use_calc_stream=True)
label_all = _c_allgather(
label, nranks=self.nranks, use_calc_stream=True)
label_all.stop_gradient = True
shard_label = nn.shard_index(
label_all,
index_num=self.nclasses,
nshards=self.nranks,
shard_id=self.rank_id,
ignore_value=-1)
shard_label.stop_gradient = True
# normalize weight
weight_l2 = ops.sqrt(nn.reduce_sum(nn.square(weight), dim=1))
norm_weight = nn.elementwise_div(weight, weight_l2, axis=0)
norm_weight = nn.transpose(norm_weight, perm=[1, 0])
shard_cos = nn.mul(norm_x_all, norm_weight, x_num_col_dims=1)
theta = ops.acos(shard_cos)
margin_cos = ops.cos(theta + margin)
shard_one_hot = nn.one_hot(
shard_label, depth=self.shard_dim, allow_out_of_range=True)
shard_one_hot.stop_gradient = True
diff = (margin_cos - shard_cos) * shard_one_hot
shard_target_cos = shard_cos + diff
shard_logit = nn.scale(shard_target_cos, scale=logit_scale)
global_loss, shard_prob = self.softmax_with_cross_entropy(shard_logit,
shard_label)
avg_loss = nn.mean(global_loss)
avg_loss._set_info('shard_logit', shard_logit)
avg_loss._set_info('shard_prob', shard_prob)
avg_loss._set_info('shard_label', shard_label)
avg_loss._set_info('shard_dim', self.shard_dim)
return avg_loss
def _distributed_fc_classify(x,
label,
class_num,
nranks,
rank_id,
param_attr=None,
use_bias=True,
name=None):
'''
Classification layer with FC, softmax and cross entropy calculation of
distibuted version in case of too large number of classes.
Args:
x (Variable): The feature representation of the input samples. This
feature will be flattened into 2-D tensor from dimension index
1. E.g. [32, 1024, 1, 1] will be flattened to [32, 1024].
label (Variable): The label corresponding to the input samples.
class_num (integer): The number of classes of the classification problem.
nranks (integer): The number of ranks of distributed trainers.
rank_id (integer): The rank index of the current trainer.
param_attr (ParamAttr, default None): The parameter attribute for
learnable distributed parameters/weights of this layer.
use_bias (float, default 64.0): The scale factor for logit value
of cosine range.
name (str, default None): The name of this layer.
Returns:
Variable: The ArcFace loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.layers.data(name="input",
shape=[32, 1024],
dtype='float32',
append_batch_size=False)
label = fluid.layers.data(name="label",
shape=[32, 1],
dtype='int64',
append_batch_size=False)
y = fluid.layers.collective.distributed_fc_classify(x=input,
label=label,
class_num=1000,
nranks=8,
rank_id=0)
'''
if name is None:
name = 'dist_fc'
helper = LayerHelper(name, **locals())
classifier = DistributedClassifier(class_num, nranks, rank_id, helper)
return classifier.fc_classify(x, label, param_attr, use_bias)
def _distributed_arcface_classify(x,
label,
class_num,
nranks,
rank_id,
margin=0.5,
logit_scale=64.0,
param_attr=None,
name=None):
'''
Classification layer with ArcFace loss of distibuted version in case of
too large number of classes. the equation is
.. math::
L=-\frac{1}{N}\sum^N_{i=1}\log\frac{e^{s(cos(\theta_{y_i}+m))}}{e^{s(cos(\theta_{y_i}+m))}+\sum^n_{j=1,j\neq y_i} e^{scos\theta_{y_i}}}
where the :math: `\theta_{y_i}` is the angle between the feature :math: `x` and
the representation of class :math: `i`. The details of ArcFace loss
could be referred to https://arxiv.org/abs/1801.07698.
Args:
x (Variable): The feature representation of the input samples. This
feature will be flattened into 2-D tensor from dimension index
1. E.g. [32, 1024, 1, 1] will be flattened to [32, 1024].
label (Variable): The label corresponding to the input samples.
class_num (integer): The number of classes of the classification problem.
nranks (integer): The number of ranks of distributed trainers.
rank_id (integer): The rank index of the current trainer.
margin (float, default 0.5): The angular margin penalty to enhance
the intra-class compactness and inter-class discrepancy.
logit_scale (float, default 64.0): The scale factor for logit value
of cosine range.
param_attr (ParamAttr, default None): The parameter attribute for
learnable distributed parameters/weights of this layer.
name (str, default None): The name of this layer.
Returns:
Variable: The ArcFace loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.layers.data(name="input",
shape=[32, 1024],
dtype='float32',
append_batch_size=False)
label = fluid.layers.data(name="label",
shape=[32, 1],
dtype='int64',
append_batch_size=False)
y = fluid.layers.collective.distributed_arcface_classify(x=input,
label=label,
class_num=1000,
nranks=8,
rank_id=0)
'''
if name is None:
name = 'dist_fc'
helper = LayerHelper(name, **locals())
classifier = DistributedClassifier(class_num, nranks, rank_id, helper)
return classifier.arcface_classify(
x=x,
label=label,
margin=margin,
logit_scale=logit_scale,
param_attr=param_attr)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import math
from six.moves import reduce
from ..layer_helper import LayerHelper
from ..framework import Variable, default_startup_program
from ..param_attr import ParamAttr
from ..initializer import Normal, Constant
from . import nn, ops, collective
class DistributedClassifier(object):
'''
Tookit for distributed classification, in which the parameter of the last
full-connected layer is distributed to all trainers
'''
def __init__(self, nclasses, nranks, rank_id, layer_helper):
self.nclasses = nclasses
self.nranks = nranks
self.rank_id = rank_id
self._layer_helper = layer_helper
self.shard_dim = (nclasses + nranks - 1) // nranks
self.padding_dim = 0
self.is_equal_division = True
if nclasses % nranks != 0:
self.is_equal_division = False
if rank_id == nranks - 1:
other_shard_dim = self.shard_dim
self.shard_dim = nclasses % other_shard_dim
self.padding_dim = other_shard_dim - self.shard_dim
def create_parameter(self,
dtype,
in_dim,
param_attr=None,
transpose_weight=False,
use_bias=True):
if param_attr is None:
stdv = math.sqrt(2.0 / (in_dim + self.nclasses))
param_attr = ParamAttr(initializer=Normal(scale=stdv))
weight_shape = [self.shard_dim, in_dim
] if transpose_weight else [in_dim, self.shard_dim]
weight = self._layer_helper.create_parameter(
shape=weight_shape, dtype=dtype, attr=param_attr, is_bias=False)
# avoid distributed parameter allreduce gradients
weight.is_distributed = True
# avoid distributed parameter broadcasting in startup program
default_startup_program().global_block().vars[
weight.name].is_distributed = True
bias = None
if use_bias:
bias = self._layer_helper.create_parameter(
shape=[self.shard_dim],
attr=ParamAttr(),
dtype=dtype,
is_bias=True)
bias.is_distributed = True
default_startup_program().global_block().vars[
bias.name].is_distributed = True
return weight, bias
def softmax_with_cross_entropy(self, shard_logit, shard_label):
shard_max = nn.reduce_max(shard_logit, dim=1, keep_dim=True)
global_max = collective._c_allreduce(
shard_max, reduce_type='max', use_calc_stream=True)
shard_logit_new = nn.elementwise_sub(shard_logit, global_max)
shard_exp = ops.exp(shard_logit_new)
shard_demon = nn.reduce_sum(shard_exp, dim=1, keep_dim=True)
global_demon = collective._c_allreduce(
shard_demon, reduce_type='sum', use_calc_stream=True)
global_log_demon = nn.log(global_demon)
shard_log_prob = shard_logit_new - global_log_demon
shard_prob = ops.exp(shard_log_prob)
shard_one_hot = nn.one_hot(
shard_label, depth=self.shard_dim, allow_out_of_range=True)
target_log_prob = nn.reduce_min(
shard_log_prob * shard_one_hot, dim=1, keep_dim=True)
shard_loss = nn.scale(target_log_prob, scale=-1.0)
global_loss = collective._c_reducescatter(
shard_loss, nranks=self.nranks, use_calc_stream=True)
return global_loss, shard_prob
def softmax_classify(self, x, label, param_attr=None, use_bias=True):
flatten_dim = reduce(lambda a, b: a * b, x.shape[1:], 1)
weight, bias = self.create_parameter(
dtype=x.dtype,
in_dim=flatten_dim,
param_attr=param_attr,
use_bias=use_bias)
x_all = collective._c_allgather(
x, nranks=self.nranks, use_calc_stream=True)
label_all = collective._c_allgather(
label, nranks=self.nranks, use_calc_stream=True)
label_all.stop_gradient = True
shard_fc = nn.mul(x_all, weight, x_num_col_dims=1)
if use_bias:
shard_fc = nn.elementwise_add(shard_fc, bias)
shard_label = nn.shard_index(
label_all,
index_num=self.nclasses,
nshards=self.nranks,
shard_id=self.rank_id,
ignore_value=-1)
shard_label.stop_gradient = True
global_loss, shard_prob = self.softmax_with_cross_entropy(shard_fc,
shard_label)
avg_loss = nn.mean(global_loss)
avg_loss._set_info('shard_logit', shard_fc)
avg_loss._set_info('shard_prob', shard_prob)
avg_loss._set_info('shard_label', shard_label)
avg_loss._set_info('shard_dim', self.shard_dim)
return avg_loss
def arcface_classify(self,
x,
label,
margin=0.5,
logit_scale=64,
param_attr=None):
'''
reference: ArcFace. https://arxiv.org/abs/1801.07698
'''
flatten_dim = reduce(lambda a, b: a * b, x.shape[1:], 1)
weight, bias = self.create_parameter(
dtype=x.dtype,
in_dim=flatten_dim,
param_attr=param_attr,
transpose_weight=True,
use_bias=False)
# normalize x
x_l2 = ops.sqrt(nn.reduce_sum(nn.square(x), dim=1))
norm_x = nn.elementwise_div(x, x_l2, axis=0)
norm_x_all = collective._c_allgather(
norm_x, nranks=self.nranks, use_calc_stream=True)
label_all = collective._c_allgather(
label, nranks=self.nranks, use_calc_stream=True)
label_all.stop_gradient = True
shard_label = nn.shard_index(
label_all,
index_num=self.nclasses,
nshards=self.nranks,
shard_id=self.rank_id,
ignore_value=-1)
shard_label.stop_gradient = True
# normalize weight
weight_l2 = ops.sqrt(nn.reduce_sum(nn.square(weight), dim=1))
norm_weight = nn.elementwise_div(weight, weight_l2, axis=0)
norm_weight = nn.transpose(norm_weight, perm=[1, 0])
shard_cos = nn.mul(norm_x_all, norm_weight, x_num_col_dims=1)
theta = ops.acos(shard_cos)
margin_cos = ops.cos(theta + margin)
shard_one_hot = nn.one_hot(
shard_label, depth=self.shard_dim, allow_out_of_range=True)
shard_one_hot.stop_gradient = True
diff = (margin_cos - shard_cos) * shard_one_hot
shard_target_cos = shard_cos + diff
shard_logit = nn.scale(shard_target_cos, scale=logit_scale)
global_loss, shard_prob = self.softmax_with_cross_entropy(shard_logit,
shard_label)
avg_loss = nn.mean(global_loss)
avg_loss._set_info('shard_logit', shard_logit)
avg_loss._set_info('shard_prob', shard_prob)
avg_loss._set_info('shard_label', shard_label)
avg_loss._set_info('shard_dim', self.shard_dim)
return avg_loss
def _distributed_softmax_classify(x,
label,
class_num,
nranks,
rank_id,
param_attr=None,
use_bias=True,
name=None):
'''
Classification layer with FC, softmax and cross entropy calculation of
distibuted version in case of too large number of classes.
Args:
x (Variable): The feature representation of the input samples. This
feature will be flattened into 2-D tensor from dimension index
1. E.g. [32, 1024, 1, 1] will be flattened to [32, 1024].
label (Variable): The label corresponding to the input samples.
class_num (integer): The number of classes of the classification problem.
nranks (integer): The number of ranks of distributed trainers.
rank_id (integer): The rank index of the current trainer.
param_attr (ParamAttr, default None): The parameter attribute for
learnable distributed parameters/weights of this layer.
use_bias (float, default 64.0): The scale factor for logit value
of cosine range.
name (str, default None): The name of this layer.
Returns:
Variable: The ArcFace loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.layers.data(name="input",
shape=[32, 1024],
dtype='float32',
append_batch_size=False)
label = fluid.layers.data(name="label",
shape=[32, 1],
dtype='int64',
append_batch_size=False)
y = fluid.layers.collective.distributed_softmax_classify(x=input,
label=label,
class_num=1000,
nranks=8,
rank_id=0)
'''
if name is None:
name = 'dist_softmax'
helper = LayerHelper(name, **locals())
classifier = DistributedClassifier(class_num, nranks, rank_id, helper)
return classifier.softmax_classify(x, label, param_attr, use_bias)
def _distributed_arcface_classify(x,
label,
class_num,
nranks,
rank_id,
margin=0.5,
logit_scale=64.0,
param_attr=None,
name=None):
'''
Classification layer with ArcFace loss of distibuted version in case of
too large number of classes. the equation is
.. math::
L=-\frac{1}{N}\sum^N_{i=1}\log\frac{e^{s(cos(\theta_{y_i}+m))}}{e^{s(cos(\theta_{y_i}+m))}+\sum^n_{j=1,j\neq y_i} e^{scos\theta_{y_i}}}
where the :math: `\theta_{y_i}` is the angle between the feature :math: `x` and
the representation of class :math: `i`. The details of ArcFace loss
could be referred to https://arxiv.org/abs/1801.07698.
Args:
x (Variable): The feature representation of the input samples. This
feature will be flattened into 2-D tensor from dimension index
1. E.g. [32, 1024, 1, 1] will be flattened to [32, 1024].
label (Variable): The label corresponding to the input samples.
class_num (integer): The number of classes of the classification problem.
nranks (integer): The number of ranks of distributed trainers.
rank_id (integer): The rank index of the current trainer.
margin (float, default 0.5): The angular margin penalty to enhance
the intra-class compactness and inter-class discrepancy.
logit_scale (float, default 64.0): The scale factor for logit value
of cosine range.
param_attr (ParamAttr, default None): The parameter attribute for
learnable distributed parameters/weights of this layer.
name (str, default None): The name of this layer.
Returns:
Variable: The ArcFace loss.
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.layers.data(name="input",
shape=[32, 1024],
dtype='float32',
append_batch_size=False)
label = fluid.layers.data(name="label",
shape=[32, 1],
dtype='int64',
append_batch_size=False)
y = fluid.layers.collective.distributed_arcface_classify(x=input,
label=label,
class_num=1000,
nranks=8,
rank_id=0)
'''
if name is None:
name = 'dist_fc'
helper = LayerHelper(name, **locals())
classifier = DistributedClassifier(class_num, nranks, rank_id, helper)
return classifier.arcface_classify(
x=x,
label=label,
margin=margin,
logit_scale=logit_scale,
param_attr=param_attr)
......@@ -17,22 +17,24 @@ import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.layers.collective as collective
import paddle.fluid.layers.dist_algo as dist_algo
from paddle.fluid.initializer import NumpyArrayInitializer
from test_dist_classification_base import DistClassificationRunner, runtime_main
from dist_classification_base import DistClassificationRunner
from test_dist_collective_base import runtime_main
# TODO donot transpose weight
# TODO(gavin1332) check whether it is necessary to transpose weight
class DistArcfaceClassificationRunner(DistClassificationRunner):
@classmethod
def add_arguments(cls, parser):
def add_other_arguments(cls, parser):
parser.add_argument('--arcface_margin', type=float, default=0.0)
parser.add_argument('--arcface_scale', type=float, default=1.0)
def __init__(self, args):
super(DistArcfaceClassificationRunner, self).__init__(args)
np.random.seed(1024)
self.param_value = np.random.rand(args.class_num, args.feature_size)
self.param_value = np.random.rand(self.args.class_num,
self.args.feature_size)
def local_classify_subnet(self, feature, label):
args = self.args
......@@ -76,7 +78,7 @@ class DistArcfaceClassificationRunner(DistClassificationRunner):
shard_start = shard_dim * args.rank
rank_param_value = self.param_value[shard_start:(shard_start + shard_dim
), :]
cost = layers.collective._distributed_arcface_classify(
cost = layers.dist_algo._distributed_arcface_classify(
x=feature,
label=label,
class_num=args.class_num,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from datetime import datetime
import unittest
import os
import sys
import subprocess
import six
import argparse
import pickle
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.transpiler.collective import \
GradAllReduce, DistributedClassificationOptimizer
from test_dist_collective_base import DistCollectiveRunner, elog
DEFAULT_FEATURE_SIZE = 4
DEFAULT_CLASS_NUM = 4
class DistClassificationRunner(DistCollectiveRunner):
##################################
##### user specified methods #####
@classmethod
def add_other_arguments(cls, parser):
pass
def local_classify_subnet(self, feature, label):
raise NotImplementedError(
'local_classifiy_subnet should be implemented by child classes.')
def parall_classify_subnet(self, feature, label):
raise NotImplementedError(
'parall_classify_subnet should be implemented by child classes.')
##### user specified methods #####
##################################
@classmethod
def add_arguments(cls, parser):
parser.add_argument(
'--feature_size', type=int, default=DEFAULT_FEATURE_SIZE)
parser.add_argument('--class_num', type=int, default=DEFAULT_CLASS_NUM)
cls.add_other_arguments(parser)
def build_local_net(self):
return self.build_classification_net()
def build_parall_net(self):
return self.build_classification_net()
def yield_sample(self, np_random):
yield [
np_random.rand(self.args.feature_size),
np_random.randint(self.args.class_num)
]
def dist_optimize(self, optimizer, loss):
args = self.args
optimizer_wrapper = DistributedClassificationOptimizer(optimizer,
args.batch_size)
optimizer_wrapper.minimize(loss)
transpiler = GradAllReduce()
transpiler.transpile(
rank=args.rank,
endpoints=args.endpoints,
current_endpoint=args.current_endpoint,
wait_port=True)
def build_classification_net(self):
args = self.args
feature = fluid.layers.data(
name='feature', shape=[args.feature_size], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
if args.nranks <= 1:
elog(self, 'build local network')
loss = self.local_classify_subnet(feature, label)
else:
elog(self, 'build parallel network')
loss = self.parall_classify_subnet(feature, label)
return [feature, label], loss
......@@ -17,16 +17,13 @@ import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.layers.dist_algo as dist_algo
from paddle.fluid.initializer import NumpyArrayInitializer
from test_dist_classification_base import DistClassificationRunner, runtime_main
from dist_classification_base import DistClassificationRunner
from test_dist_collective_base import runtime_main
# TODO bias attr
class DistSoftmaxClassificationRunner(DistClassificationRunner):
@classmethod
def add_arguments(cls, parser):
pass
def __init__(self, args):
super(DistSoftmaxClassificationRunner, self).__init__(args)
np.random.seed(1024)
......@@ -47,7 +44,7 @@ class DistSoftmaxClassificationRunner(DistClassificationRunner):
shard_start = shard_dim * args.rank
rank_param_value = self.param_value[:, shard_start:(shard_start +
shard_dim)]
cost = layers.collective._distributed_fc_classify(
cost = layers.dist_algo._distributed_softmax_classify(
x=feature,
label=label,
class_num=args.class_num,
......
......@@ -14,17 +14,17 @@
import unittest
import paddle.fluid as fluid
from test_dist_classification_base import TestDistClassificationBase
from test_dist_collective_base import TestDistCollectiveBase
class TestDistArcfaceClassification(TestDistClassificationBase):
class TestDistArcfaceClassification(TestDistCollectiveBase):
def test_training(self):
if fluid.core.is_compiled_with_cuda():
self.compare_parall_to_local(
'dist_arcface_classification.py', delta=1e-5)
class TestDistArcfaceClassificationParam(TestDistClassificationBase):
class TestDistArcfaceClassificationParam(TestDistCollectiveBase):
def append_common_cmd(self):
return '--arcface_margin 0.5 --arcface_scale 64'
......
......@@ -25,14 +25,9 @@ import pickle
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.transpiler.collective import \
GradAllReduce, DistributedClassificationOptimizer
from paddle.fluid.transpiler.collective import GradAllReduce
DEFAULT_BATCH_SIZE = 2
DEFAULT_FEATURE_SIZE = 4
DEFAULT_CLASS_NUM = 4
DEFAULT_LR = 0.001
RUN_STEPS = 5
......@@ -55,51 +50,64 @@ def elog(ref, message, to_pipe=False):
print(log_str, file=sys.stderr)
class DistClassificationRunner(object):
def __init__(self, args):
args.rank = int(os.getenv('PADDLE_TRAINER_ID', '0'))
args.current_endpoint = os.getenv('PADDLE_CURRENT_ENDPOINT')
args.nranks = int(os.getenv('PADDLE_TRAINERS_NUM', '1'))
args.endpoints = os.getenv('PADDLE_TRAINER_ENDPOINTS', '').split(',')
args.device_id = int(os.getenv('FLAGS_selected_gpus', '0'))
self.args = args
class DistCollectiveRunner(object):
##################################
##### user specified methods #####
@classmethod
def add_arguments(cls, parser):
pass
def elog(self, message, to_pipe=False):
elog(self, message, to_pipe)
def build_local_net(self):
raise NotImplementedError(
'local_net should be implemented by child classes.')
def local_classify_subnet(self, feature, label):
def build_parall_net(self):
raise NotImplementedError(
'get_local_model should be implemented by child classes.')
'parall_net should be implemented by child classes.')
def parall_classify_subnet(self, feature, label):
def yield_sample(self, np_random):
raise NotImplementedError(
'get_parall_model should be implemented by child classes.')
'data_generator should be implemented by child classes')
def create_optimizer(self):
return fluid.optimizer.SGD(learning_rate=0.001)
def dist_optimize(self, optimizer, loss):
args = self.args
optimizer.minimize(loss)
transpiler = GradAllReduce()
transpiler.transpile(
rank=args.rank,
endpoints=args.endpoints,
current_endpoint=args.current_endpoint,
wait_port=True)
##### user specified methods #####
##################################
def __init__(self, args):
self.args = args
def build_net(self):
args = self.args
main_prog = fluid.Program()
start_prog = fluid.Program()
optimizer = fluid.optimizer.SGD(learning_rate=args.lr)
with fluid.program_guard(main_prog, start_prog):
feature = fluid.layers.data(
name='feature', shape=[args.feature_size], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
if args.nranks <= 1:
elog(self, 'build local network')
loss = self.local_classify_subnet(feature, label)
optimizer.minimize(loss)
data, loss = self.build_local_net()
else:
elog(self, 'build parallel network')
loss = self.parall_classify_subnet(feature, label)
# TODO why need batch size?
optimizer_wrapper = DistributedClassificationOptimizer(
optimizer, args.batch_size)
optimizer_wrapper.minimize(loss)
self.transpile(main_prog, start_prog)
elog(self, '[r%d] build parallel network' % args.rank)
data, loss = self.build_parall_net()
return data, loss
return [feature, label], loss, start_prog
def optimize(self, loss):
args = self.args
optimizer = self.create_optimizer()
if args.nranks <= 1:
optimizer.minimize(loss)
else:
self.dist_optimize(optimizer, loss)
def gen_rank_batch(self):
def get_rank_batch(self):
args = self.args
def generate_global_batch():
......@@ -109,10 +117,10 @@ class DistClassificationRunner(object):
self.seed += 1
global_batch_size = args.batch_size * args.nranks
return [[
np.random.rand(args.feature_size),
np.random.randint(args.class_num)
] for i in range(global_batch_size)]
return [
next(self.yield_sample(np.random))
for i in range(global_batch_size)
]
rank_batch = []
global_batch = generate_global_batch()
......@@ -122,34 +130,26 @@ class DistClassificationRunner(object):
return rank_batch
def transpile(self, main_prog, start_prog):
args = self.args
transpiler = GradAllReduce()
transpiler.transpile(
startup_program=start_prog,
main_program=main_prog,
rank=args.rank,
endpoints=args.endpoints,
current_endpoint=args.current_endpoint,
wait_port=True)
def run(self):
feed_vars, loss, start_prog = self.build_net()
main_prog = loss.block.program
main_prog = fluid.Program()
start_prog = fluid.Program()
with fluid.program_guard(main_prog, start_prog):
data, loss = self.build_net()
self.optimize(loss)
place = fluid.CUDAPlace(self.args.device_id)
exe = fluid.Executor(place)
exe.run(start_prog)
elog(self, 'finish running startup program.')
feeder = fluid.DataFeeder(feed_vars, place)
feeder = fluid.DataFeeder(data, place)
elog(self, 'start to train')
out_losses = []
for i in range(RUN_STEPS):
losses = exe.run(main_prog,
fetch_list=[loss],
feed=feeder.feed(self.gen_rank_batch()))
feed=feeder.feed(self.get_rank_batch()))
out_losses.append(losses[0][0])
elog(self, "step %d loss: %f" % (i, losses[0][0]))
......@@ -157,22 +157,20 @@ class DistClassificationRunner(object):
print2pipe(out_losses)
@classmethod
def add_arguments(cls, parser):
pass
def runtime_main(test_class):
parser = argparse.ArgumentParser(
description='Run distributed classification test.')
parser.add_argument('--batch_size', type=int, required=True)
parser.add_argument(
'--feature_size', type=int, default=DEFAULT_FEATURE_SIZE)
parser.add_argument('--class_num', type=int, default=DEFAULT_CLASS_NUM)
parser.add_argument('--lr', type=float, default=DEFAULT_LR)
test_class.add_arguments(parser)
args = parser.parse_args()
args.rank = int(os.getenv('PADDLE_TRAINER_ID', '0'))
args.current_endpoint = os.getenv('PADDLE_CURRENT_ENDPOINT')
args.nranks = int(os.getenv('PADDLE_TRAINERS_NUM', '1'))
args.endpoints = os.getenv('PADDLE_TRAINER_ENDPOINTS', '').split(',')
args.device_id = int(os.getenv('FLAGS_selected_gpus', '0'))
trainer = test_class(args)
trainer.run()
......@@ -181,7 +179,26 @@ import socket
from contextlib import closing
class TestDistClassificationBase(unittest.TestCase):
class TestDistCollectiveBase(unittest.TestCase):
##################################
##### user specified methods #####
# override configurations in setUp
def update_config(self):
pass
def append_common_cmd(self):
return ''
def append_local_cmd(self):
return ''
def append_parall_cmd(self):
return ''
##### user specified methods #####
##################################
def setUp(self):
self.nranks = 2
self.batch_size = DEFAULT_BATCH_SIZE
......@@ -201,42 +218,36 @@ class TestDistClassificationBase(unittest.TestCase):
port = s.getsockname()[1]
return port
# override configurations in setUp
def update_config(self):
pass
def append_common_cmd(self):
return ''
def append_local_cmd(self):
return ''
def append_parall_cmd(self):
return ''
def run_local(self, train_script, user_env):
def run_local(self, train_script, update_env):
env = {}
cmd = '%s -u %s --batch_size %d' % (sys.executable, train_script,
self.global_batch_size)
cmd = sys.executable + ' -u'
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
env['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
cmd += ' -m coverage run --branch -p'
cmd += ' %s --batch_size %d' % (train_script, self.global_batch_size)
if self.append_common_cmd():
cmd += ' ' + self.append_common_cmd().strip()
if self.append_local_cmd():
cmd += ' ' + self.append_local_cmd().strip()
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
env['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
cmd += ' -m coverage run --branch -p'
env.update(user_env)
env.update(update_env)
elog(self, 'local_cmd: %s' % cmd)
elog(self, 'local_env: %s' % env)
ferr = open('/tmp/local.log', 'w')
local_log = '/tmp/local.log'
with open(local_log, 'w') as ferr:
proc = subprocess.Popen(
cmd.split(' '), stdout=subprocess.PIPE, stderr=ferr, env=env)
out, err = proc.communicate()
ferr.close()
with open(local_log, 'r') as fin:
proc_log_str = ''.join(fin.readlines())
message = 'local_stderr:\n%s\nlocal_stderr end' % proc_log_str
if proc.returncode != 0:
raise RuntimeError(message)
elog(self, message)
elog(self, 'local_stdout: %s' % pickle.loads(out))
......@@ -254,25 +265,27 @@ class TestDistClassificationBase(unittest.TestCase):
env['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
return env
def run_parall(self, train_script, user_env):
cmd = '%s -u %s --batch_size %d' % (sys.executable, train_script,
self.batch_size)
def run_parall(self, train_script, update_env):
cmd = sys.executable + ' -u'
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
cmd += ' -m coverage run --branch -p'
cmd += ' %s --batch_size %d' % (train_script, self.batch_size)
if self.append_common_cmd():
cmd += ' ' + self.append_common_cmd().strip()
if self.append_parall_cmd():
cmd += ' ' + self.append_parall_cmd().strip()
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
cmd += ' -m coverage run --branch -p'
procs = []
ferrs = []
parall_log_format = '/tmp/parall_tr%d.log'
for rank in range(self.nranks):
env = self.get_parall_env(rank)
env.update(user_env)
env.update(update_env)
elog(self, '[r%d] parall_cmd: %s' % (rank, cmd))
elog(self, '[r%d] parall_env: %s' % (rank, env))
ferr = open('/tmp/parall_tr%d.log' % rank, 'w')
ferr = open(parall_log_format % rank, 'w')
proc = subprocess.Popen(
cmd.strip().split(' '),
stdout=subprocess.PIPE,
......@@ -286,22 +299,31 @@ class TestDistClassificationBase(unittest.TestCase):
out, err = procs[rank].communicate()
ferrs[rank].close()
with open(parall_log_format % rank, 'r') as fin:
proc_log_str = ''.join(fin.readlines())
message = '[r%d] parall_stderr:\n%s\nparall_stderr end' % (
rank, proc_log_str)
if procs[rank].returncode != 0:
raise RuntimeError(message)
elog(self, message)
elog(self, '[r%d] parall_stdout: %s' % (rank, pickle.loads(out)))
outs.append(out)
return [pickle.loads(outs[i]) for i in range(self.nranks)]
def compare_parall_to_local(self, train_script, delta=1e-3, user_envs={}):
def compare_parall_to_local(self, train_script, delta=1e-3, update_envs={}):
required_envs = {
'PATH': os.getenv('PATH', ''),
'PYTHONPATH': os.getenv('PYTHONPATH', ''),
'LD_LIBRARY_PATH': os.getenv('LD_LIBRARY_PATH', ''),
'FLAGS_fraction_of_gpu_memory_to_use': '0.15',
'FLAGS_rpc_deadline': '30000', # 5s to fail fast
'FLAGS_rpc_deadline': '5000', # 5s to fail fast
'FLAGS_cudnn_deterministic': '1',
'NCCL_P2P_DISABLE': '1',
'NCCL_SHM_DISABLE': '1'
}
required_envs.update(user_envs)
required_envs.update(update_envs)
local_losses = self.run_local(train_script, required_envs)
parall_losses = self.run_parall(train_script, required_envs)
......
......@@ -14,14 +14,11 @@
import unittest
import paddle.fluid as fluid
from test_dist_classification_base import TestDistClassificationBase
from test_dist_collective_base import TestDistCollectiveBase
class TestDistSoftmaxClassification(TestDistClassificationBase):
def setup_config(self):
pass
def test_dist_train(self):
class TestDistSoftmaxClassification(TestDistCollectiveBase):
def test_training(self):
if fluid.core.is_compiled_with_cuda():
self.compare_parall_to_local(
"dist_softmax_classification.py", delta=1e-5)
......
......@@ -51,19 +51,26 @@ class Collective(object):
self.op_role_key = op_maker.kOpRoleAttrName()
self.op_role_var_key = op_maker.kOpRoleVarAttrName()
def transpile(self, startup_program, main_program, rank, endpoints,
current_endpoint, wait_port):
def transpile(self,
rank,
endpoints,
current_endpoint,
wait_port,
startup_program=None,
main_program=None):
# in case of '127.0.0.1:6700,127.0.0.1:6701,...'
if isinstance(endpoints, str):
endpoints = endpoints.split(',')
self.startup_program = startup_program
if startup_program is None:
self.startup_program = default_startup_program()
else:
self.startup_program = startup_program
self.main_program = main_program
if main_program is None:
self.main_program = default_main_program()
else:
self.main_program = main_program
self.nranks = len(endpoints)
if self.nranks == 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册