提交 e99bfb25 编写于 作者: B BBuf

support resnet50

上级 ceea2420
...@@ -13,182 +13,305 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,182 +13,305 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import oneflow as flow import oneflow as flow
import oneflow.typing as tp import oneflow.nn as nn
import onnx from oneflow import Tensor
import onnxruntime as ort from typing import Type, Any, Callable, Union, List, Optional
import numpy as np
from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check
import tempfile
def conv3x3(
BLOCK_COUNTS = [3, 4, 6, 3] in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1
BLOCK_FILTERS = [256, 512, 1024, 2048] ) -> nn.Conv2d:
BLOCK_FILTERS_INNER = [64, 128, 256, 512] """3x3 convolution with padding"""
return nn.Conv2d(
g_trainable = False in_planes,
out_planes,
kernel_size=3,
def _conv2d( stride=stride,
name, padding=dilation,
input, groups=groups,
filters, bias=False,
kernel_size, dilation=dilation,
strides=1,
padding="SAME",
data_format="NCHW",
dilations=1,
trainable=True,
# weight_initializer=flow.variance_scaling_initializer(data_format="NCHW"),
weight_initializer=flow.variance_scaling_initializer(
2, "fan_in", "random_normal", data_format="NCHW"
),
weight_regularizer=flow.regularizers.l2(1.0 / 32768),
):
weight = flow.get_variable(
name + "-weight",
shape=(filters, input.shape[1], kernel_size, kernel_size),
dtype=input.dtype,
initializer=weight_initializer,
regularizer=weight_regularizer,
model_name="weight",
trainable=trainable,
)
return flow.nn.conv2d(
input, weight, strides, padding, data_format, dilations, name=name
)
def _batch_norm(inputs, name=None, trainable=True):
return flow.layers.batch_normalization(
inputs=inputs,
axis=1,
momentum=0.9, # 97,
epsilon=1.001e-5,
center=True,
scale=True,
trainable=trainable,
training=trainable,
name=name,
)
def conv2d_affine(input, name, filters, kernel_size, strides, activation=None):
# input data_format must be NCHW, cannot check now
padding = "SAME" if strides > 1 or kernel_size > 1 else "VALID"
output = _conv2d(
name, input, filters, kernel_size, strides, padding, trainable=g_trainable
)
output = _batch_norm(output, name + "_bn", trainable=g_trainable)
if activation == "Relu":
output = flow.math.relu(output)
return output
def bottleneck_transformation(input, block_name, filters, filters_inner, strides):
a = conv2d_affine(
input, block_name + "_branch2a", filters_inner, 1, 1, activation="Relu",
)
b = conv2d_affine(
a, block_name + "_branch2b", filters_inner, 3, strides, activation="Relu",
) )
c = conv2d_affine(b, block_name + "_branch2c", filters, 1, 1)
return c def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
def residual_block(input, block_name, filters, filters_inner, strides_init):
if strides_init != 1 or block_name == "res2_0":
shortcut = conv2d_affine( class BasicBlock(nn.Module):
input, block_name + "_branch1", filters, 1, strides_init expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU()
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(
self,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
num_classes: int = 1000,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
"replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)
)
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(
3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
) )
else: self.bn1 = norm_layer(self.inplanes)
shortcut = input self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
bottleneck = bottleneck_transformation( self.layer1 = self._make_layer(block, 64, layers[0])
input, block_name, filters, filters_inner, strides_init self.layer2 = self._make_layer(
) block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
return flow.math.relu(bottleneck + shortcut)
def residual_stage(input, stage_name, counts, filters, filters_inner, stride_init=2):
output = input
for i in range(counts):
block_name = "%s_%d" % (stage_name, i)
output = residual_block(
output, block_name, filters, filters_inner, stride_init if i == 0 else 1,
) )
self.layer3 = self._make_layer(
return output block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
def resnet_conv_x_body(input, on_stage_end=lambda x: x):
output = input
for i, (counts, filters, filters_inner) in enumerate(
zip(BLOCK_COUNTS, BLOCK_FILTERS, BLOCK_FILTERS_INNER)
):
stage_name = "res%d" % (i + 2)
output = residual_stage(
output, stage_name, counts, filters, filters_inner, 1 if i == 0 else 2,
) )
on_stage_end(output) self.layer4 = self._make_layer(
block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
return output
def resnet_stem(input):
conv1 = _conv2d("conv1", input, 1, 1, 2)
tmp = _batch_norm(conv1, "conv1_bn", trainable=g_trainable)
conv1_bn = flow.math.relu(tmp)
pool1 = flow.nn.max_pool2d(
conv1_bn, ksize=3, strides=2, padding="VALID", data_format="NCHW", name="pool1",
)
return pool1
def resnet50(images, trainable=True, need_transpose=False):
# note: images.shape = (N C H W) in cc's new dataloader, transpose is not needed anymore
if need_transpose:
images = flow.transpose(images, name="transpose", perm=[0, 3, 1, 2])
with flow.scope.namespace("Resnet"):
stem = resnet_stem(images)
body = resnet_conv_x_body(stem, lambda x: x)
pool5 = flow.nn.avg_pool2d(
body, ksize=7, strides=1, padding="VALID", data_format="NCHW", name="pool5",
) )
self.avgpool = nn.AvgPool2d((7, 7))
fc1001 = flow.layers.dense( self.fc = nn.Linear(512 * block.expansion, num_classes)
flow.reshape(pool5, (pool5.shape[0], -1)),
units=1000, for m in self.modules():
use_bias=True, if isinstance(m, nn.Conv2d):
kernel_initializer=flow.variance_scaling_initializer( nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
2, "fan_in", "random_normal" elif isinstance(m, nn.BatchNorm2d):
), nn.init.constant_(m.weight, 1)
# kernel_initializer=flow.xavier_uniform_initializer(), nn.init.constant_(m.bias, 0)
bias_initializer=flow.random_uniform_initializer(),
kernel_regularizer=flow.regularizers.l2(1.0 / 32768), # Zero-initialize the last BN in each residual branch,
trainable=trainable, # so that the residual branch starts with zeros, and each residual block behaves like an identity.
name="fc1001", # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
def _make_layer(
self,
block: Type[Union[BasicBlock, Bottleneck]],
planes: int,
blocks: int,
stride: int = 1,
dilate: bool = False,
) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
self.groups,
self.base_width,
previous_dilation,
norm_layer,
)
) )
self.inplanes = planes * block.expansion
return fc1001 for _ in range(1, blocks):
layers.append(
block(
def test_resnet50(): self.inplanes,
@flow.global_function() planes,
def InferenceNet(images: tp.Numpy.Placeholder((1, 3, 224, 224))): groups=self.groups,
logits = resnet50(images) base_width=self.base_width,
dilation=self.dilation,
predictions = flow.nn.softmax(logits) norm_layer=norm_layer,
return predictions )
)
convert_to_onnx_and_check(InferenceNet, flow_weight_dir=None, onnx_model_path="/tmp")
return nn.Sequential(*layers)
def _forward_impl(self, x: Tensor) -> Tensor:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = flow.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def _resnet(
arch: str,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
**kwargs: Any
) -> ResNet:
model = ResNet(block, layers, **kwargs)
return model
def resnet50(**kwargs: Any) -> ResNet:
r"""ResNet-5
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
"""
return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], **kwargs)
resnet = resnet50()
resnet = resnet.to("cuda")
resnet.eval()
class ResNetGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = resnet
def build(self, x):
out = self.m(x)
return out
def test_resnet():
resnet_graph = ResNetGraph()
resnet_graph._compile(flow.randn(1, 3, 224, 224).to("cuda"))
with tempfile.TemporaryDirectory() as tmpdirname:
flow.save(resnet.state_dict(), tmpdirname)
convert_to_onnx_and_check(resnet_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp", print_outlier=False)
test_resnet()
...@@ -86,7 +86,9 @@ def FlowToOnnxNaive(graph, shape_override): ...@@ -86,7 +86,9 @@ def FlowToOnnxNaive(graph, shape_override):
for order in node.user_conf.input_order: for order in node.user_conf.input_order:
for key, val in node.user_conf.input.items(): for key, val in node.user_conf.input.items():
if key == order: if key == order:
res.append(val.s[0]) for _ in range(len(val.s)):
res.append(val.s[_])
return res return res
ipts = [] ipts = []
for ibn in ibns: for ibn in ibns:
...@@ -122,7 +124,8 @@ def FlowToOnnxNaive(graph, shape_override): ...@@ -122,7 +124,8 @@ def FlowToOnnxNaive(graph, shape_override):
for order in node.user_conf.output_order: for order in node.user_conf.output_order:
for key, val in node.user_conf.output.items(): for key, val in node.user_conf.output.items():
if key == order: if key == order:
res.append(val.s[0]) for _ in range(len(val.s)):
res.append(val.s[_])
return res return res
outputs = [] outputs = []
for obn in obns: for obn in obns:
...@@ -169,8 +172,6 @@ def FlowToOnnxNaive(graph, shape_override): ...@@ -169,8 +172,6 @@ def FlowToOnnxNaive(graph, shape_override):
op_type = get_op_type(node) op_type = get_op_type(node)
input_names = get_inputs(node) input_names = get_inputs(node)
output_names = get_outputs(node) output_names = get_outputs(node)
input_order = node.user_conf.input_order
output_order = node.user_conf.output_order
onnx_node = helper.make_node( onnx_node = helper.make_node(
op_type, input_names, output_names, name=node.name, **attr op_type, input_names, output_names, name=node.name, **attr
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册