single_trainer.py 10.7 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Training use fluid with one node only.
"""

from __future__ import print_function
T
tangwei 已提交
19

T
tangwei 已提交
20
import time
T
tangwei 已提交
21
import logging
X
test  
xjqbest 已提交
22
import os
T
tangwei 已提交
23 24
import paddle.fluid as fluid

25 26
from paddlerec.core.trainers.transpiler_trainer import TranspileTrainer
from paddlerec.core.utils import envs
X
test  
xjqbest 已提交
27
from paddlerec.core.reader import SlotReader
T
tangwei 已提交
28 29 30 31 32 33

logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)


T
tangwei 已提交
34
class SingleTrainer(TranspileTrainer):
X
test  
xjqbest 已提交
35 36 37 38
    def __init__(self, config=None):
        super(TranspileTrainer, self).__init__(config)
        self._env = self._config#envs.get_global_envs()
        #device = envs.get_global_env("train.device", "cpu")
X
fix  
xjqbest 已提交
39
        device = envs.get_global_env("device")#self._env["device"]
X
test  
xjqbest 已提交
40 41 42 43 44 45 46 47
        if device == 'gpu':
            self._place = fluid.CUDAPlace(0)
        elif device == 'cpu':
            self._place = fluid.CPUPlace()
        self._exe = fluid.Executor(self._place)
        self.processor_register()
        self._model = {}
        self._dataset = {}
X
fix  
xjqbest 已提交
48 49
        envs.set_global_envs(self._config)
        envs.update_workspace()
X
test  
xjqbest 已提交
50 51 52
        #self.inference_models = []
        #self.increment_models = []

T
tangwei 已提交
53 54 55
    def processor_register(self):
        self.regist_context_processor('uninit', self.instance)
        self.regist_context_processor('init_pass', self.init)
C
chengmo 已提交
56
        self.regist_context_processor('startup_pass', self.startup)
T
tangwei 已提交
57

X
test  
xjqbest 已提交
58 59 60 61 62 63 64 65 66
        #if envs.get_platform() == "LINUX" and envs.get_global_env(
        #        "dataset_class", None, "train.reader") != "DataLoader":

        self.regist_context_processor('train_pass', self.executor_train)
#        if envs.get_platform() == "LINUX" and envs.get_global_env(
#                 ""
#            self.regist_context_processor('train_pass', self.dataset_train)
#        else:
#            self.regist_context_processor('train_pass', self.dataloader_train)
T
tangwei 已提交
67

X
test  
xjqbest 已提交
68
        #self.regist_context_processor('infer_pass', self.infer)
T
tangwei 已提交
69 70
        self.regist_context_processor('terminal_pass', self.terminal)

X
test  
xjqbest 已提交
71 72
    def instance(self, context):
        context['status'] = 'init_pass'
T
tangwei 已提交
73

X
test  
xjqbest 已提交
74 75
    def dataloader_train(self, context):
        pass
T
tangwei 已提交
76

X
test  
xjqbest 已提交
77 78 79 80 81 82 83
    def dataset_train(self, context):
        pass

    #def _get_optmizer(self, cost):
    #    if self._env["hyper_parameters"]["optimizer"]["class"] == "Adam":
            
    def _create_dataset(self, dataset_name):
X
fix  
xjqbest 已提交
84 85 86 87 88
        #config_dict = envs.get_global_env("dataset." + dataset_name)
        #for i in self._env["dataset"]:
        #    if i["name"] == dataset_name:
        #        config_dict = i
        #        break
X
test  
xjqbest 已提交
89
        #reader_ins = SlotReader(self._config_yaml)
X
fix  
xjqbest 已提交
90 91 92 93 94 95 96 97 98
        name = "dataset." + dataset_name + "."
        sparse_slots = envs.get_global_env(name + "sparse_slots")#config_dict.get("sparse_slots")#config_dict["sparse_slots"]
        dense_slots = envs.get_global_env(name + "dense_slots")#config_dict.get("dense_slots")#config_dict["dense_slots"]
        thread_num = envs.get_global_env(name + "thread_num")
        batch_size = envs.get_global_env(name + "batch_size")
        reader_type = envs.get_global_env(name + "type")
        if envs.get_platform() != "LINUX":
            print("platform ", envs.get_platform(), " change reader to DataLoader")
            reader_type = "DataLoader"
X
test  
xjqbest 已提交
99
        padding = 0
X
fix  
xjqbest 已提交
100

X
test  
xjqbest 已提交
101 102 103 104 105 106
        reader = envs.path_adapter("paddlerec.core.utils") + "/dataset_instance.py"
        #reader = "{workspace}/paddlerec/core/utils/dataset_instance.py".replace("{workspace}", envs.path_adapter(self._env["workspace"]))
        pipe_cmd = "python {} {} {} {} {} {} {} {}".format(
            reader, "slot", "slot", self._config_yaml, "fake", \
            sparse_slots.replace(" ", "#"), dense_slots.replace(" ", "#"), str(padding))

X
fix  
xjqbest 已提交
107 108 109 110 111 112
        #print(config_dict["type"])
        type_name = envs.get_global_env(name + "type")
        if type_name == "QueueDataset":
        #if config_dict["type"] == "QueueDataset":
            dataset = fluid.DatasetFactory().create_dataset()
            dataset.set_batch_size(envs.get_global_env(name + "batch_size"))
X
test  
xjqbest 已提交
113 114 115
            #dataset.set_thread(config_dict["thread_num"])
            #dataset.set_hdfs_config(config_dict["data_fs_name"], config_dict["data_fs_ugi"])
            dataset.set_pipe_command(pipe_cmd)
X
fix  
xjqbest 已提交
116 117 118
            #print(pipe_cmd)
            train_data_path = envs.get_global_env(name + "data_path")
            #config_dict["data_path"].replace("{workspace}", envs.path_adapter(self._env["workspace"]))
X
test  
xjqbest 已提交
119 120 121 122
            file_list = [
                os.path.join(train_data_path, x)
                for x in os.listdir(train_data_path)
            ]
X
fix  
xjqbest 已提交
123
            #print(file_list)
X
test  
xjqbest 已提交
124 125 126 127 128 129 130
            dataset.set_filelist(file_list)
            for model_dict in self._env["executor"]:
                if model_dict["dataset_name"] == dataset_name:
                    model = self._model[model_dict["name"]][3]
                    inputs = model.get_inputs()
                    dataset.set_use_var(inputs)
                    break
C
chengmo 已提交
131
        else:
X
test  
xjqbest 已提交
132 133 134 135 136
            pass

        return dataset

    def init(self, context):
X
fix  
xjqbest 已提交
137
        #for model_dict in self._env["executor"]:
X
test  
xjqbest 已提交
138 139 140 141 142
        for model_dict in self._env["executor"]:
            self._model[model_dict["name"]] = [None] * 4
            train_program = fluid.Program()
            startup_program = fluid.Program()
            scope = fluid.Scope()
X
fix  
xjqbest 已提交
143 144 145
            opt_name = envs.get_global_env("hyper_parameters.optimizer.class")
            opt_lr = envs.get_global_env("hyper_parameters.optimizer.learning_rate")
            opt_strategy = envs.get_global_env("hyper_parameters.optimizer.strategy")
X
test  
xjqbest 已提交
146 147 148 149 150
            with fluid.program_guard(train_program, startup_program):
                with fluid.unique_name.guard():
                    model_path = model_dict["model"].replace("{workspace}", envs.path_adapter(self._env["workspace"]))
                    model = envs.lazy_instance_by_fliename(model_path, "Model")(self._env)
                    model._data_var = model.input_data(dataset_name=model_dict["dataset_name"])
X
fix  
xjqbest 已提交
151
                    model.net(None)
X
test  
xjqbest 已提交
152 153 154 155 156 157 158 159 160 161 162
                    optimizer = model._build_optimizer(opt_name, opt_lr, opt_strategy)
                    optimizer.minimize(model._cost)
            self._model[model_dict["name"]][0] = train_program
            self._model[model_dict["name"]][1] = startup_program
            self._model[model_dict["name"]][2] = scope
            self._model[model_dict["name"]][3] = model

        for dataset in self._env["dataset"]:
            self._dataset[dataset["name"]] = self._create_dataset(dataset["name"])

        context['status'] = 'startup_pass'
C
chengmo 已提交
163 164

    def startup(self, context):
X
test  
xjqbest 已提交
165 166 167
        for model_dict in self._env["executor"]:
            with fluid.scope_guard(self._model[model_dict["name"]][2]):            
                self._exe.run(self._model[model_dict["name"]][1])
T
tangwei 已提交
168 169
        context['status'] = 'train_pass'

X
test  
xjqbest 已提交
170 171 172 173 174
    def executor_train(self, context):
        epochs = int(self._env["epochs"])
        for j in range(epochs):
            for model_dict in self._env["executor"]:
                reader_name = model_dict["dataset_name"]
X
fix  
xjqbest 已提交
175 176 177 178 179
                #dataset = envs.get_global_env("dataset." + reader_name)
                name = "dataset." + reader_name + "."
                begin_time = time.time()
                #if dataset["type"] == "DataLoader":
                if envs.get_global_env(name + "type") == "DataLoader":
X
test  
xjqbest 已提交
180 181 182
                    self._executor_dataloader_train(model_dict)
                else:
                    self._executor_dataset_train(model_dict)
X
fix  
xjqbest 已提交
183 184 185
                end_time = time.time()
                seconds = end_time - begin_time
            print("epoch {} done, time elasped: {}".format(j, seconds))
X
test  
xjqbest 已提交
186 187 188 189 190 191 192 193
        context['status'] = "terminal_pass"

    def _executor_dataset_train(self, model_dict):
        reader_name = model_dict["dataset_name"]
        model_name = model_dict["name"]
        model_class = self._model[model_name][3]
        fetch_vars = []
        fetch_alias = []
X
fix  
xjqbest 已提交
194
        fetch_period = 20
X
test  
xjqbest 已提交
195 196 197 198 199
        metrics = model_class.get_metrics()
        if metrics:
            fetch_vars = metrics.values()
            fetch_alias = metrics.keys()
        scope = self._model[model_name][2]
X
fix  
xjqbest 已提交
200
        program = self._model[model_name][0]
X
test  
xjqbest 已提交
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
        reader = self._dataset[reader_name]
        with fluid.scope_guard(scope):
            self._exe.train_from_dataset(
                program=program,
                dataset=reader,
                fetch_list=fetch_vars,
                fetch_info=fetch_alias,
                print_period=fetch_period)


    def _executor_dataloader_train(self, model_dict):
        reader_name = model_dict["dataset_name"]
        model_name = model_dict["name"]
        model_class = self._model[model][3]
        self._model[model_name][1] = fluid.compiler.CompiledProgram(
            self._model[model_name][1]).with_data_parallel(loss_name=model_class.get_avg_cost().name)
        fetch_vars = []
        fetch_alias = []
X
fix  
xjqbest 已提交
219 220
        fetch_period = 20
        metrics = model_class.get_metrics()
X
test  
xjqbest 已提交
221 222 223 224 225 226 227 228 229 230 231 232
        if metrics:
            fetch_vars = metrics.values()
            fetch_alias = metrics.keys()
        metrics_varnames = []
        metrics_format = []
        metrics_format.append("{}: {{}}".format("epoch"))
        metrics_format.append("{}: {{}}".format("batch"))
        for name, var in model_class.items():
            metrics_varnames.append(var.name)
            metrics_format.append("{}: {{}}".format(name))
        metrics_format = ", ".join(metrics_format)

X
fix  
xjqbest 已提交
233
        reader = self._dataset[reader_name]
X
test  
xjqbest 已提交
234 235
        reader.start()
        batch_id = 0
X
fix  
xjqbest 已提交
236 237 238
        scope = self._model[model_name][2]
        prorgram = self._model[model_name][0]
        with fluid.scope_guard(scope):
T
tangwei 已提交
239 240
            try:
                while True:
T
tangwei 已提交
241 242
                    metrics_rets = self._exe.run(program=program,
                                                 fetch_list=metrics_varnames)
T
tangwei 已提交
243 244

                    metrics = [epoch, batch_id]
T
tangwei 已提交
245
                    metrics.extend(metrics_rets)
T
tangwei 已提交
246

M
malin10 已提交
247
                    if batch_id % self.fetch_period == 0 and batch_id != 0:
T
tangwei 已提交
248
                        print(metrics_format.format(*metrics))
T
tangwei 已提交
249 250 251
                    batch_id += 1
            except fluid.core.EOFException:
                reader.reset()
T
tangwei 已提交
252 253 254

    def terminal(self, context):
        context['is_exit'] = True