trainer_factory.py 7.0 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
"""Defination of TrainerFactory."""
D
dongdaxiang 已提交
15

16 17
import threading
import time
18
import logging
19 20
import numpy as np

21 22
logging.basicConfig()

H
hutuxian 已提交
23 24
from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer
from .device_worker import Hogwild, DownpourSGD, Section
25 26
from .framework import Variable
from multiprocessing import Process, Manager
X
xjqbest 已提交
27

28
__all__ = ["TrainerFactory", "FetchHandler", "FetchHandlerMonitor"]
D
dongdaxiang 已提交
29 30 31


class TrainerFactory(object):
32 33 34 35 36 37
    """
    Create trainer and device worker.
    If opt_info is not None, it will get configs from opt_info,
    otherwise create MultiTrainer and Hogwild.
    """

D
dongdaxiang 已提交
38 39 40
    def __init__(self):
        pass

41
    def _create_trainer(self, opt_info=None):
D
dongdaxiang 已提交
42 43
        trainer = None
        device_worker = None
D
dongdaxiang 已提交
44
        if opt_info == None:
D
dongdaxiang 已提交
45 46 47
            # default is MultiTrainer + Hogwild
            trainer = MultiTrainer()
            device_worker = Hogwild()
48
            trainer._set_device_worker(device_worker)
D
dongdaxiang 已提交
49
        else:
D
dongdaxiang 已提交
50 51 52 53
            trainer_class = opt_info["trainer"]
            device_worker_class = opt_info["device_worker"]
            trainer = globals()[trainer_class]()
            device_worker = globals()[device_worker_class]()
H
hutuxian 已提交
54 55 56
            if "fleet_desc" in opt_info:
                device_worker._set_fleet_desc(opt_info["fleet_desc"])
                trainer._set_fleet_desc(opt_info["fleet_desc"])
57 58
                if opt_info.get("use_cvm") is not None:
                    trainer._set_use_cvm(opt_info["use_cvm"])
59 60
                if opt_info.get("no_cvm") is not None:
                    trainer._set_no_cvm(opt_info["no_cvm"])
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
                if opt_info.get("scale_datanorm") is not None:
                    trainer._set_scale_datanorm(opt_info["scale_datanorm"])
                if opt_info.get("dump_slot") is not None:
                    trainer._set_dump_slot(opt_info["dump_slot"])
                if opt_info.get("mpi_rank") is not None:
                    trainer._set_mpi_rank(opt_info["mpi_rank"])
                if opt_info.get("mpi_size") is not None:
                    trainer._set_mpi_size(opt_info["mpi_size"])
                if opt_info.get("dump_fields") is not None:
                    trainer._set_dump_fields(opt_info["dump_fields"])
                if opt_info.get("dump_fields_path") is not None:
                    trainer._set_dump_fields_path(opt_info["dump_fields_path"])
                if opt_info.get("dump_file_num") is not None:
                    trainer._set_dump_file_num(opt_info["dump_file_num"])
                if opt_info.get("dump_converter") is not None:
                    trainer._set_dump_converter(opt_info["dump_converter"])
                if opt_info.get("adjust_ins_weight") is not None:
                    trainer._set_adjust_ins_weight(opt_info[
                        "adjust_ins_weight"])
                if opt_info.get("copy_table") is not None:
                    trainer._set_copy_table_config(opt_info["copy_table"])
                if opt_info.get("check_nan_var_names") is not None:
                    trainer._set_check_nan_var_names(opt_info[
                        "check_nan_var_names"])
                if opt_info.get("dump_param") is not None:
                    trainer._set_dump_param(opt_info["dump_param"])
87
            trainer._set_device_worker(device_worker)
D
dongdaxiang 已提交
88
        return trainer
89 90 91


class FetchHandlerMonitor(object):
92 93 94 95 96
    """
    Defination of FetchHandlerMonitor class,
    it's for fetch handler.
    """

97 98 99
    def __init__(self, scope, handler):
        self.fetch_instance = handler
        self.fetch_thread = threading.Thread(
100 101
            target=self.handler_launch_func, args=(scope, self.fetch_instance))
        self.running_lock = threading.Lock()
102 103
        self.running = False

104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
    def handler_launch_func(self, scope, handler):
        fetch_instance = handler
        period_secs = fetch_instance.period_secs
        var_name_to_key = {}
        for key in fetch_instance.var_dict:
            if isinstance(fetch_instance.var_dict[key], Variable):
                var_name_to_key[fetch_instance.var_dict[key].name] = key
            else:
                logging.warning("the value of {} is not a Variable".format(key))
                var_name_to_key["None.var"] = key
        elapsed_secs = 0
        while True:
            self.running_lock.acquire()
            if self.running == False:
                break
            if elapsed_secs < period_secs:
                # TODO(guru4elephant): needs customized condition
                time.sleep(1)
                elapsed_secs += 1
            else:
                elapsed_secs = 0
                fetch_dict = {}
                for key in var_name_to_key:
                    var = scope.find_var(key)
                    fetch_dict[key] = var
                    if var == None:
                        logging.warning("{} value currently not available".
                                        format(var_name_to_key[key]))
                res_dict = {}
                for key in fetch_dict:
                    user_name = var_name_to_key[key]
                    if fetch_dict[key] == None:
                        res_dict[user_name] = None
                        continue
                    else:
                        res_dict[user_name] = fetch_dict[key].get_tensor()

                    lod = res_dict[user_name].lod()
                    if len(lod) > 0:
                        raise RuntimeError("Some of your fetched tensors \
                                            hold LoD information. \
                                            They can not be completely cast \
                                            to Python ndarray. We can \
                                            not return LoDTensor itself directly, \
                                            please choose another targets")
                    if res_dict[user_name]._is_initialized():
                        res_dict[user_name] = np.array(res_dict[user_name])
                    else:
                        res_dict[user_name] = None
                fetch_instance.handler(res_dict)
            self.running_lock.release()

156
    def start(self):
157 158 159 160
        """
        start monitor,
        it will start a monitor thread.
        """
161
        self.running_lock.acquire()
162
        self.running = True
163
        self.running_lock.release()
164 165 166 167
        self.fetch_thread.setDaemon(True)
        self.fetch_thread.start()

    def stop(self):
168
        self.running_lock.acquire()
169
        self.running = False
170
        self.running_lock.release()