model.py 23.6 KB
Newer Older
Y
Yang Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
Yang Zhang 已提交
15 16 17
from __future__ import absolute_import

import inspect
Y
Yang Zhang 已提交
18 19
import os
import pickle
Y
Yang Zhang 已提交
20 21 22 23 24
from collections import OrderedDict

import numpy as np

from paddle import fluid
Y
Yang Zhang 已提交
25 26 27
from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer
Y
Yang Zhang 已提交
28 29
from paddle.fluid.dygraph.base import to_variable

Y
Yang Zhang 已提交
30
__all__ = ['shape_hints', 'Model', 'Loss', 'CrossEntropy']
Y
Yang Zhang 已提交
31 32 33 34 35 36 37 38


def to_list(value):
    if isinstance(value, (list, tuple)):
        return value
    return [value]


39 40 41 42 43 44 45 46
def to_numpy(var):
    assert isinstance(var, (Variable, fluid.core.VarBase)), "not a variable"
    if isinstance(var, fluid.core.VarBase):
        return var.numpy()
    t = global_scope().find_var(var.name).get_tensor()
    return np.array(t)


Y
Yang Zhang 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
def extract_args(func):
    if hasattr(inspect, 'getfullargspec'):
        return inspect.getfullargspec(func)[0]
    else:
        return inspect.getargspec(func)[0]


def shape_hints(**hints):
    assert hints, "hints can not be empty"
    assert all(isinstance(h, (list, tuple)) for h in hints.values()), \
        "shape hint must be a list or tuple"

    def wrapper(func):
        args = extract_args(func)
        invalid = set(hints.keys()) - set(args)
        assert not invalid, \
            "shape hint for arguments that are not present in forward method" \
            + ": ({})".format(", ".join(invalid))
        func.shape_hints = hints
        return func
    return wrapper


Y
Yang Zhang 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
class Loss(object):
    def __init__(self, average=True):
        super(Loss, self).__init__()
        self.average = average

    def infer_shape(self, outputs):
        return [o.shape for o in outputs]

    def infer_dtype(self, outputs):
        return [o.dtype for o in outputs]

    def forward(self, outputs, labels):
        raise NotImplementedError()

    def __call__(self, outputs, labels):
        labels = to_list(labels)
        if in_dygraph_mode():
            labels = [to_variable(l) for l in labels]
        losses = to_list(self.forward(to_list(outputs), labels))
        if not self.average:
            return losses
        return [fluid.layers.reduce_mean(l) for l in losses]


class CrossEntropy(Loss):
    def __init__(self):
        super(CrossEntropy, self).__init__()

    def infer_shape(self, outputs):
        return [o.shape[:-1] + (1, ) for o in outputs]

    def infer_dtype(self, outputs):
        return ['int64' for _ in outputs]

    def forward(self, outputs, labels):
        return [fluid.layers.cross_entropy(o, l) for o, l in zip(
            outputs, labels)]


Y
Yang Zhang 已提交
109 110 111 112 113 114 115
class StaticGraphAdapter(object):
    def __init__(self, model):
        super(StaticGraphAdapter, self).__init__()
        self.model = model
        # with `_build_once` gone, parameters are now created in `__init__`
        # so we need to keep track of the parameters already created
        self._startup_prog = fluid.default_startup_program()
116
        self._orig_prog = fluid.default_main_program()
Y
Yang Zhang 已提交
117

118
        self._label_vars = {}  # label variables
Y
Yang Zhang 已提交
119 120 121 122 123 124
        self._endpoints = {}
        self._loss_endpoint = None
        self._executor = None
        self._progs = {}
        self._compiled_progs = {}

125 126
        self._lazy_load_optimizer = None

Y
Yang Zhang 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
        # parse shape hints
        self._input_desc = OrderedDict([
            (n, None) for n in extract_args(self.model.forward) if n != 'self'
        ])
        if hasattr(self.model.forward, 'shape_hints'):
            self._input_desc.update(self.model.forward.shape_hints)

    @property
    def mode(self):
        return self.model.mode

    @mode.setter
    def mode(self, value):
        self.model.mode = value

    def train(self, inputs, labels, device='CPU', device_ids=None):
Y
Yang Zhang 已提交
143
        assert self.model._optimizer and self.model._loss_function, \
Y
Yang Zhang 已提交
144 145 146 147 148
            "model not ready, please call `model.prepare()` first"
        self.mode = 'train'
        return self._run(inputs, labels, device, device_ids)

    def eval(self, inputs, labels, device='CPU', device_ids=None):
Y
Yang Zhang 已提交
149
        assert self.model._loss_function, \
Y
Yang Zhang 已提交
150 151 152 153 154 155 156 157
            "model not ready, please call `model.prepare()` first"
        self.mode = 'eval'
        return self._run(inputs, labels, device, device_ids)

    def test(self, inputs, device='CPU', device_ids=None):
        self.mode = 'test'
        return self._run(inputs, None, device, device_ids)

158 159 160
    def parameters(self, *args, **kwargs):
        return None

Y
Yang Zhang 已提交
161
    def save(self, path):
Y
Yang Zhang 已提交
162 163 164
        def _save(state, path):
            if not state:
                return
165 166
            state = {k: to_numpy(v) if isinstance(v, Variable) else v
                     for k, v in state.items()}
Y
Yang Zhang 已提交
167 168 169 170 171
            with open(path, 'wb') as f:
                pickle.dump(state, f)

        base = os.path.basename(path)
        assert base != "", "path should be of 'dirname/filename' format"
172 173 174
        dir_name = os.path.dirname(path)
        if dir_name and not os.path.exists(dir_name):
            os.makedirs(dir_name)
Y
Yang Zhang 已提交
175 176
        param_path = path + ".pdparams"
        _save(self.model.state_dict(), param_path)
Y
Yang Zhang 已提交
177 178
        prog = self._progs.get('train', None)
        if prog is None or self.model._optimizer is None:
Y
Yang Zhang 已提交
179 180 181 182 183
            return
        # XXX `optimizer.state_dict()` only work in dygraph mode
        optim_path = path + ".pdopt"
        optim = {p.name: p for p in filter(
            is_belong_to_optimizer, prog.list_vars())}
184 185
        if not optim:
            return
186

Y
Yang Zhang 已提交
187
        _save(optim, optim_path)
Y
Yang Zhang 已提交
188 189

    def load(self, path):
Y
Yang Zhang 已提交
190 191 192 193 194 195 196
        def _load(path):
            if not os.path.exists(path):
                return
            with open(path, 'rb') as f:
                return pickle.load(f)

        param_path = path + ".pdparams"
Y
Yang Zhang 已提交
197 198 199 200
        param_state = _load(param_path)
        assert param_state, "failed to load parameters, please check path"

        if self._executor is None:
201 202 203 204
            # TODO: loading to CPU seems to some transform error, and only
            # the first step get the right result
            # executor = fluid.Executor(fluid.CPUPlace())._default_executor
            executor = fluid.Executor(fluid.CUDAPlace(0))._default_executor
Y
Yang Zhang 已提交
205 206 207 208 209 210
        else:
            executor = self._executor._default_executor

        fluid.core._create_loaded_parameter(
            list(self.model.state_dict().values()), global_scope(), executor)

Y
Yang Zhang 已提交
211
        for key, var in self.model.state_dict().items():
Y
Yang Zhang 已提交
212
            assert key in param_state, \
Y
Yang Zhang 已提交
213 214
                "parameter [{}] is not found in model file [{}]".format(
                    key, param_path)
215
            self._set_var(var, param_state[key])
Y
Yang Zhang 已提交
216 217 218 219 220 221 222 223 224

        # FIXME what if a different optimizer is used?
        if not self.model._optimizer:
            return
        optim_path = path + ".pdopt"
        optim_state = _load(optim_path)
        if optim_state is None:
            return

225
        if self._executor is not None:
226
            self._load_optimizer(optim_state)
227
        else:
228
            self._lazy_load_optimizer = optim_state
229 230 231 232 233 234 235

    def _load_optimizer(self, state):
        prog = self._progs.get('train', None)
        optim = list(filter(is_belong_to_optimizer, prog.list_vars()))
        if not optim:
            return

Y
Yang Zhang 已提交
236
        fluid.core._create_loaded_parameter(
237
            optim, global_scope(), self._executor._default_executor)
Y
Yang Zhang 已提交
238

239
        converted_state = dict(state)
Y
Yang Zhang 已提交
240
        for var in optim:
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
            if var.name in ["@LR_DECAY_COUNTER@", "global_step"]:
                # When using learning rate scheduler, dygraph would name the
                # global step var as "global_step" to save, while static-graph
                # would has a state var named as "@LR_DECAY_COUNTER@".
                # NOTE: dygraph saved global_step is 1 larger than that in
                # static-graph, since the time of global_step to increase is
                # different.
                state_val = (
                    np.array(converted_state.pop("global_step")) - 1
                ) if "global_step" in converted_state else converted_state.pop(
                    "@LR_DECAY_COUNTER@", None)
                if state_val is not None:
                    converted_state[var.name] = state_val
            elif var.name.startswith("learning_rate_"):
                # When using static learning rate, static-graph would make it
                # a persistable var named 'unique_name.generate("learning_rate")',
                # However, dygraph wouldn't save it.
                if var.name not in state: continue
            else:
                # moment and other accumulators
                if var.name not in converted_state:
                    # try to convert from dygraph name
                    opt_name = self.model._optimizer._name
                    opt_cls_name = self.model._optimizer.__class__.__name__
                    opt_unq_name = None
                    for name in self.model._optimizer._accumulators.keys():
                        accum_name = name if opt_name is None else name[
                            len(opt_name) + 1:]
                        for param_name, state_var in self.model._optimizer._accumulators[
                                name].items():
                            if opt_unq_name is None:
                                # can not infer out the exact unique(opt_name),
                                # thus try to extract rather than generate
                                for state_key in sorted(state.keys(),
                                                        key=lambda x: len(x),
                                                        reverse=True):
                                    prefix = param_name + "_" + (
                                        opt_cls_name if opt_name is None else
                                        opt_name) + "_"
                                    if state_key.startswith(prefix):
                                        prefix_offset = state_key[len(
                                            prefix):].find("_") + len(prefix)
                                        opt_unq_name = state_key[len(
                                            param_name + "_"):prefix_offset]
                                        # TODO: assert
                                        # assert opt_unq_name is None
                            # gen(param.name + "_" + gen(opt_name) + "_" + accum_name)
                            # always end with "_0" since the unique optimizer._name
                            dy_state_name = (param_name + "_" + opt_unq_name +
                                             "_" + accum_name + "_0")
                            converted_state[
                                state_var.name] = converted_state.pop(
                                    dy_state_name)

            assert var.name in converted_state, \
296
                "variable [{}] is not in optimizer state file".format(var.name)
297
            self._set_var(var, converted_state[var.name])
298 299 300 301 302 303 304 305 306 307 308 309 310 311

    def _set_var(self, var, ndarray):
        t = global_scope().find_var(var.name).get_tensor()
        p = t._place()
        if p.is_cpu_place():
            place = fluid.CPUPlace()
        elif p.is_cuda_pinned_place():
            place = fluid.CUDAPinnedPlace()
        else:
            p = fluid.core.Place()
            p.set_place(t._place())
            place = fluid.CUDAPlace(p.gpu_device_id())

        t.set(ndarray, place)
Y
Yang Zhang 已提交
312 313 314 315 316 317 318 319

    def _run(self, inputs, labels=None, device='CPU', device_ids=None):
        inputs = to_list(inputs)
        if labels is not None:
            labels = to_list(labels)
        assert len(inputs) == len(self._input_desc), "number of inputs" \
            + " does not match number of arguments of `forward` method"

Y
Yang Zhang 已提交
320 321
        if self._progs.get(self.mode, None) is None:
            self._make_program(self._infer_input_vars(inputs))
Y
Yang Zhang 已提交
322

323 324
        compiled_prog = self._compile_and_initialize(
            self._progs[self.mode], device, device_ids)
Y
Yang Zhang 已提交
325 326 327 328 329 330 331 332

        feed = {}
        input_names = [name for name in self._input_desc.keys()]
        for idx, n in enumerate(input_names):
            # train and test may take different arguments
            if inputs[idx] is not None:
                feed[n] = inputs[idx]
        if labels is not None:
333
            for idx, v in enumerate(self._label_vars[self.mode]):
Y
Yang Zhang 已提交
334 335
                feed[v.name] = labels[idx]

336 337 338 339
        endpoints = self._endpoints[self.mode]
        fetch_list = endpoints['output'] + endpoints['loss']
        num_output = len(endpoints['output'])
        out = self._executor.run(
Y
Yang Zhang 已提交
340
            compiled_prog, feed=feed,
341 342 343 344 345
            fetch_list=fetch_list)
        if self.mode == 'test':
            return out[:num_output]
        else:
            return out[:num_output], out[num_output:]
Y
Yang Zhang 已提交
346 347

    def _make_program(self, inputs):
348
        prog = self._orig_prog.clone()
349
        if self.mode == 'train' and self.model._optimizer._learning_rate_map:
350 351 352 353
            # HACK workaround learning rate map issue
            lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
            self.model._optimizer._learning_rate_map[prog] = lr_var
        losses = []
Y
Yang Zhang 已提交
354 355 356
        with fluid.program_guard(prog, self._startup_prog):
            outputs = to_list(self.model.forward(*inputs))
            if self.mode != 'test':
Y
Yang Zhang 已提交
357
                label_vars = self._infer_label_vars(outputs)
358
                self._label_vars[self.mode] = label_vars
Y
Yang Zhang 已提交
359
                losses = self.model._loss_function(outputs, label_vars)
Y
Yang Zhang 已提交
360 361 362
                if self.mode == 'train':
                    self._loss_endpoint = fluid.layers.sum(losses)
                    self.model._optimizer.minimize(self._loss_endpoint)
363 364
        if self.mode != 'train':  # clone again to put it in test mode
            prog = prog.clone(for_test=True)
Y
Yang Zhang 已提交
365
        self._progs[self.mode] = prog
366 367 368 369
        self._endpoints[self.mode] = {
            "output": outputs,
            "loss": losses
        }
Y
Yang Zhang 已提交
370 371 372 373 374 375 376 377 378 379 380 381 382 383 384

    def _infer_input_vars(self, inputs):
        input_vars = []
        for idx, i in enumerate(inputs):
            if i is None:  # train and test may take different arguments
                input_vars.append(None)
                continue
            ndarray = np.array(i)
            name = list(self._input_desc.keys())[idx]
            shape = list(self._input_desc.values())[idx]
            if shape is None:
                shape = (None, ) + ndarray.shape[1:]
            input_vars.append(fluid.data(name, shape, ndarray.dtype))
        return input_vars

Y
Yang Zhang 已提交
385 386 387 388 389 390 391 392
    def _infer_label_vars(self, outputs):
        shapes = self.model._loss_function.infer_shape(outputs)
        dtypes = self.model._loss_function.infer_dtype(outputs)
        label_vars = []
        for idx, (shape, dtype) in enumerate(zip(shapes, dtypes)):
            name = '__label{}'.format(idx)
            label_vars.append(fluid.data(name, shape, dtype))
        return label_vars
Y
Yang Zhang 已提交
393 394

    def _compile_and_initialize(self, prog, device='CPU', device_ids=None):
395 396 397
        compiled_prog = self._compiled_progs.get(self.mode, None)
        if compiled_prog is not None:
            return compiled_prog
Y
Yang Zhang 已提交
398

399 400
        places = [device.lower() == 'gpu' and fluid.CUDAPlace(i)
                  or fluid.CPUPlace() for i in device_ids]
Y
Yang Zhang 已提交
401

Y
Yang Zhang 已提交
402 403 404
        # XXX *ALL WEIGHTS* should be initialized upon model construction
        # even if `forward()` may run different code path for different mode
        # therefore startup program only needs to run once
Y
Yang Zhang 已提交
405
        if self._executor is None:
406
            self._executor = fluid.Executor(places[0])
Y
Yang Zhang 已提交
407 408 409 410 411 412 413 414 415 416
            # XXX incremental initialization
            uninitialized = []
            for var_py in self._startup_prog.list_vars():
                var = fluid.global_scope().find_var(var_py.name)
                if var and var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var_py)
            if uninitialized:
                startup_prog = self._startup_prog._prune(uninitialized)
                self._executor.run(startup_prog)
Y
Yang Zhang 已提交
417

418 419 420 421
            if self.mode == 'train' and self._lazy_load_optimizer:
                self._load_optimizer(self._lazy_load_optimizer)
                self._lazy_load_optimizer = None

422
        compiled_prog = fluid.CompiledProgram(prog)
423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
        if len(device_ids) > 1:
            loss_name = None
            if self.mode == 'train' and self._loss_endpoint is not None:
                loss_name = self._loss_endpoint.name

            share_vars_from = None
            if self.mode == 'eval' and 'train' in self._compiled_progs:
                share_vars_from = self._compiled_progs['train']
            # HACK invalidate eval program if is compiled before train program
            # quite hackish, OTOH, it is generally uncommon that the eval
            # program will be run before the train program
            if self.mode == 'train' and 'eval' in self._compiled_progs:
                del self._compiled_progs['eval']

            compiled_prog = compiled_prog.with_data_parallel(
                loss_name=loss_name, places=places,
                share_vars_from=share_vars_from)

441
        self._compiled_progs[self.mode] = compiled_prog
Y
Yang Zhang 已提交
442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
        return compiled_prog


class DynamicGraphAdapter(object):
    def __init__(self, model):
        super(DynamicGraphAdapter, self).__init__()
        self.model = model

    @property
    def mode(self):
        return self.model.mode

    @mode.setter
    def mode(self, value):
        self.model.mode = value

458
    # TODO multi device in dygraph mode not implemented at present time
Y
Yang Zhang 已提交
459
    def train(self, inputs, labels, device='CPU', device_ids=None):
Y
Yang Zhang 已提交
460
        assert self.model._optimizer and self.model._loss_function, \
Y
Yang Zhang 已提交
461 462 463 464 465 466
            "model not ready, please call `model.prepare()` first"
        super(Model, self.model).train()
        self.mode = 'train'
        inputs = to_list(inputs)
        labels = to_list(labels)
        outputs = self.model.forward(*[to_variable(x) for x in inputs])
Y
Yang Zhang 已提交
467
        losses = self.model._loss_function(outputs, labels)
Y
Yang Zhang 已提交
468 469 470 471
        final_loss = fluid.layers.sum(losses)
        final_loss.backward()
        self.model._optimizer.minimize(final_loss)
        self.model.clear_gradients()
472 473
        return [to_numpy(o) for o in to_list(outputs)], \
            [to_numpy(l) for l in losses]
Y
Yang Zhang 已提交
474 475

    def eval(self, inputs, labels, device='CPU', device_ids=None):
Y
Yang Zhang 已提交
476
        assert self.model._loss_function, \
Y
Yang Zhang 已提交
477
            "model not ready, please call `model.prepare()` first"
478
        super(Model, self.model).eval()
Y
Yang Zhang 已提交
479 480 481 482
        self.mode = 'eval'
        inputs = to_list(inputs)
        labels = to_list(labels)
        outputs = self.model.forward(*[to_variable(x) for x in inputs])
Y
Yang Zhang 已提交
483
        losses = self.model._loss_function(outputs, labels)
484 485
        return [to_numpy(o) for o in to_list(outputs)], \
            [to_numpy(l) for l in losses]
Y
Yang Zhang 已提交
486 487

    def test(self, inputs, device='CPU', device_ids=None):
488
        super(Model, self.model).eval()
Y
Yang Zhang 已提交
489
        self.mode = 'test'
490 491 492
        inputs = [to_variable(x) for x in to_list(inputs)]
        outputs = self.model.forward(*inputs)
        return [to_numpy(o) for o in to_list(outputs)]
Y
Yang Zhang 已提交
493

494 495 496
    def parameters(self, *args, **kwargs):
        return super(Model, self.model).parameters(*args, **kwargs)

Y
Yang Zhang 已提交
497 498 499 500 501 502 503 504 505 506 507 508
    def save(self, path):
        params = self.model.state_dict()
        fluid.save_dygraph(params, path)
        if self.model._optimizer is None:
            return
        if self.model._optimizer.state_dict():
            optim = self.model._optimizer.state_dict()
            fluid.save_dygraph(optim, path)

    def load(self, path):
        params, optim = fluid.load_dygraph(path)
        self.model.set_dict(params)
Y
Yang Zhang 已提交
509
        if self.model._optimizer is None or optim is None:
Y
Yang Zhang 已提交
510
            return
511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555

        # If optimizer performs set_dict when state vars haven't been created,
        # which would happen when set_dict before minimize, the state would be
        # stored in optimizer._accumulators_holder and loaded lazily.
        # To contrive this when loading from static-graph saved states, extend
        # state dict to include keys named accoring to dygraph naming rules.
        # TODO: if len(self.model._optimizer._accumulators) > 0
        converted_state = dict(optim)
        opt_unq_name = self.model._optimizer._name
        opt_cls_name = self.model._optimizer.__class__.__name__
        opt_name = opt_unq_name[:opt_unq_name.rfind("_")]  # remove suffix idx
        param_names = [param.name for param in self.model.parameters()]
        for var_name, state_var in sorted(optim.items(),
                                          key=lambda x: len(x[0]),
                                          reverse=True):
            if var_name in ["@LR_DECAY_COUNTER@", "global_step"]:
                # NOTE: dygraph saved global_step is 1 larger than that in
                # static-graph, since the time of global_step to increase is
                # different.
                if var_name == "@LR_DECAY_COUNTER@":
                    converted_state["global_step"] = np.array(
                        converted_state.pop("@LR_DECAY_COUNTER@")) + 1
            else:
                # moment and other accumulators
                # extend state dict to include promising dygraph names
                for param_name in param_names:
                    if var_name.startswith(param_name + "_" + opt_name):
                        # when init optimizer with name
                        accum_name = var_name[len(param_name + "_" + opt_name +
                                                  "_"):]
                    elif var_name.startswith(param_name +
                                             "_") and opt_name == opt_cls_name:
                        # when init optimizer without name
                        accum_name = var_name[len(param_name + "_"):]
                    else:
                        continue
                    # remove suffix idx
                    accum_name = accum_name[:accum_name.rfind("_")]
                    # state names always end with "_0" in dygraph because of the
                    # unique optimizer._name
                    dy_state_name = (param_name + "_" + opt_unq_name + "_" +
                                     accum_name + "_0")
                    converted_state[dy_state_name] = state_var

        self.model._optimizer.set_dict(converted_state)
Y
Yang Zhang 已提交
556 557 558 559 560 561


class Model(fluid.dygraph.Layer):
    def __init__(self):
        super(Model, self).__init__(self.__class__.__name__)
        self.mode = 'train'
Y
Yang Zhang 已提交
562
        self._loss_function = None
Y
Yang Zhang 已提交
563
        self._loss_weights = None
Y
Yang Zhang 已提交
564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
        self._optimizer = None
        if in_dygraph_mode():
            self._adapter = DynamicGraphAdapter(self)
        else:
            self._adapter = StaticGraphAdapter(self)

    def train(self, *args, **kwargs):
        return self._adapter.train(*args, **kwargs)

    def eval(self, *args, **kwargs):
        return self._adapter.eval(*args, **kwargs)

    def test(self, *args, **kwargs):
        return self._adapter.test(*args, **kwargs)

    def save(self, *args, **kwargs):
        return self._adapter.save(*args, **kwargs)

    def load(self, *args, **kwargs):
        return self._adapter.load(*args, **kwargs)

Y
Yang Zhang 已提交
585
    def prepare(self, optimizer, loss_function):
Y
Yang Zhang 已提交
586
        self._optimizer = optimizer
Y
Yang Zhang 已提交
587 588 589
        assert isinstance(loss_function, Loss), \
            "'loss_function' must be sub classes of 'Loss'"
        self._loss_function = loss_function
590 591 592

    def parameters(self, *args, **kwargs):
        return self._adapter.parameters(*args, **kwargs)