auto_pruner.py 9.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import socket
import logging
17 18 19
import numpy as np
import paddle.fluid as fluid
from .pruner import Pruner
W
wanghaoshuang 已提交
20
from ..core import GraphWrapper
21 22 23
from ..common import SAController
from ..common import get_logger
from ..analysis import flops
24

25 26
from ..common import ControllerServer
from ..common import ControllerClient
27 28 29

__all__ = ["AutoPruner"]

30 31
_logger = get_logger(__name__, level=logging.INFO)

32 33

class AutoPruner(object):
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
    """
    Search a group of ratios used to prune program.

    Args:
        program(Program): The program to be pruned.
        scope(Scope): The scope to be pruned.
        place(fluid.Place): The device place of parameters.
        params(list<str>): The names of parameters to be pruned.
        init_ratios(list<float>|float): Init ratios used to pruned parameters in `params`.
            List means ratios used for pruning each parameter in `params`.
            The length of `init_ratios` should be equal to length of params when `init_ratios` is a list. 
            If it is a scalar, all the parameters in `params` will be pruned by uniform ratio.
            None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None.
        pruned_flops(float): The percent of FLOPS to be pruned. Default: None.
        pruned_latency(float): The percent of latency to be pruned. Default: None.
        server_addr(tuple): A tuple of server ip and server port for controller server. 
        init_temperature(float): The init temperature used in simulated annealing search strategy.
        reduce_rate(float): The decay rate used in simulated annealing search strategy.
        max_try_times(int): The max number of trying to generate legal tokens.
        max_client_num(int): The max number of connections of controller server.
        search_steps(int): The steps of searching.
        max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`.
            List means max ratios for each parameter in `params`.
            The length of `max_ratios` should be equal to length of params when `max_ratios` is a list.
            If it is a scalar, it will used for all the parameters in `params`.
        min_ratios(float|list<float>): Min ratios used to pruned parameters in `params`.
            List means min ratios for each parameter in `params`.
            The length of `min_ratios` should be equal to length of params when `min_ratios` is a list.
            If it is a scalar, it will used for all the parameters in `params`.
        key(str): Identity used in communication between controller server and clients.
        is_server(bool): Whether current host is controller server. Default: True.
        """

67
    def __init__(self,
W
wanghaoshuang 已提交
68
                 program,
69 70
                 scope,
                 place,
71 72 73 74
                 params=[],
                 init_ratios=None,
                 pruned_flops=0.5,
                 pruned_latency=None,
W
wanghaoshuang 已提交
75 76 77
                 server_addr=("", 0),
                 init_temperature=100,
                 reduce_rate=0.85,
W
wanghaoshuang 已提交
78
                 max_try_times=300,
W
wanghaoshuang 已提交
79 80 81 82
                 max_client_num=10,
                 search_steps=300,
                 max_ratios=[0.9],
                 min_ratios=[0],
83 84 85
                 key="auto_pruner",
                 is_server=True):

W
wanghaoshuang 已提交
86
        self._program = program
87 88
        self._scope = scope
        self._place = place
89 90 91 92
        self._params = params
        self._init_ratios = init_ratios
        self._pruned_flops = pruned_flops
        self._pruned_latency = pruned_latency
W
wanghaoshuang 已提交
93 94
        self._reduce_rate = reduce_rate
        self._init_temperature = init_temperature
W
wanghaoshuang 已提交
95
        self._max_try_times = max_try_times
96
        self._is_server = is_server
W
wanghaoshuang 已提交
97 98

        self._range_table = self._get_range_table(min_ratios, max_ratios)
99

100
        self._pruner = Pruner()
W
wanghaoshuang 已提交
101 102
        if self._pruned_flops:
            self._base_flops = flops(program)
103 104 105 106
            self._max_flops = self._base_flops * (1 - self._pruned_flops)
            _logger.info(
                "AutoPruner - base flops: {}; pruned_flops: {}; max_flops: {}".
                format(self._base_flops, self._pruned_flops, self._max_flops))
W
wanghaoshuang 已提交
107 108
        if self._pruned_latency:
            self._base_latency = latency(program)
109

W
wanghaoshuang 已提交
110 111
        if self._init_ratios is None:
            self._init_ratios = self._get_init_ratios(
112 113
                self, _program, self._params, self._pruned_flops,
                self._pruned_latency)
W
wanghaoshuang 已提交
114
        init_tokens = self._ratios2tokens(self._init_ratios)
W
wanghaoshuang 已提交
115
        _logger.info("range table: {}".format(self._range_table))
116
        controller = SAController(self._range_table, self._reduce_rate,
W
wanghaoshuang 已提交
117
                                  self._init_temperature, self._max_try_times,
118 119 120 121 122
                                  init_tokens, self._constrain_func)

        server_ip, server_port = server_addr
        if server_ip == None or server_ip == "":
            server_ip = self._get_host_ip()
123

W
wanghaoshuang 已提交
124 125
        self._controller_server = ControllerServer(
            controller=controller,
126 127 128
            address=(server_ip, server_port),
            max_client_num=max_client_num,
            search_steps=search_steps,
W
wanghaoshuang 已提交
129
            key=key)
130

131 132 133
        # create controller server
        if self._is_server:
            self._controller_server.start()
134

135 136 137 138
        self._controller_client = ControllerClient(
            self._controller_server.ip(),
            self._controller_server.port(),
            key=key)
W
wanghaoshuang 已提交
139 140

        self._iter = 0
141 142 143 144
        self._param_backup = {}

    def _get_host_ip(self):
        return socket.gethostbyname(socket.gethostname())
145

W
wanghaoshuang 已提交
146 147 148 149 150 151
    def _get_init_ratios(self, program, params, pruned_flops, pruned_latency):
        pass

    def _get_range_table(self, min_ratios, max_ratios):
        assert isinstance(min_ratios, list) or isinstance(min_ratios, float)
        assert isinstance(max_ratios, list) or isinstance(max_ratios, float)
W
wanghaoshuang 已提交
152 153 154 155
        min_ratios = min_ratios if isinstance(
            min_ratios, list) else [min_ratios] * len(self._params)
        max_ratios = max_ratios if isinstance(
            max_ratios, list) else [max_ratios] * len(self._params)
W
wanghaoshuang 已提交
156 157 158 159 160 161
        min_tokens = self._ratios2tokens(min_ratios)
        max_tokens = self._ratios2tokens(max_ratios)
        return (min_tokens, max_tokens)

    def _constrain_func(self, tokens):
        ratios = self._tokens2ratios(tokens)
162
        pruned_program, _, _ = self._pruner.prune(
163 164
            self._program,
            self._scope,
W
wanghaoshuang 已提交
165
            self._params,
166 167
            ratios,
            place=self._place,
W
wanghaoshuang 已提交
168
            only_graph=True)
169 170 171 172 173 174 175 176 177
        current_flops = flops(pruned_program)
        result = current_flops < self._max_flops
        if not result:
            _logger.info("Failed try ratios: {}; flops: {}; max_flops: {}".
                         format(ratios, current_flops, self._max_flops))
        else:
            _logger.info("Success try ratios: {}; flops: {}; max_flops: {}".
                         format(ratios, current_flops, self._max_flops))
        return result
W
wanghaoshuang 已提交
178

W
wanghaoshuang 已提交
179
    def prune(self, program, eval_program=None):
W
wanghaoshuang 已提交
180 181
        """
        Prune program with latest tokens generated by controller.
182

W
wanghaoshuang 已提交
183 184
        Args:
            program(fluid.Program): The program to be pruned.
185

W
wanghaoshuang 已提交
186
        Returns:
187
            paddle.fluid.Program: The pruned program.
W
wanghaoshuang 已提交
188
        """
W
wanghaoshuang 已提交
189
        self._current_ratios = self._next_ratios()
190
        pruned_program, _, _ = self._pruner.prune(
191 192 193 194 195
            program,
            self._scope,
            self._params,
            self._current_ratios,
            place=self._place,
W
wanghaoshuang 已提交
196
            only_graph=False,
197
            param_backup=self._param_backup)
W
wanghaoshuang 已提交
198 199
        pruned_val_program = None
        if eval_program is not None:
200
            pruned_val_program, _, _ = self._pruner.prune(
W
wanghaoshuang 已提交
201 202 203 204 205 206 207
                program,
                self._scope,
                self._params,
                self._current_ratios,
                place=self._place,
                only_graph=True)

208 209
        _logger.info("AutoPruner - pruned ratios: {}".format(
            self._current_ratios))
W
wanghaoshuang 已提交
210
        return pruned_program, pruned_val_program
211 212

    def reward(self, score):
W
wanghaoshuang 已提交
213 214
        """
        Return reward of current pruned program.
215

W
wanghaoshuang 已提交
216
        Args:
217
            float: The score of pruned program.
W
wanghaoshuang 已提交
218
        """
219 220 221
        self._restore(self._scope)
        self._param_backup = {}
        tokens = self._ratios2tokens(self._current_ratios)
W
wanghaoshuang 已提交
222
        self._controller_client.update(tokens, score, self._iter)
W
wanghaoshuang 已提交
223
        self._iter += 1
224

225 226 227 228 229
    def _restore(self, scope):
        for param_name in self._param_backup.keys():
            param_t = scope.find_var(param_name).get_tensor()
            param_t.set(self._param_backup[param_name], self._place)

W
wanghaoshuang 已提交
230
    def _next_ratios(self):
231 232 233 234 235 236 237 238
        tokens = self._controller_client.next_tokens()
        return self._tokens2ratios(tokens)

    def _ratios2tokens(self, ratios):
        """Convert pruned ratios to tokens.
        """
        return [int(ratio / 0.01) for ratio in ratios]

239
    def _tokens2ratios(self, tokens):
240 241 242
        """Convert tokens to pruned ratios.
        """
        return [token * 0.01 for token in tokens]