提交 eb340bd4 编写于 作者: B BBuf

add conv2d op

上级 96a344ae
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
import oneflow.typing as tp
from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check
def test_bias_add_nchw():
@flow.global_function()
def bias_add_nchw(x: tp.Numpy.Placeholder((3, 4, 2, 5))):
y = flow.get_variable(
name="y",
shape=(4,),
dtype=flow.float,
initializer=flow.random_uniform_initializer(),
)
return flow.nn.bias_add(x, y, "NCHW")
convert_to_onnx_and_check(bias_add_nchw)
def test_bias_add_nhwc():
@flow.global_function()
def bias_add_nhwc(x: tp.Numpy.Placeholder((3, 4, 2, 5))):
y = flow.get_variable(
name="y",
shape=(5,),
dtype=flow.float,
initializer=flow.random_uniform_initializer(),
)
return flow.nn.bias_add(x, y, "NHWC")
convert_to_onnx_and_check(bias_add_nhwc)
...@@ -13,42 +13,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,42 +13,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import tempfile
import oneflow as flow import oneflow as flow
import oneflow.typing as tp
from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check
class ConCat(flow.nn.Module):
def __init__(self) -> None:
super(ConCat, self).__init__()
def forward(self, x: flow.Tensor) -> flow.Tensor:
return flow.cat([x, x, x], dim=1)
def test_concat_axis0(): concat = ConCat()
@flow.global_function() class ConCatOpGraph(flow.nn.Graph):
def concat(): def __init__(self):
variables = [] super().__init__()
for i in range(4): self.m = concat
variables.append(
flow.get_variable( def build(self, x):
name=str(i), out = self.m(x)
shape=(2, 3), return out
dtype=flow.float,
initializer=flow.random_uniform_initializer(),
) def test_concat():
)
return flow.concat(variables, axis=0) concat_graph = ConCatOpGraph()
concat_graph._compile(flow.randn(1, 3, 224, 224))
convert_to_onnx_and_check(concat)
with tempfile.TemporaryDirectory() as tmpdirname:
flow.save(concat.state_dict(), tmpdirname)
def test_concat_axis1(): convert_to_onnx_and_check(concat_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp")
@flow.global_function()
def concat(): test_concat()
variables = []
for i in range(4):
variables.append(
flow.get_variable(
name=str(i),
shape=(2, 3),
dtype=flow.float,
initializer=flow.random_uniform_initializer(),
)
)
return flow.concat(variables, axis=1)
convert_to_onnx_and_check(concat)
...@@ -13,21 +13,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,21 +13,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import tempfile
import oneflow as flow import oneflow as flow
import oneflow.typing as tp
from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check
class Constant(flow.nn.Module):
def __init__(self) -> None:
super(Constant, self).__init__()
def forward(self, x: flow.Tensor) -> flow.Tensor:
return flow.ones((2, 3)) + flow.zeros((2, 3))
def test_constant_float(): constant = Constant()
@flow.global_function() class constantOpGraph(flow.nn.Graph):
def constant(x: tp.Numpy.Placeholder((3, 5))): def __init__(self):
return flow.constant(value=1.5, shape=(1, 3, 3), dtype=flow.float) super().__init__()
self.m = constant
convert_to_onnx_and_check(constant) def build(self, x):
out = self.m(x)
return out
def test_constant_int():
@flow.global_function()
def constant(x: tp.Numpy.Placeholder((3, 5))):
return flow.constant(value=1, shape=(1, 3, 3), dtype=flow.int)
convert_to_onnx_and_check(constant) def test_constant():
\ No newline at end of file
constant_graph = constantOpGraph()
constant_graph._compile(flow.randn(1, 3, 224, 224))
with tempfile.TemporaryDirectory() as tmpdirname:
flow.save(constant.state_dict(), tmpdirname)
convert_to_onnx_and_check(constant_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp")
test_constant()
...@@ -13,106 +13,36 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,106 +13,36 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import tempfile
import oneflow as flow import oneflow as flow
import oneflow.typing as tp
from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check
initializer = flow.random_uniform_initializer() class Conv2d(flow.nn.Module):
initer_args = {"kernel_initializer": initializer, "bias_initializer": initializer} def __init__(self) -> None:
super(Conv2d, self).__init__()
self.conv = flow.nn.Conv2d(3, 16, 3)
def forward(self, x: flow.Tensor) -> flow.Tensor:
return self.conv(x)
def test_conv2d_k2d1_valid(): conv = Conv2d()
@flow.global_function() class convOpGraph(flow.nn.Graph):
def conv2d_k3s1_valid(x: tp.Numpy.Placeholder((2, 4, 3, 5))): def __init__(self):
return flow.layers.conv2d( super().__init__()
x, 6, kernel_size=3, strides=1, padding="VALID", **initer_args self.m = conv
)
convert_to_onnx_and_check(conv2d_k3s1_valid) def build(self, x):
out = self.m(x)
return out
def test_conv2d_s2_valid(): def test_conv():
@flow.global_function()
def conv2d_s2_valid(x: tp.Numpy.Placeholder((2, 4, 3, 5))): conv_graph = convOpGraph()
return flow.layers.conv2d( conv_graph._compile(flow.randn(1, 3, 224, 224))
x, 6, kernel_size=1, strides=2, padding="VALID", **initer_args
)
convert_to_onnx_and_check(conv2d_s2_valid) with tempfile.TemporaryDirectory() as tmpdirname:
flow.save(conv.state_dict(), tmpdirname)
convert_to_onnx_and_check(conv_graph, flow_weight_dir=tmpdirname, onnx_model_path="/tmp")
test_conv()
def test_conv2d_s2_same():
@flow.global_function()
def conv2d_s2_same(x: tp.Numpy.Placeholder((2, 4, 3, 5))):
return flow.layers.conv2d(
x, 6, kernel_size=3, strides=2, padding="SAME", **initer_args
)
convert_to_onnx_and_check(conv2d_s2_same)
def test_conv2d_k3s1_nhwc_valid():
@flow.global_function()
def conv2d_k3s1_nhwc_valid(x: tp.Numpy.Placeholder((2, 3, 5, 4))):
return flow.layers.conv2d(
x,
6,
kernel_size=3,
strides=1,
padding="VALID",
data_format="NHWC",
**initer_args
)
convert_to_onnx_and_check(conv2d_k3s1_nhwc_valid)
def test_conv2d_k3s1_nhwc_same_d2():
@flow.global_function()
def conv2d(x: tp.Numpy.Placeholder((2, 7, 5, 4))):
return flow.layers.conv2d(
x,
6,
kernel_size=3,
strides=1,
dilation_rate=2,
padding="SAME",
data_format="NHWC",
**initer_args
)
convert_to_onnx_and_check(conv2d)
def test_conv2d_k3s1_nchw_same_g2():
@flow.global_function()
def conv2d(x: tp.Numpy.Placeholder((2, 4, 5, 3))):
return flow.layers.conv2d(
x,
6,
kernel_size=3,
strides=1,
groups=2,
padding="SAME",
data_format="NCHW",
**initer_args
)
convert_to_onnx_and_check(conv2d)
def test_conv2d_k3s1_nchw_same_depthwise():
@flow.global_function()
def conv2d(x: tp.Numpy.Placeholder((2, 4, 5, 3))):
return flow.layers.conv2d(
x,
4,
kernel_size=3,
strides=1,
groups=4,
padding="SAME",
data_format="NCHW",
**initer_args
)
convert_to_onnx_and_check(conv2d)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册