ctr_modul_trainer.py 21.4 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
tangwei 已提交
15 16
import datetime
import json
X
xiexionghang 已提交
17 18
import sys
import time
T
tangwei 已提交
19

T
tangwei 已提交
20
import numpy as np
X
xiexionghang 已提交
21 22
import paddle.fluid as fluid
from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet
23
from paddle.fluid.incubate.fleet.base.role_maker import GeneralRoleMaker
X
xiexionghang 已提交
24

25 26 27 28 29 30
from paddlerec.core.utils import fs as fs
from paddlerec.core.utils import util as util
from paddlerec.core.metrics.auc_metrics import AUCMetric
from paddlerec.core.modules.modul import build as model_basic
from paddlerec.core.utils import dataset
from paddlerec.core.trainer import Trainer
T
tangwei 已提交
31

T
tangwei 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72

def wroker_numric_opt(value, env, opt):
    """
    numric count opt for workers
    Args:
        value: value for count
        env: mpi/gloo
        opt: count operator, SUM/MAX/MIN/AVG
    Return:
        count result
    """
    local_value = np.array([value])
    global_value = np.copy(local_value) * 0
    fleet._role_maker.all_reduce_worker(local_value, global_value, opt)
    return global_value[0]


def worker_numric_sum(value, env="mpi"):
    """R
    """
    return wroker_numric_opt(value, env, "sum")


def worker_numric_avg(value, env="mpi"):
    """R
    """
    return worker_numric_sum(value, env) / fleet.worker_num()


def worker_numric_min(value, env="mpi"):
    """R
    """
    return wroker_numric_opt(value, env, "min")


def worker_numric_max(value, env="mpi"):
    """R
    """
    return wroker_numric_opt(value, env, "max")


T
for mat  
tangwei 已提交
73
class CtrTrainer(Trainer):
X
xiexionghang 已提交
74 75
    """R
    """
T
tangwei 已提交
76

X
xiexionghang 已提交
77
    def __init__(self, config):
X
xiexionghang 已提交
78 79
        """R
        """
T
tangwei 已提交
80
        Trainer.__init__(self, config)
T
tangwei 已提交
81 82
        config['output_path'] = util.get_absolute_path(config['output_path'],
                                                       config['io']['afs'])
T
tangwei 已提交
83 84

        self.global_config = config
X
xiexionghang 已提交
85
        self._metrics = {}
T
tangwei 已提交
86

T
tangwei 已提交
87
        self._path_generator = util.PathGenerator({
T
tangwei 已提交
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
            'templates': [{
                'name': 'xbox_base_done',
                'template': config['output_path'] + '/xbox_base_done.txt'
            }, {
                'name': 'xbox_delta_done',
                'template': config['output_path'] + '/xbox_patch_done.txt'
            }, {
                'name': 'xbox_base',
                'template': config['output_path'] + '/xbox/{day}/base/'
            }, {
                'name': 'xbox_delta',
                'template':
                config['output_path'] + '/xbox/{day}/delta-{pass_id}/'
            }, {
                'name': 'batch_model',
                'template':
                config['output_path'] + '/batch_model/{day}/{pass_id}/'
            }]
X
xiexionghang 已提交
106 107 108
        })
        if 'path_generator' in config:
            self._path_generator.add_path_template(config['path_generator'])
T
tangwei 已提交
109

X
xiexionghang 已提交
110 111 112 113 114 115 116
        self.regist_context_processor('uninit', self.init)
        self.regist_context_processor('startup', self.startup)
        self.regist_context_processor('begin_day', self.begin_day)
        self.regist_context_processor('train_pass', self.train_pass)
        self.regist_context_processor('end_day', self.end_day)

    def init(self, context):
X
xiexionghang 已提交
117 118
        """R
        """
119 120 121
        role_maker = None
        if self.global_config.get('process_mode', 'mpi') == 'brilliant_cpu':
            afs_config = self.global_config['io']['afs']
T
tangwei 已提交
122
            role_maker = GeneralRoleMaker(
T
tangwei 已提交
123 124
                hdfs_name=afs_config['fs_name'],
                hdfs_ugi=afs_config['fs_ugi'],
125
                path=self.global_config['output_path'] + "/gloo",
T
tangwei 已提交
126 127
                init_timeout_seconds=1200,
                run_timeout_seconds=1200)
128
        fleet.init(role_maker)
X
xiexionghang 已提交
129 130 131 132 133 134 135 136 137 138
        data_var_list = []
        data_var_name_dict = {}
        runnnable_scope = []
        runnnable_cost_op = []
        context['status'] = 'startup'

        for executor in self.global_config['executor']:
            scope = fluid.Scope()
            self._exector_context[executor['name']] = {}
            self._exector_context[executor['name']]['scope'] = scope
T
tangwei 已提交
139 140
            self._exector_context[executor['name']][
                'model'] = model_basic.create(executor)
T
tangwei 已提交
141
            model = self._exector_context[executor['name']]['model']
X
xiexionghang 已提交
142 143
            self._metrics.update(model.get_metrics())
            runnnable_scope.append(scope)
T
tangwei 已提交
144
            runnnable_cost_op.append(model.get_avg_cost())
X
xiexionghang 已提交
145 146 147 148
            for var in model._data_var:
                if var.name in data_var_name_dict:
                    continue
                data_var_list.append(var)
T
tangwei 已提交
149
                data_var_name_dict[var.name] = var
X
xiexionghang 已提交
150

T
tangwei 已提交
151
        optimizer = model_basic.YamlModel.build_optimizer({
T
tangwei 已提交
152
            'metrics': self._metrics,
X
xiexionghang 已提交
153
            'optimizer_conf': self.global_config['optimizer']
X
xiexionghang 已提交
154 155 156 157
        })
        optimizer.minimize(runnnable_cost_op, runnnable_scope)
        for executor in self.global_config['executor']:
            scope = self._exector_context[executor['name']]['scope']
T
tangwei 已提交
158
            model = self._exector_context[executor['name']]['model']
X
xiexionghang 已提交
159 160
            program = model._build_param['model']['train_program']
            if not executor['is_update_sparse']:
T
tangwei 已提交
161 162 163
                program._fleet_opt["program_configs"][str(
                    id(model.get_avg_cost().block.program))][
                        "push_sparse"] = []
X
xiexionghang 已提交
164
            if 'train_thread_num' not in executor:
T
tangwei 已提交
165 166
                executor['train_thread_num'] = self.global_config[
                    'train_thread_num']
X
xiexionghang 已提交
167 168 169 170
            with fluid.scope_guard(scope):
                self._exe.run(model._build_param['model']['startup_program'])
            model.dump_model_program('./')

T
tangwei 已提交
171
        # server init done
X
xiexionghang 已提交
172 173
        if fleet.is_server():
            return 0
T
tangwei 已提交
174

X
xiexionghang 已提交
175 176 177 178 179
        self._dataset = {}
        for dataset_item in self.global_config['dataset']['data_list']:
            dataset_item['data_vars'] = data_var_list
            dataset_item.update(self.global_config['io']['afs'])
            dataset_item["batch_size"] = self.global_config['batch_size']
T
tangwei 已提交
180 181
            self._dataset[dataset_item[
                'name']] = dataset.FluidTimeSplitDataset(dataset_item)
T
tangwei 已提交
182
        # if config.need_reqi_changeslot and config.reqi_dnn_plugin_day >= last_day and config.reqi_dnn_plugin_pass >= last_pass:
X
xiexionghang 已提交
183 184 185 186 187
        #    util.reqi_changeslot(config.hdfs_dnn_plugin_path, join_save_params, common_save_params, update_save_params, scope2, scope3)
        fleet.init_worker()
        pass

    def print_log(self, log_str, params):
X
xiexionghang 已提交
188 189
        """R
        """
X
xiexionghang 已提交
190
        params['index'] = fleet.worker_index()
T
tangwei 已提交
191 192 193 194 195 196 197 198
        if params['master']:
            if fleet.worker_index() == 0:
                print(log_str)
                sys.stdout.flush()
        else:
            print(log_str)
        if 'stdout' in params:
            params['stdout'] += str(datetime.datetime.now()) + log_str
X
xiexionghang 已提交
199 200

    def print_global_metrics(self, scope, model, monitor_data, stdout_str):
X
xiexionghang 已提交
201 202
        """R
        """
X
xiexionghang 已提交
203
        metrics = model.get_metrics()
T
tangwei 已提交
204
        metric_calculator = AUCMetric(None)
X
xiexionghang 已提交
205
        for metric in metrics:
T
tangwei 已提交
206
            metric_param = {'label': metric, 'metric_dict': metrics[metric]}
X
xiexionghang 已提交
207
            metric_calculator.calculate(scope, metric_param)
T
tangwei 已提交
208
            metric_result = metric_calculator.get_result_to_string()
T
tangwei 已提交
209 210 211
            self.print_log(metric_result,
                           {'master': True,
                            'stdout': stdout_str})
X
xiexionghang 已提交
212 213
            monitor_data += metric_result
            metric_calculator.clear(scope, metric_param)
T
tangwei 已提交
214

X
xiexionghang 已提交
215
    def save_model(self, day, pass_index, base_key):
X
xiexionghang 已提交
216 217
        """R
        """
T
tangwei 已提交
218 219 220 221 222 223 224
        cost_printer = util.CostPrinter(util.print_cost, {
            'master': True,
            'log_format': 'save model cost %s sec'
        })
        model_path = self._path_generator.generate_path(
            'batch_model', {'day': day,
                            'pass_id': pass_index})
T
tangwei 已提交
225 226 227
        save_mode = 0  # just save all
        if pass_index < 1:  # batch_model
            save_mode = 3  # unseen_day++, save all
T
tangwei 已提交
228
        util.rank0_print("going to save_model %s" % model_path)
X
xiexionghang 已提交
229
        fleet.save_persistables(None, model_path, mode=save_mode)
230
        if fleet._role_maker.is_first_worker():
T
tangwei 已提交
231 232
            self._train_pass.save_train_progress(
                day, pass_index, base_key, model_path, is_checkpoint=True)
X
xiexionghang 已提交
233 234
        cost_printer.done()
        return model_path
T
tangwei 已提交
235

X
xiexionghang 已提交
236
    def save_xbox_model(self, day, pass_index, xbox_base_key, monitor_data):
X
xiexionghang 已提交
237 238
        """R
        """
X
xiexionghang 已提交
239 240
        stdout_str = ""
        xbox_patch_id = str(int(time.time()))
T
tangwei 已提交
241
        util.rank0_print("begin save delta model")
T
tangwei 已提交
242

X
xiexionghang 已提交
243 244
        model_path = ""
        xbox_model_donefile = ""
T
tangwei 已提交
245
        cost_printer = util.CostPrinter(util.print_cost, {'master': True, \
T
tangwei 已提交
246 247
                                                          'log_format': 'save xbox model cost %s sec',
                                                          'stdout': stdout_str})
X
xiexionghang 已提交
248 249 250
        if pass_index < 1:
            save_mode = 2
            xbox_patch_id = xbox_base_key
T
tangwei 已提交
251 252 253 254
            model_path = self._path_generator.generate_path('xbox_base',
                                                            {'day': day})
            xbox_model_donefile = self._path_generator.generate_path(
                'xbox_base_done', {'day': day})
X
xiexionghang 已提交
255 256
        else:
            save_mode = 1
T
tangwei 已提交
257 258 259 260 261 262 263
            model_path = self._path_generator.generate_path(
                'xbox_delta', {'day': day,
                               'pass_id': pass_index})
            xbox_model_donefile = self._path_generator.generate_path(
                'xbox_delta_done', {'day': day})
        total_save_num = fleet.save_persistables(
            None, model_path, mode=save_mode)
X
xiexionghang 已提交
264 265
        cost_printer.done()

T
tangwei 已提交
266 267 268 269 270
        cost_printer = util.CostPrinter(util.print_cost, {
            'master': True,
            'log_format': 'save cache model cost %s sec',
            'stdout': stdout_str
        })
T
tangwei 已提交
271
        model_file_handler = fs.FileHandler(self.global_config['io']['afs'])
X
xiexionghang 已提交
272
        if self.global_config['save_cache_model']:
T
tangwei 已提交
273 274
            cache_save_num = fleet.save_cache_model(
                None, model_path, mode=save_mode)
X
xiexionghang 已提交
275
            model_file_handler.write(
T
tangwei 已提交
276 277
                "file_prefix:part\npart_num:16\nkey_num:%d\n" % cache_save_num,
                model_path + '/000_cache/sparse_cache.meta', 'w')
X
xiexionghang 已提交
278
        cost_printer.done()
T
tangwei 已提交
279 280
        util.rank0_print("save xbox cache model done, key_num=%s" %
                         cache_save_num)
X
xiexionghang 已提交
281

T
tangwei 已提交
282 283 284 285 286 287
        save_env_param = {'executor': self._exe, 'save_combine': True}
        cost_printer = util.CostPrinter(util.print_cost, {
            'master': True,
            'log_format': 'save dense model cost %s sec',
            'stdout': stdout_str
        })
288 289 290 291 292 293
        if fleet._role_maker.is_first_worker():
            for executor in self.global_config['executor']:
                if 'layer_for_inference' not in executor:
                    continue
                executor_name = executor['name']
                model = self._exector_context[executor_name]['model']
T
tangwei 已提交
294 295 296 297
                save_env_param['inference_list'] = executor[
                    'layer_for_inference']
                save_env_param['scope'] = self._exector_context[executor_name][
                    'scope']
298 299
                model.dump_inference_param(save_env_param)
                for dnn_layer in executor['layer_for_inference']:
T
tangwei 已提交
300
                    model_file_handler.cp(dnn_layer['save_file_name'],
T
tangwei 已提交
301 302
                                          model_path + '/dnn_plugin/' +
                                          dnn_layer['save_file_name'])
303
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
304 305 306
        cost_printer.done()

        xbox_done_info = {
X
xiexionghang 已提交
307 308 309 310 311 312 313 314 315
            "id": xbox_patch_id,
            "key": xbox_base_key,
            "ins_path": "",
            "ins_tag": "feasign",
            "partition_type": "2",
            "record_count": "111111",
            "monitor_data": monitor_data,
            "mpi_size": str(fleet.worker_num()),
            "input": model_path.rstrip("/") + "/000",
T
tangwei 已提交
316 317
            "job_id": util.get_env_value("JOB_ID"),
            "job_name": util.get_env_value("JOB_NAME")
X
xiexionghang 已提交
318
        }
319
        if fleet._role_maker.is_first_worker():
T
tangwei 已提交
320 321
            model_file_handler.write(
                json.dumps(xbox_done_info) + "\n", xbox_model_donefile, 'a')
322
            if pass_index > 0:
T
tangwei 已提交
323 324 325 326 327 328
                self._train_pass.save_train_progress(
                    day,
                    pass_index,
                    xbox_base_key,
                    model_path,
                    is_checkpoint=False)
329
        fleet._role_maker._barrier_worker()
T
tangwei 已提交
330 331
        return stdout_str

X
xiexionghang 已提交
332
    def run_executor(self, executor_config, dataset, stdout_str):
X
xiexionghang 已提交
333 334
        """R
        """
X
xiexionghang 已提交
335 336 337 338 339 340 341
        day = self._train_pass.date()
        pass_id = self._train_pass._pass_id
        xbox_base_key = self._train_pass._base_key
        executor_name = executor_config['name']
        scope = self._exector_context[executor_name]['scope']
        model = self._exector_context[executor_name]['model']
        with fluid.scope_guard(scope):
T
tangwei 已提交
342
            util.rank0_print("Begin " + executor_name + " pass")
X
xiexionghang 已提交
343 344
            begin = time.time()
            program = model._build_param['model']['train_program']
T
tangwei 已提交
345 346 347 348 349 350
            self._exe.train_from_dataset(
                program,
                dataset,
                scope,
                thread=executor_config['train_thread_num'],
                debug=self.global_config['debug'])
X
xiexionghang 已提交
351
            end = time.time()
X
xiexionghang 已提交
352
            local_cost = (end - begin) / 60.0
T
tangwei 已提交
353 354 355
            avg_cost = worker_numric_avg(local_cost)
            min_cost = worker_numric_min(local_cost)
            max_cost = worker_numric_max(local_cost)
T
tangwei 已提交
356 357
            util.rank0_print("avg train time %s mins, min %s mins, max %s mins"
                             % (avg_cost, min_cost, max_cost))
X
xiexionghang 已提交
358 359 360 361
            self._exector_context[executor_name]['cost'] = max_cost

            monitor_data = ""
            self.print_global_metrics(scope, model, monitor_data, stdout_str)
T
tangwei 已提交
362
            util.rank0_print("End " + executor_name + " pass")
T
tangwei 已提交
363 364 365 366
            if self._train_pass.need_dump_inference(
                    pass_id) and executor_config['dump_inference_model']:
                stdout_str += self.save_xbox_model(day, pass_id, xbox_base_key,
                                                   monitor_data)
367
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
368 369

    def startup(self, context):
X
xiexionghang 已提交
370 371
        """R
        """
X
xiexionghang 已提交
372 373 374 375 376
        if fleet.is_server():
            fleet.run_server()
            context['status'] = 'wait'
            return
        stdout_str = ""
T
tangwei 已提交
377
        self._train_pass = util.TimeTrainPass(self.global_config)
X
xiexionghang 已提交
378
        if not self.global_config['cold_start']:
T
tangwei 已提交
379 380 381 382 383 384 385 386
            cost_printer = util.CostPrinter(util.print_cost, {
                'master': True,
                'log_format': 'load model cost %s sec',
                'stdout': stdout_str
            })
            self.print_log("going to load model %s" %
                           self._train_pass._checkpoint_model_path,
                           {'master': True})
T
tangwei 已提交
387
            # if config.need_reqi_changeslot and config.reqi_dnn_plugin_day >= self._train_pass.date()
X
xiexionghang 已提交
388 389
            #    and config.reqi_dnn_plugin_pass >= self._pass_id:
            #    fleet.load_one_table(0, self._train_pass._checkpoint_model_path)
T
tangwei 已提交
390
            # else:
X
xiexionghang 已提交
391 392 393 394
            fleet.init_server(self._train_pass._checkpoint_model_path, mode=0)
            cost_printer.done()
        if self.global_config['save_first_base']:
            self.print_log("save_first_base=True", {'master': True})
T
tangwei 已提交
395 396 397
            self.print_log("going to save xbox base model",
                           {'master': True,
                            'stdout': stdout_str})
X
xiexionghang 已提交
398
            self._train_pass._base_key = int(time.time())
T
tangwei 已提交
399 400
            stdout_str += self.save_xbox_model(self._train_pass.date(), 0,
                                               self._train_pass._base_key, "")
X
xiexionghang 已提交
401
        context['status'] = 'begin_day'
T
tangwei 已提交
402

X
xiexionghang 已提交
403
    def begin_day(self, context):
X
xiexionghang 已提交
404 405
        """R
        """
X
xiexionghang 已提交
406 407 408 409 410
        stdout_str = ""
        if not self._train_pass.next():
            context['is_exit'] = True
        day = self._train_pass.date()
        pass_id = self._train_pass._pass_id
T
tangwei 已提交
411 412 413
        self.print_log("======== BEGIN DAY:%s ========" % day,
                       {'master': True,
                        'stdout': stdout_str})
X
xiexionghang 已提交
414 415 416 417
        if pass_id == self._train_pass.max_pass_num_day():
            context['status'] = 'end_day'
        else:
            context['status'] = 'train_pass'
T
tangwei 已提交
418

X
xiexionghang 已提交
419
    def end_day(self, context):
X
xiexionghang 已提交
420 421
        """R
        """
X
xiexionghang 已提交
422 423 424 425 426
        day = self._train_pass.date()
        pass_id = self._train_pass._pass_id
        xbox_base_key = int(time.time())
        context['status'] = 'begin_day'

T
tangwei 已提交
427
        util.rank0_print("shrink table")
T
tangwei 已提交
428 429 430 431
        cost_printer = util.CostPrinter(util.print_cost, {
            'master': True,
            'log_format': 'shrink table done, cost %s sec'
        })
X
xiexionghang 已提交
432 433 434 435 436 437 438 439 440
        fleet.shrink_sparse_table()
        for executor in self._exector_context:
            self._exector_context[executor]['model'].shrink({
                'scope': self._exector_context[executor]['scope'],
                'decay': self.global_config['optimizer']['dense_decay_rate']
            })
        cost_printer.done()

        next_date = self._train_pass.date(delta_day=1)
T
tangwei 已提交
441
        util.rank0_print("going to save xbox base model")
X
xiexionghang 已提交
442
        self.save_xbox_model(next_date, 0, xbox_base_key, "")
T
tangwei 已提交
443
        util.rank0_print("going to save batch model")
X
xiexionghang 已提交
444 445
        self.save_model(next_date, 0, xbox_base_key)
        self._train_pass._base_key = xbox_base_key
446
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
447 448

    def train_pass(self, context):
X
xiexionghang 已提交
449 450
        """R
        """
X
xiexionghang 已提交
451 452 453 454 455
        stdout_str = ""
        day = self._train_pass.date()
        pass_id = self._train_pass._pass_id
        base_key = self._train_pass._base_key
        pass_time = self._train_pass._current_train_time.strftime("%Y%m%d%H%M")
T
tangwei 已提交
456 457 458
        self.print_log("    ==== begin delta:%s ========" % pass_id,
                       {'master': True,
                        'stdout': stdout_str})
X
xiexionghang 已提交
459 460
        train_begin_time = time.time()

T
tangwei 已提交
461
        cost_printer = util.CostPrinter(util.print_cost, \
T
tangwei 已提交
462 463
                                        {'master': True, 'log_format': 'load into memory done, cost %s sec',
                                         'stdout': stdout_str})
X
xiexionghang 已提交
464 465 466
        current_dataset = {}
        for name in self._dataset:
            current_dataset[name] = self._dataset[name].load_dataset({
T
tangwei 已提交
467 468 469 470
                'node_num': fleet.worker_num(),
                'node_idx': fleet.worker_index(),
                'begin_time': pass_time,
                'time_window_min': self._train_pass._interval_per_pass
X
xiexionghang 已提交
471
            })
472
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
473
        cost_printer.done()
T
tangwei 已提交
474

T
tangwei 已提交
475 476
        util.rank0_print("going to global shuffle")
        cost_printer = util.CostPrinter(util.print_cost, {
T
tangwei 已提交
477 478 479 480
            'master': True,
            'stdout': stdout_str,
            'log_format': 'global shuffle done, cost %s sec'
        })
X
xiexionghang 已提交
481
        for name in current_dataset:
T
tangwei 已提交
482 483
            current_dataset[name].global_shuffle(
                fleet, self.global_config['dataset']['shuffle_thread'])
X
xiexionghang 已提交
484 485
        cost_printer.done()
        # str(dataset.get_shuffle_data_size(fleet))
486
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
487 488

        if self.global_config['prefetch_data']:
T
tangwei 已提交
489 490 491 492
            next_pass_time = (
                self._train_pass._current_train_time + datetime.timedelta(
                    minutes=self._train_pass._interval_per_pass)
            ).strftime("%Y%m%d%H%M")
X
xiexionghang 已提交
493 494
            for name in self._dataset:
                self._dataset[name].preload_dataset({
T
tangwei 已提交
495 496 497 498
                    'node_num': fleet.worker_num(),
                    'node_idx': fleet.worker_index(),
                    'begin_time': next_pass_time,
                    'time_window_min': self._train_pass._interval_per_pass
X
xiexionghang 已提交
499
                })
T
tangwei 已提交
500

501
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
502 503
        pure_train_begin = time.time()
        for executor in self.global_config['executor']:
T
tangwei 已提交
504 505 506
            self.run_executor(executor,
                              current_dataset[executor['dataset_name']],
                              stdout_str)
T
tangwei 已提交
507
        cost_printer = util.CostPrinter(util.print_cost, \
T
tangwei 已提交
508
                                        {'master': True, 'log_format': 'release_memory cost %s sec'})
X
xiexionghang 已提交
509 510 511
        for name in current_dataset:
            current_dataset[name].release_memory()
        pure_train_cost = time.time() - pure_train_begin
T
tangwei 已提交
512

X
xiexionghang 已提交
513 514 515 516 517
        if self._train_pass.is_checkpoint_pass(pass_id):
            self.save_model(day, pass_id, base_key)

        train_end_time = time.time()
        train_cost = train_end_time - train_begin_time
T
tangwei 已提交
518
        other_cost = train_cost - pure_train_cost
T
tangwei 已提交
519 520
        log_str = "finished train day %s pass %s time cost:%s sec job time cost:" % (
            day, pass_id, train_cost)
X
xiexionghang 已提交
521
        for executor in self._exector_context:
T
tangwei 已提交
522 523
            log_str += '[' + executor + ':' + str(self._exector_context[
                executor]['cost']) + ']'
T
tangwei 已提交
524
        log_str += '[other_cost:' + str(other_cost) + ']'
T
tangwei 已提交
525 526
        util.rank0_print(log_str)
        stdout_str += util.now_time_str() + log_str
X
xiexionghang 已提交
527
        sys.stdout.write(stdout_str)
528
        fleet._role_maker._barrier_worker()
X
xiexionghang 已提交
529 530 531 532 533 534
        stdout_str = ""
        if pass_id == self._train_pass.max_pass_num_day():
            context['status'] = 'end_day'
            return
        elif not self._train_pass.next():
            context['is_exit'] = True