tdm_cluster_trainer.py 4.6 KB
Newer Older
C
chengmo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Training use fluid with one node only.
"""

from __future__ import print_function
C
chengmo 已提交
20
import logging
C
chengmo 已提交
21
import numpy as np
C
chengmo 已提交
22 23 24
import paddle.fluid as fluid
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet

25 26
from paddlerec.core.utils import envs
from paddlerec.core.trainers.cluster_trainer import ClusterTrainer
C
chengmo 已提交
27 28


C
chengmo 已提交
29 30 31 32
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
special_param = ["TDM_Tree_Travel", "TDM_Tree_Layer", "TDM_Tree_Info"]
C
chengmo 已提交
33 34


C
fix  
chengmo 已提交
35
class TDMClusterTrainer(ClusterTrainer):
C
chengmo 已提交
36
    def server(self, context):
C
chengmo 已提交
37
        namespace = "train.startup"
C
chengmo 已提交
38 39 40 41 42
        init_model_path = envs.get_global_env(
            "cluster.init_model_path", "", namespace)
        assert init_model_path != "", "Cluster train must has init_model for TDM"
        fleet.init_server(init_model_path)
        logger.info("TDM: load model from {}".format(init_model_path))
C
chengmo 已提交
43 44 45
        fleet.run_server()
        context['is_exit'] = True

C
chengmo 已提交
46 47 48
    def startup(self, context):
        self._exe.run(fleet.startup_program)

C
chengmo 已提交
49 50
        namespace = "train.startup"
        load_tree = envs.get_global_env(
C
fix  
chengmo 已提交
51
            "tree.load_tree", True, namespace)
C
chengmo 已提交
52
        self.tree_layer_path = envs.get_global_env(
C
fix  
chengmo 已提交
53
            "tree.tree_layer_path", "", namespace)
C
chengmo 已提交
54
        self.tree_travel_path = envs.get_global_env(
C
fix  
chengmo 已提交
55
            "tree.tree_travel_path", "", namespace)
C
chengmo 已提交
56
        self.tree_info_path = envs.get_global_env(
C
fix  
chengmo 已提交
57
            "tree.tree_info_path", "", namespace)
C
chengmo 已提交
58 59 60 61 62 63 64

        save_init_model = envs.get_global_env(
            "cluster.save_init_model", False, namespace)
        init_model_path = envs.get_global_env(
            "cluster.init_model_path", "", namespace)

        if load_tree:
T
tangwei 已提交
65
            # covert tree to tensor, set it into Fluid's variable.
C
chengmo 已提交
66 67
            for param_name in special_param:
                param_t = fluid.global_scope().find_var(param_name).get_tensor()
C
chengmo 已提交
68
                param_array = self._tdm_prepare(param_name)
C
chengmo 已提交
69 70 71 72 73 74 75 76 77 78
                param_t.set(param_array.astype('int32'), self._place)

        if save_init_model:
            logger.info("Begin Save Init model.")
            fluid.io.save_persistables(
                executor=self._exe, dirname=init_model_path)
            logger.info("End Save Init model.")

        context['status'] = 'train_pass'

C
chengmo 已提交
79
    def _tdm_prepare(self, param_name):
C
chengmo 已提交
80
        if param_name == "TDM_Tree_Travel":
C
chengmo 已提交
81
            travel_array = self._tdm_travel_prepare()
C
chengmo 已提交
82 83
            return travel_array
        elif param_name == "TDM_Tree_Layer":
C
chengmo 已提交
84
            layer_array, _ = self._tdm_layer_prepare()
C
chengmo 已提交
85 86
            return layer_array
        elif param_name == "TDM_Tree_Info":
C
chengmo 已提交
87
            info_array = self._tdm_info_prepare()
C
chengmo 已提交
88 89 90 91
            return info_array
        else:
            raise " {} is not a special tdm param name".format(param_name)

C
chengmo 已提交
92
    def _tdm_travel_prepare(self):
C
chengmo 已提交
93 94 95 96 97 98
        """load tdm tree param from npy/list file"""
        travel_array = np.load(self.tree_travel_path)
        logger.info("TDM Tree leaf node nums: {}".format(
            travel_array.shape[0]))
        return travel_array

C
chengmo 已提交
99
    def _tdm_layer_prepare(self):
C
chengmo 已提交
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        """load tdm tree param from npy/list file"""
        layer_list = []
        layer_list_flat = []
        with open(self.tree_layer_path, 'r') as fin:
            for line in fin.readlines():
                l = []
                layer = (line.split('\n'))[0].split(',')
                for node in layer:
                    if node:
                        layer_list_flat.append(node)
                        l.append(node)
                layer_list.append(l)
        layer_array = np.array(layer_list_flat)
        layer_array = layer_array.reshape([-1, 1])
        logger.info("TDM Tree max layer: {}".format(len(layer_list)))
        logger.info("TDM Tree layer_node_num_list: {}".format(
            [len(i) for i in layer_list]))
        return layer_array, layer_list

C
chengmo 已提交
119
    def _tdm_info_prepare(self):
C
chengmo 已提交
120 121 122
        """load tdm tree param from list file"""
        info_array = np.load(self.tree_info_path)
        return info_array