ctr_trainer.py 19.9 KB
Newer Older
X
xiexionghang 已提交
1
"""
X
xiexionghang 已提交
2
A paddle trainer Adapt to Ctr
X
xiexionghang 已提交
3 4
"""
import abc
X
xiexionghang 已提交
5 6 7 8 9 10
import sys
import copy
import yaml
import time
import json
import datetime
T
tangwei 已提交
11 12 13

import numpy as np

X
xiexionghang 已提交
14
import paddle.fluid as fluid
T
tangwei 已提交
15 16 17 18 19 20 21 22

from .. utils import fs as fs
from .. utils import util as util
from .. metrics.auc_metrics import AUCMetric
from .. models import base as model_basic
from .. reader import dataset
from . import trainer

X
xiexionghang 已提交
23
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
24
from paddle.fluid.incubate.fleet.base.role_maker import GeneralRoleMaker
X
xiexionghang 已提交
25

T
tangwei 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66

def wroker_numric_opt(value, env, opt):
    """
    numric count opt for workers
    Args:
        value: value for count
        env: mpi/gloo
        opt: count operator, SUM/MAX/MIN/AVG
    Return:
        count result
    """
    local_value = np.array([value])
    global_value = np.copy(local_value) * 0
    fleet._role_maker.all_reduce_worker(local_value, global_value, opt)
    return global_value[0]


def worker_numric_sum(value, env="mpi"):
    """R
    """
    return wroker_numric_opt(value, env, "sum")


def worker_numric_avg(value, env="mpi"):
    """R
    """
    return worker_numric_sum(value, env) / fleet.worker_num()


def worker_numric_min(value, env="mpi"):
    """R
    """
    return wroker_numric_opt(value, env, "min")


def worker_numric_max(value, env="mpi"):
    """R
    """
    return wroker_numric_opt(value, env, "max")


T
tangwei 已提交
67
class CtrPaddleTrainer(trainer.Trainer):
X
xiexionghang 已提交
68 69
    """R
    """
T
tangwei 已提交
70

X
xiexionghang 已提交
71
    def __init__(self, config):
X
xiexionghang 已提交
72 73
        """R
        """
T
tangwei 已提交
74 75
        trainer.Trainer.__init__(self, config)
        config['output_path'] = util.get_absolute_path(
X
xiexionghang 已提交
76 77 78 79 80 81
            config['output_path'], config['io']['afs'])
        self.global_config = config
        self._place = fluid.CPUPlace()
        self._exe = fluid.Executor(self._place)
        self._exector_context = {}
        self._metrics = {}
T
tangwei 已提交
82
        self._path_generator = util.PathGenerator({
X
xiexionghang 已提交
83
            'templates': [
X
xiexionghang 已提交
84 85 86 87 88 89 90 91 92
                {'name': 'xbox_base_done', 'template': config['output_path'] + '/xbox_base_done.txt'},
                {'name': 'xbox_delta_done', 'template': config['output_path'] + '/xbox_patch_done.txt'},
                {'name': 'xbox_base', 'template': config['output_path'] + '/xbox/{day}/base/'},
                {'name': 'xbox_delta', 'template': config['output_path'] + '/xbox/{day}/delta-{pass_id}/'},
                {'name': 'batch_model', 'template': config['output_path'] + '/batch_model/{day}/{pass_id}/'}
            ]
        })
        if 'path_generator' in config:
            self._path_generator.add_path_template(config['path_generator'])
T
tangwei 已提交
93

X
xiexionghang 已提交
94 95 96 97 98 99 100
        self.regist_context_processor('uninit', self.init)
        self.regist_context_processor('startup', self.startup)
        self.regist_context_processor('begin_day', self.begin_day)
        self.regist_context_processor('train_pass', self.train_pass)
        self.regist_context_processor('end_day', self.end_day)

    def init(self, context):
X
xiexionghang 已提交
101 102
        """R
        """
103 104 105
        role_maker = None
        if self.global_config.get('process_mode', 'mpi') == 'brilliant_cpu':
            afs_config = self.global_config['io']['afs']
T
tangwei 已提交
106
            role_maker = GeneralRoleMaker(
107 108 109 110
                hdfs_name=afs_config['fs_name'], hdfs_ugi=afs_config['fs_ugi'],
                path=self.global_config['output_path'] + "/gloo",
                init_timeout_seconds=1200, run_timeout_seconds=1200)
        fleet.init(role_maker)
X
xiexionghang 已提交
111 112 113 114 115 116 117 118 119 120
        data_var_list = []
        data_var_name_dict = {}
        runnnable_scope = []
        runnnable_cost_op = []
        context['status'] = 'startup'

        for executor in self.global_config['executor']:
            scope = fluid.Scope()
            self._exector_context[executor['name']] = {}
            self._exector_context[executor['name']]['scope'] = scope
T
tangwei 已提交
121
            self._exector_context[executor['name']]['model'] = model_basic.create(executor)
T
tangwei 已提交
122
            model = self._exector_context[executor['name']]['model']
X
xiexionghang 已提交
123 124 125 126 127 128 129
            self._metrics.update(model.get_metrics())
            runnnable_scope.append(scope)
            runnnable_cost_op.append(model.get_cost_op())
            for var in model._data_var:
                if var.name in data_var_name_dict:
                    continue
                data_var_list.append(var)
T
tangwei 已提交
130
                data_var_name_dict[var.name] = var
X
xiexionghang 已提交
131

T
tangwei 已提交
132
        optimizer = model_basic.YamlModel.build_optimizer({
T
tangwei 已提交
133
            'metrics': self._metrics,
X
xiexionghang 已提交
134
            'optimizer_conf': self.global_config['optimizer']
X
xiexionghang 已提交
135 136 137 138
        })
        optimizer.minimize(runnnable_cost_op, runnnable_scope)
        for executor in self.global_config['executor']:
            scope = self._exector_context[executor['name']]['scope']
T
tangwei 已提交
139
            model = self._exector_context[executor['name']]['model']
X
xiexionghang 已提交
140 141 142 143
            program = model._build_param['model']['train_program']
            if not executor['is_update_sparse']:
                program._fleet_opt["program_configs"][str(id(model.get_cost_op().block.program))]["push_sparse"] = []
            if 'train_thread_num' not in executor:
X
xiexionghang 已提交
144
                executor['train_thread_num'] = self.global_config['train_thread_num']
X
xiexionghang 已提交
145 146 147 148
            with fluid.scope_guard(scope):
                self._exe.run(model._build_param['model']['startup_program'])
            model.dump_model_program('./')

T
tangwei 已提交
149
        # server init done
X
xiexionghang 已提交
150 151
        if fleet.is_server():
            return 0
T
tangwei 已提交
152

X
xiexionghang 已提交
153 154 155 156 157
        self._dataset = {}
        for dataset_item in self.global_config['dataset']['data_list']:
            dataset_item['data_vars'] = data_var_list
            dataset_item.update(self.global_config['io']['afs'])
            dataset_item["batch_size"] = self.global_config['batch_size']
T
tangwei 已提交
158
            self._dataset[dataset_item['name']] = dataset.FluidTimeSplitDataset(dataset_item)
T
tangwei 已提交
159
        # if config.need_reqi_changeslot and config.reqi_dnn_plugin_day >= last_day and config.reqi_dnn_plugin_pass >= last_pass:
X
xiexionghang 已提交
160 161 162 163 164
        #    util.reqi_changeslot(config.hdfs_dnn_plugin_path, join_save_params, common_save_params, update_save_params, scope2, scope3)
        fleet.init_worker()
        pass

    def print_log(self, log_str, params):
X
xiexionghang 已提交
165 166
        """R
        """
X
xiexionghang 已提交
167
        params['index'] = fleet.worker_index()
T
tangwei 已提交
168 169 170 171 172 173 174 175
        if params['master']:
            if fleet.worker_index() == 0:
                print(log_str)
                sys.stdout.flush()
        else:
            print(log_str)
        if 'stdout' in params:
            params['stdout'] += str(datetime.datetime.now()) + log_str
X
xiexionghang 已提交
176 177

    def print_global_metrics(self, scope, model, monitor_data, stdout_str):
X
xiexionghang 已提交
178 179
        """R
        """
X
xiexionghang 已提交
180
        metrics = model.get_metrics()
T
tangwei 已提交
181
        metric_calculator = AUCMetric(None)
X
xiexionghang 已提交
182
        for metric in metrics:
T
tangwei 已提交
183
            metric_param = {'label': metric, 'metric_dict': metrics[metric]}
X
xiexionghang 已提交
184
            metric_calculator.calculate(scope, metric_param)
T
tangwei 已提交
185
            metric_result = metric_calculator.get_result_to_string()
X
xiexionghang 已提交
186
            self.print_log(metric_result, {'master': True, 'stdout': stdout_str})
X
xiexionghang 已提交
187 188
            monitor_data += metric_result
            metric_calculator.clear(scope, metric_param)
T
tangwei 已提交
189

X
xiexionghang 已提交
190
    def save_model(self, day, pass_index, base_key):
X
xiexionghang 已提交
191 192
        """R
        """
T
tangwei 已提交
193
        cost_printer = util.CostPrinter(util.print_cost,
T
tangwei 已提交
194
                                              {'master': True, 'log_format': 'save model cost %s sec'})
X
xiexionghang 已提交
195
        model_path = self._path_generator.generate_path('batch_model', {'day': day, 'pass_id': pass_index})
T
tangwei 已提交
196 197 198
        save_mode = 0  # just save all
        if pass_index < 1:  # batch_model
            save_mode = 3  # unseen_day++, save all
T
tangwei 已提交
199
        util.rank0_print("going to save_model %s" % model_path)
X
xiexionghang 已提交
200
        fleet.save_persistables(None, model_path, mode=save_mode)
201 202
        if fleet._role_maker.is_first_worker():
            self._train_pass.save_train_progress(day, pass_index, base_key, model_path, is_checkpoint=True)
X
xiexionghang 已提交
203 204
        cost_printer.done()
        return model_path
T
tangwei 已提交
205

X
xiexionghang 已提交
206
    def save_xbox_model(self, day, pass_index, xbox_base_key, monitor_data):
X
xiexionghang 已提交
207 208
        """R
        """
X
xiexionghang 已提交
209 210
        stdout_str = ""
        xbox_patch_id = str(int(time.time()))
T
tangwei 已提交
211
        util.rank0_print("begin save delta model")
T
tangwei 已提交
212

X
xiexionghang 已提交
213 214
        model_path = ""
        xbox_model_donefile = ""
T
tangwei 已提交
215
        cost_printer = util.CostPrinter(util.print_cost, {'master': True, \
T
tangwei 已提交
216 217
                                                                      'log_format': 'save xbox model cost %s sec',
                                                                      'stdout': stdout_str})
X
xiexionghang 已提交
218 219 220
        if pass_index < 1:
            save_mode = 2
            xbox_patch_id = xbox_base_key
X
xiexionghang 已提交
221 222
            model_path = self._path_generator.generate_path('xbox_base', {'day': day})
            xbox_model_donefile = self._path_generator.generate_path('xbox_base_done', {'day': day})
X
xiexionghang 已提交
223 224
        else:
            save_mode = 1
X
xiexionghang 已提交
225 226
            model_path = self._path_generator.generate_path('xbox_delta', {'day': day, 'pass_id': pass_index})
            xbox_model_donefile = self._path_generator.generate_path('xbox_delta_done', {'day': day})
X
xiexionghang 已提交
227 228 229
        total_save_num = fleet.save_persistables(None, model_path, mode=save_mode)
        cost_printer.done()

T
tangwei 已提交
230
        cost_printer = util.CostPrinter(util.print_cost, {'master': True,
T
tangwei 已提交
231 232
                                                                      'log_format': 'save cache model cost %s sec',
                                                                      'stdout': stdout_str})
T
tangwei 已提交
233
        model_file_handler = fs.FileHandler(self.global_config['io']['afs'])
X
xiexionghang 已提交
234 235 236
        if self.global_config['save_cache_model']:
            cache_save_num = fleet.save_cache_model(None, model_path, mode=save_mode)
            model_file_handler.write(
T
tangwei 已提交
237 238
                "file_prefix:part\npart_num:16\nkey_num:%d\n" % cache_save_num,
                model_path + '/000_cache/sparse_cache.meta', 'w')
X
xiexionghang 已提交
239
        cost_printer.done()
T
tangwei 已提交
240
        util.rank0_print("save xbox cache model done, key_num=%s" % cache_save_num)
X
xiexionghang 已提交
241 242 243 244 245

        save_env_param = {
            'executor': self._exe,
            'save_combine': True
        }
T
tangwei 已提交
246
        cost_printer = util.CostPrinter(util.print_cost, {'master': True,
T
tangwei 已提交
247 248
                                                                      'log_format': 'save dense model cost %s sec',
                                                                      'stdout': stdout_str})
249 250 251 252 253 254 255
        if fleet._role_maker.is_first_worker():
            for executor in self.global_config['executor']:
                if 'layer_for_inference' not in executor:
                    continue
                executor_name = executor['name']
                model = self._exector_context[executor_name]['model']
                save_env_param['inference_list'] = executor['layer_for_inference']
T
tangwei 已提交
256
                save_env_param['scope'] = self._exector_context[executor_name]['scope']
257 258
                model.dump_inference_param(save_env_param)
                for dnn_layer in executor['layer_for_inference']:
T
tangwei 已提交
259 260
                    model_file_handler.cp(dnn_layer['save_file_name'],
                                          model_path + '/dnn_plugin/' + dnn_layer['save_file_name'])
261
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
262 263 264
        cost_printer.done()

        xbox_done_info = {
X
xiexionghang 已提交
265 266 267 268 269 270 271 272 273
            "id": xbox_patch_id,
            "key": xbox_base_key,
            "ins_path": "",
            "ins_tag": "feasign",
            "partition_type": "2",
            "record_count": "111111",
            "monitor_data": monitor_data,
            "mpi_size": str(fleet.worker_num()),
            "input": model_path.rstrip("/") + "/000",
T
tangwei 已提交
274 275
            "job_id": util.get_env_value("JOB_ID"),
            "job_name": util.get_env_value("JOB_NAME")
X
xiexionghang 已提交
276
        }
277 278 279 280 281
        if fleet._role_maker.is_first_worker():
            model_file_handler.write(json.dumps(xbox_done_info) + "\n", xbox_model_donefile, 'a')
            if pass_index > 0:
                self._train_pass.save_train_progress(day, pass_index, xbox_base_key, model_path, is_checkpoint=False)
        fleet._role_maker._barrier_worker()
T
tangwei 已提交
282 283
        return stdout_str

X
xiexionghang 已提交
284
    def run_executor(self, executor_config, dataset, stdout_str):
X
xiexionghang 已提交
285 286
        """R
        """
X
xiexionghang 已提交
287 288 289 290 291 292 293
        day = self._train_pass.date()
        pass_id = self._train_pass._pass_id
        xbox_base_key = self._train_pass._base_key
        executor_name = executor_config['name']
        scope = self._exector_context[executor_name]['scope']
        model = self._exector_context[executor_name]['model']
        with fluid.scope_guard(scope):
T
tangwei 已提交
294
            util.rank0_print("Begin " + executor_name + " pass")
X
xiexionghang 已提交
295 296 297
            begin = time.time()
            program = model._build_param['model']['train_program']
            self._exe.train_from_dataset(program, dataset, scope,
T
tangwei 已提交
298
                                         thread=executor_config['train_thread_num'], debug=self.global_config['debug'])
X
xiexionghang 已提交
299
            end = time.time()
X
xiexionghang 已提交
300
            local_cost = (end - begin) / 60.0
T
tangwei 已提交
301 302 303
            avg_cost = worker_numric_avg(local_cost)
            min_cost = worker_numric_min(local_cost)
            max_cost = worker_numric_max(local_cost)
T
tangwei 已提交
304
            util.rank0_print("avg train time %s mins, min %s mins, max %s mins" % (avg_cost, min_cost, max_cost))
X
xiexionghang 已提交
305 306 307 308
            self._exector_context[executor_name]['cost'] = max_cost

            monitor_data = ""
            self.print_global_metrics(scope, model, monitor_data, stdout_str)
T
tangwei 已提交
309
            util.rank0_print("End " + executor_name + " pass")
X
xiexionghang 已提交
310 311
            if self._train_pass.need_dump_inference(pass_id) and executor_config['dump_inference_model']:
                stdout_str += self.save_xbox_model(day, pass_id, xbox_base_key, monitor_data)
312
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
313 314

    def startup(self, context):
X
xiexionghang 已提交
315 316
        """R
        """
X
xiexionghang 已提交
317 318 319 320 321
        if fleet.is_server():
            fleet.run_server()
            context['status'] = 'wait'
            return
        stdout_str = ""
T
tangwei 已提交
322
        self._train_pass = util.TimeTrainPass(self.global_config)
X
xiexionghang 已提交
323
        if not self.global_config['cold_start']:
T
tangwei 已提交
324
            cost_printer = util.CostPrinter(util.print_cost,
T
tangwei 已提交
325 326
                                                  {'master': True, 'log_format': 'load model cost %s sec',
                                                   'stdout': stdout_str})
X
xiexionghang 已提交
327
            self.print_log("going to load model %s" % self._train_pass._checkpoint_model_path, {'master': True})
T
tangwei 已提交
328
            # if config.need_reqi_changeslot and config.reqi_dnn_plugin_day >= self._train_pass.date()
X
xiexionghang 已提交
329 330
            #    and config.reqi_dnn_plugin_pass >= self._pass_id:
            #    fleet.load_one_table(0, self._train_pass._checkpoint_model_path)
T
tangwei 已提交
331
            # else:
X
xiexionghang 已提交
332 333 334 335
            fleet.init_server(self._train_pass._checkpoint_model_path, mode=0)
            cost_printer.done()
        if self.global_config['save_first_base']:
            self.print_log("save_first_base=True", {'master': True})
X
xiexionghang 已提交
336
            self.print_log("going to save xbox base model", {'master': True, 'stdout': stdout_str})
X
xiexionghang 已提交
337
            self._train_pass._base_key = int(time.time())
X
xiexionghang 已提交
338
            stdout_str += self.save_xbox_model(self._train_pass.date(), 0, self._train_pass._base_key, "")
X
xiexionghang 已提交
339
        context['status'] = 'begin_day'
T
tangwei 已提交
340

X
xiexionghang 已提交
341
    def begin_day(self, context):
X
xiexionghang 已提交
342 343
        """R
        """
X
xiexionghang 已提交
344 345 346 347 348
        stdout_str = ""
        if not self._train_pass.next():
            context['is_exit'] = True
        day = self._train_pass.date()
        pass_id = self._train_pass._pass_id
X
xiexionghang 已提交
349
        self.print_log("======== BEGIN DAY:%s ========" % day, {'master': True, 'stdout': stdout_str})
X
xiexionghang 已提交
350 351 352 353
        if pass_id == self._train_pass.max_pass_num_day():
            context['status'] = 'end_day'
        else:
            context['status'] = 'train_pass'
T
tangwei 已提交
354

X
xiexionghang 已提交
355
    def end_day(self, context):
X
xiexionghang 已提交
356 357
        """R
        """
X
xiexionghang 已提交
358 359 360 361 362
        day = self._train_pass.date()
        pass_id = self._train_pass._pass_id
        xbox_base_key = int(time.time())
        context['status'] = 'begin_day'

T
tangwei 已提交
363 364
        util.rank0_print("shrink table")
        cost_printer = util.CostPrinter(util.print_cost,
T
tangwei 已提交
365
                                              {'master': True, 'log_format': 'shrink table done, cost %s sec'})
X
xiexionghang 已提交
366 367 368 369 370 371 372 373 374
        fleet.shrink_sparse_table()
        for executor in self._exector_context:
            self._exector_context[executor]['model'].shrink({
                'scope': self._exector_context[executor]['scope'],
                'decay': self.global_config['optimizer']['dense_decay_rate']
            })
        cost_printer.done()

        next_date = self._train_pass.date(delta_day=1)
T
tangwei 已提交
375
        util.rank0_print("going to save xbox base model")
X
xiexionghang 已提交
376
        self.save_xbox_model(next_date, 0, xbox_base_key, "")
T
tangwei 已提交
377
        util.rank0_print("going to save batch model")
X
xiexionghang 已提交
378 379
        self.save_model(next_date, 0, xbox_base_key)
        self._train_pass._base_key = xbox_base_key
380
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
381 382

    def train_pass(self, context):
X
xiexionghang 已提交
383 384
        """R
        """
X
xiexionghang 已提交
385 386 387 388 389
        stdout_str = ""
        day = self._train_pass.date()
        pass_id = self._train_pass._pass_id
        base_key = self._train_pass._base_key
        pass_time = self._train_pass._current_train_time.strftime("%Y%m%d%H%M")
X
xiexionghang 已提交
390
        self.print_log("    ==== begin delta:%s ========" % pass_id, {'master': True, 'stdout': stdout_str})
X
xiexionghang 已提交
391 392
        train_begin_time = time.time()

T
tangwei 已提交
393
        cost_printer = util.CostPrinter(util.print_cost, \
T
tangwei 已提交
394 395
                                              {'master': True, 'log_format': 'load into memory done, cost %s sec',
                                               'stdout': stdout_str})
X
xiexionghang 已提交
396 397 398 399
        current_dataset = {}
        for name in self._dataset:
            current_dataset[name] = self._dataset[name].load_dataset({
                'node_num': fleet.worker_num(), 'node_idx': fleet.worker_index(),
X
xiexionghang 已提交
400
                'begin_time': pass_time, 'time_window_min': self._train_pass._interval_per_pass
X
xiexionghang 已提交
401
            })
402
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
403
        cost_printer.done()
T
tangwei 已提交
404

T
tangwei 已提交
405 406
        util.rank0_print("going to global shuffle")
        cost_printer = util.CostPrinter(util.print_cost, {
X
xiexionghang 已提交
407
            'master': True, 'stdout': stdout_str,
T
tangwei 已提交
408
            'log_format': 'global shuffle done, cost %s sec'})
X
xiexionghang 已提交
409 410 411 412
        for name in current_dataset:
            current_dataset[name].global_shuffle(fleet, self.global_config['dataset']['shuffle_thread'])
        cost_printer.done()
        # str(dataset.get_shuffle_data_size(fleet))
413
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
414 415

        if self.global_config['prefetch_data']:
T
tangwei 已提交
416 417
            next_pass_time = (self._train_pass._current_train_time +
                              datetime.timedelta(minutes=self._train_pass._interval_per_pass)).strftime("%Y%m%d%H%M")
X
xiexionghang 已提交
418 419 420
            for name in self._dataset:
                self._dataset[name].preload_dataset({
                    'node_num': fleet.worker_num(), 'node_idx': fleet.worker_index(),
X
xiexionghang 已提交
421
                    'begin_time': next_pass_time, 'time_window_min': self._train_pass._interval_per_pass
X
xiexionghang 已提交
422
                })
T
tangwei 已提交
423

424
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
425 426 427
        pure_train_begin = time.time()
        for executor in self.global_config['executor']:
            self.run_executor(executor, current_dataset[executor['dataset_name']], stdout_str)
T
tangwei 已提交
428
        cost_printer = util.CostPrinter(util.print_cost, \
T
tangwei 已提交
429
                                              {'master': True, 'log_format': 'release_memory cost %s sec'})
X
xiexionghang 已提交
430 431 432
        for name in current_dataset:
            current_dataset[name].release_memory()
        pure_train_cost = time.time() - pure_train_begin
T
tangwei 已提交
433

X
xiexionghang 已提交
434 435 436 437 438
        if self._train_pass.is_checkpoint_pass(pass_id):
            self.save_model(day, pass_id, base_key)

        train_end_time = time.time()
        train_cost = train_end_time - train_begin_time
T
tangwei 已提交
439
        other_cost = train_cost - pure_train_cost
X
xiexionghang 已提交
440 441 442
        log_str = "finished train day %s pass %s time cost:%s sec job time cost:" % (day, pass_id, train_cost)
        for executor in self._exector_context:
            log_str += '[' + executor + ':' + str(self._exector_context[executor]['cost']) + ']'
T
tangwei 已提交
443
        log_str += '[other_cost:' + str(other_cost) + ']'
T
tangwei 已提交
444 445
        util.rank0_print(log_str)
        stdout_str += util.now_time_str() + log_str
X
xiexionghang 已提交
446
        sys.stdout.write(stdout_str)
447
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
448 449 450 451 452 453
        stdout_str = ""
        if pass_id == self._train_pass.max_pass_num_day():
            context['status'] = 'end_day'
            return
        elif not self._train_pass.next():
            context['is_exit'] = True