model.py 27.0 KB
Newer Older
Y
Yang Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
Yang Zhang 已提交
15 16 17
from __future__ import absolute_import

import inspect
Y
Yang Zhang 已提交
18 19
import os
import pickle
Y
Yang Zhang 已提交
20 21 22 23 24
from collections import OrderedDict

import numpy as np

from paddle import fluid
Y
Yang Zhang 已提交
25 26 27
from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer
Y
Yang Zhang 已提交
28
from paddle.fluid.dygraph.base import to_variable
29
from metrics import Metric
Y
Yang Zhang 已提交
30

Q
qingqing01 已提交
31
__all__ = ['Model', 'Loss', 'CrossEntropy', 'Input']
Y
Yang Zhang 已提交
32 33 34


def to_list(value):
Q
qingqing01 已提交
35 36
    if value is None:
        return value
Y
Yang Zhang 已提交
37 38 39 40 41
    if isinstance(value, (list, tuple)):
        return value
    return [value]


42 43 44 45 46 47 48 49
def to_numpy(var):
    assert isinstance(var, (Variable, fluid.core.VarBase)), "not a variable"
    if isinstance(var, fluid.core.VarBase):
        return var.numpy()
    t = global_scope().find_var(var.name).get_tensor()
    return np.array(t)


D
dengkaipeng 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
def flatten_list(l):
    assert isinstance(l, list), "not a list"
    outl = []
    splits = []
    for sl in l:
        assert isinstance(sl, list), "sub content not a list"
        splits.append(len(sl))
        outl += sl
    return outl, splits


def restore_flatten_list(l, splits):
    outl = []
    for split in splits:
        assert len(l) >= split, "list length invalid"
        sl, l = l[:split], l[split:]
        outl.append(sl)
    return outl


Y
Yang Zhang 已提交
70 71 72 73 74 75 76
def extract_args(func):
    if hasattr(inspect, 'getfullargspec'):
        return inspect.getfullargspec(func)[0]
    else:
        return inspect.getargspec(func)[0]


77 78 79 80 81
class Input(fluid.dygraph.Layer):
    def __init__(self, shape=None, dtype=None, name=None):
        self.shape = shape
        self.dtype = dtype
        self.name = name
Y
Yang Zhang 已提交
82

83 84
    def forward(self):
        return fluid.data(self.name, shape=self.shape, dtype=self.dtype)
Y
Yang Zhang 已提交
85 86


Y
Yang Zhang 已提交
87 88 89 90 91 92 93 94 95 96 97 98 99
class Loss(object):
    def __init__(self, average=True):
        super(Loss, self).__init__()
        self.average = average

    def forward(self, outputs, labels):
        raise NotImplementedError()

    def __call__(self, outputs, labels):
        labels = to_list(labels)
        if in_dygraph_mode():
            labels = [to_variable(l) for l in labels]
        losses = to_list(self.forward(to_list(outputs), labels))
Q
qingqing01 已提交
100 101 102 103 104
        if self.average:
            losses = [fluid.layers.reduce_mean(l) for l in losses]
        else:
            losses = [fluid.layers.reduce_sum(l) for l in losses]
        return losses
Y
Yang Zhang 已提交
105 106 107


class CrossEntropy(Loss):
Q
qingqing01 已提交
108
    def __init__(self, average=True):
Y
Yang Zhang 已提交
109 110 111
        super(CrossEntropy, self).__init__()

    def forward(self, outputs, labels):
Q
qingqing01 已提交
112 113 114
        return [
            fluid.layers.cross_entropy(o, l) for o, l in zip(outputs, labels)
        ]
Y
Yang Zhang 已提交
115 116


Y
Yang Zhang 已提交
117 118 119 120 121 122 123
class StaticGraphAdapter(object):
    def __init__(self, model):
        super(StaticGraphAdapter, self).__init__()
        self.model = model
        # with `_build_once` gone, parameters are now created in `__init__`
        # so we need to keep track of the parameters already created
        self._startup_prog = fluid.default_startup_program()
124
        self._orig_prog = fluid.default_main_program()
Y
Yang Zhang 已提交
125

126
        self._label_vars = {}  # label variables
Q
qingqing01 已提交
127
        self._input_vars = {}  # label variables
Y
Yang Zhang 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141
        self._endpoints = {}
        self._loss_endpoint = None
        self._executor = None
        self._progs = {}
        self._compiled_progs = {}

    @property
    def mode(self):
        return self.model.mode

    @mode.setter
    def mode(self, value):
        self.model.mode = value

142
    def train(self, inputs, labels=None):
Q
qingqing01 已提交
143
        assert self.model._optimizer, \
Y
Yang Zhang 已提交
144 145
            "model not ready, please call `model.prepare()` first"
        self.mode = 'train'
146
        return self._run(inputs, labels)
Y
Yang Zhang 已提交
147

148
    def eval(self, inputs, labels=None):
Y
Yang Zhang 已提交
149
        self.mode = 'eval'
150
        return self._run(inputs, labels)
Y
Yang Zhang 已提交
151

152
    def test(self, inputs):
Y
Yang Zhang 已提交
153
        self.mode = 'test'
154
        return self._run(inputs, None)
Y
Yang Zhang 已提交
155

156 157 158
    def parameters(self, *args, **kwargs):
        return None

Y
Yang Zhang 已提交
159
    def save(self, path):
Y
Yang Zhang 已提交
160 161 162
        def _save(state, path):
            if not state:
                return
Q
qingqing01 已提交
163 164 165 166
            state = {
                k: to_numpy(v) if isinstance(v, Variable) else v
                for k, v in state.items()
            }
Y
Yang Zhang 已提交
167 168 169 170 171
            with open(path, 'wb') as f:
                pickle.dump(state, f)

        base = os.path.basename(path)
        assert base != "", "path should be of 'dirname/filename' format"
172 173 174
        dir_name = os.path.dirname(path)
        if dir_name and not os.path.exists(dir_name):
            os.makedirs(dir_name)
Y
Yang Zhang 已提交
175 176
        param_path = path + ".pdparams"
        _save(self.model.state_dict(), param_path)
Y
Yang Zhang 已提交
177 178
        prog = self._progs.get('train', None)
        if prog is None or self.model._optimizer is None:
Y
Yang Zhang 已提交
179 180 181
            return
        # XXX `optimizer.state_dict()` only work in dygraph mode
        optim_path = path + ".pdopt"
Q
qingqing01 已提交
182 183 184 185
        optim = {
            p.name: p
            for p in filter(is_belong_to_optimizer, prog.list_vars())
        }
186 187
        if not optim:
            return
188

Y
Yang Zhang 已提交
189
        _save(optim, optim_path)
Y
Yang Zhang 已提交
190 191

    def load(self, path):
Y
Yang Zhang 已提交
192 193 194 195 196 197 198
        def _load(path):
            if not os.path.exists(path):
                return
            with open(path, 'rb') as f:
                return pickle.load(f)

        param_path = path + ".pdparams"
Y
Yang Zhang 已提交
199 200 201 202 203 204 205 206 207 208 209
        param_state = _load(param_path)
        assert param_state, "failed to load parameters, please check path"

        if self._executor is None:
            executor = fluid.Executor(fluid.CPUPlace())._default_executor
        else:
            executor = self._executor._default_executor

        fluid.core._create_loaded_parameter(
            list(self.model.state_dict().values()), global_scope(), executor)

Y
Yang Zhang 已提交
210
        for key, var in self.model.state_dict().items():
Y
Yang Zhang 已提交
211
            assert key in param_state, \
Y
Yang Zhang 已提交
212 213
                "parameter [{}] is not found in model file [{}]".format(
                    key, param_path)
214
            self._set_var(var, param_state[key])
Y
Yang Zhang 已提交
215 216 217 218 219 220 221 222 223

        # FIXME what if a different optimizer is used?
        if not self.model._optimizer:
            return
        optim_path = path + ".pdopt"
        optim_state = _load(optim_path)
        if optim_state is None:
            return

224
        self._load_optimizer(optim_state, executor)
225

226
    def _load_optimizer(self, state, executor):
227 228 229 230 231
        prog = self._progs.get('train', None)
        optim = list(filter(is_belong_to_optimizer, prog.list_vars()))
        if not optim:
            return

232
        fluid.core._create_loaded_parameter(optim, global_scope(), executor)
Y
Yang Zhang 已提交
233

234
        converted_state = dict(state)
Y
Yang Zhang 已提交
235
        for var in optim:
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261
            if var.name in ["@LR_DECAY_COUNTER@", "global_step"]:
                # When using learning rate scheduler, dygraph would name the
                # global step var as "global_step" to save, while static-graph
                # would has a state var named as "@LR_DECAY_COUNTER@".
                # NOTE: dygraph saved global_step is 1 larger than that in
                # static-graph, since the time of global_step to increase is
                # different.
                state_val = (
                    np.array(converted_state.pop("global_step")) - 1
                ) if "global_step" in converted_state else converted_state.pop(
                    "@LR_DECAY_COUNTER@", None)
                if state_val is not None:
                    converted_state[var.name] = state_val
            elif var.name.startswith("learning_rate_"):
                # When using static learning rate, static-graph would make it
                # a persistable var named 'unique_name.generate("learning_rate")',
                # However, dygraph wouldn't save it.
                if var.name not in state: continue
            else:
                # moment and other accumulators
                if var.name not in converted_state:
                    # try to convert from dygraph name
                    opt_name = self.model._optimizer._name
                    opt_cls_name = self.model._optimizer.__class__.__name__
                    opt_unq_name = None
                    for name in self.model._optimizer._accumulators.keys():
Q
qingqing01 已提交
262 263
                        accum_name = name if opt_name is None else name[len(
                            opt_name) + 1:]
264 265 266 267 268
                        for param_name, state_var in self.model._optimizer._accumulators[
                                name].items():
                            if opt_unq_name is None:
                                # can not infer out the exact unique(opt_name),
                                # thus try to extract rather than generate
Q
qingqing01 已提交
269 270 271 272
                                for state_key in sorted(
                                        state.keys(),
                                        key=lambda x: len(x),
                                        reverse=True):
273 274 275 276 277 278 279 280 281 282
                                    prefix = param_name + "_" + (
                                        opt_cls_name if opt_name is None else
                                        opt_name) + "_"
                                    if state_key.startswith(prefix):
                                        prefix_offset = state_key[len(
                                            prefix):].find("_") + len(prefix)
                                        opt_unq_name = state_key[len(
                                            param_name + "_"):prefix_offset]
                                        # TODO: assert
                                        # assert opt_unq_name is None
Q
qingqing01 已提交
283 284
                                    # gen(param.name + "_" + gen(opt_name) + "_" + accum_name)
                                    # always end with "_0" since the unique optimizer._name
285 286 287 288 289 290 291
                            dy_state_name = (param_name + "_" + opt_unq_name +
                                             "_" + accum_name + "_0")
                            converted_state[
                                state_var.name] = converted_state.pop(
                                    dy_state_name)

            assert var.name in converted_state, \
292
                "variable [{}] is not in optimizer state file".format(var.name)
293
            self._set_var(var, converted_state[var.name])
294 295 296 297 298 299 300 301 302 303 304 305 306 307

    def _set_var(self, var, ndarray):
        t = global_scope().find_var(var.name).get_tensor()
        p = t._place()
        if p.is_cpu_place():
            place = fluid.CPUPlace()
        elif p.is_cuda_pinned_place():
            place = fluid.CUDAPinnedPlace()
        else:
            p = fluid.core.Place()
            p.set_place(t._place())
            place = fluid.CUDAPlace(p.gpu_device_id())

        t.set(ndarray, place)
Y
Yang Zhang 已提交
308

309 310 311 312
    def _run(self, inputs, labels=None):
        compiled_prog = self._compiled_progs.get(self.mode, None)
        assert compiled_prog, \
            "Model is not ready, please call `model.prepare()` first"
Y
Yang Zhang 已提交
313 314 315 316

        inputs = to_list(inputs)
        if labels is not None:
            labels = to_list(labels)
317 318
        assert len(inputs) == len(self._input_vars[self.mode]), \
            "number of inputs" \
Y
Yang Zhang 已提交
319 320 321
            + " does not match number of arguments of `forward` method"

        feed = {}
322
        input_names = [v.name for v in self._input_vars[self.mode]]
Y
Yang Zhang 已提交
323 324 325 326 327
        for idx, n in enumerate(input_names):
            # train and test may take different arguments
            if inputs[idx] is not None:
                feed[n] = inputs[idx]
        if labels is not None:
328
            for idx, v in enumerate(self._label_vars[self.mode]):
Y
Yang Zhang 已提交
329 330
                feed[v.name] = labels[idx]

331
        endpoints = self._endpoints[self.mode]
D
dengkaipeng 已提交
332 333 334 335 336 337
        if self.mode == 'test':
            fetch_list = endpoints['output']
        else:
            metric_list, metric_splits = flatten_list(endpoints['metric'])
            fetch_list = endpoints['loss'] + metric_list
            num_loss = len(endpoints['loss'])
D
dengkaipeng 已提交
338
        rets = self._executor.run(
Y
Yang Zhang 已提交
339
            compiled_prog, feed=feed,
D
dengkaipeng 已提交
340 341
            fetch_list=fetch_list,
            return_numpy=False)
D
dengkaipeng 已提交
342 343
        # LoDTensor cannot be fetch as numpy directly
        rets = [np.array(v) for v in rets]
344
        if self.mode == 'test':
D
dengkaipeng 已提交
345 346 347 348 349 350
            return rets[:]
        losses = rets[:num_loss]
        metric_states = restore_flatten_list(rets[num_loss:], metric_splits)
        metrics = []
        for metric, state in zip(self.model._metrics, metric_states):
            metrics.append(metric.update(*state))
D
dengkaipeng 已提交
351
        return (losses, metrics) if len(metrics) > 0 else losses
Y
Yang Zhang 已提交
352

353 354 355 356 357 358 359 360 361 362
    def prepare(self):
        modes = ['train', 'eval', 'test']
        for mode in modes:
            self._make_program(mode)
            self._compile_and_initialize(self._progs[mode], mode)

    def _make_program(self, mode):
        prog = self._progs.get(mode, None)
        if prog is not None:
            return
363

364
        prog = self._orig_prog.clone()
365 366 367 368 369
        # NOTE: When defining learning rate scheduling in static-graph, ops to
        # increase the global step var and calculate learning rate would be
        # prepended into _orig_prog. test program maked by `_orig_prog.clone`
        # also would include these ops. Thus must prune these ops in test
        # program, otherwise the global step would be changed in test.
370
        if mode != 'train':
371 372
            for op in list(prog.global_block().ops):
                prog.global_block()._remove_op(0)
373 374
        if mode == 'train' and self.model._optimizer \
            and self.model._optimizer._learning_rate_map:
375 376 377 378
            # HACK workaround learning rate map issue
            lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
            self.model._optimizer._learning_rate_map[prog] = lr_var
        losses = []
D
dengkaipeng 已提交
379
        metrics = []
Y
Yang Zhang 已提交
380
        with fluid.program_guard(prog, self._startup_prog):
381 382 383 384 385 386 387 388
            if isinstance(self.model._inputs, dict):
                ins = [self.model._inputs[n] \
                    for n in extract_args(self.model.forward) if n != 'self']
            else:
                ins = self.model._inputs
            lbls = self.model._labels if self.model._labels else []
            inputs = [k.forward() for k in to_list(ins)]
            labels = [k.forward() for k in to_list(lbls)]
Y
Yang Zhang 已提交
389
            outputs = to_list(self.model.forward(*inputs))
390 391 392
            if mode != 'test':
                if self.model._loss_function:
                    losses = self.model._loss_function(outputs, labels)
D
dengkaipeng 已提交
393 394
                    for metric in self.model._metrics:
                        metrics.append(to_list(metric.add_metric_op(outputs, labels)))
395
                if mode == 'train' and self.model._optimizer:
Y
Yang Zhang 已提交
396 397
                    self._loss_endpoint = fluid.layers.sum(losses)
                    self.model._optimizer.minimize(self._loss_endpoint)
398
        if mode != 'train':  # clone again to put it in test mode
399
            prog = prog.clone(for_test=True)
Y
Yang Zhang 已提交
400

401 402 403
        self._input_vars[mode] = inputs
        self._label_vars[mode] = labels
        self._progs[mode] = prog
D
dengkaipeng 已提交
404
        self._endpoints[mode] = {"output": outputs, "loss": losses, "metric": metrics}
405

406 407
    def _compile_and_initialize(self, prog, mode):
        compiled_prog = self._compiled_progs.get(mode, None)
408 409
        if compiled_prog is not None:
            return compiled_prog
Y
Yang Zhang 已提交
410

411 412 413 414 415 416 417
        device = self.model._device
        device_ids = self.model._device_ids

        if device.lower() == 'gpu':
            places = fluid.cuda_places(device_ids)
        else:
            places = fluid.cpu_places(len(device_ids) if device_ids else None)
Y
Yang Zhang 已提交
418

Y
Yang Zhang 已提交
419 420 421
        # XXX *ALL WEIGHTS* should be initialized upon model construction
        # even if `forward()` may run different code path for different mode
        # therefore startup program only needs to run once
Y
Yang Zhang 已提交
422
        if self._executor is None:
423
            self._executor = fluid.Executor(places[0])
Y
Yang Zhang 已提交
424 425 426 427 428 429 430 431 432 433
            # XXX incremental initialization
            uninitialized = []
            for var_py in self._startup_prog.list_vars():
                var = fluid.global_scope().find_var(var_py.name)
                if var and var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var_py)
            if uninitialized:
                startup_prog = self._startup_prog._prune(uninitialized)
                self._executor.run(startup_prog)
Y
Yang Zhang 已提交
434

435
        compiled_prog = fluid.CompiledProgram(prog)
436
        if len(places) > 1:
437
            loss_name = None
438
            if mode == 'train' and self._loss_endpoint is not None:
439 440
                loss_name = self._loss_endpoint.name
            compiled_prog = compiled_prog.with_data_parallel(
441 442
                loss_name=loss_name, places=places)
        self._compiled_progs[mode] = compiled_prog
Y
Yang Zhang 已提交
443 444 445 446 447 448 449 450 451 452 453 454 455 456 457


class DynamicGraphAdapter(object):
    def __init__(self, model):
        super(DynamicGraphAdapter, self).__init__()
        self.model = model

    @property
    def mode(self):
        return self.model.mode

    @mode.setter
    def mode(self, value):
        self.model.mode = value

458
    # TODO multi device in dygraph mode not implemented at present time
459
    def train(self, inputs, labels=None):
Q
qingqing01 已提交
460
        assert self.model._optimizer, \
Y
Yang Zhang 已提交
461 462 463 464
            "model not ready, please call `model.prepare()` first"
        super(Model, self.model).train()
        self.mode = 'train'
        inputs = to_list(inputs)
Q
qingqing01 已提交
465 466
        if labels is not None:
            labels = to_list(labels)
D
dengkaipeng 已提交
467
        outputs = to_list(self.model.forward(*[to_variable(x) for x in inputs]))
Y
Yang Zhang 已提交
468
        losses = self.model._loss_function(outputs, labels)
Y
Yang Zhang 已提交
469 470 471 472
        final_loss = fluid.layers.sum(losses)
        final_loss.backward()
        self.model._optimizer.minimize(final_loss)
        self.model.clear_gradients()
D
dengkaipeng 已提交
473
        metrics = []
474
        for metric in self.model._metrics:
D
dengkaipeng 已提交
475 476 477
            metric_outs = metric.add_metric_op(outputs, [to_variable(l) for l in labels])
            m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
            metrics.append(m)
D
dengkaipeng 已提交
478 479
        return ([to_numpy(l) for l in losses], metrics) \
                if len(metrics) > 0 else [to_numpy(l) for l in losses]
Y
Yang Zhang 已提交
480

481
    def eval(self, inputs, labels=None):
482
        super(Model, self.model).eval()
Y
Yang Zhang 已提交
483 484
        self.mode = 'eval'
        inputs = to_list(inputs)
Q
qingqing01 已提交
485 486
        if labels is not None:
            labels = to_list(labels)
D
dengkaipeng 已提交
487
        outputs = to_list(self.model.forward(*[to_variable(x) for x in inputs]))
488 489 490 491 492 493

        if self.model._loss_function:
            losses = self.model._loss_function(outputs, labels)
        else:
            losses = []

D
dengkaipeng 已提交
494
        metrics = []
D
dengkaipeng 已提交
495
        for metric in self.model._metrics:
D
dengkaipeng 已提交
496 497 498
            metric_outs = metric.add_metric_op(outputs, [to_variable(l) for l in labels])
            m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
            metrics.append(m)
D
dengkaipeng 已提交
499

500 501
        # To be consistent with static graph
        # return empty loss if loss_function is None
D
dengkaipeng 已提交
502 503
        return ([to_numpy(l) for l in losses], metrics) \
                if len(metrics) > 0 else [to_numpy(l) for l in losses]
Y
Yang Zhang 已提交
504

505
    def test(self, inputs):
506
        super(Model, self.model).eval()
Y
Yang Zhang 已提交
507
        self.mode = 'test'
508
        inputs = [to_variable(x) for x in to_list(inputs)]
509 510
        outputs = self.model.forward(*inputs)
        return [to_numpy(o) for o in to_list(outputs)]
Y
Yang Zhang 已提交
511

512 513 514
    def parameters(self, *args, **kwargs):
        return super(Model, self.model).parameters(*args, **kwargs)

Y
Yang Zhang 已提交
515 516 517 518 519 520 521 522 523 524 525 526
    def save(self, path):
        params = self.model.state_dict()
        fluid.save_dygraph(params, path)
        if self.model._optimizer is None:
            return
        if self.model._optimizer.state_dict():
            optim = self.model._optimizer.state_dict()
            fluid.save_dygraph(optim, path)

    def load(self, path):
        params, optim = fluid.load_dygraph(path)
        self.model.set_dict(params)
Y
Yang Zhang 已提交
527
        if self.model._optimizer is None or optim is None:
Y
Yang Zhang 已提交
528
            return
529 530 531 532 533 534 535 536 537 538 539 540

        # If optimizer performs set_dict when state vars haven't been created,
        # which would happen when set_dict before minimize, the state would be
        # stored in optimizer._accumulators_holder and loaded lazily.
        # To contrive this when loading from static-graph saved states, extend
        # state dict to include keys named accoring to dygraph naming rules.
        # TODO: if len(self.model._optimizer._accumulators) > 0
        converted_state = dict(optim)
        opt_unq_name = self.model._optimizer._name
        opt_cls_name = self.model._optimizer.__class__.__name__
        opt_name = opt_unq_name[:opt_unq_name.rfind("_")]  # remove suffix idx
        param_names = [param.name for param in self.model.parameters()]
Q
qingqing01 已提交
541 542
        for var_name, state_var in sorted(
                optim.items(), key=lambda x: len(x[0]), reverse=True):
543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572
            if var_name in ["@LR_DECAY_COUNTER@", "global_step"]:
                # NOTE: dygraph saved global_step is 1 larger than that in
                # static-graph, since the time of global_step to increase is
                # different.
                if var_name == "@LR_DECAY_COUNTER@":
                    converted_state["global_step"] = np.array(
                        converted_state.pop("@LR_DECAY_COUNTER@")) + 1
            else:
                # moment and other accumulators
                # extend state dict to include promising dygraph names
                for param_name in param_names:
                    if var_name.startswith(param_name + "_" + opt_name):
                        # when init optimizer with name
                        accum_name = var_name[len(param_name + "_" + opt_name +
                                                  "_"):]
                    elif var_name.startswith(param_name +
                                             "_") and opt_name == opt_cls_name:
                        # when init optimizer without name
                        accum_name = var_name[len(param_name + "_"):]
                    else:
                        continue
                    # remove suffix idx
                    accum_name = accum_name[:accum_name.rfind("_")]
                    # state names always end with "_0" in dygraph because of the
                    # unique optimizer._name
                    dy_state_name = (param_name + "_" + opt_unq_name + "_" +
                                     accum_name + "_0")
                    converted_state[dy_state_name] = state_var

        self.model._optimizer.set_dict(converted_state)
Y
Yang Zhang 已提交
573 574 575


class Model(fluid.dygraph.Layer):
Q
qingqing01 已提交
576 577 578 579
    """
    FIXME: add more comments and usage
    """

Y
Yang Zhang 已提交
580 581 582
    def __init__(self):
        super(Model, self).__init__(self.__class__.__name__)
        self.mode = 'train'
583 584
        self._inputs = None
        self._labels = None
Y
Yang Zhang 已提交
585
        self._loss_function = None
Y
Yang Zhang 已提交
586
        self._loss_weights = None
Q
qingqing01 已提交
587
        self._loss = None
Y
Yang Zhang 已提交
588
        self._optimizer = None
589 590
        self._device = None
        self._device_ids = None
Y
Yang Zhang 已提交
591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611
        self._optimizer = None
        if in_dygraph_mode():
            self._adapter = DynamicGraphAdapter(self)
        else:
            self._adapter = StaticGraphAdapter(self)

    def train(self, *args, **kwargs):
        return self._adapter.train(*args, **kwargs)

    def eval(self, *args, **kwargs):
        return self._adapter.eval(*args, **kwargs)

    def test(self, *args, **kwargs):
        return self._adapter.test(*args, **kwargs)

    def save(self, *args, **kwargs):
        return self._adapter.save(*args, **kwargs)

    def load(self, *args, **kwargs):
        return self._adapter.load(*args, **kwargs)

612 613 614
    def prepare(self,
                optimizer=None,
                loss_function=None,
D
dengkaipeng 已提交
615
                metrics=None,
616 617 618 619
                inputs=None,
                labels=None,
                device=None,
                device_ids=None):
620 621 622
        """
        FIXME: add comments
        Args:
623 624 625 626 627 628
            optimizer (Optimizer|None): optimizer must be set in training
                and should be a Optimizer instance. It can be None in eval
                and test mode.
            loss_function (Loss|None): loss function must be set in training
                and should be a Loss instance. It can be None when there is
                no loss.
D
dengkaipeng 已提交
629 630
            metrics (Metric|list of Metric|None): if metrics is set, all
                metric will be calculate and output in train/eval mode.
631 632 633 634
            inputs (Input|list|dict|None): inputs, entry points of network,
                could be a Input layer, or lits of Input layers,
                or dict (name: Input), or None. For static graph,
                inputs must be set. For dynamic graph, it could be None.
635
            labels (Input|list|None): labels, entry points of network,
636 637 638
                could be a Input layer or lits of Input layers, or None.
                For static graph, if set loss_function in Model.prepare(), it
                must be set. Otherwise, it could be None.
639 640 641 642 643 644 645 646 647 648 649 650 651 652
            device (str|None): specify device type, 'CPU' or 'GPU'.
                If None, automatically select device according to
                installation package version.
            device_ids (list[int]|None): specify device index. If None,
                the available device will be obtained from the environment
                variable when the model is executed: If the GPU is used, the
                currently available device ID is obtained from the environment
                variable FLAGS_selected_gpus or CUDA_VISIBLE_DEVICES when the
                model is executed; CPU, when the model is executed,
                the currently available CPU number is obtained from the
                environment variable CPU_NUM. For example, export CPU_NUM=4,
                if the environment variable is not set, the executor will add
                the variable to the environment variable and set its value to 1.
                The default is None.
653
        """
Y
Yang Zhang 已提交
654
        self._optimizer = optimizer
Q
qingqing01 已提交
655 656 657 658
        if loss_function:
            if not isinstance(loss_function, Loss):
                raise TypeError(
                    "'loss_function' must be sub classes of 'Loss'")
Y
Yang Zhang 已提交
659
        self._loss_function = loss_function
660 661 662 663 664 665
        if not in_dygraph_mode():
            if not isinstance(inputs, (list, dict, Input)):
                raise TypeError(
                    "'inputs' must be list or dict in static graph mode")
            if loss_function and not isinstance(labels, (list, Input)):
                raise TypeError("'labels' must be list in static graph mode")
D
dengkaipeng 已提交
666 667

        metrics = metrics or []
D
dengkaipeng 已提交
668 669 670 671
        for metric in to_list(metrics):
            assert isinstance(metric, Metric), \
                "{} is not sub class of Metric".format(metric.__class__.__name__)
        self._metrics = to_list(metrics)
672

673 674
        self._inputs = inputs
        self._labels = labels
675 676 677 678 679 680
        self._device = device
        if device is None:
            self._device = 'GPU' if fluid.is_compiled_with_cuda() else 'CPU'
        self._device_ids = device_ids
        if not in_dygraph_mode():
            self._adapter.prepare()
681 682 683

    def parameters(self, *args, **kwargs):
        return self._adapter.parameters(*args, **kwargs)