未验证 提交 cca7b8cc 编写于 作者: X Xiaoyu Zhang 提交者: GitHub

Merge pull request #32 from Oneflow-Inc/add_insightface_export_onnx

add test insightface
......@@ -97,4 +97,6 @@ python3 setup.py install
- 2020/4/16 将Expand OP并入主分支,并修复导入oneflow_api报错的bug,发布0.3.1 whell包。
- 2020/4/16 解决自动代码生成遗留问题,并将自动代码生成的测试加入CI,发布0.3.2 whell包。
- 2020/6/21 导出ONNX新增PreLU/LeakyReLU OP,修复自动代码生成bug,发布0.3.3 whell包。
- 2020/6/23 导出ONNX新增Constant OP,修复BN只有NC两个维度(InsightFace)导出的bug以及禁用导出ONNX时默认开启的global function,发布0.3.3.20210623 whell包。
......@@ -97,3 +97,5 @@ This tool is to transform OneFlow models into models that can be used on the Ser
- 2020/4/16 Merged Expand OP into the main branch. Debugged oneflow_api. 0.3.1 whell package was announced.
- 2020/4/16 Solve the remaining problems of automatic code generation, add the test of automatic code generation to CI. 0.3.2 whell package was announced.
- 2020/6/21 Add PreLU/LeakyReLU OP ONNX Export, fix automatic code generation bug. 0.3.3 whell package was announced.
- 2020/6/32 Add Constant OP ONNX Export,fix batchnorm op export onnx bug (which caused in insightface) and disable the global function that
is enabled by default when exporting ONNX . 0.3.3.20210623 whell package was announced.
......@@ -22,4 +22,4 @@
| 58 | ReduceSum| 59 | ReduceProd | 60 | ArgMax | 61 | ArgMin |
|62 | Reshape | 63 | Squeeze | 64 | Transpose| 65 | Concat |
| 66 | Cast | 67 | Identity | 68 | Mul | 69 | PReLU |
| 70 | LeakyReLU|
| 70 | LeakyReLU| 71 | Constant |
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
import oneflow.typing as tp
import onnx
import onnxruntime as ort
import numpy as np
from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check
import oneflow as flow
def _get_initializer():
return flow.random_normal_initializer(mean=0.0, stddev=0.1)
def _get_regularizer(name):
return None
def _dropout(input_blob, dropout_prob):
return flow.nn.dropout(input_blob, rate=dropout_prob)
def _prelu(inputs, data_format="NCHW", name=None):
return flow.layers.prelu(
inputs,
alpha_initializer=flow.constant_initializer(0.25),
alpha_regularizer=_get_regularizer("alpha"),
shared_axes=[2, 3] if data_format == "NCHW" else [1, 2],
name=name,
)
def _avg_pool(inputs, pool_size, strides, padding, data_format="NCHW", name=None):
return flow.nn.avg_pool2d(
input=inputs, ksize=pool_size, strides=strides, padding=padding, data_format=data_format, name=name
)
def _batch_norm(
inputs,
epsilon,
center=True,
scale=True,
trainable=True,
is_training=True,
data_format="NCHW",
name=None,
):
return flow.layers.batch_normalization(
inputs=inputs,
axis=3 if data_format == "NHWC" and inputs.shape == 4 else 1,
momentum=0.9,
epsilon=epsilon,
center=center,
scale=scale,
beta_initializer=flow.zeros_initializer(),
gamma_initializer=flow.ones_initializer(),
beta_regularizer=_get_regularizer("beta"),
gamma_regularizer=_get_regularizer("gamma"),
moving_mean_initializer=flow.zeros_initializer(),
moving_variance_initializer=flow.ones_initializer(),
trainable=trainable,
training=is_training,
name=name,
)
def _conv2d_layer(
name,
input,
filters,
kernel_size=3,
strides=1,
padding="SAME",
group_num=1,
data_format="NCHW",
dilation_rate=1,
activation=None,
use_bias=False,
weight_initializer=_get_initializer(),
bias_initializer=flow.zeros_initializer(),
weight_regularizer=_get_regularizer("weight"),
bias_regularizer=_get_regularizer("bias"),
):
return flow.layers.conv2d(inputs=input, filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, groups=group_num, activation=activation, use_bias=use_bias, kernel_initializer=weight_initializer, bias_initializer=bias_initializer, kernel_regularizer=weight_regularizer, bias_regularizer=bias_regularizer, name=name)
def Linear(
input_blob,
num_filter=1,
kernel=None,
stride=None,
pad="valid",
num_group=1,
bn_is_training=True,
data_format="NCHW",
name=None,
suffix="",
):
conv = _conv2d_layer(
name="%s%s_conv2d" % (name, suffix),
input=input_blob,
filters=num_filter,
kernel_size=kernel,
strides=stride,
padding=pad,
data_format=data_format,
group_num=num_group,
use_bias=False,
dilation_rate=1,
activation=None,
)
bn = _batch_norm(
conv,
epsilon=0.001,
is_training=bn_is_training,
data_format=data_format,
name="%s%s_batchnorm" % (name, suffix),
)
return bn
def get_fc1(last_conv, num_classes, fc_type, input_channel=512):
body = last_conv
if fc_type == "Z":
body = _batch_norm(
body,
epsilon=2e-5,
scale=False,
center=True,
is_training=False,
data_format="NCHW",
name="bn1"
)
body = _dropout(body, 0.4)
fc1 = body
elif fc_type == "E":
body = _batch_norm(
body,
epsilon=2e-5,
is_training=False,
data_format="NCHW",
name="bn1"
)
body = _dropout(body, dropout_prob=0.4)
body = flow.reshape(body, (body.shape[0], -1))
fc1 = flow.layers.dense(
inputs=body,
units=num_classes,
activation=None,
use_bias=True,
kernel_initializer=_get_initializer(),
bias_initializer=flow.zeros_initializer(),
kernel_regularizer=_get_regularizer("weight"),
bias_regularizer=_get_regularizer("bias"),
trainable=True,
name="pre_fc1",
)
fc1 = _batch_norm(
fc1,
epsilon=2e-5,
scale=False,
center=True,
is_training=False,
data_format="NCHW",
name="fc1",
)
elif fc_type == "FC":
body = _batch_norm(
body,
epsilon=2e-5,
is_training=False,
data_format="NCHW",
name="bn1"
)
body = flow.reshape(body, (body.shape[0], -1))
fc1 = flow.layers.dense(
inputs=body,
units=num_classes,
activation=None,
use_bias=True,
kernel_initializer=_get_initializer(),
bias_initializer=flow.zeros_initializer(),
kernel_regularizer=_get_regularizer("weight"),
bias_regularizer=_get_regularizer("bias"),
trainable=True,
name="pre_fc1"
)
fc1 = _batch_norm(
fc1,
epsilon=2e-5,
scale=False,
center=True,
is_training=False,
data_format="NCHW",
name="fc1"
)
elif fc_type == "GDC":
conv_6_dw = Linear(
last_conv,
num_filter=input_channel, # 512
num_group=input_channel, # 512
kernel=7,
pad="valid",
stride=[1, 1],
bn_is_training=False,
data_format="NCHW",
name="conv_6dw7_7",
)
conv_6_dw = flow.reshape(conv_6_dw, (body.shape[0], -1))
conv_6_f = flow.layers.dense(
inputs=conv_6_dw,
units=num_classes,
activation=None,
use_bias=True,
kernel_initializer=_get_initializer(),
bias_initializer=flow.zeros_initializer(),
kernel_regularizer=_get_regularizer("weight"),
bias_regularizer=_get_regularizer("bias"),
trainable=True,
name="pre_fc1",
)
fc1 = _batch_norm(
conv_6_f,
epsilon=2e-5,
scale=False,
center=True,
is_training=False,
data_format="NCHW",
name="fc1",
)
return fc1
def residual_unit_v3(
in_data, num_filter, stride, dim_match, bn_is_training, data_format, name
):
suffix = ""
use_se = 0
bn1 = _batch_norm(
in_data,
epsilon=2e-5,
is_training=bn_is_training,
data_format=data_format,
name="%s%s_bn1" % (name, suffix),
)
conv1 = _conv2d_layer(
name="%s%s_conv1" % (name, suffix),
input=bn1,
filters=num_filter,
kernel_size=3,
strides=[1, 1],
padding="same",
data_format=data_format,
use_bias=False,
dilation_rate=1,
activation=None,
)
bn2 = _batch_norm(
conv1,
epsilon=2e-5,
is_training=bn_is_training,
data_format=data_format,
name="%s%s_bn2" % (name, suffix),
)
prelu = _prelu(bn2, data_format=data_format,
name="%s%s_relu1" % (name, suffix))
conv2 = _conv2d_layer(
name="%s%s_conv2" % (name, suffix),
input=prelu,
filters=num_filter,
kernel_size=3,
strides=stride,
padding="same",
data_format=data_format,
use_bias=False,
dilation_rate=1,
activation=None,
)
bn3 = _batch_norm(
conv2,
epsilon=2e-5,
is_training=bn_is_training,
data_format=data_format,
name="%s%s_bn3" % (name, suffix),
)
if dim_match:
input_blob = in_data
else:
input_blob = _conv2d_layer(
name="%s%s_conv1sc" % (name, suffix),
input=in_data,
filters=num_filter,
kernel_size=1,
strides=stride,
padding="valid",
data_format=data_format,
use_bias=False,
dilation_rate=1,
activation=None,
)
input_blob = _batch_norm(
input_blob,
epsilon=2e-5,
is_training=bn_is_training,
data_format=data_format,
name="%s%s_sc" % (name, suffix),
)
identity = flow.math.add(x=bn3, y=input_blob)
return identity
def get_symbol(input_blob):
filter_list = [64, 64, 128, 256, 512]
num_stages = 4
units = [3, 13, 30, 3]
num_classes =512
fc_type = 'E'
bn_is_training = False
data_format = "NCHW"
if data_format.upper() == "NCHW":
input_blob = flow.transpose(
input_blob, name="transpose", perm=[0, 3, 1, 2]
)
input_blob = _conv2d_layer(
name="conv0",
input=input_blob,
filters=filter_list[0],
kernel_size=3,
strides=[1, 1],
padding="same",
data_format=data_format,
use_bias=False,
dilation_rate=1,
activation=None,
)
input_blob = _batch_norm(
input_blob, epsilon=2e-5, is_training=bn_is_training, data_format=data_format, name="bn0"
)
input_blob = _prelu(input_blob, data_format=data_format, name="relu0")
for i in range(num_stages):
input_blob = residual_unit_v3(
input_blob,
filter_list[i + 1],
[2, 2],
False,
bn_is_training=bn_is_training,
data_format=data_format,
name="stage%d_unit%d" % (i + 1, 1),
)
for j in range(units[i] - 1):
input_blob = residual_unit_v3(
input_blob,
filter_list[i + 1],
[1, 1],
True,
bn_is_training=bn_is_training,
data_format=data_format,
name="stage%d_unit%d" % (i + 1, j + 2),
)
fc1 = get_fc1(input_blob, num_classes, fc_type)
return fc1
def test_insightface():
@flow.global_function()
def InferenceNet(images: tp.Numpy.Placeholder((1, 3, 112, 112))):
logits = get_symbol(images)
return logits
convert_to_onnx_and_check(InferenceNet, print_outlier=True, flow_weight_dir=None, onnx_model_path="/tmp")
......@@ -84,3 +84,36 @@ def test_bn_nhwc():
)
convert_to_onnx_and_check(bn)
def test_bn_nc():
@flow.global_function()
def bn(x: tp.Numpy.Placeholder((3, 4))):
params_shape = (4,)
mean = flow.get_variable(
name="mean",
shape=params_shape,
dtype=flow.float,
initializer=flow.random_uniform_initializer(),
)
variance = flow.get_variable(
name="var",
shape=params_shape,
dtype=flow.float,
initializer=flow.random_uniform_initializer(),
)
gamma = flow.get_variable(
name="gamma",
shape=params_shape,
dtype=flow.float,
initializer=flow.random_uniform_initializer(),
)
beta = flow.get_variable(
name="beta",
shape=params_shape,
dtype=flow.float,
initializer=flow.random_uniform_initializer(),
)
return flow.nn.batch_normalization(x, mean, variance, beta, gamma, 1e-5, axis=1)
convert_to_onnx_and_check(bn)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
import oneflow.typing as tp
from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check
def test_constant_float():
@flow.global_function()
def constant(x: tp.Numpy.Placeholder((3, 5))):
return flow.constant(value=1.5, shape=(1, 3, 3), dtype=flow.float)
convert_to_onnx_and_check(constant)
def test_constant_int():
@flow.global_function()
def constant(x: tp.Numpy.Placeholder((3, 5))):
return flow.constant(value=1, shape=(1, 3, 3), dtype=flow.int)
convert_to_onnx_and_check(constant)
\ No newline at end of file
......@@ -24,3 +24,10 @@ def test_reshape():
return flow.reshape(x, (4, 30))
convert_to_onnx_and_check(reshape)
def test_reshape_negative_dim():
@flow.global_function()
def reshape(x: tp.Numpy.Placeholder((3, 4, 2, 5))):
return flow.reshape(x, (3, -1))
convert_to_onnx_and_check(reshape)
......@@ -208,3 +208,23 @@ class Identity:
@classmethod
def Version_1(cls, ctx, node, **kwargs):
pass
@flow_op("constant", "Constant")
class Constant:
@classmethod
def Version_1(cls, ctx, node, **kwargs):
floating_value = node.attrs.get("floating_value", 0.0)
integer_value = node.attrs.get("integer_value", 0)
is_floating_value = node.attrs.get("is_floating_value", False)
shape = node.attrs.get("shape", None)
if is_floating_value:
values = np.full(shape=shape, fill_value=floating_value, dtype=np.float32)
else:
values = np.full(shape=shape, fill_value=integer_value, dtype=np.float32)
output_name = node.output_tensor_names[0]
ctx.RemoveNode(node.name)
if is_floating_value:
ctx.MakeConst(output_name, values)
else:
ctx.MakeConst(output_name, values)
......@@ -368,7 +368,13 @@ class BatchNorm:
new_output = [node.output_tensor_names[0]]
node.output_tensor_names = new_output
_ConvConvertInputs(ctx, node, with_kernel=False)
input_shape = ctx.get_shape(node.input_tensor_names[0])
if len(input_shape) == 4:
_ConvConvertInputs(ctx, node, with_kernel=False)
else:
# for [n, c] batch_norm
pass
scale_shape = ctx.get_shape(node.input_tensor_names[1])
mean_shape = ctx.get_shape(node.input_tensor_names[3])
......@@ -401,3 +407,4 @@ class BatchNorm:
def Version_9(cls, ctx, node, **kwargs):
# is_test was removed - no change for us
cls.Version_6(ctx, node, **kwargs)
......@@ -24,7 +24,7 @@ long_description += "Email: zhangxiaoyu@oneflow.org"
setuptools.setup(
name="oneflow_onnx",
version="0.3.3",
version="0.3.3.20210623",
author="zhangxiaoyu",
author_email="zhangxiaoyu@oneflow.org",
description="a toolkit for converting trained model of OneFlow to ONNX and ONNX to OneFlow.",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册