提交 0d953806 编写于 作者: G gavin1332

Do not transpose weight in dist_arcface_classification algorithm, without affect performance.

test=develop
test=document_preview
上级 2be9036f
...@@ -151,7 +151,6 @@ class DistributedClassifier(object): ...@@ -151,7 +151,6 @@ class DistributedClassifier(object):
dtype=x.dtype, dtype=x.dtype,
in_dim=flatten_dim, in_dim=flatten_dim,
param_attr=param_attr, param_attr=param_attr,
transpose_weight=True,
use_bias=False) use_bias=False)
# normalize x # normalize x
...@@ -169,12 +168,12 @@ class DistributedClassifier(object): ...@@ -169,12 +168,12 @@ class DistributedClassifier(object):
nshards=self.nranks, nshards=self.nranks,
shard_id=self.rank_id, shard_id=self.rank_id,
ignore_value=-1) ignore_value=-1)
# TODO check necessary
shard_label.stop_gradient = True shard_label.stop_gradient = True
# normalize weight # normalize weight
weight_l2 = ops.sqrt(nn.reduce_sum(nn.square(weight), dim=1)) weight_l2 = ops.sqrt(nn.reduce_sum(nn.square(weight), dim=0))
norm_weight = nn.elementwise_div(weight, weight_l2, axis=0) norm_weight = nn.elementwise_div(weight, weight_l2, axis=1)
norm_weight = nn.transpose(norm_weight, perm=[1, 0])
shard_cos = nn.mul(norm_x_all, norm_weight, x_num_col_dims=1) shard_cos = nn.mul(norm_x_all, norm_weight, x_num_col_dims=1)
...@@ -183,6 +182,7 @@ class DistributedClassifier(object): ...@@ -183,6 +182,7 @@ class DistributedClassifier(object):
shard_one_hot = nn.one_hot( shard_one_hot = nn.one_hot(
shard_label, depth=self.shard_dim, allow_out_of_range=True) shard_label, depth=self.shard_dim, allow_out_of_range=True)
# TODO check necessary
shard_one_hot.stop_gradient = True shard_one_hot.stop_gradient = True
diff = (margin_cos - shard_cos) * shard_one_hot diff = (margin_cos - shard_cos) * shard_one_hot
......
...@@ -23,7 +23,6 @@ from dist_classification_base import DistClassificationRunner ...@@ -23,7 +23,6 @@ from dist_classification_base import DistClassificationRunner
from test_dist_collective_base import runtime_main from test_dist_collective_base import runtime_main
# TODO(gavin1332) check whether it is necessary to transpose weight
class DistArcfaceClassificationRunner(DistClassificationRunner): class DistArcfaceClassificationRunner(DistClassificationRunner):
@classmethod @classmethod
def add_other_arguments(cls, parser): def add_other_arguments(cls, parser):
...@@ -33,15 +32,15 @@ class DistArcfaceClassificationRunner(DistClassificationRunner): ...@@ -33,15 +32,15 @@ class DistArcfaceClassificationRunner(DistClassificationRunner):
def __init__(self, args): def __init__(self, args):
super(DistArcfaceClassificationRunner, self).__init__(args) super(DistArcfaceClassificationRunner, self).__init__(args)
np.random.seed(1024) np.random.seed(1024)
self.param_value = np.random.rand(self.args.class_num, self.param_value = np.random.rand(self.args.feature_size,
self.args.feature_size) self.args.class_num)
def local_classify_subnet(self, feature, label): def local_classify_subnet(self, feature, label):
args = self.args args = self.args
weight = layers.create_parameter( weight = layers.create_parameter(
dtype=feature.dtype, dtype=feature.dtype,
shape=[args.class_num, args.feature_size], shape=[args.feature_size, args.class_num],
default_initializer=NumpyArrayInitializer(self.param_value), default_initializer=NumpyArrayInitializer(self.param_value),
is_bias=False) is_bias=False)
...@@ -52,9 +51,8 @@ class DistArcfaceClassificationRunner(DistClassificationRunner): ...@@ -52,9 +51,8 @@ class DistArcfaceClassificationRunner(DistClassificationRunner):
norm_feature = layers.elementwise_div(feature, feature_l2, axis=0) norm_feature = layers.elementwise_div(feature, feature_l2, axis=0)
# normalize weight # normalize weight
weight_l2 = layers.sqrt(layers.reduce_sum(layers.square(weight), dim=1)) weight_l2 = layers.sqrt(layers.reduce_sum(layers.square(weight), dim=0))
norm_weight = layers.elementwise_div(weight, weight_l2, axis=0) norm_weight = layers.elementwise_div(weight, weight_l2, axis=1)
norm_weight = layers.transpose(norm_weight, perm=[1, 0])
cos = layers.mul(norm_feature, norm_weight) cos = layers.mul(norm_feature, norm_weight)
...@@ -76,8 +74,8 @@ class DistArcfaceClassificationRunner(DistClassificationRunner): ...@@ -76,8 +74,8 @@ class DistArcfaceClassificationRunner(DistClassificationRunner):
args = self.args args = self.args
shard_dim = (args.class_num + args.nranks - 1) // args.nranks shard_dim = (args.class_num + args.nranks - 1) // args.nranks
shard_start = shard_dim * args.rank shard_start = shard_dim * args.rank
rank_param_value = self.param_value[shard_start:(shard_start + shard_dim rank_param_value = self.param_value[:, shard_start:(shard_start +
), :] shard_dim)]
cost = layers.dist_algo._distributed_arcface_classify( cost = layers.dist_algo._distributed_arcface_classify(
x=feature, x=feature,
label=label, label=label,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册