model.py 19.4 KB
Newer Older
Y
Yang Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
Yang Zhang 已提交
15 16 17
from __future__ import absolute_import

import inspect
Y
Yang Zhang 已提交
18 19
import os
import pickle
Y
Yang Zhang 已提交
20 21 22 23 24
from collections import OrderedDict

import numpy as np

from paddle import fluid
Y
Yang Zhang 已提交
25 26 27
from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer
Y
Yang Zhang 已提交
28
from paddle.fluid.dygraph.base import to_variable
29
from metrics import Metric
Y
Yang Zhang 已提交
30

Y
Yang Zhang 已提交
31
__all__ = ['shape_hints', 'Model', 'Loss', 'CrossEntropy']
Y
Yang Zhang 已提交
32 33 34 35 36 37 38 39


def to_list(value):
    if isinstance(value, (list, tuple)):
        return value
    return [value]


40 41 42 43 44 45 46 47
def to_numpy(var):
    assert isinstance(var, (Variable, fluid.core.VarBase)), "not a variable"
    if isinstance(var, fluid.core.VarBase):
        return var.numpy()
    t = global_scope().find_var(var.name).get_tensor()
    return np.array(t)


D
dengkaipeng 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
def flatten_list(l):
    assert isinstance(l, list), "not a list"
    outl = []
    splits = []
    for sl in l:
        assert isinstance(sl, list), "sub content not a list"
        splits.append(len(sl))
        outl += sl
    return outl, splits


def restore_flatten_list(l, splits):
    outl = []
    for split in splits:
        assert len(l) >= split, "list length invalid"
        sl, l = l[:split], l[split:]
        outl.append(sl)
    return outl


Y
Yang Zhang 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
def extract_args(func):
    if hasattr(inspect, 'getfullargspec'):
        return inspect.getfullargspec(func)[0]
    else:
        return inspect.getargspec(func)[0]


def shape_hints(**hints):
    assert hints, "hints can not be empty"
    assert all(isinstance(h, (list, tuple)) for h in hints.values()), \
        "shape hint must be a list or tuple"

    def wrapper(func):
        args = extract_args(func)
        invalid = set(hints.keys()) - set(args)
        assert not invalid, \
            "shape hint for arguments that are not present in forward method" \
            + ": ({})".format(", ".join(invalid))
        func.shape_hints = hints
        return func
    return wrapper


Y
Yang Zhang 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
class Loss(object):
    def __init__(self, average=True):
        super(Loss, self).__init__()
        self.average = average

    def infer_shape(self, outputs):
        return [o.shape for o in outputs]

    def infer_dtype(self, outputs):
        return [o.dtype for o in outputs]

    def forward(self, outputs, labels):
        raise NotImplementedError()

    def __call__(self, outputs, labels):
        labels = to_list(labels)
        if in_dygraph_mode():
            labels = [to_variable(l) for l in labels]
        losses = to_list(self.forward(to_list(outputs), labels))
        if not self.average:
            return losses
        return [fluid.layers.reduce_mean(l) for l in losses]


class CrossEntropy(Loss):
    def __init__(self):
        super(CrossEntropy, self).__init__()

    def infer_shape(self, outputs):
        return [o.shape[:-1] + (1, ) for o in outputs]

    def infer_dtype(self, outputs):
        return ['int64' for _ in outputs]

    def forward(self, outputs, labels):
        return [fluid.layers.cross_entropy(o, l) for o, l in zip(
            outputs, labels)]


Y
Yang Zhang 已提交
130 131 132 133 134 135 136
class StaticGraphAdapter(object):
    def __init__(self, model):
        super(StaticGraphAdapter, self).__init__()
        self.model = model
        # with `_build_once` gone, parameters are now created in `__init__`
        # so we need to keep track of the parameters already created
        self._startup_prog = fluid.default_startup_program()
137
        self._orig_prog = fluid.default_main_program()
Y
Yang Zhang 已提交
138

139
        self._label_vars = {}  # label variables
Y
Yang Zhang 已提交
140 141 142 143 144 145
        self._endpoints = {}
        self._loss_endpoint = None
        self._executor = None
        self._progs = {}
        self._compiled_progs = {}

146 147
        self._lazy_load_optimizer = None

Y
Yang Zhang 已提交
148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
        # parse shape hints
        self._input_desc = OrderedDict([
            (n, None) for n in extract_args(self.model.forward) if n != 'self'
        ])
        if hasattr(self.model.forward, 'shape_hints'):
            self._input_desc.update(self.model.forward.shape_hints)

    @property
    def mode(self):
        return self.model.mode

    @mode.setter
    def mode(self, value):
        self.model.mode = value

    def train(self, inputs, labels, device='CPU', device_ids=None):
Y
Yang Zhang 已提交
164
        assert self.model._optimizer and self.model._loss_function, \
Y
Yang Zhang 已提交
165 166 167 168 169
            "model not ready, please call `model.prepare()` first"
        self.mode = 'train'
        return self._run(inputs, labels, device, device_ids)

    def eval(self, inputs, labels, device='CPU', device_ids=None):
Y
Yang Zhang 已提交
170
        assert self.model._loss_function, \
Y
Yang Zhang 已提交
171 172 173 174 175 176 177 178
            "model not ready, please call `model.prepare()` first"
        self.mode = 'eval'
        return self._run(inputs, labels, device, device_ids)

    def test(self, inputs, device='CPU', device_ids=None):
        self.mode = 'test'
        return self._run(inputs, None, device, device_ids)

179 180 181
    def parameters(self, *args, **kwargs):
        return None

Y
Yang Zhang 已提交
182
    def save(self, path):
Y
Yang Zhang 已提交
183 184 185
        def _save(state, path):
            if not state:
                return
186 187
            state = {k: to_numpy(v) if isinstance(v, Variable) else v
                     for k, v in state.items()}
Y
Yang Zhang 已提交
188 189 190 191 192 193 194
            with open(path, 'wb') as f:
                pickle.dump(state, f)

        base = os.path.basename(path)
        assert base != "", "path should be of 'dirname/filename' format"
        param_path = path + ".pdparams"
        _save(self.model.state_dict(), param_path)
Y
Yang Zhang 已提交
195 196
        prog = self._progs.get('train', None)
        if prog is None or self.model._optimizer is None:
Y
Yang Zhang 已提交
197 198 199 200 201
            return
        # XXX `optimizer.state_dict()` only work in dygraph mode
        optim_path = path + ".pdopt"
        optim = {p.name: p for p in filter(
            is_belong_to_optimizer, prog.list_vars())}
202 203
        if not optim:
            return
Y
Yang Zhang 已提交
204 205 206 207
        # HACK this is contrived, optimizer state is not the same for
        # static/dynamic graph mode
        optim['__static_graph_only__'] = True
        _save(optim, optim_path)
Y
Yang Zhang 已提交
208 209

    def load(self, path):
Y
Yang Zhang 已提交
210 211 212 213 214 215 216
        def _load(path):
            if not os.path.exists(path):
                return
            with open(path, 'rb') as f:
                return pickle.load(f)

        param_path = path + ".pdparams"
Y
Yang Zhang 已提交
217 218 219 220 221 222 223 224 225 226 227
        param_state = _load(param_path)
        assert param_state, "failed to load parameters, please check path"

        if self._executor is None:
            executor = fluid.Executor(fluid.CPUPlace())._default_executor
        else:
            executor = self._executor._default_executor

        fluid.core._create_loaded_parameter(
            list(self.model.state_dict().values()), global_scope(), executor)

Y
Yang Zhang 已提交
228
        for key, var in self.model.state_dict().items():
Y
Yang Zhang 已提交
229
            assert key in param_state, \
Y
Yang Zhang 已提交
230 231
                "parameter [{}] is not found in model file [{}]".format(
                    key, param_path)
232
            self._set_var(var, param_state[key])
Y
Yang Zhang 已提交
233 234 235 236 237 238 239 240 241 242 243

        # FIXME what if a different optimizer is used?
        if not self.model._optimizer:
            return
        optim_path = path + ".pdopt"
        optim_state = _load(optim_path)
        if optim_state is None:
            return
        assert '__static_graph_only__' in optim_state, \
            "optimizer saved in dygraph mode is not usable in static graph"

244 245 246 247 248 249 250 251 252 253 254
        if self._executor is not None:
           self._load_optimizer(optim_state)
        else:
           self._lazy_load_optimizer = optim_state

    def _load_optimizer(self, state):
        prog = self._progs.get('train', None)
        optim = list(filter(is_belong_to_optimizer, prog.list_vars()))
        if not optim:
            return

Y
Yang Zhang 已提交
255
        fluid.core._create_loaded_parameter(
256
            optim, global_scope(), self._executor._default_executor)
Y
Yang Zhang 已提交
257 258

        for var in optim:
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
            assert var.name in state, \
                "variable [{}] is not in optimizer state file".format(var.name)
            self._set_var(var, state[var.name])

    def _set_var(self, var, ndarray):
        t = global_scope().find_var(var.name).get_tensor()
        p = t._place()
        if p.is_cpu_place():
            place = fluid.CPUPlace()
        elif p.is_cuda_pinned_place():
            place = fluid.CUDAPinnedPlace()
        else:
            p = fluid.core.Place()
            p.set_place(t._place())
            place = fluid.CUDAPlace(p.gpu_device_id())

        t.set(ndarray, place)
Y
Yang Zhang 已提交
276 277 278 279 280 281 282 283

    def _run(self, inputs, labels=None, device='CPU', device_ids=None):
        inputs = to_list(inputs)
        if labels is not None:
            labels = to_list(labels)
        assert len(inputs) == len(self._input_desc), "number of inputs" \
            + " does not match number of arguments of `forward` method"

Y
Yang Zhang 已提交
284 285
        if self._progs.get(self.mode, None) is None:
            self._make_program(self._infer_input_vars(inputs))
Y
Yang Zhang 已提交
286

287 288
        compiled_prog = self._compile_and_initialize(
            self._progs[self.mode], device, device_ids)
Y
Yang Zhang 已提交
289 290 291 292 293 294 295 296

        feed = {}
        input_names = [name for name in self._input_desc.keys()]
        for idx, n in enumerate(input_names):
            # train and test may take different arguments
            if inputs[idx] is not None:
                feed[n] = inputs[idx]
        if labels is not None:
297
            for idx, v in enumerate(self._label_vars[self.mode]):
Y
Yang Zhang 已提交
298 299
                feed[v.name] = labels[idx]

300
        endpoints = self._endpoints[self.mode]
D
dengkaipeng 已提交
301 302 303 304 305 306
        if self.mode == 'test':
            fetch_list = endpoints['output']
        else:
            metric_list, metric_splits = flatten_list(endpoints['metric'])
            fetch_list = endpoints['loss'] + metric_list
            num_loss = len(endpoints['loss'])
D
dengkaipeng 已提交
307
        rets = self._executor.run(
Y
Yang Zhang 已提交
308
            compiled_prog, feed=feed,
D
dengkaipeng 已提交
309 310
            fetch_list=fetch_list,
            return_numpy=False)
D
dengkaipeng 已提交
311 312
        # LoDTensor cannot be fetch as numpy directly
        rets = [np.array(v) for v in rets]
313
        if self.mode == 'test':
D
dengkaipeng 已提交
314 315 316 317 318 319 320
            return rets[:]
        losses = rets[:num_loss]
        metric_states = restore_flatten_list(rets[num_loss:], metric_splits)
        metrics = []
        for metric, state in zip(self.model._metrics, metric_states):
            metrics.append(metric.update(*state))
        return losses, metrics
Y
Yang Zhang 已提交
321 322

    def _make_program(self, inputs):
323
        prog = self._orig_prog.clone()
324
        if self.mode == 'train' and self.model._optimizer._learning_rate_map:
325 326 327 328
            # HACK workaround learning rate map issue
            lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
            self.model._optimizer._learning_rate_map[prog] = lr_var
        losses = []
Y
Yang Zhang 已提交
329 330 331
        with fluid.program_guard(prog, self._startup_prog):
            outputs = to_list(self.model.forward(*inputs))
            if self.mode != 'test':
Y
Yang Zhang 已提交
332
                label_vars = self._infer_label_vars(outputs)
333
                self._label_vars[self.mode] = label_vars
334
                losses = self.model._loss_function(outputs, label_vars)
D
dengkaipeng 已提交
335 336 337
                metrics = []
                for metric in self.model._metrics:
                    metrics.append(to_list(metric.add_metric_op(outputs, label_vars)))
Y
Yang Zhang 已提交
338 339 340
                if self.mode == 'train':
                    self._loss_endpoint = fluid.layers.sum(losses)
                    self.model._optimizer.minimize(self._loss_endpoint)
341 342
        if self.mode != 'train':  # clone again to put it in test mode
            prog = prog.clone(for_test=True)
Y
Yang Zhang 已提交
343
        self._progs[self.mode] = prog
344
        self._endpoints[self.mode] = {
345
            "output": outputs,
D
dengkaipeng 已提交
346
            "loss": losses,
D
dengkaipeng 已提交
347
            "metric": metrics,
348
        }
Y
Yang Zhang 已提交
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363

    def _infer_input_vars(self, inputs):
        input_vars = []
        for idx, i in enumerate(inputs):
            if i is None:  # train and test may take different arguments
                input_vars.append(None)
                continue
            ndarray = np.array(i)
            name = list(self._input_desc.keys())[idx]
            shape = list(self._input_desc.values())[idx]
            if shape is None:
                shape = (None, ) + ndarray.shape[1:]
            input_vars.append(fluid.data(name, shape, ndarray.dtype))
        return input_vars

Y
Yang Zhang 已提交
364 365 366 367 368 369 370 371
    def _infer_label_vars(self, outputs):
        shapes = self.model._loss_function.infer_shape(outputs)
        dtypes = self.model._loss_function.infer_dtype(outputs)
        label_vars = []
        for idx, (shape, dtype) in enumerate(zip(shapes, dtypes)):
            name = '__label{}'.format(idx)
            label_vars.append(fluid.data(name, shape, dtype))
        return label_vars
Y
Yang Zhang 已提交
372 373

    def _compile_and_initialize(self, prog, device='CPU', device_ids=None):
374 375 376
        compiled_prog = self._compiled_progs.get(self.mode, None)
        if compiled_prog is not None:
            return compiled_prog
Y
Yang Zhang 已提交
377

378 379
        places = [device.lower() == 'gpu' and fluid.CUDAPlace(i)
                  or fluid.CPUPlace() for i in device_ids]
Y
Yang Zhang 已提交
380

Y
Yang Zhang 已提交
381 382 383
        # XXX *ALL WEIGHTS* should be initialized upon model construction
        # even if `forward()` may run different code path for different mode
        # therefore startup program only needs to run once
Y
Yang Zhang 已提交
384
        if self._executor is None:
385
            self._executor = fluid.Executor(places[0])
Y
Yang Zhang 已提交
386 387 388 389 390 391 392 393 394 395
            # XXX incremental initialization
            uninitialized = []
            for var_py in self._startup_prog.list_vars():
                var = fluid.global_scope().find_var(var_py.name)
                if var and var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var_py)
            if uninitialized:
                startup_prog = self._startup_prog._prune(uninitialized)
                self._executor.run(startup_prog)
Y
Yang Zhang 已提交
396

397 398 399 400
            if self.mode == 'train' and self._lazy_load_optimizer:
                self._load_optimizer(self._lazy_load_optimizer)
                self._lazy_load_optimizer = None

401
        compiled_prog = fluid.CompiledProgram(prog)
402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419
        if len(device_ids) > 1:
            loss_name = None
            if self.mode == 'train' and self._loss_endpoint is not None:
                loss_name = self._loss_endpoint.name

            share_vars_from = None
            if self.mode == 'eval' and 'train' in self._compiled_progs:
                share_vars_from = self._compiled_progs['train']
            # HACK invalidate eval program if is compiled before train program
            # quite hackish, OTOH, it is generally uncommon that the eval
            # program will be run before the train program
            if self.mode == 'train' and 'eval' in self._compiled_progs:
                del self._compiled_progs['eval']

            compiled_prog = compiled_prog.with_data_parallel(
                loss_name=loss_name, places=places,
                share_vars_from=share_vars_from)

420
        self._compiled_progs[self.mode] = compiled_prog
Y
Yang Zhang 已提交
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436
        return compiled_prog


class DynamicGraphAdapter(object):
    def __init__(self, model):
        super(DynamicGraphAdapter, self).__init__()
        self.model = model

    @property
    def mode(self):
        return self.model.mode

    @mode.setter
    def mode(self, value):
        self.model.mode = value

437
    # TODO multi device in dygraph mode not implemented at present time
Y
Yang Zhang 已提交
438
    def train(self, inputs, labels, device='CPU', device_ids=None):
Y
Yang Zhang 已提交
439
        assert self.model._optimizer and self.model._loss_function, \
Y
Yang Zhang 已提交
440 441 442 443 444
            "model not ready, please call `model.prepare()` first"
        super(Model, self.model).train()
        self.mode = 'train'
        inputs = to_list(inputs)
        labels = to_list(labels)
D
dengkaipeng 已提交
445
        outputs = to_list(self.model.forward(*[to_variable(x) for x in inputs]))
Y
Yang Zhang 已提交
446
        losses = self.model._loss_function(outputs, labels)
Y
Yang Zhang 已提交
447 448 449 450
        final_loss = fluid.layers.sum(losses)
        final_loss.backward()
        self.model._optimizer.minimize(final_loss)
        self.model.clear_gradients()
D
dengkaipeng 已提交
451
        metrics = []
452
        for metric in self.model._metrics:
D
dengkaipeng 已提交
453 454 455 456
            metric_outs = metric.add_metric_op(outputs, [to_variable(l) for l in labels])
            m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
            metrics.append(m)
        return [to_numpy(l) for l in losses], metrics
Y
Yang Zhang 已提交
457 458

    def eval(self, inputs, labels, device='CPU', device_ids=None):
Y
Yang Zhang 已提交
459
        assert self.model._loss_function, \
Y
Yang Zhang 已提交
460
            "model not ready, please call `model.prepare()` first"
461
        super(Model, self.model).eval()
Y
Yang Zhang 已提交
462 463 464
        self.mode = 'eval'
        inputs = to_list(inputs)
        labels = to_list(labels)
D
dengkaipeng 已提交
465
        outputs = to_list(self.model.forward(*[to_variable(x) for x in inputs]))
466
        losses = self.model._loss_function(outputs, labels)
D
dengkaipeng 已提交
467
        metrics = []
D
dengkaipeng 已提交
468
        for metric in self.model._metrics:
D
dengkaipeng 已提交
469 470 471 472
            metric_outs = metric.add_metric_op(outputs, [to_variable(l) for l in labels])
            m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
            metrics.append(m)
        return [to_numpy(l) for l in losses], metrics
Y
Yang Zhang 已提交
473 474

    def test(self, inputs, device='CPU', device_ids=None):
475
        super(Model, self.model).eval()
Y
Yang Zhang 已提交
476
        self.mode = 'test'
477
        inputs = [to_variable(x) for x in to_list(inputs)]
478 479
        outputs = self.model.forward(*inputs)
        return [to_numpy(o) for o in to_list(outputs)]
Y
Yang Zhang 已提交
480

481 482 483
    def parameters(self, *args, **kwargs):
        return super(Model, self.model).parameters(*args, **kwargs)

Y
Yang Zhang 已提交
484 485 486 487 488 489 490 491 492 493 494 495
    def save(self, path):
        params = self.model.state_dict()
        fluid.save_dygraph(params, path)
        if self.model._optimizer is None:
            return
        if self.model._optimizer.state_dict():
            optim = self.model._optimizer.state_dict()
            fluid.save_dygraph(optim, path)

    def load(self, path):
        params, optim = fluid.load_dygraph(path)
        self.model.set_dict(params)
Y
Yang Zhang 已提交
496
        if self.model._optimizer is None or optim is None:
Y
Yang Zhang 已提交
497 498 499 500 501 502 503 504
            return
        self.model._optimizer.set_dict(optim)


class Model(fluid.dygraph.Layer):
    def __init__(self):
        super(Model, self).__init__(self.__class__.__name__)
        self.mode = 'train'
Y
Yang Zhang 已提交
505
        self._loss_function = None
Y
Yang Zhang 已提交
506
        self._loss_weights = None
Y
Yang Zhang 已提交
507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527
        self._optimizer = None
        if in_dygraph_mode():
            self._adapter = DynamicGraphAdapter(self)
        else:
            self._adapter = StaticGraphAdapter(self)

    def train(self, *args, **kwargs):
        return self._adapter.train(*args, **kwargs)

    def eval(self, *args, **kwargs):
        return self._adapter.eval(*args, **kwargs)

    def test(self, *args, **kwargs):
        return self._adapter.test(*args, **kwargs)

    def save(self, *args, **kwargs):
        return self._adapter.save(*args, **kwargs)

    def load(self, *args, **kwargs):
        return self._adapter.load(*args, **kwargs)

D
dengkaipeng 已提交
528
    def prepare(self, optimizer, loss_function, metrics=[]):
Y
Yang Zhang 已提交
529
        self._optimizer = optimizer
Y
Yang Zhang 已提交
530 531 532
        assert isinstance(loss_function, Loss), \
            "'loss_function' must be sub classes of 'Loss'"
        self._loss_function = loss_function
D
dengkaipeng 已提交
533 534 535 536
        for metric in to_list(metrics):
            assert isinstance(metric, Metric), \
                "{} is not sub class of Metric".format(metric.__class__.__name__)
        self._metrics = to_list(metrics)
537 538 539

    def parameters(self, *args, **kwargs):
        return self._adapter.parameters(*args, **kwargs)