ctr_trainer.py 20.3 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
tangwei 已提交
15

X
xiexionghang 已提交
16 17 18 19
import sys
import time
import json
import datetime
T
tangwei 已提交
20 21
import numpy as np

X
xiexionghang 已提交
22 23
import paddle.fluid as fluid
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
24
from paddle.fluid.incubate.fleet.base.role_maker import GeneralRoleMaker
X
xiexionghang 已提交
25

T
tangwei 已提交
26 27 28 29 30
from fleet_rec.utils import fs as fs
from fleet_rec.utils import util as util
from fleet_rec.metrics.auc_metrics import AUCMetric
from fleet_rec.models import base as model_basic
from fleet_rec.reader import dataset
T
tangwei 已提交
31 32
from .trainer import Trainer

T
tangwei 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73

def wroker_numric_opt(value, env, opt):
    """
    numric count opt for workers
    Args:
        value: value for count
        env: mpi/gloo
        opt: count operator, SUM/MAX/MIN/AVG
    Return:
        count result
    """
    local_value = np.array([value])
    global_value = np.copy(local_value) * 0
    fleet._role_maker.all_reduce_worker(local_value, global_value, opt)
    return global_value[0]


def worker_numric_sum(value, env="mpi"):
    """R
    """
    return wroker_numric_opt(value, env, "sum")


def worker_numric_avg(value, env="mpi"):
    """R
    """
    return worker_numric_sum(value, env) / fleet.worker_num()


def worker_numric_min(value, env="mpi"):
    """R
    """
    return wroker_numric_opt(value, env, "min")


def worker_numric_max(value, env="mpi"):
    """R
    """
    return wroker_numric_opt(value, env, "max")


T
tangwei 已提交
74
class CtrPaddleTrainer(Trainer):
X
xiexionghang 已提交
75 76
    """R
    """
T
tangwei 已提交
77

X
xiexionghang 已提交
78
    def __init__(self, config):
X
xiexionghang 已提交
79 80
        """R
        """
T
tangwei 已提交
81
        Trainer.__init__(self, config)
T
tangwei 已提交
82
        config['output_path'] = util.get_absolute_path(
X
xiexionghang 已提交
83
            config['output_path'], config['io']['afs'])
T
tangwei 已提交
84

X
xiexionghang 已提交
85 86 87
        self._place = fluid.CPUPlace()
        self._exe = fluid.Executor(self._place)
        self._exector_context = {}
T
tangwei 已提交
88 89

        self.global_config = config
X
xiexionghang 已提交
90
        self._metrics = {}
T
tangwei 已提交
91

T
tangwei 已提交
92
        self._path_generator = util.PathGenerator({
X
xiexionghang 已提交
93
            'templates': [
X
xiexionghang 已提交
94 95 96 97 98 99 100 101 102
                {'name': 'xbox_base_done', 'template': config['output_path'] + '/xbox_base_done.txt'},
                {'name': 'xbox_delta_done', 'template': config['output_path'] + '/xbox_patch_done.txt'},
                {'name': 'xbox_base', 'template': config['output_path'] + '/xbox/{day}/base/'},
                {'name': 'xbox_delta', 'template': config['output_path'] + '/xbox/{day}/delta-{pass_id}/'},
                {'name': 'batch_model', 'template': config['output_path'] + '/batch_model/{day}/{pass_id}/'}
            ]
        })
        if 'path_generator' in config:
            self._path_generator.add_path_template(config['path_generator'])
T
tangwei 已提交
103

X
xiexionghang 已提交
104 105 106 107 108 109 110
        self.regist_context_processor('uninit', self.init)
        self.regist_context_processor('startup', self.startup)
        self.regist_context_processor('begin_day', self.begin_day)
        self.regist_context_processor('train_pass', self.train_pass)
        self.regist_context_processor('end_day', self.end_day)

    def init(self, context):
X
xiexionghang 已提交
111 112
        """R
        """
113 114 115
        role_maker = None
        if self.global_config.get('process_mode', 'mpi') == 'brilliant_cpu':
            afs_config = self.global_config['io']['afs']
T
tangwei 已提交
116
            role_maker = GeneralRoleMaker(
117 118 119 120
                hdfs_name=afs_config['fs_name'], hdfs_ugi=afs_config['fs_ugi'],
                path=self.global_config['output_path'] + "/gloo",
                init_timeout_seconds=1200, run_timeout_seconds=1200)
        fleet.init(role_maker)
X
xiexionghang 已提交
121 122 123 124 125 126 127 128 129 130
        data_var_list = []
        data_var_name_dict = {}
        runnnable_scope = []
        runnnable_cost_op = []
        context['status'] = 'startup'

        for executor in self.global_config['executor']:
            scope = fluid.Scope()
            self._exector_context[executor['name']] = {}
            self._exector_context[executor['name']]['scope'] = scope
T
tangwei 已提交
131
            self._exector_context[executor['name']]['model'] = model_basic.create(executor)
T
tangwei 已提交
132
            model = self._exector_context[executor['name']]['model']
X
xiexionghang 已提交
133 134 135 136 137 138 139
            self._metrics.update(model.get_metrics())
            runnnable_scope.append(scope)
            runnnable_cost_op.append(model.get_cost_op())
            for var in model._data_var:
                if var.name in data_var_name_dict:
                    continue
                data_var_list.append(var)
T
tangwei 已提交
140
                data_var_name_dict[var.name] = var
X
xiexionghang 已提交
141

T
tangwei 已提交
142
        optimizer = model_basic.YamlModel.build_optimizer({
T
tangwei 已提交
143
            'metrics': self._metrics,
X
xiexionghang 已提交
144
            'optimizer_conf': self.global_config['optimizer']
X
xiexionghang 已提交
145 146 147 148
        })
        optimizer.minimize(runnnable_cost_op, runnnable_scope)
        for executor in self.global_config['executor']:
            scope = self._exector_context[executor['name']]['scope']
T
tangwei 已提交
149
            model = self._exector_context[executor['name']]['model']
X
xiexionghang 已提交
150 151 152 153
            program = model._build_param['model']['train_program']
            if not executor['is_update_sparse']:
                program._fleet_opt["program_configs"][str(id(model.get_cost_op().block.program))]["push_sparse"] = []
            if 'train_thread_num' not in executor:
X
xiexionghang 已提交
154
                executor['train_thread_num'] = self.global_config['train_thread_num']
X
xiexionghang 已提交
155 156 157 158
            with fluid.scope_guard(scope):
                self._exe.run(model._build_param['model']['startup_program'])
            model.dump_model_program('./')

T
tangwei 已提交
159
        # server init done
X
xiexionghang 已提交
160 161
        if fleet.is_server():
            return 0
T
tangwei 已提交
162

X
xiexionghang 已提交
163 164 165 166 167
        self._dataset = {}
        for dataset_item in self.global_config['dataset']['data_list']:
            dataset_item['data_vars'] = data_var_list
            dataset_item.update(self.global_config['io']['afs'])
            dataset_item["batch_size"] = self.global_config['batch_size']
T
tangwei 已提交
168
            self._dataset[dataset_item['name']] = dataset.FluidTimeSplitDataset(dataset_item)
T
tangwei 已提交
169
        # if config.need_reqi_changeslot and config.reqi_dnn_plugin_day >= last_day and config.reqi_dnn_plugin_pass >= last_pass:
X
xiexionghang 已提交
170 171 172 173 174
        #    util.reqi_changeslot(config.hdfs_dnn_plugin_path, join_save_params, common_save_params, update_save_params, scope2, scope3)
        fleet.init_worker()
        pass

    def print_log(self, log_str, params):
X
xiexionghang 已提交
175 176
        """R
        """
X
xiexionghang 已提交
177
        params['index'] = fleet.worker_index()
T
tangwei 已提交
178 179 180 181 182 183 184 185
        if params['master']:
            if fleet.worker_index() == 0:
                print(log_str)
                sys.stdout.flush()
        else:
            print(log_str)
        if 'stdout' in params:
            params['stdout'] += str(datetime.datetime.now()) + log_str
X
xiexionghang 已提交
186 187

    def print_global_metrics(self, scope, model, monitor_data, stdout_str):
X
xiexionghang 已提交
188 189
        """R
        """
X
xiexionghang 已提交
190
        metrics = model.get_metrics()
T
tangwei 已提交
191
        metric_calculator = AUCMetric(None)
X
xiexionghang 已提交
192
        for metric in metrics:
T
tangwei 已提交
193
            metric_param = {'label': metric, 'metric_dict': metrics[metric]}
X
xiexionghang 已提交
194
            metric_calculator.calculate(scope, metric_param)
T
tangwei 已提交
195
            metric_result = metric_calculator.get_result_to_string()
X
xiexionghang 已提交
196
            self.print_log(metric_result, {'master': True, 'stdout': stdout_str})
X
xiexionghang 已提交
197 198
            monitor_data += metric_result
            metric_calculator.clear(scope, metric_param)
T
tangwei 已提交
199

X
xiexionghang 已提交
200
    def save_model(self, day, pass_index, base_key):
X
xiexionghang 已提交
201 202
        """R
        """
T
tangwei 已提交
203
        cost_printer = util.CostPrinter(util.print_cost,
T
tangwei 已提交
204
                                        {'master': True, 'log_format': 'save model cost %s sec'})
X
xiexionghang 已提交
205
        model_path = self._path_generator.generate_path('batch_model', {'day': day, 'pass_id': pass_index})
T
tangwei 已提交
206 207 208
        save_mode = 0  # just save all
        if pass_index < 1:  # batch_model
            save_mode = 3  # unseen_day++, save all
T
tangwei 已提交
209
        util.rank0_print("going to save_model %s" % model_path)
X
xiexionghang 已提交
210
        fleet.save_persistables(None, model_path, mode=save_mode)
211 212
        if fleet._role_maker.is_first_worker():
            self._train_pass.save_train_progress(day, pass_index, base_key, model_path, is_checkpoint=True)
X
xiexionghang 已提交
213 214
        cost_printer.done()
        return model_path
T
tangwei 已提交
215

X
xiexionghang 已提交
216
    def save_xbox_model(self, day, pass_index, xbox_base_key, monitor_data):
X
xiexionghang 已提交
217 218
        """R
        """
X
xiexionghang 已提交
219 220
        stdout_str = ""
        xbox_patch_id = str(int(time.time()))
T
tangwei 已提交
221
        util.rank0_print("begin save delta model")
T
tangwei 已提交
222

X
xiexionghang 已提交
223 224
        model_path = ""
        xbox_model_donefile = ""
T
tangwei 已提交
225
        cost_printer = util.CostPrinter(util.print_cost, {'master': True, \
T
tangwei 已提交
226 227
                                                          'log_format': 'save xbox model cost %s sec',
                                                          'stdout': stdout_str})
X
xiexionghang 已提交
228 229 230
        if pass_index < 1:
            save_mode = 2
            xbox_patch_id = xbox_base_key
X
xiexionghang 已提交
231 232
            model_path = self._path_generator.generate_path('xbox_base', {'day': day})
            xbox_model_donefile = self._path_generator.generate_path('xbox_base_done', {'day': day})
X
xiexionghang 已提交
233 234
        else:
            save_mode = 1
X
xiexionghang 已提交
235 236
            model_path = self._path_generator.generate_path('xbox_delta', {'day': day, 'pass_id': pass_index})
            xbox_model_donefile = self._path_generator.generate_path('xbox_delta_done', {'day': day})
X
xiexionghang 已提交
237 238 239
        total_save_num = fleet.save_persistables(None, model_path, mode=save_mode)
        cost_printer.done()

T
tangwei 已提交
240
        cost_printer = util.CostPrinter(util.print_cost, {'master': True,
T
tangwei 已提交
241 242
                                                          'log_format': 'save cache model cost %s sec',
                                                          'stdout': stdout_str})
T
tangwei 已提交
243
        model_file_handler = fs.FileHandler(self.global_config['io']['afs'])
X
xiexionghang 已提交
244 245 246
        if self.global_config['save_cache_model']:
            cache_save_num = fleet.save_cache_model(None, model_path, mode=save_mode)
            model_file_handler.write(
T
tangwei 已提交
247 248
                "file_prefix:part\npart_num:16\nkey_num:%d\n" % cache_save_num,
                model_path + '/000_cache/sparse_cache.meta', 'w')
X
xiexionghang 已提交
249
        cost_printer.done()
T
tangwei 已提交
250
        util.rank0_print("save xbox cache model done, key_num=%s" % cache_save_num)
X
xiexionghang 已提交
251 252 253 254 255

        save_env_param = {
            'executor': self._exe,
            'save_combine': True
        }
T
tangwei 已提交
256
        cost_printer = util.CostPrinter(util.print_cost, {'master': True,
T
tangwei 已提交
257 258
                                                          'log_format': 'save dense model cost %s sec',
                                                          'stdout': stdout_str})
259 260 261 262 263 264 265
        if fleet._role_maker.is_first_worker():
            for executor in self.global_config['executor']:
                if 'layer_for_inference' not in executor:
                    continue
                executor_name = executor['name']
                model = self._exector_context[executor_name]['model']
                save_env_param['inference_list'] = executor['layer_for_inference']
T
tangwei 已提交
266
                save_env_param['scope'] = self._exector_context[executor_name]['scope']
267 268
                model.dump_inference_param(save_env_param)
                for dnn_layer in executor['layer_for_inference']:
T
tangwei 已提交
269 270
                    model_file_handler.cp(dnn_layer['save_file_name'],
                                          model_path + '/dnn_plugin/' + dnn_layer['save_file_name'])
271
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
272 273 274
        cost_printer.done()

        xbox_done_info = {
X
xiexionghang 已提交
275 276 277 278 279 280 281 282 283
            "id": xbox_patch_id,
            "key": xbox_base_key,
            "ins_path": "",
            "ins_tag": "feasign",
            "partition_type": "2",
            "record_count": "111111",
            "monitor_data": monitor_data,
            "mpi_size": str(fleet.worker_num()),
            "input": model_path.rstrip("/") + "/000",
T
tangwei 已提交
284 285
            "job_id": util.get_env_value("JOB_ID"),
            "job_name": util.get_env_value("JOB_NAME")
X
xiexionghang 已提交
286
        }
287 288 289 290 291
        if fleet._role_maker.is_first_worker():
            model_file_handler.write(json.dumps(xbox_done_info) + "\n", xbox_model_donefile, 'a')
            if pass_index > 0:
                self._train_pass.save_train_progress(day, pass_index, xbox_base_key, model_path, is_checkpoint=False)
        fleet._role_maker._barrier_worker()
T
tangwei 已提交
292 293
        return stdout_str

X
xiexionghang 已提交
294
    def run_executor(self, executor_config, dataset, stdout_str):
X
xiexionghang 已提交
295 296
        """R
        """
X
xiexionghang 已提交
297 298 299 300 301 302 303
        day = self._train_pass.date()
        pass_id = self._train_pass._pass_id
        xbox_base_key = self._train_pass._base_key
        executor_name = executor_config['name']
        scope = self._exector_context[executor_name]['scope']
        model = self._exector_context[executor_name]['model']
        with fluid.scope_guard(scope):
T
tangwei 已提交
304
            util.rank0_print("Begin " + executor_name + " pass")
X
xiexionghang 已提交
305 306 307
            begin = time.time()
            program = model._build_param['model']['train_program']
            self._exe.train_from_dataset(program, dataset, scope,
T
tangwei 已提交
308
                                         thread=executor_config['train_thread_num'], debug=self.global_config['debug'])
X
xiexionghang 已提交
309
            end = time.time()
X
xiexionghang 已提交
310
            local_cost = (end - begin) / 60.0
T
tangwei 已提交
311 312 313
            avg_cost = worker_numric_avg(local_cost)
            min_cost = worker_numric_min(local_cost)
            max_cost = worker_numric_max(local_cost)
T
tangwei 已提交
314
            util.rank0_print("avg train time %s mins, min %s mins, max %s mins" % (avg_cost, min_cost, max_cost))
X
xiexionghang 已提交
315 316 317 318
            self._exector_context[executor_name]['cost'] = max_cost

            monitor_data = ""
            self.print_global_metrics(scope, model, monitor_data, stdout_str)
T
tangwei 已提交
319
            util.rank0_print("End " + executor_name + " pass")
X
xiexionghang 已提交
320 321
            if self._train_pass.need_dump_inference(pass_id) and executor_config['dump_inference_model']:
                stdout_str += self.save_xbox_model(day, pass_id, xbox_base_key, monitor_data)
322
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
323 324

    def startup(self, context):
X
xiexionghang 已提交
325 326
        """R
        """
X
xiexionghang 已提交
327 328 329 330 331
        if fleet.is_server():
            fleet.run_server()
            context['status'] = 'wait'
            return
        stdout_str = ""
T
tangwei 已提交
332
        self._train_pass = util.TimeTrainPass(self.global_config)
X
xiexionghang 已提交
333
        if not self.global_config['cold_start']:
T
tangwei 已提交
334
            cost_printer = util.CostPrinter(util.print_cost,
T
tangwei 已提交
335 336
                                            {'master': True, 'log_format': 'load model cost %s sec',
                                             'stdout': stdout_str})
X
xiexionghang 已提交
337
            self.print_log("going to load model %s" % self._train_pass._checkpoint_model_path, {'master': True})
T
tangwei 已提交
338
            # if config.need_reqi_changeslot and config.reqi_dnn_plugin_day >= self._train_pass.date()
X
xiexionghang 已提交
339 340
            #    and config.reqi_dnn_plugin_pass >= self._pass_id:
            #    fleet.load_one_table(0, self._train_pass._checkpoint_model_path)
T
tangwei 已提交
341
            # else:
X
xiexionghang 已提交
342 343 344 345
            fleet.init_server(self._train_pass._checkpoint_model_path, mode=0)
            cost_printer.done()
        if self.global_config['save_first_base']:
            self.print_log("save_first_base=True", {'master': True})
X
xiexionghang 已提交
346
            self.print_log("going to save xbox base model", {'master': True, 'stdout': stdout_str})
X
xiexionghang 已提交
347
            self._train_pass._base_key = int(time.time())
X
xiexionghang 已提交
348
            stdout_str += self.save_xbox_model(self._train_pass.date(), 0, self._train_pass._base_key, "")
X
xiexionghang 已提交
349
        context['status'] = 'begin_day'
T
tangwei 已提交
350

X
xiexionghang 已提交
351
    def begin_day(self, context):
X
xiexionghang 已提交
352 353
        """R
        """
X
xiexionghang 已提交
354 355 356 357 358
        stdout_str = ""
        if not self._train_pass.next():
            context['is_exit'] = True
        day = self._train_pass.date()
        pass_id = self._train_pass._pass_id
X
xiexionghang 已提交
359
        self.print_log("======== BEGIN DAY:%s ========" % day, {'master': True, 'stdout': stdout_str})
X
xiexionghang 已提交
360 361 362 363
        if pass_id == self._train_pass.max_pass_num_day():
            context['status'] = 'end_day'
        else:
            context['status'] = 'train_pass'
T
tangwei 已提交
364

X
xiexionghang 已提交
365
    def end_day(self, context):
X
xiexionghang 已提交
366 367
        """R
        """
X
xiexionghang 已提交
368 369 370 371 372
        day = self._train_pass.date()
        pass_id = self._train_pass._pass_id
        xbox_base_key = int(time.time())
        context['status'] = 'begin_day'

T
tangwei 已提交
373 374
        util.rank0_print("shrink table")
        cost_printer = util.CostPrinter(util.print_cost,
T
tangwei 已提交
375
                                        {'master': True, 'log_format': 'shrink table done, cost %s sec'})
X
xiexionghang 已提交
376 377 378 379 380 381 382 383 384
        fleet.shrink_sparse_table()
        for executor in self._exector_context:
            self._exector_context[executor]['model'].shrink({
                'scope': self._exector_context[executor]['scope'],
                'decay': self.global_config['optimizer']['dense_decay_rate']
            })
        cost_printer.done()

        next_date = self._train_pass.date(delta_day=1)
T
tangwei 已提交
385
        util.rank0_print("going to save xbox base model")
X
xiexionghang 已提交
386
        self.save_xbox_model(next_date, 0, xbox_base_key, "")
T
tangwei 已提交
387
        util.rank0_print("going to save batch model")
X
xiexionghang 已提交
388 389
        self.save_model(next_date, 0, xbox_base_key)
        self._train_pass._base_key = xbox_base_key
390
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
391 392

    def train_pass(self, context):
X
xiexionghang 已提交
393 394
        """R
        """
X
xiexionghang 已提交
395 396 397 398 399
        stdout_str = ""
        day = self._train_pass.date()
        pass_id = self._train_pass._pass_id
        base_key = self._train_pass._base_key
        pass_time = self._train_pass._current_train_time.strftime("%Y%m%d%H%M")
X
xiexionghang 已提交
400
        self.print_log("    ==== begin delta:%s ========" % pass_id, {'master': True, 'stdout': stdout_str})
X
xiexionghang 已提交
401 402
        train_begin_time = time.time()

T
tangwei 已提交
403
        cost_printer = util.CostPrinter(util.print_cost, \
T
tangwei 已提交
404 405
                                        {'master': True, 'log_format': 'load into memory done, cost %s sec',
                                         'stdout': stdout_str})
X
xiexionghang 已提交
406 407 408 409
        current_dataset = {}
        for name in self._dataset:
            current_dataset[name] = self._dataset[name].load_dataset({
                'node_num': fleet.worker_num(), 'node_idx': fleet.worker_index(),
X
xiexionghang 已提交
410
                'begin_time': pass_time, 'time_window_min': self._train_pass._interval_per_pass
X
xiexionghang 已提交
411
            })
412
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
413
        cost_printer.done()
T
tangwei 已提交
414

T
tangwei 已提交
415 416
        util.rank0_print("going to global shuffle")
        cost_printer = util.CostPrinter(util.print_cost, {
X
xiexionghang 已提交
417
            'master': True, 'stdout': stdout_str,
T
tangwei 已提交
418
            'log_format': 'global shuffle done, cost %s sec'})
X
xiexionghang 已提交
419 420 421 422
        for name in current_dataset:
            current_dataset[name].global_shuffle(fleet, self.global_config['dataset']['shuffle_thread'])
        cost_printer.done()
        # str(dataset.get_shuffle_data_size(fleet))
423
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
424 425

        if self.global_config['prefetch_data']:
T
tangwei 已提交
426 427
            next_pass_time = (self._train_pass._current_train_time +
                              datetime.timedelta(minutes=self._train_pass._interval_per_pass)).strftime("%Y%m%d%H%M")
X
xiexionghang 已提交
428 429 430
            for name in self._dataset:
                self._dataset[name].preload_dataset({
                    'node_num': fleet.worker_num(), 'node_idx': fleet.worker_index(),
X
xiexionghang 已提交
431
                    'begin_time': next_pass_time, 'time_window_min': self._train_pass._interval_per_pass
X
xiexionghang 已提交
432
                })
T
tangwei 已提交
433

434
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
435 436 437
        pure_train_begin = time.time()
        for executor in self.global_config['executor']:
            self.run_executor(executor, current_dataset[executor['dataset_name']], stdout_str)
T
tangwei 已提交
438
        cost_printer = util.CostPrinter(util.print_cost, \
T
tangwei 已提交
439
                                        {'master': True, 'log_format': 'release_memory cost %s sec'})
X
xiexionghang 已提交
440 441 442
        for name in current_dataset:
            current_dataset[name].release_memory()
        pure_train_cost = time.time() - pure_train_begin
T
tangwei 已提交
443

X
xiexionghang 已提交
444 445 446 447 448
        if self._train_pass.is_checkpoint_pass(pass_id):
            self.save_model(day, pass_id, base_key)

        train_end_time = time.time()
        train_cost = train_end_time - train_begin_time
T
tangwei 已提交
449
        other_cost = train_cost - pure_train_cost
X
xiexionghang 已提交
450 451 452
        log_str = "finished train day %s pass %s time cost:%s sec job time cost:" % (day, pass_id, train_cost)
        for executor in self._exector_context:
            log_str += '[' + executor + ':' + str(self._exector_context[executor]['cost']) + ']'
T
tangwei 已提交
453
        log_str += '[other_cost:' + str(other_cost) + ']'
T
tangwei 已提交
454 455
        util.rank0_print(log_str)
        stdout_str += util.now_time_str() + log_str
X
xiexionghang 已提交
456
        sys.stdout.write(stdout_str)
457
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
458 459 460 461 462 463
        stdout_str = ""
        if pass_id == self._train_pass.max_pass_num_day():
            context['status'] = 'end_day'
            return
        elif not self._train_pass.next():
            context['is_exit'] = True