提交 7acdb671 编写于 作者: N niuyazhe

polish(nyz): add torch1.1.0 compatibility for nn.Flatten

上级 171dddc4
import os
import torch
__TITLE__ = 'DI-engine'
__VERSION__ = 'v0.2.0'
......@@ -10,3 +11,7 @@ __version__ = __VERSION__
enable_hpc_rl = False
enable_linklink = os.environ.get('ENABLE_LINKLINK', 'false').lower() == 'true'
enable_numba = True
def torch_gt_131():
return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 131
......@@ -2,7 +2,7 @@ from typing import Optional
import torch
import torch.nn as nn
from ding.torch_utils import ResFCBlock, ResBlock
from ding.torch_utils import ResFCBlock, ResBlock, Flatten
from ding.utils import SequenceType
......@@ -49,7 +49,7 @@ class ConvEncoder(nn.Module):
assert len(set(hidden_size_list[3:-1])) <= 1, "Please indicate the same hidden size for res block parts"
for i in range(3, len(self.hidden_size_list) - 1):
layers.append(ResBlock(self.hidden_size_list[i], activation=self.act, norm_type=norm_type))
layers.append(nn.Flatten())
layers.append(Flatten())
self.main = nn.Sequential(*layers)
flatten_size = self._get_flatten_size()
......
from .activation import build_activation, Swish
from .res_block import ResBlock, ResFCBlock
from .nn_module import fc_block, conv2d_block, one_hot, deconv2d_block, BilinearUpsample, NearestUpsample, \
binary_encode, NoiseLinearLayer, noise_block, MLP
binary_encode, NoiseLinearLayer, noise_block, MLP, Flatten
from .normalization import build_normalization
from .rnn import get_lstm, sequence_mask
from .soft_argmax import SoftArgmax
......
......@@ -4,6 +4,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_normal_, kaiming_normal_, orthogonal_
from typing import Union, Tuple, List, Callable
from ding import torch_gt_131
from .normalization import build_normalization
......@@ -577,3 +578,23 @@ def noise_block(
if use_dropout:
block.append(nn.Dropout(dropout_probability))
return sequential_pack(block)
class NaiveFlatten(nn.Module):
def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
super(NaiveFlatten, self).__init__()
self.start_dim = start_dim
self.end_dim = end_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.end_dim != -1:
return x.view(*x.shape[:self.start_dim], -1, *x.shape[self.end_dim + 1:])
else:
return x.view(*x.shape[:self.start_dim], -1)
if torch_gt_131():
Flatten = nn.Flatten
else:
Flatten = NaiveFlatten
......@@ -6,6 +6,7 @@ import math
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from .nn_module import Flatten
def to_2tuple(item):
......@@ -94,7 +95,7 @@ class ClassifierHead(nn.Module):
self.drop_rate = drop_rate
self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
self.flatten = Flatten(1) if use_conv and pool_type else nn.Identity()
def forward(self, x):
x = self.global_pool(x)
......
......@@ -2,7 +2,7 @@ import torch
import pytest
from ding.torch_utils import build_activation, build_normalization
from ding.torch_utils.network.nn_module import conv1d_block, conv2d_block, fc_block, deconv2d_block, ChannelShuffle, \
one_hot, NearestUpsample, BilinearUpsample, binary_encode, weight_init_
one_hot, NearestUpsample, BilinearUpsample, binary_encode, weight_init_, NaiveFlatten
batch_size = 2
in_channels = 2
......@@ -148,3 +148,16 @@ class TestNnModule:
max_val = torch.tensor(8)
output = binary_encode(input, max_val)
assert torch.equal(output, torch.tensor([[0, 1, 0, 0]]))
@pytest.mark.tmp
def test_flatten(self):
inputs = torch.randn(4, 3, 8, 8)
model1 = NaiveFlatten()
output1 = model1(inputs)
assert output1.shape == (4, 3 * 8 * 8)
model2 = NaiveFlatten(1, 2)
output2 = model2(inputs)
assert output2.shape == (4, 3 * 8, 8)
model3 = NaiveFlatten(1, 3)
output3 = model2(inputs)
assert output1.shape == (4, 3 * 8 * 8)
......@@ -5,6 +5,7 @@ import torch
import re
from torch._six import string_classes
import collections.abc as container_abcs
from ding import torch_gt_131
int_classes = int
np_str_obj_array_pattern = re.compile(r'[SaUO]')
......@@ -15,10 +16,6 @@ default_collate_err_msg_format = (
)
def torch_gt_131():
return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 131
def default_collate(batch: Sequence,
cat_1dim: bool = True,
ignore_prefix: list = ['collate_ignore']) -> Union[torch.Tensor, Mapping, Sequence]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册