tdm_single_trainer.py 5.3 KB
Newer Older
C
chengmo 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Training use fluid with one node only.
"""

from __future__ import print_function
import logging

T
tangwei 已提交
22 23
import numpy as np
import paddle.fluid as fluid
24 25
from paddlerec.core.trainers.single_trainer import SingleTrainer
from paddlerec.core.utils import envs
T
tangwei 已提交
26

C
chengmo 已提交
27 28 29
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
C
chengmo 已提交
30 31
special_param = ["TDM_Tree_Travel", "TDM_Tree_Layer",
                 "TDM_Tree_Info", "TDM_Tree_Emb"]
C
chengmo 已提交
32 33


C
chengmo 已提交
34
class TDMSingleTrainer(SingleTrainer):
C
chengmo 已提交
35 36 37 38 39 40 41 42
    def startup(self, context):
        namespace = "train.startup"
        load_persistables = envs.get_global_env(
            "single.load_persistables", False, namespace)
        persistables_model_path = envs.get_global_env(
            "single.persistables_model_path", "", namespace)

        load_tree = envs.get_global_env(
C
fix  
chengmo 已提交
43
            "tree.load_tree", False, namespace)
C
chengmo 已提交
44
        self.tree_layer_path = envs.get_global_env(
C
fix  
chengmo 已提交
45
            "tree.tree_layer_path", "", namespace)
C
chengmo 已提交
46
        self.tree_travel_path = envs.get_global_env(
C
fix  
chengmo 已提交
47
            "tree.tree_travel_path", "", namespace)
C
chengmo 已提交
48
        self.tree_info_path = envs.get_global_env(
C
fix  
chengmo 已提交
49
            "tree.tree_info_path", "", namespace)
C
chengmo 已提交
50
        self.tree_emb_path = envs.get_global_env(
C
fix  
chengmo 已提交
51
            "tree.tree_emb_path", "", namespace)
C
chengmo 已提交
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68

        save_init_model = envs.get_global_env(
            "single.save_init_model", False, namespace)
        init_model_path = envs.get_global_env(
            "single.init_model_path", "", namespace)
        self._exe.run(fluid.default_startup_program())

        if load_persistables:
            # 从paddle二进制模型加载参数
            fluid.io.load_persistables(
                executor=self._exe,
                dirname=persistables_model_path,
                main_program=fluid.default_main_program())
            logger.info("Load persistables from \"{}\"".format(
                persistables_model_path))

        if load_tree:
T
tangwei 已提交
69
            # covert tree to tensor, set it into Fluid's variable.
C
chengmo 已提交
70
            for param_name in special_param:
C
chengmo 已提交
71
                param_t = fluid.global_scope().find_var(param_name).get_tensor()
C
chengmo 已提交
72
                param_array = self._tdm_prepare(param_name)
C
chengmo 已提交
73
                if param_name == 'TDM_Tree_Emb':
C
chengmo 已提交
74
                    param_t.set(param_array.astype('float32'), self._place)
C
chengmo 已提交
75
                else:
C
chengmo 已提交
76
                    param_t.set(param_array.astype('int32'), self._place)
C
chengmo 已提交
77 78 79 80 81 82 83 84 85

        if save_init_model:
            logger.info("Begin Save Init model.")
            fluid.io.save_persistables(
                executor=self._exe, dirname=init_model_path)
            logger.info("End Save Init model.")

        context['status'] = 'train_pass'

C
chengmo 已提交
86
    def _tdm_prepare(self, param_name):
C
chengmo 已提交
87
        if param_name == "TDM_Tree_Travel":
C
chengmo 已提交
88
            travel_array = self._tdm_travel_prepare()
C
chengmo 已提交
89 90
            return travel_array
        elif param_name == "TDM_Tree_Layer":
C
chengmo 已提交
91
            layer_array, _ = self._tdm_layer_prepare()
C
chengmo 已提交
92 93
            return layer_array
        elif param_name == "TDM_Tree_Info":
C
chengmo 已提交
94
            info_array = self._tdm_info_prepare()
C
chengmo 已提交
95 96
            return info_array
        elif param_name == "TDM_Tree_Emb":
C
chengmo 已提交
97
            emb_array = self._tdm_emb_prepare()
C
chengmo 已提交
98 99 100 101
            return emb_array
        else:
            raise " {} is not a special tdm param name".format(param_name)

C
chengmo 已提交
102
    def _tdm_travel_prepare(self):
C
chengmo 已提交
103 104 105 106 107 108
        """load tdm tree param from npy/list file"""
        travel_array = np.load(self.tree_travel_path)
        logger.info("TDM Tree leaf node nums: {}".format(
            travel_array.shape[0]))
        return travel_array

C
chengmo 已提交
109
    def _tdm_emb_prepare(self):
C
chengmo 已提交
110 111 112 113 114 115
        """load tdm tree param from npy/list file"""
        emb_array = np.load(self.tree_emb_path)
        logger.info("TDM Tree node nums from emb: {}".format(
            emb_array.shape[0]))
        return emb_array

C
chengmo 已提交
116
    def _tdm_layer_prepare(self):
C
chengmo 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
        """load tdm tree param from npy/list file"""
        layer_list = []
        layer_list_flat = []
        with open(self.tree_layer_path, 'r') as fin:
            for line in fin.readlines():
                l = []
                layer = (line.split('\n'))[0].split(',')
                for node in layer:
                    if node:
                        layer_list_flat.append(node)
                        l.append(node)
                layer_list.append(l)
        layer_array = np.array(layer_list_flat)
        layer_array = layer_array.reshape([-1, 1])
        logger.info("TDM Tree max layer: {}".format(len(layer_list)))
        logger.info("TDM Tree layer_node_num_list: {}".format(
            [len(i) for i in layer_list]))
        return layer_array, layer_list

C
chengmo 已提交
136
    def _tdm_info_prepare(self):
C
chengmo 已提交
137 138 139
        """load tdm tree param from list file"""
        info_array = np.load(self.tree_info_path)
        return info_array