提交 467aefb0 编写于 作者: L liutuo 提交者: 叶剑武

fix data type convert in onnx-converter

refactor structure of python scripts
上级 bc0116dc
......@@ -17,7 +17,6 @@ import hashlib
import inspect
import re
import os
import six
......
......@@ -25,14 +25,11 @@ import yaml
import sh_commands
from enum import Enum
sys.path.insert(0, "tools/python") # noqa
from common import *
from device import DeviceWrapper, DeviceManager
from utils import config_parser
import convert
import encrypt
from dana.dana_util import DanaUtil
from python.utils import config_parser
from python import convert
from python import encrypt
################################
# set environment
......
......@@ -13,9 +13,7 @@
# limitations under the License.
import argparse
import sys
import numpy as np
import re
import common
import six
......
......@@ -12,19 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import copy
import os
import sys
import yaml
sys.path.insert(0, "tools/python") # noqa
from py_proto import mace_pb2
from transform.base_converter import ConverterUtil
from transform.base_converter import MaceKeyword
from transform.base_converter import MaceOp
from transform.hexagon_converter import HexagonOp
from utils.util import mace_check
from python.py_proto import mace_pb2
from python.transform.base_converter import ConverterUtil
from python.transform.base_converter import MaceKeyword
from python.transform.base_converter import MaceOp
from python.transform.hexagon_converter import HexagonOp
from python.utils.util import mace_check
def normalize_op_name(name):
......
......@@ -22,20 +22,19 @@ from __future__ import print_function
import argparse
import sys
import numpy as np
import shutil
import tempfile
from utils import config_parser
from utils.config_parser import DataFormat
from utils.config_parser import DeviceType
from utils.config_parser import Platform
from utils import util
from utils.util import mace_check
from utils.config_parser import normalize_model_config
from utils.config_parser import ModelKeys
from py_proto import mace_pb2
from transform import base_converter as cvt
from transform import transformer
from visualize import visualize_model
from python.utils import config_parser
from python.utils.config_parser import DataFormat
from python.utils.config_parser import DeviceType
from python.utils.config_parser import Platform
from python.utils import util
from python.utils.util import mace_check
from python.utils.config_parser import normalize_model_config
from python.utils.config_parser import ModelKeys
from python.py_proto import mace_pb2
from python.transform import base_converter as cvt
from python.transform import transformer
from python.visualize import visualize_model
def transpose_shape(shape, dst_order):
......@@ -162,16 +161,16 @@ def convert_model(conf):
print("Transform model to one that can better run on device")
platform = conf[ModelKeys.platform]
if platform == Platform.TENSORFLOW:
from transform import tensorflow_converter
from python.transform import tensorflow_converter
converter = tensorflow_converter.TensorflowConverter(
option, conf["model_file_path"])
elif platform == Platform.CAFFE:
from transform import caffe_converter
from python.transform import caffe_converter
converter = caffe_converter.CaffeConverter(option,
conf["model_file_path"],
conf["weight_file_path"])
elif platform == Platform.ONNX:
from transform import onnx_converter
from python.transform import onnx_converter
converter = onnx_converter.OnnxConverter(option,
conf["model_file_path"])
else:
......@@ -185,14 +184,14 @@ def convert_model(conf):
runtime = conf[ModelKeys.runtime]
if runtime in [DeviceType.HEXAGON,
DeviceType.HTA]:
from transform import hexagon_converter
from python.transform import hexagon_converter
converter = hexagon_converter.HexagonConverter(
option, output_graph_def, quantize_activation_info)
output_graph_def = converter.run()
elif runtime == DeviceType.APU:
mace_check(platform == Platform.TENSORFLOW,
"apu only support model from tensorflow")
from transform import apu_converter
from python.transform import apu_converter
converter = apu_converter.ApuConverter(
option, output_graph_def, quantize_activation_info)
output_graph_def = converter.run()
......
......@@ -21,14 +21,15 @@ import datetime
import os
import hashlib
from jinja2 import Environment, FileSystemLoader
from py_proto import mace_pb2
from utils import device
from utils import util
from utils.util import mace_check
from utils.util import MaceLogger
from utils import config_parser
from utils.config_parser import CPP_KEYWORDS
from utils.config_parser import ModelKeys
from python.py_proto import mace_pb2
from python.utils import device
from python.utils import util
from python.utils.util import mace_check
from python.utils.util import MaceLogger
from python.utils import config_parser
from python.utils.config_parser import CPP_KEYWORDS
from python.utils.config_parser import ModelKeys
GENERATED_NAME = set()
......
......@@ -22,9 +22,9 @@ import os
import struct
import numpy as np
from utils import util
from utils.util import MaceLogger
from utils.util import mace_check
from python.utils import util
from python.utils.util import MaceLogger
from python.utils.util import mace_check
def generate_opencl_code(binary_file_name, load_func_name, size_func_name,
......
......@@ -17,8 +17,8 @@ from __future__ import division
from __future__ import print_function
import os
from utils import device
from utils.util import MaceLogger
from python.utils import device
from python.utils.util import MaceLogger
cwd = os.path.dirname(__file__)
......
......@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
import math
from transform.base_converter import DeviceType
from python.transform.base_converter import DeviceType
class QuantizedData(object):
......
......@@ -14,7 +14,7 @@
import unittest
import numpy as np
import quantize.quantize_util
import quantize_util
class TestQuantize(unittest.TestCase):
......
......@@ -22,17 +22,17 @@ import tempfile
import shutil
import numpy as np
from py_proto import mace_pb2
from utils import util
from utils import device
from utils import config_parser
from utils.config_parser import DeviceType
from utils.target import Target
from utils.config_parser import ModelKeys
from utils.util import MaceLogger
from utils.util import mace_check
import run_target
import validate
from python.py_proto import mace_pb2
from python.utils import util
from python.utils import device
from python.utils import config_parser
from python.utils.config_parser import DeviceType
from python.utils.target import Target
from python.utils.config_parser import ModelKeys
from python.utils.util import MaceLogger
from python.utils.util import mace_check
from python import run_target
from python import validate
"""
Tool for mace_run:
......
......@@ -28,10 +28,10 @@ from __future__ import print_function
import argparse
import os
from utils import device
from utils import target
from utils import config_parser
from utils import util
from python.utils import device
from python.utils import target
from python.utils import config_parser
from python.utils import util
def run_target(target_abi, install_dir, target_obj, device_ids="all"):
......
......@@ -12,23 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from enum import Enum
from operator import mul
from py_proto import mace_pb2
from transform import base_converter
from transform.base_converter import ConverterUtil
from transform.base_converter import EltwiseType
from transform.base_converter import MaceKeyword
from transform.base_converter import MaceOp
from transform.base_converter import PaddingMode
from transform.base_converter import PoolingType
from transform.base_converter import ReduceType
from transform.base_converter import DataFormat
from transform.base_converter import FrameworkType
from utils.util import mace_check
from python.py_proto import mace_pb2
from python.transform import base_converter
from python.transform.base_converter import ConverterUtil
from python.transform.base_converter import EltwiseType
from python.transform.base_converter import MaceKeyword
from python.transform.base_converter import MaceOp
from python.transform.base_converter import PaddingMode
from python.transform.base_converter import PoolingType
from python.transform.base_converter import DataFormat
from python.transform.base_converter import FrameworkType
from python.utils.util import mace_check
ApuSupportedOps = [
......
......@@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from enum import Enum
from py_proto import mace_pb2
from utils.config_parser import DataFormat
from utils.config_parser import DeviceType
from python.py_proto import mace_pb2
from python.utils.config_parser import DataFormat
from python.utils.config_parser import DeviceType
# SAME_LOWER: if the amount of paddings to be added is odd,
......
......@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
......@@ -19,20 +22,18 @@ import numpy as np
import six
import google.protobuf.text_format
from py_proto import mace_pb2
from transform import base_converter
from transform import shape_inference
from transform.base_converter import PoolingType
from transform.base_converter import ActivationType
from transform.base_converter import EltwiseType
from transform.base_converter import FrameworkType
from transform.base_converter import DataFormat
from transform.base_converter import MaceOp
from transform.base_converter import MaceKeyword
from transform.base_converter import ConverterUtil
from utils.util import mace_check
from py_proto import caffe_pb2
from python.py_proto import mace_pb2, caffe_pb2
from python.utils.util import mace_check
from . import base_converter
from . import shape_inference
from .base_converter import PoolingType
from .base_converter import ActivationType
from .base_converter import EltwiseType
from .base_converter import FrameworkType
from .base_converter import DataFormat
from .base_converter import MaceOp
from .base_converter import MaceKeyword
from .base_converter import ConverterUtil
caffe_group_str = 'group'
caffe_kernel_h_str = 'kernel_h'
......
......@@ -16,24 +16,23 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import numpy as np
from enum import Enum
from operator import mul
from functools import reduce
from py_proto import mace_pb2
from transform import base_converter
from transform.base_converter import ConverterUtil
from transform.base_converter import DeviceType
from transform.base_converter import EltwiseType
from transform.base_converter import MaceKeyword
from transform.base_converter import MaceOp
from transform.base_converter import PaddingMode
from transform.base_converter import PoolingType
from transform.base_converter import ReduceType
from utils.util import mace_check
from python.py_proto import mace_pb2
from python.utils.util import mace_check
from . import base_converter
from .base_converter import ConverterUtil
from .base_converter import DeviceType
from .base_converter import EltwiseType
from .base_converter import MaceKeyword
from .base_converter import MaceOp
from .base_converter import PaddingMode
from .base_converter import PoolingType
from .base_converter import ReduceType
HexagonSupportedOps = [
......
......@@ -12,32 +12,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from enum import Enum
import six
from py_proto import mace_pb2
from transform import base_converter
from transform.base_converter import PoolingType
from transform.base_converter import PaddingMode
from transform.base_converter import ActivationType
from transform.base_converter import EltwiseType
from transform.base_converter import ReduceType
from transform.base_converter import FrameworkType
from transform.base_converter import RoundMode
from transform.base_converter import DataFormat
from transform.base_converter import MaceOp
from transform.base_converter import MaceKeyword
from transform.base_converter import ConverterUtil
from utils.util import mace_check
import numpy as np
from numbers import Number
from python.py_proto import mace_pb2
from . import base_converter
from .base_converter import PoolingType
from .base_converter import PaddingMode
from .base_converter import ActivationType
from .base_converter import EltwiseType
from .base_converter import ReduceType
from .base_converter import FrameworkType
from .base_converter import RoundMode
from .base_converter import DataFormat
from .base_converter import MaceOp
from .base_converter import MaceKeyword
from .base_converter import ConverterUtil
from python.utils.util import mace_check
import onnx
import onnx.utils
from onnx import mapping, numpy_helper, TensorProto
from numbers import Number
IS_PYTHON3 = sys.version_info > (3,)
......@@ -191,9 +193,9 @@ OnnxOpType = Enum('OnnxOpType',
onnx_attr_translator = {
"axis": lambda x: int(x),
"axes": lambda x: [int(a) for a in x],
"dtype": lambda x: data_type.onnx2tf(x),
"dtype": lambda x: onnx_dtype(x),
"keepdims": lambda x: bool(x),
"to": lambda x: data_type.onnx2tf(x),
"to": lambda x: onnx_dtype(x),
}
......@@ -567,11 +569,7 @@ class OnnxConverter(base_converter.ConverterInterface):
tensor.data_type = mace_pb2.DT_FLOAT
tensor.float_data.extend(
onnx_tensor.astype(np.float32).flat)
elif data_type == np.int32:
tensor.data_type = mace_pb2.DT_INT32
tensor.int32_data.extend(
onnx_tensor.astype(np.int32).flat)
elif data_type == np.int64:
elif data_type == np.int32 or data_type == np.int64:
tensor.data_type = mace_pb2.DT_INT32
tensor.int32_data.extend(
onnx_tensor.astype(np.int32).flat)
......@@ -668,9 +666,9 @@ class OnnxConverter(base_converter.ConverterInterface):
if 'to' in node.attrs:
dtype = node.attrs['to']
if dtype == TensorProto.FLOAT:
if dtype == np.float32 or dtype == np.float64:
op.output_type.extend([self._option.data_type])
elif dtype == TensorProto.INT:
elif dtype == np.int64 or dtype == np.int32:
op.output_type.extend([mace_pb2.DT_INT32])
else:
mace_check(False, "data type %s not supported" % dtype)
......@@ -959,7 +957,10 @@ class OnnxConverter(base_converter.ConverterInterface):
if len(const_tensor.dims) == 0:
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = const_tensor.float_data[0]
if const_tensor.data_type == mace_pb2.DT_INT32:
value_arg.f = float(const_tensor.int32_data[0])
elif const_tensor.data_type == mace_pb2.DT_FLOAT:
value_arg.f = const_tensor.float_data[0]
value_index_arg = op.arg.add()
value_index_arg.name = \
MaceKeyword.mace_scalar_input_index_str
......@@ -972,7 +973,10 @@ class OnnxConverter(base_converter.ConverterInterface):
if len(const_tensor.dims) == 0:
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = const_tensor.float_data[0]
if const_tensor.data_type == mace_pb2.DT_INT32:
value_arg.f = float(const_tensor.int32_data[0])
elif const_tensor.data_type == mace_pb2.DT_FLOAT:
value_arg.f = const_tensor.float_data[0]
value_index_arg = op.arg.add()
value_index_arg.name = \
MaceKeyword.mace_scalar_input_index_str
......
......@@ -12,18 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import six
from transform.transformer import Transformer
from transform.base_converter import DataFormat
from transform.base_converter import MaceOp
from transform.base_converter import MaceKeyword
from transform.base_converter import ConverterUtil
from utils.util import mace_check
from .transformer import Transformer
from .base_converter import DataFormat
from .base_converter import MaceOp
from .base_converter import MaceKeyword
from .base_converter import ConverterUtil
from python.utils.util import mace_check
class ShapeInference(object):
......@@ -253,7 +255,7 @@ class ShapeInference(object):
aspect_ratio = ConverterUtil.get_arg(op, MaceKeyword.mace_aspect_ratio_str).floats # noqa
num_prior = len(aspect_ratio) * len(min_size) + len(max_size)
output_shape[2] = num_prior * input_h * input_w * 4
output_shape[2] = int(num_prior * input_h * input_w * 4)
self.add_output_shape(op, [output_shape])
def infer_shape_reshape(self, op):
......@@ -275,7 +277,7 @@ class ShapeInference(object):
output_shape[i] = dim[i]
product *= dim[i]
if idx != -1:
output_shape[idx] = input_size / product
output_shape[idx] = int(input_size / product)
self.add_output_shape(op, [output_shape])
else:
output_shape = []
......
......@@ -12,27 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import six
import tensorflow as tf
from enum import Enum
from py_proto import mace_pb2
from transform import base_converter
from transform.base_converter import PoolingType
from transform.base_converter import PaddingMode
from transform.base_converter import ActivationType
from transform.base_converter import EltwiseType
from transform.base_converter import PadType
from transform.base_converter import FrameworkType
from transform.base_converter import ReduceType
from transform.base_converter import DataFormat
from transform.base_converter import MaceOp
from transform.base_converter import MaceKeyword
from transform.base_converter import ConverterUtil
from utils.util import mace_check
from python.py_proto import mace_pb2
from . import base_converter
from .base_converter import PoolingType
from .base_converter import PaddingMode
from .base_converter import ActivationType
from .base_converter import EltwiseType
from .base_converter import PadType
from .base_converter import FrameworkType
from .base_converter import ReduceType
from .base_converter import DataFormat
from .base_converter import MaceOp
from .base_converter import MaceKeyword
from .base_converter import ConverterUtil
from python.utils.util import mace_check
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.tools.graph_transforms import TransformGraph
......
......@@ -12,28 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import numpy as np
import six
from py_proto import mace_pb2
from transform import base_converter
from transform.base_converter import ConverterUtil
from transform.base_converter import DataFormat
from transform.base_converter import DeviceType
from transform.base_converter import EltwiseType
from transform.base_converter import FrameworkType
from transform.base_converter import MaceKeyword
from transform.base_converter import MaceOp
from transform.base_converter import MaceFixedDataFormatOps # noqa
from transform.base_converter import MaceTransposableDataFormatOps # noqa
from transform.base_converter import PaddingMode
from transform.base_converter import ReduceType
from transform.base_converter import TransformerRule
from quantize import quantize_util
from utils.util import mace_check
from python.py_proto import mace_pb2
from . import base_converter
from .base_converter import ConverterUtil
from .base_converter import DataFormat
from .base_converter import DeviceType
from .base_converter import EltwiseType
from .base_converter import FrameworkType
from .base_converter import MaceKeyword
from .base_converter import MaceOp
from .base_converter import MaceFixedDataFormatOps # noqa
from .base_converter import MaceTransposableDataFormatOps # noqa
from .base_converter import PaddingMode
from .base_converter import ReduceType
from .base_converter import TransformerRule
from python.quantize import quantize_util
from python.utils.util import mace_check
class Transformer(base_converter.ConverterInterface):
......@@ -1440,15 +1443,15 @@ class Transformer(base_converter.ConverterInterface):
arg.i = 1
elif arg.i == 3:
arg.i = 2
producer = self._producer[op.input[0]]
input_shape = producer.output_shape[0].dims
if producer.type == MaceOp.FullyConnected.name and \
len(input_shape) == 2:
axis_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_axis_str)
if axis_arg.i == 1:
axis_arg.i = 3
if op.input[0] in self._producer:
producer = self._producer[op.input[0]]
input_shape = producer.output_shape[0].dims
if producer.type == MaceOp.FullyConnected.name and\
len(input_shape) == 2:
axis_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_axis_str)
if axis_arg.i == 1:
axis_arg.i = 3
elif op.type == MaceOp.Squeeze.name:
for arg in op.arg:
......
......@@ -22,9 +22,9 @@ import copy
import yaml
from enum import Enum
from utils.util import mace_check
from utils.util import MaceLogger
from py_proto import mace_pb2
from python.utils.util import mace_check
from python.utils.util import MaceLogger
from python.py_proto import mace_pb2
CPP_KEYWORDS = [
'alignas', 'alignof', 'and', 'and_eq', 'asm', 'atomic_cancel',
......
......@@ -23,7 +23,7 @@ import subprocess
import random
import tempfile
from utils import util
from python.utils import util
def execute(cmd, verbose=True):
......
......@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......
......@@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import os.path
import numpy as np
import six
from py_proto import mace_pb2
from utils import util
from utils.config_parser import DataFormat
from utils.config_parser import Platform
from python.py_proto import mace_pb2
from python.utils import util
from python.utils.config_parser import DataFormat
from python.utils.config_parser import Platform
VALIDATION_MODULE = 'VALIDATION'
......
......@@ -21,7 +21,6 @@ import re
import sh
import struct
import sys
import time
import platform
import six
......
......@@ -13,11 +13,9 @@
# limitations under the License.
import argparse
import sys
import os
import os.path
import numpy as np
import re
import six
import common
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册