From 5c153e0af2cdbd48acad6086c6ef85c23719e5da Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Tue, 12 May 2020 15:51:21 +0800 Subject: [PATCH] [Cherry-Pick]Add darts unittest (#271) * add darts unittest --- demo/darts/search.py | 1 - paddleslim/nas/darts/train_search.py | 2 - tests/test_darts.py | 97 ++++++++++++++++++++++++++++ tests/test_fsp_loss.py | 4 +- tests/test_l2_loss.py | 4 +- tests/test_loss.py | 4 +- tests/test_soft_label_loss.py | 4 +- 7 files changed, 105 insertions(+), 11 deletions(-) create mode 100644 tests/test_darts.py diff --git a/demo/darts/search.py b/demo/darts/search.py index d8a8c484..0b4dcce2 100644 --- a/demo/darts/search.py +++ b/demo/darts/search.py @@ -36,7 +36,6 @@ add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable add_arg('log_freq', int, 50, "Log frequency.") add_arg('use_multiprocess', bool, True, "Whether use multiprocess reader.") -add_arg('data', str, 'dataset/cifar10',"The dir of dataset.") add_arg('batch_size', int, 64, "Minibatch size.") add_arg('learning_rate', float, 0.025, "The start learning rate.") add_arg('momentum', float, 0.9, "Momentum.") diff --git a/paddleslim/nas/darts/train_search.py b/paddleslim/nas/darts/train_search.py index c8c70413..0dbf954b 100644 --- a/paddleslim/nas/darts/train_search.py +++ b/paddleslim/nas/darts/train_search.py @@ -206,8 +206,6 @@ class DARTSearch(object): if self.use_data_parallel: self.train_reader = fluid.contrib.reader.distributed_batch_reader( self.train_reader) - self.valid_reader = fluid.contrib.reader.distributed_batch_reader( - self.valid_reader) train_loader = fluid.io.DataLoader.from_generator( capacity=64, diff --git a/tests/test_darts.py b/tests/test_darts.py new file mode 100644 index 00000000..63383ea4 --- /dev/null +++ b/tests/test_darts.py @@ -0,0 +1,97 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import unittest +import paddle.fluid as fluid +import numpy as np +from paddleslim.nas.darts import DARTSearch +from layers import conv_bn_layer + + +class TestDARTS(unittest.TestCase): + def test_darts(self): + class SuperNet(fluid.dygraph.Layer): + def __init__(self): + super(SuperNet, self).__init__() + self._method = 'DARTS' + self._steps = 1 + self.stem = fluid.dygraph.nn.Conv2D( + num_channels=1, num_filters=3, filter_size=3, padding=1) + self.classifier = fluid.dygraph.nn.Linear( + input_dim=2352, output_dim=10) + self._multiplier = 4 + self._primitives = [ + 'none', 'max_pool_3x3', 'avg_pool_3x3', 'skip_connect', + 'sep_conv_3x3', 'sep_conv_5x5', 'dil_conv_3x3', + 'dil_conv_5x5' + ] + self._initialize_alphas() + + def _initialize_alphas(self): + self.alphas_normal = fluid.layers.create_parameter( + shape=[14, 8], dtype="float32") + self.alphas_reduce = fluid.layers.create_parameter( + shape=[14, 8], dtype="float32") + self._arch_parameters = [ + self.alphas_normal, + self.alphas_reduce, + ] + + def arch_parameters(self): + return self._arch_parameters + + def forward(self, input): + out = self.stem(input) * self.alphas_normal[0][ + 0] * self.alphas_reduce[0][0] + out = fluid.layers.reshape(out, [0, -1]) + logits = self.classifier(out) + return logits + + def _loss(self, input, label): + logits = self.forward(input) + return fluid.layers.reduce_mean( + fluid.layers.softmax_with_cross_entropy(logits, label)) + + def batch_generator(reader): + def wrapper(): + batch_data = [] + batch_label = [] + for sample in reader(): + image = np.array(sample[0]).reshape(1, 28, 28) + label = np.array(sample[1]).reshape(1) + batch_data.append(image) + batch_label.append(label) + if len(batch_data) == 128: + batch_data = np.array(batch_data, dtype='float32') + batch_label = np.array(batch_label, dtype='int64') + yield [batch_data, batch_label] + batch_data = [] + batch_label = [] + + return wrapper + + place = fluid.CUDAPlace(0) + with fluid.dygraph.guard(place): + model = SuperNet() + trainset = paddle.dataset.mnist.train() + validset = paddle.dataset.mnist.test() + train_reader = batch_generator(trainset) + valid_reader = batch_generator(validset) + searcher = DARTSearch( + model, train_reader, valid_reader, place, num_epochs=5) + searcher.train() + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_fsp_loss.py b/tests/test_fsp_loss.py index ec9b0364..a71cf143 100644 --- a/tests/test_fsp_loss.py +++ b/tests/test_fsp_loss.py @@ -19,8 +19,8 @@ from paddleslim.dist import merge, fsp_loss from layers import conv_bn_layer -class TestMerge(unittest.TestCase): - def test_merge(self): +class TestFSPLoss(unittest.TestCase): + def test_fsp_loss(self): student_main = fluid.Program() student_startup = fluid.Program() with fluid.program_guard(student_main, student_startup): diff --git a/tests/test_l2_loss.py b/tests/test_l2_loss.py index 49e89f53..b9f50479 100644 --- a/tests/test_l2_loss.py +++ b/tests/test_l2_loss.py @@ -19,8 +19,8 @@ from paddleslim.dist import merge, l2_loss from layers import conv_bn_layer -class TestMerge(unittest.TestCase): - def test_merge(self): +class TestL2Loss(unittest.TestCase): + def test_l2_loss(self): student_main = fluid.Program() student_startup = fluid.Program() with fluid.program_guard(student_main, student_startup): diff --git a/tests/test_loss.py b/tests/test_loss.py index b4cd4329..8afa5018 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -19,8 +19,8 @@ from paddleslim.dist import merge, loss from layers import conv_bn_layer -class TestMerge(unittest.TestCase): - def test_merge(self): +class TestLoss(unittest.TestCase): + def test_loss(self): student_main = fluid.Program() student_startup = fluid.Program() with fluid.program_guard(student_main, student_startup): diff --git a/tests/test_soft_label_loss.py b/tests/test_soft_label_loss.py index 22458200..965edee3 100644 --- a/tests/test_soft_label_loss.py +++ b/tests/test_soft_label_loss.py @@ -19,8 +19,8 @@ from paddleslim.dist import merge, soft_label_loss from layers import conv_bn_layer -class TestMerge(unittest.TestCase): - def test_merge(self): +class TestSoftLabelLoss(unittest.TestCase): + def test_soft_label_loss(self): student_main = fluid.Program() student_startup = fluid.Program() with fluid.program_guard(student_main, student_startup): -- GitLab