auto_pruner.py 7.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import socket
import logging
17 18 19 20
import numpy as np
import paddle.fluid as fluid
from .pruner import Pruner
from ..core import VarWrapper, OpWrapper, GraphWrapper
21 22 23
from ..common import SAController
from ..common import get_logger
from ..analysis import flops
24

25 26
from .controller_server import ControllerServer
from .controller_client import ControllerClient
27 28 29

__all__ = ["AutoPruner"]

30 31
_logger = get_logger(__name__, level=logging.INFO)

32 33 34

class AutoPruner(object):
    def __init__(self,
W
wanghaoshuang 已提交
35
                 program,
36 37
                 scope,
                 place,
38 39 40 41
                 params=[],
                 init_ratios=None,
                 pruned_flops=0.5,
                 pruned_latency=None,
W
wanghaoshuang 已提交
42 43 44
                 server_addr=("", 0),
                 init_temperature=100,
                 reduce_rate=0.85,
45
                 max_try_number=300,
W
wanghaoshuang 已提交
46 47 48 49
                 max_client_num=10,
                 search_steps=300,
                 max_ratios=[0.9],
                 min_ratios=[0],
50 51
                 key="auto_pruner",
                 is_server=True):
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
        """
        Search a group of ratios used to prune program.
        Args:
            params(list<str>): The names of parameters to be pruned.
            init_ratios(list<float>|float): Init ratios used to pruned parameters in `params`.
                                            List means ratios used for pruning each parameter in `params`.
                                            The length of `init_ratios` should be equal to length of params when `init_ratios` is a list. 
                                            If it is a scalar, all the parameters in `params` will be pruned by uniform ratio.
                                            None means get a group of init ratios by `pruned_flops` of `pruned_latency`. Default: None.
            pruned_flops(float): The percent of FLOPS to be pruned. Default: None.
            pruned_latency(float): The percent of latency to be pruned. Default: None.
            server_addr(tuple): A tuple of server ip and server port for controller server. 
            search_strategy(str): The search strategy. Default: 'sa'.
        """
        # step1: Create controller server. And start server if current host match server_ip.
67

W
wanghaoshuang 已提交
68
        self._program = program
69 70
        self._scope = scope
        self._place = place
71 72 73 74
        self._params = params
        self._init_ratios = init_ratios
        self._pruned_flops = pruned_flops
        self._pruned_latency = pruned_latency
W
wanghaoshuang 已提交
75 76 77
        self._reduce_rate = reduce_rate
        self._init_temperature = init_temperature
        self._max_try_number = max_try_number
78
        self._is_server = is_server
W
wanghaoshuang 已提交
79 80

        self._range_table = self._get_range_table(min_ratios, max_ratios)
81

82
        self._pruner = Pruner()
W
wanghaoshuang 已提交
83 84
        if self._pruned_flops:
            self._base_flops = flops(program)
85 86
            _logger.info("AutoPruner - base flops: {};".format(
                self._base_flops))
W
wanghaoshuang 已提交
87 88
        if self._pruned_latency:
            self._base_latency = latency(program)
89

W
wanghaoshuang 已提交
90 91
        if self._init_ratios is None:
            self._init_ratios = self._get_init_ratios(
92 93
                self, _program, self._params, self._pruned_flops,
                self._pruned_latency)
W
wanghaoshuang 已提交
94
        init_tokens = self._ratios2tokens(self._init_ratios)
95

96 97 98 99 100 101 102
        controller = SAController(self._range_table, self._reduce_rate,
                                  self._init_temperature, self._max_try_number,
                                  init_tokens, self._constrain_func)

        server_ip, server_port = server_addr
        if server_ip == None or server_ip == "":
            server_ip = self._get_host_ip()
103

W
wanghaoshuang 已提交
104 105
        self._controller_server = ControllerServer(
            controller=controller,
106 107 108
            address=(server_ip, server_port),
            max_client_num=max_client_num,
            search_steps=search_steps,
W
wanghaoshuang 已提交
109
            key=key)
110

111 112 113
        # create controller server
        if self._is_server:
            self._controller_server.start()
114

115 116 117 118
        self._controller_client = ControllerClient(
            self._controller_server.ip(),
            self._controller_server.port(),
            key=key)
W
wanghaoshuang 已提交
119 120

        self._iter = 0
121 122 123 124
        self._param_backup = {}

    def _get_host_ip(self):
        return socket.gethostbyname(socket.gethostname())
125

W
wanghaoshuang 已提交
126 127 128 129 130 131
    def _get_init_ratios(self, program, params, pruned_flops, pruned_latency):
        pass

    def _get_range_table(self, min_ratios, max_ratios):
        assert isinstance(min_ratios, list) or isinstance(min_ratios, float)
        assert isinstance(max_ratios, list) or isinstance(max_ratios, float)
132 133 134 135
        min_ratios = min_ratios if isinstance(min_ratios,
                                              list) else [min_ratios]
        max_ratios = max_ratios if isinstance(max_ratios,
                                              list) else [max_ratios]
W
wanghaoshuang 已提交
136 137 138 139 140 141 142
        min_tokens = self._ratios2tokens(min_ratios)
        max_tokens = self._ratios2tokens(max_ratios)
        return (min_tokens, max_tokens)

    def _constrain_func(self, tokens):
        ratios = self._tokens2ratios(tokens)
        pruned_program = self._pruner.prune(
143 144
            self._program,
            self._scope,
W
wanghaoshuang 已提交
145
            self._params,
146 147
            ratios,
            place=self._place,
W
wanghaoshuang 已提交
148
            only_graph=True)
149 150
        return flops(pruned_program) < self._base_flops * (
            1 - self._pruned_flops)
W
wanghaoshuang 已提交
151

152
    def prune(self, program):
W
wanghaoshuang 已提交
153
        self._current_ratios = self._next_ratios()
154 155 156 157 158 159 160 161 162
        pruned_program = self._pruner.prune(
            program,
            self._scope,
            self._params,
            self._current_ratios,
            place=self._place,
            param_backup=self._param_backup)
        _logger.info("AutoPruner - pruned ratios: {}".format(
            self._current_ratios))
163 164 165
        return pruned_program

    def reward(self, score):
166 167 168 169
        self._restore(self._scope)
        self._param_backup = {}
        tokens = self._ratios2tokens(self._current_ratios)
        self._controller_client.update(tokens, score)
W
wanghaoshuang 已提交
170
        self._iter += 1
171

172 173 174 175 176
    def _restore(self, scope):
        for param_name in self._param_backup.keys():
            param_t = scope.find_var(param_name).get_tensor()
            param_t.set(self._param_backup[param_name], self._place)

W
wanghaoshuang 已提交
177
    def _next_ratios(self):
178 179 180 181 182 183 184 185
        tokens = self._controller_client.next_tokens()
        return self._tokens2ratios(tokens)

    def _ratios2tokens(self, ratios):
        """Convert pruned ratios to tokens.
        """
        return [int(ratio / 0.01) for ratio in ratios]

186
    def _tokens2ratios(self, tokens):
187 188 189
        """Convert tokens to pruned ratios.
        """
        return [token * 0.01 for token in tokens]