model.py 35.1 KB
Newer Older
Y
Yang Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
Yang Zhang 已提交
15 16 17
from __future__ import absolute_import

import inspect
Y
Yang Zhang 已提交
18 19
import os
import pickle
Y
Yang Zhang 已提交
20
import numpy as np
Q
qingqing01 已提交
21 22
from collections import Iterable
from collections import OrderedDict
Y
Yang Zhang 已提交
23

L
LielinJiang 已提交
24
from collections import OrderedDict
Y
Yang Zhang 已提交
25
from paddle import fluid
Y
Yang Zhang 已提交
26 27 28
from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer
Y
Yang Zhang 已提交
29
from paddle.fluid.dygraph.base import to_variable
30

L
LielinJiang 已提交
31 32 33
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
import distributed
Y
Yang Zhang 已提交
34

35
from metrics import Metric
Q
qingqing01 已提交
36
from callbacks import config_callbacks
Y
Yang Zhang 已提交
37

38

Q
qingqing01 已提交
39
__all__ = ['Model', 'Loss', 'CrossEntropy', 'Input']
Y
Yang Zhang 已提交
40 41 42


def to_list(value):
Q
qingqing01 已提交
43 44
    if value is None:
        return value
Y
Yang Zhang 已提交
45 46 47 48 49
    if isinstance(value, (list, tuple)):
        return value
    return [value]


50 51 52 53 54 55 56 57
def to_numpy(var):
    assert isinstance(var, (Variable, fluid.core.VarBase)), "not a variable"
    if isinstance(var, fluid.core.VarBase):
        return var.numpy()
    t = global_scope().find_var(var.name).get_tensor()
    return np.array(t)


D
dengkaipeng 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
def flatten_list(l):
    assert isinstance(l, list), "not a list"
    outl = []
    splits = []
    for sl in l:
        assert isinstance(sl, list), "sub content not a list"
        splits.append(len(sl))
        outl += sl
    return outl, splits


def restore_flatten_list(l, splits):
    outl = []
    for split in splits:
        assert len(l) >= split, "list length invalid"
        sl, l = l[:split], l[split:]
        outl.append(sl)
    return outl


Y
Yang Zhang 已提交
78 79 80 81 82 83 84
def extract_args(func):
    if hasattr(inspect, 'getfullargspec'):
        return inspect.getfullargspec(func)[0]
    else:
        return inspect.getargspec(func)[0]


85 86
class Input(fluid.dygraph.Layer):
    def __init__(self, shape=None, dtype=None, name=None):
G
guosheng 已提交
87
        super(Input, self).__init__()
88 89 90
        self.shape = shape
        self.dtype = dtype
        self.name = name
Y
Yang Zhang 已提交
91

92 93
    def forward(self):
        return fluid.data(self.name, shape=self.shape, dtype=self.dtype)
Y
Yang Zhang 已提交
94 95


Y
Yang Zhang 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108
class Loss(object):
    def __init__(self, average=True):
        super(Loss, self).__init__()
        self.average = average

    def forward(self, outputs, labels):
        raise NotImplementedError()

    def __call__(self, outputs, labels):
        labels = to_list(labels)
        if in_dygraph_mode():
            labels = [to_variable(l) for l in labels]
        losses = to_list(self.forward(to_list(outputs), labels))
Q
qingqing01 已提交
109 110 111 112 113
        if self.average:
            losses = [fluid.layers.reduce_mean(l) for l in losses]
        else:
            losses = [fluid.layers.reduce_sum(l) for l in losses]
        return losses
Y
Yang Zhang 已提交
114 115 116


class CrossEntropy(Loss):
Q
qingqing01 已提交
117
    def __init__(self, average=True):
Y
Yang Zhang 已提交
118 119 120
        super(CrossEntropy, self).__init__()

    def forward(self, outputs, labels):
Q
qingqing01 已提交
121 122 123
        return [
            fluid.layers.cross_entropy(o, l) for o, l in zip(outputs, labels)
        ]
Y
Yang Zhang 已提交
124 125


Y
Yang Zhang 已提交
126 127 128 129 130 131 132
class StaticGraphAdapter(object):
    def __init__(self, model):
        super(StaticGraphAdapter, self).__init__()
        self.model = model
        # with `_build_once` gone, parameters are now created in `__init__`
        # so we need to keep track of the parameters already created
        self._startup_prog = fluid.default_startup_program()
133
        self._orig_prog = fluid.default_main_program()
Y
Yang Zhang 已提交
134

135
        self._label_vars = {}  # label variables
Q
qingqing01 已提交
136
        self._input_vars = {}  # label variables
Y
Yang Zhang 已提交
137 138 139 140 141 142
        self._endpoints = {}
        self._loss_endpoint = None
        self._executor = None
        self._progs = {}
        self._compiled_progs = {}

L
LielinJiang 已提交
143
        self._merge_count = {'eval': 0, 'test': 0}
L
LielinJiang 已提交
144 145 146
        self._nranks = distributed.Env().nranks
        self._local_rank = distributed.Env().local_rank

Y
Yang Zhang 已提交
147 148 149 150 151 152 153 154
    @property
    def mode(self):
        return self.model.mode

    @mode.setter
    def mode(self, value):
        self.model.mode = value

155
    def train(self, inputs, labels=None):
Q
qingqing01 已提交
156
        assert self.model._optimizer, \
Y
Yang Zhang 已提交
157 158
            "model not ready, please call `model.prepare()` first"
        self.mode = 'train'
159
        return self._run(inputs, labels)
Y
Yang Zhang 已提交
160

161
    def eval(self, inputs, labels=None):
Y
Yang Zhang 已提交
162
        self.mode = 'eval'
163
        return self._run(inputs, labels)
Y
Yang Zhang 已提交
164

165
    def test(self, inputs):
Y
Yang Zhang 已提交
166
        self.mode = 'test'
167
        return self._run(inputs, None)
Y
Yang Zhang 已提交
168

169 170 171
    def parameters(self, *args, **kwargs):
        return None

Y
Yang Zhang 已提交
172
    def save(self, path):
Y
Yang Zhang 已提交
173 174 175
        def _save(state, path):
            if not state:
                return
Q
qingqing01 已提交
176 177 178 179
            state = {
                k: to_numpy(v) if isinstance(v, Variable) else v
                for k, v in state.items()
            }
Y
Yang Zhang 已提交
180 181 182 183 184
            with open(path, 'wb') as f:
                pickle.dump(state, f)

        base = os.path.basename(path)
        assert base != "", "path should be of 'dirname/filename' format"
185 186 187
        dir_name = os.path.dirname(path)
        if dir_name and not os.path.exists(dir_name):
            os.makedirs(dir_name)
Y
Yang Zhang 已提交
188 189
        param_path = path + ".pdparams"
        _save(self.model.state_dict(), param_path)
Y
Yang Zhang 已提交
190 191
        prog = self._progs.get('train', None)
        if prog is None or self.model._optimizer is None:
Y
Yang Zhang 已提交
192 193 194
            return
        # XXX `optimizer.state_dict()` only work in dygraph mode
        optim_path = path + ".pdopt"
Q
qingqing01 已提交
195 196 197 198
        optim = {
            p.name: p
            for p in filter(is_belong_to_optimizer, prog.list_vars())
        }
199 200
        if not optim:
            return
201

Y
Yang Zhang 已提交
202
        _save(optim, optim_path)
Y
Yang Zhang 已提交
203 204

    def load(self, path):
Y
Yang Zhang 已提交
205 206 207 208 209 210 211
        def _load(path):
            if not os.path.exists(path):
                return
            with open(path, 'rb') as f:
                return pickle.load(f)

        param_path = path + ".pdparams"
Y
Yang Zhang 已提交
212 213 214 215 216 217 218 219 220 221 222
        param_state = _load(param_path)
        assert param_state, "failed to load parameters, please check path"

        if self._executor is None:
            executor = fluid.Executor(fluid.CPUPlace())._default_executor
        else:
            executor = self._executor._default_executor

        fluid.core._create_loaded_parameter(
            list(self.model.state_dict().values()), global_scope(), executor)

Y
Yang Zhang 已提交
223
        for key, var in self.model.state_dict().items():
Y
Yang Zhang 已提交
224
            assert key in param_state, \
Y
Yang Zhang 已提交
225 226
                "parameter [{}] is not found in model file [{}]".format(
                    key, param_path)
227
            self._set_var(var, param_state[key])
Y
Yang Zhang 已提交
228 229 230 231 232 233 234 235 236

        # FIXME what if a different optimizer is used?
        if not self.model._optimizer:
            return
        optim_path = path + ".pdopt"
        optim_state = _load(optim_path)
        if optim_state is None:
            return

237
        self._load_optimizer(optim_state, executor)
238

239
    def _load_optimizer(self, state, executor):
240 241 242 243 244
        prog = self._progs.get('train', None)
        optim = list(filter(is_belong_to_optimizer, prog.list_vars()))
        if not optim:
            return

245
        fluid.core._create_loaded_parameter(optim, global_scope(), executor)
Y
Yang Zhang 已提交
246

247
        converted_state = dict(state)
Y
Yang Zhang 已提交
248
        for var in optim:
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
            if var.name in ["@LR_DECAY_COUNTER@", "global_step"]:
                # When using learning rate scheduler, dygraph would name the
                # global step var as "global_step" to save, while static-graph
                # would has a state var named as "@LR_DECAY_COUNTER@".
                # NOTE: dygraph saved global_step is 1 larger than that in
                # static-graph, since the time of global_step to increase is
                # different.
                state_val = (
                    np.array(converted_state.pop("global_step")) - 1
                ) if "global_step" in converted_state else converted_state.pop(
                    "@LR_DECAY_COUNTER@", None)
                if state_val is not None:
                    converted_state[var.name] = state_val
            elif var.name.startswith("learning_rate_"):
                # When using static learning rate, static-graph would make it
                # a persistable var named 'unique_name.generate("learning_rate")',
                # However, dygraph wouldn't save it.
                if var.name not in state: continue
            else:
                # moment and other accumulators
                if var.name not in converted_state:
                    # try to convert from dygraph name
                    opt_name = self.model._optimizer._name
                    opt_cls_name = self.model._optimizer.__class__.__name__
                    opt_unq_name = None
                    for name in self.model._optimizer._accumulators.keys():
Q
qingqing01 已提交
275 276
                        accum_name = name if opt_name is None else name[len(
                            opt_name) + 1:]
277 278 279 280 281
                        for param_name, state_var in self.model._optimizer._accumulators[
                                name].items():
                            if opt_unq_name is None:
                                # can not infer out the exact unique(opt_name),
                                # thus try to extract rather than generate
Q
qingqing01 已提交
282 283 284 285
                                for state_key in sorted(
                                        state.keys(),
                                        key=lambda x: len(x),
                                        reverse=True):
286 287 288 289 290 291 292 293 294 295
                                    prefix = param_name + "_" + (
                                        opt_cls_name if opt_name is None else
                                        opt_name) + "_"
                                    if state_key.startswith(prefix):
                                        prefix_offset = state_key[len(
                                            prefix):].find("_") + len(prefix)
                                        opt_unq_name = state_key[len(
                                            param_name + "_"):prefix_offset]
                                        # TODO: assert
                                        # assert opt_unq_name is None
Q
qingqing01 已提交
296 297
                                    # gen(param.name + "_" + gen(opt_name) + "_" + accum_name)
                                    # always end with "_0" since the unique optimizer._name
298 299 300 301 302 303 304
                            dy_state_name = (param_name + "_" + opt_unq_name +
                                             "_" + accum_name + "_0")
                            converted_state[
                                state_var.name] = converted_state.pop(
                                    dy_state_name)

            assert var.name in converted_state, \
305
                "variable [{}] is not in optimizer state file".format(var.name)
306
            self._set_var(var, converted_state[var.name])
307 308 309 310 311 312 313 314 315 316 317 318 319 320

    def _set_var(self, var, ndarray):
        t = global_scope().find_var(var.name).get_tensor()
        p = t._place()
        if p.is_cpu_place():
            place = fluid.CPUPlace()
        elif p.is_cuda_pinned_place():
            place = fluid.CUDAPinnedPlace()
        else:
            p = fluid.core.Place()
            p.set_place(t._place())
            place = fluid.CUDAPlace(p.gpu_device_id())

        t.set(ndarray, place)
Y
Yang Zhang 已提交
321

322 323 324 325
    def _run(self, inputs, labels=None):
        compiled_prog = self._compiled_progs.get(self.mode, None)
        assert compiled_prog, \
            "Model is not ready, please call `model.prepare()` first"
Y
Yang Zhang 已提交
326 327 328 329

        inputs = to_list(inputs)
        if labels is not None:
            labels = to_list(labels)
330 331
        assert len(inputs) == len(self._input_vars[self.mode]), \
            "number of inputs" \
Y
Yang Zhang 已提交
332 333 334
            + " does not match number of arguments of `forward` method"

        feed = {}
335
        input_names = [v.name for v in self._input_vars[self.mode]]
Y
Yang Zhang 已提交
336 337 338 339 340
        for idx, n in enumerate(input_names):
            # train and test may take different arguments
            if inputs[idx] is not None:
                feed[n] = inputs[idx]
        if labels is not None:
341
            for idx, v in enumerate(self._label_vars[self.mode]):
Y
Yang Zhang 已提交
342 343
                feed[v.name] = labels[idx]

344 345
        endpoints = self._endpoints[self.mode]
        if self.mode == 'test':
D
dengkaipeng 已提交
346
            fetch_list = endpoints['output']
347
        else:
D
dengkaipeng 已提交
348 349 350
            metric_list, metric_splits = flatten_list(endpoints['metric'])
            fetch_list = endpoints['loss'] + metric_list
            num_loss = len(endpoints['loss'])
Q
qingqing01 已提交
351 352 353 354
        rets = self._executor.run(compiled_prog,
                                  feed=feed,
                                  fetch_list=fetch_list,
                                  return_numpy=False)
D
dengkaipeng 已提交
355 356
        # LoDTensor cannot be fetch as numpy directly
        rets = [np.array(v) for v in rets]
357
        if self.mode == 'test':
D
dengkaipeng 已提交
358 359 360 361 362
            return rets[:]
        losses = rets[:num_loss]
        metric_states = restore_flatten_list(rets[num_loss:], metric_splits)
        metrics = []
        for metric, state in zip(self.model._metrics, metric_states):
L
LielinJiang 已提交
363
            # cut off padding size
L
LielinJiang 已提交
364 365
            if self.mode != 'train' and self.model._test_dataloader is not None and self._nranks > 1:
                total_size = len(self.model._test_dataloader.dataset)
L
LielinJiang 已提交
366
                samples = state[0].shape[0]
L
LielinJiang 已提交
367 368 369 370 371 372 373
                current_count = self._merge_count.get(self.mode, 0)
                if current_count + samples > total_size:
                    state = [s[:total_size - current_count, ...] for s in state]
                    self._merge_count[self.mode] = 0
                else:
                    self._merge_count[self.mode] += samples

D
dengkaipeng 已提交
374
            metrics.append(metric.update(*state))
D
dengkaipeng 已提交
375
        return (losses, metrics) if len(metrics) > 0 else losses
Y
Yang Zhang 已提交
376

377 378 379 380 381 382 383 384 385 386
    def prepare(self):
        modes = ['train', 'eval', 'test']
        for mode in modes:
            self._make_program(mode)
            self._compile_and_initialize(self._progs[mode], mode)

    def _make_program(self, mode):
        prog = self._progs.get(mode, None)
        if prog is not None:
            return
Y
Yang Zhang 已提交
387

388
        prog = self._orig_prog.clone()
389 390 391 392 393
        # NOTE: When defining learning rate scheduling in static-graph, ops to
        # increase the global step var and calculate learning rate would be
        # prepended into _orig_prog. test program maked by `_orig_prog.clone`
        # also would include these ops. Thus must prune these ops in test
        # program, otherwise the global step would be changed in test.
394
        if mode != 'train':
395 396
            for op in list(prog.global_block().ops):
                prog.global_block()._remove_op(0)
397 398
        if mode == 'train' and self.model._optimizer \
            and self.model._optimizer._learning_rate_map:
399 400 401
            # HACK workaround learning rate map issue
            lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
            self.model._optimizer._learning_rate_map[prog] = lr_var
L
LielinJiang 已提交
402
                
403
        losses = []
D
dengkaipeng 已提交
404
        metrics = []
Y
Yang Zhang 已提交
405
        with fluid.program_guard(prog, self._startup_prog):
406 407 408 409 410 411 412 413
            if isinstance(self.model._inputs, dict):
                ins = [self.model._inputs[n] \
                    for n in extract_args(self.model.forward) if n != 'self']
            else:
                ins = self.model._inputs
            lbls = self.model._labels if self.model._labels else []
            inputs = [k.forward() for k in to_list(ins)]
            labels = [k.forward() for k in to_list(lbls)]
L
LielinJiang 已提交
414
            self._label_vars[mode] = labels
Y
Yang Zhang 已提交
415
            outputs = to_list(self.model.forward(*inputs))
416 417 418
            if mode != 'test':
                if self.model._loss_function:
                    losses = self.model._loss_function(outputs, labels)
L
LielinJiang 已提交
419
                    
420
                if mode == 'train' and self.model._optimizer:
Y
Yang Zhang 已提交
421
                    self._loss_endpoint = fluid.layers.sum(losses)
L
LielinJiang 已提交
422 423 424 425 426 427 428 429
                    if self._nranks > 1:
                        role = role_maker.PaddleCloudRoleMaker(is_collective=True)
                        fleet.init(role)
                        dist_strategy = DistributedStrategy()
                        dist_strategy.mode = "collective"
                        dist_strategy.collective_mode = "grad_allreduce"
                        self.model._optimizer = fleet.distributed_optimizer(self.model._optimizer, strategy=dist_strategy)
                        
Y
Yang Zhang 已提交
430
                    self.model._optimizer.minimize(self._loss_endpoint)
L
LielinJiang 已提交
431
            if self._nranks > 1 and mode != 'train' and self.model._test_dataloader is not None:
L
LielinJiang 已提交
432
                outputs = [distributed._all_gather(o, self._nranks) for o in outputs]
L
LielinJiang 已提交
433 434 435 436 437 438
                if mode != 'test':
                    labels = [distributed._all_gather(l, self._nranks) for l in labels]
                    
            if mode != 'test':
                for metric in self.model._metrics:
                    metrics.append(to_list(metric.add_metric_op(outputs, labels)))   
L
LielinJiang 已提交
439
                     
440
        if mode != 'train':  # clone again to put it in test mode
441
            prog = prog.clone(for_test=True)
Y
Yang Zhang 已提交
442

443
        self._input_vars[mode] = inputs
L
LielinJiang 已提交
444
        
445
        self._progs[mode] = prog
Q
qingqing01 已提交
446 447 448 449 450
        self._endpoints[mode] = {
            "output": outputs,
            "loss": losses,
            "metric": metrics
        }
451

L
LielinJiang 已提交
452

453 454
    def _compile_and_initialize(self, prog, mode):
        compiled_prog = self._compiled_progs.get(mode, None)
455 456
        if compiled_prog is not None:
            return compiled_prog
Y
Yang Zhang 已提交
457

458 459 460 461 462 463 464
        device = self.model._device
        device_ids = self.model._device_ids

        if device.lower() == 'gpu':
            places = fluid.cuda_places(device_ids)
        else:
            places = fluid.cpu_places(len(device_ids) if device_ids else None)
Y
Yang Zhang 已提交
465

Y
Yang Zhang 已提交
466 467 468
        # XXX *ALL WEIGHTS* should be initialized upon model construction
        # even if `forward()` may run different code path for different mode
        # therefore startup program only needs to run once
Y
Yang Zhang 已提交
469
        if self._executor is None:
L
LielinJiang 已提交
470 471 472 473 474 475
            if self._nranks > 1 and device.lower() == 'gpu':
                gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
                place = fluid.CUDAPlace(gpu_id) if device.lower() == 'gpu' else fluid.CPUPlace()
            else:
                place = places[0]
            self._executor = fluid.Executor(place)
Y
Yang Zhang 已提交
476 477 478 479
            # XXX incremental initialization
            uninitialized = []
            for var_py in self._startup_prog.list_vars():
                var = fluid.global_scope().find_var(var_py.name)
L
LielinJiang 已提交
480
                if not var_py.name.startswith('nccl_id') and var and var.get_tensor()._is_initialized():
Y
Yang Zhang 已提交
481
                    continue
L
LielinJiang 已提交
482

Y
Yang Zhang 已提交
483 484 485 486
                uninitialized.append(var_py)
            if uninitialized:
                startup_prog = self._startup_prog._prune(uninitialized)
                self._executor.run(startup_prog)
Y
Yang Zhang 已提交
487

L
LielinJiang 已提交
488 489 490 491
        if self._nranks < 2:
            compiled_prog = fluid.CompiledProgram(prog)
        else:
            compiled_prog = prog#fleet.main_program
492

493
        if len(places) > 1:
494
            loss_name = None
495
            if mode == 'train' and self._loss_endpoint is not None:
496 497
                loss_name = self._loss_endpoint.name
            compiled_prog = compiled_prog.with_data_parallel(
498 499
                loss_name=loss_name, places=places)
        self._compiled_progs[mode] = compiled_prog
Y
Yang Zhang 已提交
500 501 502 503 504 505


class DynamicGraphAdapter(object):
    def __init__(self, model):
        super(DynamicGraphAdapter, self).__init__()
        self.model = model
L
LielinJiang 已提交
506 507
        self._nranks = distributed.Env().nranks
        self._local_rank = distributed.Env().local_rank
L
LielinJiang 已提交
508
        self._merge_count = {'eval': 0, 'test': 0}
L
LielinJiang 已提交
509 510 511

        if self._nranks > 1:
            self.ddp_model = distributed.DistributedDataParallel(self.model)
Y
Yang Zhang 已提交
512 513 514 515 516 517 518 519 520

    @property
    def mode(self):
        return self.model.mode

    @mode.setter
    def mode(self, value):
        self.model.mode = value

521
    # TODO multi device in dygraph mode not implemented at present time
522
    def train(self, inputs, labels=None):
Q
qingqing01 已提交
523
        assert self.model._optimizer, \
Y
Yang Zhang 已提交
524 525 526 527
            "model not ready, please call `model.prepare()` first"
        super(Model, self.model).train()
        self.mode = 'train'
        inputs = to_list(inputs)
Q
qingqing01 已提交
528
        if labels is not None:
D
dengkaipeng 已提交
529
            labels = [to_variable(l) for l in to_list(labels)]
L
LielinJiang 已提交
530 531 532 533 534 535 536 537 538 539 540 541
        if self._nranks > 1:
            outputs = self.ddp_model.forward(*[to_variable(x) for x in inputs])
            losses = self.model._loss_function(outputs, labels)
            final_loss = fluid.layers.sum(losses)
            final_loss = self.ddp_model.scale_loss(final_loss)
            final_loss.backward()
            self.ddp_model.apply_collective_grads()
        else:
            outputs = self.model.forward(*[to_variable(x) for x in inputs])
            losses = self.model._loss_function(outputs, labels)
            final_loss = fluid.layers.sum(losses)
            final_loss.backward()
Y
Yang Zhang 已提交
542 543
        self.model._optimizer.minimize(final_loss)
        self.model.clear_gradients()
D
dengkaipeng 已提交
544
        metrics = []
545
        for metric in self.model._metrics:
L
LielinJiang 已提交
546
            metric_outs = metric.add_metric_op(to_list(outputs), to_list(labels))
D
dengkaipeng 已提交
547 548
            m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
            metrics.append(m)
D
dengkaipeng 已提交
549 550
        return ([to_numpy(l) for l in losses], metrics) \
                if len(metrics) > 0 else [to_numpy(l) for l in losses]
Y
Yang Zhang 已提交
551 552

    def eval(self, inputs, labels, device='CPU', device_ids=None):
Y
Yang Zhang 已提交
553
        assert self.model._loss_function, \
Y
Yang Zhang 已提交
554
            "model not ready, please call `model.prepare()` first"
555
        super(Model, self.model).eval()
Y
Yang Zhang 已提交
556 557
        self.mode = 'eval'
        inputs = to_list(inputs)
Q
qingqing01 已提交
558
        if labels is not None:
D
dengkaipeng 已提交
559
            labels = [to_variable(l) for l in to_list(labels)]
Y
Yang Zhang 已提交
560
        outputs = self.model.forward(*[to_variable(x) for x in inputs])
561 562 563 564
        if self.model._loss_function:
            losses = self.model._loss_function(outputs, labels)
        else:
            losses = []
L
LielinJiang 已提交
565 566 567
        if self._nranks > 1:
            outputs = [distributed._all_gather(o, self._nranks) for o in to_list(outputs)]
            labels = [distributed._all_gather(l, self._nranks) for l in labels]
D
dengkaipeng 已提交
568
        metrics = []
D
dengkaipeng 已提交
569
        for metric in self.model._metrics:
L
LielinJiang 已提交
570
            # cut off padding value.
L
LielinJiang 已提交
571 572
            if self.model._test_dataloader is not None and self._nranks > 1:
                total_size = len(self.model._test_dataloader.dataset)
L
LielinJiang 已提交
573
                samples = outputs[0].shape[0]
L
LielinJiang 已提交
574 575
                current_count = self._merge_count.get(self.mode, 0)
                if current_count + samples > total_size:
L
LielinJiang 已提交
576 577
                    outputs = [o[:total_size - metric.count[0]] for o in outputs]
                    labels = [l[:total_size - metric.count[0]] for l in labels]
L
LielinJiang 已提交
578 579 580
                    self._merge_count[self.mode] = 0
                else:
                    self._merge_count[self.mode] += samples
L
LielinJiang 已提交
581 582

            metric_outs = metric.add_metric_op(to_list(outputs), labels)
D
dengkaipeng 已提交
583 584
            m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
            metrics.append(m)
D
dengkaipeng 已提交
585

586 587
        # To be consistent with static graph
        # return empty loss if loss_function is None
D
dengkaipeng 已提交
588 589
        return ([to_numpy(l) for l in losses], metrics) \
                if len(metrics) > 0 else [to_numpy(l) for l in losses]
Y
Yang Zhang 已提交
590

591
    def test(self, inputs):
592
        super(Model, self.model).eval()
Y
Yang Zhang 已提交
593
        self.mode = 'test'
594 595
        inputs = [to_variable(x) for x in to_list(inputs)]
        outputs = self.model.forward(*inputs)
L
LielinJiang 已提交
596 597
        if self._nranks > 2:
            outputs = [distributed._all_gather(o, self._nranks) for o in to_list(outputs)]
598
        return [to_numpy(o) for o in to_list(outputs)]
Y
Yang Zhang 已提交
599

600 601 602
    def parameters(self, *args, **kwargs):
        return super(Model, self.model).parameters(*args, **kwargs)

Y
Yang Zhang 已提交
603 604 605 606 607 608 609 610 611 612 613 614
    def save(self, path):
        params = self.model.state_dict()
        fluid.save_dygraph(params, path)
        if self.model._optimizer is None:
            return
        if self.model._optimizer.state_dict():
            optim = self.model._optimizer.state_dict()
            fluid.save_dygraph(optim, path)

    def load(self, path):
        params, optim = fluid.load_dygraph(path)
        self.model.set_dict(params)
Y
Yang Zhang 已提交
615
        if self.model._optimizer is None or optim is None:
Y
Yang Zhang 已提交
616
            return
617 618 619 620 621 622 623 624 625 626 627 628

        # If optimizer performs set_dict when state vars haven't been created,
        # which would happen when set_dict before minimize, the state would be
        # stored in optimizer._accumulators_holder and loaded lazily.
        # To contrive this when loading from static-graph saved states, extend
        # state dict to include keys named accoring to dygraph naming rules.
        # TODO: if len(self.model._optimizer._accumulators) > 0
        converted_state = dict(optim)
        opt_unq_name = self.model._optimizer._name
        opt_cls_name = self.model._optimizer.__class__.__name__
        opt_name = opt_unq_name[:opt_unq_name.rfind("_")]  # remove suffix idx
        param_names = [param.name for param in self.model.parameters()]
Q
qingqing01 已提交
629 630
        for var_name, state_var in sorted(
                optim.items(), key=lambda x: len(x[0]), reverse=True):
631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660
            if var_name in ["@LR_DECAY_COUNTER@", "global_step"]:
                # NOTE: dygraph saved global_step is 1 larger than that in
                # static-graph, since the time of global_step to increase is
                # different.
                if var_name == "@LR_DECAY_COUNTER@":
                    converted_state["global_step"] = np.array(
                        converted_state.pop("@LR_DECAY_COUNTER@")) + 1
            else:
                # moment and other accumulators
                # extend state dict to include promising dygraph names
                for param_name in param_names:
                    if var_name.startswith(param_name + "_" + opt_name):
                        # when init optimizer with name
                        accum_name = var_name[len(param_name + "_" + opt_name +
                                                  "_"):]
                    elif var_name.startswith(param_name +
                                             "_") and opt_name == opt_cls_name:
                        # when init optimizer without name
                        accum_name = var_name[len(param_name + "_"):]
                    else:
                        continue
                    # remove suffix idx
                    accum_name = accum_name[:accum_name.rfind("_")]
                    # state names always end with "_0" in dygraph because of the
                    # unique optimizer._name
                    dy_state_name = (param_name + "_" + opt_unq_name + "_" +
                                     accum_name + "_0")
                    converted_state[dy_state_name] = state_var

        self.model._optimizer.set_dict(converted_state)
Y
Yang Zhang 已提交
661 662 663


class Model(fluid.dygraph.Layer):
Q
qingqing01 已提交
664 665 666 667
    """
    FIXME: add more comments and usage
    """

Y
Yang Zhang 已提交
668 669 670
    def __init__(self):
        super(Model, self).__init__(self.__class__.__name__)
        self.mode = 'train'
671 672
        self._inputs = None
        self._labels = None
Y
Yang Zhang 已提交
673
        self._loss_function = None
Y
Yang Zhang 已提交
674
        self._loss_weights = None
Y
Yang Zhang 已提交
675
        self._optimizer = None
676 677
        self._device = None
        self._device_ids = None
Y
Yang Zhang 已提交
678
        self._optimizer = None
L
LielinJiang 已提交
679
        self._distributed_sampler = None
L
LielinJiang 已提交
680 681
        self._test_dataloader = None

Y
Yang Zhang 已提交
682 683 684 685 686 687 688 689 690 691 692 693 694 695 696
        if in_dygraph_mode():
            self._adapter = DynamicGraphAdapter(self)
        else:
            self._adapter = StaticGraphAdapter(self)

    def train(self, *args, **kwargs):
        return self._adapter.train(*args, **kwargs)

    def eval(self, *args, **kwargs):
        return self._adapter.eval(*args, **kwargs)

    def test(self, *args, **kwargs):
        return self._adapter.test(*args, **kwargs)

    def save(self, *args, **kwargs):
L
LielinJiang 已提交
697 698
        if distributed.get_local_rank() == 0:
            return self._adapter.save(*args, **kwargs)
Y
Yang Zhang 已提交
699 700 701 702

    def load(self, *args, **kwargs):
        return self._adapter.load(*args, **kwargs)

Q
qingqing01 已提交
703 704 705
    def parameters(self, *args, **kwargs):
        return self._adapter.parameters(*args, **kwargs)

706 707 708
    def prepare(self,
                optimizer=None,
                loss_function=None,
D
dengkaipeng 已提交
709
                metrics=None,
710 711 712 713
                inputs=None,
                labels=None,
                device=None,
                device_ids=None):
714 715 716
        """
        FIXME: add comments
        Args:
717 718 719 720 721 722
            optimizer (Optimizer|None): optimizer must be set in training
                and should be a Optimizer instance. It can be None in eval
                and test mode.
            loss_function (Loss|None): loss function must be set in training
                and should be a Loss instance. It can be None when there is
                no loss.
D
dengkaipeng 已提交
723 724
            metrics (Metric|list of Metric|None): if metrics is set, all
                metric will be calculate and output in train/eval mode.
725 726 727 728
            inputs (Input|list|dict|None): inputs, entry points of network,
                could be a Input layer, or lits of Input layers,
                or dict (name: Input), or None. For static graph,
                inputs must be set. For dynamic graph, it could be None.
729
            labels (Input|list|None): labels, entry points of network,
730 731 732
                could be a Input layer or lits of Input layers, or None.
                For static graph, if set loss_function in Model.prepare(), it
                must be set. Otherwise, it could be None.
733 734 735 736 737 738 739 740 741 742 743 744 745 746
            device (str|None): specify device type, 'CPU' or 'GPU'.
                If None, automatically select device according to
                installation package version.
            device_ids (list[int]|None): specify device index. If None,
                the available device will be obtained from the environment
                variable when the model is executed: If the GPU is used, the
                currently available device ID is obtained from the environment
                variable FLAGS_selected_gpus or CUDA_VISIBLE_DEVICES when the
                model is executed; CPU, when the model is executed,
                the currently available CPU number is obtained from the
                environment variable CPU_NUM. For example, export CPU_NUM=4,
                if the environment variable is not set, the executor will add
                the variable to the environment variable and set its value to 1.
                The default is None.
747
        """
Y
Yang Zhang 已提交
748
        self._optimizer = optimizer
Q
qingqing01 已提交
749 750 751 752
        if loss_function:
            if not isinstance(loss_function, Loss):
                raise TypeError(
                    "'loss_function' must be sub classes of 'Loss'")
Y
Yang Zhang 已提交
753
        self._loss_function = loss_function
754 755 756 757 758 759
        if not in_dygraph_mode():
            if not isinstance(inputs, (list, dict, Input)):
                raise TypeError(
                    "'inputs' must be list or dict in static graph mode")
            if loss_function and not isinstance(labels, (list, Input)):
                raise TypeError("'labels' must be list in static graph mode")
D
dengkaipeng 已提交
760 761

        metrics = metrics or []
D
dengkaipeng 已提交
762 763 764 765
        for metric in to_list(metrics):
            assert isinstance(metric, Metric), \
                "{} is not sub class of Metric".format(metric.__class__.__name__)
        self._metrics = to_list(metrics)
766

767 768
        self._inputs = inputs
        self._labels = labels
769
        self._device = device
L
LielinJiang 已提交
770
        
771 772 773 774 775
        if device is None:
            self._device = 'GPU' if fluid.is_compiled_with_cuda() else 'CPU'
        self._device_ids = device_ids
        if not in_dygraph_mode():
            self._adapter.prepare()
776

Q
qingqing01 已提交
777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802
    def fit(
            self,
            train_loader=None,
            eval_loader=None,
            epochs=1,
            eval_freq=1,
            log_freq=10,
            save_freq=1,
            verbose=2,
            callbacks=None, ):
        """
        FIXME: add more comments and usage
        Args:
            train_loader (DataLoader): an iterable data loader is used for train.
            eval_loader (DataLoader): an iterable data loader is used for
                evaluation at the end of epoch. If None, will not do evaluation.
            epochs (int): number of epochs to train the model.
            eval_freq (int): evaluation frequency in epoch.
            log_freq (int): frequency to print log during training.
            save_freq (int): frequency to save checkpoint during training.
            verbose (int): verbosity mode, should be 0, 1, or 2.
                0 = silent, 1 = progress bar, 2 = one line per epoch.
            callbacks (Callback|None): list of `Callback` instances to apply
                during training.
        """
        do_eval = eval_loader is not None
L
LielinJiang 已提交
803
        self._test_dataloader = eval_loader
Q
qingqing01 已提交
804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821
        metrics_name = self._metrics_name()
        cbks = config_callbacks(
            callbacks,
            model=self,
            epochs=epochs,
            steps=None,
            log_freq=log_freq,
            save_freq=save_freq,
            verbose=verbose,
            metrics=self._metrics_name(), )

        def _run_one_epoch(data_loader, callbacks, mode):
            size = data_loader.size if hasattr(data_loader, 'size') else None
            logs = {
                'steps': size,
                'metrics_name': metrics_name,
            }
            for step, data in enumerate(data_loader):
L
LielinJiang 已提交
822 823 824 825 826 827
                if not fluid.in_dygraph_mode():
                    data = data[0]
                    batch_size = data[0].shape()[0]
                else:
                    batch_size = data[0].shape[0]

Q
qingqing01 已提交
828 829 830 831 832 833
                cbks.on_batch_begin(mode, step, logs)
                if mode == 'train':
                    outs = self.train(*data)
                else:
                    outs = self.eval(*data)

Q
qingqing01 已提交
834 835 836 837 838
                # losses
                loss = outs[0] if self._metrics else outs
                metrics = [[l[0] for l in loss]]

                # metrics
Q
qingqing01 已提交
839 840 841
                for metric in self._metrics:
                    res = metric.accumulate()
                    metrics.extend(to_list(res))
L
LielinJiang 已提交
842
                    
Q
qingqing01 已提交
843 844
                assert len(metrics_name) == len(metrics)
                for k, v in zip(metrics_name, metrics):
Q
qingqing01 已提交
845
                    logs[k] = v
Q
qingqing01 已提交
846 847

                logs['step'] = step
L
LielinJiang 已提交
848
                logs['batch_size'] = batch_size
Q
qingqing01 已提交
849 850 851 852 853 854 855 856

                cbks.on_batch_end(mode, step, logs)
            self._reset_metrics()
            return logs

        cbks.on_begin('train')
        for epoch in range(epochs):
            cbks.on_epoch_begin(epoch)
Q
qingqing01 已提交
857
            # FIXME: adapt to DataLoader
Q
qingqing01 已提交
858 859 860 861 862 863 864 865
            loader = train_loader
            if not isinstance(train_loader, Iterable):
                loader = train_loader()
            logs = _run_one_epoch(loader, cbks, 'train')
            cbks.on_epoch_end(epoch, logs)

            if do_eval and epoch % eval_freq == 0:
                cbks.on_begin('eval', logs)
Q
qingqing01 已提交
866
                # FIXME: adapt to DataLoader
Q
qingqing01 已提交
867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883
                loader = eval_loader
                if not isinstance(eval_loader, Iterable):
                    loader = eval_loader()
                logs = _run_one_epoch(eval_loader(), cbks, 'eval')
                cbks.on_end('eval', logs)

        cbks.on_end('train', logs)

    def _reset_metrics(self):
        for metric in self._metrics:
            metric.reset()

    def _metrics_name(self):
        metrics_name = ['loss']
        for m in self._metrics:
            metrics_name.extend(to_list(m.name()))
        return metrics_name