auto_pruner.py 9.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import socket
import logging
17 18 19 20
import numpy as np
import paddle.fluid as fluid
from .pruner import Pruner
from ..core import VarWrapper, OpWrapper, GraphWrapper
21 22 23
from ..common import SAController
from ..common import get_logger
from ..analysis import flops
24

25 26
from ..common import ControllerServer
from ..common import ControllerClient
27 28 29

__all__ = ["AutoPruner"]

30 31
_logger = get_logger(__name__, level=logging.INFO)

32 33

class AutoPruner(object):
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
    """
    Search a group of ratios used to prune program.

    Args:
        program(Program): The program to be pruned.
        scope(Scope): The scope to be pruned.
        place(fluid.Place): The device place of parameters.
        params(list<str>): The names of parameters to be pruned.
        init_ratios(list<float>|float): Init ratios used to pruned parameters in `params`.
            List means ratios used for pruning each parameter in `params`.
            The length of `init_ratios` should be equal to length of params when `init_ratios` is a list. 
            If it is a scalar, all the parameters in `params` will be pruned by uniform ratio.
            None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None.
        pruned_flops(float): The percent of FLOPS to be pruned. Default: None.
        pruned_latency(float): The percent of latency to be pruned. Default: None.
        server_addr(tuple): A tuple of server ip and server port for controller server. 
        init_temperature(float): The init temperature used in simulated annealing search strategy.
        reduce_rate(float): The decay rate used in simulated annealing search strategy.
        max_try_times(int): The max number of trying to generate legal tokens.
        max_client_num(int): The max number of connections of controller server.
        search_steps(int): The steps of searching.
        max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`.
            List means max ratios for each parameter in `params`.
            The length of `max_ratios` should be equal to length of params when `max_ratios` is a list.
            If it is a scalar, it will used for all the parameters in `params`.
        min_ratios(float|list<float>): Min ratios used to pruned parameters in `params`.
            List means min ratios for each parameter in `params`.
            The length of `min_ratios` should be equal to length of params when `min_ratios` is a list.
            If it is a scalar, it will used for all the parameters in `params`.
        key(str): Identity used in communication between controller server and clients.
        is_server(bool): Whether current host is controller server. Default: True.
        """

67
    def __init__(self,
W
wanghaoshuang 已提交
68
                 program,
69 70
                 scope,
                 place,
71 72 73 74
                 params=[],
                 init_ratios=None,
                 pruned_flops=0.5,
                 pruned_latency=None,
W
wanghaoshuang 已提交
75 76 77
                 server_addr=("", 0),
                 init_temperature=100,
                 reduce_rate=0.85,
W
wanghaoshuang 已提交
78
                 max_try_times=300,
W
wanghaoshuang 已提交
79 80 81 82
                 max_client_num=10,
                 search_steps=300,
                 max_ratios=[0.9],
                 min_ratios=[0],
83 84 85
                 key="auto_pruner",
                 is_server=True):

W
wanghaoshuang 已提交
86
        self._program = program
87 88
        self._scope = scope
        self._place = place
89 90 91 92
        self._params = params
        self._init_ratios = init_ratios
        self._pruned_flops = pruned_flops
        self._pruned_latency = pruned_latency
W
wanghaoshuang 已提交
93 94
        self._reduce_rate = reduce_rate
        self._init_temperature = init_temperature
W
wanghaoshuang 已提交
95
        self._max_try_times = max_try_times
96
        self._is_server = is_server
W
wanghaoshuang 已提交
97 98

        self._range_table = self._get_range_table(min_ratios, max_ratios)
99

100
        self._pruner = Pruner()
W
wanghaoshuang 已提交
101 102
        if self._pruned_flops:
            self._base_flops = flops(program)
103 104 105 106
            self._max_flops = self._base_flops * (1 - self._pruned_flops)
            _logger.info(
                "AutoPruner - base flops: {}; pruned_flops: {}; max_flops: {}".
                format(self._base_flops, self._pruned_flops, self._max_flops))
W
wanghaoshuang 已提交
107 108
        if self._pruned_latency:
            self._base_latency = latency(program)
109

W
wanghaoshuang 已提交
110 111
        if self._init_ratios is None:
            self._init_ratios = self._get_init_ratios(
112 113
                self, _program, self._params, self._pruned_flops,
                self._pruned_latency)
W
wanghaoshuang 已提交
114
        init_tokens = self._ratios2tokens(self._init_ratios)
W
wanghaoshuang 已提交
115
        _logger.info("range table: {}".format(self._range_table))
Y
yukavio 已提交
116 117 118 119 120 121 122
        controller = SAController(
            self._range_table,
            self._reduce_rate,
            self._init_temperature,
            self._max_try_times,
            init_tokens,
            constrain_func=self._constrain_func)
123 124 125 126

        server_ip, server_port = server_addr
        if server_ip == None or server_ip == "":
            server_ip = self._get_host_ip()
127

W
wanghaoshuang 已提交
128 129
        self._controller_server = ControllerServer(
            controller=controller,
130 131 132
            address=(server_ip, server_port),
            max_client_num=max_client_num,
            search_steps=search_steps,
W
wanghaoshuang 已提交
133
            key=key)
134

135 136 137
        # create controller server
        if self._is_server:
            self._controller_server.start()
138

139 140 141 142
        self._controller_client = ControllerClient(
            self._controller_server.ip(),
            self._controller_server.port(),
            key=key)
W
wanghaoshuang 已提交
143 144

        self._iter = 0
145 146 147 148
        self._param_backup = {}

    def _get_host_ip(self):
        return socket.gethostbyname(socket.gethostname())
149

W
wanghaoshuang 已提交
150 151 152 153 154 155
    def _get_init_ratios(self, program, params, pruned_flops, pruned_latency):
        pass

    def _get_range_table(self, min_ratios, max_ratios):
        assert isinstance(min_ratios, list) or isinstance(min_ratios, float)
        assert isinstance(max_ratios, list) or isinstance(max_ratios, float)
W
wanghaoshuang 已提交
156 157 158 159
        min_ratios = min_ratios if isinstance(
            min_ratios, list) else [min_ratios] * len(self._params)
        max_ratios = max_ratios if isinstance(
            max_ratios, list) else [max_ratios] * len(self._params)
W
wanghaoshuang 已提交
160 161 162 163 164 165
        min_tokens = self._ratios2tokens(min_ratios)
        max_tokens = self._ratios2tokens(max_ratios)
        return (min_tokens, max_tokens)

    def _constrain_func(self, tokens):
        ratios = self._tokens2ratios(tokens)
166
        pruned_program, _, _ = self._pruner.prune(
167 168
            self._program,
            self._scope,
W
wanghaoshuang 已提交
169
            self._params,
170 171
            ratios,
            place=self._place,
W
wanghaoshuang 已提交
172
            only_graph=True)
173 174 175 176 177 178 179 180 181
        current_flops = flops(pruned_program)
        result = current_flops < self._max_flops
        if not result:
            _logger.info("Failed try ratios: {}; flops: {}; max_flops: {}".
                         format(ratios, current_flops, self._max_flops))
        else:
            _logger.info("Success try ratios: {}; flops: {}; max_flops: {}".
                         format(ratios, current_flops, self._max_flops))
        return result
W
wanghaoshuang 已提交
182

W
wanghaoshuang 已提交
183
    def prune(self, program, eval_program=None):
W
wanghaoshuang 已提交
184 185
        """
        Prune program with latest tokens generated by controller.
186

W
wanghaoshuang 已提交
187 188
        Args:
            program(fluid.Program): The program to be pruned.
189

W
wanghaoshuang 已提交
190
        Returns:
191
            paddle.fluid.Program: The pruned program.
W
wanghaoshuang 已提交
192
        """
W
wanghaoshuang 已提交
193
        self._current_ratios = self._next_ratios()
194
        pruned_program, _, _ = self._pruner.prune(
195 196 197 198 199
            program,
            self._scope,
            self._params,
            self._current_ratios,
            place=self._place,
W
wanghaoshuang 已提交
200
            only_graph=False,
201
            param_backup=self._param_backup)
W
wanghaoshuang 已提交
202 203
        pruned_val_program = None
        if eval_program is not None:
204
            pruned_val_program, _, _ = self._pruner.prune(
W
wanghaoshuang 已提交
205 206 207 208 209 210 211
                program,
                self._scope,
                self._params,
                self._current_ratios,
                place=self._place,
                only_graph=True)

212 213
        _logger.info("AutoPruner - pruned ratios: {}".format(
            self._current_ratios))
W
wanghaoshuang 已提交
214
        return pruned_program, pruned_val_program
215 216

    def reward(self, score):
W
wanghaoshuang 已提交
217 218
        """
        Return reward of current pruned program.
219

W
wanghaoshuang 已提交
220
        Args:
221
            float: The score of pruned program.
W
wanghaoshuang 已提交
222
        """
223 224 225
        self._restore(self._scope)
        self._param_backup = {}
        tokens = self._ratios2tokens(self._current_ratios)
W
wanghaoshuang 已提交
226
        self._controller_client.update(tokens, score, self._iter)
W
wanghaoshuang 已提交
227
        self._iter += 1
228

229 230 231 232 233
    def _restore(self, scope):
        for param_name in self._param_backup.keys():
            param_t = scope.find_var(param_name).get_tensor()
            param_t.set(self._param_backup[param_name], self._place)

W
wanghaoshuang 已提交
234
    def _next_ratios(self):
235 236 237 238 239 240 241 242
        tokens = self._controller_client.next_tokens()
        return self._tokens2ratios(tokens)

    def _ratios2tokens(self, ratios):
        """Convert pruned ratios to tokens.
        """
        return [int(ratio / 0.01) for ratio in ratios]

243
    def _tokens2ratios(self, tokens):
244 245 246
        """Convert tokens to pruned ratios.
        """
        return [token * 0.01 for token in tokens]