auto_pruner.py 9.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import socket
import logging
17 18 19 20
import numpy as np
import paddle.fluid as fluid
from .pruner import Pruner
from ..core import VarWrapper, OpWrapper, GraphWrapper
21 22 23
from ..common import SAController
from ..common import get_logger
from ..analysis import flops
24

25 26
from ..common import ControllerServer
from ..common import ControllerClient
27 28 29

__all__ = ["AutoPruner"]

30 31
_logger = get_logger(__name__, level=logging.INFO)

32 33

class AutoPruner(object):
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
    """
    Search a group of ratios used to prune program.

    Args:
        program(Program): The program to be pruned.
        scope(Scope): The scope to be pruned.
        place(fluid.Place): The device place of parameters.
        params(list<str>): The names of parameters to be pruned.
        init_ratios(list<float>|float): Init ratios used to pruned parameters in `params`.
            List means ratios used for pruning each parameter in `params`.
            The length of `init_ratios` should be equal to length of params when `init_ratios` is a list. 
            If it is a scalar, all the parameters in `params` will be pruned by uniform ratio.
            None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None.
        pruned_flops(float): The percent of FLOPS to be pruned. Default: None.
        pruned_latency(float): The percent of latency to be pruned. Default: None.
        server_addr(tuple): A tuple of server ip and server port for controller server. 
        init_temperature(float): The init temperature used in simulated annealing search strategy.
        reduce_rate(float): The decay rate used in simulated annealing search strategy.
        max_try_times(int): The max number of trying to generate legal tokens.
        max_client_num(int): The max number of connections of controller server.
        search_steps(int): The steps of searching.
        max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`.
            List means max ratios for each parameter in `params`.
            The length of `max_ratios` should be equal to length of params when `max_ratios` is a list.
            If it is a scalar, it will used for all the parameters in `params`.
        min_ratios(float|list<float>): Min ratios used to pruned parameters in `params`.
            List means min ratios for each parameter in `params`.
            The length of `min_ratios` should be equal to length of params when `min_ratios` is a list.
            If it is a scalar, it will used for all the parameters in `params`.
        key(str): Identity used in communication between controller server and clients.
        is_server(bool): Whether current host is controller server. Default: True.
        """

67
    def __init__(self,
W
wanghaoshuang 已提交
68
                 program,
69 70
                 scope,
                 place,
71 72 73 74
                 params=[],
                 init_ratios=None,
                 pruned_flops=0.5,
                 pruned_latency=None,
W
wanghaoshuang 已提交
75 76 77
                 server_addr=("", 0),
                 init_temperature=100,
                 reduce_rate=0.85,
W
wanghaoshuang 已提交
78
                 max_try_times=300,
W
wanghaoshuang 已提交
79 80 81 82
                 max_client_num=10,
                 search_steps=300,
                 max_ratios=[0.9],
                 min_ratios=[0],
83 84 85
                 key="auto_pruner",
                 is_server=True):

W
wanghaoshuang 已提交
86
        self._program = program
87 88
        self._scope = scope
        self._place = place
89 90 91 92
        self._params = params
        self._init_ratios = init_ratios
        self._pruned_flops = pruned_flops
        self._pruned_latency = pruned_latency
W
wanghaoshuang 已提交
93 94
        self._reduce_rate = reduce_rate
        self._init_temperature = init_temperature
W
wanghaoshuang 已提交
95
        self._max_try_times = max_try_times
96
        self._is_server = is_server
W
wanghaoshuang 已提交
97 98

        self._range_table = self._get_range_table(min_ratios, max_ratios)
99

100
        self._pruner = Pruner()
W
wanghaoshuang 已提交
101 102
        if self._pruned_flops:
            self._base_flops = flops(program)
103 104 105 106
            self._max_flops = self._base_flops * (1 - self._pruned_flops)
            _logger.info(
                "AutoPruner - base flops: {}; pruned_flops: {}; max_flops: {}".
                format(self._base_flops, self._pruned_flops, self._max_flops))
W
wanghaoshuang 已提交
107 108
        if self._pruned_latency:
            self._base_latency = latency(program)
109

W
wanghaoshuang 已提交
110 111
        if self._init_ratios is None:
            self._init_ratios = self._get_init_ratios(
112 113
                self, _program, self._params, self._pruned_flops,
                self._pruned_latency)
W
wanghaoshuang 已提交
114
        init_tokens = self._ratios2tokens(self._init_ratios)
W
wanghaoshuang 已提交
115
        _logger.info("range table: {}".format(self._range_table))
116
        controller = SAController(self._range_table, self._reduce_rate,
W
wanghaoshuang 已提交
117
                                  self._init_temperature, self._max_try_times,
118 119 120 121 122
                                  init_tokens, self._constrain_func)

        server_ip, server_port = server_addr
        if server_ip == None or server_ip == "":
            server_ip = self._get_host_ip()
123

W
wanghaoshuang 已提交
124 125
        self._controller_server = ControllerServer(
            controller=controller,
126 127 128
            address=(server_ip, server_port),
            max_client_num=max_client_num,
            search_steps=search_steps,
W
wanghaoshuang 已提交
129
            key=key)
130

131 132 133
        # create controller server
        if self._is_server:
            self._controller_server.start()
134

135 136 137 138
        self._controller_client = ControllerClient(
            self._controller_server.ip(),
            self._controller_server.port(),
            key=key)
W
wanghaoshuang 已提交
139 140

        self._iter = 0
141 142 143 144
        self._param_backup = {}

    def _get_host_ip(self):
        return socket.gethostbyname(socket.gethostname())
145

W
wanghaoshuang 已提交
146 147 148 149 150 151
    def _get_init_ratios(self, program, params, pruned_flops, pruned_latency):
        pass

    def _get_range_table(self, min_ratios, max_ratios):
        assert isinstance(min_ratios, list) or isinstance(min_ratios, float)
        assert isinstance(max_ratios, list) or isinstance(max_ratios, float)
W
wanghaoshuang 已提交
152 153 154 155
        min_ratios = min_ratios if isinstance(
            min_ratios, list) else [min_ratios] * len(self._params)
        max_ratios = max_ratios if isinstance(
            max_ratios, list) else [max_ratios] * len(self._params)
W
wanghaoshuang 已提交
156 157 158 159 160 161
        min_tokens = self._ratios2tokens(min_ratios)
        max_tokens = self._ratios2tokens(max_ratios)
        return (min_tokens, max_tokens)

    def _constrain_func(self, tokens):
        ratios = self._tokens2ratios(tokens)
162
        pruned_program, _, _ = self._pruner.prune(
163 164
            self._program,
            self._scope,
W
wanghaoshuang 已提交
165
            self._params,
166 167
            ratios,
            place=self._place,
W
wanghaoshuang 已提交
168
            only_graph=True)
169 170 171 172 173 174 175 176 177
        current_flops = flops(pruned_program)
        result = current_flops < self._max_flops
        if not result:
            _logger.info("Failed try ratios: {}; flops: {}; max_flops: {}".
                         format(ratios, current_flops, self._max_flops))
        else:
            _logger.info("Success try ratios: {}; flops: {}; max_flops: {}".
                         format(ratios, current_flops, self._max_flops))
        return result
W
wanghaoshuang 已提交
178

W
wanghaoshuang 已提交
179
    def prune(self, program, eval_program=None):
W
wanghaoshuang 已提交
180 181
        """
        Prune program with latest tokens generated by controller.
182

W
wanghaoshuang 已提交
183 184
        Args:
            program(fluid.Program): The program to be pruned.
185

W
wanghaoshuang 已提交
186
        Returns:
187
            paddle.fluid.Program: The pruned program.
W
wanghaoshuang 已提交
188
        """
W
wanghaoshuang 已提交
189
        self._current_ratios = self._next_ratios()
190
        pruned_program, _, _ = self._pruner.prune(
191 192 193 194 195
            program,
            self._scope,
            self._params,
            self._current_ratios,
            place=self._place,
W
wanghaoshuang 已提交
196
            only_graph=False,
197
            param_backup=self._param_backup)
W
wanghaoshuang 已提交
198 199
        pruned_val_program = None
        if eval_program is not None:
200
            pruned_val_program, _, _ = self._pruner.prune(
W
wanghaoshuang 已提交
201 202 203 204 205 206 207
                program,
                self._scope,
                self._params,
                self._current_ratios,
                place=self._place,
                only_graph=True)

208 209
        _logger.info("AutoPruner - pruned ratios: {}".format(
            self._current_ratios))
W
wanghaoshuang 已提交
210
        return pruned_program, pruned_val_program
211 212

    def reward(self, score):
W
wanghaoshuang 已提交
213 214
        """
        Return reward of current pruned program.
215

W
wanghaoshuang 已提交
216
        Args:
217
            float: The score of pruned program.
W
wanghaoshuang 已提交
218
        """
219 220 221
        self._restore(self._scope)
        self._param_backup = {}
        tokens = self._ratios2tokens(self._current_ratios)
W
wanghaoshuang 已提交
222
        self._controller_client.update(tokens, score, self._iter)
W
wanghaoshuang 已提交
223
        self._iter += 1
224

225 226 227 228 229
    def _restore(self, scope):
        for param_name in self._param_backup.keys():
            param_t = scope.find_var(param_name).get_tensor()
            param_t.set(self._param_backup[param_name], self._place)

W
wanghaoshuang 已提交
230
    def _next_ratios(self):
231 232 233 234 235 236 237 238
        tokens = self._controller_client.next_tokens()
        return self._tokens2ratios(tokens)

    def _ratios2tokens(self, ratios):
        """Convert pruned ratios to tokens.
        """
        return [int(ratio / 0.01) for ratio in ratios]

239
    def _tokens2ratios(self, tokens):
240 241 242
        """Convert tokens to pruned ratios.
        """
        return [token * 0.01 for token in tokens]