trainer.py 23.6 KB
Newer Older
C
chenxuyi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
C
chenxuyi 已提交
14
"""common ML train and eval procedure"""
C
chenxuyi 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
from __future__ import print_function
from __future__ import absolute_import
from __future__ import unicode_literals

import os
import itertools
import six
import inspect
from collections import namedtuple
from contextlib import contextmanager
from six.moves import zip, map
import logging
from time import time

import paddle.fluid as F
import paddle.fluid.layers as L

M
Meiyim 已提交
32
from propeller.data.functional import unflatten
C
chenxuyi 已提交
33 34
from propeller.types import RunMode, StopException, SummaryRecord, StopException
from propeller.types import ModelSpec, InferenceSpec, ProgramPair, RunConfig
C
chenxuyi 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48
from propeller.paddle import summary, collection
from propeller.paddle.data.functional import Dataset
from propeller.paddle.train import distribution
from propeller.train.model import Model
from propeller.paddle.train.monitored_executor import Saver
from propeller.paddle.train import hooks, metrics

from propeller.paddle.train.monitored_executor import MonitoredExecutor

log = logging.getLogger(__name__)

__all__ = ['train_and_eval', 'Learner']


C
chenxuyi 已提交
49
def _get_summary_writer(path):
C
chenxuyi 已提交
50 51
    summary_writer = None
    try:
M
Meiyim 已提交
52 53
        #from tensorboardX import SummaryWriter
        from visualdl import LogWriter as SummaryWriter
C
chenxuyi 已提交
54
        if distribution.status.is_master:
M
Meiyim 已提交
55
            summary_writer = SummaryWriter(os.path.join(path))
C
chenxuyi 已提交
56
    except ImportError:
M
Meiyim 已提交
57
        log.warning('Visual DL not installed, will not log to tensorboard')
C
chenxuyi 已提交
58 59 60
    return summary_writer


C
chenxuyi 已提交
61 62 63 64 65 66
def _get_one_place():
    return F.cuda_places()[0] if F.core.is_compiled_with_cuda(
    ) else F.cpu_places()[0]


def _log_eval_result(name, eval_result, swriter, state):
C
chenxuyi 已提交
67 68 69
    log.debug(eval_result)
    printable = []
    for n, val in six.iteritems(eval_result):
M
Meiyim 已提交
70 71
        #assert val.shape == (), 'metrics eval use float'
        printable.append('{}:{}'.format(n, val))
C
chenxuyi 已提交
72 73
        if swriter is not None:
            swriter.add_scalar(n, val, state.gstep)
M
Meiyim 已提交
74
            log.debug('write to tensorboard %s' % swriter.logdir)
C
chenxuyi 已提交
75

M
Meiyim 已提交
76 77
    if printable:
        log.info('[Eval:%s]:' % name + '\t'.join(printable))
C
chenxuyi 已提交
78 79


C
chenxuyi 已提交
80
def _build_net(model_fn, features, mode, params, run_config):
C
chenxuyi 已提交
81 82 83
    model_spec = model_fn(
        features=features, mode=mode, params=params, run_config=run_config)

M
Meiyim 已提交
84
    if mode == RunMode.TRAIN or mode == RunMode.EVAL:
C
chenxuyi 已提交
85 86 87 88 89 90
        if not isinstance(model_spec.loss, F.framework.Variable):
            raise ValueError('model_spec.metrics should be Variable, got %s' %
                             repr(model_spec.loss))
        if not (model_spec.loss.shape == () or model_spec.loss.shape == (1, )):
            raise ValueError('expect scarlar loss, got %s' %
                             repr(model_spec.loss.shape))
M
Meiyim 已提交
91 92 93

    if mode == RunMode.TRAIN:
        pass
C
chenxuyi 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107
    elif mode == RunMode.EVAL:
        if not isinstance(model_spec.metrics, dict):
            raise ValueError('model_spec.metrics should be dict, got %s' %
                             repr(model_spec.metrics))
    elif mode == RunMode.PREDICT:
        if not isinstance(model_spec.predictions, (list, tuple)):
            raise ValueError('model_spec.predictions shuold be list, got %s' %
                             repr(model_spec.predictions))
    else:
        raise ValueError('unkonw mode %s' % mode)
    return model_spec


class Learner(object):
C
chenxuyi 已提交
108 109
    """A Learner can train / eval / predict on a Dataset"""

C
chenxuyi 已提交
110 111 112 113 114
    def __init__(self,
                 model_class_or_model_fn,
                 run_config,
                 params=None,
                 warm_start_setting=None):
C
chenxuyi 已提交
115
        """
C
chenxuyi 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
        model_class_or_model_fn(callable|propeller.train.Model): `model_class_or_model_fn` be specified in 2 ways:
            1. subclass of propeller.train.Model which implements:
                1. \_\_init\_\_       (hyper_param, mode, run_config)
                2. forward            (features) => (prediction)
                3. backword           (loss) => None
                4. loss               (predictoin) => (loss)
                5. metrics (optional) (prediction) => (dict of propeller.Metrics)
                
            2. a model_fn takes following args:
                1. features
                2. param
                3. mode
                4. run_config(optional)
               and returns a `propeller.ModelSpec`

        params: any python object, will pass to your `model_fn` or `propeller.train.Model`
        run_config (propeller.RunConfig): run_config.max_steps should not be None.
        warm_start_setting (propeller.WarmStartSetting): Optional. warm start variable will overwrite model variable.
C
chenxuyi 已提交
134
        """
C
chenxuyi 已提交
135 136 137
        if run_config.model_dir is None:
            raise ValueError('model_dir should specified in run_config')

I
illcat 已提交
138
        if inspect.isfunction(model_class_or_model_fn):
C
chenxuyi 已提交
139
            _model_fn = model_class_or_model_fn
I
illcat 已提交
140 141
        elif issubclass(model_class_or_model_fn, Model):
            _model_fn = _build_model_fn(model_class_or_model_fn)
C
chenxuyi 已提交
142 143 144
        else:
            raise ValueError('unknown model %s' % model_class_or_model_fn)

C
chenxuyi 已提交
145
        self.model_fn = _model_fn
C
chenxuyi 已提交
146 147 148 149
        self.params = params
        self.run_config = run_config
        self.warm_start_setting = warm_start_setting

C
chenxuyi 已提交
150
    def _build_for_train(self, train_dataset):
C
chenxuyi 已提交
151 152 153 154
        train_dataset.name = 'train'
        train_program = F.Program()
        startup_prog = F.Program()
        with F.program_guard(train_program, startup_prog):
M
Meiyim 已提交
155 156 157
            with collection.Collections() as collections:
                log.info('Building Train Graph...')
                fea = train_dataset.features()
M
Meiyim 已提交
158
                fea = unflatten(fea, train_dataset.data_schema)
M
Meiyim 已提交
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
                model_spec = _build_net(self.model_fn, fea, RunMode.TRAIN,
                                        self.params, self.run_config)
                log.info('Building Train Graph: Done')

            scalars = collections.get(collection.Key.SUMMARY_SCALAR)
            histograms = collections.get(collection.Key.SUMMARY_HISTOGRAM)
            skip_optimize_ops = collections.get(collection.Key.SKIP_OPTIMIZE)
            skip_opt = set()
            if skip_optimize_ops is not None:
                skip_opt |= set(skip_optimize_ops)
            if scalars is not None:
                skip_opt |= {t for _, t in scalars}
            if histograms is not None:
                skip_opt |= {t for _, t in histograms}
            skip_opt = list(skip_opt)
C
chenxuyi 已提交
174 175 176 177 178 179 180 181 182 183 184
        log.info(
            'Train with: \n> Run_config: %s\n> Params: %s\n> Train_model_spec: %s\n'
            % (repr(self.run_config), repr(self.params), repr(model_spec)))

        summary_record = SummaryRecord(
            scalar=collections.get(collection.Key.SUMMARY_SCALAR),
            histogram=collections.get(collection.Key.SUMMARY_HISTOGRAM), )
        return ProgramPair(
            train_program=train_program,
            startup_program=startup_prog), model_spec, summary_record

C
chenxuyi 已提交
185
    def _build_for_eval(self, ds):
C
chenxuyi 已提交
186 187 188 189 190
        ds.name = 'eval'
        program = F.Program()
        startup_prog = F.Program()
        with F.program_guard(program, startup_prog):
            #share var with Train net
M
Meiyim 已提交
191 192
            log.info('Building Eval Graph')
            fea = ds.features()
M
Meiyim 已提交
193
            fea = unflatten(fea, ds.data_schema)
M
Meiyim 已提交
194 195 196 197
            model_spec = _build_net(self.model_fn, fea, RunMode.EVAL,
                                    self.params, self.run_config)
            log.info('Done')
        #program = program.clone(for_test=True)
M
Meiyim 已提交
198 199 200 201 202 203 204 205 206 207 208
        # program check
        optimizer_ops = {'sgd', 'adam', 'adagrad'}
        for op in program.global_block().ops:
            if op.type == 'dropout':
                op._set_attr('is_test', True)
            if op.type == 'batch_norm':
                op._set_attr('is_test', True)
            if op.type in optimizer_ops:
                raise RuntimeError('Found optimizer op in eval graph, op: %s' %
                                   repr(op))

C
chenxuyi 已提交
209 210 211 212 213 214
        log.info(
            'Eval with: \n> Run_config: %s\n> Params: %s\n> Train_model_spec: %s\n'
            % (repr(self.run_config), repr(self.params), repr(model_spec)))
        return ProgramPair(
            train_program=program, startup_program=startup_prog), model_spec

C
chenxuyi 已提交
215
    def _build_for_predict(self, ds):
C
chenxuyi 已提交
216 217 218 219 220
        ds.name = 'predict'
        program = F.Program()
        startup_prog = F.Program()
        with F.program_guard(program, startup_prog):
            #share var with Train net
M
Meiyim 已提交
221 222
            log.info('Building Predict Graph')
            fea = ds.features()
M
Meiyim 已提交
223
            fea = unflatten(fea, ds.data_schema)
M
Meiyim 已提交
224 225 226
            model_spec = _build_net(self.model_fn, fea, RunMode.PREDICT,
                                    self.params, self.run_config)
            log.info('Done')
C
chenxuyi 已提交
227

M
Meiyim 已提交
228 229 230 231 232 233 234 235 236
        optimizer_ops = {'sgd', 'adam', 'adagrad'}
        for op in program.global_block().ops:
            if op.type == 'dropout':
                op._set_attr('is_test', True)
            if op.type == 'batch_norm':
                op._set_attr('is_test', True)
            if op.type in optimizer_ops:
                raise RuntimeError('Found optimizer op in eval graph, op: %s' %
                                   repr(op))
M
Meiyim 已提交
237
        #program = program.clone(for_test=True)
C
chenxuyi 已提交
238 239 240 241 242 243 244 245

        log.info(
            'Predict with: \n> Run_config: %s\n> Params: %s\n> Train_model_spec: %s\n'
            % (repr(self.run_config), repr(self.params), repr(model_spec)))
        return ProgramPair(
            train_program=program, startup_program=startup_prog), model_spec

    def train(self, train_ds, train_hooks=[]):
C
chenxuyi 已提交
246
        """train on a `Dataset`"""
C
chenxuyi 已提交
247 248 249 250
        if not isinstance(train_ds, Dataset):
            raise ValueError('expect dataset to be instance of Dataset, got %s'
                             % repr(train_ds))

C
chenxuyi 已提交
251
        train_program, model_spec, summary_record = self._build_for_train(
C
chenxuyi 已提交
252 253 254 255 256 257 258
            train_ds)
        train_run_hooks = [
            hooks.StopAtStepHook(self.run_config.max_steps,
                                 self.run_config.run_steps),
            hooks.LoggingHook(
                model_spec.loss,
                summary_record=summary_record,
C
chenxuyi 已提交
259
                summary_writer=_get_summary_writer(
C
chenxuyi 已提交
260 261
                    os.path.join(self.run_config.model_dir, 'train_history')),
                per_step=self.run_config.log_steps,
M
Meiyim 已提交
262
                prefix=self.run_config.log_prefix or 'training',
C
chenxuyi 已提交
263 264
                skip_step=self.run_config.skip_steps),
        ]
C
chenxuyi 已提交
265 266
        if model_spec.train_hooks is not None:
            train_run_hooks.extend(model_spec.train_hooks)
C
chenxuyi 已提交
267
        train_run_hooks.extend(train_hooks)
C
chenxuyi 已提交
268 269

        train_executor = F.Executor(_get_one_place())
C
chenxuyi 已提交
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286

        mon_exe = MonitoredExecutor(
            train_executor,
            train_program,
            loss=model_spec.loss,
            run_config=self.run_config,
            run_hooks=train_run_hooks,
            warm_start_setting=self.warm_start_setting)

        distribution.init_distribuition_env(
            train_program)  #only initialize distribute training with 
        mon_exe.init_or_restore_variables()
        if distribution.status.is_master:
            mon_exe._hooks.append(
                hooks.CheckpointSaverHook(
                    mon_exe._saver,
                    per_step=mon_exe._save_steps,
M
Meiyim 已提交
287
                    skip_step=mon_exe._skip_steps, ))
C
chenxuyi 已提交
288 289 290 291 292 293 294 295 296 297 298

        try:
            with mon_exe:
                for data in train_ds.start():
                    mon_exe.run(feed=data)
        except (StopException, F.core.EOFException) as e:
            pass

        return mon_exe.result

    def evaluate(self, eval_dataset, eval_hooks=[]):
C
chenxuyi 已提交
299
        """eval on a `Dataset`"""
C
chenxuyi 已提交
300 301 302
        if not isinstance(eval_dataset, Dataset):
            raise ValueError('expect dataset to be instance of Dataset, got %s'
                             % repr(eval_dataset))
C
chenxuyi 已提交
303 304
        program, model_spec = self._build_for_eval(eval_dataset)
        single_card_place = _get_one_place()
C
chenxuyi 已提交
305 306
        eval_executor = F.Executor(single_card_place)

C
chenxuyi 已提交
307
        eval_run_hooks = [
C
chenxuyi 已提交
308 309 310 311 312
            hooks.StopAtStepHook(self.run_config.eval_max_steps,
                                 self.run_config.eval_max_steps),
            hooks.EvalHook(model_spec.metrics, )
        ]

C
chenxuyi 已提交
313 314 315 316
        if model_spec.eval_hooks is not None:
            eval_run_hooks.extend(model_spec.eval_hooks)
        eval_run_hooks.extend(eval_hooks)

C
chenxuyi 已提交
317 318 319
        mon_exe = MonitoredExecutor(
            eval_executor,
            program,
M
Meiyim 已提交
320
            loss=model_spec.loss,
C
chenxuyi 已提交
321
            run_config=self.run_config,
M
Meiyim 已提交
322 323 324 325
            run_hooks=eval_run_hooks,
            warm_start_setting=self.warm_start_setting)
        distribution.init_distribuition_env(
            program)  #only initialize distribute training with 
C
chenxuyi 已提交
326 327 328 329
        mon_exe.init_or_restore_variables()

        try:
            with mon_exe:
M
Meiyim 已提交
330
                for data in eval_dataset.start():
C
chenxuyi 已提交
331 332 333 334 335 336
                    mon_exe.run(feed=data)
        except (StopException, F.core.EOFException) as e:
            pass

        _, eval_result = mon_exe.result

C
chenxuyi 已提交
337
        summary_writer = _get_summary_writer(
C
chenxuyi 已提交
338
            os.path.join(self.run_config.model_dir, 'eval_history'))
C
chenxuyi 已提交
339
        _log_eval_result('eval', eval_result, summary_writer, mon_exe.state)
C
chenxuyi 已提交
340

M
Meiyim 已提交
341
        return eval_result
C
chenxuyi 已提交
342

C
chenxuyi 已提交
343 344 345 346 347 348 349
    def predict(self,
                predict_dataset,
                ckpt=-1,
                ckpt_path=None,
                steps=-1,
                split_batch=True):
        """
C
chenxuyi 已提交
350 351 352 353 354
        Perform predictoin
        will call `model_fn` and initiate user-specifed model in `propeller.RunMode.PREDICT` mode 

        Args:
            infer_dataset (propeller.data.Dataset): should not `shuffle` or `repeat`
C
chenxuyi 已提交
355 356 357 358 359 360 361
            steps (int): steps to predict, if None is specifed, 
                will stop when `StopException` is raised in `infer_dataset`
            ckpt_path (None|str): Path of a specific checkpoint to predict. 
                If None, the latest checkpoint in model_dir is used. 
                If there are no checkpoints in model_dir, 
                prediction is run with newly initialized Variables instead of ones restored from checkpoint.
            ckpt (int): deprecated args
C
chenxuyi 已提交
362 363 364 365 366
            split_batch (bool): if True, prediction of each example in a batch is returned.

        Yields:
            Evaluated values of predictions tensors.

C
chenxuyi 已提交
367
        """
C
chenxuyi 已提交
368 369 370 371
        if not isinstance(predict_dataset, Dataset):
            raise ValueError('expect dataset to be instance of Dataset, got %s'
                             % repr(predict_dataset))

C
chenxuyi 已提交
372 373
        program, model_spec = self._build_for_predict(predict_dataset)
        single_card_place = _get_one_place()
C
chenxuyi 已提交
374 375 376 377 378 379 380
        executor = F.Executor(single_card_place)
        pred_run_config = RunConfig(
            run_steps=steps if steps == -1 else None,
            model_dir=self.run_config.model_dir)
        mon_exe = MonitoredExecutor(
            executor,
            program,
M
Meiyim 已提交
381 382
            run_config=pred_run_config,
            warm_start_setting=self.warm_start_setting, )
M
Meiyim 已提交
383 384 385 386 387 388
        mon_exe.init_or_restore_variables(ckpt)
        if ckpt_path is not None:
            if not os.path.exists(ckpt_path):
                raise RuntimeError('ckpt path not found: %s' % ckpt_path)
            log.info('Loading ckpt path for prediction: %s' % ckpt_path)
            mon_exe._saver._load_program(ckpt_path)
C
chenxuyi 已提交
389 390 391
        try:
            with mon_exe:
                log.info('Runining predict from dir: %s' % repr(mon_exe.state))
C
chenxuyi 已提交
392
                single_card_place = _get_one_place()
C
chenxuyi 已提交
393 394 395 396 397 398 399 400 401 402 403 404 405 406
                for data in predict_dataset.start(places=[single_card_place]):
                    res = mon_exe.run(fetch_list=model_spec.predictions,
                                      feed=data)
                    if split_batch:
                        res = map(lambda i: i.tolist(), res)
                        res = zip(*res)  # transpose
                        for r in res:
                            yield r
                    else:
                        yield list(map(lambda i: i.tolist(), res))
        except (StopException, F.core.EOFException) as e:
            pass


C
chenxuyi 已提交
407
def train_and_eval(_placeholder=None,
C
chenxuyi 已提交
408 409 410 411 412 413 414 415 416
                   model_class_or_model_fn=None,
                   params=None,
                   run_config=None,
                   train_dataset=None,
                   eval_dataset=None,
                   warm_start_setting=None,
                   train_hooks=[],
                   eval_hooks=[],
                   exporters=[]):
C
chenxuyi 已提交
417
    """
C
chenxuyi 已提交
418 419 420 421 422
    Perform train and evaluate procesure. 
    will call `model_fn` and initiate user-specifed model in `propeller.RunMode.PREDICT` mode 

    Args:
        model_class_or_model_fn(callable|propeller.train.Model): `model_class_or_model_fn` be specified in 2 ways:
C
chenxuyi 已提交
423 424
            1. subclass of propeller.train.Model
            2. a model_fn takes following args: 1. features; 2. param; 3. mode; 4. run_config(optional)
C
chenxuyi 已提交
425 426 427 428 429
               and returns a `propeller.ModelSpec`

        params: any python object, will pass to your `model_fn` or `propeller.train.Model`
        run_config (propeller.RunConfig): run_config.max_steps should not be None.
        train_dataset (propeller.paddle.data.Dataset): training will stop if global_step > run_config.max_steps.
C
chenxuyi 已提交
430 431
        eval_dataset (propeller.paddle.data.Dataset|dict): Optional, if Dict of propeller.data.Dataset were specified, 
            will perform evluatation on every evaluation sets and report results.
C
chenxuyi 已提交
432 433 434 435
        warm_start_setting (propeller.WarmStartSetting): Optional. warm start variable will overwrite model variable.
        train_hooks (list of propeller.paddle.train.RunHook): Optional.
        eval_hooks (list of propeller.paddle.train.RunHook): Optional.
        exporters (list of propeller.paddle.train.Exporter): Optional.
C
chenxuyi 已提交
436 437
    """
    if _placeholder is not None:
C
chenxuyi 已提交
438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472
        raise ValueError('specify keyword args to this function')
    if model_class_or_model_fn is None or params is None or run_config is None or train_dataset is None:
        raise ValueError(
            'some argument is None: model_class_or_model_fn:%s params:%s run_config:%s train_dataset:%s'
            % (model_class_or_model_fn, params, run_config, train_dataset))

    #init distribution env if envvir PROPELLER_DISCONFIG is set
    if train_dataset is None:
        raise ValueError('train dataset not specified')

    if eval_dataset is None:
        raise ValueError('eval dataset not specifed')

    if not isinstance(eval_dataset, (dict, Dataset)):
        raise ValueError(
            'Eval dataset should be propeller.Dataset of a list of that, got: %s'
            % eval_dataset)
    if isinstance(eval_dataset, Dataset):
        eval_dataset = {'eval': eval_dataset}
    ds_list = list(eval_dataset.values())
    for ds in ds_list:
        ds.name = 'eval'
    first = ds_list[0]
    for d in ds_list[1:]:
        if not first.__eq__(d):
            raise ValueError(
                'eval dataset has different output_shapes or types: %s' %
                repr(ds_list))

    est = Learner(
        model_class_or_model_fn,
        run_config,
        params,
        warm_start_setting=warm_start_setting)

C
chenxuyi 已提交
473
    class _EvalHookOnTrainLoop(hooks.RunHook):
C
chenxuyi 已提交
474
        def __init__(self):
C
chenxuyi 已提交
475
            self.program, self.model_spec = est._build_for_eval(
C
chenxuyi 已提交
476 477 478
                list(eval_dataset.values())[
                    0])  #eval_datasets must have same output shapes
            self.summary_writers = {
C
chenxuyi 已提交
479
                ds_name: _get_summary_writer(
C
chenxuyi 已提交
480 481 482 483 484 485 486
                    os.path.join(
                        os.path.join(run_config.model_dir, 'eval_history'),
                        ds_name))
                for ds_name in eval_dataset
            }

        def after_run(self, _, state):
C
chenxuyi 已提交
487
            """doc"""
M
Meiyim 已提交
488
            if state.gstep > run_config.skip_steps and state.gstep % run_config.eval_steps == 0:
C
chenxuyi 已提交
489 490 491 492 493 494 495 496 497
                eval_results = {}
                for name, ds in six.iteritems(eval_dataset):
                    ehooks = [
                        hooks.StopAtStepHook(est.run_config.eval_max_steps,
                                             est.run_config.eval_max_steps),
                        hooks.EvalHook(
                            self.model_spec.metrics,
                            summary_writer=self.summary_writers[name], )
                    ]
C
chenxuyi 已提交
498
                    single_card_place = _get_one_place()
C
chenxuyi 已提交
499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514
                    eval_executor = F.Executor(single_card_place)
                    mon_exe = MonitoredExecutor(
                        eval_executor,
                        self.program,
                        run_config=est.run_config,
                        run_hooks=ehooks + eval_hooks)
                    try:
                        with mon_exe:
                            for data in ds.start(places=[single_card_place]):
                                mon_exe.run(feed=data)
                    except (StopException, F.core.EOFException) as e:
                        pass
                    hook_results = mon_exe.result
                    eval_res = hook_results[
                        1]  # hook_results:  [StopAtStepHook, EvalHook, ...]
                    eval_results[name] = eval_res
C
chenxuyi 已提交
515 516
                    _log_eval_result(name, eval_res,
                                     self.summary_writers[name], state)
M
Meiyim 已提交
517 518 519 520 521

                if distribution.status.is_master:
                    for exporter in exporters:
                        exporter.export(eval_executor, self.program,
                                        self.model_spec, eval_results, state)
C
chenxuyi 已提交
522 523 524 525
            else:
                eval_results = {}
            return eval_results

M
Meiyim 已提交
526 527 528 529 530 531
        def after_train(self, _, __):
            for _, w in six.iteritems(self.summary_writers):
                if w:
                    w.close()

    train_hooks.append(_EvalHookOnTrainLoop())
C
chenxuyi 已提交
532 533
    res = est.train(train_dataset, train_hooks=train_hooks)
    return res
C
chenxuyi 已提交
534 535 536 537 538


def _build_model_fn(model_class):
    def _model_fn(features, mode, params, run_config):
        if mode != RunMode.PREDICT:
M
Meiyim 已提交
539 540 541 542 543 544 545 546
            if isinstance(features, list) or isinstance(features, tuple):
                fea, label = features[:-1], features[-1]
            elif isinstance(features, dict):
                label = {"labels": features["labels"]}
                del features["labels"]
                fea = features
            else:
                raise TypeError
C
chenxuyi 已提交
547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580
        else:
            fea = features

        model = model_class(params, mode, run_config=run_config)
        pred = model.forward(fea)
        if isinstance(pred, F.framework.Variable):
            prediction = [pred]
        else:
            prediction = pred
        if mode == RunMode.TRAIN:
            loss = model.loss(pred, label)
            model.backward(loss)
            return ModelSpec(loss=loss, predictions=prediction, mode=mode)
        elif mode == RunMode.EVAL:
            loss = model.loss(pred, label)
            me = model.metrics(pred, label)

            inf_spec = InferenceSpec(inputs=fea, outputs=prediction)
            if 'loss' not in me:
                me['loss'] = metrics.Mean(loss)
            return ModelSpec(
                loss=loss,
                predictions=prediction,
                metrics=me,
                mode=mode,
                inference_spec=inf_spec)
        elif mode == RunMode.PREDICT:
            inf_spec = InferenceSpec(inputs=fea, outputs=prediction)
            return ModelSpec(
                predictions=prediction, mode=mode, inference_spec=inf_spec)
        else:
            raise RuntimeError('unknown run mode %s' % mode)

    return _model_fn