auto_pruner.py 9.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import socket
import logging
17 18 19 20
import numpy as np
import paddle.fluid as fluid
from .pruner import Pruner
from ..core import VarWrapper, OpWrapper, GraphWrapper
21 22 23
from ..common import SAController
from ..common import get_logger
from ..analysis import flops
24

25 26
from ..common import ControllerServer
from ..common import ControllerClient
27 28 29

__all__ = ["AutoPruner"]

30 31
_logger = get_logger(__name__, level=logging.INFO)

32 33 34

class AutoPruner(object):
    def __init__(self,
W
wanghaoshuang 已提交
35
                 program,
36 37
                 scope,
                 place,
38 39 40 41
                 params=[],
                 init_ratios=None,
                 pruned_flops=0.5,
                 pruned_latency=None,
W
wanghaoshuang 已提交
42 43 44
                 server_addr=("", 0),
                 init_temperature=100,
                 reduce_rate=0.85,
45
                 max_try_number=300,
W
wanghaoshuang 已提交
46 47 48 49
                 max_client_num=10,
                 search_steps=300,
                 max_ratios=[0.9],
                 min_ratios=[0],
50 51
                 key="auto_pruner",
                 is_server=True):
52 53 54
        """
        Search a group of ratios used to prune program.
        Args:
W
wanghaoshuang 已提交
55 56 57
            program(Program): The program to be pruned.
            scope(Scope): The scope to be pruned.
            place(fluid.Place): The device place of parameters.
58 59 60 61 62 63 64 65 66
            params(list<str>): The names of parameters to be pruned.
            init_ratios(list<float>|float): Init ratios used to pruned parameters in `params`.
                                            List means ratios used for pruning each parameter in `params`.
                                            The length of `init_ratios` should be equal to length of params when `init_ratios` is a list. 
                                            If it is a scalar, all the parameters in `params` will be pruned by uniform ratio.
                                            None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None.
            pruned_flops(float): The percent of FLOPS to be pruned. Default: None.
            pruned_latency(float): The percent of latency to be pruned. Default: None.
            server_addr(tuple): A tuple of server ip and server port for controller server. 
W
wanghaoshuang 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79
            init_temperature(float): The init temperature used in simulated annealing search strategy.
            reduce_rate(float): The decay rate used in simulated annealing search strategy.
            max_try_number(int): The max number of trying to generate legal tokens.
            max_client_num(int): The max number of connections of controller server.
            search_steps(int): The steps of searching.
            max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`.
                                           The length of `max_ratios` should be equal to length of params when `max_ratios` is a list.
                                           If it is a scalar, it will used for all the parameters in `params`.
            min_ratios(float|list<float>): Min ratios used to pruned parameters in `params`. List means min ratios for each parameter in `params`.
                                           The length of `min_ratios` should be equal to length of params when `min_ratios` is a list.
                                           If it is a scalar, it will used for all the parameters in `params`.
            key(str): Identity used in communication between controller server and clients.
            is_server(bool): Whether current host is controller server. Default: True.
80
        """
81

W
wanghaoshuang 已提交
82
        self._program = program
83 84
        self._scope = scope
        self._place = place
85 86 87 88
        self._params = params
        self._init_ratios = init_ratios
        self._pruned_flops = pruned_flops
        self._pruned_latency = pruned_latency
W
wanghaoshuang 已提交
89 90 91
        self._reduce_rate = reduce_rate
        self._init_temperature = init_temperature
        self._max_try_number = max_try_number
92
        self._is_server = is_server
W
wanghaoshuang 已提交
93 94

        self._range_table = self._get_range_table(min_ratios, max_ratios)
95

96
        self._pruner = Pruner()
W
wanghaoshuang 已提交
97 98
        if self._pruned_flops:
            self._base_flops = flops(program)
99 100 101 102
            self._max_flops = self._base_flops * (1 - self._pruned_flops)
            _logger.info(
                "AutoPruner - base flops: {}; pruned_flops: {}; max_flops: {}".
                format(self._base_flops, self._pruned_flops, self._max_flops))
W
wanghaoshuang 已提交
103 104
        if self._pruned_latency:
            self._base_latency = latency(program)
105

W
wanghaoshuang 已提交
106 107
        if self._init_ratios is None:
            self._init_ratios = self._get_init_ratios(
108 109
                self, _program, self._params, self._pruned_flops,
                self._pruned_latency)
W
wanghaoshuang 已提交
110
        init_tokens = self._ratios2tokens(self._init_ratios)
W
wanghaoshuang 已提交
111
        _logger.info("range table: {}".format(self._range_table))
112 113 114 115 116 117 118
        controller = SAController(self._range_table, self._reduce_rate,
                                  self._init_temperature, self._max_try_number,
                                  init_tokens, self._constrain_func)

        server_ip, server_port = server_addr
        if server_ip == None or server_ip == "":
            server_ip = self._get_host_ip()
119

W
wanghaoshuang 已提交
120 121
        self._controller_server = ControllerServer(
            controller=controller,
122 123 124
            address=(server_ip, server_port),
            max_client_num=max_client_num,
            search_steps=search_steps,
W
wanghaoshuang 已提交
125
            key=key)
126

127 128 129
        # create controller server
        if self._is_server:
            self._controller_server.start()
130

131 132 133 134
        self._controller_client = ControllerClient(
            self._controller_server.ip(),
            self._controller_server.port(),
            key=key)
W
wanghaoshuang 已提交
135 136

        self._iter = 0
137 138 139 140
        self._param_backup = {}

    def _get_host_ip(self):
        return socket.gethostbyname(socket.gethostname())
141

W
wanghaoshuang 已提交
142 143 144 145 146 147
    def _get_init_ratios(self, program, params, pruned_flops, pruned_latency):
        pass

    def _get_range_table(self, min_ratios, max_ratios):
        assert isinstance(min_ratios, list) or isinstance(min_ratios, float)
        assert isinstance(max_ratios, list) or isinstance(max_ratios, float)
W
wanghaoshuang 已提交
148 149 150 151
        min_ratios = min_ratios if isinstance(
            min_ratios, list) else [min_ratios] * len(self._params)
        max_ratios = max_ratios if isinstance(
            max_ratios, list) else [max_ratios] * len(self._params)
W
wanghaoshuang 已提交
152 153 154 155 156 157 158
        min_tokens = self._ratios2tokens(min_ratios)
        max_tokens = self._ratios2tokens(max_ratios)
        return (min_tokens, max_tokens)

    def _constrain_func(self, tokens):
        ratios = self._tokens2ratios(tokens)
        pruned_program = self._pruner.prune(
159 160
            self._program,
            self._scope,
W
wanghaoshuang 已提交
161
            self._params,
162 163
            ratios,
            place=self._place,
W
wanghaoshuang 已提交
164
            only_graph=True)
165 166 167 168 169 170 171 172 173
        current_flops = flops(pruned_program)
        result = current_flops < self._max_flops
        if not result:
            _logger.info("Failed try ratios: {}; flops: {}; max_flops: {}".
                         format(ratios, current_flops, self._max_flops))
        else:
            _logger.info("Success try ratios: {}; flops: {}; max_flops: {}".
                         format(ratios, current_flops, self._max_flops))
        return result
W
wanghaoshuang 已提交
174

W
wanghaoshuang 已提交
175
    def prune(self, program, eval_program=None):
W
wanghaoshuang 已提交
176 177 178 179 180 181 182
        """
        Prune program with latest tokens generated by controller.
        Args:
            program(fluid.Program): The program to be pruned.
        Returns:
            Program: The pruned program.
        """
W
wanghaoshuang 已提交
183
        self._current_ratios = self._next_ratios()
184 185 186 187 188 189
        pruned_program = self._pruner.prune(
            program,
            self._scope,
            self._params,
            self._current_ratios,
            place=self._place,
W
wanghaoshuang 已提交
190
            only_graph=False,
191
            param_backup=self._param_backup)
W
wanghaoshuang 已提交
192 193 194 195 196 197 198 199 200 201
        pruned_val_program = None
        if eval_program is not None:
            pruned_val_program = self._pruner.prune(
                program,
                self._scope,
                self._params,
                self._current_ratios,
                place=self._place,
                only_graph=True)

202 203
        _logger.info("AutoPruner - pruned ratios: {}".format(
            self._current_ratios))
W
wanghaoshuang 已提交
204
        return pruned_program, pruned_val_program
205 206

    def reward(self, score):
W
wanghaoshuang 已提交
207 208 209 210 211
        """
        Return reward of current pruned program.
        Args:
            score(float): The score of pruned program.
        """
212 213 214
        self._restore(self._scope)
        self._param_backup = {}
        tokens = self._ratios2tokens(self._current_ratios)
W
wanghaoshuang 已提交
215
        self._controller_client.update(tokens, score, self._iter)
W
wanghaoshuang 已提交
216
        self._iter += 1
217

218 219 220 221 222
    def _restore(self, scope):
        for param_name in self._param_backup.keys():
            param_t = scope.find_var(param_name).get_tensor()
            param_t.set(self._param_backup[param_name], self._place)

W
wanghaoshuang 已提交
223
    def _next_ratios(self):
224 225 226 227 228 229 230 231
        tokens = self._controller_client.next_tokens()
        return self._tokens2ratios(tokens)

    def _ratios2tokens(self, ratios):
        """Convert pruned ratios to tokens.
        """
        return [int(ratio / 0.01) for ratio in ratios]

232
    def _tokens2ratios(self, tokens):
233 234 235
        """Convert tokens to pruned ratios.
        """
        return [token * 0.01 for token in tokens]