model.py 25.9 KB
Newer Older
Y
Yang Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
Yang Zhang 已提交
15 16 17
from __future__ import absolute_import

import inspect
Y
Yang Zhang 已提交
18 19
import os
import pickle
Y
Yang Zhang 已提交
20 21 22 23 24
from collections import OrderedDict

import numpy as np

from paddle import fluid
Y
Yang Zhang 已提交
25 26 27
from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer
Y
Yang Zhang 已提交
28 29
from paddle.fluid.dygraph.base import to_variable

Q
qingqing01 已提交
30 31 32
__all__ = ['Model', 'Loss', 'CrossEntropy', 'Input']


Y
Yang Zhang 已提交
33
def to_list(value):
Q
qingqing01 已提交
34 35
    if value is None:
        return value
Y
Yang Zhang 已提交
36 37 38 39 40
    if isinstance(value, (list, tuple)):
        return value
    return [value]


41 42 43 44 45 46 47 48
def to_numpy(var):
    assert isinstance(var, (Variable, fluid.core.VarBase)), "not a variable"
    if isinstance(var, fluid.core.VarBase):
        return var.numpy()
    t = global_scope().find_var(var.name).get_tensor()
    return np.array(t)


49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
def extract_args(func):
    if hasattr(inspect, 'getfullargspec'):
        return inspect.getfullargspec(func)[0]
    else:
        return inspect.getargspec(func)[0]


class Input(fluid.dygraph.Layer):
    def __init__(self, shape=None, dtype=None, name=None):
        self.shape = shape
        self.dtype = dtype
        self.name = name

    def forward(self):
        return fluid.data(self.name, shape=self.shape, dtype=self.dtype)


Y
Yang Zhang 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78
class Loss(object):
    def __init__(self, average=True):
        super(Loss, self).__init__()
        self.average = average

    def forward(self, outputs, labels):
        raise NotImplementedError()

    def __call__(self, outputs, labels):
        labels = to_list(labels)
        if in_dygraph_mode():
            labels = [to_variable(l) for l in labels]
        losses = to_list(self.forward(to_list(outputs), labels))
Q
qingqing01 已提交
79 80 81 82 83
        if self.average:
            losses = [fluid.layers.reduce_mean(l) for l in losses]
        else:
            losses = [fluid.layers.reduce_sum(l) for l in losses]
        return losses
Y
Yang Zhang 已提交
84 85 86


class CrossEntropy(Loss):
Q
qingqing01 已提交
87
    def __init__(self, average=True):
Y
Yang Zhang 已提交
88 89 90
        super(CrossEntropy, self).__init__()

    def forward(self, outputs, labels):
Q
qingqing01 已提交
91 92 93
        return [
            fluid.layers.cross_entropy(o, l) for o, l in zip(outputs, labels)
        ]
Y
Yang Zhang 已提交
94 95


Y
Yang Zhang 已提交
96 97 98 99 100 101 102
class StaticGraphAdapter(object):
    def __init__(self, model):
        super(StaticGraphAdapter, self).__init__()
        self.model = model
        # with `_build_once` gone, parameters are now created in `__init__`
        # so we need to keep track of the parameters already created
        self._startup_prog = fluid.default_startup_program()
103
        self._orig_prog = fluid.default_main_program()
Y
Yang Zhang 已提交
104

105
        self._label_vars = {}  # label variables
Q
qingqing01 已提交
106
        self._input_vars = {}  # label variables
Y
Yang Zhang 已提交
107 108 109 110 111 112 113 114 115 116 117 118 119 120
        self._endpoints = {}
        self._loss_endpoint = None
        self._executor = None
        self._progs = {}
        self._compiled_progs = {}

    @property
    def mode(self):
        return self.model.mode

    @mode.setter
    def mode(self, value):
        self.model.mode = value

121
    def train(self, inputs, labels=None):
Q
qingqing01 已提交
122
        assert self.model._optimizer, \
Y
Yang Zhang 已提交
123 124
            "model not ready, please call `model.prepare()` first"
        self.mode = 'train'
125
        return self._run(inputs, labels)
Y
Yang Zhang 已提交
126

127
    def eval(self, inputs, labels=None):
Y
Yang Zhang 已提交
128
        self.mode = 'eval'
129
        return self._run(inputs, labels)
Y
Yang Zhang 已提交
130

131
    def test(self, inputs):
Y
Yang Zhang 已提交
132
        self.mode = 'test'
133
        return self._run(inputs, None)
Y
Yang Zhang 已提交
134

135 136 137
    def parameters(self, *args, **kwargs):
        return None

Y
Yang Zhang 已提交
138
    def save(self, path):
Y
Yang Zhang 已提交
139 140 141
        def _save(state, path):
            if not state:
                return
Q
qingqing01 已提交
142 143 144 145
            state = {
                k: to_numpy(v) if isinstance(v, Variable) else v
                for k, v in state.items()
            }
Y
Yang Zhang 已提交
146 147 148 149 150
            with open(path, 'wb') as f:
                pickle.dump(state, f)

        base = os.path.basename(path)
        assert base != "", "path should be of 'dirname/filename' format"
151 152 153
        dir_name = os.path.dirname(path)
        if dir_name and not os.path.exists(dir_name):
            os.makedirs(dir_name)
Y
Yang Zhang 已提交
154 155
        param_path = path + ".pdparams"
        _save(self.model.state_dict(), param_path)
Y
Yang Zhang 已提交
156 157
        prog = self._progs.get('train', None)
        if prog is None or self.model._optimizer is None:
Y
Yang Zhang 已提交
158 159 160
            return
        # XXX `optimizer.state_dict()` only work in dygraph mode
        optim_path = path + ".pdopt"
Q
qingqing01 已提交
161 162 163 164
        optim = {
            p.name: p
            for p in filter(is_belong_to_optimizer, prog.list_vars())
        }
165 166
        if not optim:
            return
167

Y
Yang Zhang 已提交
168
        _save(optim, optim_path)
Y
Yang Zhang 已提交
169 170

    def load(self, path):
Y
Yang Zhang 已提交
171 172 173 174 175 176 177
        def _load(path):
            if not os.path.exists(path):
                return
            with open(path, 'rb') as f:
                return pickle.load(f)

        param_path = path + ".pdparams"
Y
Yang Zhang 已提交
178 179 180 181 182 183 184 185 186 187 188
        param_state = _load(param_path)
        assert param_state, "failed to load parameters, please check path"

        if self._executor is None:
            executor = fluid.Executor(fluid.CPUPlace())._default_executor
        else:
            executor = self._executor._default_executor

        fluid.core._create_loaded_parameter(
            list(self.model.state_dict().values()), global_scope(), executor)

Y
Yang Zhang 已提交
189
        for key, var in self.model.state_dict().items():
Y
Yang Zhang 已提交
190
            assert key in param_state, \
Y
Yang Zhang 已提交
191 192
                "parameter [{}] is not found in model file [{}]".format(
                    key, param_path)
193
            self._set_var(var, param_state[key])
Y
Yang Zhang 已提交
194 195 196 197 198 199 200 201 202

        # FIXME what if a different optimizer is used?
        if not self.model._optimizer:
            return
        optim_path = path + ".pdopt"
        optim_state = _load(optim_path)
        if optim_state is None:
            return

203 204
        assert self._executor
        self._load_optimizer(optim_state)
205 206 207 208 209 210 211

    def _load_optimizer(self, state):
        prog = self._progs.get('train', None)
        optim = list(filter(is_belong_to_optimizer, prog.list_vars()))
        if not optim:
            return

Q
qingqing01 已提交
212 213 214
        fluid.core._create_loaded_parameter(optim,
                                            global_scope(),
                                            self._executor._default_executor)
Y
Yang Zhang 已提交
215

216
        converted_state = dict(state)
Y
Yang Zhang 已提交
217
        for var in optim:
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
            if var.name in ["@LR_DECAY_COUNTER@", "global_step"]:
                # When using learning rate scheduler, dygraph would name the
                # global step var as "global_step" to save, while static-graph
                # would has a state var named as "@LR_DECAY_COUNTER@".
                # NOTE: dygraph saved global_step is 1 larger than that in
                # static-graph, since the time of global_step to increase is
                # different.
                state_val = (
                    np.array(converted_state.pop("global_step")) - 1
                ) if "global_step" in converted_state else converted_state.pop(
                    "@LR_DECAY_COUNTER@", None)
                if state_val is not None:
                    converted_state[var.name] = state_val
            elif var.name.startswith("learning_rate_"):
                # When using static learning rate, static-graph would make it
                # a persistable var named 'unique_name.generate("learning_rate")',
                # However, dygraph wouldn't save it.
                if var.name not in state: continue
            else:
                # moment and other accumulators
                if var.name not in converted_state:
                    # try to convert from dygraph name
                    opt_name = self.model._optimizer._name
                    opt_cls_name = self.model._optimizer.__class__.__name__
                    opt_unq_name = None
                    for name in self.model._optimizer._accumulators.keys():
Q
qingqing01 已提交
244 245
                        accum_name = name if opt_name is None else name[len(
                            opt_name) + 1:]
246 247 248 249 250
                        for param_name, state_var in self.model._optimizer._accumulators[
                                name].items():
                            if opt_unq_name is None:
                                # can not infer out the exact unique(opt_name),
                                # thus try to extract rather than generate
Q
qingqing01 已提交
251 252 253 254
                                for state_key in sorted(
                                        state.keys(),
                                        key=lambda x: len(x),
                                        reverse=True):
255 256 257 258 259 260 261 262 263 264
                                    prefix = param_name + "_" + (
                                        opt_cls_name if opt_name is None else
                                        opt_name) + "_"
                                    if state_key.startswith(prefix):
                                        prefix_offset = state_key[len(
                                            prefix):].find("_") + len(prefix)
                                        opt_unq_name = state_key[len(
                                            param_name + "_"):prefix_offset]
                                        # TODO: assert
                                        # assert opt_unq_name is None
Q
qingqing01 已提交
265 266
                                    # gen(param.name + "_" + gen(opt_name) + "_" + accum_name)
                                    # always end with "_0" since the unique optimizer._name
267 268 269 270 271 272 273
                            dy_state_name = (param_name + "_" + opt_unq_name +
                                             "_" + accum_name + "_0")
                            converted_state[
                                state_var.name] = converted_state.pop(
                                    dy_state_name)

            assert var.name in converted_state, \
274
                "variable [{}] is not in optimizer state file".format(var.name)
275
            self._set_var(var, converted_state[var.name])
276 277 278 279 280 281 282 283 284 285 286 287 288 289

    def _set_var(self, var, ndarray):
        t = global_scope().find_var(var.name).get_tensor()
        p = t._place()
        if p.is_cpu_place():
            place = fluid.CPUPlace()
        elif p.is_cuda_pinned_place():
            place = fluid.CUDAPinnedPlace()
        else:
            p = fluid.core.Place()
            p.set_place(t._place())
            place = fluid.CUDAPlace(p.gpu_device_id())

        t.set(ndarray, place)
Y
Yang Zhang 已提交
290

291 292
    def _run(self, inputs, labels=None, device='CPU'):
        compiled_prog = self.prepare()
Y
Yang Zhang 已提交
293

294 295 296
        inputs = to_list(inputs)
        if labels is not None:
            labels = to_list(labels)
297 298
        assert len(inputs) == len(self._input_vars[self.mode]), \
            "number of inputs" \
299 300
            + " does not match number of arguments of `forward` method"

Y
Yang Zhang 已提交
301
        feed = {}
302
        input_names = [v.name for v in self._input_vars[self.mode]]
Y
Yang Zhang 已提交
303 304 305 306 307
        for idx, n in enumerate(input_names):
            # train and test may take different arguments
            if inputs[idx] is not None:
                feed[n] = inputs[idx]
        if labels is not None:
308
            for idx, v in enumerate(self._label_vars[self.mode]):
Y
Yang Zhang 已提交
309 310
                feed[v.name] = labels[idx]

311
        endpoints = self._endpoints[self.mode]
Q
qingqing01 已提交
312 313
        fetch_list = endpoints['output']
        if 'loss' in endpoints:
314
            fetch_list = endpoints['output'] + endpoints['loss']
315
        num_output = len(endpoints['output'])
Q
qingqing01 已提交
316 317 318
        out = self._executor.run(compiled_prog,
                                 feed=feed,
                                 fetch_list=fetch_list)
319 320 321 322
        if self.mode == 'test':
            return out[:num_output]
        else:
            return out[:num_output], out[num_output:]
Y
Yang Zhang 已提交
323

324 325 326 327 328 329 330
    def _get_loss(self, outputs):
        assert self.model._loss_function
        label_vars = [k.forward() for k in to_list(self.model._labels)]
        self._label_vars[self.mode] = label_vars
        losses = self.model._loss_function(outputs, label_vars)
        return losses

Y
Yang Zhang 已提交
331
    def _make_program(self, inputs):
332
        prog = self._orig_prog.clone()
333 334 335
        # change inputs to the same var in cloned program
        inputs = fluid.layers.utils.map_structure(
            lambda var: prog.global_block().var(var.name), inputs)
336 337 338 339 340
        # NOTE: When defining learning rate scheduling in static-graph, ops to
        # increase the global step var and calculate learning rate would be
        # prepended into _orig_prog. test program maked by `_orig_prog.clone`
        # also would include these ops. Thus must prune these ops in test
        # program, otherwise the global step would be changed in test.
341 342 343
        if self.mode != 'train':
            for op in list(prog.global_block().ops):
                prog.global_block()._remove_op(0)
344
        if self.mode == 'train' and self.model._optimizer._learning_rate_map:
345 346 347 348
            # HACK workaround learning rate map issue
            lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
            self.model._optimizer._learning_rate_map[prog] = lr_var
        losses = []
Y
Yang Zhang 已提交
349 350 351
        with fluid.program_guard(prog, self._startup_prog):
            outputs = to_list(self.model.forward(*inputs))
            if self.mode != 'test':
Q
qingqing01 已提交
352
                losses = self._get_loss(outputs)
353
                if self.mode == 'train' and self.model._optimizer:
Y
Yang Zhang 已提交
354 355
                    self._loss_endpoint = fluid.layers.sum(losses)
                    self.model._optimizer.minimize(self._loss_endpoint)
356 357
        if self.mode != 'train':  # clone again to put it in test mode
            prog = prog.clone(for_test=True)
Y
Yang Zhang 已提交
358
        self._progs[self.mode] = prog
359 360 361
        self._endpoints[self.mode] = {
            "output": outputs,
            "loss": losses
Q
qingqing01 已提交
362 363
        } if self.model._loss_function else {
            'output': outputs
364
        }
Y
Yang Zhang 已提交
365

366 367 368 369 370 371 372 373 374 375 376 377 378
    def prepare(self):
        compiled_prog = self._compiled_progs.get(self.mode, None)
        if compiled_prog is not None:
            return compiled_prog

        if isinstance(self.model._inputs, dict):
            ins = [self.model._inputs[n] \
                for n in extract_args(self.model.forward) if n != 'self']
        else:
            ins = self.model._inputs
        self._input_vars[self.mode] = [k.forward() for k in to_list(ins)]
        self._make_program(self._input_vars[self.mode])
        return self._compile_and_initialize(self._progs[self.mode])
Y
Yang Zhang 已提交
379

380
    def _compile_and_initialize(self, prog):
381 382 383
        compiled_prog = self._compiled_progs.get(self.mode, None)
        if compiled_prog is not None:
            return compiled_prog
Y
Yang Zhang 已提交
384

385 386 387 388 389 390 391
        device = self.model._device
        device_ids = self.model._device_ids

        if device.lower() == 'gpu':
            places = fluid.cuda_places(device_ids)
        else:
            places = fluid.cpu_places(len(device_ids) if device_ids else None)
Y
Yang Zhang 已提交
392

Y
Yang Zhang 已提交
393 394 395
        # XXX *ALL WEIGHTS* should be initialized upon model construction
        # even if `forward()` may run different code path for different mode
        # therefore startup program only needs to run once
Y
Yang Zhang 已提交
396
        if self._executor is None:
397
            self._executor = fluid.Executor(places[0])
Y
Yang Zhang 已提交
398 399 400 401 402 403 404 405 406 407
            # XXX incremental initialization
            uninitialized = []
            for var_py in self._startup_prog.list_vars():
                var = fluid.global_scope().find_var(var_py.name)
                if var and var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var_py)
            if uninitialized:
                startup_prog = self._startup_prog._prune(uninitialized)
                self._executor.run(startup_prog)
Y
Yang Zhang 已提交
408

409
        compiled_prog = fluid.CompiledProgram(prog)
410
        if len(places) > 1:
411 412 413 414 415 416 417 418 419 420 421 422 423
            loss_name = None
            if self.mode == 'train' and self._loss_endpoint is not None:
                loss_name = self._loss_endpoint.name

            share_vars_from = None
            if self.mode == 'eval' and 'train' in self._compiled_progs:
                share_vars_from = self._compiled_progs['train']
            # HACK invalidate eval program if is compiled before train program
            # quite hackish, OTOH, it is generally uncommon that the eval
            # program will be run before the train program
            if self.mode == 'train' and 'eval' in self._compiled_progs:
                del self._compiled_progs['eval']
            compiled_prog = compiled_prog.with_data_parallel(
Q
qingqing01 已提交
424 425
                loss_name=loss_name,
                places=places,
426
                share_vars_from=share_vars_from)
427
        self._compiled_progs[self.mode] = compiled_prog
Y
Yang Zhang 已提交
428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443
        return compiled_prog


class DynamicGraphAdapter(object):
    def __init__(self, model):
        super(DynamicGraphAdapter, self).__init__()
        self.model = model

    @property
    def mode(self):
        return self.model.mode

    @mode.setter
    def mode(self, value):
        self.model.mode = value

444
    # TODO multi device in dygraph mode not implemented at present time
445
    def train(self, inputs, labels=None):
Q
qingqing01 已提交
446
        assert self.model._optimizer, \
Y
Yang Zhang 已提交
447 448 449 450
            "model not ready, please call `model.prepare()` first"
        super(Model, self.model).train()
        self.mode = 'train'
        inputs = to_list(inputs)
Q
qingqing01 已提交
451 452
        if labels is not None:
            labels = to_list(labels)
Y
Yang Zhang 已提交
453
        outputs = self.model.forward(*[to_variable(x) for x in inputs])
Q
qingqing01 已提交
454
        losses = self._get_loss(outputs, labels)
Y
Yang Zhang 已提交
455 456 457 458
        final_loss = fluid.layers.sum(losses)
        final_loss.backward()
        self.model._optimizer.minimize(final_loss)
        self.model.clear_gradients()
459 460
        return [to_numpy(o) for o in to_list(outputs)], \
            [to_numpy(l) for l in losses]
Y
Yang Zhang 已提交
461

462
    def eval(self, inputs, labels=None):
463
        super(Model, self.model).eval()
Y
Yang Zhang 已提交
464 465
        self.mode = 'eval'
        inputs = to_list(inputs)
Q
qingqing01 已提交
466 467
        if labels is not None:
            labels = to_list(labels)
Y
Yang Zhang 已提交
468
        outputs = self.model.forward(*[to_variable(x) for x in inputs])
Q
qingqing01 已提交
469
        losses = self._get_loss(outputs, labels)
470 471
        return [to_numpy(o) for o in to_list(outputs)], \
            [to_numpy(l) for l in losses]
Y
Yang Zhang 已提交
472

473
    def test(self, inputs):
474
        super(Model, self.model).eval()
Y
Yang Zhang 已提交
475
        self.mode = 'test'
476 477 478
        inputs = [to_variable(x) for x in to_list(inputs)]
        outputs = self.model.forward(*inputs)
        return [to_numpy(o) for o in to_list(outputs)]
Y
Yang Zhang 已提交
479

Q
qingqing01 已提交
480
    def _get_loss(self, outputs, labels):
481 482
        assert self.model._loss_function
        return self.model._loss_function(outputs, labels)
Q
qingqing01 已提交
483

484 485 486
    def parameters(self, *args, **kwargs):
        return super(Model, self.model).parameters(*args, **kwargs)

Y
Yang Zhang 已提交
487 488 489 490 491 492 493 494 495 496 497 498
    def save(self, path):
        params = self.model.state_dict()
        fluid.save_dygraph(params, path)
        if self.model._optimizer is None:
            return
        if self.model._optimizer.state_dict():
            optim = self.model._optimizer.state_dict()
            fluid.save_dygraph(optim, path)

    def load(self, path):
        params, optim = fluid.load_dygraph(path)
        self.model.set_dict(params)
Y
Yang Zhang 已提交
499
        if self.model._optimizer is None or optim is None:
Y
Yang Zhang 已提交
500
            return
501 502 503 504 505 506 507 508 509 510 511 512

        # If optimizer performs set_dict when state vars haven't been created,
        # which would happen when set_dict before minimize, the state would be
        # stored in optimizer._accumulators_holder and loaded lazily.
        # To contrive this when loading from static-graph saved states, extend
        # state dict to include keys named accoring to dygraph naming rules.
        # TODO: if len(self.model._optimizer._accumulators) > 0
        converted_state = dict(optim)
        opt_unq_name = self.model._optimizer._name
        opt_cls_name = self.model._optimizer.__class__.__name__
        opt_name = opt_unq_name[:opt_unq_name.rfind("_")]  # remove suffix idx
        param_names = [param.name for param in self.model.parameters()]
Q
qingqing01 已提交
513 514
        for var_name, state_var in sorted(
                optim.items(), key=lambda x: len(x[0]), reverse=True):
515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544
            if var_name in ["@LR_DECAY_COUNTER@", "global_step"]:
                # NOTE: dygraph saved global_step is 1 larger than that in
                # static-graph, since the time of global_step to increase is
                # different.
                if var_name == "@LR_DECAY_COUNTER@":
                    converted_state["global_step"] = np.array(
                        converted_state.pop("@LR_DECAY_COUNTER@")) + 1
            else:
                # moment and other accumulators
                # extend state dict to include promising dygraph names
                for param_name in param_names:
                    if var_name.startswith(param_name + "_" + opt_name):
                        # when init optimizer with name
                        accum_name = var_name[len(param_name + "_" + opt_name +
                                                  "_"):]
                    elif var_name.startswith(param_name +
                                             "_") and opt_name == opt_cls_name:
                        # when init optimizer without name
                        accum_name = var_name[len(param_name + "_"):]
                    else:
                        continue
                    # remove suffix idx
                    accum_name = accum_name[:accum_name.rfind("_")]
                    # state names always end with "_0" in dygraph because of the
                    # unique optimizer._name
                    dy_state_name = (param_name + "_" + opt_unq_name + "_" +
                                     accum_name + "_0")
                    converted_state[dy_state_name] = state_var

        self.model._optimizer.set_dict(converted_state)
Y
Yang Zhang 已提交
545 546 547


class Model(fluid.dygraph.Layer):
Q
qingqing01 已提交
548 549 550 551
    """
    FIXME: add more comments and usage
    """

552
    def __init__(self):
Y
Yang Zhang 已提交
553 554
        super(Model, self).__init__(self.__class__.__name__)
        self.mode = 'train'
555 556
        self._inputs = None
        self._labels = None
Y
Yang Zhang 已提交
557
        self._loss_function = None
Y
Yang Zhang 已提交
558
        self._loss_weights = None
Q
qingqing01 已提交
559
        self._loss = None
Y
Yang Zhang 已提交
560
        self._optimizer = None
561 562 563
        self._device = None
        self._device_ids = None
        self._optimizer = None
Y
Yang Zhang 已提交
564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583
        if in_dygraph_mode():
            self._adapter = DynamicGraphAdapter(self)
        else:
            self._adapter = StaticGraphAdapter(self)

    def train(self, *args, **kwargs):
        return self._adapter.train(*args, **kwargs)

    def eval(self, *args, **kwargs):
        return self._adapter.eval(*args, **kwargs)

    def test(self, *args, **kwargs):
        return self._adapter.test(*args, **kwargs)

    def save(self, *args, **kwargs):
        return self._adapter.save(*args, **kwargs)

    def load(self, *args, **kwargs):
        return self._adapter.load(*args, **kwargs)

584 585 586 587 588 589 590
    def prepare(self,
                optimizer=None,
                loss_function=None,
                inputs=None,
                labels=None,
                device=None,
                device_ids=None):
591 592 593
        """
        FIXME: add comments
        Args:
594 595 596 597 598 599
            optimizer (Optimizer|None): optimizer must be set in training
                and should be a Optimizer instance. It can be None in eval
                and test mode.
            loss_function (Loss|None): loss function must be set in training
                and should be a Loss instance. It can be None when there is
                no loss.
600
            inputs (Input|list|dict|None): inputs, entry points of network,
601 602 603 604
            inputs (Input|list|dict|None): inputs, entry points of network,
                could be a Input layer, or lits of Input layers,
                or dict (name: Input), or None. For static graph,
                inputs must be set. For dynamic graph, it could be None.
605 606 607 608
            labels (Input|list|dict|None): labels, entry points of network,
                could be a Input layer or lits of Input layers, or None.
                For static graph, if set loss_function in Model.prepare(), it
                must be set. Otherwise, it could be None.
609 610 611 612 613 614 615 616 617 618 619 620 621 622
            device (str|None): specify device type, 'CPU' or 'GPU'.
                If None, automatically select device according to
                installation package version.
            device_ids (list[int]|None): specify device index. If None,
                the available device will be obtained from the environment
                variable when the model is executed: If the GPU is used, the
                currently available device ID is obtained from the environment
                variable FLAGS_selected_gpus or CUDA_VISIBLE_DEVICES when the
                model is executed; CPU, when the model is executed,
                the currently available CPU number is obtained from the
                environment variable CPU_NUM. For example, export CPU_NUM=4,
                if the environment variable is not set, the executor will add
                the variable to the environment variable and set its value to 1.
                The default is None.
623
        """
Y
Yang Zhang 已提交
624
        self._optimizer = optimizer
Q
qingqing01 已提交
625 626 627 628
        if loss_function:
            if not isinstance(loss_function, Loss):
                raise TypeError(
                    "'loss_function' must be sub classes of 'Loss'")
Y
Yang Zhang 已提交
629
        self._loss_function = loss_function
630 631 632 633 634 635 636 637
        if not in_dygraph_mode():
            if not isinstance(inputs, (list, dict, Input)):
                raise TypeError(
                    "'inputs' must be list or dict in static graph mode")
            if loss_function and not isinstance(labels, (list, Input)):
                raise TypeError("'labels' must be list in static graph mode")
        self._inputs = inputs
        self._labels = labels
638 639 640 641 642 643
        self._device = device
        if device is None:
            self._device = 'GPU' if fluid.is_compiled_with_cuda() else 'CPU'
        self._device_ids = device_ids
        if not in_dygraph_mode():
            self._adapter.prepare()
644 645 646

    def parameters(self, *args, **kwargs):
        return self._adapter.parameters(*args, **kwargs)