transpiler_trainer.py 11.4 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
T
tangwei 已提交
15
Training use fluid with DistributeTranspiler
T
tangwei 已提交
16 17
"""
import os
T
tangwei 已提交
18

T
tangwei 已提交
19
import paddle.fluid as fluid
T
tangwei 已提交
20 21
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet

22 23 24
from paddlerec.core.trainer import Trainer
from paddlerec.core.utils import envs
from paddlerec.core.utils import dataloader_instance
X
xujiaqi01 已提交
25
from paddlerec.core.reader import SlotReader
T
tangwei 已提交
26 27


T
tangwei 已提交
28
class TranspileTrainer(Trainer):
T
tangwei 已提交
29 30
    def __init__(self, config=None):
        Trainer.__init__(self, config)
T
tangwei 已提交
31
        device = envs.get_global_env("train.device", "cpu")
Z
zhangwenhui03 已提交
32 33 34
        if device == 'gpu':
            self._place = fluid.CUDAPlace(0)
            self._exe = fluid.Executor(self._place)
T
tangwei 已提交
35
        self.processor_register()
T
tangwei 已提交
36
        self.model = None
T
tangwei12 已提交
37 38 39
        self.inference_models = []
        self.increment_models = []

T
tangwei 已提交
40
    def processor_register(self):
T
tangwei 已提交
41 42 43
        print(
            "Need implement by trainer, `self.regist_context_processor('uninit', self.instance)` must be the first"
        )
T
tangwei 已提交
44

45
    def _get_dataloader(self, state="TRAIN"):
T
tangwei 已提交
46

M
malin10 已提交
47 48 49
        if state == "TRAIN":
            dataloader = self.model._data_loader
            namespace = "train.reader"
M
debug  
malin10 已提交
50
            class_name = "TrainReader"
M
malin10 已提交
51
        else:
C
fix  
chengmo 已提交
52
            dataloader = self.model._infer_data_loader
M
malin10 已提交
53
            namespace = "evaluate.reader"
M
debug  
malin10 已提交
54
            class_name = "EvaluateReader"
M
malin10 已提交
55

X
xujiaqi01 已提交
56 57 58
        sparse_slots = envs.get_global_env("sparse_slots", None, namespace)
        dense_slots = envs.get_global_env("dense_slots", None, namespace)

T
tangwei 已提交
59
        batch_size = envs.get_global_env("batch_size", None, namespace)
C
chengmo 已提交
60
        print("batch_size: {}".format(batch_size))
T
tangwei 已提交
61

X
xujiaqi01 已提交
62 63
        if sparse_slots is None and dense_slots is None:
            reader_class = envs.get_global_env("class", None, namespace)
T
tangwei 已提交
64 65 66 67
            reader = dataloader_instance.dataloader(reader_class, state,
                                                    self._config_yaml)
            reader_class = envs.lazy_instance_by_fliename(reader_class,
                                                          class_name)
X
xujiaqi01 已提交
68 69
            reader_ins = reader_class(self._config_yaml)
        else:
T
tangwei 已提交
70 71
            reader = dataloader_instance.slotdataloader("", state,
                                                        self._config_yaml)
X
xujiaqi01 已提交
72 73
            reader_ins = SlotReader(self._config_yaml)

C
chengmo 已提交
74 75 76 77 78 79
        if hasattr(reader_ins, 'generate_batch_from_trainfiles'):
            dataloader.set_sample_list_generator(reader)
        else:
            dataloader.set_sample_generator(reader, batch_size)

        debug_mode = envs.get_global_env("reader_debug_mode", False, namespace)
C
chengmo 已提交
80 81
        if debug_mode:
            print("--- DataLoader Debug Mode Begin , show pre 10 data ---")
C
chengmo 已提交
82
            for idx, line in enumerate(reader()):
C
chengmo 已提交
83 84 85 86
                print(line)
                if idx >= 9:
                    break
            print("--- DataLoader Debug Mode End , show pre 10 data ---")
C
fix bug  
chengmo 已提交
87
            exit(0)
T
tangwei 已提交
88 89
        return dataloader

T
tangwei 已提交
90 91 92 93 94 95 96
    def _get_dataset_ins(self):
        count = 0
        for f in self.files:
            for _, _ in enumerate(open(f, 'r')):
                count += 1
        return count

X
fix  
xjqbest 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
    def _get_dataset(self, state="TRAIN"):
        if state == "TRAIN":
            inputs = self.model.get_inputs()
            namespace = "train.reader"
            train_data_path = envs.get_global_env("train_data_path", None,
                                                  namespace)
        else:
            inputs = self.model.get_infer_inputs()
            namespace = "evaluate.reader"
            train_data_path = envs.get_global_env("test_data_path", None,
                                                  namespace)

        sparse_slots = envs.get_global_env("sparse_slots", None, namespace)
        dense_slots = envs.get_global_env("dense_slots", None, namespace)

        threads = int(envs.get_runtime_environ("train.trainer.threads"))
        batch_size = envs.get_global_env("batch_size", None, namespace)
        reader_class = envs.get_global_env("class", None, namespace)
T
tangwei 已提交
115
        abs_dir = os.path.dirname(os.path.abspath(__file__))
T
tangwei 已提交
116
        reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
X
xujiaqi01 已提交
117 118

        if sparse_slots is None and dense_slots is None:
T
tangwei 已提交
119 120
            pipe_cmd = "python {} {} {} {}".format(reader, reader_class, state,
                                                   self._config_yaml)
X
xujiaqi01 已提交
121 122 123 124 125
        else:
            padding = envs.get_global_env("padding", 0, namespace)
            pipe_cmd = "python {} {} {} {} {} {} {} {}".format(
                reader, "slot", "slot", self._config_yaml, namespace, \
                sparse_slots.replace(" ", "#"), dense_slots.replace(" ", "#"), str(padding))
T
tangwei12 已提交
126

127
        if train_data_path.startswith("paddlerec::"):
T
tangwei 已提交
128 129
            package_base = envs.get_runtime_environ("PACKAGE_BASE")
            assert package_base is not None
T
tangwei 已提交
130 131
            train_data_path = os.path.join(package_base,
                                           train_data_path.split("::")[1])
T
tangwei 已提交
132

T
tangwei 已提交
133 134
        dataset = fluid.DatasetFactory().create_dataset()
        dataset.set_use_var(inputs)
T
tangwei 已提交
135
        dataset.set_pipe_command(pipe_cmd)
T
tangwei 已提交
136 137 138
        dataset.set_batch_size(batch_size)
        dataset.set_thread(threads)
        file_list = [
T
tangwei12 已提交
139 140
            os.path.join(train_data_path, x)
            for x in os.listdir(train_data_path)
T
tangwei 已提交
141
        ]
T
tangwei 已提交
142 143
        self.files = file_list
        dataset.set_filelist(self.files)
C
chengmo 已提交
144

C
chengmo 已提交
145
        debug_mode = envs.get_global_env("reader_debug_mode", False, namespace)
C
chengmo 已提交
146
        if debug_mode:
T
tangwei 已提交
147 148
            print("--- Dataset Debug Mode Begin , show pre 10 data of {}---".
                  format(file_list[0]))
C
chengmo 已提交
149
            os.system("cat {} | {} | head -10".format(file_list[0], pipe_cmd))
T
tangwei 已提交
150 151
            print("--- Dataset Debug Mode End , show pre 10 data of {}---".
                  format(file_list[0]))
C
fix bug  
chengmo 已提交
152
            exit(0)
C
chengmo 已提交
153

T
tangwei 已提交
154 155
        return dataset

T
tangwei 已提交
156
    def save(self, epoch_id, namespace, is_fleet=False):
T
tangwei12 已提交
157 158 159
        def need_save(epoch_id, epoch_interval, is_last=False):
            if is_last:
                return True
T
tangwei 已提交
160

T
tangwei12 已提交
161 162
            if epoch_id == -1:
                return False
T
tangwei 已提交
163

T
tangwei12 已提交
164 165
            return epoch_id % epoch_interval == 0

T
tangwei 已提交
166
        def save_inference_model():
C
chengmo 已提交
167 168
            save_interval = envs.get_global_env(
                "save.inference.epoch_interval", -1, namespace)
T
tangwei 已提交
169 170 171

            if not need_save(epoch_id, save_interval, False):
                return
C
chengmo 已提交
172

T
tangwei 已提交
173 174
            feed_varnames = envs.get_global_env("save.inference.feed_varnames",
                                                None, namespace)
C
chengmo 已提交
175 176
            fetch_varnames = envs.get_global_env(
                "save.inference.fetch_varnames", None, namespace)
177 178
            if feed_varnames is None or fetch_varnames is None:
                return
T
tangwei12 已提交
179

T
tangwei 已提交
180 181 182 183 184 185
            fetch_vars = [
                fluid.default_main_program().global_block().vars[varname]
                for varname in fetch_varnames
            ]
            dirname = envs.get_global_env("save.inference.dirname", None,
                                          namespace)
T
tangwei 已提交
186 187 188

            assert dirname is not None
            dirname = os.path.join(dirname, str(epoch_id))
T
tangwei 已提交
189 190

            if is_fleet:
T
tangwei 已提交
191 192
                fleet.save_inference_model(self._exe, dirname, feed_varnames,
                                           fetch_vars)
T
tangwei 已提交
193
            else:
T
tangwei 已提交
194 195
                fluid.io.save_inference_model(dirname, feed_varnames,
                                              fetch_vars, self._exe)
T
tangwei12 已提交
196
            self.inference_models.append((epoch_id, dirname))
T
tangwei 已提交
197 198

        def save_persistables():
C
chengmo 已提交
199 200
            save_interval = envs.get_global_env(
                "save.increment.epoch_interval", -1, namespace)
T
tangwei 已提交
201 202 203 204

            if not need_save(epoch_id, save_interval, False):
                return

T
tangwei 已提交
205 206
            dirname = envs.get_global_env("save.increment.dirname", None,
                                          namespace)
T
tangwei 已提交
207 208 209

            assert dirname is not None
            dirname = os.path.join(dirname, str(epoch_id))
T
tangwei 已提交
210 211

            if is_fleet:
T
tangwei 已提交
212
                fleet.save_persistables(self._exe, dirname)
T
tangwei 已提交
213
            else:
T
tangwei 已提交
214
                fluid.io.save_persistables(self._exe, dirname)
T
tangwei12 已提交
215
            self.increment_models.append((epoch_id, dirname))
T
tangwei 已提交
216 217 218 219

        save_persistables()
        save_inference_model()

T
tangwei 已提交
220 221
    def instance(self, context):
        models = envs.get_global_env("train.model.models")
T
tangwei 已提交
222
        model_class = envs.lazy_instance_by_fliename(models, "Model")
T
tangwei 已提交
223
        self.model = model_class(None)
T
tangwei 已提交
224
        context['status'] = 'init_pass'
T
tangwei 已提交
225

T
tangwei 已提交
226 227 228
    def init(self, context):
        print("Need to be implement")
        context['is_exit'] = True
T
tangwei 已提交
229

T
tangwei 已提交
230 231 232 233 234
    def dataloader_train(self, context):
        print("Need to be implement")
        context['is_exit'] = True

    def dataset_train(self, context):
T
tangwei 已提交
235 236
        print("Need to be implement")
        context['is_exit'] = True
T
tangwei 已提交
237

T
tangwei12 已提交
238
    def infer(self, context):
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
        infer_program = fluid.Program()
        startup_program = fluid.Program()
        with fluid.unique_name.guard():
            with fluid.program_guard(infer_program, startup_program):
                self.model.infer_net()

        if self.model._infer_data_loader is None:
            context['status'] = 'terminal_pass'
            return

        reader = self._get_dataloader("Evaluate")

        metrics_varnames = []
        metrics_format = []

        metrics_format.append("{}: {{}}".format("epoch"))
        metrics_format.append("{}: {{}}".format("batch"))

        for name, var in self.model.get_infer_results().items():
            metrics_varnames.append(var.name)
            metrics_format.append("{}: {{}}".format(name))

        metrics_format = ", ".join(metrics_format)
        self._exe.run(startup_program)

C
chengmo 已提交
264 265 266 267 268 269 270 271
        model_list = self.increment_models

        evaluate_only = envs.get_global_env(
            'evaluate_only', False, namespace='evaluate')
        if evaluate_only:
            model_list = [(0, envs.get_global_env(
                'evaluate_model_path', "", namespace='evaluate'))]

Z
zhangwenhui03 已提交
272
        is_return_numpy = envs.get_global_env(
T
for mat  
tangwei 已提交
273
            'is_return_numpy', True, namespace='evaluate')
Z
zhangwenhui03 已提交
274

C
chengmo 已提交
275 276 277
        for (epoch, model_dir) in model_list:
            print("Begin to infer No.{} model, model_dir: {}".format(
                epoch, model_dir))
278 279 280 281 282 283
            program = infer_program.clone()
            fluid.io.load_persistables(self._exe, model_dir, program)
            reader.start()
            batch_id = 0
            try:
                while True:
T
tangwei 已提交
284 285 286
                    metrics_rets = self._exe.run(program=program,
                                                 fetch_list=metrics_varnames,
                                                 return_numpy=is_return_numpy)
287 288 289 290 291 292 293 294 295 296 297

                    metrics = [epoch, batch_id]
                    metrics.extend(metrics_rets)

                    if batch_id % 2 == 0 and batch_id != 0:
                        print(metrics_format.format(*metrics))
                    batch_id += 1
            except fluid.core.EOFException:
                reader.reset()

        context['status'] = 'terminal_pass'
T
tangwei12 已提交
298 299

    def terminal(self, context):
T
tangwei 已提交
300
        print("clean up and exit")
T
tangwei12 已提交
301
        context['is_exit'] = True