auto_pruner.py 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import socket
import logging
17 18 19 20
import numpy as np
import paddle.fluid as fluid
from .pruner import Pruner
from ..core import VarWrapper, OpWrapper, GraphWrapper
21 22 23
from ..common import SAController
from ..common import get_logger
from ..analysis import flops
24

25 26
from ..common import ControllerServer
from ..common import ControllerClient
27 28 29

__all__ = ["AutoPruner"]

30 31
_logger = get_logger(__name__, level=logging.INFO)

32 33 34

class AutoPruner(object):
    def __init__(self,
W
wanghaoshuang 已提交
35
                 program,
36 37
                 scope,
                 place,
38 39 40 41
                 params=[],
                 init_ratios=None,
                 pruned_flops=0.5,
                 pruned_latency=None,
W
wanghaoshuang 已提交
42 43 44
                 server_addr=("", 0),
                 init_temperature=100,
                 reduce_rate=0.85,
45
                 max_try_number=300,
W
wanghaoshuang 已提交
46 47 48 49
                 max_client_num=10,
                 search_steps=300,
                 max_ratios=[0.9],
                 min_ratios=[0],
50 51
                 key="auto_pruner",
                 is_server=True):
52 53 54
        """
        Search a group of ratios used to prune program.
        Args:
W
wanghaoshuang 已提交
55 56 57
            program(Program): The program to be pruned.
            scope(Scope): The scope to be pruned.
            place(fluid.Place): The device place of parameters.
58 59 60 61 62 63 64 65 66
            params(list<str>): The names of parameters to be pruned.
            init_ratios(list<float>|float): Init ratios used to pruned parameters in `params`.
                                            List means ratios used for pruning each parameter in `params`.
                                            The length of `init_ratios` should be equal to length of params when `init_ratios` is a list. 
                                            If it is a scalar, all the parameters in `params` will be pruned by uniform ratio.
                                            None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None.
            pruned_flops(float): The percent of FLOPS to be pruned. Default: None.
            pruned_latency(float): The percent of latency to be pruned. Default: None.
            server_addr(tuple): A tuple of server ip and server port for controller server. 
W
wanghaoshuang 已提交
67 68 69 70 71 72 73 74 75 76 77 78 79
            init_temperature(float): The init temperature used in simulated annealing search strategy.
            reduce_rate(float): The decay rate used in simulated annealing search strategy.
            max_try_number(int): The max number of trying to generate legal tokens.
            max_client_num(int): The max number of connections of controller server.
            search_steps(int): The steps of searching.
            max_ratios(float|list<float>): Max ratios used to pruned parameters in `params`. List means max ratios for each parameter in `params`.
                                           The length of `max_ratios` should be equal to length of params when `max_ratios` is a list.
                                           If it is a scalar, it will used for all the parameters in `params`.
            min_ratios(float|list<float>): Min ratios used to pruned parameters in `params`. List means min ratios for each parameter in `params`.
                                           The length of `min_ratios` should be equal to length of params when `min_ratios` is a list.
                                           If it is a scalar, it will used for all the parameters in `params`.
            key(str): Identity used in communication between controller server and clients.
            is_server(bool): Whether current host is controller server. Default: True.
80
        """
81

W
wanghaoshuang 已提交
82
        self._program = program
83 84
        self._scope = scope
        self._place = place
85 86 87 88
        self._params = params
        self._init_ratios = init_ratios
        self._pruned_flops = pruned_flops
        self._pruned_latency = pruned_latency
W
wanghaoshuang 已提交
89 90 91
        self._reduce_rate = reduce_rate
        self._init_temperature = init_temperature
        self._max_try_number = max_try_number
92
        self._is_server = is_server
W
wanghaoshuang 已提交
93 94

        self._range_table = self._get_range_table(min_ratios, max_ratios)
95

96
        self._pruner = Pruner()
W
wanghaoshuang 已提交
97 98
        if self._pruned_flops:
            self._base_flops = flops(program)
99 100
            _logger.info("AutoPruner - base flops: {};".format(
                self._base_flops))
W
wanghaoshuang 已提交
101 102
        if self._pruned_latency:
            self._base_latency = latency(program)
103

W
wanghaoshuang 已提交
104 105
        if self._init_ratios is None:
            self._init_ratios = self._get_init_ratios(
106 107
                self, _program, self._params, self._pruned_flops,
                self._pruned_latency)
W
wanghaoshuang 已提交
108
        init_tokens = self._ratios2tokens(self._init_ratios)
109

110 111 112 113 114 115 116
        controller = SAController(self._range_table, self._reduce_rate,
                                  self._init_temperature, self._max_try_number,
                                  init_tokens, self._constrain_func)

        server_ip, server_port = server_addr
        if server_ip == None or server_ip == "":
            server_ip = self._get_host_ip()
117

W
wanghaoshuang 已提交
118 119
        self._controller_server = ControllerServer(
            controller=controller,
120 121 122
            address=(server_ip, server_port),
            max_client_num=max_client_num,
            search_steps=search_steps,
W
wanghaoshuang 已提交
123
            key=key)
124

125 126 127
        # create controller server
        if self._is_server:
            self._controller_server.start()
128

129 130 131 132
        self._controller_client = ControllerClient(
            self._controller_server.ip(),
            self._controller_server.port(),
            key=key)
W
wanghaoshuang 已提交
133 134

        self._iter = 0
135 136 137 138
        self._param_backup = {}

    def _get_host_ip(self):
        return socket.gethostbyname(socket.gethostname())
139

W
wanghaoshuang 已提交
140 141 142 143 144 145
    def _get_init_ratios(self, program, params, pruned_flops, pruned_latency):
        pass

    def _get_range_table(self, min_ratios, max_ratios):
        assert isinstance(min_ratios, list) or isinstance(min_ratios, float)
        assert isinstance(max_ratios, list) or isinstance(max_ratios, float)
146 147 148 149
        min_ratios = min_ratios if isinstance(min_ratios,
                                              list) else [min_ratios]
        max_ratios = max_ratios if isinstance(max_ratios,
                                              list) else [max_ratios]
W
wanghaoshuang 已提交
150 151 152 153 154 155 156
        min_tokens = self._ratios2tokens(min_ratios)
        max_tokens = self._ratios2tokens(max_ratios)
        return (min_tokens, max_tokens)

    def _constrain_func(self, tokens):
        ratios = self._tokens2ratios(tokens)
        pruned_program = self._pruner.prune(
157 158
            self._program,
            self._scope,
W
wanghaoshuang 已提交
159
            self._params,
160 161
            ratios,
            place=self._place,
W
wanghaoshuang 已提交
162
            only_graph=True)
163 164
        return flops(pruned_program) < self._base_flops * (
            1 - self._pruned_flops)
W
wanghaoshuang 已提交
165

166
    def prune(self, program):
W
wanghaoshuang 已提交
167 168 169 170 171 172 173
        """
        Prune program with latest tokens generated by controller.
        Args:
            program(fluid.Program): The program to be pruned.
        Returns:
            Program: The pruned program.
        """
W
wanghaoshuang 已提交
174
        self._current_ratios = self._next_ratios()
175 176 177 178 179 180 181 182 183
        pruned_program = self._pruner.prune(
            program,
            self._scope,
            self._params,
            self._current_ratios,
            place=self._place,
            param_backup=self._param_backup)
        _logger.info("AutoPruner - pruned ratios: {}".format(
            self._current_ratios))
184 185 186
        return pruned_program

    def reward(self, score):
W
wanghaoshuang 已提交
187 188 189 190 191
        """
        Return reward of current pruned program.
        Args:
            score(float): The score of pruned program.
        """
192 193 194 195
        self._restore(self._scope)
        self._param_backup = {}
        tokens = self._ratios2tokens(self._current_ratios)
        self._controller_client.update(tokens, score)
W
wanghaoshuang 已提交
196
        self._iter += 1
197

198 199 200 201 202
    def _restore(self, scope):
        for param_name in self._param_backup.keys():
            param_t = scope.find_var(param_name).get_tensor()
            param_t.set(self._param_backup[param_name], self._place)

W
wanghaoshuang 已提交
203
    def _next_ratios(self):
204 205 206 207 208 209 210 211
        tokens = self._controller_client.next_tokens()
        return self._tokens2ratios(tokens)

    def _ratios2tokens(self, ratios):
        """Convert pruned ratios to tokens.
        """
        return [int(ratio / 0.01) for ratio in ratios]

212
    def _tokens2ratios(self, tokens):
213 214 215
        """Convert tokens to pruned ratios.
        """
        return [token * 0.01 for token in tokens]