model.py 16.6 KB
Newer Older
Y
Yang Zhang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Y
Yang Zhang 已提交
15 16 17
from __future__ import absolute_import

import inspect
Y
Yang Zhang 已提交
18 19
import os
import pickle
Y
Yang Zhang 已提交
20 21 22 23 24
from collections import OrderedDict

import numpy as np

from paddle import fluid
Y
Yang Zhang 已提交
25 26 27
from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer
Y
Yang Zhang 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
from paddle.fluid.dygraph.base import to_variable

__all__ = ['Model', 'shape_hints']

LOSS_DTYPE_MAP = {
    'cross_entropy': 'int64'
}


def to_list(value):
    if isinstance(value, (list, tuple)):
        return value
    return [value]


43 44 45 46 47 48 49 50
def to_numpy(var):
    assert isinstance(var, (Variable, fluid.core.VarBase)), "not a variable"
    if isinstance(var, fluid.core.VarBase):
        return var.numpy()
    t = global_scope().find_var(var.name).get_tensor()
    return np.array(t)


Y
Yang Zhang 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
def extract_args(func):
    if hasattr(inspect, 'getfullargspec'):
        return inspect.getfullargspec(func)[0]
    else:
        return inspect.getargspec(func)[0]


def shape_hints(**hints):
    assert hints, "hints can not be empty"
    assert all(isinstance(h, (list, tuple)) for h in hints.values()), \
        "shape hint must be a list or tuple"

    def wrapper(func):
        args = extract_args(func)
        invalid = set(hints.keys()) - set(args)
        assert not invalid, \
            "shape hint for arguments that are not present in forward method" \
            + ": ({})".format(", ".join(invalid))
        func.shape_hints = hints
        return func
    return wrapper


class StaticGraphAdapter(object):
    def __init__(self, model):
        super(StaticGraphAdapter, self).__init__()
        self.model = model
        # with `_build_once` gone, parameters are now created in `__init__`
        # so we need to keep track of the parameters already created
        self._startup_prog = fluid.default_startup_program()
81
        self._orig_prog = fluid.default_main_program()
Y
Yang Zhang 已提交
82

83
        self._label_vars = {}  # label variables
Y
Yang Zhang 已提交
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
        self._endpoints = {}
        self._loss_endpoint = None
        self._executor = None
        self._progs = {}
        self._compiled_progs = {}

        # parse shape hints
        self._input_desc = OrderedDict([
            (n, None) for n in extract_args(self.model.forward) if n != 'self'
        ])
        if hasattr(self.model.forward, 'shape_hints'):
            self._input_desc.update(self.model.forward.shape_hints)

    @property
    def mode(self):
        return self.model.mode

    @mode.setter
    def mode(self, value):
        self.model.mode = value

    def train(self, inputs, labels, device='CPU', device_ids=None):
        assert self.model._optimizer and self.model._loss_functions, \
            "model not ready, please call `model.prepare()` first"
        self.mode = 'train'
        return self._run(inputs, labels, device, device_ids)

    def eval(self, inputs, labels, device='CPU', device_ids=None):
        assert self.model._loss_functions, \
            "model not ready, please call `model.prepare()` first"
        self.mode = 'eval'
        return self._run(inputs, labels, device, device_ids)

    def test(self, inputs, device='CPU', device_ids=None):
        self.mode = 'test'
        return self._run(inputs, None, device, device_ids)

121 122 123
    def parameters(self, *args, **kwargs):
        return None

Y
Yang Zhang 已提交
124
    def save(self, path):
Y
Yang Zhang 已提交
125 126 127
        def _save(state, path):
            if not state:
                return
128 129
            state = {k: to_numpy(v) if isinstance(v, Variable) else v
                     for k, v in state.items()}
Y
Yang Zhang 已提交
130 131 132 133 134 135 136
            with open(path, 'wb') as f:
                pickle.dump(state, f)

        base = os.path.basename(path)
        assert base != "", "path should be of 'dirname/filename' format"
        param_path = path + ".pdparams"
        _save(self.model.state_dict(), param_path)
Y
Yang Zhang 已提交
137 138
        prog = self._progs.get('train', None)
        if prog is None or self.model._optimizer is None:
Y
Yang Zhang 已提交
139 140 141 142 143 144 145 146 147
            return
        # XXX `optimizer.state_dict()` only work in dygraph mode
        optim_path = path + ".pdopt"
        optim = {p.name: p for p in filter(
            is_belong_to_optimizer, prog.list_vars())}
        # HACK this is contrived, optimizer state is not the same for
        # static/dynamic graph mode
        optim['__static_graph_only__'] = True
        _save(optim, optim_path)
Y
Yang Zhang 已提交
148 149

    def load(self, path):
Y
Yang Zhang 已提交
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
        def _load(path):
            if not os.path.exists(path):
                return
            with open(path, 'rb') as f:
                return pickle.load(f)

        def set_var(var, ndarray):
            t = global_scope().find_var(var.name).get_tensor()
            p = t._place()
            if p.is_cpu_place():
                place = fluid.CPUPlace()
            elif p.is_cuda_pinned_place():
                place = fluid.CUDAPinnedPlace()
            else:
                p = fluid.core.Place()
                p.set_place(t._place())
                place = fluid.CUDAPlace(p.gpu_device_id())

            t.set(ndarray, place)

        param_path = path + ".pdparams"
        params = _load(param_path)
        assert params, "failed to load parameters, please check path"
        for key, var in self.model.state_dict().items():
            assert key in params, \
                "parameter [{}] is not found in model file [{}]".format(
                    key, param_path)
            set_var(var, params[key])

        # FIXME what if a different optimizer is used?
        if not self.model._optimizer:
            return
        prog = self._progs.get('train', None)
        optim = list(filter(is_belong_to_optimizer, prog.list_vars()))
        if not optim:
            return

        optim_path = path + ".pdopt"
        optim_state = _load(optim_path)
        if optim_state is None:
            return
        assert '__static_graph_only__' in optim_state, \
            "optimizer saved in dygraph mode is not usable in static graph"

        fluid.core._create_loaded_parameter(
            optim, global_scope(), self._executor._default_executor)

        for var in optim:
            assert var.name in optim_state, \
                "variable [{}] is not found in model file [{}]".format(
                    var.name, optim_path)
            set_var(var, optim_state[var.name])
Y
Yang Zhang 已提交
202 203 204 205 206 207 208 209

    def _run(self, inputs, labels=None, device='CPU', device_ids=None):
        inputs = to_list(inputs)
        if labels is not None:
            labels = to_list(labels)
        assert len(inputs) == len(self._input_desc), "number of inputs" \
            + " does not match number of arguments of `forward` method"

Y
Yang Zhang 已提交
210 211
        if self._progs.get(self.mode, None) is None:
            self._make_program(self._infer_input_vars(inputs))
Y
Yang Zhang 已提交
212

Y
Yang Zhang 已提交
213 214 215 216 217 218 219 220
        ids = [str(i) for i in device_ids]
        ids.sort()
        prog_hash = '_'.join([self.mode] + ids)
        compiled_prog = self._compiled_progs.get(prog_hash, None)
        if compiled_prog is None:
            compiled_prog = self._compile_and_initialize(
                self._progs[self.mode], device, device_ids)
            self._compiled_progs[prog_hash] = compiled_prog
Y
Yang Zhang 已提交
221 222 223 224 225 226 227 228

        feed = {}
        input_names = [name for name in self._input_desc.keys()]
        for idx, n in enumerate(input_names):
            # train and test may take different arguments
            if inputs[idx] is not None:
                feed[n] = inputs[idx]
        if labels is not None:
229
            for idx, v in enumerate(self._label_vars[self.mode]):
Y
Yang Zhang 已提交
230 231
                feed[v.name] = labels[idx]

232 233 234 235
        endpoints = self._endpoints[self.mode]
        fetch_list = endpoints['output'] + endpoints['loss']
        num_output = len(endpoints['output'])
        out = self._executor.run(
Y
Yang Zhang 已提交
236
            compiled_prog, feed=feed,
237 238 239 240 241
            fetch_list=fetch_list)
        if self.mode == 'test':
            return out[:num_output]
        else:
            return out[:num_output], out[num_output:]
Y
Yang Zhang 已提交
242 243

    def _make_program(self, inputs):
244 245 246 247 248 249
        prog = self._orig_prog.clone(for_test=self.mode != 'train')
        if self.mode == 'train':
            # HACK workaround learning rate map issue
            lr_var = self.model._optimizer._learning_rate_map[self._orig_prog]
            self.model._optimizer._learning_rate_map[prog] = lr_var
        losses = []
Y
Yang Zhang 已提交
250 251
        with fluid.program_guard(prog, self._startup_prog):
            outputs = to_list(self.model.forward(*inputs))
252
            losses = []
Y
Yang Zhang 已提交
253 254
            label_vars = []
            if self.mode != 'test':
Y
Yang Zhang 已提交
255 256 257 258 259
                loss_weights = self.model._loss_weights
                if loss_weights is None:
                    loss_weights = [1. for _ in self.model._loss_functions]
                for o, l, w in zip(outputs, self.model._loss_functions,
                                   loss_weights):
Y
Yang Zhang 已提交
260 261 262 263 264 265
                    if l is None:
                        continue
                    label_var = self._infer_label_var(o, l)
                    label_vars.append(label_var)
                    loss_fn = getattr(fluid.layers, l)
                    loss = loss_fn(o, label_var)
Y
Yang Zhang 已提交
266
                    losses.append(fluid.layers.reduce_mean(loss) * w)
267
                self._label_vars[self.mode] = label_vars
Y
Yang Zhang 已提交
268 269 270 271
                if self.mode == 'train':
                    self._loss_endpoint = fluid.layers.sum(losses)
                    self.model._optimizer.minimize(self._loss_endpoint)
        self._progs[self.mode] = prog
272 273 274 275
        self._endpoints[self.mode] = {
            "output": outputs,
            "loss": losses
        }
Y
Yang Zhang 已提交
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351

    def _infer_input_vars(self, inputs):
        input_vars = []
        for idx, i in enumerate(inputs):
            if i is None:  # train and test may take different arguments
                input_vars.append(None)
                continue
            ndarray = np.array(i)
            name = list(self._input_desc.keys())[idx]
            shape = list(self._input_desc.values())[idx]
            if shape is None:
                shape = (None, ) + ndarray.shape[1:]
            input_vars.append(fluid.data(name, shape, ndarray.dtype))
        return input_vars

    # TODO wrap loss in callable classes
    # - same call signaure
    # - infer_shape method? or same shape as y_pred (e.g., one hot)
    # - split multiple dtype loss functions (e.g., soft label)
    def _infer_label_var(self, output, loss):
        name = output.name + '.label'
        shape = output.shape
        # XXX could get ugly very quickly
        if loss == 'cross_entropy':
            shape = shape[:-1] + (1, )
        dtype = LOSS_DTYPE_MAP.get(loss, output.dtype)
        return fluid.data(name, shape, dtype)

    def _compile_and_initialize(self, prog, device='CPU', device_ids=None):
        if device.lower() == 'cpu':
            place = fluid.CPUPlace()
        elif device.lower() == 'gpu' and isinstance(device_ids, (list, tuple)):
            place = fluid.CUDAPlace(device_ids[0])
        else:
            raise "device not supported"

        compiled_prog = fluid.CompiledProgram(prog)
        if device.lower() == 'gpu' and len(device_ids) > 0:
            places = [fluid.CUDAPlace(i) for i in device_ids]
            loss_name = None
            if self._loss_endpoint is not None:
                loss_name = self._loss_endpoint.name
            compiled_prog = compiled_prog.with_data_parallel(
                loss_name=loss_name, places=places)

        if self._executor is None:
            self._executor = fluid.Executor(place)
            # XXX only run startup once as *ALL* weights should be initialized
            # upon construction of the model
            # XXX incremental initialization, lifted from GuoSheng code
            uninitialized = []
            for var_py in self._startup_prog.list_vars():
                var = fluid.global_scope().find_var(var_py.name)
                if var and var.get_tensor()._is_initialized():
                    continue
                uninitialized.append(var_py)
            if uninitialized:
                startup_prog = self._startup_prog._prune(uninitialized)
                self._executor.run(startup_prog)

        return compiled_prog


class DynamicGraphAdapter(object):
    def __init__(self, model):
        super(DynamicGraphAdapter, self).__init__()
        self.model = model

    @property
    def mode(self):
        return self.model.mode

    @mode.setter
    def mode(self, value):
        self.model.mode = value

352
    # TODO multi device in dygraph mode not implemented at present time
Y
Yang Zhang 已提交
353 354 355 356 357 358 359 360 361 362 363 364 365
    def train(self, inputs, labels, device='CPU', device_ids=None):
        assert self.model._optimizer and self.model._loss_functions, \
            "model not ready, please call `model.prepare()` first"
        super(Model, self.model).train()
        self.mode = 'train'
        inputs = to_list(inputs)
        labels = to_list(labels)
        outputs = self.model.forward(*[to_variable(x) for x in inputs])
        losses = self._loss(outputs, labels)
        final_loss = fluid.layers.sum(losses)
        final_loss.backward()
        self.model._optimizer.minimize(final_loss)
        self.model.clear_gradients()
366 367
        return [to_numpy(o) for o in to_list(outputs)], \
            [to_numpy(l) for l in losses]
Y
Yang Zhang 已提交
368 369 370 371

    def eval(self, inputs, labels, device='CPU', device_ids=None):
        assert self.model._loss_functions, \
            "model not ready, please call `model.prepare()` first"
372
        super(Model, self.model).eval()
Y
Yang Zhang 已提交
373 374 375 376
        self.mode = 'eval'
        inputs = to_list(inputs)
        labels = to_list(labels)
        outputs = self.model.forward(*[to_variable(x) for x in inputs])
377 378 379
        losses = self._loss(outputs, labels)
        return [to_numpy(o) for o in to_list(outputs)], \
            [to_numpy(l) for l in losses]
Y
Yang Zhang 已提交
380 381

    def test(self, inputs, device='CPU', device_ids=None):
382
        super(Model, self.model).eval()
Y
Yang Zhang 已提交
383
        self.mode = 'test'
384 385 386
        inputs = [to_variable(x) for x in to_list(inputs)]
        outputs = self.model.forward(*inputs)
        return [to_numpy(o) for o in to_list(outputs)]
Y
Yang Zhang 已提交
387

388 389 390
    def parameters(self, *args, **kwargs):
        return super(Model, self.model).parameters(*args, **kwargs)

Y
Yang Zhang 已提交
391 392 393 394 395 396 397 398 399 400 401 402
    def save(self, path):
        params = self.model.state_dict()
        fluid.save_dygraph(params, path)
        if self.model._optimizer is None:
            return
        if self.model._optimizer.state_dict():
            optim = self.model._optimizer.state_dict()
            fluid.save_dygraph(optim, path)

    def load(self, path):
        params, optim = fluid.load_dygraph(path)
        self.model.set_dict(params)
Y
Yang Zhang 已提交
403
        if self.model._optimizer is None or optim is None:
Y
Yang Zhang 已提交
404 405 406 407 408
            return
        self.model._optimizer.set_dict(optim)

    def _loss(self, pred, labels):
        losses = []
Y
Yang Zhang 已提交
409 410 411 412 413
        loss_weights = self.model._loss_weights
        if loss_weights is None:
            loss_weights = [1. for _ in self.model._loss_functions]
        for o, l, w, t in zip(to_list(pred), self.model._loss_functions,
                              loss_weights, labels):
Y
Yang Zhang 已提交
414 415 416 417
            if l is None:
                continue
            loss_fn = getattr(fluid.layers, l)
            loss = loss_fn(o, to_variable(t))
Y
Yang Zhang 已提交
418
            losses.append(fluid.layers.reduce_mean(loss) * w)
Y
Yang Zhang 已提交
419 420 421 422 423 424 425 426
        return losses


class Model(fluid.dygraph.Layer):
    def __init__(self):
        super(Model, self).__init__(self.__class__.__name__)
        self.mode = 'train'
        self._loss_functions = []
Y
Yang Zhang 已提交
427
        self._loss_weights = None
Y
Yang Zhang 已提交
428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
        self._optimizer = None
        if in_dygraph_mode():
            self._adapter = DynamicGraphAdapter(self)
        else:
            self._adapter = StaticGraphAdapter(self)

    def train(self, *args, **kwargs):
        return self._adapter.train(*args, **kwargs)

    def eval(self, *args, **kwargs):
        return self._adapter.eval(*args, **kwargs)

    def test(self, *args, **kwargs):
        return self._adapter.test(*args, **kwargs)

    def save(self, *args, **kwargs):
        return self._adapter.save(*args, **kwargs)

    def load(self, *args, **kwargs):
        return self._adapter.load(*args, **kwargs)

Y
Yang Zhang 已提交
449
    def prepare(self, optimizer, loss_functions, loss_weights=None):
Y
Yang Zhang 已提交
450 451
        self._optimizer = optimizer
        self._loss_functions = to_list(loss_functions)
Y
Yang Zhang 已提交
452 453
        if loss_weights is not None:
            self._loss_weights = to_list(loss_weights)
454 455 456

    def parameters(self, *args, **kwargs):
        return self._adapter.parameters(*args, **kwargs)