From fbfe9d893545da2c66aa9332af30d5f4444425da Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Fri, 24 Sep 2021 17:38:55 +0800 Subject: [PATCH] support inceptionv3 --- .../oneflow2onnx/models/test_inceptionv3.py | 54 +-- examples/oneflow2onnx/models/test_repvgg.py | 396 ++++++++++++++++++ 2 files changed, 423 insertions(+), 27 deletions(-) create mode 100644 examples/oneflow2onnx/models/test_repvgg.py diff --git a/examples/oneflow2onnx/models/test_inceptionv3.py b/examples/oneflow2onnx/models/test_inceptionv3.py index 8afcb82..e03e0e9 100644 --- a/examples/oneflow2onnx/models/test_inceptionv3.py +++ b/examples/oneflow2onnx/models/test_inceptionv3.py @@ -116,20 +116,20 @@ class Inception3(nn.Module): self.Mixed_7a = inception_d(768) self.Mixed_7b = inception_e(1280) self.Mixed_7c = inception_e(2048) - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.avgpool = nn.AvgPool2d((8, 8)) self.dropout = nn.Dropout() self.fc = nn.Linear(2048, num_classes) - if init_weights: - for m in self.modules(): - if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): - stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1 # type: ignore - flow.nn.init.trunc_normal_( - m.weight, mean=0.0, std=stddev, a=-2, b=2 - ) - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) + # if init_weights: + # for m in self.modules(): + # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + # stddev = float(m.stddev) if hasattr(m, "stddev") else 0.1 # type: ignore + # flow.nn.init.trunc_normal_( + # m.weight, mean=0.0, std=stddev, a=-2, b=2 + # ) + # elif isinstance(m, nn.BatchNorm2d): + # nn.init.constant_(m.weight, 1) + # nn.init.constant_(m.bias, 0) def _transform_input(self, x: Tensor) -> Tensor: if self.transform_input: @@ -191,7 +191,7 @@ class Inception3(nn.Module): # N x 2048 x = self.fc(x) # N x 1000 (num_classes) - return x, aux + return x class InceptionA(nn.Module): @@ -418,7 +418,7 @@ class InceptionAux(nn.Module): self.fc = nn.Linear(768, num_classes) self.fc.stddev = 0.001 # type: ignore[assignment] self.avg_pool = nn.AvgPool2d(kernel_size=5, stride=3) - self.adaptive_avp_pool = nn.AdaptiveAvgPool2d((1, 1)) + self.adaptive_avp_pool = nn.AvgPool2d((1, 1)) def forward(self, x: Tensor) -> Tensor: # N x 768 x 17 x 17 @@ -453,22 +453,22 @@ class BasicConv2d(nn.Module): inceptionv3 = inception_v3() inceptionv3.eval() -# class inceptionv3Graph(flow.nn.Graph): -# def __init__(self): -# super().__init__() -# self.m = inceptionv3 +class inceptionv3Graph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.m = inceptionv3 -# def build(self, x): -# out = self.m(x) -# return out + def build(self, x): + out = self.m(x) + return out -# def test_inceptionv3(): +def test_inceptionv3(): -# inceptionv3_graph = inceptionv3Graph() -# inceptionv3_graph._compile(flow.randn(1, 3, 299, 299)) + inceptionv3_graph = inceptionv3Graph() + inceptionv3_graph._compile(flow.randn(1, 3, 299, 299)) -# with tempfile.TemporaryDirectory() as tmpdirname: -# flow.save(inceptionv3.state_dict(), tmpdirname) -# convert_to_onnx_and_check(inceptionv3_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp") + with tempfile.TemporaryDirectory() as tmpdirname: + flow.save(inceptionv3.state_dict(), tmpdirname) + convert_to_onnx_and_check(inceptionv3_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp") -# test_inceptionv3() +test_inceptionv3() diff --git a/examples/oneflow2onnx/models/test_repvgg.py b/examples/oneflow2onnx/models/test_repvgg.py new file mode 100644 index 0000000..cb0b05a --- /dev/null +++ b/examples/oneflow2onnx/models/test_repvgg.py @@ -0,0 +1,396 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from oneflow import nn, Tensor +import oneflow as flow +from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check +import tempfile + +def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): + result = nn.Sequential() + result.add_module( + "conv", + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=False, + ), + ) + result.add_module("bn", nn.BatchNorm2d(num_features=out_channels)) + return result + + +class SEBlock(nn.Module): + def __init__(self, input_channels: int, internal_neurons: int): + super(SEBlock, self).__init__() + self.down = nn.Conv2d( + in_channels=input_channels, + out_channels=internal_neurons, + kernel_size=1, + stride=1, + bias=True, + ) + self.up = nn.Conv2d( + in_channels=internal_neurons, + out_channels=input_channels, + kernel_size=1, + stride=1, + bias=True, + ) + self.relu = nn.ReLU() + self.adaptive_avg_pool2d = nn.AdaptiveAvgPool2d(output_size=1) + + def forward(self, inputs: Tensor): + x = self.adaptive_avg_pool2d(inputs) + x = self.down(x) + x = self.relu(x) + x = self.up(x) + x = flow.sigmoid(x) + return inputs * x + + +class RepVGGBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + padding_mode="zeros", + deploy=False, + use_se=False, + ) -> None: + super(RepVGGBlock, self).__init__() + self.deploy = deploy + self.groups = groups + self.in_channels = in_channels + + assert kernel_size == 3 + assert padding == 1 + + padding_11 = padding - kernel_size // 2 + + self.nonlinearity = nn.ReLU() + + if use_se: + self.se = SEBlock(out_channels, internal_neurons=out_channels // 16) + else: + self.se = nn.Identity() + + if deploy: + self.rbr_reparam = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=True, + padding_mode=padding_mode, + ) + else: + self.rbr_identity = ( + nn.BatchNorm2d(num_features=in_channels) + if out_channels == in_channels and stride == 1 + else None + ) + self.rbr_dense = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + ) + self.rbr_1x1 = conv_bn( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=stride, + padding=padding_11, + groups=groups, + ) + + def forward(self, inputs): + if hasattr(self, "rbr_reparam"): + return self.non_linearity(self.se(self.rbr_reparam(inputs))) + + if self.rbr_identity is None: + id_out = 0 + else: + id_out = self.rbr_identity(inputs) + + return self.nonlinearity( + self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out) + ) + + +class RepVGG(nn.Module): + def __init__( + self, + num_blocks, + num_classes=1000, + width_multiplier=None, + override_groups_map=None, + deploy=False, + use_se=False, + ): + super(RepVGG, self).__init__() + + assert len(width_multiplier) == 4 + + self.deploy = deploy + self.override_groups_map = override_groups_map or dict() + self.use_se = use_se + + assert 0 not in self.override_groups_map + + self.in_planes = min(64, int(64 * width_multiplier[0])) + + self.stage0 = RepVGGBlock( + in_channels=3, + out_channels=self.in_planes, + kernel_size=3, + stride=2, + padding=1, + deploy=self.deploy, + use_se=self.use_se, + ) + self.cur_layer_idx = 1 + self.stage1 = self._make_stage( + int(64 * width_multiplier[0]), num_blocks[0], stride=2 + ) + self.stage2 = self._make_stage( + int(128 * width_multiplier[1]), num_blocks[1], stride=2 + ) + self.stage3 = self._make_stage( + int(256 * width_multiplier[2]), num_blocks[2], stride=2 + ) + self.stage4 = self._make_stage( + int(512 * width_multiplier[3]), num_blocks[3], stride=2 + ) + self.gap = nn.AdaptiveAvgPool2d(output_size=1) + self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes) + + def _make_stage(self, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + blocks = [] + for stride in strides: + cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1) + blocks.append( + RepVGGBlock( + in_channels=self.in_planes, + out_channels=planes, + kernel_size=3, + stride=stride, + padding=1, + groups=cur_groups, + deploy=self.deploy, + use_se=self.use_se, + ) + ) + self.in_planes = planes + self.cur_layer_idx += 1 + return nn.Sequential(*blocks) + + def forward(self, x): + out = self.stage0(x) + out = self.stage1(out) + out = self.stage2(out) + out = self.stage3(out) + out = self.stage4(out) + out = self.gap(out) + out = flow.flatten(out, 1) + out = self.linear(out) + return out + + +optional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26] +g2_map = {l: 2 for l in optional_groupwise_layers} +g4_map = {l: 4 for l in optional_groupwise_layers} + + +def create_RepVGG_A0(deploy=False): + return RepVGG( + num_blocks=[2, 4, 14, 1], + num_classes=1000, + width_multiplier=[0.75, 0.75, 0.75, 2.5], + override_groups_map=None, + deploy=deploy, + ) + + +def create_RepVGG_A1(deploy=False): + return RepVGG( + num_blocks=[2, 4, 14, 1], + num_classes=1000, + width_multiplier=[1, 1, 1, 2.5], + override_groups_map=None, + deploy=deploy, + ) + + +def create_RepVGG_A2(deploy=False): + return RepVGG( + num_blocks=[2, 4, 14, 1], + num_classes=1000, + width_multiplier=[1.5, 1.5, 1.5, 2.75], + override_groups_map=None, + deploy=deploy, + ) + + +def create_RepVGG_B0(deploy=False): + return RepVGG( + num_blocks=[4, 6, 16, 1], + num_classes=1000, + width_multiplier=[1, 1, 1, 2.5], + override_groups_map=None, + deploy=deploy, + ) + + +def create_RepVGG_B1(deploy=False): + return RepVGG( + num_blocks=[4, 6, 16, 1], + num_classes=1000, + width_multiplier=[2, 2, 2, 4], + override_groups_map=None, + deploy=deploy, + ) + + +def create_RepVGG_B1g2(deploy=False): + return RepVGG( + num_blocks=[4, 6, 16, 1], + num_classes=1000, + width_multiplier=[2, 2, 2, 4], + override_groups_map=g2_map, + deploy=deploy, + ) + + +def create_RepVGG_B1g4(deploy=False): + return RepVGG( + num_blocks=[4, 6, 16, 1], + num_classes=1000, + width_multiplier=[2, 2, 2, 4], + override_groups_map=g4_map, + deploy=deploy, + ) + + +def create_RepVGG_B2(deploy=False): + return RepVGG( + num_blocks=[4, 6, 16, 1], + num_classes=1000, + width_multiplier=[2.5, 2.5, 2.5, 5], + override_groups_map=None, + deploy=deploy, + ) + + +def create_RepVGG_B2g2(deploy=False): + return RepVGG( + num_blocks=[4, 6, 16, 1], + num_classes=1000, + width_multiplier=[2.5, 2.5, 2.5, 5], + override_groups_map=g2_map, + deploy=deploy, + ) + + +def create_RepVGG_B2g4(deploy=False): + return RepVGG( + num_blocks=[4, 6, 16, 1], + num_classes=1000, + width_multiplier=[2.5, 2.5, 2.5, 5], + override_groups_map=g4_map, + deploy=deploy, + ) + + +def create_RepVGG_B3(deploy=False): + return RepVGG( + num_blocks=[4, 6, 16, 1], + num_classes=1000, + width_multiplier=[3, 3, 3, 5], + override_groups_map=None, + deploy=deploy, + ) + + +def create_RepVGG_B3g2(deploy=False): + return RepVGG( + num_blocks=[4, 6, 16, 1], + num_classes=1000, + width_multiplier=[3, 3, 3, 5], + override_groups_map=g2_map, + deploy=deploy, + ) + + +def create_RepVGG_B3g4(deploy=False): + return RepVGG( + num_blocks=[4, 6, 16, 1], + num_classes=1000, + width_multiplier=[3, 3, 3, 5], + override_groups_map=g4_map, + deploy=deploy, + ) + + +def create_RepVGG_D2se(deploy=False): + return RepVGG( + num_blocks=[8, 14, 24, 1], + num_classes=1000, + width_multiplier=[2.5, 2.5, 2.5, 5], + override_groups_map=None, + deploy=deploy, + use_se=True, + ) + +repvgg = create_RepVGG_B2g4() +repvgg.eval() +repvgg = repvgg.to("cuda") + +class RepVGGGraph(flow.nn.Graph): + def __init__(self): + super().__init__() + self.m = repvgg + + def build(self, x): + out = self.m(x) + return out + +def test_repvgg(): + + repvgg_graph = RepVGGGraph() + repvgg_graph._compile(flow.randn(1, 3, 224, 224).to("cuda")) + + with tempfile.TemporaryDirectory() as tmpdirname: + flow.save(repvgg.state_dict(), tmpdirname) + convert_to_onnx_and_check(repvgg_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp") + +test_repvgg() -- GitLab