What's Changed

In the new version (v1.4.12), support for torch >= 2 versions has been added, including support for torch.compile for faster inference and backpropagation. Here's an example:

from typing import Tuple, Mapping

import torch
from torch import nn

from treevalue import FastTreeValue


# A simple MLP
class MLP(nn.Module):
    def __init__(self, in_features: int, out_features: int, layers: Tuple[int, ...] = (1024,)):
        nn.Module.__init__(self)
        self.in_features = in_features
        self.out_features = out_features
        self.layers = layers
        ios = [self.in_features, *self.layers, self.out_features]
        self.mlp = nn.Sequential(
            *(
                nn.Linear(in_, out_, bias=True)
                for in_, out_ in zip(ios[:-1], ios[1:])
            )
        )

    def forward(self, x):
        return self.mlp(x)


# Multiple headed MLP
class MultiHeadMLP(nn.Module):
    def __init__(self, in_features: int, out_features: Mapping[str, int], layers: Tuple[int, ...] = (1024,)):
        nn.Module.__init__(self)
        self.in_features = in_features
        self.out_features = out_features
        self.layers = layers
        _networks = {
            o_name: MLP(in_features, o_feat, layers)
            for o_name, o_feat in self.out_features.items()
        }
        self.mlps = nn.ModuleDict(_networks)  # use nn.ModuleDict to register child MLPs
        self._t_mlps = FastTreeValue(_networks)  # use TreeValue for batch inferring

    def forward(self, x):
        return self._t_mlps(x)


if __name__ == '__main__':
    net = MultiHeadMLP(
        20,
        {'a': 10, 'b': 20, 'c': 14, 'd': 3},
    )
    net = torch.compile(net)
    print(net)

    input_ = torch.randn(1, 10, 20)
    output = net(input_)
    print(output.shape)

The compiled version of the MultiHeadMLP above will have the following network structure:

OptimizedModule(
  (_orig_mod): MultiHeadMLP(
    (mlps): ModuleDict(
      (a): MLP(
        (mlp): Sequential(
          (0): Linear(in_features=20, out_features=1024, bias=True)
          (1): Linear(in_features=1024, out_features=10, bias=True)
        )
      )
      (b): MLP(
        (mlp): Sequential(
          (0): Linear(in_features=20, out_features=1024, bias=True)
          (1): Linear(in_features=1024, out_features=20, bias=True)
        )
      )
      (c): MLP(
        (mlp): Sequential(
          (0): Linear(in_features=20, out_features=1024, bias=True)
          (1): Linear(in_features=1024, out_features=14, bias=True)
        )
      )
      (d): MLP(
        (mlp): Sequential(
          (0): Linear(in_features=20, out_features=1024, bias=True)
          (1): Linear(in_features=1024, out_features=3, bias=True)
        )
      )
    )
  )
)

And the inference output after passing float32[1, 10, 20] as input will have the following dimensions:

<FastTreeValue 0x7fe9b197e6a0>
├── 'a' --> torch.Size([1, 10, 10])
├── 'b' --> torch.Size([1, 10, 20])
├── 'c' --> torch.Size([1, 10, 14])
└── 'd' --> torch.Size([1, 10, 3])

Full Changelog: https://github.com/opendilab/treevalue/compare/v1.4.11...v1.4.12

项目简介

Here are the most awesome tree structure computing solutions, make your life easier. (这里有目前性能最优的树形结构计算解决方案)

🚀 Github 镜像仓库 🚀

源项目地址

https://github.com/opendilab/treevalue

发行版本 21

v1.4.12

全部发行版

贡献者 4

HansBug @HansBug
H HansBug @HansBug
H HansBug @HansBug
N niuyazhe @niuyazhe

开发语言

  • Python 99.0 %
  • Makefile 0.5 %
  • Shell 0.5 %