diff --git a/python_module/megengine/core/tensor_nn.py b/python_module/megengine/core/tensor_nn.py index 08a8cd008a03845cc4fc91866932fcdcd3b4b8e8..6241aa25bbeb98146be2572e55eb7b5e28488271 100644 --- a/python_module/megengine/core/tensor_nn.py +++ b/python_module/megengine/core/tensor_nn.py @@ -25,5 +25,14 @@ class Parameter(Tensor): def __init__(self, value, *, dtype=None, device=None, requires_grad=True): # pylint: disable=super-init-not-called - t = tensor(value, dtype=dtype, device=device, requires_grad=requires_grad) + if isinstance(value, Tensor): + t = value + else: + t = tensor(value, dtype=dtype, device=device, requires_grad=requires_grad) self.__dict__.update(t.__dict__) + + @property + def shape(self): + r"""Return shape of parameter. + """ + return self._symvar.imm_shape diff --git a/python_module/megengine/module/__init__.py b/python_module/megengine/module/__init__.py index 64e23a2f02aa9bd6811d7335aaa767cebabfa124..f87fa3e923977a26bb81cd387bf7b43fb334fef0 100644 --- a/python_module/megengine/module/__init__.py +++ b/python_module/megengine/module/__init__.py @@ -16,3 +16,4 @@ from .linear import Linear from .module import Module from .pooling import AvgPool2d, MaxPool2d from .sequential import Sequential +from .parampack import ParamPack diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index 25ab8bc764a16621390f61ab6ef22ab57c024de1..a77602b2b3f6659585ee9440636472350076f1da 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -168,6 +168,29 @@ class Module(metaclass=ABCMeta): """ yield from self._flatten(predicate=_is_buffer, recursive=recursive) + def replace_param(self, + params: dict, + start_pos: int, + seen: Optional[Set[int]] = None): + offset = 0 + if seen is None: + seen = set([id(self)]) + module_dict = vars(self) + for key in sorted(module_dict): + hash_id = id(module_dict[key]) + if hash_id in seen: + continue + seen.add(hash_id) + if isinstance(module_dict[key], Parameter): + if start_pos + offset in params: + assert module_dict[key].shape == params[start_pos + + offset].shape + module_dict[key] = params[start_pos + offset] + offset += 1 + if isinstance(module_dict[key], Module): + offset += module_dict[key].replace_param(params, start_pos + offset, seen) + return offset + def named_buffers( self, prefix: str = "", recursive: bool = True ) -> Iterable[Tuple[str, Buffer]]: diff --git a/python_module/megengine/module/parampack.py b/python_module/megengine/module/parampack.py new file mode 100644 index 0000000000000000000000000000000000000000..d91aee83ea1cea41a4a11e33feecb3ea8a1c372e --- /dev/null +++ b/python_module/megengine/module/parampack.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import collections +from typing import Iterable, Optional +import numpy as np + +from ..core import Parameter, Tensor +from .module import Module +from .._internal.opr import param_pack_split + + +class ParamPack(Module): + def __init__(self, + model: Module, + nr_ignore_first:int = 8, + max_size_per_group: int = 10, + max_nr_params_per_group: int = 100): + super().__init__() + self._model = model + self._nr_ignore_first = nr_ignore_first + self._max_size_per_group = max_size_per_group + self._max_nr_params_per_group = max_nr_params_per_group + self._grouped_params = [] + self._packed_params = [] + + params = model.parameters() + self._pack_params(params) + + def parameters(self, requires_grad: Optional[bool] = None) -> Iterable[Parameter]: + for param in self._packed_params: + if requires_grad is None or param.requires_grad == requires_grad: + yield param + + def _pack_params(self, params: Iterable[Parameter]): + groups = collections.defaultdict(list) + ignored = 0 + param_id = 0 + for param in params: + if self._nr_ignore_first > ignored: + ignored += 1 + self._grouped_params.append([{'tensor': param, 'id': param_id}]) + self._packed_params.append(param) + else: + key = (param.dtype, param.device, param.requires_grad) + groups[key].append({'tensor': param, 'id': param_id}) + param_id += 1 + for (dtype, device, requires_grad) in groups.keys(): + dtype_sz = np.dtype(dtype).itemsize + align = device.mem_align + if align < dtype_sz: + align = 1 + else: + assert align % dtype_sz == 0 + align //= dtype_sz + + group = groups[(dtype, device, requires_grad)] + while group: + aligned_pos = [] + offset = 0 + params = [] + idx = 0 + while idx < len(group): + param = group[idx] + assert param['tensor'].device == device + padding = (align - (offset & (align - 1))) & (align - 1) + offset += padding + aligned_pos.append(offset) + params.append(param) + offset += int(np.prod(param['tensor'].shape)) + idx += 1 + + if (offset * dtype_sz >= + self._max_size_per_group * 1024 * 1024 + or idx >= self._max_nr_params_per_group): + break + group = group[idx:] + if idx == 1: + # ignore param packs with only one item + self._packed_params.append(params[0]) + self._grouped_params.append(params) + continue + + packed_value = np.zeros((offset, ), dtype=dtype) + for param, pos in zip(params, aligned_pos): + val = param['tensor'].numpy() + packed_value[pos:pos + val.size] = val.flatten() + new_param = Parameter(value=packed_value, + device=device, + dtype=dtype, + requires_grad=requires_grad) + self._packed_params.append(new_param) + self._grouped_params.append(params) + + def forward(self, *args, **kwargs): + replace_param = dict() + for i in range(len(self._packed_params)): + packed_param = self._packed_params[i] + grouped_params = self._grouped_params[i] + if len(grouped_params) == 1: + continue + split = param_pack_split(packed_param._symvar, + [i['tensor'].shape for i in grouped_params]) + split = [ + Parameter(Tensor(i, requires_grad=packed_param.requires_grad)) + for i in split + ] + for j in range(len(split)): + replace_param[grouped_params[j]['id']] = split[j] + self._model.replace_param(replace_param, 0) + + return self._model.forward(*args, **kwargs) diff --git a/python_module/megengine/optimizer/optimizer.py b/python_module/megengine/optimizer/optimizer.py index 7e85034551027bb52c1aca69d9ba3970bd61b121..d559783ad408f9c0dc4fbb4e778f53346b11a388 100644 --- a/python_module/megengine/optimizer/optimizer.py +++ b/python_module/megengine/optimizer/optimizer.py @@ -168,6 +168,8 @@ class Optimizer(metaclass=ABCMeta): cg = get_default_graph() grads = grad_func(loss, params, use_virtual_grad=not cg.is_eager()) + if not isinstance(grads, list): + grads = [grads] assert len(grads) == len(params) for param, grad in zip(params, grads): diff --git a/python_module/test/integration/test_parampack.py b/python_module/test/integration/test_parampack.py new file mode 100644 index 0000000000000000000000000000000000000000..678971a2f8e9fb8baa7c5c79a057b4bbf6b55850 --- /dev/null +++ b/python_module/test/integration/test_parampack.py @@ -0,0 +1,207 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import itertools + +import numpy as np +import pytest + +import megengine as mge +from megengine.core import tensor +from megengine.functional import cross_entropy_with_softmax, tanh +from megengine.jit import trace +from megengine.module import Linear, Module, ParamPack +from megengine.optimizer import SGD + +batch_size = 64 +data_shape = (batch_size, 2) +label_shape = (batch_size,) + + +def minibatch_generator(): + while True: + inp_data = np.zeros((batch_size, 2)) + label = np.zeros(batch_size, dtype=np.int32) + for i in range(batch_size): + # [x0, x1], sampled from U[-1, 1] + inp_data[i, :] = np.random.rand(2) * 2 - 1 + label[i] = 0 if np.prod(inp_data[i]) < 0 else 1 + yield inp_data.astype(np.float32), label.astype(np.int32) + + +def calculate_precision(data: np.ndarray, pred: np.ndarray) -> float: + """ Calculate precision for given data and prediction. + + :type data: [[x, y], ...] + :param data: Input data + :type pred: [[x_pred, y_pred], ...] + :param pred: Network output data + """ + correct = 0 + assert len(data) == len(pred) + for inp_data, pred_output in zip(data, pred): + label = 0 if np.prod(inp_data) < 0 else 1 + pred_label = np.argmax(pred_output) + if pred_label == label: + correct += 1 + return float(correct) / len(data) + + +class XORNet(Module): + def __init__(self): + self.mid_layers = 14 + self.num_class = 2 + super().__init__() + + self.fc0 = Linear(self.num_class, self.mid_layers, bias=True) + self.fc1 = Linear(self.mid_layers, self.mid_layers, bias=True) + + self.fc2 = Linear(self.mid_layers, self.num_class, bias=True) + + def forward(self, x): + x = self.fc0(x) + x = tanh(x) + x = self.fc1(x) + x = tanh(x) + x = self.fc2(x) + return x + + +@pytest.mark.slow +def test_static_graph_parampack(): + net = XORNet() + net = ParamPack(net, + nr_ignore_first=0, + max_size_per_group=10, + max_nr_params_per_group=100) + opt = SGD( + net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 + ) + + @trace(symbolic=True) + def train(data, label): + pred = net(data) + opt.zero_grad() + loss = cross_entropy_with_softmax(pred, label) + opt.backward(loss) + return loss + + @trace(symbolic=True) + def infer(data): + return net(data) + + train_dataset = minibatch_generator() + losses = [] + + for data, label in itertools.islice(train_dataset, 2000): + loss = train(data, label) + loss = loss[0][0] + opt.step() + losses.append(loss.numpy()) + + assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" + + data, _ = next(train_dataset) + pred = infer(data).numpy() + assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" + +@pytest.mark.slow +def test_dynamic_graph_parampack(): + net = XORNet() + net = ParamPack(net, + nr_ignore_first=0, + max_size_per_group=10, + max_nr_params_per_group=100) + opt = SGD( + net.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 + ) + + @trace(symbolic=False) + def train(data, label): + pred = net(data) + opt.zero_grad() + loss = cross_entropy_with_softmax(pred, label) + opt.backward(loss) + return loss + + @trace(symbolic=False) + def infer(data): + return net(data) + + train_dataset = minibatch_generator() + losses = [] + + for data, label in itertools.islice(train_dataset, 2000): + loss = train(data, label) + loss = loss[0][0] + opt.step() + losses.append(loss.numpy()) + + assert np.mean(losses[-100:]) < 0.1, "Final training Loss must be low enough" + + data, _ = next(train_dataset) + pred = infer(data).numpy() + assert calculate_precision(data, pred) > 0.95, "Test precision must be high enough" + +@pytest.mark.slow +def test_correctness_parampack(): + net1 = XORNet() + net2 = XORNet() + params1 = net1.parameters() + params2 = net2.parameters() + for param1, param2 in zip(params1, params2): + param1.set_value(param2.numpy()) + net1 = ParamPack(net1, + nr_ignore_first=0, + max_size_per_group=10, + max_nr_params_per_group=100) + opt1 = SGD( + net1.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 + ) + + opt2 = SGD( + net2.parameters(requires_grad=True), lr=0.01, momentum=0.9, weight_decay=5e-4 + ) + + @trace(symbolic=False) + def train1(data, label): + pred = net1(data) + opt1.zero_grad() + loss = cross_entropy_with_softmax(pred, label) + opt1.backward(loss) + return loss + + @trace(symbolic=False) + def train2(data, label): + pred = net2(data) + opt2.zero_grad() + loss = cross_entropy_with_softmax(pred, label) + opt2.backward(loss) + return loss + + @trace(symbolic=False) + def infer1(data): + return net1(data) + + @trace(symbolic=False) + def infer2(data): + return net2(data) + + train_dataset = minibatch_generator() + + for data, label in itertools.islice(train_dataset, 2000): + train1(data, label) + opt1.step() + + train2(data, label) + opt2.step() + + data, _ = next(train_dataset) + pred1 = infer1(data).numpy() + pred2 = infer2(data).numpy() + assert np.allclose(pred1, pred2)