trainer_factory.py 7.0 KB
Newer Older
D
dongdaxiang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
X
xujiaqi01 已提交
14
"""Defination of TrainerFactory."""
D
dongdaxiang 已提交
15

16 17
import threading
import time
D
Dong Daxiang 已提交
18
import logging
19 20
import numpy as np

D
Dong Daxiang 已提交
21 22
logging.basicConfig()

H
hutuxian 已提交
23 24
from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer
from .device_worker import Hogwild, DownpourSGD, Section
D
Dong Daxiang 已提交
25 26
from .framework import Variable
from multiprocessing import Process, Manager
X
xjqbest 已提交
27

28
__all__ = ["TrainerFactory", "FetchHandler", "FetchHandlerMonitor"]
D
dongdaxiang 已提交
29 30 31


class TrainerFactory(object):
X
xujiaqi01 已提交
32 33 34 35 36 37
    """
    Create trainer and device worker.
    If opt_info is not None, it will get configs from opt_info,
    otherwise create MultiTrainer and Hogwild.
    """

D
dongdaxiang 已提交
38 39 40
    def __init__(self):
        pass

41
    def _create_trainer(self, opt_info=None):
D
dongdaxiang 已提交
42 43
        trainer = None
        device_worker = None
D
dongdaxiang 已提交
44
        if opt_info == None:
D
dongdaxiang 已提交
45 46 47
            # default is MultiTrainer + Hogwild
            trainer = MultiTrainer()
            device_worker = Hogwild()
48
            trainer._set_device_worker(device_worker)
D
dongdaxiang 已提交
49
        else:
D
dongdaxiang 已提交
50 51 52 53
            trainer_class = opt_info["trainer"]
            device_worker_class = opt_info["device_worker"]
            trainer = globals()[trainer_class]()
            device_worker = globals()[device_worker_class]()
H
hutuxian 已提交
54 55 56
            if "fleet_desc" in opt_info:
                device_worker._set_fleet_desc(opt_info["fleet_desc"])
                trainer._set_fleet_desc(opt_info["fleet_desc"])
X
xujiaqi01 已提交
57 58
                if opt_info.get("use_cvm") is not None:
                    trainer._set_use_cvm(opt_info["use_cvm"])
59 60
                if opt_info.get("no_cvm") is not None:
                    trainer._set_no_cvm(opt_info["no_cvm"])
X
xujiaqi01 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
                if opt_info.get("scale_datanorm") is not None:
                    trainer._set_scale_datanorm(opt_info["scale_datanorm"])
                if opt_info.get("dump_slot") is not None:
                    trainer._set_dump_slot(opt_info["dump_slot"])
                if opt_info.get("mpi_rank") is not None:
                    trainer._set_mpi_rank(opt_info["mpi_rank"])
                if opt_info.get("mpi_size") is not None:
                    trainer._set_mpi_size(opt_info["mpi_size"])
                if opt_info.get("dump_fields") is not None:
                    trainer._set_dump_fields(opt_info["dump_fields"])
                if opt_info.get("dump_fields_path") is not None:
                    trainer._set_dump_fields_path(opt_info["dump_fields_path"])
                if opt_info.get("dump_file_num") is not None:
                    trainer._set_dump_file_num(opt_info["dump_file_num"])
                if opt_info.get("dump_converter") is not None:
                    trainer._set_dump_converter(opt_info["dump_converter"])
                if opt_info.get("adjust_ins_weight") is not None:
                    trainer._set_adjust_ins_weight(opt_info[
                        "adjust_ins_weight"])
                if opt_info.get("copy_table") is not None:
                    trainer._set_copy_table_config(opt_info["copy_table"])
                if opt_info.get("check_nan_var_names") is not None:
                    trainer._set_check_nan_var_names(opt_info[
                        "check_nan_var_names"])
                if opt_info.get("dump_param") is not None:
                    trainer._set_dump_param(opt_info["dump_param"])
87
            trainer._set_device_worker(device_worker)
D
dongdaxiang 已提交
88
        return trainer
89 90 91


class FetchHandlerMonitor(object):
X
xujiaqi01 已提交
92 93 94 95 96
    """
    Defination of FetchHandlerMonitor class,
    it's for fetch handler.
    """

97 98 99
    def __init__(self, scope, handler):
        self.fetch_instance = handler
        self.fetch_thread = threading.Thread(
D
Dong Daxiang 已提交
100 101
            target=self.handler_launch_func, args=(scope, self.fetch_instance))
        self.running_lock = threading.Lock()
102 103
        self.running = False

D
Dong Daxiang 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
    def handler_launch_func(self, scope, handler):
        fetch_instance = handler
        period_secs = fetch_instance.period_secs
        var_name_to_key = {}
        for key in fetch_instance.var_dict:
            if isinstance(fetch_instance.var_dict[key], Variable):
                var_name_to_key[fetch_instance.var_dict[key].name] = key
            else:
                logging.warning("the value of {} is not a Variable".format(key))
                var_name_to_key["None.var"] = key
        elapsed_secs = 0
        while True:
            self.running_lock.acquire()
            if self.running == False:
                break
            if elapsed_secs < period_secs:
                # TODO(guru4elephant): needs customized condition
                time.sleep(1)
                elapsed_secs += 1
            else:
                elapsed_secs = 0
                fetch_dict = {}
                for key in var_name_to_key:
                    var = scope.find_var(key)
                    fetch_dict[key] = var
                    if var == None:
                        logging.warning("{} value currently not available".
                                        format(var_name_to_key[key]))
                res_dict = {}
                for key in fetch_dict:
                    user_name = var_name_to_key[key]
                    if fetch_dict[key] == None:
                        res_dict[user_name] = None
                        continue
                    else:
                        res_dict[user_name] = fetch_dict[key].get_tensor()

                    lod = res_dict[user_name].lod()
                    if len(lod) > 0:
                        raise RuntimeError("Some of your fetched tensors \
                                            hold LoD information. \
                                            They can not be completely cast \
                                            to Python ndarray. We can \
                                            not return LoDTensor itself directly, \
                                            please choose another targets")
                    if res_dict[user_name]._is_initialized():
                        res_dict[user_name] = np.array(res_dict[user_name])
                    else:
                        res_dict[user_name] = None
                fetch_instance.handler(res_dict)
            self.running_lock.release()

156
    def start(self):
X
xujiaqi01 已提交
157 158 159 160
        """
        start monitor,
        it will start a monitor thread.
        """
D
Dong Daxiang 已提交
161
        self.running_lock.acquire()
162
        self.running = True
D
Dong Daxiang 已提交
163
        self.running_lock.release()
164 165 166 167
        self.fetch_thread.setDaemon(True)
        self.fetch_thread.start()

    def stop(self):
D
Dong Daxiang 已提交
168
        self.running_lock.acquire()
169
        self.running = False
D
Dong Daxiang 已提交
170
        self.running_lock.release()