未验证 提交 52f7d32a 编写于 作者: P Peihong Liu 提交者: GitHub

Merge pull request #23 from mosout/add_mobilenet_qat

add mobilenet with qat
set -ex
python3 -m pip config set global.index-url https://mirrors.bfsu.edu.cn/pypi/web/simple
python3 -m pip install --user --upgrade pip
python3 -m pip install --user nvidia-pyindex
python3 -m pip install -r test-requirements.txt --user
if [ -f requirements.txt ]; then pip install --user -r requirements.txt; fi
python3 -m pip install oneflow --user -U -f https://staging.oneflow.info/branch/master/cu102
python3 -m pip install -r test-requirements.txt --user --extra-index-url https://pypi.ngc.nvidia.com
if [ -f requirements.txt ]; then python3 -m pip install -r requirements.txt --user; fi
python3 -m pip install oneflow --user -U -f https://staging.oneflow.info/branch/master/cu110
python3 -m pip install gast==0.3.3 --user
python3 setup.py install
# python3 examples/tensorrt_qat/test_lenet_qat_train.py
# python3 -m pytest -s examples/tensorrt_qat/test_lenet_qat.py
python3 examples/tensorrt_qat/test_lenet_qat_train.py
python3 -m pytest -s examples/tensorrt_qat/test_lenet_qat.py
python3 examples/tensorrt_qat/test_mobilenet_qat_train.py
python3 -m pytest -s examples/tensorrt_qat/test_mobilenet_qat.py
python3 -m pytest examples/oneflow2onnx
python3 -m pytest examples/x2oneflow/pytorch2oneflow/nodes
python3 -m pytest examples/x2oneflow/pytorch2oneflow/models
......
......@@ -271,3 +271,56 @@ def retry(n_retries=3):
return _wrapper
return wrapper
def build_qat_engine_from_onnx(model_file, verbose=False):
if verbose:
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
else:
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network_flags = network_flags | (
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION)
)
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(
flags=network_flags
) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
with open(model_file, "rb") as model:
if not parser.parse(model.read()):
print("ERROR: Failed to parse the ONNX file.")
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30
config.flags = config.flags | 1 << int(trt.BuilderFlag.INT8)
return builder.build_engine(network, config)
def run_tensorrt(onnx_path, test_case):
with build_qat_engine_from_onnx(onnx_path) as engine:
inputs, outputs, bindings, stream = allocate_buffers(engine)
with engine.create_execution_context() as context:
batch_size = test_case.shape[0]
test_case = test_case.reshape(-1)
np.copyto(inputs[0].host, test_case)
trt_outputs = do_inference_v2(
context,
bindings=bindings,
inputs=inputs,
outputs=outputs,
stream=stream,
)
data = trt_outputs[0]
return data.reshape(batch_size, -1)
def get_onnx_provider(ctx: str = "cpu"):
if ctx == "gpu":
return ["CUDAExecutionProvider"]
elif ctx == "cpu":
return ["CPUExecutionProvider"]
else:
raise NotImplementedError("Not supported device type. ")
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
import oneflow.typing as tp
flow.env.init()
def Lenet(data):
initializer = flow.truncated_normal(0.1)
......@@ -30,21 +47,345 @@ def Lenet(data):
pool2 = flow.nn.max_pool2d(
conv2, ksize=2, strides=2, padding="VALID", name="pool2", data_format="NCHW"
)
# fc is replaced by conv to support tensorrt7
hidden1 = flow.layers.conv2d(
pool2, 512, 7, padding="VALID", name="hidden1", use_bias=False
reshape = flow.reshape(pool2, [pool2.shape[0], -1])
hidden = flow.layers.dense(
reshape,
512,
activation=flow.nn.relu,
kernel_initializer=initializer,
name="dense1",
)
return flow.layers.dense(hidden, 10, kernel_initializer=initializer, name="dense2")
def _get_regularizer(model_name):
# all decay
return flow.regularizers.l2(0.00004)
def _get_initializer(model_name):
if model_name == "weight":
return flow.variance_scaling_initializer(
2.0, mode="fan_out", distribution="random_normal", data_format="NCHW"
)
elif model_name == "bias":
return flow.zeros_initializer()
elif model_name == "gamma":
return flow.ones_initializer()
elif model_name == "beta":
return flow.zeros_initializer()
elif model_name == "dense_weight":
return flow.random_normal_initializer(0, 0.01)
def _batch_norm(
inputs,
axis,
momentum,
epsilon,
center=True,
scale=True,
trainable=True,
training=True,
name=None,
):
return flow.layers.batch_normalization(
inputs=inputs,
axis=axis,
momentum=momentum,
epsilon=epsilon,
center=center,
scale=scale,
beta_initializer=_get_initializer("beta"),
gamma_initializer=_get_initializer("gamma"),
beta_regularizer=_get_regularizer("beta"),
gamma_regularizer=_get_regularizer("gamma"),
trainable=trainable,
training=training,
name=name,
)
def _relu6(data, prefix):
return flow.clip_by_value(data, 0, 6, name="%s-relu6" % prefix)
def mobilenet_unit(
data,
num_filter=1,
kernel=(1, 1),
stride=(1, 1),
pad=(0, 0),
num_group=1,
data_format="NCHW",
if_act=True,
use_bias=False,
trainable=True,
training=True,
prefix="",
):
conv = flow.layers.conv2d(
inputs=data,
filters=num_filter,
kernel_size=kernel,
strides=stride,
padding=pad,
data_format=data_format,
dilation_rate=1,
groups=num_group,
activation=None,
use_bias=use_bias,
kernel_initializer=_get_initializer("weight"),
bias_initializer=_get_initializer("bias"),
kernel_regularizer=_get_regularizer("weight"),
bias_regularizer=_get_regularizer("bias"),
name=prefix,
)
if if_act:
act = _relu6(conv, prefix)
return act
else:
return conv
def shortcut(data_in, data_residual, prefix):
out = flow.math.add(data_in, data_residual, f"{prefix}-add")
return out
def inverted_residual_unit(
data,
num_in_filter,
num_filter,
ifshortcut,
stride,
kernel,
pad,
expansion_factor,
prefix,
trainable=True,
training=True,
data_format="NCHW",
has_expand=1,
):
num_expfilter = int(round(num_in_filter * expansion_factor))
if has_expand:
channel_expand = mobilenet_unit(
data=data,
num_filter=num_expfilter,
kernel=(1, 1),
stride=(1, 1),
pad="valid",
num_group=1,
data_format=data_format,
if_act=True,
trainable=trainable,
training=training,
prefix="%s-expand" % prefix,
)
else:
channel_expand = data
bottleneck_conv = mobilenet_unit(
data=channel_expand,
num_filter=num_expfilter,
stride=stride,
kernel=kernel,
pad=pad,
num_group=num_expfilter,
data_format=data_format,
if_act=True,
trainable=trainable,
training=training,
prefix="%s-depthwise" % prefix,
)
hidden2 = flow.layers.conv2d(
hidden1, 10, 1, padding="VALID", name="hidden2", use_bias=False
linear_out = mobilenet_unit(
data=bottleneck_conv,
num_filter=num_filter,
kernel=(1, 1),
stride=(1, 1),
pad="valid",
num_group=1,
data_format=data_format,
if_act=False,
trainable=trainable,
training=training,
prefix="%s-project" % prefix,
)
reshape = flow.reshape(hidden2, [hidden2.shape[0], -1])
return reshape
if ifshortcut:
out = shortcut(data_in=data, data_residual=linear_out, prefix=prefix)
return out
else:
return linear_out
MNETV2_CONFIGS_MAP = {
(224, 224): {
"firstconv_filter_num": 32,
# t, c, s
"bottleneck_params_list": [
(1, 16, 1, False),
(6, 24, 2, False),
(6, 24, 1, True),
(6, 32, 2, False),
(6, 32, 1, True),
(6, 32, 1, True),
(6, 64, 2, False),
(6, 64, 1, True),
(6, 64, 1, True),
(6, 64, 1, True),
(6, 96, 1, False),
(6, 96, 1, True),
(6, 96, 1, True),
(6, 160, 2, False),
(6, 160, 1, True),
(6, 160, 1, True),
(6, 320, 1, False),
],
"filter_num_before_gp": 1280,
}
}
class MobileNetV2(object):
def __init__(self, data_wh, multiplier, trainable=True, training=True, **kargs):
super(MobileNetV2, self).__init__()
self.data_wh = data_wh
self.multiplier = multiplier
self.trainable = trainable
self.training = training
if self.data_wh in MNETV2_CONFIGS_MAP:
self.config_map = MNETV2_CONFIGS_MAP[self.data_wh]
else:
self.config_map = MNETV2_CONFIGS_MAP[(224, 224)]
def build_network(
self, input_data, data_format, class_num=1000, prefix="", **configs
):
self.config_map.update(configs)
first_c = int(round(self.config_map["firstconv_filter_num"] * self.multiplier))
first_layer = mobilenet_unit(
data=input_data,
num_filter=first_c,
kernel=(3, 3),
stride=(2, 2),
pad="same",
data_format=data_format,
if_act=True,
trainable=self.trainable,
training=self.training,
prefix=prefix + "-Conv",
)
last_bottleneck_layer = first_layer
in_c = first_c
for i, layer_setting in enumerate(self.config_map["bottleneck_params_list"]):
t, c, s, sc = layer_setting
if i == 0:
last_bottleneck_layer = inverted_residual_unit(
data=last_bottleneck_layer,
num_in_filter=in_c,
num_filter=int(round(c * self.multiplier)),
ifshortcut=sc,
stride=(s, s),
kernel=(3, 3),
pad="same",
expansion_factor=t,
prefix=prefix + "-expanded_conv",
trainable=self.trainable,
training=self.training,
data_format=data_format,
has_expand=0,
)
in_c = int(round(c * self.multiplier))
else:
last_bottleneck_layer = inverted_residual_unit(
data=last_bottleneck_layer,
num_in_filter=in_c,
num_filter=int(round(c * self.multiplier)),
ifshortcut=sc,
stride=(s, s),
kernel=(3, 3),
pad="same",
expansion_factor=t,
prefix=prefix + "-expanded_conv_%d" % i,
trainable=self.trainable,
training=self.training,
data_format=data_format,
)
in_c = int(round(c * self.multiplier))
last_fm = mobilenet_unit(
data=last_bottleneck_layer,
num_filter=int(1280 * self.multiplier) if self.multiplier > 1.0 else 1280,
kernel=(1, 1),
stride=(1, 1),
pad="valid",
data_format=data_format,
if_act=True,
trainable=self.trainable,
training=self.training,
prefix=prefix + "-Conv_1",
)
# global average pooling
pool_size = int(self.data_wh[0] / 32)
pool = flow.nn.avg_pool2d(
last_fm,
ksize=pool_size,
strides=1,
padding="VALID",
data_format="NCHW",
name="pool5",
)
fc = flow.layers.dense(
flow.reshape(pool, (pool.shape[0], -1)),
units=class_num,
use_bias=False,
kernel_initializer=_get_initializer("dense_weight"),
bias_initializer=_get_initializer("bias"),
kernel_regularizer=_get_regularizer("dense_weight"),
bias_regularizer=_get_regularizer("bias"),
trainable=self.trainable,
name=prefix + "-fc",
)
return fc
def __call__(self, input_data, class_num=1000, prefix="", **configs):
sym = self.build_network(
input_data, class_num=class_num, prefix=prefix, **configs
)
return sym
def Mobilenet(
input_data,
channel_last=False,
trainable=True,
training=True,
num_classes=1000,
multiplier=1.0,
prefix="",
):
assert (
channel_last == False
), "Mobilenet does not support channel_last mode, set channel_last=False will be right!"
data_format = "NCHW"
mobilenetgen = MobileNetV2(
(224, 224), multiplier=multiplier, trainable=trainable, training=training
)
out = mobilenetgen(
input_data, data_format=data_format, class_num=num_classes, prefix="MobilenetV2"
)
return out
def get_lenet_job_function(
func_type: str = "train", enable_qat: bool = True, batch_size: int = 100
func_type: str = "train",
enable_qat: bool = True,
batch_size: int = 100,
ctx: str = "gpu",
):
func_config = flow.FunctionConfig()
func_config.default_placement_scope(flow.scope.placement(ctx, "0:0"))
func_config.cudnn_conv_force_fwd_algo(1)
if enable_qat:
func_config.enable_qat(True)
......@@ -59,7 +400,7 @@ def get_lenet_job_function(
images: tp.Numpy.Placeholder((batch_size, 1, 28, 28), dtype=flow.float),
labels: tp.Numpy.Placeholder((batch_size,), dtype=flow.int32),
) -> tp.Numpy:
with flow.scope.placement("gpu", "0:0"):
with flow.scope.placement(ctx, "0:0"):
logits = Lenet(images)
loss = flow.nn.sparse_softmax_cross_entropy_with_logits(
labels, logits, name="softmax_loss"
......@@ -75,7 +416,7 @@ def get_lenet_job_function(
def eval_job(
images: tp.Numpy.Placeholder((batch_size, 1, 28, 28), dtype=flow.float),
) -> tp.Numpy:
with flow.scope.placement("gpu", "0:0"):
with flow.scope.placement(ctx, "0:0"):
logits = Lenet(images)
return logits
......@@ -83,4 +424,53 @@ def get_lenet_job_function(
return eval_job
def get_mobilenet_job_function(
func_type: str = "train",
enable_qat: bool = True,
batch_size: int = 100,
ctx: str = "gpu",
):
func_config = flow.FunctionConfig()
func_config.cudnn_conv_force_fwd_algo(1)
func_config.default_placement_scope(flow.scope.placement(ctx, "0:0"))
if enable_qat:
func_config.enable_qat(True)
func_config.qat.symmetric(True)
func_config.qat.per_channel_weight_quantization(False)
func_config.qat.moving_min_max_stop_update_after_iters(1000)
func_config.qat.target_backend("tensorrt7")
if func_type == "train":
@flow.global_function(type="train", function_config=func_config)
def train_job(
images: tp.Numpy.Placeholder((batch_size, 1, 224, 224), dtype=flow.float),
labels: tp.Numpy.Placeholder((batch_size,), dtype=flow.int32),
) -> tp.Numpy:
with flow.scope.placement(ctx, "0:0"):
logits = Mobilenet(images, num_classes=10)
loss = flow.nn.sparse_softmax_cross_entropy_with_logits(
labels, logits, name="softmax_loss"
)
lr_scheduler = flow.optimizer.PiecewiseConstantScheduler([], [0.1])
flow.optimizer.SGD(lr_scheduler, momentum=0).minimize(loss)
return loss
return train_job
else:
@flow.global_function(type="predict", function_config=func_config)
def eval_job(
images: tp.Numpy.Placeholder((batch_size, 1, 224, 224), dtype=flow.float),
) -> tp.Numpy:
with flow.scope.placement(ctx, "0:0"):
logits = Mobilenet(
images, trainable=False, training=False, num_classes=10
)
return logits
return eval_job
LENET_MODEL_QAT_DIR = "./lenet_model_qat_dir"
MOBILENET_MODEL_QAT_DIR = "./mobilenet_model_qat_dir"
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import common
import argparse
import shutil
import numpy as np
import tensorrt as trt
import oneflow as flow
from common import run_tensorrt, get_onnx_provider
from models import get_lenet_job_function, LENET_MODEL_QAT_DIR
from oneflow_onnx.oneflow2onnx.util import export_onnx_model, run_onnx, compare_result
def build_engine_onnx(model_file, verbose=False):
if verbose:
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
else:
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network_flags = network_flags | (
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION)
)
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(
flags=network_flags
) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
with open(model_file, "rb") as model:
if not parser.parse(model.read()):
print("ERROR: Failed to parse the ONNX file.")
for error in range(parser.num_errors):
print(parser.get_error(error))
return None
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30
config.flags = config.flags | 1 << int(trt.BuilderFlag.INT8)
return builder.build_engine(network, config)
def run_tensorrt(onnx_path, test_case):
with build_engine_onnx(onnx_path) as engine:
inputs, outputs, bindings, stream = common.allocate_buffers(engine)
with engine.create_execution_context() as context:
batch_size = test_case.shape[0]
test_case = test_case.reshape(-1)
np.copyto(inputs[0].host, test_case)
trt_outputs = common.do_inference_v2(
context,
bindings=bindings,
inputs=inputs,
outputs=outputs,
stream=stream,
)
data = trt_outputs[0]
return data.reshape(batch_size, -1)
def test_lenet_qat():
model_existed = os.path.exists(LENET_MODEL_QAT_DIR)
assert model_existed
# Without the following 'print' CI won't pass, but I have no idea why.
print(
"Model exists. "
if os.path.exists(LENET_MODEL_QAT_DIR)
else "Model does not exist. "
)
batch_size = 16
print("Model exists. " if model_existed else "Model does not exist. ")
batch_size = 32
predict_job = get_lenet_job_function("predict", batch_size=batch_size)
flow.load_variables(flow.checkpoint.get(LENET_MODEL_QAT_DIR))
onnx_model_path, cleanup = export_onnx_model(predict_job, opset=10)
ipt_dict, onnx_res = run_onnx(onnx_model_path)
ipt_dict, onnx_res = run_onnx(onnx_model_path, get_onnx_provider("cpu"))
oneflow_res = predict_job(*ipt_dict.values())
compare_result(oneflow_res, onnx_res)
compare_result(oneflow_res, onnx_res, print_outlier=True)
trt_res = run_tensorrt(onnx_model_path, ipt_dict[list(ipt_dict.keys())[0]])
compare_result(oneflow_res, trt_res, True)
compare_result(oneflow_res, trt_res, print_outlier=True)
flow.clear_default_session()
cleanup()
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import shutil
import uuid
import argparse
import oneflow as flow
from models import get_lenet_job_function, LENET_MODEL_QAT_DIR
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import shutil
import oneflow as flow
from common import run_tensorrt, get_onnx_provider
from models import get_mobilenet_job_function, MOBILENET_MODEL_QAT_DIR
from oneflow_onnx.oneflow2onnx.util import export_onnx_model, run_onnx, compare_result
def test_mobilenet_qat():
model_existed = os.path.exists(MOBILENET_MODEL_QAT_DIR)
assert model_existed
# Without the following 'print' CI won't pass, but I have no idea why.
print("Model exists. " if model_existed else "Model does not exist. ")
batch_size = 32
predict_job = get_mobilenet_job_function("predict", batch_size=batch_size)
flow.load_variables(flow.checkpoint.get(MOBILENET_MODEL_QAT_DIR))
onnx_model_path, cleanup = export_onnx_model(predict_job, opset=10)
ipt_dict, onnx_res = run_onnx(onnx_model_path, get_onnx_provider("gpu"))
oneflow_res = predict_job(*ipt_dict.values())
compare_result(oneflow_res, onnx_res, print_outlier=True)
trt_res = run_tensorrt(onnx_model_path, ipt_dict[list(ipt_dict.keys())[0]])
compare_result(oneflow_res, trt_res, print_outlier=True)
flow.clear_default_session()
cleanup()
if os.path.exists(MOBILENET_MODEL_QAT_DIR):
shutil.rmtree(MOBILENET_MODEL_QAT_DIR)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import shutil
import cv2
import numpy as np
import oneflow as flow
from models import get_mobilenet_job_function, MOBILENET_MODEL_QAT_DIR
def resize(images):
results = []
for image in images:
image = np.transpose(image, (1, 2, 0))
image = cv2.resize(image, (224, 224))
results.append(image[None, :, :])
return np.array(results)
if __name__ == "__main__":
batch_size = 16
(train_images, train_labels), (test_images, test_labels) = flow.data.load_mnist(
batch_size, batch_size
)
# train
train_job = get_mobilenet_job_function("train", batch_size=batch_size)
for epoch in range(1):
for i, (images, labels) in enumerate(zip(train_images, train_labels)):
images = resize(images)
loss = train_job(images, labels)
if i % 20 == 0:
print(loss.mean())
if i == 100:
break
if os.path.exists(MOBILENET_MODEL_QAT_DIR):
shutil.rmtree(MOBILENET_MODEL_QAT_DIR)
flow.checkpoint.save(MOBILENET_MODEL_QAT_DIR)
# Without the following 'print' CI won't pass, but I have no idea why.
print(
"Model was saved at "
+ MOBILENET_MODEL_QAT_DIR
+ ". Status : "
+ str(os.path.exists(MOBILENET_MODEL_QAT_DIR))
)
......@@ -18,28 +18,43 @@ import tempfile
import numpy as np
import oneflow as flow
import onnxruntime as ort
from typing import Optional, Union, Tuple, List
from collections import OrderedDict
from oneflow_onnx.oneflow2onnx.flow2onnx import Export
def run_onnx(onnx_model_path, ort_optimize=True):
def run_onnx(
onnx_model_path: str,
providers: List[str],
ipt_dict: Optional[OrderedDict] = None,
ort_optimize: bool = True,
) -> Union[Tuple[OrderedDict, np.ndarray], np.ndarray]:
ort_sess_opt = ort.SessionOptions()
ort_sess_opt.graph_optimization_level = (
ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
if ort_optimize
else ort.GraphOptimizationLevel.ORT_DISABLE_ALL
)
sess = ort.InferenceSession(onnx_model_path, sess_options=ort_sess_opt)
sess = ort.InferenceSession(
onnx_model_path, sess_options=ort_sess_opt, providers=providers
)
assert len(sess.get_outputs()) == 1
assert len(sess.get_inputs()) <= 1
ipt_dict = OrderedDict()
for ipt in sess.get_inputs():
ipt_data = np.random.uniform(low=-10, high=10, size=ipt.shape).astype(
np.float32
)
ipt_dict[ipt.name] = ipt_data
only_return_result = ipt_dict is not None
if ipt_dict is None:
ipt_dict = OrderedDict()
for ipt in sess.get_inputs():
ipt_data = np.random.uniform(low=-10, high=10, size=ipt.shape).astype(
np.float32
)
ipt_dict[ipt.name] = ipt_data
onnx_res = sess.run([], ipt_dict)[0]
if only_return_result:
return onnx_res
return ipt_dict, onnx_res
......@@ -86,8 +101,13 @@ def export_onnx_model(
return onnx_model_path, cleanup
def compare_result(a, b, print_outlier=False):
rtol, atol = 1e-2, 1e-5
def compare_result(
a: np.ndarray,
b: np.ndarray,
rtol: float = 1e-2,
atol: float = 1e-5,
print_outlier: bool = False,
):
if print_outlier:
a = a.flatten()
b = b.flatten()
......@@ -124,12 +144,14 @@ def convert_to_onnx_and_check(
job_func, external_data, opset, flow_weight_dir, onnx_model_path
)
ipt_dict, onnx_res = run_onnx(onnx_model_path, ort_optimize)
ipt_dict, onnx_res = run_onnx(
onnx_model_path, ["CPUExecutionProvider"], ort_optimize=ort_optimize
)
oneflow_res = job_func(*ipt_dict.values())
if not isinstance(oneflow_res, np.ndarray):
oneflow_res = oneflow_res.get().numpy()
compare_result(oneflow_res, onnx_res, print_outlier)
compare_result(oneflow_res, onnx_res, print_outlier=print_outlier)
flow.clear_default_session()
# cleanup()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册