提交 467aefb0 编写于 作者: L liutuo 提交者: 叶剑武

fix data type convert in onnx-converter

refactor structure of python scripts
上级 bc0116dc
...@@ -17,7 +17,6 @@ import hashlib ...@@ -17,7 +17,6 @@ import hashlib
import inspect import inspect
import re import re
import os import os
import six import six
......
...@@ -25,14 +25,11 @@ import yaml ...@@ -25,14 +25,11 @@ import yaml
import sh_commands import sh_commands
from enum import Enum from enum import Enum
sys.path.insert(0, "tools/python") # noqa
from common import * from common import *
from device import DeviceWrapper, DeviceManager from device import DeviceWrapper, DeviceManager
from utils import config_parser from python.utils import config_parser
import convert from python import convert
import encrypt from python import encrypt
from dana.dana_util import DanaUtil
################################ ################################
# set environment # set environment
......
...@@ -13,9 +13,7 @@ ...@@ -13,9 +13,7 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import sys
import numpy as np import numpy as np
import re
import common import common
import six import six
......
...@@ -12,19 +12,21 @@ ...@@ -12,19 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse import argparse
import copy import copy
import os import os
import sys
import yaml import yaml
sys.path.insert(0, "tools/python") # noqa from python.py_proto import mace_pb2
from py_proto import mace_pb2 from python.transform.base_converter import ConverterUtil
from transform.base_converter import ConverterUtil from python.transform.base_converter import MaceKeyword
from transform.base_converter import MaceKeyword from python.transform.base_converter import MaceOp
from transform.base_converter import MaceOp from python.transform.hexagon_converter import HexagonOp
from transform.hexagon_converter import HexagonOp from python.utils.util import mace_check
from utils.util import mace_check
def normalize_op_name(name): def normalize_op_name(name):
......
...@@ -22,20 +22,19 @@ from __future__ import print_function ...@@ -22,20 +22,19 @@ from __future__ import print_function
import argparse import argparse
import sys import sys
import numpy as np import numpy as np
import shutil
import tempfile from python.utils import config_parser
from utils import config_parser from python.utils.config_parser import DataFormat
from utils.config_parser import DataFormat from python.utils.config_parser import DeviceType
from utils.config_parser import DeviceType from python.utils.config_parser import Platform
from utils.config_parser import Platform from python.utils import util
from utils import util from python.utils.util import mace_check
from utils.util import mace_check from python.utils.config_parser import normalize_model_config
from utils.config_parser import normalize_model_config from python.utils.config_parser import ModelKeys
from utils.config_parser import ModelKeys from python.py_proto import mace_pb2
from py_proto import mace_pb2 from python.transform import base_converter as cvt
from transform import base_converter as cvt from python.transform import transformer
from transform import transformer from python.visualize import visualize_model
from visualize import visualize_model
def transpose_shape(shape, dst_order): def transpose_shape(shape, dst_order):
...@@ -162,16 +161,16 @@ def convert_model(conf): ...@@ -162,16 +161,16 @@ def convert_model(conf):
print("Transform model to one that can better run on device") print("Transform model to one that can better run on device")
platform = conf[ModelKeys.platform] platform = conf[ModelKeys.platform]
if platform == Platform.TENSORFLOW: if platform == Platform.TENSORFLOW:
from transform import tensorflow_converter from python.transform import tensorflow_converter
converter = tensorflow_converter.TensorflowConverter( converter = tensorflow_converter.TensorflowConverter(
option, conf["model_file_path"]) option, conf["model_file_path"])
elif platform == Platform.CAFFE: elif platform == Platform.CAFFE:
from transform import caffe_converter from python.transform import caffe_converter
converter = caffe_converter.CaffeConverter(option, converter = caffe_converter.CaffeConverter(option,
conf["model_file_path"], conf["model_file_path"],
conf["weight_file_path"]) conf["weight_file_path"])
elif platform == Platform.ONNX: elif platform == Platform.ONNX:
from transform import onnx_converter from python.transform import onnx_converter
converter = onnx_converter.OnnxConverter(option, converter = onnx_converter.OnnxConverter(option,
conf["model_file_path"]) conf["model_file_path"])
else: else:
...@@ -185,14 +184,14 @@ def convert_model(conf): ...@@ -185,14 +184,14 @@ def convert_model(conf):
runtime = conf[ModelKeys.runtime] runtime = conf[ModelKeys.runtime]
if runtime in [DeviceType.HEXAGON, if runtime in [DeviceType.HEXAGON,
DeviceType.HTA]: DeviceType.HTA]:
from transform import hexagon_converter from python.transform import hexagon_converter
converter = hexagon_converter.HexagonConverter( converter = hexagon_converter.HexagonConverter(
option, output_graph_def, quantize_activation_info) option, output_graph_def, quantize_activation_info)
output_graph_def = converter.run() output_graph_def = converter.run()
elif runtime == DeviceType.APU: elif runtime == DeviceType.APU:
mace_check(platform == Platform.TENSORFLOW, mace_check(platform == Platform.TENSORFLOW,
"apu only support model from tensorflow") "apu only support model from tensorflow")
from transform import apu_converter from python.transform import apu_converter
converter = apu_converter.ApuConverter( converter = apu_converter.ApuConverter(
option, output_graph_def, quantize_activation_info) option, output_graph_def, quantize_activation_info)
output_graph_def = converter.run() output_graph_def = converter.run()
......
...@@ -21,14 +21,15 @@ import datetime ...@@ -21,14 +21,15 @@ import datetime
import os import os
import hashlib import hashlib
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from py_proto import mace_pb2
from utils import device from python.py_proto import mace_pb2
from utils import util from python.utils import device
from utils.util import mace_check from python.utils import util
from utils.util import MaceLogger from python.utils.util import mace_check
from utils import config_parser from python.utils.util import MaceLogger
from utils.config_parser import CPP_KEYWORDS from python.utils import config_parser
from utils.config_parser import ModelKeys from python.utils.config_parser import CPP_KEYWORDS
from python.utils.config_parser import ModelKeys
GENERATED_NAME = set() GENERATED_NAME = set()
......
...@@ -22,9 +22,9 @@ import os ...@@ -22,9 +22,9 @@ import os
import struct import struct
import numpy as np import numpy as np
from utils import util from python.utils import util
from utils.util import MaceLogger from python.utils.util import MaceLogger
from utils.util import mace_check from python.utils.util import mace_check
def generate_opencl_code(binary_file_name, load_func_name, size_func_name, def generate_opencl_code(binary_file_name, load_func_name, size_func_name,
......
...@@ -17,8 +17,8 @@ from __future__ import division ...@@ -17,8 +17,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
from utils import device from python.utils import device
from utils.util import MaceLogger from python.utils.util import MaceLogger
cwd = os.path.dirname(__file__) cwd = os.path.dirname(__file__)
......
...@@ -19,7 +19,7 @@ from __future__ import print_function ...@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import math import math
from transform.base_converter import DeviceType from python.transform.base_converter import DeviceType
class QuantizedData(object): class QuantizedData(object):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import unittest import unittest
import numpy as np import numpy as np
import quantize.quantize_util import quantize_util
class TestQuantize(unittest.TestCase): class TestQuantize(unittest.TestCase):
......
...@@ -22,17 +22,17 @@ import tempfile ...@@ -22,17 +22,17 @@ import tempfile
import shutil import shutil
import numpy as np import numpy as np
from py_proto import mace_pb2 from python.py_proto import mace_pb2
from utils import util from python.utils import util
from utils import device from python.utils import device
from utils import config_parser from python.utils import config_parser
from utils.config_parser import DeviceType from python.utils.config_parser import DeviceType
from utils.target import Target from python.utils.target import Target
from utils.config_parser import ModelKeys from python.utils.config_parser import ModelKeys
from utils.util import MaceLogger from python.utils.util import MaceLogger
from utils.util import mace_check from python.utils.util import mace_check
import run_target from python import run_target
import validate from python import validate
""" """
Tool for mace_run: Tool for mace_run:
......
...@@ -28,10 +28,10 @@ from __future__ import print_function ...@@ -28,10 +28,10 @@ from __future__ import print_function
import argparse import argparse
import os import os
from utils import device from python.utils import device
from utils import target from python.utils import target
from utils import config_parser from python.utils import config_parser
from utils import util from python.utils import util
def run_target(target_abi, install_dir, target_obj, device_ids="all"): def run_target(target_abi, install_dir, target_obj, device_ids="all"):
......
...@@ -12,23 +12,24 @@ ...@@ -12,23 +12,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np import numpy as np
from enum import Enum from enum import Enum
from operator import mul
from py_proto import mace_pb2 from python.py_proto import mace_pb2
from transform import base_converter from python.transform import base_converter
from transform.base_converter import ConverterUtil from python.transform.base_converter import ConverterUtil
from transform.base_converter import EltwiseType from python.transform.base_converter import EltwiseType
from transform.base_converter import MaceKeyword from python.transform.base_converter import MaceKeyword
from transform.base_converter import MaceOp from python.transform.base_converter import MaceOp
from transform.base_converter import PaddingMode from python.transform.base_converter import PaddingMode
from transform.base_converter import PoolingType from python.transform.base_converter import PoolingType
from transform.base_converter import ReduceType from python.transform.base_converter import DataFormat
from transform.base_converter import DataFormat from python.transform.base_converter import FrameworkType
from transform.base_converter import FrameworkType from python.utils.util import mace_check
from utils.util import mace_check
ApuSupportedOps = [ ApuSupportedOps = [
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from enum import Enum from enum import Enum
from py_proto import mace_pb2 from python.py_proto import mace_pb2
from python.utils.config_parser import DataFormat
from utils.config_parser import DataFormat from python.utils.config_parser import DeviceType
from utils.config_parser import DeviceType
# SAME_LOWER: if the amount of paddings to be added is odd, # SAME_LOWER: if the amount of paddings to be added is odd,
......
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math import math
...@@ -19,20 +22,18 @@ import numpy as np ...@@ -19,20 +22,18 @@ import numpy as np
import six import six
import google.protobuf.text_format import google.protobuf.text_format
from py_proto import mace_pb2 from python.py_proto import mace_pb2, caffe_pb2
from transform import base_converter from python.utils.util import mace_check
from transform import shape_inference from . import base_converter
from transform.base_converter import PoolingType from . import shape_inference
from transform.base_converter import ActivationType from .base_converter import PoolingType
from transform.base_converter import EltwiseType from .base_converter import ActivationType
from transform.base_converter import FrameworkType from .base_converter import EltwiseType
from transform.base_converter import DataFormat from .base_converter import FrameworkType
from transform.base_converter import MaceOp from .base_converter import DataFormat
from transform.base_converter import MaceKeyword from .base_converter import MaceOp
from transform.base_converter import ConverterUtil from .base_converter import MaceKeyword
from utils.util import mace_check from .base_converter import ConverterUtil
from py_proto import caffe_pb2
caffe_group_str = 'group' caffe_group_str = 'group'
caffe_kernel_h_str = 'kernel_h' caffe_kernel_h_str = 'kernel_h'
......
...@@ -16,24 +16,23 @@ from __future__ import absolute_import ...@@ -16,24 +16,23 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import copy import copy
import numpy as np import numpy as np
from enum import Enum from enum import Enum
from operator import mul from operator import mul
from functools import reduce from functools import reduce
from py_proto import mace_pb2 from python.py_proto import mace_pb2
from transform import base_converter from python.utils.util import mace_check
from transform.base_converter import ConverterUtil from . import base_converter
from transform.base_converter import DeviceType from .base_converter import ConverterUtil
from transform.base_converter import EltwiseType from .base_converter import DeviceType
from transform.base_converter import MaceKeyword from .base_converter import EltwiseType
from transform.base_converter import MaceOp from .base_converter import MaceKeyword
from transform.base_converter import PaddingMode from .base_converter import MaceOp
from transform.base_converter import PoolingType from .base_converter import PaddingMode
from transform.base_converter import ReduceType from .base_converter import PoolingType
from utils.util import mace_check from .base_converter import ReduceType
HexagonSupportedOps = [ HexagonSupportedOps = [
......
...@@ -12,32 +12,34 @@ ...@@ -12,32 +12,34 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys import sys
from enum import Enum from enum import Enum
import six import six
from py_proto import mace_pb2
from transform import base_converter
from transform.base_converter import PoolingType
from transform.base_converter import PaddingMode
from transform.base_converter import ActivationType
from transform.base_converter import EltwiseType
from transform.base_converter import ReduceType
from transform.base_converter import FrameworkType
from transform.base_converter import RoundMode
from transform.base_converter import DataFormat
from transform.base_converter import MaceOp
from transform.base_converter import MaceKeyword
from transform.base_converter import ConverterUtil
from utils.util import mace_check
import numpy as np import numpy as np
from numbers import Number
from python.py_proto import mace_pb2
from . import base_converter
from .base_converter import PoolingType
from .base_converter import PaddingMode
from .base_converter import ActivationType
from .base_converter import EltwiseType
from .base_converter import ReduceType
from .base_converter import FrameworkType
from .base_converter import RoundMode
from .base_converter import DataFormat
from .base_converter import MaceOp
from .base_converter import MaceKeyword
from .base_converter import ConverterUtil
from python.utils.util import mace_check
import onnx import onnx
import onnx.utils import onnx.utils
from onnx import mapping, numpy_helper, TensorProto from onnx import mapping, numpy_helper, TensorProto
from numbers import Number
IS_PYTHON3 = sys.version_info > (3,) IS_PYTHON3 = sys.version_info > (3,)
...@@ -191,9 +193,9 @@ OnnxOpType = Enum('OnnxOpType', ...@@ -191,9 +193,9 @@ OnnxOpType = Enum('OnnxOpType',
onnx_attr_translator = { onnx_attr_translator = {
"axis": lambda x: int(x), "axis": lambda x: int(x),
"axes": lambda x: [int(a) for a in x], "axes": lambda x: [int(a) for a in x],
"dtype": lambda x: data_type.onnx2tf(x), "dtype": lambda x: onnx_dtype(x),
"keepdims": lambda x: bool(x), "keepdims": lambda x: bool(x),
"to": lambda x: data_type.onnx2tf(x), "to": lambda x: onnx_dtype(x),
} }
...@@ -567,11 +569,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -567,11 +569,7 @@ class OnnxConverter(base_converter.ConverterInterface):
tensor.data_type = mace_pb2.DT_FLOAT tensor.data_type = mace_pb2.DT_FLOAT
tensor.float_data.extend( tensor.float_data.extend(
onnx_tensor.astype(np.float32).flat) onnx_tensor.astype(np.float32).flat)
elif data_type == np.int32: elif data_type == np.int32 or data_type == np.int64:
tensor.data_type = mace_pb2.DT_INT32
tensor.int32_data.extend(
onnx_tensor.astype(np.int32).flat)
elif data_type == np.int64:
tensor.data_type = mace_pb2.DT_INT32 tensor.data_type = mace_pb2.DT_INT32
tensor.int32_data.extend( tensor.int32_data.extend(
onnx_tensor.astype(np.int32).flat) onnx_tensor.astype(np.int32).flat)
...@@ -668,9 +666,9 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -668,9 +666,9 @@ class OnnxConverter(base_converter.ConverterInterface):
if 'to' in node.attrs: if 'to' in node.attrs:
dtype = node.attrs['to'] dtype = node.attrs['to']
if dtype == TensorProto.FLOAT: if dtype == np.float32 or dtype == np.float64:
op.output_type.extend([self._option.data_type]) op.output_type.extend([self._option.data_type])
elif dtype == TensorProto.INT: elif dtype == np.int64 or dtype == np.int32:
op.output_type.extend([mace_pb2.DT_INT32]) op.output_type.extend([mace_pb2.DT_INT32])
else: else:
mace_check(False, "data type %s not supported" % dtype) mace_check(False, "data type %s not supported" % dtype)
...@@ -959,6 +957,9 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -959,6 +957,9 @@ class OnnxConverter(base_converter.ConverterInterface):
if len(const_tensor.dims) == 0: if len(const_tensor.dims) == 0:
value_arg = op.arg.add() value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.name = MaceKeyword.mace_scalar_input_str
if const_tensor.data_type == mace_pb2.DT_INT32:
value_arg.f = float(const_tensor.int32_data[0])
elif const_tensor.data_type == mace_pb2.DT_FLOAT:
value_arg.f = const_tensor.float_data[0] value_arg.f = const_tensor.float_data[0]
value_index_arg = op.arg.add() value_index_arg = op.arg.add()
value_index_arg.name = \ value_index_arg.name = \
...@@ -972,6 +973,9 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -972,6 +973,9 @@ class OnnxConverter(base_converter.ConverterInterface):
if len(const_tensor.dims) == 0: if len(const_tensor.dims) == 0:
value_arg = op.arg.add() value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.name = MaceKeyword.mace_scalar_input_str
if const_tensor.data_type == mace_pb2.DT_INT32:
value_arg.f = float(const_tensor.int32_data[0])
elif const_tensor.data_type == mace_pb2.DT_FLOAT:
value_arg.f = const_tensor.float_data[0] value_arg.f = const_tensor.float_data[0]
value_index_arg = op.arg.add() value_index_arg = op.arg.add()
value_index_arg.name = \ value_index_arg.name = \
......
...@@ -12,18 +12,20 @@ ...@@ -12,18 +12,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math import math
import numpy as np import numpy as np
import six import six
from transform.transformer import Transformer from .transformer import Transformer
from transform.base_converter import DataFormat from .base_converter import DataFormat
from transform.base_converter import MaceOp from .base_converter import MaceOp
from transform.base_converter import MaceKeyword from .base_converter import MaceKeyword
from transform.base_converter import ConverterUtil from .base_converter import ConverterUtil
from utils.util import mace_check from python.utils.util import mace_check
class ShapeInference(object): class ShapeInference(object):
...@@ -253,7 +255,7 @@ class ShapeInference(object): ...@@ -253,7 +255,7 @@ class ShapeInference(object):
aspect_ratio = ConverterUtil.get_arg(op, MaceKeyword.mace_aspect_ratio_str).floats # noqa aspect_ratio = ConverterUtil.get_arg(op, MaceKeyword.mace_aspect_ratio_str).floats # noqa
num_prior = len(aspect_ratio) * len(min_size) + len(max_size) num_prior = len(aspect_ratio) * len(min_size) + len(max_size)
output_shape[2] = num_prior * input_h * input_w * 4 output_shape[2] = int(num_prior * input_h * input_w * 4)
self.add_output_shape(op, [output_shape]) self.add_output_shape(op, [output_shape])
def infer_shape_reshape(self, op): def infer_shape_reshape(self, op):
...@@ -275,7 +277,7 @@ class ShapeInference(object): ...@@ -275,7 +277,7 @@ class ShapeInference(object):
output_shape[i] = dim[i] output_shape[i] = dim[i]
product *= dim[i] product *= dim[i]
if idx != -1: if idx != -1:
output_shape[idx] = input_size / product output_shape[idx] = int(input_size / product)
self.add_output_shape(op, [output_shape]) self.add_output_shape(op, [output_shape])
else: else:
output_shape = [] output_shape = []
......
...@@ -12,27 +12,30 @@ ...@@ -12,27 +12,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math import math
import numpy as np import numpy as np
import six import six
import tensorflow as tf import tensorflow as tf
from enum import Enum from enum import Enum
from py_proto import mace_pb2 from python.py_proto import mace_pb2
from transform import base_converter from . import base_converter
from transform.base_converter import PoolingType from .base_converter import PoolingType
from transform.base_converter import PaddingMode from .base_converter import PaddingMode
from transform.base_converter import ActivationType from .base_converter import ActivationType
from transform.base_converter import EltwiseType from .base_converter import EltwiseType
from transform.base_converter import PadType from .base_converter import PadType
from transform.base_converter import FrameworkType from .base_converter import FrameworkType
from transform.base_converter import ReduceType from .base_converter import ReduceType
from transform.base_converter import DataFormat from .base_converter import DataFormat
from transform.base_converter import MaceOp from .base_converter import MaceOp
from transform.base_converter import MaceKeyword from .base_converter import MaceKeyword
from transform.base_converter import ConverterUtil from .base_converter import ConverterUtil
from utils.util import mace_check from python.utils.util import mace_check
from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.tools.graph_transforms import TransformGraph from tensorflow.tools.graph_transforms import TransformGraph
......
...@@ -12,28 +12,31 @@ ...@@ -12,28 +12,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re import re
import numpy as np import numpy as np
import six import six
from py_proto import mace_pb2 from python.py_proto import mace_pb2
from transform import base_converter from . import base_converter
from transform.base_converter import ConverterUtil from .base_converter import ConverterUtil
from transform.base_converter import DataFormat from .base_converter import DataFormat
from transform.base_converter import DeviceType from .base_converter import DeviceType
from transform.base_converter import EltwiseType from .base_converter import EltwiseType
from transform.base_converter import FrameworkType from .base_converter import FrameworkType
from transform.base_converter import MaceKeyword from .base_converter import MaceKeyword
from transform.base_converter import MaceOp from .base_converter import MaceOp
from transform.base_converter import MaceFixedDataFormatOps # noqa from .base_converter import MaceFixedDataFormatOps # noqa
from transform.base_converter import MaceTransposableDataFormatOps # noqa from .base_converter import MaceTransposableDataFormatOps # noqa
from transform.base_converter import PaddingMode from .base_converter import PaddingMode
from transform.base_converter import ReduceType from .base_converter import ReduceType
from transform.base_converter import TransformerRule from .base_converter import TransformerRule
from quantize import quantize_util from python.quantize import quantize_util
from utils.util import mace_check from python.utils.util import mace_check
class Transformer(base_converter.ConverterInterface): class Transformer(base_converter.ConverterInterface):
...@@ -1440,10 +1443,10 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1440,10 +1443,10 @@ class Transformer(base_converter.ConverterInterface):
arg.i = 1 arg.i = 1
elif arg.i == 3: elif arg.i == 3:
arg.i = 2 arg.i = 2
if op.input[0] in self._producer:
producer = self._producer[op.input[0]] producer = self._producer[op.input[0]]
input_shape = producer.output_shape[0].dims input_shape = producer.output_shape[0].dims
if producer.type == MaceOp.FullyConnected.name and \ if producer.type == MaceOp.FullyConnected.name and\
len(input_shape) == 2: len(input_shape) == 2:
axis_arg = ConverterUtil.get_arg( axis_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_axis_str) op, MaceKeyword.mace_axis_str)
......
...@@ -22,9 +22,9 @@ import copy ...@@ -22,9 +22,9 @@ import copy
import yaml import yaml
from enum import Enum from enum import Enum
from utils.util import mace_check from python.utils.util import mace_check
from utils.util import MaceLogger from python.utils.util import MaceLogger
from py_proto import mace_pb2 from python.py_proto import mace_pb2
CPP_KEYWORDS = [ CPP_KEYWORDS = [
'alignas', 'alignof', 'and', 'and_eq', 'asm', 'atomic_cancel', 'alignas', 'alignof', 'and', 'and_eq', 'asm', 'atomic_cancel',
......
...@@ -23,7 +23,7 @@ import subprocess ...@@ -23,7 +23,7 @@ import subprocess
import random import random
import tempfile import tempfile
from utils import util from python.utils import util
def execute(cmd, verbose=True): def execute(cmd, verbose=True):
......
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -12,15 +12,19 @@ ...@@ -12,15 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os import os
import os.path import os.path
import numpy as np import numpy as np
import six import six
from py_proto import mace_pb2 from python.py_proto import mace_pb2
from utils import util from python.utils import util
from utils.config_parser import DataFormat from python.utils.config_parser import DataFormat
from utils.config_parser import Platform from python.utils.config_parser import Platform
VALIDATION_MODULE = 'VALIDATION' VALIDATION_MODULE = 'VALIDATION'
......
...@@ -21,7 +21,6 @@ import re ...@@ -21,7 +21,6 @@ import re
import sh import sh
import struct import struct
import sys import sys
import time
import platform import platform
import six import six
......
...@@ -13,11 +13,9 @@ ...@@ -13,11 +13,9 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import sys
import os import os
import os.path import os.path
import numpy as np import numpy as np
import re
import six import six
import common import common
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册