提交 6aff0f70 编写于 作者: B BBuf

support slice op

上级 d4352a5e
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from oneflow import Tensor
import oneflow.nn as nn
from typing import Callable, Any, List
from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check
import tempfile
def channel_shuffle(x: Tensor, groups: int) -> Tensor:
batchsize, num_channels, height, width = x.size()
channels_per_group = num_channels // groups
# reshape
x = flow.reshape(x, [batchsize, groups, channels_per_group, height, width])
x = flow.transpose(x, 1, 2)
# flatten
x = flow.reshape(x, [batchsize, -1, height, width])
return x
class InvertedResidual(nn.Module):
def __init__(self, inp: int, oup: int, stride: int) -> None:
super().__init__()
if not (1 <= stride <= 3):
raise ValueError("illegal stride value")
self.stride = stride
branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)
if self.stride > 1:
self.branch1 = nn.Sequential(
self.depthwise_conv(
inp, inp, kernel_size=3, stride=self.stride, padding=1
),
nn.BatchNorm2d(inp),
nn.Conv2d(
inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False
),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
else:
self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential(
nn.Conv2d(
inp if (self.stride > 1) else branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(
branch_features,
branch_features,
kernel_size=3,
stride=self.stride,
padding=1,
),
nn.BatchNorm2d(branch_features),
nn.Conv2d(
branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False,
),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
@staticmethod
def depthwise_conv(
i: int,
o: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
bias: bool = False,
) -> nn.Conv2d:
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
def forward(self, x: Tensor) -> Tensor:
if self.stride == 1:
cnt_at_dim1 = int(x.shape[1] / 2)
x1 = x[:, 0:cnt_at_dim1, ::]
x2 = x[:, cnt_at_dim1:, ::]
out = flow.cat((x1, self.branch2(x2)), dim=1)
else:
out = flow.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
class ShuffleNetV2(nn.Module):
def __init__(
self,
stages_repeats: List[int],
stages_out_channels: List[int],
num_classes: int = 1000,
inverted_residual: Callable[..., nn.Module] = InvertedResidual,
) -> None:
super().__init__()
if len(stages_repeats) != 3:
raise ValueError("expected stages_repeats as list of 3 positive ints")
if len(stages_out_channels) != 5:
raise ValueError("expected stages_out_channels as list of 5 positive ints")
self._stage_out_channels = stages_out_channels
input_channels = 3
output_channels = self._stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
input_channels = output_channels
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Static annotations for mypy
self.stage2: nn.Sequential
self.stage3: nn.Sequential
self.stage4: nn.Sequential
stage_names = ["stage{}".format(i) for i in [2, 3, 4]]
for name, repeats, output_channels in zip(
stage_names, stages_repeats, self._stage_out_channels[1:]
):
seq = [inverted_residual(input_channels, output_channels, 2)]
for i in range(repeats - 1):
seq.append(inverted_residual(output_channels, output_channels, 1))
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels
output_channels = self._stage_out_channels[-1]
self.conv5 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True),
)
self.fc = nn.Linear(output_channels, num_classes)
def _forward_impl(self, x: Tensor) -> Tensor:
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.conv5(x)
x = x.mean([2, 3]) # globalpool
x = self.fc(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def _shufflenetv2(arch: str, *args: Any, **kwargs: Any):
model = ShuffleNetV2(*args, **kwargs)
return model
def shufflenetv2_x0dot5():
return ShuffleNetV2([4, 8, 4], [24, 48, 96, 192, 1024])
def shufflenetv2_x1():
return ShuffleNetV2([4, 8, 4], [24, 116, 232, 464, 1024])
shufflenet = shufflenetv2_x0dot5()
shufflenet.eval()
shufflenet = shufflenet.to("cuda")
class shufflenetGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = shufflenet
def build(self, x):
out = self.m(x)
return out
def test_shufflenet():
shufflenet_graph = shufflenetGraph()
shufflenet_graph._compile(flow.randn(1, 3, 224, 224).to("cuda"))
with tempfile.TemporaryDirectory() as tmpdirname:
flow.save(shufflenet.state_dict(), tmpdirname)
convert_to_onnx_and_check(shufflenet_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp")
test_shufflenet()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import tempfile
import oneflow as flow
from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check
class Slice(flow.nn.Module):
def __init__(self) -> None:
super(Slice, self).__init__()
def forward(self, x: flow.Tensor) -> flow.Tensor:
return x[:, :1, :, :]
slice = Slice()
class sliceOpGraph(flow.nn.Graph):
def __init__(self):
super().__init__()
self.m = slice
def build(self, x):
out = self.m(x)
return out
def test_slice():
slice_graph = sliceOpGraph()
slice_graph._compile(flow.randn(1, 3, 224, 224))
# print(slice_graph._full_graph_proto)
with tempfile.TemporaryDirectory() as tmpdirname:
flow.save(slice.state_dict(), tmpdirname)
convert_to_onnx_and_check(slice_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp")
test_slice()
...@@ -207,6 +207,29 @@ class Concat: ...@@ -207,6 +207,29 @@ class Concat:
cls.Version_1(ctx, node, **kwargs) cls.Version_1(ctx, node, **kwargs)
@flow_op("slice", "Slice")
class Slice:
@classmethod
def Version_1(cls, ctx, node, **kwargs):
starts = ctx.MakeConst(oneflow._oneflow_internal.UniqueStr("start"), np.array(node.attrs["start"]).astype(np.int64))
node.input_tensor_names.append(starts.output_tensor_names[0])
ends = ctx.MakeConst(oneflow._oneflow_internal.UniqueStr("stop"), np.array(node.attrs["stop"]).astype(np.int64))
node.input_tensor_names.append(ends.output_tensor_names[0])
slice_axes = []
input_shape = ctx.get_shape(node.input_tensor_names[0])
for i in range(len(input_shape)):
slice_axes.append(i)
axes = ctx.MakeConst(oneflow._oneflow_internal.UniqueStr("axes"), np.array(slice_axes).astype(np.int64))
node.input_tensor_names.append(axes.output_tensor_names[0])
steps = ctx.MakeConst(oneflow._oneflow_internal.UniqueStr("steps"), np.array(node.attrs["step"]).astype(np.int64))
node.input_tensor_names.append(steps.output_tensor_names[0])
@classmethod
def Version_11(cls, ctx, node, **kwargs):
cls.Version_1(ctx, node, **kwargs)
@flow_op("gather_nd", onnx_op="GatherND", flow_ibns=["params", "indices"]) @flow_op("gather_nd", onnx_op="GatherND", flow_ibns=["params", "indices"])
class GatherND: class GatherND:
@classmethod @classmethod
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册