single_trainer.py 15.1 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Training use fluid with one node only.
"""

from __future__ import print_function
T
tangwei 已提交
19

T
tangwei 已提交
20
import time
T
tangwei 已提交
21
import logging
X
fix  
xjqbest 已提交
22
import os
T
tangwei 已提交
23 24
import paddle.fluid as fluid

25 26
from paddlerec.core.trainers.transpiler_trainer import TranspileTrainer
from paddlerec.core.utils import envs
X
fix  
xjqbest 已提交
27 28
from paddlerec.core.reader import SlotReader
from paddlerec.core.utils import dataloader_instance
T
tangwei 已提交
29 30 31 32 33 34

logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)


X
fix  
xjqbest 已提交
35
class SingleTrainer(TranspileTrainer):
X
fix  
xjqbest 已提交
36 37 38 39 40 41 42 43
    def __init__(self, config=None):
        super(TranspileTrainer, self).__init__(config)
        self._env = self._config
        self.processor_register()
        self._model = {}
        self._dataset = {}
        envs.set_global_envs(self._config)
        envs.update_workspace()
X
fix  
xjqbest 已提交
44
        self._runner_name = envs.get_global_env("mode")
X
fix  
xjqbest 已提交
45 46 47 48 49 50
        device = envs.get_global_env("runner." + self._runner_name + ".device")
        if device == 'gpu':
            self._place = fluid.CUDAPlace(0)
        elif device == 'cpu':
            self._place = fluid.CPUPlace()
        self._exe = fluid.Executor(self._place)
X
fix  
xjqbest 已提交
51

T
tangwei 已提交
52 53 54
    def processor_register(self):
        self.regist_context_processor('uninit', self.instance)
        self.regist_context_processor('init_pass', self.init)
C
chengmo 已提交
55
        self.regist_context_processor('startup_pass', self.startup)
X
fix  
xjqbest 已提交
56 57 58 59 60 61 62 63 64 65
        self.regist_context_processor('train_pass', self.executor_train)
        self.regist_context_processor('terminal_pass', self.terminal)

    def instance(self, context):
        context['status'] = 'init_pass'

    def _get_dataset(self, dataset_name):
        name = "dataset." + dataset_name + "."
        thread_num = envs.get_global_env(name + "thread_num")
        batch_size = envs.get_global_env(name + "batch_size")
X
fix  
xjqbest 已提交
66
        reader_class = envs.get_global_env(name + "data_converter")
X
fix  
xjqbest 已提交
67 68
        abs_dir = os.path.dirname(os.path.abspath(__file__))
        reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
X
fix  
xjqbest 已提交
69 70
        sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
        dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
X
fix  
xjqbest 已提交
71
        if sparse_slots == "" and dense_slots == "":
X
fix  
xjqbest 已提交
72 73
            pipe_cmd = "python {} {} {} {}".format(reader, reader_class,
                                                   "TRAIN", self._config_yaml)
X
fix  
xjqbest 已提交
74
        else:
X
fix  
xjqbest 已提交
75
            if sparse_slots == "":
X
fix  
xjqbest 已提交
76
                sparse_slots = "?"
X
fix  
xjqbest 已提交
77
            if dense_slots == "":
X
fix  
xjqbest 已提交
78
                dense_slots = "?"
X
fix  
xjqbest 已提交
79
            padding = envs.get_global_env(name + "padding", 0)
X
fix  
xjqbest 已提交
80 81
            pipe_cmd = "python {} {} {} {} {} {} {} {}".format(
                reader, "slot", "slot", self._config_yaml, "fake", \
X
fix  
xjqbest 已提交
82
                sparse_slots.replace(" ", "?"), dense_slots.replace(" ", "?"), str(padding))
X
fix  
xjqbest 已提交
83

X
fix  
xjqbest 已提交
84 85 86 87 88 89 90 91 92
        dataset = fluid.DatasetFactory().create_dataset()
        dataset.set_batch_size(envs.get_global_env(name + "batch_size"))
        dataset.set_pipe_command(pipe_cmd)
        train_data_path = envs.get_global_env(name + "data_path")
        file_list = [
            os.path.join(train_data_path, x)
            for x in os.listdir(train_data_path)
        ]
        dataset.set_filelist(file_list)
X
fix  
xjqbest 已提交
93
        for model_dict in self._env["phase"]:
X
fix  
xjqbest 已提交
94 95
            if model_dict["dataset_name"] == dataset_name:
                model = self._model[model_dict["name"]][3]
X
fix  
xjqbest 已提交
96
                inputs = model._data_var
X
fix  
xjqbest 已提交
97 98 99
                dataset.set_use_var(inputs)
                break
        return dataset
T
tangwei 已提交
100

X
fix  
xjqbest 已提交
101
    def _get_dataloader(self, dataset_name, dataloader):
X
fix  
xjqbest 已提交
102
        name = "dataset." + dataset_name + "."
X
fix  
xjqbest 已提交
103 104
        sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
        dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
X
fix  
xjqbest 已提交
105 106
        thread_num = envs.get_global_env(name + "thread_num")
        batch_size = envs.get_global_env(name + "batch_size")
X
fix  
xjqbest 已提交
107
        reader_class = envs.get_global_env(name + "data_converter")
X
fix  
xjqbest 已提交
108
        abs_dir = os.path.dirname(os.path.abspath(__file__))
X
fix  
xjqbest 已提交
109
        if sparse_slots == "" and dense_slots == "":
X
fix  
xjqbest 已提交
110 111
            reader = dataloader_instance.dataloader_by_name(
                reader_class, dataset_name, self._config_yaml)
X
fix  
xjqbest 已提交
112 113
            reader_class = envs.lazy_instance_by_fliename(reader_class,
                                                          "TrainReader")
X
fix  
xjqbest 已提交
114 115
            reader_ins = reader_class(self._config_yaml)
        else:
X
fix  
xjqbest 已提交
116 117
            reader = dataloader_instance.slotdataloader_by_name(
                "", dataset_name, self._config_yaml)
X
fix  
xjqbest 已提交
118 119 120 121 122 123
            reader_ins = SlotReader(self._config_yaml)
        if hasattr(reader_ins, 'generate_batch_from_trainfiles'):
            dataloader.set_sample_list_generator(reader)
        else:
            dataloader.set_sample_generator(reader, batch_size)
        return dataloader
T
tangwei 已提交
124

X
fix  
xjqbest 已提交
125 126
    def _create_dataset(self, dataset_name):
        name = "dataset." + dataset_name + "."
X
fix  
xjqbest 已提交
127 128
        sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
        dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
X
fix  
xjqbest 已提交
129 130 131 132
        thread_num = envs.get_global_env(name + "thread_num")
        batch_size = envs.get_global_env(name + "batch_size")
        type_name = envs.get_global_env(name + "type")
        if envs.get_platform() != "LINUX":
X
fix  
xjqbest 已提交
133 134
            print("platform ", envs.get_platform(),
                  " change reader to DataLoader")
X
fix  
xjqbest 已提交
135 136
            type_name = "DataLoader"
        padding = 0
T
tangwei 已提交
137

X
fix  
xjqbest 已提交
138
        if type_name == "DataLoader":
X
fix  
xjqbest 已提交
139
            return None
X
fix  
xjqbest 已提交
140 141 142 143
        else:
            return self._get_dataset(dataset_name)

    def init(self, context):
X
fix  
xjqbest 已提交
144
        for model_dict in self._env["phase"]:
X
fix  
xjqbest 已提交
145
            self._model[model_dict["name"]] = [None] * 5
X
fix  
xjqbest 已提交
146 147 148 149 150 151
            train_program = fluid.Program()
            startup_program = fluid.Program()
            scope = fluid.Scope()
            dataset_name = model_dict["dataset_name"]
            with fluid.program_guard(train_program, startup_program):
                with fluid.unique_name.guard():
X
fix  
xjqbest 已提交
152
                    with fluid.scope_guard(scope):
X
fix  
xjqbest 已提交
153 154 155 156 157
                        model_path = model_dict["model"].replace(
                            "{workspace}",
                            envs.path_adapter(self._env["workspace"]))
                        model = envs.lazy_instance_by_fliename(
                            model_path, "Model")(self._env)
X
fix  
xjqbest 已提交
158 159
                        model._data_var = model.input_data(
                            dataset_name=model_dict["dataset_name"])
X
fix  
xjqbest 已提交
160 161
                        if envs.get_global_env("dataset." + dataset_name +
                                               ".type") == "DataLoader":
X
fix  
xjqbest 已提交
162
                            model._init_dataloader(is_infer=False)
X
fix  
xjqbest 已提交
163 164
                            self._get_dataloader(dataset_name,
                                                 model._data_loader)
X
fix  
xjqbest 已提交
165
                        model.net(model._data_var, False)
M
malin10 已提交
166
                        optimizer = model.optimizer()
X
fix  
xjqbest 已提交
167
                        optimizer.minimize(model._cost)
X
fix  
xjqbest 已提交
168 169 170 171
            self._model[model_dict["name"]][0] = train_program
            self._model[model_dict["name"]][1] = startup_program
            self._model[model_dict["name"]][2] = scope
            self._model[model_dict["name"]][3] = model
X
fix  
xjqbest 已提交
172
            self._model[model_dict["name"]][4] = train_program.clone()
X
fix  
xjqbest 已提交
173 174 175

        for dataset in self._env["dataset"]:
            if dataset["type"] != "DataLoader":
X
fix  
xjqbest 已提交
176
                self._dataset[dataset["name"]] = self._create_dataset(dataset[
X
fix  
xjqbest 已提交
177
                    "name"])
X
fix  
xjqbest 已提交
178 179

        context['status'] = 'startup_pass'
C
chengmo 已提交
180 181

    def startup(self, context):
X
fix  
xjqbest 已提交
182
        for model_dict in self._env["phase"]:
X
fix  
xjqbest 已提交
183
            with fluid.scope_guard(self._model[model_dict["name"]][2]):
X
fix  
xjqbest 已提交
184
                self._exe.run(self._model[model_dict["name"]][1])
T
tangwei 已提交
185 186
        context['status'] = 'train_pass'

X
fix  
xjqbest 已提交
187
    def executor_train(self, context):
X
fix  
xjqbest 已提交
188 189
        epochs = int(
            envs.get_global_env("runner." + self._runner_name + ".epochs"))
X
fix  
xjqbest 已提交
190
        for j in range(epochs):
X
fix  
xjqbest 已提交
191
            for model_dict in self._env["phase"]:
X
fix  
xjqbest 已提交
192 193 194 195 196
                if j == 0:
                    with fluid.scope_guard(self._model[model_dict["name"]][2]):
                        train_prog = self._model[model_dict["name"]][0]
                        startup_prog = self._model[model_dict["name"]][1]
                        with fluid.program_guard(train_prog, startup_prog):
X
fix  
xjqbest 已提交
197
                            self.load()
X
fix  
xjqbest 已提交
198 199 200 201 202 203 204
                reader_name = model_dict["dataset_name"]
                name = "dataset." + reader_name + "."
                begin_time = time.time()
                if envs.get_global_env(name + "type") == "DataLoader":
                    self._executor_dataloader_train(model_dict)
                else:
                    self._executor_dataset_train(model_dict)
X
fix  
xjqbest 已提交
205
                with fluid.scope_guard(self._model[model_dict["name"]][2]):
X
fix  
xjqbest 已提交
206
                    train_prog = self._model[model_dict["name"]][4]
X
fix  
xjqbest 已提交
207 208 209
                    startup_prog = self._model[model_dict["name"]][1]
                    with fluid.program_guard(train_prog, startup_prog):
                        self.save(j)
X
fix  
xjqbest 已提交
210 211 212 213
                end_time = time.time()
                seconds = end_time - begin_time
            print("epoch {} done, time elasped: {}".format(j, seconds))
        context['status'] = "terminal_pass"
X
fix  
xjqbest 已提交
214

X
fix  
xjqbest 已提交
215 216 217 218 219 220
    def _executor_dataset_train(self, model_dict):
        reader_name = model_dict["dataset_name"]
        model_name = model_dict["name"]
        model_class = self._model[model_name][3]
        fetch_vars = []
        fetch_alias = []
X
fix  
xjqbest 已提交
221 222 223
        fetch_period = int(
            envs.get_global_env("runner." + self._runner_name +
                                ".fetch_period", 20))
X
fix  
xjqbest 已提交
224
        metrics = model_class.get_metrics()
X
fix  
xjqbest 已提交
225 226 227 228 229 230 231
        if metrics:
            fetch_vars = metrics.values()
            fetch_alias = metrics.keys()
        scope = self._model[model_name][2]
        program = self._model[model_name][0]
        reader = self._dataset[reader_name]
        with fluid.scope_guard(scope):
X
fix  
xjqbest 已提交
232 233 234 235 236 237
            self._exe.train_from_dataset(
                program=program,
                dataset=reader,
                fetch_list=fetch_vars,
                fetch_info=fetch_alias,
                print_period=fetch_period)
X
fix  
xjqbest 已提交
238

X
fix  
xjqbest 已提交
239 240 241 242
    def _executor_dataloader_train(self, model_dict):
        reader_name = model_dict["dataset_name"]
        model_name = model_dict["name"]
        model_class = self._model[model_name][3]
X
fix  
xjqbest 已提交
243
        program = self._model[model_name][0].clone()
X
fix  
xjqbest 已提交
244 245
        program = fluid.compiler.CompiledProgram(program).with_data_parallel(
            loss_name=model_class.get_avg_cost().name)
X
fix  
xjqbest 已提交
246 247
        fetch_vars = []
        fetch_alias = []
X
fix  
xjqbest 已提交
248 249 250
        fetch_period = int(
            envs.get_global_env("runner." + self._runner_name +
                                ".fetch_period", 20))
X
fix  
xjqbest 已提交
251
        metrics = model_class.get_metrics()
X
fix  
xjqbest 已提交
252 253 254
        if metrics:
            fetch_vars = metrics.values()
            fetch_alias = metrics.keys()
X
test  
xjqbest 已提交
255 256 257
        metrics_varnames = []
        metrics_format = []
        metrics_format.append("{}: {{}}".format("batch"))
X
fix  
xjqbest 已提交
258
        for name, var in metrics.items():
X
test  
xjqbest 已提交
259 260 261 262
            metrics_varnames.append(var.name)
            metrics_format.append("{}: {{}}".format(name))
        metrics_format = ", ".join(metrics_format)

X
fix  
xjqbest 已提交
263 264 265 266 267
        reader = self._model[model_name][3]._data_loader
        reader.start()
        batch_id = 0
        scope = self._model[model_name][2]
        with fluid.scope_guard(scope):
T
tangwei 已提交
268 269
            try:
                while True:
T
tangwei 已提交
270 271
                    metrics_rets = self._exe.run(program=program,
                                                 fetch_list=metrics_varnames)
X
fix  
xjqbest 已提交
272
                    metrics = [batch_id]
T
tangwei 已提交
273
                    metrics.extend(metrics_rets)
T
tangwei 已提交
274

X
fix  
xjqbest 已提交
275
                    if batch_id % fetch_period == 0 and batch_id != 0:
T
tangwei 已提交
276
                        print(metrics_format.format(*metrics))
T
tangwei 已提交
277 278 279
                    batch_id += 1
            except fluid.core.EOFException:
                reader.reset()
T
tangwei 已提交
280 281 282

    def terminal(self, context):
        context['is_exit'] = True
X
fix  
xjqbest 已提交
283

X
fix  
xjqbest 已提交
284
    def load(self, is_fleet=False):
X
fix  
xjqbest 已提交
285 286
        dirname = envs.get_global_env(
            "runner." + self._runner_name + ".init_model_path", None)
X
fix  
xjqbest 已提交
287
        if dirname is None or dirname == "":
X
fix  
xjqbest 已提交
288
            return
X
fix  
xjqbest 已提交
289
        print("going to load ", dirname)
X
fix  
xjqbest 已提交
290 291 292 293 294
        if is_fleet:
            fleet.load_persistables(self._exe, dirname)
        else:
            fluid.io.load_persistables(self._exe, dirname)

X
fix  
xjqbest 已提交
295 296 297 298 299 300 301 302 303 304
    def save(self, epoch_id, is_fleet=False):
        def need_save(epoch_id, epoch_interval, is_last=False):
            if is_last:
                return True
            if epoch_id == -1:
                return False

            return epoch_id % epoch_interval == 0

        def save_inference_model():
X
fix  
xjqbest 已提交
305
            name = "runner." + self._runner_name + "."
X
fix  
xjqbest 已提交
306
            save_interval = int(
X
fix  
xjqbest 已提交
307
                envs.get_global_env(name + "save_inference_interval", -1))
X
fix  
xjqbest 已提交
308 309
            if not need_save(epoch_id, save_interval, False):
                return
X
fix  
xjqbest 已提交
310
            feed_varnames = envs.get_global_env(
X
fix  
xjqbest 已提交
311
                name + "save_inference_feed_varnames", [])
X
fix  
xjqbest 已提交
312
            fetch_varnames = envs.get_global_env(
X
fix  
xjqbest 已提交
313 314 315
                name + "save_inference_fetch_varnames", [])
            if feed_varnames is None or fetch_varnames is None or feed_varnames == "" or fetch_varnames == "" or \
               len(feed_varnames) == 0 or len(fetch_varnames) == 0:
X
fix  
xjqbest 已提交
316 317 318 319 320
                return
            fetch_vars = [
                fluid.default_main_program().global_block().vars[varname]
                for varname in fetch_varnames
            ]
X
fix  
xjqbest 已提交
321
            dirname = envs.get_global_env(name + "save_inference_path", None)
X
fix  
xjqbest 已提交
322 323 324 325 326 327 328 329 330 331 332 333

            assert dirname is not None
            dirname = os.path.join(dirname, str(epoch_id))

            if is_fleet:
                fleet.save_inference_model(self._exe, dirname, feed_varnames,
                                           fetch_vars)
            else:
                fluid.io.save_inference_model(dirname, feed_varnames,
                                              fetch_vars, self._exe)

        def save_persistables():
X
fix  
xjqbest 已提交
334
            name = "runner." + self._runner_name + "."
X
fix  
xjqbest 已提交
335
            save_interval = int(
X
fix  
xjqbest 已提交
336
                envs.get_global_env(name + "save_checkpoint_interval", -1))
X
fix  
xjqbest 已提交
337 338
            if not need_save(epoch_id, save_interval, False):
                return
X
fix  
xjqbest 已提交
339
            dirname = envs.get_global_env(name + "save_checkpoint_path", None)
X
fix  
xjqbest 已提交
340 341
            if dirname is None or dirname == "":
                return
X
fix  
xjqbest 已提交
342 343 344 345 346 347 348 349
            dirname = os.path.join(dirname, str(epoch_id))
            if is_fleet:
                fleet.save_persistables(self._exe, dirname)
            else:
                fluid.io.save_persistables(self._exe, dirname)

        save_persistables()
        save_inference_model()