autofinetune.py 9.4 KB
Newer Older
K
kinghuin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# coding:utf-8
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
S
Steffy-zxf 已提交
21
import shutil
K
kinghuin 已提交
22 23 24

from paddlehub.commands.base_command import BaseCommand, ENTRY
from paddlehub.autofinetune.autoft import PSHE2
S
Steffy-zxf 已提交
25
from paddlehub.autofinetune.autoft import HAZero
K
kinghuin 已提交
26
from paddlehub.autofinetune.evaluator import FullTrailEvaluator
S
Steffy-zxf 已提交
27
from paddlehub.autofinetune.evaluator import PopulationBasedEvaluator
K
kinghuin 已提交
28 29 30 31 32 33 34 35 36


class AutoFineTuneCommand(BaseCommand):
    name = "autofinetune"

    def __init__(self, name):
        super(AutoFineTuneCommand, self).__init__(name)
        self.show_in_help = True
        self.name = name
S
Steffy-zxf 已提交
37
        self.description = "PaddleHub helps to finetune a task by searching hyperparameters automatically."
K
kinghuin 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
        self.parser = argparse.ArgumentParser(
            description=self.__class__.__doc__,
            prog='%s %s <task to be fintuned in python script>' % (ENTRY,
                                                                   self.name),
            usage='%(prog)s',
            add_help=False)
        self.module = None

    def add_params_file_arg(self):
        self.arg_params_to_be_searched_group.add_argument(
            "--param_file",
            type=str,
            default=None,
            required=True,
            help=
            "Hyperparameters to be searched in the yaml format. The number of hyperparameters searched must be greater than 1."
        )

    def add_autoft_config_arg(self):
        self.arg_config_group.add_argument(
            "--popsize", type=int, default=5, help="Population size")
        self.arg_config_group.add_argument(
S
Steffy-zxf 已提交
60 61 62
            "--gpu",
            type=str,
            default="0",
S
Steffy-zxf 已提交
63
            required=True,
K
kinghuin 已提交
64 65 66 67 68 69 70 71 72
            help="The list of gpu devices to be used")
        self.arg_config_group.add_argument(
            "--round", type=int, default=10, help="Number of searches")
        self.arg_config_group.add_argument(
            "--output_dir",
            type=str,
            default=None,
            help="Directory to model checkpoint")
        self.arg_config_group.add_argument(
S
Steffy-zxf 已提交
73
            "--evaluator",
K
kinghuin 已提交
74
            type=str,
S
Steffy-zxf 已提交
75 76
            default="populationbased",
            help="Choices: fulltrail or populationbased.")
S
Steffy-zxf 已提交
77 78 79
        self.arg_config_group.add_argument(
            "--tuning_strategy",
            type=str,
80
            default="pshe2",
S
Steffy-zxf 已提交
81
            help="Choices: HAZero or PSHE2.")
Z
zhangxuefei 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        self.arg_config_group.add_argument(
            'opts',
            help='See utils/config.py for all options',
            default=None,
            nargs=argparse.REMAINDER)

    def convert_to_other_options(self, config_list):
        if len(config_list) % 2 != 0:
            raise ValueError(
                "Command for finetuned task options config format error! Please check it: {}"
                .format(config_list))
        options_str = ""
        for key, value in zip(config_list[0::2], config_list[1::2]):
            options_str += "--" + key + "=" + value + " "
        return options_str
K
kinghuin 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114

    def execute(self, argv):
        if not argv:
            print("ERROR: Please specify a script to be finetuned in python.\n")
            self.help()
            return False

        self.fintunee_script = argv[0]

        self.parser.prog = '%s %s %s' % (ENTRY, self.name, self.fintunee_script)
        self.arg_params_to_be_searched_group = self.parser.add_argument_group(
            title="Input options",
            description="Hyperparameters to be searched.")
        self.arg_config_group = self.parser.add_argument_group(
            title="Autofinetune config options",
            description=
            "Autofintune configuration for controlling autofinetune behavior, not required"
        )
Z
zhangxuefei 已提交
115 116 117 118 119
        self.arg_finetuned_task_group = self.parser.add_argument_group(
            title="Finetuned task config options",
            description=
            "Finetuned task configuration for controlling finetuned task behavior, not required"
        )
K
kinghuin 已提交
120 121 122 123 124 125 126 127 128

        self.add_params_file_arg()
        self.add_autoft_config_arg()

        if not argv[1:]:
            self.help()
            return False

        self.args = self.parser.parse_args(argv[1:])
Z
zhangxuefei 已提交
129 130 131 132
        options_str = ""
        if self.args.opts is not None:
            options_str = self.convert_to_other_options(self.args.opts)

S
Steffy-zxf 已提交
133 134 135 136
        device_ids = self.args.gpu.strip().split(",")
        device_ids = [int(device_id) for device_id in device_ids]

        if self.args.evaluator.lower() == "fulltrail":
Z
zhangxuefei 已提交
137 138 139 140
            evaluator = FullTrailEvaluator(
                self.args.param_file,
                self.fintunee_script,
                options_str=options_str)
S
Steffy-zxf 已提交
141 142
        elif self.args.evaluator.lower() == "populationbased":
            evaluator = PopulationBasedEvaluator(
Z
zhangxuefei 已提交
143 144 145
                self.args.param_file,
                self.fintunee_script,
                options_str=options_str)
K
kinghuin 已提交
146 147
        else:
            raise ValueError(
S
Steffy-zxf 已提交
148
                "The evaluate %s is not defined!" % self.args.evaluator)
K
kinghuin 已提交
149

S
Steffy-zxf 已提交
150 151 152
        if self.args.tuning_strategy.lower() == "hazero":
            autoft = HAZero(
                evaluator,
S
Steffy-zxf 已提交
153
                cudas=device_ids,
S
Steffy-zxf 已提交
154 155 156 157 158
                popsize=self.args.popsize,
                output_dir=self.args.output_dir)
        elif self.args.tuning_strategy.lower() == "pshe2":
            autoft = PSHE2(
                evaluator,
S
Steffy-zxf 已提交
159
                cudas=device_ids,
S
Steffy-zxf 已提交
160 161 162 163 164
                popsize=self.args.popsize,
                output_dir=self.args.output_dir)
        else:
            raise ValueError("The tuning strategy %s is not defined!" %
                             self.args.tuning_strategy)
K
kinghuin 已提交
165 166

        run_round_cnt = 0
Z
zhangxuefei 已提交
167
        solutions_modeldirs = {}
K
kinghuin 已提交
168 169 170 171 172
        print("PaddleHub Autofinetune starts.")
        while (not autoft.is_stop()) and run_round_cnt < self.args.round:
            print("PaddleHub Autofinetune starts round at %s." % run_round_cnt)
            output_dir = autoft._output_dir + "/round" + str(run_round_cnt)
            res = autoft.step(output_dir)
Z
zhangxuefei 已提交
173
            solutions_modeldirs.update(res)
K
kinghuin 已提交
174 175 176
            evaluator.new_round()
            run_round_cnt = run_round_cnt + 1
        print("PaddleHub Autofinetune ends.")
S
Steffy-zxf 已提交
177

178 179 180
        best_hparams_origin = autoft.get_best_hparams()
        best_hparams_origin = autoft.mpi.bcast(best_hparams_origin)

Z
zhangxuefei 已提交
181
        with open(autoft._output_dir + "/log_file.txt", "w") as f:
182
            best_hparams = evaluator.convert_params(best_hparams_origin)
S
Steffy-zxf 已提交
183 184 185 186 187
            print("The final best hyperparameters:")
            f.write("The final best hyperparameters:\n")
            for index, hparam_name in enumerate(autoft.hparams_name_list):
                print("%s=%s" % (hparam_name, best_hparams[index]))
                f.write(hparam_name + "\t:\t" + str(best_hparams[index]) + "\n")
S
Steffy-zxf 已提交
188

H
hj 已提交
189 190
            best_hparams_dir, best_hparams_rank = solutions_modeldirs[tuple(
                best_hparams_origin)]
191

S
Steffy-zxf 已提交
192 193
            print("The final best eval score is %s." %
                  autoft.get_best_eval_value())
194 195 196

            if autoft.mpi.multi_machine:
                print("The final best model parameters are saved as " +
H
hj 已提交
197 198
                      autoft._output_dir + "/best_model on rank " +
                      str(best_hparams_rank) + " .")
199 200 201
            else:
                print("The final best model parameters are saved as " +
                      autoft._output_dir + "/best_model .")
S
Steffy-zxf 已提交
202 203 204 205 206
            f.write("The final best eval score is %s.\n" %
                    autoft.get_best_eval_value())

            best_model_dir = autoft._output_dir + "/best_model"

207 208 209 210 211 212 213 214 215 216 217
            if autoft.mpi.rank == best_hparams_rank:
                shutil.copytree(best_hparams_dir, best_model_dir)

            if autoft.mpi.multi_machine:
                f.write(
                    "The final best model parameters are saved as ./best_model on rank " \
                    + str(best_hparams_rank) + " .")
                f.write("\t".join(autoft.hparams_name_list) +
                        "\tsaved_params_dir\trank\n")
            else:
                f.write(
H
hj 已提交
218 219
                    "The final best model parameters are saved as ./best_model ."
                )
220 221 222
                f.write("\t".join(autoft.hparams_name_list) +
                        "\tsaved_params_dir\n")

K
kinghuin 已提交
223
            print(
224
                "The related information about hyperparamemters searched are saved as %s/log_file.txt ."
Z
zhangxuefei 已提交
225
                % autoft._output_dir)
Z
zhangxuefei 已提交
226
            for solution, modeldir in solutions_modeldirs.items():
K
kinghuin 已提交
227 228
                param = evaluator.convert_params(solution)
                param = [str(p) for p in param]
229
                if autoft.mpi.multi_machine:
H
hj 已提交
230 231
                    f.write("\t".join(param) + "\t" + modeldir[0] + "\t" +
                            str(modeldir[1]) + "\n")
232 233
                else:
                    f.write("\t".join(param) + "\t" + modeldir[0] + "\n")
K
kinghuin 已提交
234 235 236 237 238

        return True


command = AutoFineTuneCommand.instance()