auto_pruner.py 9.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import socket
import logging
17 18 19 20
import numpy as np
import paddle.fluid as fluid
from .pruner import Pruner
from ..core import VarWrapper, OpWrapper, GraphWrapper
21 22 23
from ..common import SAController
from ..common import get_logger
from ..analysis import flops
24

25 26
from ..common import ControllerServer
from ..common import ControllerClient
27 28 29

__all__ = ["AutoPruner"]

30 31
_logger = get_logger(__name__, level=logging.INFO)

32 33 34

class AutoPruner(object):
    def __init__(self,
W
wanghaoshuang 已提交
35
                 program,
36 37
                 scope,
                 place,
38 39 40 41
                 params=[],
                 init_ratios=None,
                 pruned_flops=0.5,
                 pruned_latency=None,
W
wanghaoshuang 已提交
42 43 44
                 server_addr=("", 0),
                 init_temperature=100,
                 reduce_rate=0.85,
W
wanghaoshuang 已提交
45
                 max_try_times=300,
W
wanghaoshuang 已提交
46 47 48 49
                 max_client_num=10,
                 search_steps=300,
                 max_ratios=[0.9],
                 min_ratios=[0],
50 51
                 key="auto_pruner",
                 is_server=True):
52 53 54
        """
        Search a group of ratios used to prune program.
        Args:
W
wanghaoshuang 已提交
55 56 57
            program(Program): The program to be pruned.
            scope(Scope): The scope to be pruned.
            place(fluid.Place): The device place of parameters.
58 59
            params(list<str>): The names of parameters to be pruned.
            init_ratios(list<float>|float): Init ratios used to pruned parameters in `params`.
Q
qingqing01 已提交
60 61 62 63
                List means ratios used for pruning each parameter in `params`.
                The length of `init_ratios` should be equal to length of params when `init_ratios` is a list. 
                If it is a scalar, all the parameters in `params` will be pruned by uniform ratio.
                None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None.
64 65 66
            pruned_flops(float): The percent of FLOPS to be pruned. Default: None.
            pruned_latency(float): The percent of latency to be pruned. Default: None.
            server_addr(tuple): A tuple of server ip and server port for controller server. 
W
wanghaoshuang 已提交
67 68
            init_temperature(float): The init temperature used in simulated annealing search strategy.
            reduce_rate(float): The decay rate used in simulated annealing search strategy.
W
wanghaoshuang 已提交
69
            max_try_times(int): The max number of trying to generate legal tokens.
W
wanghaoshuang 已提交
70 71
            max_client_num(int): The max number of connections of controller server.
            search_steps(int): The steps of searching.
Q
qingqing01 已提交
72 73 74 75 76 77 78 79
            max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`.
                List means max ratios for each parameter in `params`.
                The length of `max_ratios` should be equal to length of params when `max_ratios` is a list.
                If it is a scalar, it will used for all the parameters in `params`.
            min_ratios(float|list<float>): Min ratios used to pruned parameters in `params`.
                List means min ratios for each parameter in `params`.
                The length of `min_ratios` should be equal to length of params when `min_ratios` is a list.
                If it is a scalar, it will used for all the parameters in `params`.
W
wanghaoshuang 已提交
80 81
            key(str): Identity used in communication between controller server and clients.
            is_server(bool): Whether current host is controller server. Default: True.
82
        """
83

W
wanghaoshuang 已提交
84
        self._program = program
85 86
        self._scope = scope
        self._place = place
87 88 89 90
        self._params = params
        self._init_ratios = init_ratios
        self._pruned_flops = pruned_flops
        self._pruned_latency = pruned_latency
W
wanghaoshuang 已提交
91 92
        self._reduce_rate = reduce_rate
        self._init_temperature = init_temperature
W
wanghaoshuang 已提交
93
        self._max_try_times = max_try_times
94
        self._is_server = is_server
W
wanghaoshuang 已提交
95 96

        self._range_table = self._get_range_table(min_ratios, max_ratios)
97

98
        self._pruner = Pruner()
W
wanghaoshuang 已提交
99 100
        if self._pruned_flops:
            self._base_flops = flops(program)
101 102 103 104
            self._max_flops = self._base_flops * (1 - self._pruned_flops)
            _logger.info(
                "AutoPruner - base flops: {}; pruned_flops: {}; max_flops: {}".
                format(self._base_flops, self._pruned_flops, self._max_flops))
W
wanghaoshuang 已提交
105 106
        if self._pruned_latency:
            self._base_latency = latency(program)
107

W
wanghaoshuang 已提交
108 109
        if self._init_ratios is None:
            self._init_ratios = self._get_init_ratios(
110 111
                self, _program, self._params, self._pruned_flops,
                self._pruned_latency)
W
wanghaoshuang 已提交
112
        init_tokens = self._ratios2tokens(self._init_ratios)
W
wanghaoshuang 已提交
113
        _logger.info("range table: {}".format(self._range_table))
114
        controller = SAController(self._range_table, self._reduce_rate,
W
wanghaoshuang 已提交
115
                                  self._init_temperature, self._max_try_times,
116 117 118 119 120
                                  init_tokens, self._constrain_func)

        server_ip, server_port = server_addr
        if server_ip == None or server_ip == "":
            server_ip = self._get_host_ip()
121

W
wanghaoshuang 已提交
122 123
        self._controller_server = ControllerServer(
            controller=controller,
124 125 126
            address=(server_ip, server_port),
            max_client_num=max_client_num,
            search_steps=search_steps,
W
wanghaoshuang 已提交
127
            key=key)
128

129 130 131
        # create controller server
        if self._is_server:
            self._controller_server.start()
132

133 134 135 136
        self._controller_client = ControllerClient(
            self._controller_server.ip(),
            self._controller_server.port(),
            key=key)
W
wanghaoshuang 已提交
137 138

        self._iter = 0
139 140 141 142
        self._param_backup = {}

    def _get_host_ip(self):
        return socket.gethostbyname(socket.gethostname())
143

W
wanghaoshuang 已提交
144 145 146 147 148 149
    def _get_init_ratios(self, program, params, pruned_flops, pruned_latency):
        pass

    def _get_range_table(self, min_ratios, max_ratios):
        assert isinstance(min_ratios, list) or isinstance(min_ratios, float)
        assert isinstance(max_ratios, list) or isinstance(max_ratios, float)
W
wanghaoshuang 已提交
150 151 152 153
        min_ratios = min_ratios if isinstance(
            min_ratios, list) else [min_ratios] * len(self._params)
        max_ratios = max_ratios if isinstance(
            max_ratios, list) else [max_ratios] * len(self._params)
W
wanghaoshuang 已提交
154 155 156 157 158 159 160
        min_tokens = self._ratios2tokens(min_ratios)
        max_tokens = self._ratios2tokens(max_ratios)
        return (min_tokens, max_tokens)

    def _constrain_func(self, tokens):
        ratios = self._tokens2ratios(tokens)
        pruned_program = self._pruner.prune(
161 162
            self._program,
            self._scope,
W
wanghaoshuang 已提交
163
            self._params,
164 165
            ratios,
            place=self._place,
W
wanghaoshuang 已提交
166
            only_graph=True)
167 168 169 170 171 172 173 174 175
        current_flops = flops(pruned_program)
        result = current_flops < self._max_flops
        if not result:
            _logger.info("Failed try ratios: {}; flops: {}; max_flops: {}".
                         format(ratios, current_flops, self._max_flops))
        else:
            _logger.info("Success try ratios: {}; flops: {}; max_flops: {}".
                         format(ratios, current_flops, self._max_flops))
        return result
W
wanghaoshuang 已提交
176

W
wanghaoshuang 已提交
177
    def prune(self, program, eval_program=None):
W
wanghaoshuang 已提交
178 179 180 181 182 183 184
        """
        Prune program with latest tokens generated by controller.
        Args:
            program(fluid.Program): The program to be pruned.
        Returns:
            Program: The pruned program.
        """
W
wanghaoshuang 已提交
185
        self._current_ratios = self._next_ratios()
186 187 188 189 190 191
        pruned_program = self._pruner.prune(
            program,
            self._scope,
            self._params,
            self._current_ratios,
            place=self._place,
W
wanghaoshuang 已提交
192
            only_graph=False,
193
            param_backup=self._param_backup)
W
wanghaoshuang 已提交
194 195 196 197 198 199 200 201 202 203
        pruned_val_program = None
        if eval_program is not None:
            pruned_val_program = self._pruner.prune(
                program,
                self._scope,
                self._params,
                self._current_ratios,
                place=self._place,
                only_graph=True)

204 205
        _logger.info("AutoPruner - pruned ratios: {}".format(
            self._current_ratios))
W
wanghaoshuang 已提交
206
        return pruned_program, pruned_val_program
207 208

    def reward(self, score):
W
wanghaoshuang 已提交
209 210 211 212 213
        """
        Return reward of current pruned program.
        Args:
            score(float): The score of pruned program.
        """
214 215 216
        self._restore(self._scope)
        self._param_backup = {}
        tokens = self._ratios2tokens(self._current_ratios)
W
wanghaoshuang 已提交
217
        self._controller_client.update(tokens, score, self._iter)
W
wanghaoshuang 已提交
218
        self._iter += 1
219

220 221 222 223 224
    def _restore(self, scope):
        for param_name in self._param_backup.keys():
            param_t = scope.find_var(param_name).get_tensor()
            param_t.set(self._param_backup[param_name], self._place)

W
wanghaoshuang 已提交
225
    def _next_ratios(self):
226 227 228 229 230 231 232 233
        tokens = self._controller_client.next_tokens()
        return self._tokens2ratios(tokens)

    def _ratios2tokens(self, ratios):
        """Convert pruned ratios to tokens.
        """
        return [int(ratio / 0.01) for ratio in ratios]

234
    def _tokens2ratios(self, tokens):
235 236 237
        """Convert tokens to pruned ratios.
        """
        return [token * 0.01 for token in tokens]