提交 fbfe9d89 编写于 作者: B BBuf

support inceptionv3

上级 7d06900d
...@@ -116,20 +116,20 @@ class Inception3(nn.Module): ...@@ -116,20 +116,20 @@ class Inception3(nn.Module):
self.Mixed_7a = inception_d(768) self.Mixed_7a = inception_d(768)
self.Mixed_7b = inception_e(1280) self.Mixed_7b = inception_e(1280)
self.Mixed_7c = inception_e(2048) self.Mixed_7c = inception_e(2048)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AvgPool2d((8, 8))
self.dropout = nn.Dropout() self.dropout = nn.Dropout()
self.fc = nn.Linear(2048, num_classes) self.fc = nn.Linear(2048, num_classes)
if init_weights: # if init_weights:
for m in self.modules(): # for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1 # type: ignore # stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1 # type: ignore
flow.nn.init.trunc_normal_( # flow.nn.init.trunc_normal_(
m.weight, mean=0.0, std=stddev, a=-2, b=2 # m.weight, mean=0.0, std=stddev, a=-2, b=2
) # )
elif isinstance(m, nn.BatchNorm2d): # elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1) # nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0) # nn.init.constant_(m.bias, 0)
def _transform_input(self, x: Tensor) -> Tensor: def _transform_input(self, x: Tensor) -> Tensor:
if self.transform_input: if self.transform_input:
...@@ -191,7 +191,7 @@ class Inception3(nn.Module): ...@@ -191,7 +191,7 @@ class Inception3(nn.Module):
# N x 2048 # N x 2048
x = self.fc(x) x = self.fc(x)
# N x 1000 (num_classes) # N x 1000 (num_classes)
return x, aux return x
class InceptionA(nn.Module): class InceptionA(nn.Module):
...@@ -418,7 +418,7 @@ class InceptionAux(nn.Module): ...@@ -418,7 +418,7 @@ class InceptionAux(nn.Module):
self.fc = nn.Linear(768, num_classes) self.fc = nn.Linear(768, num_classes)
self.fc.stddev = 0.001 # type: ignore[assignment] self.fc.stddev = 0.001 # type: ignore[assignment]
self.avg_pool = nn.AvgPool2d(kernel_size=5, stride=3) self.avg_pool = nn.AvgPool2d(kernel_size=5, stride=3)
self.adaptive_avp_pool = nn.AdaptiveAvgPool2d((1, 1)) self.adaptive_avp_pool = nn.AvgPool2d((1, 1))
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
# N x 768 x 17 x 17 # N x 768 x 17 x 17
...@@ -453,22 +453,22 @@ class BasicConv2d(nn.Module): ...@@ -453,22 +453,22 @@ class BasicConv2d(nn.Module):
inceptionv3 = inception_v3() inceptionv3 = inception_v3()
inceptionv3.eval() inceptionv3.eval()
# class inceptionv3Graph(flow.nn.Graph): class inceptionv3Graph(flow.nn.Graph):
# def __init__(self): def __init__(self):
# super().__init__() super().__init__()
# self.m = inceptionv3 self.m = inceptionv3
# def build(self, x): def build(self, x):
# out = self.m(x) out = self.m(x)
# return out return out
# def test_inceptionv3(): def test_inceptionv3():
# inceptionv3_graph = inceptionv3Graph() inceptionv3_graph = inceptionv3Graph()
# inceptionv3_graph._compile(flow.randn(1, 3, 299, 299)) inceptionv3_graph._compile(flow.randn(1, 3, 299, 299))
# with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
# flow.save(inceptionv3.state_dict(), tmpdirname) flow.save(inceptionv3.state_dict(), tmpdirname)
# convert_to_onnx_and_check(inceptionv3_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp") convert_to_onnx_and_check(inceptionv3_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp")
# test_inceptionv3() test_inceptionv3()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from oneflow import nn, Tensor
import oneflow as flow
from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check
import tempfile
def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
result = nn.Sequential()
result.add_module(
"conv",
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False,
),
)
result.add_module("bn", nn.BatchNorm2d(num_features=out_channels))
return result
class SEBlock(nn.Module):
def __init__(self, input_channels: int, internal_neurons: int):
super(SEBlock, self).__init__()
self.down = nn.Conv2d(
in_channels=input_channels,
out_channels=internal_neurons,
kernel_size=1,
stride=1,
bias=True,
)
self.up = nn.Conv2d(
in_channels=internal_neurons,
out_channels=input_channels,
kernel_size=1,
stride=1,
bias=True,
)
self.relu = nn.ReLU()
self.adaptive_avg_pool2d = nn.AdaptiveAvgPool2d(output_size=1)
def forward(self, inputs: Tensor):
x = self.adaptive_avg_pool2d(inputs)
x = self.down(x)
x = self.relu(x)
x = self.up(x)
x = flow.sigmoid(x)
return inputs * x
class RepVGGBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode="zeros",
deploy=False,
use_se=False,
) -> None:
super(RepVGGBlock, self).__init__()
self.deploy = deploy
self.groups = groups
self.in_channels = in_channels
assert kernel_size == 3
assert padding == 1
padding_11 = padding - kernel_size // 2
self.nonlinearity = nn.ReLU()
if use_se:
self.se = SEBlock(out_channels, internal_neurons=out_channels // 16)
else:
self.se = nn.Identity()
if deploy:
self.rbr_reparam = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=True,
padding_mode=padding_mode,
)
else:
self.rbr_identity = (
nn.BatchNorm2d(num_features=in_channels)
if out_channels == in_channels and stride == 1
else None
)
self.rbr_dense = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
)
self.rbr_1x1 = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=padding_11,
groups=groups,
)
def forward(self, inputs):
if hasattr(self, "rbr_reparam"):
return self.non_linearity(self.se(self.rbr_reparam(inputs)))
if self.rbr_identity is None:
id_out = 0
else:
id_out = self.rbr_identity(inputs)
return self.nonlinearity(
self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
)
class RepVGG(nn.Module):
def __init__(
self,
num_blocks,
num_classes=1000,
width_multiplier=None,
override_groups_map=None,
deploy=False,
use_se=False,
):
super(RepVGG, self).__init__()
assert len(width_multiplier) == 4
self.deploy = deploy
self.override_groups_map = override_groups_map or dict()
self.use_se = use_se
assert 0 not in self.override_groups_map
self.in_planes = min(64, int(64 * width_multiplier[0]))
self.stage0 = RepVGGBlock(
in_channels=3,
out_channels=self.in_planes,
kernel_size=3,
stride=2,
padding=1,
deploy=self.deploy,
use_se=self.use_se,
)
self.cur_layer_idx = 1
self.stage1 = self._make_stage(
int(64 * width_multiplier[0]), num_blocks[0], stride=2
)
self.stage2 = self._make_stage(
int(128 * width_multiplier[1]), num_blocks[1], stride=2
)
self.stage3 = self._make_stage(
int(256 * width_multiplier[2]), num_blocks[2], stride=2
)
self.stage4 = self._make_stage(
int(512 * width_multiplier[3]), num_blocks[3], stride=2
)
self.gap = nn.AdaptiveAvgPool2d(output_size=1)
self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes)
def _make_stage(self, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
blocks = []
for stride in strides:
cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)
blocks.append(
RepVGGBlock(
in_channels=self.in_planes,
out_channels=planes,
kernel_size=3,
stride=stride,
padding=1,
groups=cur_groups,
deploy=self.deploy,
use_se=self.use_se,
)
)
self.in_planes = planes
self.cur_layer_idx += 1
return nn.Sequential(*blocks)
def forward(self, x):
out = self.stage0(x)
out = self.stage1(out)
out = self.stage2(out)
out = self.stage3(out)
out = self.stage4(out)
out = self.gap(out)
out = flow.flatten(out, 1)
out = self.linear(out)
return out
optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
g2_map = {l: 2 for l in optional_groupwise_layers}
g4_map = {l: 4 for l in optional_groupwise_layers}
def create_RepVGG_A0(deploy=False):
return RepVGG(
num_blocks=[2, 4, 14, 1],
num_classes=1000,
width_multiplier=[0.75, 0.75, 0.75, 2.5],
override_groups_map=None,
deploy=deploy,
)
def create_RepVGG_A1(deploy=False):
return RepVGG(
num_blocks=[2, 4, 14, 1],
num_classes=1000,
width_multiplier=[1, 1, 1, 2.5],
override_groups_map=None,
deploy=deploy,
)
def create_RepVGG_A2(deploy=False):
return RepVGG(
num_blocks=[2, 4, 14, 1],
num_classes=1000,
width_multiplier=[1.5, 1.5, 1.5, 2.75],
override_groups_map=None,
deploy=deploy,
)
def create_RepVGG_B0(deploy=False):
return RepVGG(
num_blocks=[4, 6, 16, 1],
num_classes=1000,
width_multiplier=[1, 1, 1, 2.5],
override_groups_map=None,
deploy=deploy,
)
def create_RepVGG_B1(deploy=False):
return RepVGG(
num_blocks=[4, 6, 16, 1],
num_classes=1000,
width_multiplier=[2, 2, 2, 4],
override_groups_map=None,
deploy=deploy,
)
def create_RepVGG_B1g2(deploy=False):
return RepVGG(
num_blocks=[4, 6, 16, 1],
num_classes=1000,
width_multiplier=[2, 2, 2, 4],
override_groups_map=g2_map,
deploy=deploy,
)
def create_RepVGG_B1g4(deploy=False):
return RepVGG(
num_blocks=[4, 6, 16, 1],
num_classes=1000,
width_multiplier=[2, 2, 2, 4],
override_groups_map=g4_map,
deploy=deploy,
)
def create_RepVGG_B2(deploy=False):
return RepVGG(
num_blocks=[4, 6, 16, 1],
num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5],
override_groups_map=None,
deploy=deploy,
)
def create_RepVGG_B2g2(deploy=False):
return RepVGG(
num_blocks=[4, 6, 16, 1],
num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5],
override_groups_map=g2_map,
deploy=deploy,
)
def create_RepVGG_B2g4(deploy=False):
return RepVGG(
num_blocks=[4, 6, 16, 1],
num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5],
override_groups_map=g4_map,
deploy=deploy,
)
def create_RepVGG_B3(deploy=False):
return RepVGG(
num_blocks=[4, 6, 16, 1],
num_classes=1000,
width_multiplier=[3, 3, 3, 5],
override_groups_map=None,
deploy=deploy,
)
def create_RepVGG_B3g2(deploy=False):
return RepVGG(
num_blocks=[4, 6, 16, 1],
num_classes=1000,
width_multiplier=[3, 3, 3, 5],
override_groups_map=g2_map,
deploy=deploy,
)
def create_RepVGG_B3g4(deploy=False):
return RepVGG(
num_blocks=[4, 6, 16, 1],
num_classes=1000,
width_multiplier=[3, 3, 3, 5],
override_groups_map=g4_map,
deploy=deploy,
)
def create_RepVGG_D2se(deploy=False):
return RepVGG(
num_blocks=[8, 14, 24, 1],
num_classes=1000,
width_multiplier=[2.5, 2.5, 2.5, 5],
override_groups_map=None,
deploy=deploy,
use_se=True,
)
repvgg = create_RepVGG_B2g4()
repvgg.eval()
repvgg = repvgg.to("cuda")
class RepVGGGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = repvgg
def build(self, x):
out = self.m(x)
return out
def test_repvgg():
repvgg_graph = RepVGGGraph()
repvgg_graph._compile(flow.randn(1, 3, 224, 224).to("cuda"))
with tempfile.TemporaryDirectory() as tmpdirname:
flow.save(repvgg.state_dict(), tmpdirname)
convert_to_onnx_and_check(repvgg_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp")
test_repvgg()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册