transpiler_trainer.py 6.5 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
T
tangwei 已提交
16
Training use fluid with DistributeTranspiler
T
tangwei 已提交
17 18
"""
import os
T
tangwei 已提交
19

T
tangwei 已提交
20
import paddle.fluid as fluid
T
tangwei 已提交
21 22
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet

T
rename  
tangwei 已提交
23 24 25
from fleetrec.core.trainer import Trainer
from fleetrec.core.utils import envs
from fleetrec.core.utils import dataloader_instance
T
tangwei 已提交
26 27


T
tangwei 已提交
28
class TranspileTrainer(Trainer):
T
tangwei 已提交
29 30
    def __init__(self, config=None):
        Trainer.__init__(self, config)
T
tangwei 已提交
31
        self.processor_register()
T
tangwei 已提交
32
        self.model = None
T
tangwei12 已提交
33 34 35
        self.inference_models = []
        self.increment_models = []

T
tangwei 已提交
36 37
    def processor_register(self):
        print("Need implement by trainer, `self.regist_context_processor('uninit', self.instance)` must be the first")
T
tangwei 已提交
38

M
malin10 已提交
39 40 41 42
    def _get_dataloader(self, state):
        if state == "TRAIN":
            dataloader = self.model._data_loader
            namespace = "train.reader"
M
debug  
malin10 已提交
43
            class_name = "TrainReader"
M
malin10 已提交
44 45 46
        else:
            dataloader = self.model._infer_data_loader
            namespace = "evaluate.reader"
M
debug  
malin10 已提交
47
            class_name = "EvaluateReader"
M
malin10 已提交
48

T
tangwei 已提交
49 50 51
        batch_size = envs.get_global_env("batch_size", None, namespace)
        reader_class = envs.get_global_env("class", None, namespace)

M
malin10 已提交
52
        reader = dataloader_instance.dataloader(reader_class, state, self._config_yaml)
Y
add din  
yaoxuefeng 已提交
53
            
M
debug  
malin10 已提交
54
        reader_class = envs.lazy_instance_by_fliename(reader_class, class_name)
Y
add din  
yaoxuefeng 已提交
55 56 57 58 59
        reader_ins = reader_class(self._config_yaml)
        if hasattr(reader_ins,'generate_batch_from_trainfiles'):
            dataloader.set_sample_list_generator(reader)
        else:
            dataloader.set_sample_generator(reader, batch_size)
T
tangwei 已提交
60 61
        return dataloader

M
malin10 已提交
62 63 64 65 66 67 68 69 70
    def _get_dataset(self, state):
        if state == "TRAIN":
            inputs = self.model.get_inputs()
            namespace = "train.reader"
            train_data_path = envs.get_global_env("train_data_path", None, namespace)
        else:
            inputs = self.model.get_infer_inputs()
            namespace = "evaluate.reader"
            train_data_path = envs.get_global_env("test_data_path", None, namespace)
T
tangwei12 已提交
71

T
tangwei 已提交
72
        threads = int(envs.get_runtime_environ("train.trainer.threads"))
T
tangwei12 已提交
73
        batch_size = envs.get_global_env("batch_size", None, namespace)
T
tangwei 已提交
74 75
        reader_class = envs.get_global_env("class", None, namespace)
        abs_dir = os.path.dirname(os.path.abspath(__file__))
T
tangwei 已提交
76
        reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
M
malin10 已提交
77
        pipe_cmd = "python {} {} {} {}".format(reader, reader_class, state, self._config_yaml)
T
tangwei12 已提交
78

T
tangwei 已提交
79 80 81 82 83
        if train_data_path.startswith("fleetrec::"):
            package_base = envs.get_runtime_environ("PACKAGE_BASE")
            assert package_base is not None
            train_data_path = os.path.join(package_base, train_data_path.split("::")[1])

T
tangwei 已提交
84 85
        dataset = fluid.DatasetFactory().create_dataset()
        dataset.set_use_var(inputs)
T
tangwei 已提交
86
        dataset.set_pipe_command(pipe_cmd)
T
tangwei 已提交
87 88 89
        dataset.set_batch_size(batch_size)
        dataset.set_thread(threads)
        file_list = [
T
tangwei12 已提交
90 91
            os.path.join(train_data_path, x)
            for x in os.listdir(train_data_path)
T
tangwei 已提交
92 93 94 95 96
        ]

        dataset.set_filelist(file_list)
        return dataset

T
tangwei 已提交
97
    def save(self, epoch_id, namespace, is_fleet=False):
T
tangwei12 已提交
98 99 100
        def need_save(epoch_id, epoch_interval, is_last=False):
            if is_last:
                return True
T
tangwei 已提交
101

T
tangwei12 已提交
102 103
            if epoch_id == -1:
                return False
T
tangwei 已提交
104

T
tangwei12 已提交
105 106
            return epoch_id % epoch_interval == 0

T
tangwei 已提交
107
        def save_inference_model():
T
tangwei12 已提交
108
            save_interval = envs.get_global_env("save.inference.epoch_interval", -1, namespace)
T
tangwei 已提交
109 110 111

            if not need_save(epoch_id, save_interval, False):
                return
M
malin10 已提交
112
            
M
bug fix  
malin10 已提交
113 114
            print("save inference model is not supported now.")
            return
T
tangwei12 已提交
115 116 117

            feed_varnames = envs.get_global_env("save.inference.feed_varnames", None, namespace)
            fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None, namespace)
M
malin10 已提交
118
            fetch_vars = [fluid.default_main_program().global_block().vars[varname] for varname in fetch_varnames]
T
tangwei12 已提交
119
            dirname = envs.get_global_env("save.inference.dirname", None, namespace)
T
tangwei 已提交
120 121 122

            assert dirname is not None
            dirname = os.path.join(dirname, str(epoch_id))
T
tangwei 已提交
123 124

            if is_fleet:
T
tangwei 已提交
125
                fleet.save_inference_model(dirname, feed_varnames, fetch_vars)
T
tangwei 已提交
126
            else:
T
tangwei 已提交
127
                fluid.io.save_inference_model(dirname, feed_varnames, fetch_vars, self._exe)
T
tangwei12 已提交
128
            self.inference_models.append((epoch_id, dirname))
T
tangwei 已提交
129 130

        def save_persistables():
T
tangwei12 已提交
131
            save_interval = envs.get_global_env("save.increment.epoch_interval", -1, namespace)
T
tangwei 已提交
132 133 134 135

            if not need_save(epoch_id, save_interval, False):
                return

T
tangwei12 已提交
136
            dirname = envs.get_global_env("save.increment.dirname", None, namespace)
T
tangwei 已提交
137 138 139

            assert dirname is not None
            dirname = os.path.join(dirname, str(epoch_id))
T
tangwei 已提交
140 141

            if is_fleet:
T
tangwei 已提交
142
                fleet.save_persistables(self._exe, dirname)
T
tangwei 已提交
143
            else:
T
tangwei 已提交
144
                fluid.io.save_persistables(self._exe, dirname)
T
tangwei12 已提交
145
            self.increment_models.append((epoch_id, dirname))
T
tangwei 已提交
146 147 148

        save_persistables()
        save_inference_model()
M
malin10 已提交
149
        
T
tangwei 已提交
150

T
tangwei 已提交
151 152
    def instance(self, context):
        models = envs.get_global_env("train.model.models")
T
tangwei 已提交
153
        model_class = envs.lazy_instance_by_fliename(models, "Model")
T
tangwei 已提交
154
        self.model = model_class(None)
T
tangwei 已提交
155
        context['status'] = 'init_pass'
T
tangwei 已提交
156

T
tangwei 已提交
157 158 159
    def init(self, context):
        print("Need to be implement")
        context['is_exit'] = True
T
tangwei 已提交
160

T
tangwei 已提交
161 162 163 164 165
    def dataloader_train(self, context):
        print("Need to be implement")
        context['is_exit'] = True

    def dataset_train(self, context):
T
tangwei 已提交
166 167
        print("Need to be implement")
        context['is_exit'] = True
T
tangwei 已提交
168

T
tangwei12 已提交
169
    def infer(self, context):
T
tangwei 已提交
170
        context['is_exit'] = True
T
tangwei12 已提交
171 172

    def terminal(self, context):
T
tangwei 已提交
173
        print("clean up and exit")
T
tangwei12 已提交
174
        context['is_exit'] = True