test_asp_optimize_static.py 9.6 KB
Newer Older
1 2
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 NVIDIA Corporation.  All rights reserved.
3
#
4 5 6
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
17 18 19

import numpy as np

20
import paddle
21 22
from paddle import fluid
from paddle.fluid import core
23
from paddle.incubate.asp import ASPHelper
24 25 26 27

paddle.enable_static()


28
class TestASPStaticOptimize(unittest.TestCase):
29 30 31 32 33
    def setUp(self):
        self.main_program = fluid.Program()
        self.startup_program = fluid.Program()

        def build_model():
34
            img = paddle.static.data(
35 36
                name='img', shape=[None, 3, 24, 24], dtype='float32'
            )
37 38 39
            label = paddle.static.data(
                name='label', shape=[None, 1], dtype='int64'
            )
40
            hidden = paddle.static.nn.conv2d(
41 42
                input=img, num_filters=4, filter_size=3, padding=2, act="relu"
            )
C
Charles-hit 已提交
43 44 45 46
            hidden = paddle.static.nn.fc(x=hidden, size=32, activation='relu')
            prediction = paddle.static.nn.fc(
                x=hidden, size=10, activation='softmax'
            )
47 48 49 50
            return img, label, prediction

        with fluid.program_guard(self.main_program, self.startup_program):
            self.img, self.label, predict = build_model()
51
            self.loss = paddle.mean(
52 53 54 55 56 57
                paddle.nn.functional.cross_entropy(
                    input=predict,
                    label=self.label,
                    reduction='none',
                    use_softmax=False,
                )
58
            )
59 60 61 62 63 64 65 66 67 68 69 70 71 72
            self.optimizer = fluid.optimizer.SGD(learning_rate=0.01)

    def test_get_not_ASP_relevant_vars(self):
        def check_params(params, params_from_asp):
            if len(params_from_asp) != len(params):
                return False

            for i, p in enumerate(params_from_asp):
                if p.name != params[i].name:
                    return False
            return True

        params = self.main_program.global_block().all_parameters()
        params_from_asp = ASPHelper._get_not_ASP_relevant_vars(
73 74
            self.main_program
        )
75 76 77
        self.assertTrue(check_params(params, params_from_asp))

        with fluid.program_guard(self.main_program, self.startup_program):
78 79 80 81 82 83
            ASPHelper._minimize(
                self.optimizer,
                self.loss,
                self.main_program,
                self.startup_program,
            )
84
        params_from_asp_after_opt = ASPHelper._get_not_ASP_relevant_vars(
85 86
            self.main_program
        )
87 88 89 90 91 92
        self.assertTrue(check_params(params, params_from_asp_after_opt))

    def test_is_supported_layers(self):
        program = paddle.static.default_main_program()

        names = [
93 94 95 96 97 98 99 100 101 102 103 104
            'embedding_0.w_0',
            'fack_layer_0.w_0',
            'conv2d_0.w_0',
            'conv2d_0.b_0',
            'conv2d_1.w_0',
            'conv2d_1.b_0',
            'fc_0.w_0',
            'fc_0.b_0',
            'fc_1.w_0',
            'fc_1.b_0',
            'linear_2.w_0',
            'linear_2.b_0',
105 106
        ]
        ref = [
107 108 109 110 111 112 113 114 115 116 117 118
            False,
            False,
            True,
            False,
            True,
            False,
            True,
            False,
            True,
            False,
            True,
            False,
119 120 121
        ]
        for i, name in enumerate(names):
            self.assertTrue(
122 123
                ref[i] == ASPHelper._is_supported_layer(program, name)
            )
124

125
        paddle.incubate.asp.set_excluded_layers(['fc_1', 'conv2d_0'], program)
126
        ref = [
127 128 129 130 131 132 133 134 135 136 137 138
            False,
            False,
            False,
            False,
            True,
            False,
            True,
            False,
            False,
            False,
            True,
            False,
139 140 141
        ]
        for i, name in enumerate(names):
            self.assertTrue(
142 143
                ref[i] == ASPHelper._is_supported_layer(program, name)
            )
144

145
        paddle.incubate.asp.reset_excluded_layers(program)
146
        ref = [
147 148 149 150 151 152 153 154 155 156 157 158
            False,
            False,
            True,
            False,
            True,
            False,
            True,
            False,
            True,
            False,
            True,
            False,
159 160 161
        ]
        for i, name in enumerate(names):
            self.assertTrue(
162 163
                ref[i] == ASPHelper._is_supported_layer(program, name)
            )
164 165

    def test_decorate(self):
166
        param_names = self.__get_param_names(
167 168
            self.main_program.global_block().all_parameters()
        )
169
        with fluid.program_guard(self.main_program, self.startup_program):
170
            self.optimizer = paddle.incubate.asp.decorate(self.optimizer)
171 172
            self.optimizer.minimize(self.loss, self.startup_program)
        param_names_after_minimize = self.__get_param_names(
173 174
            self.main_program.global_block().all_parameters()
        )
175

176 177 178
        self.__check_mask_variables_and_ops(
            param_names, param_names_after_minimize
        )
179 180 181

    def test_asp_training(self):
        with fluid.program_guard(self.main_program, self.startup_program):
182
            self.optimizer = paddle.incubate.asp.decorate(self.optimizer)
183 184 185 186 187 188 189 190 191
            self.optimizer.minimize(self.loss, self.startup_program)

        place = paddle.CPUPlace()
        if core.is_compiled_with_cuda():
            place = paddle.CUDAPlace(0)
        exe = fluid.Executor(place)
        feeder = fluid.DataFeeder(feed_list=[self.img, self.label], place=place)

        exe.run(self.startup_program)
192
        paddle.incubate.asp.prune_model(self.main_program)
193

194 195 196 197
        data = (
            np.random.randn(32, 3, 24, 24),
            np.random.randint(10, size=(32, 1)),
        )
198 199 200 201
        exe.run(self.main_program, feed=feeder.feed([data]))

        for param in self.main_program.global_block().all_parameters():
            if ASPHelper._is_supported_layer(self.main_program, param.name):
202 203 204 205 206 207
                mat = np.array(
                    fluid.global_scope().find_var(param.name).get_tensor()
                )
                if (len(param.shape) == 4 and param.shape[1] < 4) or (
                    len(param.shape) == 2 and param.shape[0] < 4
                ):
M
minghaoBD 已提交
208
                    self.assertFalse(
209
                        paddle.incubate.asp.check_sparsity(mat.T, n=2, m=4)
210
                    )
M
minghaoBD 已提交
211 212
                else:
                    self.assertTrue(
213
                        paddle.incubate.asp.check_sparsity(mat.T, n=2, m=4)
214
                    )
215 216 217 218 219

    def test_asp_training_with_amp(self):
        if core.is_compiled_with_cuda():
            place = paddle.CUDAPlace(0)
            with fluid.program_guard(self.main_program, self.startup_program):
220
                self.optimizer = paddle.static.amp.decorate(self.optimizer)
221
                self.optimizer = paddle.incubate.asp.decorate(self.optimizer)
222 223 224
                self.optimizer.minimize(self.loss, self.startup_program)

            exe = fluid.Executor(place)
225 226 227
            feeder = fluid.DataFeeder(
                feed_list=[self.img, self.label], place=place
            )
228 229

            exe.run(self.startup_program)
230
            paddle.incubate.asp.prune_model(self.main_program)
231

232 233 234 235
            data = (
                np.random.randn(32, 3, 24, 24),
                np.random.randint(10, size=(32, 1)),
            )
236 237 238 239
            exe.run(self.main_program, feed=feeder.feed([data]))

            for param in self.main_program.global_block().all_parameters():
                if ASPHelper._is_supported_layer(self.main_program, param.name):
240 241 242 243 244 245
                    mat = np.array(
                        fluid.global_scope().find_var(param.name).get_tensor()
                    )
                    if (len(param.shape) == 4 and param.shape[1] < 4) or (
                        len(param.shape) == 2 and param.shape[0] < 4
                    ):
M
minghaoBD 已提交
246
                        self.assertFalse(
247
                            paddle.incubate.asp.check_sparsity(mat.T, n=2, m=4)
248
                        )
M
minghaoBD 已提交
249 250
                    else:
                        self.assertTrue(
251
                            paddle.incubate.asp.check_sparsity(mat.T, n=2, m=4)
252
                        )
253 254 255 256 257 258 259

    def __get_param_names(self, params):
        param_names = []
        for p in params:
            param_names.append(p.name)
        return param_names

260 261 262
    def __check_mask_variables_and_ops(
        self, param_names, param_names_after_minimize
    ):
263
        for n in param_names:
264 265 266 267 268
            self.assertFalse(
                ASPHelper._is_supported_layer(self.main_program, n)
                and ASPHelper._get_mask_name(n)
                not in param_names_after_minimize
            )
269 270 271 272 273 274 275 276

        mask_names = []
        for n in param_names:
            if ASPHelper._is_supported_layer(self.main_program, n):
                mask_names.append(ASPHelper._get_mask_name(n))

        masking_ops = []
        for op in self.main_program.global_block().ops:
277
            if op.type == 'elementwise_mul' and op.input('Y')[0] in mask_names:
278 279 280 281 282 283 284 285 286 287 288 289
                masking_ops.append(op.input('Y')[0])

        self.assertTrue(len(masking_ops) == len(mask_names))
        for n in masking_ops:
            self.assertTrue(n in mask_names)

        for n in mask_names:
            self.assertTrue(n in masking_ops)


if __name__ == '__main__':
    unittest.main()