diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 838f0277aed6f1c90f318235379ba67b429b03c5..00b119327901affe9c1144e22cc1352f1a96d360 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -24,6 +24,8 @@ if(NOT WITH_DISTRIBUTE) LIST(REMOVE_ITEM TEST_OPS test_nce_remote_table_op) LIST(REMOVE_ITEM TEST_OPS test_hsigmoid_remote_table_op) LIST(REMOVE_ITEM TEST_OPS test_dist_fleet_ctr) + LIST(REMOVE_ITEM TEST_OPS test_dist_softmax_classification) + LIST(REMOVE_ITEM TEST_OPS test_dist_arcface_classification) endif(NOT WITH_DISTRIBUTE) diff --git a/python/paddle/fluid/tests/unittests/dist_arcface_classification.py b/python/paddle/fluid/tests/unittests/dist_arcface_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..ec043956d1dbdc159a2328fc7e3ea3a7ca8bc2e9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_arcface_classification.py @@ -0,0 +1,92 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.fluid.layers.collective as collective +from paddle.fluid.initializer import NumpyArrayInitializer +from test_dist_classification_base import DistClassificationRunner, runtime_main + + +# TODO donot transpose weight +class DistArcfaceClassificationRunner(DistClassificationRunner): + @classmethod + def add_arguments(cls, parser): + parser.add_argument('--arcface_margin', type=float, default=0.0) + parser.add_argument('--arcface_scale', type=float, default=1.0) + + def __init__(self, args): + super(DistArcfaceClassificationRunner, self).__init__(args) + np.random.seed(1024) + self.param_value = np.random.rand(args.class_num, args.feature_size) + + def local_classify_subnet(self, feature, label): + args = self.args + + weight = layers.create_parameter( + dtype=feature.dtype, + shape=[args.class_num, args.feature_size], + default_initializer=NumpyArrayInitializer(self.param_value), + is_bias=False) + + # normalize feature + feature_l2 = layers.sqrt( + layers.reduce_sum( + layers.square(feature), dim=1)) + norm_feature = layers.elementwise_div(feature, feature_l2, axis=0) + + # normalize weight + weight_l2 = layers.sqrt(layers.reduce_sum(layers.square(weight), dim=1)) + norm_weight = layers.elementwise_div(weight, weight_l2, axis=0) + norm_weight = layers.transpose(norm_weight, perm=[1, 0]) + + cos = layers.mul(norm_feature, norm_weight) + + theta = layers.acos(cos) + margin_cos = layers.cos(theta + args.arcface_margin) + + one_hot = layers.one_hot(label, depth=args.class_num) + + diff = (margin_cos - cos) * one_hot + target_cos = cos + diff + logit = layers.scale(target_cos, scale=args.arcface_scale) + + loss = layers.softmax_with_cross_entropy(logit, label) + cost = layers.mean(loss) + + return cost + + def parall_classify_subnet(self, feature, label): + args = self.args + shard_dim = (args.class_num + args.nranks - 1) // args.nranks + shard_start = shard_dim * args.rank + rank_param_value = self.param_value[shard_start:(shard_start + shard_dim + ), :] + cost = layers.collective._distributed_arcface_classify( + x=feature, + label=label, + class_num=args.class_num, + nranks=args.nranks, + rank_id=args.rank, + margin=args.arcface_margin, + logit_scale=args.arcface_scale, + param_attr=NumpyArrayInitializer(rank_param_value)) + return cost + + +if __name__ == "__main__": + runtime_main(DistArcfaceClassificationRunner) diff --git a/python/paddle/fluid/tests/unittests/test_dist_arcface_classification.py b/python/paddle/fluid/tests/unittests/test_dist_arcface_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..3dbfbf7306cf52291abd075c2e31dc2ffa3b3729 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_arcface_classification.py @@ -0,0 +1,38 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle.fluid as fluid +from test_dist_classification_base import TestDistClassificationBase + + +class TestDistArcfaceClassification(TestDistClassificationBase): + def test_training(self): + if fluid.core.is_compiled_with_cuda(): + self.compare_parall_to_local( + 'dist_arcface_classification.py', delta=1e-5) + + +class TestDistArcfaceClassificationParam(TestDistClassificationBase): + def append_common_cmd(self): + return '--arcface_margin 0.5 --arcface_scale 64' + + def test_training(self): + if fluid.core.is_compiled_with_cuda(): + self.compare_parall_to_local( + "dist_arcface_classification.py", delta=1e-5) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_dist_classification_base.py b/python/paddle/fluid/tests/unittests/test_dist_classification_base.py index b203088ab515f77503781ec2c2fad6ec11eb0e7a..4cef1c73ca32d8f5c084dd3e9a1410170ee1fa00 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_classification_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_classification_base.py @@ -36,23 +36,23 @@ DEFAULT_LR = 0.001 RUN_STEPS = 5 -def stdprint(value): +def print2pipe(value): if six.PY2: print(pickle.dumps(value)) else: sys.stdout.buffer.write(pickle.dumps(value)) -def log(ref, message, print2pipe=False): +def elog(ref, message, to_pipe=False): localtime = datetime.now().strftime('%Y-%m-%d %H:%M:%S') log_str = '[%s] [%s] %s' % (localtime, type(ref).__name__, message) - if print2pipe: + if to_pipe: if six.PY2: sys.stderr.write(pickle.dumps(log_str)) else: sys.stderr.buffer.write(pickle.dumps(log_str)) else: - sys.stderr.write(log_str + "\n") + print(log_str, file=sys.stderr) class DistClassificationRunner(object): @@ -64,8 +64,8 @@ class DistClassificationRunner(object): args.device_id = int(os.getenv('FLAGS_selected_gpus', '0')) self.args = args - def log(self, message, print2pipe=False): - log(self, message, print2pipe) + def elog(self, message, to_pipe=False): + elog(self, message, to_pipe) def local_classify_subnet(self, feature, label): raise NotImplementedError( @@ -85,11 +85,11 @@ class DistClassificationRunner(object): name='feature', shape=[args.feature_size], dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') if args.nranks <= 1: - log(self, 'build local network') + elog(self, 'build local network') loss = self.local_classify_subnet(feature, label) optimizer.minimize(loss) else: - log(self, 'build parallel network') + elog(self, 'build parallel network') loss = self.parall_classify_subnet(feature, label) # TODO why need batch size? optimizer_wrapper = DistributedClassificationOptimizer( @@ -120,8 +120,6 @@ class DistClassificationRunner(object): if i // args.batch_size == args.rank: rank_batch.append(sample) - log(self, rank_batch) - return rank_batch def transpile(self, main_prog, start_prog): @@ -142,22 +140,22 @@ class DistClassificationRunner(object): place = fluid.CUDAPlace(self.args.device_id) exe = fluid.Executor(place) exe.run(start_prog) - log(self, 'finish running startup program.') + elog(self, 'finish running startup program.') feeder = fluid.DataFeeder(feed_vars, place) - log(self, 'start to train') + elog(self, 'start to train') out_losses = [] for i in range(RUN_STEPS): losses = exe.run(main_prog, fetch_list=[loss], feed=feeder.feed(self.gen_rank_batch())) out_losses.append(losses[0][0]) - log(self, "step %d loss: %f" % (i, losses[0][0])) + elog(self, "step %d loss: %f" % (i, losses[0][0])) - log(self, 'finish training') + elog(self, 'finish training') - stdprint(out_losses) + print2pipe(out_losses) @classmethod def add_arguments(cls, parser): @@ -184,14 +182,10 @@ from contextlib import closing class TestDistClassificationBase(unittest.TestCase): - # override configurations in setUp - def setup_config(self): - raise NotImplementedError('tests should have setup_config implemented') - def setUp(self): self.nranks = 2 self.batch_size = DEFAULT_BATCH_SIZE - self.setup_config() + self.update_config() self.global_batch_size = self.batch_size * self.nranks self.endpoints = [ @@ -203,35 +197,48 @@ class TestDistClassificationBase(unittest.TestCase): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(('', 0)) - log(self, 'socket port: %s' % s.getsockname()[1]) + elog(self, 'socket port: %s' % s.getsockname()[1]) port = s.getsockname()[1] return port + # override configurations in setUp + def update_config(self): + pass + + def append_common_cmd(self): + return '' + + def append_local_cmd(self): + return '' + + def append_parall_cmd(self): + return '' + def run_local(self, train_script, user_env): env = {} cmd = '%s -u %s --batch_size %d' % (sys.executable, train_script, self.global_batch_size) + if self.append_common_cmd(): + cmd += ' ' + self.append_common_cmd().strip() + if self.append_local_cmd(): + cmd += ' ' + self.append_local_cmd().strip() + if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': env['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '') cmd += ' -m coverage run --branch -p' env.update(user_env) - log(self, 'local_cmd: %s' % cmd) - log(self, 'local_env: %s' % env) + elog(self, 'local_cmd: %s' % cmd) + elog(self, 'local_env: %s' % env) ferr = open('/tmp/local.log', 'w') proc = subprocess.Popen( - cmd.split(' '), - stdout=subprocess.PIPE, - #stderr=subprocess.PIPE, - stderr=ferr, - env=env) + cmd.split(' '), stdout=subprocess.PIPE, stderr=ferr, env=env) out, err = proc.communicate() ferr.close() - log(self, 'local_stdout: %s' % pickle.loads(out)) - #log(self, 'local_stderr: %s' % pickle.loads(err)) + elog(self, 'local_stdout: %s' % pickle.loads(out)) return pickle.loads(out) @@ -250,6 +257,10 @@ class TestDistClassificationBase(unittest.TestCase): def run_parall(self, train_script, user_env): cmd = '%s -u %s --batch_size %d' % (sys.executable, train_script, self.batch_size) + if self.append_common_cmd(): + cmd += ' ' + self.append_common_cmd().strip() + if self.append_parall_cmd(): + cmd += ' ' + self.append_parall_cmd().strip() if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': cmd += ' -m coverage run --branch -p' @@ -258,8 +269,8 @@ class TestDistClassificationBase(unittest.TestCase): for rank in range(self.nranks): env = self.get_parall_env(rank) env.update(user_env) - log(self, '[r%d] parall_cmd: %s' % (rank, cmd)) - log(self, '[r%d] parall_env: %s' % (rank, env)) + elog(self, '[r%d] parall_cmd: %s' % (rank, cmd)) + elog(self, '[r%d] parall_env: %s' % (rank, env)) ferr = open('/tmp/parall_tr%d.log' % rank, 'w') proc = subprocess.Popen( @@ -276,7 +287,6 @@ class TestDistClassificationBase(unittest.TestCase): ferrs[rank].close() outs.append(out) - #log(self, '[r%d] parall_stderr: %s' % (rank, pickle.loads(err))) return [pickle.loads(outs[i]) for i in range(self.nranks)] @@ -296,10 +306,10 @@ class TestDistClassificationBase(unittest.TestCase): local_losses = self.run_local(train_script, required_envs) parall_losses = self.run_parall(train_script, required_envs) + elog(self, '======= local_loss : parall_loss =======') for i in range(RUN_STEPS): local_loss = local_losses[i] parall_loss = sum( [parall_losses[j][i] for j in range(self.nranks)]) / self.nranks - log(self, '======= local_loss : parall_loss =======') - log(self, '======= %s : %s =======' % (local_loss, parall_loss)) + elog(self, '======= %s : %s =======' % (local_loss, parall_loss)) self.assertAlmostEqual(local_loss, parall_loss, delta=delta) diff --git a/python/paddle/fluid/tests/unittests/test_dist_softmax_classification.py b/python/paddle/fluid/tests/unittests/test_dist_softmax_classification.py index bc371a13cd524897b318fed762b81b5ac9c8f4d0..c872412ed03d92fa1da5b6ed3203ed1b1b364fc2 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_softmax_classification.py +++ b/python/paddle/fluid/tests/unittests/test_dist_softmax_classification.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +import paddle.fluid as fluid from test_dist_classification_base import TestDistClassificationBase @@ -21,10 +22,9 @@ class TestDistSoftmaxClassification(TestDistClassificationBase): pass def test_dist_train(self): - import paddle.fluid as fluid if fluid.core.is_compiled_with_cuda(): self.compare_parall_to_local( - "dist_softmax_classification.py", delta=1e-4) + "dist_softmax_classification.py", delta=1e-5) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_shard_index_op.py b/python/paddle/fluid/tests/unittests/test_shard_index_op.py index fd3c0a5458ab8cc675b4de43516164b6386a4882..9ccf1f254a5566bfebce1d18873b76f5961ff65b 100644 --- a/python/paddle/fluid/tests/unittests/test_shard_index_op.py +++ b/python/paddle/fluid/tests/unittests/test_shard_index_op.py @@ -31,7 +31,7 @@ def common_setup(self, index_num, nshards, shard_id, ignore_value): x = [np.random.randint(0, index_num - 1) for i in range(N)] x = np.array(x).astype('int32').reshape([N, 1]) - shard_size = index_num // nshards + shard_size = (index_num + nshards - 1) // nshards out = np.zeros(shape=x.shape).astype('int32') for i in range(N): if x[i] // shard_size == shard_id: