model.py 18.6 KB
Newer Older
Y
Yang Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
Yang Zhang 已提交
15 16 17
from __future__ import absolute_import

import inspect
Y
Yang Zhang 已提交
18 19
import os
import pickle
Y
Yang Zhang 已提交
20 21 22 23 24
from collections import OrderedDict

import numpy as np

from paddle import fluid
Y
Yang Zhang 已提交
25 26 27
from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer
Y
Yang Zhang 已提交
28
from paddle.fluid.dygraph.base import to_variable
D
dengkaipeng 已提交
29
from metrics.metric import Metric
Y
Yang Zhang 已提交
30

Y
Yang Zhang 已提交
31
__all__ = ['shape_hints', 'Model', 'Loss', 'CrossEntropy']
Y
Yang Zhang 已提交
32 33 34 35 36 37 38 39


def to_list(value):
    if isinstance(value, (list, tuple)):
        return value
    return [value]


40 41 42 43 44 45 46 47
def to_numpy(var):
    assert isinstance(var, (Variable, fluid.core.VarBase)), "not a variable"
    if isinstance(var, fluid.core.VarBase):
        return var.numpy()
    t = global_scope().find_var(var.name).get_tensor()
    return np.array(t)


Y
Yang Zhang 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
def extract_args(func):
    if hasattr(inspect, 'getfullargspec'):
        return inspect.getfullargspec(func)[0]
    else:
        return inspect.getargspec(func)[0]


def shape_hints(**hints):
    assert hints, "hints can not be empty"
    assert all(isinstance(h, (list, tuple)) for h in hints.values()), \
        "shape hint must be a list or tuple"

    def wrapper(func):
        args = extract_args(func)
        invalid = set(hints.keys()) - set(args)
        assert not invalid, \
            "shape hint for arguments that are not present in forward method" \
            + ": ({})".format(", ".join(invalid))
        func.shape_hints = hints
        return func
    return wrapper


Y
Yang Zhang 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
class Loss(object):
    def __init__(self, average=True):
        super(Loss, self).__init__()
        self.average = average

    def infer_shape(self, outputs):
        return [o.shape for o in outputs]

    def infer_dtype(self, outputs):
        return [o.dtype for o in outputs]

    def forward(self, outputs, labels):
        raise NotImplementedError()

    def __call__(self, outputs, labels):
        labels = to_list(labels)
        if in_dygraph_mode():
            labels = [to_variable(l) for l in labels]
        losses = to_list(self.forward(to_list(outputs), labels))
        if not self.average:
            return losses
        return [fluid.layers.reduce_mean(l) for l in losses]


class CrossEntropy(Loss):
    def __init__(self):
        super(CrossEntropy, self).__init__()

    def infer_shape(self, outputs):
        return [o.shape[:-1] + (1, ) for o in outputs]

    def infer_dtype(self, outputs):
        return ['int64' for _ in outputs]

    def forward(self, outputs, labels):
        return [fluid.layers.cross_entropy(o, l) for o, l in zip(
            outputs, labels)]


Y
Yang Zhang 已提交
110 111 112 113 114 115 116
class StaticGraphAdapter(object):
    def __init__(self, model):
        super(StaticGraphAdapter, self).__init__()
        self.model = model
        # with `_build_once` gone, parameters are now created in `__init__`
        # so we need to keep track of the parameters already created
        self._startup_prog = fluid.default_startup_program()
117
        self._orig_prog = fluid.default_main_program()
Y
Yang Zhang 已提交
118

119
        self._label_vars = {}  # label variables
Y
Yang Zhang 已提交
120 121 122 123 124 125
        self._endpoints = {}
        self._loss_endpoint = None
        self._executor = None
        self._progs = {}
        self._compiled_progs = {}

126 127
        self._lazy_load_optimizer = None

Y
Yang Zhang 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
        # parse shape hints
        self._input_desc = OrderedDict([
            (n, None) for n in extract_args(self.model.forward) if n != 'self'
        ])
        if hasattr(self.model.forward, 'shape_hints'):
            self._input_desc.update(self.model.forward.shape_hints)

    @property
    def mode(self):
        return self.model.mode

    @mode.setter
    def mode(self, value):
        self.model.mode = value

    def train(self, inputs, labels, device='CPU', device_ids=None):
Y
Yang Zhang 已提交
144
        assert self.model._optimizer and self.model._loss_function, \
Y
Yang Zhang 已提交
145 146 147 148 149
            "model not ready, please call `model.prepare()` first"
        self.mode = 'train'
        return self._run(inputs, labels, device, device_ids)

    def eval(self, inputs, labels, device='CPU', device_ids=None):
Y
Yang Zhang 已提交
150
        assert self.model._loss_function, \
Y
Yang Zhang 已提交
151 152 153 154 155 156 157 158
            "model not ready, please call `model.prepare()` first"
        self.mode = 'eval'
        return self._run(inputs, labels, device, device_ids)

    def test(self, inputs, device='CPU', device_ids=None):
        self.mode = 'test'
        return self._run(inputs, None, device, device_ids)

159 160 161
    def parameters(self, *args, **kwargs):
        return None

Y
Yang Zhang 已提交
162
    def save(self, path):
Y
Yang Zhang 已提交
163 164 165
        def _save(state, path):
            if not state:
                return
166 167
            state = {k: to_numpy(v) if isinstance(v, Variable) else v
                     for k, v in state.items()}
Y
Yang Zhang 已提交
168 169 170 171 172 173 174
            with open(path, 'wb') as f:
                pickle.dump(state, f)

        base = os.path.basename(path)
        assert base != "", "path should be of 'dirname/filename' format"
        param_path = path + ".pdparams"
        _save(self.model.state_dict(), param_path)
Y
Yang Zhang 已提交
175 176
        prog = self._progs.get('train', None)
        if prog is None or self.model._optimizer is None:
Y
Yang Zhang 已提交
177 178 179 180 181
            return
        # XXX `optimizer.state_dict()` only work in dygraph mode
        optim_path = path + ".pdopt"
        optim = {p.name: p for p in filter(
            is_belong_to_optimizer, prog.list_vars())}
182 183
        if not optim:
            return
Y
Yang Zhang 已提交
184 185 186 187
        # HACK this is contrived, optimizer state is not the same for
        # static/dynamic graph mode
        optim['__static_graph_only__'] = True
        _save(optim, optim_path)
Y
Yang Zhang 已提交
188 189

    def load(self, path):
Y
Yang Zhang 已提交
190 191 192 193 194 195 196
        def _load(path):
            if not os.path.exists(path):
                return
            with open(path, 'rb') as f:
                return pickle.load(f)

        param_path = path + ".pdparams"
Y
Yang Zhang 已提交
197 198 199 200 201 202 203 204 205 206 207
        param_state = _load(param_path)
        assert param_state, "failed to load parameters, please check path"

        if self._executor is None:
            executor = fluid.Executor(fluid.CPUPlace())._default_executor
        else:
            executor = self._executor._default_executor

        fluid.core._create_loaded_parameter(
            list(self.model.state_dict().values()), global_scope(), executor)

Y
Yang Zhang 已提交
208
        for key, var in self.model.state_dict().items():
Y
Yang Zhang 已提交
209
            assert key in param_state, \
Y
Yang Zhang 已提交
210 211
                "parameter [{}] is not found in model file [{}]".format(
                    key, param_path)
212
            self._set_var(var, param_state[key])
Y
Yang Zhang 已提交
213 214 215 216 217 218 219 220 221 222 223

        # FIXME what if a different optimizer is used?
        if not self.model._optimizer:
            return
        optim_path = path + ".pdopt"
        optim_state = _load(optim_path)
        if optim_state is None:
            return
        assert '__static_graph_only__' in optim_state, \
            "optimizer saved in dygraph mode is not usable in static graph"

224 225 226 227 228 229 230 231 232 233 234
        if self._executor is not None:
           self._load_optimizer(optim_state)
        else:
           self._lazy_load_optimizer = optim_state

    def _load_optimizer(self, state):
        prog = self._progs.get('train', None)
        optim = list(filter(is_belong_to_optimizer, prog.list_vars()))
        if not optim:
            return

Y
Yang Zhang 已提交
235
        fluid.core._create_loaded_parameter(
236
            optim, global_scope(), self._executor._default_executor)
Y
Yang Zhang 已提交
237 238

        for var in optim:
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
            assert var.name in state, \
                "variable [{}] is not in optimizer state file".format(var.name)
            self._set_var(var, state[var.name])

    def _set_var(self, var, ndarray):
        t = global_scope().find_var(var.name).get_tensor()
        p = t._place()
        if p.is_cpu_place():
            place = fluid.CPUPlace()
        elif p.is_cuda_pinned_place():
            place = fluid.CUDAPinnedPlace()
        else:
            p = fluid.core.Place()
            p.set_place(t._place())
            place = fluid.CUDAPlace(p.gpu_device_id())

        t.set(ndarray, place)
Y
Yang Zhang 已提交
256 257 258 259 260 261 262 263

    def _run(self, inputs, labels=None, device='CPU', device_ids=None):
        inputs = to_list(inputs)
        if labels is not None:
            labels = to_list(labels)
        assert len(inputs) == len(self._input_desc), "number of inputs" \
            + " does not match number of arguments of `forward` method"

Y
Yang Zhang 已提交
264 265
        if self._progs.get(self.mode, None) is None:
            self._make_program(self._infer_input_vars(inputs))
Y
Yang Zhang 已提交
266

267 268
        compiled_prog = self._compile_and_initialize(
            self._progs[self.mode], device, device_ids)
Y
Yang Zhang 已提交
269 270 271 272 273 274 275 276

        feed = {}
        input_names = [name for name in self._input_desc.keys()]
        for idx, n in enumerate(input_names):
            # train and test may take different arguments
            if inputs[idx] is not None:
                feed[n] = inputs[idx]
        if labels is not None:
277
            for idx, v in enumerate(self._label_vars[self.mode]):
Y
Yang Zhang 已提交
278 279
                feed[v.name] = labels[idx]

280
        endpoints = self._endpoints[self.mode]
D
dengkaipeng 已提交
281
        fetch_list = endpoints['output'] + endpoints['label'] + endpoints['loss']
282
        num_output = len(endpoints['output'])
D
dengkaipeng 已提交
283 284
        num_label = len(endpoints['label'])
        rets = self._executor.run(
Y
Yang Zhang 已提交
285
            compiled_prog, feed=feed,
D
dengkaipeng 已提交
286 287 288 289 290 291 292 293 294 295 296 297 298
            fetch_list=fetch_list,
            return_numpy=False)
        # rets = [(np.array(v), v.recursive_sequence_lengths()) if v.lod() for v in rets]
        np_rets = []
        for ret in rets:
            seq_len = ret.recursive_sequence_lengths()
            if len(seq_len) == 0:
                np_rets.append(np.array(ret))
            else:
                np_rets.append((np.array(ret), seq_len))
        outputs = np_rets[:num_output]
        labels = np_rets[num_output:num_output+num_label]
        losses = np_rets[num_output+num_label:]
299
        if self.mode == 'test':
D
dengkaipeng 已提交
300 301 302 303 304 305 306
            return outputs
        elif self.mode == 'eval':
            for metric in self.model._metrics:
                metric.update(outputs, labels)
            return outputs, losses
        else: # train
            return outputs, losses
Y
Yang Zhang 已提交
307 308

    def _make_program(self, inputs):
309
        prog = self._orig_prog.clone()
310
        if self.mode == 'train' and self.model._optimizer._learning_rate_map:
311 312 313 314
            # HACK workaround learning rate map issue
            lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
            self.model._optimizer._learning_rate_map[prog] = lr_var
        losses = []
Y
Yang Zhang 已提交
315 316 317
        with fluid.program_guard(prog, self._startup_prog):
            outputs = to_list(self.model.forward(*inputs))
            if self.mode != 'test':
Y
Yang Zhang 已提交
318
                label_vars = self._infer_label_vars(outputs)
319
                self._label_vars[self.mode] = label_vars
D
dengkaipeng 已提交
320
                losses = self.model._loss_function(outputs[0], label_vars)
Y
Yang Zhang 已提交
321 322 323
                if self.mode == 'train':
                    self._loss_endpoint = fluid.layers.sum(losses)
                    self.model._optimizer.minimize(self._loss_endpoint)
324 325
        if self.mode != 'train':  # clone again to put it in test mode
            prog = prog.clone(for_test=True)
Y
Yang Zhang 已提交
326
        self._progs[self.mode] = prog
327
        self._endpoints[self.mode] = {
D
dengkaipeng 已提交
328 329 330
            "output": outputs[1:],
            "label": label_vars,
            "loss": losses,
331
        }
Y
Yang Zhang 已提交
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346

    def _infer_input_vars(self, inputs):
        input_vars = []
        for idx, i in enumerate(inputs):
            if i is None:  # train and test may take different arguments
                input_vars.append(None)
                continue
            ndarray = np.array(i)
            name = list(self._input_desc.keys())[idx]
            shape = list(self._input_desc.values())[idx]
            if shape is None:
                shape = (None, ) + ndarray.shape[1:]
            input_vars.append(fluid.data(name, shape, ndarray.dtype))
        return input_vars

Y
Yang Zhang 已提交
347 348 349 350 351 352 353 354
    def _infer_label_vars(self, outputs):
        shapes = self.model._loss_function.infer_shape(outputs)
        dtypes = self.model._loss_function.infer_dtype(outputs)
        label_vars = []
        for idx, (shape, dtype) in enumerate(zip(shapes, dtypes)):
            name = '__label{}'.format(idx)
            label_vars.append(fluid.data(name, shape, dtype))
        return label_vars
Y
Yang Zhang 已提交
355 356

    def _compile_and_initialize(self, prog, device='CPU', device_ids=None):
357 358 359
        compiled_prog = self._compiled_progs.get(self.mode, None)
        if compiled_prog is not None:
            return compiled_prog
Y
Yang Zhang 已提交
360

361 362
        places = [device.lower() == 'gpu' and fluid.CUDAPlace(i)
                  or fluid.CPUPlace() for i in device_ids]
Y
Yang Zhang 已提交
363

Y
Yang Zhang 已提交
364 365 366
        # XXX *ALL WEIGHTS* should be initialized upon model construction
        # even if `forward()` may run different code path for different mode
        # therefore startup program only needs to run once
Y
Yang Zhang 已提交
367
        if self._executor is None:
368
            self._executor = fluid.Executor(places[0])
Y
Yang Zhang 已提交
369 370 371 372 373 374 375 376 377 378
            # XXX incremental initialization
            uninitialized = []
            for var_py in self._startup_prog.list_vars():
                var = fluid.global_scope().find_var(var_py.name)
                if var and var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var_py)
            if uninitialized:
                startup_prog = self._startup_prog._prune(uninitialized)
                self._executor.run(startup_prog)
Y
Yang Zhang 已提交
379

380 381 382 383
            if self.mode == 'train' and self._lazy_load_optimizer:
                self._load_optimizer(self._lazy_load_optimizer)
                self._lazy_load_optimizer = None

384
        compiled_prog = fluid.CompiledProgram(prog)
385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402
        if len(device_ids) > 1:
            loss_name = None
            if self.mode == 'train' and self._loss_endpoint is not None:
                loss_name = self._loss_endpoint.name

            share_vars_from = None
            if self.mode == 'eval' and 'train' in self._compiled_progs:
                share_vars_from = self._compiled_progs['train']
            # HACK invalidate eval program if is compiled before train program
            # quite hackish, OTOH, it is generally uncommon that the eval
            # program will be run before the train program
            if self.mode == 'train' and 'eval' in self._compiled_progs:
                del self._compiled_progs['eval']

            compiled_prog = compiled_prog.with_data_parallel(
                loss_name=loss_name, places=places,
                share_vars_from=share_vars_from)

403
        self._compiled_progs[self.mode] = compiled_prog
Y
Yang Zhang 已提交
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419
        return compiled_prog


class DynamicGraphAdapter(object):
    def __init__(self, model):
        super(DynamicGraphAdapter, self).__init__()
        self.model = model

    @property
    def mode(self):
        return self.model.mode

    @mode.setter
    def mode(self, value):
        self.model.mode = value

420
    # TODO multi device in dygraph mode not implemented at present time
Y
Yang Zhang 已提交
421
    def train(self, inputs, labels, device='CPU', device_ids=None):
Y
Yang Zhang 已提交
422
        assert self.model._optimizer and self.model._loss_function, \
Y
Yang Zhang 已提交
423 424 425 426 427
            "model not ready, please call `model.prepare()` first"
        super(Model, self.model).train()
        self.mode = 'train'
        inputs = to_list(inputs)
        labels = to_list(labels)
D
dengkaipeng 已提交
428
        outputs = self.model.forward(*[to_variable(x) for x in inputs])[0]
Y
Yang Zhang 已提交
429
        losses = self.model._loss_function(outputs, labels)
Y
Yang Zhang 已提交
430 431 432 433
        final_loss = fluid.layers.sum(losses)
        final_loss.backward()
        self.model._optimizer.minimize(final_loss)
        self.model.clear_gradients()
434 435
        return [to_numpy(o) for o in to_list(outputs)], \
            [to_numpy(l) for l in losses]
Y
Yang Zhang 已提交
436 437

    def eval(self, inputs, labels, device='CPU', device_ids=None):
Y
Yang Zhang 已提交
438
        assert self.model._loss_function, \
Y
Yang Zhang 已提交
439
            "model not ready, please call `model.prepare()` first"
440
        super(Model, self.model).eval()
Y
Yang Zhang 已提交
441 442 443 444
        self.mode = 'eval'
        inputs = to_list(inputs)
        labels = to_list(labels)
        outputs = self.model.forward(*[to_variable(x) for x in inputs])
D
dengkaipeng 已提交
445 446
        losses = self.model._loss_function(outputs[0], labels)
        return [to_numpy(o) for o in to_list(outputs[0])], \
447
            [to_numpy(l) for l in losses]
Y
Yang Zhang 已提交
448 449

    def test(self, inputs, device='CPU', device_ids=None):
450
        super(Model, self.model).eval()
Y
Yang Zhang 已提交
451
        self.mode = 'test'
452
        inputs = [to_variable(x) for x in to_list(inputs)]
D
dengkaipeng 已提交
453 454
        outputs = self.model.forward(*inputs)[1:]
        return [to_numpy(o) for o in to_list(outputs[1:])]
Y
Yang Zhang 已提交
455

456 457 458
    def parameters(self, *args, **kwargs):
        return super(Model, self.model).parameters(*args, **kwargs)

Y
Yang Zhang 已提交
459 460 461 462 463 464 465 466 467 468 469 470
    def save(self, path):
        params = self.model.state_dict()
        fluid.save_dygraph(params, path)
        if self.model._optimizer is None:
            return
        if self.model._optimizer.state_dict():
            optim = self.model._optimizer.state_dict()
            fluid.save_dygraph(optim, path)

    def load(self, path):
        params, optim = fluid.load_dygraph(path)
        self.model.set_dict(params)
Y
Yang Zhang 已提交
471
        if self.model._optimizer is None or optim is None:
Y
Yang Zhang 已提交
472 473 474 475 476 477 478 479
            return
        self.model._optimizer.set_dict(optim)


class Model(fluid.dygraph.Layer):
    def __init__(self):
        super(Model, self).__init__(self.__class__.__name__)
        self.mode = 'train'
Y
Yang Zhang 已提交
480
        self._loss_function = None
Y
Yang Zhang 已提交
481
        self._loss_weights = None
Y
Yang Zhang 已提交
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502
        self._optimizer = None
        if in_dygraph_mode():
            self._adapter = DynamicGraphAdapter(self)
        else:
            self._adapter = StaticGraphAdapter(self)

    def train(self, *args, **kwargs):
        return self._adapter.train(*args, **kwargs)

    def eval(self, *args, **kwargs):
        return self._adapter.eval(*args, **kwargs)

    def test(self, *args, **kwargs):
        return self._adapter.test(*args, **kwargs)

    def save(self, *args, **kwargs):
        return self._adapter.save(*args, **kwargs)

    def load(self, *args, **kwargs):
        return self._adapter.load(*args, **kwargs)

D
dengkaipeng 已提交
503
    def prepare(self, optimizer, loss_function, metrics=[]):
Y
Yang Zhang 已提交
504
        self._optimizer = optimizer
Y
Yang Zhang 已提交
505 506 507
        assert isinstance(loss_function, Loss), \
            "'loss_function' must be sub classes of 'Loss'"
        self._loss_function = loss_function
D
dengkaipeng 已提交
508 509 510 511
        for metric in to_list(metrics):
            assert isinstance(metric, Metric), \
                "{} is not sub class of Metric".format(metric.__class__.__name__)
        self._metrics = to_list(metrics)
512 513 514

    def parameters(self, *args, **kwargs):
        return self._adapter.parameters(*args, **kwargs)