single_train.py 6.0 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Training use fluid with one node only.
"""

from __future__ import print_function
import os
import time
import numpy as np
import logging
import paddle.fluid as fluid

from .trainer import Trainer
from ..utils import envs

logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)


class SingleTrainer(Trainer):
T
tangwei 已提交
35 36
    def __init__(self, config=None):
        Trainer.__init__(self, config)
T
tangwei 已提交
37

T
tangwei12 已提交
38 39 40
        self.inference_models = []
        self.increment_models = []

T
tangwei 已提交
41 42 43 44 45 46 47 48 49
        self.exe = fluid.Executor(fluid.CPUPlace())

        self.regist_context_processor('uninit', self.instance)
        self.regist_context_processor('init_pass', self.init)
        self.regist_context_processor('train_pass', self.train)
        self.regist_context_processor('infer_pass', self.infer)
        self.regist_context_processor('terminal_pass', self.terminal)

    def instance(self, context):
T
tangwei12 已提交
50 51 52 53

        models = envs.get_global_env("train.model.models")
        model_package = __import__(models, globals(), locals(), models.split("."))

T
tangwei 已提交
54 55 56 57 58 59 60 61 62 63
        train_model = getattr(model_package, 'Train')

        self.model = train_model()

        context['status'] = 'init_pass'

    def init(self, context):
        self.model.input()
        self.model.net()
        self.metrics = self.model.metrics()
T
tangwei12 已提交
64
        self.metric_extras = self.model.metric_extras()
T
tangwei 已提交
65 66
        loss = self.model.avg_loss()

T
tangwei12 已提交
67
        optimizer = self.model.optimizer()
T
tangwei 已提交
68
        optimizer.minimize(loss)
T
tangwei 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82

        # run startup program at once
        self.exe.run(fluid.default_startup_program())

        context['status'] = 'train_pass'

    def train(self, context):
        print("Need to be implement")
        context['is_exit'] = True

    def infer(self, context):
        context['is_exit'] = True

    def terminal(self, context):
T
tangwei12 已提交
83
        print("clean up and exit")
T
tangwei 已提交
84 85 86 87 88 89 90 91
        context['is_exit'] = True


class SingleTrainerWithDataloader(SingleTrainer):
    pass


class SingleTrainerWithDataset(SingleTrainer):
T
tangwei12 已提交
92 93 94 95 96 97 98 99 100 101
    def _get_dataset(self):
        namespace = "train.reader"

        inputs = self.model.input_vars()
        threads = envs.get_global_env("train.threads", None)
        batch_size = envs.get_global_env("batch_size", None, namespace)
        pipe_command = envs.get_global_env("pipe_command", None, namespace)
        train_data_path = envs.get_global_env("train_data_path", None, namespace)


T
tangwei 已提交
102 103 104 105 106 107
        dataset = fluid.DatasetFactory().create_dataset()
        dataset.set_use_var(inputs)
        dataset.set_pipe_command(pipe_command)
        dataset.set_batch_size(batch_size)
        dataset.set_thread(threads)
        file_list = [
T
tangwei12 已提交
108 109
            os.path.join(train_data_path, x)
            for x in os.listdir(train_data_path)
T
tangwei 已提交
110 111 112 113 114
        ]

        dataset.set_filelist(file_list)
        return dataset

T
tangwei12 已提交
115 116 117 118 119 120 121 122 123 124
    def save(self, epoch_id, namespace):
        def need_save(epoch_id, epoch_interval, is_last=False):
            if is_last:
                return True
        
            if epoch_id == -1:
                return False
        
            return epoch_id % epoch_interval == 0

T
tangwei 已提交
125
        def save_inference_model():
T
tangwei12 已提交
126
            save_interval = envs.get_global_env("save.inference.epoch_interval", -1, namespace)
T
tangwei 已提交
127 128 129 130

            if not need_save(epoch_id, save_interval, False):
                return

T
tangwei12 已提交
131 132 133 134 135
            print("save inference model is not supported now.")
            return

            feed_varnames = envs.get_global_env("save.inference.feed_varnames", None, namespace)
            fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None, namespace)
T
tangwei 已提交
136
            fetch_vars = [fluid.global_scope().vars[varname] for varname in fetch_varnames]
T
tangwei12 已提交
137
            dirname = envs.get_global_env("save.inference.dirname", None, namespace)
T
tangwei 已提交
138 139 140 141

            assert dirname is not None
            dirname = os.path.join(dirname, str(epoch_id))
            fluid.io.save_inference_model(dirname, feed_varnames, fetch_vars, self.exe)
T
tangwei12 已提交
142 143
            self.inference_models.append((epoch_id, dirname))
            
T
tangwei 已提交
144 145

        def save_persistables():
T
tangwei12 已提交
146
            save_interval = envs.get_global_env("save.increment.epoch_interval", -1, namespace)
T
tangwei 已提交
147 148 149 150

            if not need_save(epoch_id, save_interval, False):
                return

T
tangwei12 已提交
151
            dirname = envs.get_global_env("save.increment.dirname", None, namespace)
T
tangwei 已提交
152 153 154 155

            assert dirname is not None
            dirname = os.path.join(dirname, str(epoch_id))
            fluid.io.save_persistables(self.exe, dirname)
T
tangwei12 已提交
156
            self.increment_models.append((epoch_id, dirname))
T
tangwei 已提交
157 158 159 160 161

        save_persistables()
        save_inference_model()

    def train(self, context):
T
tangwei12 已提交
162
        dataset = self._get_dataset()
T
tangwei 已提交
163

T
tangwei12 已提交
164
        epochs = envs.get_global_env("train.epochs")
T
tangwei 已提交
165 166 167 168

        for i in range(epochs):
            self.exe.train_from_dataset(program=fluid.default_main_program(),
                                        dataset=dataset,
T
tangwei12 已提交
169 170 171
                                        fetch_list=self.metric_extras[0],
                                        fetch_info=self.metric_extras[1],
                                        print_period=self.metric_extras[2])
T
tangwei12 已提交
172
            self.save(i, "train")
T
tangwei 已提交
173 174 175
        context['status'] = 'infer_pass'


T
tangwei12 已提交
176 177
    def infer(self, context):
        context['status'] = 'terminal_pass'
T
tangwei12 已提交
178 179 180 181 182

    def terminal(self, context):
        for model in self.increment_models:
            print("epoch :{}, dir: {}".format(model[0], model[1]))
        context['is_exit'] = True