tdm_cluster_trainer.py 4.8 KB
Newer Older
C
chengmo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Training use fluid with one node only.
"""

from __future__ import print_function
T
tangwei 已提交
19

C
chengmo 已提交
20
import logging
T
tangwei 已提交
21

C
chengmo 已提交
22
import numpy as np
C
chengmo 已提交
23 24 25
import paddle.fluid as fluid
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet

26 27
from paddlerec.core.utils import envs
from paddlerec.core.trainers.cluster_trainer import ClusterTrainer
C
chengmo 已提交
28

C
chengmo 已提交
29 30 31 32
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
special_param = ["TDM_Tree_Travel", "TDM_Tree_Layer", "TDM_Tree_Info"]
C
chengmo 已提交
33 34


C
fix  
chengmo 已提交
35
class TDMClusterTrainer(ClusterTrainer):
C
chengmo 已提交
36
    def server(self, context):
C
chengmo 已提交
37
        namespace = "train.startup"
T
tangwei 已提交
38 39
        init_model_path = envs.get_global_env("cluster.init_model_path", "",
                                              namespace)
C
chengmo 已提交
40 41 42
        assert init_model_path != "", "Cluster train must has init_model for TDM"
        fleet.init_server(init_model_path)
        logger.info("TDM: load model from {}".format(init_model_path))
C
chengmo 已提交
43 44 45
        fleet.run_server()
        context['is_exit'] = True

C
chengmo 已提交
46 47 48
    def startup(self, context):
        self._exe.run(fleet.startup_program)

C
chengmo 已提交
49
        namespace = "train.startup"
T
tangwei 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
        load_tree = envs.get_global_env("tree.load_tree", True, namespace)

        self.tree_layer_path = envs.get_global_env("tree.tree_layer_path", "",
                                                   namespace)

        self.tree_travel_path = envs.get_global_env("tree.tree_travel_path",
                                                    "", namespace)

        self.tree_info_path = envs.get_global_env("tree.tree_info_path", "",
                                                  namespace)

        save_init_model = envs.get_global_env("cluster.save_init_model", False,
                                              namespace)
        init_model_path = envs.get_global_env("cluster.init_model_path", "",
                                              namespace)
C
chengmo 已提交
65 66

        if load_tree:
T
tangwei 已提交
67
            # covert tree to tensor, set it into Fluid's variable.
C
chengmo 已提交
68
            for param_name in special_param:
T
tangwei 已提交
69 70
                param_t = fluid.global_scope().find_var(param_name).get_tensor(
                )
C
chengmo 已提交
71
                param_array = self._tdm_prepare(param_name)
C
chengmo 已提交
72 73 74 75 76 77 78 79 80 81
                param_t.set(param_array.astype('int32'), self._place)

        if save_init_model:
            logger.info("Begin Save Init model.")
            fluid.io.save_persistables(
                executor=self._exe, dirname=init_model_path)
            logger.info("End Save Init model.")

        context['status'] = 'train_pass'

C
chengmo 已提交
82
    def _tdm_prepare(self, param_name):
C
chengmo 已提交
83
        if param_name == "TDM_Tree_Travel":
C
chengmo 已提交
84
            travel_array = self._tdm_travel_prepare()
C
chengmo 已提交
85 86
            return travel_array
        elif param_name == "TDM_Tree_Layer":
C
chengmo 已提交
87
            layer_array, _ = self._tdm_layer_prepare()
C
chengmo 已提交
88 89
            return layer_array
        elif param_name == "TDM_Tree_Info":
C
chengmo 已提交
90
            info_array = self._tdm_info_prepare()
C
chengmo 已提交
91 92 93 94
            return info_array
        else:
            raise " {} is not a special tdm param name".format(param_name)

C
chengmo 已提交
95
    def _tdm_travel_prepare(self):
C
chengmo 已提交
96 97
        """load tdm tree param from npy/list file"""
        travel_array = np.load(self.tree_travel_path)
T
tangwei 已提交
98 99
        logger.info("TDM Tree leaf node nums: {}".format(travel_array.shape[
            0]))
C
chengmo 已提交
100 101
        return travel_array

C
chengmo 已提交
102
    def _tdm_layer_prepare(self):
C
chengmo 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
        """load tdm tree param from npy/list file"""
        layer_list = []
        layer_list_flat = []
        with open(self.tree_layer_path, 'r') as fin:
            for line in fin.readlines():
                l = []
                layer = (line.split('\n'))[0].split(',')
                for node in layer:
                    if node:
                        layer_list_flat.append(node)
                        l.append(node)
                layer_list.append(l)
        layer_array = np.array(layer_list_flat)
        layer_array = layer_array.reshape([-1, 1])
        logger.info("TDM Tree max layer: {}".format(len(layer_list)))
        logger.info("TDM Tree layer_node_num_list: {}".format(
            [len(i) for i in layer_list]))
        return layer_array, layer_list

C
chengmo 已提交
122
    def _tdm_info_prepare(self):
C
chengmo 已提交
123 124 125
        """load tdm tree param from list file"""
        info_array = np.load(self.tree_info_path)
        return info_array