model.py 30.8 KB
Newer Older
Y
Yang Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
Yang Zhang 已提交
15 16 17
from __future__ import absolute_import

import inspect
Y
Yang Zhang 已提交
18 19
import os
import pickle
Y
Yang Zhang 已提交
20
import numpy as np
Q
qingqing01 已提交
21 22 23
import itertools
from collections import Iterable
from collections import OrderedDict
Y
Yang Zhang 已提交
24 25

from paddle import fluid
Y
Yang Zhang 已提交
26 27 28
from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer
Y
Yang Zhang 已提交
29
from paddle.fluid.dygraph.base import to_variable
Q
qingqing01 已提交
30

31
from metrics import Metric
Q
qingqing01 已提交
32
from callbacks import config_callbacks
Y
Yang Zhang 已提交
33

Q
qingqing01 已提交
34
__all__ = ['Model', 'Loss', 'CrossEntropy', 'Input']
Y
Yang Zhang 已提交
35 36 37


def to_list(value):
Q
qingqing01 已提交
38 39
    if value is None:
        return value
Y
Yang Zhang 已提交
40 41 42 43 44
    if isinstance(value, (list, tuple)):
        return value
    return [value]


45 46 47 48 49 50 51 52
def to_numpy(var):
    assert isinstance(var, (Variable, fluid.core.VarBase)), "not a variable"
    if isinstance(var, fluid.core.VarBase):
        return var.numpy()
    t = global_scope().find_var(var.name).get_tensor()
    return np.array(t)


D
dengkaipeng 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
def flatten_list(l):
    assert isinstance(l, list), "not a list"
    outl = []
    splits = []
    for sl in l:
        assert isinstance(sl, list), "sub content not a list"
        splits.append(len(sl))
        outl += sl
    return outl, splits


def restore_flatten_list(l, splits):
    outl = []
    for split in splits:
        assert len(l) >= split, "list length invalid"
        sl, l = l[:split], l[split:]
        outl.append(sl)
    return outl


Y
Yang Zhang 已提交
73 74 75 76 77 78 79
def extract_args(func):
    if hasattr(inspect, 'getfullargspec'):
        return inspect.getfullargspec(func)[0]
    else:
        return inspect.getargspec(func)[0]


80 81
class Input(fluid.dygraph.Layer):
    def __init__(self, shape=None, dtype=None, name=None):
G
guosheng 已提交
82
        super(Input, self).__init__()
83 84 85
        self.shape = shape
        self.dtype = dtype
        self.name = name
Y
Yang Zhang 已提交
86

87 88
    def forward(self):
        return fluid.data(self.name, shape=self.shape, dtype=self.dtype)
Y
Yang Zhang 已提交
89 90


Y
Yang Zhang 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103
class Loss(object):
    def __init__(self, average=True):
        super(Loss, self).__init__()
        self.average = average

    def forward(self, outputs, labels):
        raise NotImplementedError()

    def __call__(self, outputs, labels):
        labels = to_list(labels)
        if in_dygraph_mode():
            labels = [to_variable(l) for l in labels]
        losses = to_list(self.forward(to_list(outputs), labels))
Q
qingqing01 已提交
104 105 106 107 108
        if self.average:
            losses = [fluid.layers.reduce_mean(l) for l in losses]
        else:
            losses = [fluid.layers.reduce_sum(l) for l in losses]
        return losses
Y
Yang Zhang 已提交
109 110 111


class CrossEntropy(Loss):
Q
qingqing01 已提交
112
    def __init__(self, average=True):
Y
Yang Zhang 已提交
113 114 115
        super(CrossEntropy, self).__init__()

    def forward(self, outputs, labels):
Q
qingqing01 已提交
116 117 118
        return [
            fluid.layers.cross_entropy(o, l) for o, l in zip(outputs, labels)
        ]
Y
Yang Zhang 已提交
119 120


Y
Yang Zhang 已提交
121 122 123 124 125 126 127
class StaticGraphAdapter(object):
    def __init__(self, model):
        super(StaticGraphAdapter, self).__init__()
        self.model = model
        # with `_build_once` gone, parameters are now created in `__init__`
        # so we need to keep track of the parameters already created
        self._startup_prog = fluid.default_startup_program()
128
        self._orig_prog = fluid.default_main_program()
Y
Yang Zhang 已提交
129

130
        self._label_vars = {}  # label variables
Q
qingqing01 已提交
131
        self._input_vars = {}  # label variables
Y
Yang Zhang 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145
        self._endpoints = {}
        self._loss_endpoint = None
        self._executor = None
        self._progs = {}
        self._compiled_progs = {}

    @property
    def mode(self):
        return self.model.mode

    @mode.setter
    def mode(self, value):
        self.model.mode = value

146
    def train(self, inputs, labels=None):
Q
qingqing01 已提交
147
        assert self.model._optimizer, \
Y
Yang Zhang 已提交
148 149
            "model not ready, please call `model.prepare()` first"
        self.mode = 'train'
150
        return self._run(inputs, labels)
Y
Yang Zhang 已提交
151

152
    def eval(self, inputs, labels=None):
Y
Yang Zhang 已提交
153
        self.mode = 'eval'
154
        return self._run(inputs, labels)
Y
Yang Zhang 已提交
155

156
    def test(self, inputs):
Y
Yang Zhang 已提交
157
        self.mode = 'test'
158
        return self._run(inputs, None)
Y
Yang Zhang 已提交
159

160 161 162
    def parameters(self, *args, **kwargs):
        return None

Y
Yang Zhang 已提交
163
    def save(self, path):
Y
Yang Zhang 已提交
164 165 166
        def _save(state, path):
            if not state:
                return
Q
qingqing01 已提交
167 168 169 170
            state = {
                k: to_numpy(v) if isinstance(v, Variable) else v
                for k, v in state.items()
            }
Y
Yang Zhang 已提交
171 172 173 174 175
            with open(path, 'wb') as f:
                pickle.dump(state, f)

        base = os.path.basename(path)
        assert base != "", "path should be of 'dirname/filename' format"
176 177 178
        dir_name = os.path.dirname(path)
        if dir_name and not os.path.exists(dir_name):
            os.makedirs(dir_name)
Y
Yang Zhang 已提交
179 180
        param_path = path + ".pdparams"
        _save(self.model.state_dict(), param_path)
Y
Yang Zhang 已提交
181 182
        prog = self._progs.get('train', None)
        if prog is None or self.model._optimizer is None:
Y
Yang Zhang 已提交
183 184 185
            return
        # XXX `optimizer.state_dict()` only work in dygraph mode
        optim_path = path + ".pdopt"
Q
qingqing01 已提交
186 187 188 189
        optim = {
            p.name: p
            for p in filter(is_belong_to_optimizer, prog.list_vars())
        }
190 191
        if not optim:
            return
192

Y
Yang Zhang 已提交
193
        _save(optim, optim_path)
Y
Yang Zhang 已提交
194 195

    def load(self, path):
Y
Yang Zhang 已提交
196 197 198 199 200 201 202
        def _load(path):
            if not os.path.exists(path):
                return
            with open(path, 'rb') as f:
                return pickle.load(f)

        param_path = path + ".pdparams"
Y
Yang Zhang 已提交
203 204 205 206 207 208 209 210 211 212 213
        param_state = _load(param_path)
        assert param_state, "failed to load parameters, please check path"

        if self._executor is None:
            executor = fluid.Executor(fluid.CPUPlace())._default_executor
        else:
            executor = self._executor._default_executor

        fluid.core._create_loaded_parameter(
            list(self.model.state_dict().values()), global_scope(), executor)

Y
Yang Zhang 已提交
214
        for key, var in self.model.state_dict().items():
Y
Yang Zhang 已提交
215
            assert key in param_state, \
Y
Yang Zhang 已提交
216 217
                "parameter [{}] is not found in model file [{}]".format(
                    key, param_path)
218
            self._set_var(var, param_state[key])
Y
Yang Zhang 已提交
219 220 221 222 223 224 225 226 227

        # FIXME what if a different optimizer is used?
        if not self.model._optimizer:
            return
        optim_path = path + ".pdopt"
        optim_state = _load(optim_path)
        if optim_state is None:
            return

228
        self._load_optimizer(optim_state, executor)
229

230
    def _load_optimizer(self, state, executor):
231 232 233 234 235
        prog = self._progs.get('train', None)
        optim = list(filter(is_belong_to_optimizer, prog.list_vars()))
        if not optim:
            return

236
        fluid.core._create_loaded_parameter(optim, global_scope(), executor)
Y
Yang Zhang 已提交
237

238
        converted_state = dict(state)
Y
Yang Zhang 已提交
239
        for var in optim:
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
            if var.name in ["@LR_DECAY_COUNTER@", "global_step"]:
                # When using learning rate scheduler, dygraph would name the
                # global step var as "global_step" to save, while static-graph
                # would has a state var named as "@LR_DECAY_COUNTER@".
                # NOTE: dygraph saved global_step is 1 larger than that in
                # static-graph, since the time of global_step to increase is
                # different.
                state_val = (
                    np.array(converted_state.pop("global_step")) - 1
                ) if "global_step" in converted_state else converted_state.pop(
                    "@LR_DECAY_COUNTER@", None)
                if state_val is not None:
                    converted_state[var.name] = state_val
            elif var.name.startswith("learning_rate_"):
                # When using static learning rate, static-graph would make it
                # a persistable var named 'unique_name.generate("learning_rate")',
                # However, dygraph wouldn't save it.
                if var.name not in state: continue
            else:
                # moment and other accumulators
                if var.name not in converted_state:
                    # try to convert from dygraph name
                    opt_name = self.model._optimizer._name
                    opt_cls_name = self.model._optimizer.__class__.__name__
                    opt_unq_name = None
                    for name in self.model._optimizer._accumulators.keys():
Q
qingqing01 已提交
266 267
                        accum_name = name if opt_name is None else name[len(
                            opt_name) + 1:]
268 269 270 271 272
                        for param_name, state_var in self.model._optimizer._accumulators[
                                name].items():
                            if opt_unq_name is None:
                                # can not infer out the exact unique(opt_name),
                                # thus try to extract rather than generate
Q
qingqing01 已提交
273 274 275 276
                                for state_key in sorted(
                                        state.keys(),
                                        key=lambda x: len(x),
                                        reverse=True):
277 278 279 280 281 282 283 284 285 286
                                    prefix = param_name + "_" + (
                                        opt_cls_name if opt_name is None else
                                        opt_name) + "_"
                                    if state_key.startswith(prefix):
                                        prefix_offset = state_key[len(
                                            prefix):].find("_") + len(prefix)
                                        opt_unq_name = state_key[len(
                                            param_name + "_"):prefix_offset]
                                        # TODO: assert
                                        # assert opt_unq_name is None
Q
qingqing01 已提交
287 288
                                    # gen(param.name + "_" + gen(opt_name) + "_" + accum_name)
                                    # always end with "_0" since the unique optimizer._name
289 290 291 292 293 294 295
                            dy_state_name = (param_name + "_" + opt_unq_name +
                                             "_" + accum_name + "_0")
                            converted_state[
                                state_var.name] = converted_state.pop(
                                    dy_state_name)

            assert var.name in converted_state, \
296
                "variable [{}] is not in optimizer state file".format(var.name)
297
            self._set_var(var, converted_state[var.name])
298 299 300 301 302 303 304 305 306 307 308 309 310 311

    def _set_var(self, var, ndarray):
        t = global_scope().find_var(var.name).get_tensor()
        p = t._place()
        if p.is_cpu_place():
            place = fluid.CPUPlace()
        elif p.is_cuda_pinned_place():
            place = fluid.CUDAPinnedPlace()
        else:
            p = fluid.core.Place()
            p.set_place(t._place())
            place = fluid.CUDAPlace(p.gpu_device_id())

        t.set(ndarray, place)
Y
Yang Zhang 已提交
312

313 314 315 316
    def _run(self, inputs, labels=None):
        compiled_prog = self._compiled_progs.get(self.mode, None)
        assert compiled_prog, \
            "Model is not ready, please call `model.prepare()` first"
Y
Yang Zhang 已提交
317 318 319 320

        inputs = to_list(inputs)
        if labels is not None:
            labels = to_list(labels)
321 322
        assert len(inputs) == len(self._input_vars[self.mode]), \
            "number of inputs" \
Y
Yang Zhang 已提交
323 324 325
            + " does not match number of arguments of `forward` method"

        feed = {}
326
        input_names = [v.name for v in self._input_vars[self.mode]]
Y
Yang Zhang 已提交
327 328 329 330 331
        for idx, n in enumerate(input_names):
            # train and test may take different arguments
            if inputs[idx] is not None:
                feed[n] = inputs[idx]
        if labels is not None:
332
            for idx, v in enumerate(self._label_vars[self.mode]):
Y
Yang Zhang 已提交
333 334
                feed[v.name] = labels[idx]

335
        endpoints = self._endpoints[self.mode]
D
dengkaipeng 已提交
336 337 338 339 340 341
        if self.mode == 'test':
            fetch_list = endpoints['output']
        else:
            metric_list, metric_splits = flatten_list(endpoints['metric'])
            fetch_list = endpoints['loss'] + metric_list
            num_loss = len(endpoints['loss'])
Q
qingqing01 已提交
342 343 344 345
        rets = self._executor.run(compiled_prog,
                                  feed=feed,
                                  fetch_list=fetch_list,
                                  return_numpy=False)
D
dengkaipeng 已提交
346 347
        # LoDTensor cannot be fetch as numpy directly
        rets = [np.array(v) for v in rets]
348
        if self.mode == 'test':
D
dengkaipeng 已提交
349 350 351 352 353 354
            return rets[:]
        losses = rets[:num_loss]
        metric_states = restore_flatten_list(rets[num_loss:], metric_splits)
        metrics = []
        for metric, state in zip(self.model._metrics, metric_states):
            metrics.append(metric.update(*state))
D
dengkaipeng 已提交
355
        return (losses, metrics) if len(metrics) > 0 else losses
Y
Yang Zhang 已提交
356

357 358 359 360 361 362 363 364 365 366
    def prepare(self):
        modes = ['train', 'eval', 'test']
        for mode in modes:
            self._make_program(mode)
            self._compile_and_initialize(self._progs[mode], mode)

    def _make_program(self, mode):
        prog = self._progs.get(mode, None)
        if prog is not None:
            return
367

368
        prog = self._orig_prog.clone()
369 370 371 372 373
        # NOTE: When defining learning rate scheduling in static-graph, ops to
        # increase the global step var and calculate learning rate would be
        # prepended into _orig_prog. test program maked by `_orig_prog.clone`
        # also would include these ops. Thus must prune these ops in test
        # program, otherwise the global step would be changed in test.
374
        if mode != 'train':
375 376
            for op in list(prog.global_block().ops):
                prog.global_block()._remove_op(0)
377 378
        if mode == 'train' and self.model._optimizer \
            and self.model._optimizer._learning_rate_map:
379 380 381 382
            # HACK workaround learning rate map issue
            lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
            self.model._optimizer._learning_rate_map[prog] = lr_var
        losses = []
D
dengkaipeng 已提交
383
        metrics = []
Y
Yang Zhang 已提交
384
        with fluid.program_guard(prog, self._startup_prog):
385 386 387 388 389 390 391 392
            if isinstance(self.model._inputs, dict):
                ins = [self.model._inputs[n] \
                    for n in extract_args(self.model.forward) if n != 'self']
            else:
                ins = self.model._inputs
            lbls = self.model._labels if self.model._labels else []
            inputs = [k.forward() for k in to_list(ins)]
            labels = [k.forward() for k in to_list(lbls)]
Y
Yang Zhang 已提交
393
            outputs = to_list(self.model.forward(*inputs))
394 395 396
            if mode != 'test':
                if self.model._loss_function:
                    losses = self.model._loss_function(outputs, labels)
D
dengkaipeng 已提交
397
                    for metric in self.model._metrics:
Q
qingqing01 已提交
398 399
                        metrics.append(
                            to_list(metric.add_metric_op(outputs, labels)))
400
                if mode == 'train' and self.model._optimizer:
Y
Yang Zhang 已提交
401 402
                    self._loss_endpoint = fluid.layers.sum(losses)
                    self.model._optimizer.minimize(self._loss_endpoint)
403
        if mode != 'train':  # clone again to put it in test mode
404
            prog = prog.clone(for_test=True)
Y
Yang Zhang 已提交
405

406 407 408
        self._input_vars[mode] = inputs
        self._label_vars[mode] = labels
        self._progs[mode] = prog
Q
qingqing01 已提交
409 410 411 412 413
        self._endpoints[mode] = {
            "output": outputs,
            "loss": losses,
            "metric": metrics
        }
414

415 416
    def _compile_and_initialize(self, prog, mode):
        compiled_prog = self._compiled_progs.get(mode, None)
417 418
        if compiled_prog is not None:
            return compiled_prog
Y
Yang Zhang 已提交
419

420 421 422 423 424 425 426
        device = self.model._device
        device_ids = self.model._device_ids

        if device.lower() == 'gpu':
            places = fluid.cuda_places(device_ids)
        else:
            places = fluid.cpu_places(len(device_ids) if device_ids else None)
Y
Yang Zhang 已提交
427

Y
Yang Zhang 已提交
428 429 430
        # XXX *ALL WEIGHTS* should be initialized upon model construction
        # even if `forward()` may run different code path for different mode
        # therefore startup program only needs to run once
Y
Yang Zhang 已提交
431
        if self._executor is None:
432
            self._executor = fluid.Executor(places[0])
Y
Yang Zhang 已提交
433 434 435 436 437 438 439 440 441 442
            # XXX incremental initialization
            uninitialized = []
            for var_py in self._startup_prog.list_vars():
                var = fluid.global_scope().find_var(var_py.name)
                if var and var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var_py)
            if uninitialized:
                startup_prog = self._startup_prog._prune(uninitialized)
                self._executor.run(startup_prog)
Y
Yang Zhang 已提交
443

444
        compiled_prog = fluid.CompiledProgram(prog)
445
        if len(places) > 1:
446
            loss_name = None
447
            if mode == 'train' and self._loss_endpoint is not None:
448 449
                loss_name = self._loss_endpoint.name
            compiled_prog = compiled_prog.with_data_parallel(
450 451
                loss_name=loss_name, places=places)
        self._compiled_progs[mode] = compiled_prog
Y
Yang Zhang 已提交
452 453 454 455 456 457 458 459 460 461 462 463 464 465 466


class DynamicGraphAdapter(object):
    def __init__(self, model):
        super(DynamicGraphAdapter, self).__init__()
        self.model = model

    @property
    def mode(self):
        return self.model.mode

    @mode.setter
    def mode(self, value):
        self.model.mode = value

467
    # TODO multi device in dygraph mode not implemented at present time
468
    def train(self, inputs, labels=None):
Q
qingqing01 已提交
469
        assert self.model._optimizer, \
Y
Yang Zhang 已提交
470 471 472 473
            "model not ready, please call `model.prepare()` first"
        super(Model, self.model).train()
        self.mode = 'train'
        inputs = to_list(inputs)
Q
qingqing01 已提交
474
        if labels is not None:
D
dengkaipeng 已提交
475
            labels = [to_variable(l) for l in to_list(labels)]
Q
qingqing01 已提交
476 477
        outputs = to_list(
            self.model.forward(*[to_variable(x) for x in inputs]))
Y
Yang Zhang 已提交
478
        losses = self.model._loss_function(outputs, labels)
Y
Yang Zhang 已提交
479 480 481 482
        final_loss = fluid.layers.sum(losses)
        final_loss.backward()
        self.model._optimizer.minimize(final_loss)
        self.model.clear_gradients()
D
dengkaipeng 已提交
483
        metrics = []
484
        for metric in self.model._metrics:
D
dengkaipeng 已提交
485
            metric_outs = metric.add_metric_op(outputs, to_list(labels))
D
dengkaipeng 已提交
486 487
            m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
            metrics.append(m)
D
dengkaipeng 已提交
488 489
        return ([to_numpy(l) for l in losses], metrics) \
                if len(metrics) > 0 else [to_numpy(l) for l in losses]
Y
Yang Zhang 已提交
490

491
    def eval(self, inputs, labels=None):
492
        super(Model, self.model).eval()
Y
Yang Zhang 已提交
493 494
        self.mode = 'eval'
        inputs = to_list(inputs)
Q
qingqing01 已提交
495
        if labels is not None:
D
dengkaipeng 已提交
496
            labels = [to_variable(l) for l in to_list(labels)]
Q
qingqing01 已提交
497 498
        outputs = to_list(
            self.model.forward(*[to_variable(x) for x in inputs]))
499 500 501 502 503 504

        if self.model._loss_function:
            losses = self.model._loss_function(outputs, labels)
        else:
            losses = []

D
dengkaipeng 已提交
505
        metrics = []
D
dengkaipeng 已提交
506
        for metric in self.model._metrics:
D
dengkaipeng 已提交
507
            metric_outs = metric.add_metric_op(outputs, labels)
D
dengkaipeng 已提交
508 509
            m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
            metrics.append(m)
D
dengkaipeng 已提交
510

511 512
        # To be consistent with static graph
        # return empty loss if loss_function is None
D
dengkaipeng 已提交
513 514
        return ([to_numpy(l) for l in losses], metrics) \
                if len(metrics) > 0 else [to_numpy(l) for l in losses]
Y
Yang Zhang 已提交
515

516
    def test(self, inputs):
517
        super(Model, self.model).eval()
Y
Yang Zhang 已提交
518
        self.mode = 'test'
519
        inputs = [to_variable(x) for x in to_list(inputs)]
520 521
        outputs = self.model.forward(*inputs)
        return [to_numpy(o) for o in to_list(outputs)]
Y
Yang Zhang 已提交
522

523 524 525
    def parameters(self, *args, **kwargs):
        return super(Model, self.model).parameters(*args, **kwargs)

Y
Yang Zhang 已提交
526 527 528 529 530 531 532 533 534 535 536 537
    def save(self, path):
        params = self.model.state_dict()
        fluid.save_dygraph(params, path)
        if self.model._optimizer is None:
            return
        if self.model._optimizer.state_dict():
            optim = self.model._optimizer.state_dict()
            fluid.save_dygraph(optim, path)

    def load(self, path):
        params, optim = fluid.load_dygraph(path)
        self.model.set_dict(params)
Y
Yang Zhang 已提交
538
        if self.model._optimizer is None or optim is None:
Y
Yang Zhang 已提交
539
            return
540 541 542 543 544 545 546 547 548 549 550 551

        # If optimizer performs set_dict when state vars haven't been created,
        # which would happen when set_dict before minimize, the state would be
        # stored in optimizer._accumulators_holder and loaded lazily.
        # To contrive this when loading from static-graph saved states, extend
        # state dict to include keys named accoring to dygraph naming rules.
        # TODO: if len(self.model._optimizer._accumulators) > 0
        converted_state = dict(optim)
        opt_unq_name = self.model._optimizer._name
        opt_cls_name = self.model._optimizer.__class__.__name__
        opt_name = opt_unq_name[:opt_unq_name.rfind("_")]  # remove suffix idx
        param_names = [param.name for param in self.model.parameters()]
Q
qingqing01 已提交
552 553
        for var_name, state_var in sorted(
                optim.items(), key=lambda x: len(x[0]), reverse=True):
554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583
            if var_name in ["@LR_DECAY_COUNTER@", "global_step"]:
                # NOTE: dygraph saved global_step is 1 larger than that in
                # static-graph, since the time of global_step to increase is
                # different.
                if var_name == "@LR_DECAY_COUNTER@":
                    converted_state["global_step"] = np.array(
                        converted_state.pop("@LR_DECAY_COUNTER@")) + 1
            else:
                # moment and other accumulators
                # extend state dict to include promising dygraph names
                for param_name in param_names:
                    if var_name.startswith(param_name + "_" + opt_name):
                        # when init optimizer with name
                        accum_name = var_name[len(param_name + "_" + opt_name +
                                                  "_"):]
                    elif var_name.startswith(param_name +
                                             "_") and opt_name == opt_cls_name:
                        # when init optimizer without name
                        accum_name = var_name[len(param_name + "_"):]
                    else:
                        continue
                    # remove suffix idx
                    accum_name = accum_name[:accum_name.rfind("_")]
                    # state names always end with "_0" in dygraph because of the
                    # unique optimizer._name
                    dy_state_name = (param_name + "_" + opt_unq_name + "_" +
                                     accum_name + "_0")
                    converted_state[dy_state_name] = state_var

        self.model._optimizer.set_dict(converted_state)
Y
Yang Zhang 已提交
584 585 586


class Model(fluid.dygraph.Layer):
Q
qingqing01 已提交
587 588 589 590
    """
    FIXME: add more comments and usage
    """

Y
Yang Zhang 已提交
591 592 593
    def __init__(self):
        super(Model, self).__init__(self.__class__.__name__)
        self.mode = 'train'
594 595
        self._inputs = None
        self._labels = None
Y
Yang Zhang 已提交
596
        self._loss_function = None
Y
Yang Zhang 已提交
597
        self._loss_weights = None
Y
Yang Zhang 已提交
598
        self._optimizer = None
599 600
        self._device = None
        self._device_ids = None
Y
Yang Zhang 已提交
601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621
        self._optimizer = None
        if in_dygraph_mode():
            self._adapter = DynamicGraphAdapter(self)
        else:
            self._adapter = StaticGraphAdapter(self)

    def train(self, *args, **kwargs):
        return self._adapter.train(*args, **kwargs)

    def eval(self, *args, **kwargs):
        return self._adapter.eval(*args, **kwargs)

    def test(self, *args, **kwargs):
        return self._adapter.test(*args, **kwargs)

    def save(self, *args, **kwargs):
        return self._adapter.save(*args, **kwargs)

    def load(self, *args, **kwargs):
        return self._adapter.load(*args, **kwargs)

Q
qingqing01 已提交
622 623 624
    def parameters(self, *args, **kwargs):
        return self._adapter.parameters(*args, **kwargs)

625 626 627
    def prepare(self,
                optimizer=None,
                loss_function=None,
D
dengkaipeng 已提交
628
                metrics=None,
629 630 631 632
                inputs=None,
                labels=None,
                device=None,
                device_ids=None):
633 634 635
        """
        FIXME: add comments
        Args:
636 637 638 639 640 641
            optimizer (Optimizer|None): optimizer must be set in training
                and should be a Optimizer instance. It can be None in eval
                and test mode.
            loss_function (Loss|None): loss function must be set in training
                and should be a Loss instance. It can be None when there is
                no loss.
D
dengkaipeng 已提交
642 643
            metrics (Metric|list of Metric|None): if metrics is set, all
                metric will be calculate and output in train/eval mode.
644 645 646 647
            inputs (Input|list|dict|None): inputs, entry points of network,
                could be a Input layer, or lits of Input layers,
                or dict (name: Input), or None. For static graph,
                inputs must be set. For dynamic graph, it could be None.
648
            labels (Input|list|None): labels, entry points of network,
649 650 651
                could be a Input layer or lits of Input layers, or None.
                For static graph, if set loss_function in Model.prepare(), it
                must be set. Otherwise, it could be None.
652 653 654 655 656 657 658 659 660 661 662 663 664 665
            device (str|None): specify device type, 'CPU' or 'GPU'.
                If None, automatically select device according to
                installation package version.
            device_ids (list[int]|None): specify device index. If None,
                the available device will be obtained from the environment
                variable when the model is executed: If the GPU is used, the
                currently available device ID is obtained from the environment
                variable FLAGS_selected_gpus or CUDA_VISIBLE_DEVICES when the
                model is executed; CPU, when the model is executed,
                the currently available CPU number is obtained from the
                environment variable CPU_NUM. For example, export CPU_NUM=4,
                if the environment variable is not set, the executor will add
                the variable to the environment variable and set its value to 1.
                The default is None.
666
        """
Y
Yang Zhang 已提交
667
        self._optimizer = optimizer
Q
qingqing01 已提交
668 669 670 671
        if loss_function:
            if not isinstance(loss_function, Loss):
                raise TypeError(
                    "'loss_function' must be sub classes of 'Loss'")
Y
Yang Zhang 已提交
672
        self._loss_function = loss_function
673 674 675 676 677 678
        if not in_dygraph_mode():
            if not isinstance(inputs, (list, dict, Input)):
                raise TypeError(
                    "'inputs' must be list or dict in static graph mode")
            if loss_function and not isinstance(labels, (list, Input)):
                raise TypeError("'labels' must be list in static graph mode")
D
dengkaipeng 已提交
679 680

        metrics = metrics or []
D
dengkaipeng 已提交
681 682 683 684
        for metric in to_list(metrics):
            assert isinstance(metric, Metric), \
                "{} is not sub class of Metric".format(metric.__class__.__name__)
        self._metrics = to_list(metrics)
685

686 687
        self._inputs = inputs
        self._labels = labels
688 689 690 691 692 693
        self._device = device
        if device is None:
            self._device = 'GPU' if fluid.is_compiled_with_cuda() else 'CPU'
        self._device_ids = device_ids
        if not in_dygraph_mode():
            self._adapter.prepare()
694

Q
qingqing01 已提交
695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790
    def fit(
            self,
            train_loader=None,
            eval_loader=None,
            epochs=1,
            eval_freq=1,
            log_freq=10,
            save_freq=1,
            verbose=2,
            callbacks=None, ):
        """
        FIXME: add more comments and usage
        Args:
            train_loader (DataLoader): an iterable data loader is used for train.
            eval_loader (DataLoader): an iterable data loader is used for
                evaluation at the end of epoch. If None, will not do evaluation.
            epochs (int): number of epochs to train the model.
            eval_freq (int): evaluation frequency in epoch.
            log_freq (int): frequency to print log during training.
            save_freq (int): frequency to save checkpoint during training.
            verbose (int): verbosity mode, should be 0, 1, or 2.
                0 = silent, 1 = progress bar, 2 = one line per epoch.
            callbacks (Callback|None): list of `Callback` instances to apply
                during training.
        """
        do_eval = eval_loader is not None
        metrics_name = self._metrics_name()
        cbks = config_callbacks(
            callbacks,
            model=self,
            epochs=epochs,
            steps=None,
            log_freq=log_freq,
            save_freq=save_freq,
            verbose=verbose,
            metrics=self._metrics_name(), )

        def _run_one_epoch(data_loader, callbacks, mode):
            size = data_loader.size if hasattr(data_loader, 'size') else None
            logs = {
                'steps': size,
                'metrics_name': metrics_name,
            }
            for step, data in enumerate(data_loader):
                cbks.on_batch_begin(mode, step, logs)
                if mode == 'train':
                    outs = self.train(*data)
                else:
                    outs = self.eval(*data)

                metrics = list(itertools.chain.from_iterable(outs))
                metrics = [np.mean(metrics[0])]
                for metric in self._metrics:
                    res = metric.accumulate()
                    metrics.extend(to_list(res))
                assert len(metrics_name) == len(metrics)
                for k, v in zip(metrics_name, metrics):
                    logs[k] = np.mean(v)

                logs['step'] = step
                logs['batch_size'] = data[0].shape[0]

                cbks.on_batch_end(mode, step, logs)
            self._reset_metrics()
            return logs

        cbks.on_begin('train')
        for epoch in range(epochs):
            cbks.on_epoch_begin(epoch)
            # FIXME: adapte to DataLoader
            loader = train_loader
            if not isinstance(train_loader, Iterable):
                loader = train_loader()
            logs = _run_one_epoch(loader, cbks, 'train')
            cbks.on_epoch_end(epoch, logs)

            if do_eval and epoch % eval_freq == 0:
                cbks.on_begin('eval', logs)
                # FIXME: adapte to DataLoader
                loader = eval_loader
                if not isinstance(eval_loader, Iterable):
                    loader = eval_loader()
                logs = _run_one_epoch(eval_loader(), cbks, 'eval')
                cbks.on_end('eval', logs)

        cbks.on_end('train', logs)

    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

    def _metrics_name(self):
        metrics_name = ['loss']
        for m in self._metrics:
            metrics_name.extend(to_list(m.name()))
        return metrics_name