提交 0d953806 编写于 作者: G gavin1332

Do not transpose weight in dist_arcface_classification algorithm, without affect performance.

test=develop
test=document_preview
上级 2be9036f
......@@ -151,7 +151,6 @@ class DistributedClassifier(object):
dtype=x.dtype,
in_dim=flatten_dim,
param_attr=param_attr,
transpose_weight=True,
use_bias=False)
# normalize x
......@@ -169,12 +168,12 @@ class DistributedClassifier(object):
nshards=self.nranks,
shard_id=self.rank_id,
ignore_value=-1)
# TODO check necessary
shard_label.stop_gradient = True
# normalize weight
weight_l2 = ops.sqrt(nn.reduce_sum(nn.square(weight), dim=1))
norm_weight = nn.elementwise_div(weight, weight_l2, axis=0)
norm_weight = nn.transpose(norm_weight, perm=[1, 0])
weight_l2 = ops.sqrt(nn.reduce_sum(nn.square(weight), dim=0))
norm_weight = nn.elementwise_div(weight, weight_l2, axis=1)
shard_cos = nn.mul(norm_x_all, norm_weight, x_num_col_dims=1)
......@@ -183,6 +182,7 @@ class DistributedClassifier(object):
shard_one_hot = nn.one_hot(
shard_label, depth=self.shard_dim, allow_out_of_range=True)
# TODO check necessary
shard_one_hot.stop_gradient = True
diff = (margin_cos - shard_cos) * shard_one_hot
......
......@@ -23,7 +23,6 @@ from dist_classification_base import DistClassificationRunner
from test_dist_collective_base import runtime_main
# TODO(gavin1332) check whether it is necessary to transpose weight
class DistArcfaceClassificationRunner(DistClassificationRunner):
@classmethod
def add_other_arguments(cls, parser):
......@@ -33,15 +32,15 @@ class DistArcfaceClassificationRunner(DistClassificationRunner):
def __init__(self, args):
super(DistArcfaceClassificationRunner, self).__init__(args)
np.random.seed(1024)
self.param_value = np.random.rand(self.args.class_num,
self.args.feature_size)
self.param_value = np.random.rand(self.args.feature_size,
self.args.class_num)
def local_classify_subnet(self, feature, label):
args = self.args
weight = layers.create_parameter(
dtype=feature.dtype,
shape=[args.class_num, args.feature_size],
shape=[args.feature_size, args.class_num],
default_initializer=NumpyArrayInitializer(self.param_value),
is_bias=False)
......@@ -52,9 +51,8 @@ class DistArcfaceClassificationRunner(DistClassificationRunner):
norm_feature = layers.elementwise_div(feature, feature_l2, axis=0)
# normalize weight
weight_l2 = layers.sqrt(layers.reduce_sum(layers.square(weight), dim=1))
norm_weight = layers.elementwise_div(weight, weight_l2, axis=0)
norm_weight = layers.transpose(norm_weight, perm=[1, 0])
weight_l2 = layers.sqrt(layers.reduce_sum(layers.square(weight), dim=0))
norm_weight = layers.elementwise_div(weight, weight_l2, axis=1)
cos = layers.mul(norm_feature, norm_weight)
......@@ -76,8 +74,8 @@ class DistArcfaceClassificationRunner(DistClassificationRunner):
args = self.args
shard_dim = (args.class_num + args.nranks - 1) // args.nranks
shard_start = shard_dim * args.rank
rank_param_value = self.param_value[shard_start:(shard_start + shard_dim
), :]
rank_param_value = self.param_value[:, shard_start:(shard_start +
shard_dim)]
cost = layers.dist_algo._distributed_arcface_classify(
x=feature,
label=label,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册