提交 80a017ce 编写于 作者: L liyin 提交者: liutuo

Add CMake and ZH docs

Revert "Merge branch 'refactor-python-scripts' into 'master'"

This reverts merge request !1200
上级 c1ae5dd3
...@@ -17,6 +17,7 @@ import hashlib ...@@ -17,6 +17,7 @@ import hashlib
import inspect import inspect
import re import re
import os import os
import six import six
......
...@@ -25,11 +25,14 @@ import yaml ...@@ -25,11 +25,14 @@ import yaml
import sh_commands import sh_commands
from enum import Enum from enum import Enum
sys.path.insert(0, "tools/python") # noqa
from common import * from common import *
from device import DeviceWrapper, DeviceManager from device import DeviceWrapper, DeviceManager
from python.utils import config_parser from utils import config_parser
from python import convert import convert
from python import encrypt import encrypt
from dana.dana_util import DanaUtil
################################ ################################
# set environment # set environment
......
...@@ -13,7 +13,9 @@ ...@@ -13,7 +13,9 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import sys
import numpy as np import numpy as np
import re
import common import common
import six import six
......
...@@ -12,21 +12,19 @@ ...@@ -12,21 +12,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse import argparse
import copy import copy
import os import os
import sys
import yaml import yaml
from python.py_proto import mace_pb2 sys.path.insert(0, "tools/python") # noqa
from python.transform.base_converter import ConverterUtil from py_proto import mace_pb2
from python.transform.base_converter import MaceKeyword from transform.base_converter import ConverterUtil
from python.transform.base_converter import MaceOp from transform.base_converter import MaceKeyword
from python.transform.hexagon_converter import HexagonOp from transform.base_converter import MaceOp
from python.utils.util import mace_check from transform.hexagon_converter import HexagonOp
from utils.util import mace_check
def normalize_op_name(name): def normalize_op_name(name):
......
...@@ -22,19 +22,20 @@ from __future__ import print_function ...@@ -22,19 +22,20 @@ from __future__ import print_function
import argparse import argparse
import sys import sys
import numpy as np import numpy as np
import shutil
from python.utils import config_parser import tempfile
from python.utils.config_parser import DataFormat from utils import config_parser
from python.utils.config_parser import DeviceType from utils.config_parser import DataFormat
from python.utils.config_parser import Platform from utils.config_parser import DeviceType
from python.utils import util from utils.config_parser import Platform
from python.utils.util import mace_check from utils import util
from python.utils.config_parser import normalize_model_config from utils.util import mace_check
from python.utils.config_parser import ModelKeys from utils.config_parser import normalize_model_config
from python.py_proto import mace_pb2 from utils.config_parser import ModelKeys
from python.transform import base_converter as cvt from py_proto import mace_pb2
from python.transform import transformer from transform import base_converter as cvt
from python.visualize import visualize_model from transform import transformer
from visualize import visualize_model
def transpose_shape(shape, dst_order): def transpose_shape(shape, dst_order):
...@@ -161,16 +162,16 @@ def convert_model(conf): ...@@ -161,16 +162,16 @@ def convert_model(conf):
print("Transform model to one that can better run on device") print("Transform model to one that can better run on device")
platform = conf[ModelKeys.platform] platform = conf[ModelKeys.platform]
if platform == Platform.TENSORFLOW: if platform == Platform.TENSORFLOW:
from python.transform import tensorflow_converter from transform import tensorflow_converter
converter = tensorflow_converter.TensorflowConverter( converter = tensorflow_converter.TensorflowConverter(
option, conf["model_file_path"]) option, conf["model_file_path"])
elif platform == Platform.CAFFE: elif platform == Platform.CAFFE:
from python.transform import caffe_converter from transform import caffe_converter
converter = caffe_converter.CaffeConverter(option, converter = caffe_converter.CaffeConverter(option,
conf["model_file_path"], conf["model_file_path"],
conf["weight_file_path"]) conf["weight_file_path"])
elif platform == Platform.ONNX: elif platform == Platform.ONNX:
from python.transform import onnx_converter from transform import onnx_converter
converter = onnx_converter.OnnxConverter(option, converter = onnx_converter.OnnxConverter(option,
conf["model_file_path"]) conf["model_file_path"])
else: else:
...@@ -184,14 +185,14 @@ def convert_model(conf): ...@@ -184,14 +185,14 @@ def convert_model(conf):
runtime = conf[ModelKeys.runtime] runtime = conf[ModelKeys.runtime]
if runtime in [DeviceType.HEXAGON, if runtime in [DeviceType.HEXAGON,
DeviceType.HTA]: DeviceType.HTA]:
from python.transform import hexagon_converter from transform import hexagon_converter
converter = hexagon_converter.HexagonConverter( converter = hexagon_converter.HexagonConverter(
option, output_graph_def, quantize_activation_info) option, output_graph_def, quantize_activation_info)
output_graph_def = converter.run() output_graph_def = converter.run()
elif runtime == DeviceType.APU: elif runtime == DeviceType.APU:
mace_check(platform == Platform.TENSORFLOW, mace_check(platform == Platform.TENSORFLOW,
"apu only support model from tensorflow") "apu only support model from tensorflow")
from python.transform import apu_converter from transform import apu_converter
converter = apu_converter.ApuConverter( converter = apu_converter.ApuConverter(
option, output_graph_def, quantize_activation_info) option, output_graph_def, quantize_activation_info)
output_graph_def = converter.run() output_graph_def = converter.run()
......
...@@ -21,15 +21,14 @@ import datetime ...@@ -21,15 +21,14 @@ import datetime
import os import os
import hashlib import hashlib
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from py_proto import mace_pb2
from python.py_proto import mace_pb2 from utils import device
from python.utils import device from utils import util
from python.utils import util from utils.util import mace_check
from python.utils.util import mace_check from utils.util import MaceLogger
from python.utils.util import MaceLogger from utils import config_parser
from python.utils import config_parser from utils.config_parser import CPP_KEYWORDS
from python.utils.config_parser import CPP_KEYWORDS from utils.config_parser import ModelKeys
from python.utils.config_parser import ModelKeys
GENERATED_NAME = set() GENERATED_NAME = set()
......
...@@ -22,9 +22,9 @@ import os ...@@ -22,9 +22,9 @@ import os
import struct import struct
import numpy as np import numpy as np
from python.utils import util from utils import util
from python.utils.util import MaceLogger from utils.util import MaceLogger
from python.utils.util import mace_check from utils.util import mace_check
def generate_opencl_code(binary_file_name, load_func_name, size_func_name, def generate_opencl_code(binary_file_name, load_func_name, size_func_name,
......
...@@ -17,8 +17,8 @@ from __future__ import division ...@@ -17,8 +17,8 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
from python.utils import device from utils import device
from python.utils.util import MaceLogger from utils.util import MaceLogger
cwd = os.path.dirname(__file__) cwd = os.path.dirname(__file__)
......
...@@ -19,7 +19,7 @@ from __future__ import print_function ...@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import math import math
from python.transform.base_converter import DeviceType from transform.base_converter import DeviceType
class QuantizedData(object): class QuantizedData(object):
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import unittest import unittest
import numpy as np import numpy as np
import quantize_util import quantize.quantize_util
class TestQuantize(unittest.TestCase): class TestQuantize(unittest.TestCase):
......
...@@ -22,17 +22,17 @@ import tempfile ...@@ -22,17 +22,17 @@ import tempfile
import shutil import shutil
import numpy as np import numpy as np
from python.py_proto import mace_pb2 from py_proto import mace_pb2
from python.utils import util from utils import util
from python.utils import device from utils import device
from python.utils import config_parser from utils import config_parser
from python.utils.config_parser import DeviceType from utils.config_parser import DeviceType
from python.utils.target import Target from utils.target import Target
from python.utils.config_parser import ModelKeys from utils.config_parser import ModelKeys
from python.utils.util import MaceLogger from utils.util import MaceLogger
from python.utils.util import mace_check from utils.util import mace_check
from python import run_target import run_target
from python import validate import validate
""" """
Tool for mace_run: Tool for mace_run:
......
...@@ -28,10 +28,10 @@ from __future__ import print_function ...@@ -28,10 +28,10 @@ from __future__ import print_function
import argparse import argparse
import os import os
from python.utils import device from utils import device
from python.utils import target from utils import target
from python.utils import config_parser from utils import config_parser
from python.utils import util from utils import util
def run_target(target_abi, install_dir, target_obj, device_ids="all"): def run_target(target_abi, install_dir, target_obj, device_ids="all"):
......
...@@ -12,24 +12,23 @@ ...@@ -12,24 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import import copy
from __future__ import division
from __future__ import print_function
import numpy as np import numpy as np
from enum import Enum from enum import Enum
from operator import mul
from python.py_proto import mace_pb2 from py_proto import mace_pb2
from python.transform import base_converter from transform import base_converter
from python.transform.base_converter import ConverterUtil from transform.base_converter import ConverterUtil
from python.transform.base_converter import EltwiseType from transform.base_converter import EltwiseType
from python.transform.base_converter import MaceKeyword from transform.base_converter import MaceKeyword
from python.transform.base_converter import MaceOp from transform.base_converter import MaceOp
from python.transform.base_converter import PaddingMode from transform.base_converter import PaddingMode
from python.transform.base_converter import PoolingType from transform.base_converter import PoolingType
from python.transform.base_converter import DataFormat from transform.base_converter import ReduceType
from python.transform.base_converter import FrameworkType from transform.base_converter import DataFormat
from python.utils.util import mace_check from transform.base_converter import FrameworkType
from utils.util import mace_check
ApuSupportedOps = [ ApuSupportedOps = [
......
...@@ -12,15 +12,13 @@ ...@@ -12,15 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from enum import Enum from enum import Enum
from python.py_proto import mace_pb2 from py_proto import mace_pb2
from python.utils.config_parser import DataFormat
from python.utils.config_parser import DeviceType from utils.config_parser import DataFormat
from utils.config_parser import DeviceType
# SAME_LOWER: if the amount of paddings to be added is odd, # SAME_LOWER: if the amount of paddings to be added is odd,
......
...@@ -12,9 +12,6 @@ ...@@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math import math
...@@ -22,18 +19,20 @@ import numpy as np ...@@ -22,18 +19,20 @@ import numpy as np
import six import six
import google.protobuf.text_format import google.protobuf.text_format
from python.py_proto import mace_pb2, caffe_pb2 from py_proto import mace_pb2
from python.utils.util import mace_check from transform import base_converter
from . import base_converter from transform import shape_inference
from . import shape_inference from transform.base_converter import PoolingType
from .base_converter import PoolingType from transform.base_converter import ActivationType
from .base_converter import ActivationType from transform.base_converter import EltwiseType
from .base_converter import EltwiseType from transform.base_converter import FrameworkType
from .base_converter import FrameworkType from transform.base_converter import DataFormat
from .base_converter import DataFormat from transform.base_converter import MaceOp
from .base_converter import MaceOp from transform.base_converter import MaceKeyword
from .base_converter import MaceKeyword from transform.base_converter import ConverterUtil
from .base_converter import ConverterUtil from utils.util import mace_check
from py_proto import caffe_pb2
caffe_group_str = 'group' caffe_group_str = 'group'
caffe_kernel_h_str = 'kernel_h' caffe_kernel_h_str = 'kernel_h'
......
...@@ -16,23 +16,24 @@ from __future__ import absolute_import ...@@ -16,23 +16,24 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import copy import copy
import numpy as np import numpy as np
from enum import Enum from enum import Enum
from operator import mul from operator import mul
from functools import reduce from functools import reduce
from python.py_proto import mace_pb2 from py_proto import mace_pb2
from python.utils.util import mace_check from transform import base_converter
from . import base_converter from transform.base_converter import ConverterUtil
from .base_converter import ConverterUtil from transform.base_converter import DeviceType
from .base_converter import DeviceType from transform.base_converter import EltwiseType
from .base_converter import EltwiseType from transform.base_converter import MaceKeyword
from .base_converter import MaceKeyword from transform.base_converter import MaceOp
from .base_converter import MaceOp from transform.base_converter import PaddingMode
from .base_converter import PaddingMode from transform.base_converter import PoolingType
from .base_converter import PoolingType from transform.base_converter import ReduceType
from .base_converter import ReduceType from utils.util import mace_check
HexagonSupportedOps = [ HexagonSupportedOps = [
......
...@@ -12,34 +12,32 @@ ...@@ -12,34 +12,32 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys import sys
from enum import Enum from enum import Enum
import six import six
import numpy as np
from numbers import Number
from python.py_proto import mace_pb2 from py_proto import mace_pb2
from . import base_converter from transform import base_converter
from .base_converter import PoolingType from transform.base_converter import PoolingType
from .base_converter import PaddingMode from transform.base_converter import PaddingMode
from .base_converter import ActivationType from transform.base_converter import ActivationType
from .base_converter import EltwiseType from transform.base_converter import EltwiseType
from .base_converter import ReduceType from transform.base_converter import ReduceType
from .base_converter import FrameworkType from transform.base_converter import FrameworkType
from .base_converter import RoundMode from transform.base_converter import RoundMode
from .base_converter import DataFormat from transform.base_converter import DataFormat
from .base_converter import MaceOp from transform.base_converter import MaceOp
from .base_converter import MaceKeyword from transform.base_converter import MaceKeyword
from .base_converter import ConverterUtil from transform.base_converter import ConverterUtil
from python.utils.util import mace_check from utils.util import mace_check
import numpy as np
import onnx import onnx
import onnx.utils import onnx.utils
from onnx import mapping, numpy_helper, TensorProto from onnx import mapping, numpy_helper, TensorProto
from numbers import Number
IS_PYTHON3 = sys.version_info > (3,) IS_PYTHON3 = sys.version_info > (3,)
...@@ -193,9 +191,9 @@ OnnxOpType = Enum('OnnxOpType', ...@@ -193,9 +191,9 @@ OnnxOpType = Enum('OnnxOpType',
onnx_attr_translator = { onnx_attr_translator = {
"axis": lambda x: int(x), "axis": lambda x: int(x),
"axes": lambda x: [int(a) for a in x], "axes": lambda x: [int(a) for a in x],
"dtype": lambda x: onnx_dtype(x), "dtype": lambda x: data_type.onnx2tf(x),
"keepdims": lambda x: bool(x), "keepdims": lambda x: bool(x),
"to": lambda x: onnx_dtype(x), "to": lambda x: data_type.onnx2tf(x),
} }
...@@ -569,7 +567,11 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -569,7 +567,11 @@ class OnnxConverter(base_converter.ConverterInterface):
tensor.data_type = mace_pb2.DT_FLOAT tensor.data_type = mace_pb2.DT_FLOAT
tensor.float_data.extend( tensor.float_data.extend(
onnx_tensor.astype(np.float32).flat) onnx_tensor.astype(np.float32).flat)
elif data_type == np.int32 or data_type == np.int64: elif data_type == np.int32:
tensor.data_type = mace_pb2.DT_INT32
tensor.int32_data.extend(
onnx_tensor.astype(np.int32).flat)
elif data_type == np.int64:
tensor.data_type = mace_pb2.DT_INT32 tensor.data_type = mace_pb2.DT_INT32
tensor.int32_data.extend( tensor.int32_data.extend(
onnx_tensor.astype(np.int32).flat) onnx_tensor.astype(np.int32).flat)
...@@ -666,9 +668,9 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -666,9 +668,9 @@ class OnnxConverter(base_converter.ConverterInterface):
if 'to' in node.attrs: if 'to' in node.attrs:
dtype = node.attrs['to'] dtype = node.attrs['to']
if dtype == np.float32 or dtype == np.float64: if dtype == TensorProto.FLOAT:
op.output_type.extend([self._option.data_type]) op.output_type.extend([self._option.data_type])
elif dtype == np.int64 or dtype == np.int32: elif dtype == TensorProto.INT:
op.output_type.extend([mace_pb2.DT_INT32]) op.output_type.extend([mace_pb2.DT_INT32])
else: else:
mace_check(False, "data type %s not supported" % dtype) mace_check(False, "data type %s not supported" % dtype)
...@@ -957,10 +959,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -957,10 +959,7 @@ class OnnxConverter(base_converter.ConverterInterface):
if len(const_tensor.dims) == 0: if len(const_tensor.dims) == 0:
value_arg = op.arg.add() value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.name = MaceKeyword.mace_scalar_input_str
if const_tensor.data_type == mace_pb2.DT_INT32: value_arg.f = const_tensor.float_data[0]
value_arg.f = float(const_tensor.int32_data[0])
elif const_tensor.data_type == mace_pb2.DT_FLOAT:
value_arg.f = const_tensor.float_data[0]
value_index_arg = op.arg.add() value_index_arg = op.arg.add()
value_index_arg.name = \ value_index_arg.name = \
MaceKeyword.mace_scalar_input_index_str MaceKeyword.mace_scalar_input_index_str
...@@ -973,10 +972,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -973,10 +972,7 @@ class OnnxConverter(base_converter.ConverterInterface):
if len(const_tensor.dims) == 0: if len(const_tensor.dims) == 0:
value_arg = op.arg.add() value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.name = MaceKeyword.mace_scalar_input_str
if const_tensor.data_type == mace_pb2.DT_INT32: value_arg.f = const_tensor.float_data[0]
value_arg.f = float(const_tensor.int32_data[0])
elif const_tensor.data_type == mace_pb2.DT_FLOAT:
value_arg.f = const_tensor.float_data[0]
value_index_arg = op.arg.add() value_index_arg = op.arg.add()
value_index_arg.name = \ value_index_arg.name = \
MaceKeyword.mace_scalar_input_index_str MaceKeyword.mace_scalar_input_index_str
......
...@@ -12,20 +12,18 @@ ...@@ -12,20 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math import math
import numpy as np import numpy as np
import six import six
from .transformer import Transformer from transform.transformer import Transformer
from .base_converter import DataFormat from transform.base_converter import DataFormat
from .base_converter import MaceOp from transform.base_converter import MaceOp
from .base_converter import MaceKeyword from transform.base_converter import MaceKeyword
from .base_converter import ConverterUtil from transform.base_converter import ConverterUtil
from python.utils.util import mace_check from utils.util import mace_check
class ShapeInference(object): class ShapeInference(object):
...@@ -255,7 +253,7 @@ class ShapeInference(object): ...@@ -255,7 +253,7 @@ class ShapeInference(object):
aspect_ratio = ConverterUtil.get_arg(op, MaceKeyword.mace_aspect_ratio_str).floats # noqa aspect_ratio = ConverterUtil.get_arg(op, MaceKeyword.mace_aspect_ratio_str).floats # noqa
num_prior = len(aspect_ratio) * len(min_size) + len(max_size) num_prior = len(aspect_ratio) * len(min_size) + len(max_size)
output_shape[2] = int(num_prior * input_h * input_w * 4) output_shape[2] = num_prior * input_h * input_w * 4
self.add_output_shape(op, [output_shape]) self.add_output_shape(op, [output_shape])
def infer_shape_reshape(self, op): def infer_shape_reshape(self, op):
...@@ -277,7 +275,7 @@ class ShapeInference(object): ...@@ -277,7 +275,7 @@ class ShapeInference(object):
output_shape[i] = dim[i] output_shape[i] = dim[i]
product *= dim[i] product *= dim[i]
if idx != -1: if idx != -1:
output_shape[idx] = int(input_size / product) output_shape[idx] = input_size / product
self.add_output_shape(op, [output_shape]) self.add_output_shape(op, [output_shape])
else: else:
output_shape = [] output_shape = []
......
...@@ -12,30 +12,27 @@ ...@@ -12,30 +12,27 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import import os
from __future__ import division
from __future__ import print_function
import math import math
import numpy as np import numpy as np
import six import six
import tensorflow as tf import tensorflow as tf
from enum import Enum from enum import Enum
from python.py_proto import mace_pb2 from py_proto import mace_pb2
from . import base_converter from transform import base_converter
from .base_converter import PoolingType from transform.base_converter import PoolingType
from .base_converter import PaddingMode from transform.base_converter import PaddingMode
from .base_converter import ActivationType from transform.base_converter import ActivationType
from .base_converter import EltwiseType from transform.base_converter import EltwiseType
from .base_converter import PadType from transform.base_converter import PadType
from .base_converter import FrameworkType from transform.base_converter import FrameworkType
from .base_converter import ReduceType from transform.base_converter import ReduceType
from .base_converter import DataFormat from transform.base_converter import DataFormat
from .base_converter import MaceOp from transform.base_converter import MaceOp
from .base_converter import MaceKeyword from transform.base_converter import MaceKeyword
from .base_converter import ConverterUtil from transform.base_converter import ConverterUtil
from python.utils.util import mace_check from utils.util import mace_check
from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.tools.graph_transforms import TransformGraph from tensorflow.tools.graph_transforms import TransformGraph
......
...@@ -12,31 +12,28 @@ ...@@ -12,31 +12,28 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re import re
import numpy as np import numpy as np
import six import six
from python.py_proto import mace_pb2 from py_proto import mace_pb2
from . import base_converter from transform import base_converter
from .base_converter import ConverterUtil from transform.base_converter import ConverterUtil
from .base_converter import DataFormat from transform.base_converter import DataFormat
from .base_converter import DeviceType from transform.base_converter import DeviceType
from .base_converter import EltwiseType from transform.base_converter import EltwiseType
from .base_converter import FrameworkType from transform.base_converter import FrameworkType
from .base_converter import MaceKeyword from transform.base_converter import MaceKeyword
from .base_converter import MaceOp from transform.base_converter import MaceOp
from .base_converter import MaceFixedDataFormatOps # noqa from transform.base_converter import MaceFixedDataFormatOps # noqa
from .base_converter import MaceTransposableDataFormatOps # noqa from transform.base_converter import MaceTransposableDataFormatOps # noqa
from .base_converter import PaddingMode from transform.base_converter import PaddingMode
from .base_converter import ReduceType from transform.base_converter import ReduceType
from .base_converter import TransformerRule from transform.base_converter import TransformerRule
from python.quantize import quantize_util from quantize import quantize_util
from python.utils.util import mace_check from utils.util import mace_check
class Transformer(base_converter.ConverterInterface): class Transformer(base_converter.ConverterInterface):
...@@ -1443,15 +1440,15 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1443,15 +1440,15 @@ class Transformer(base_converter.ConverterInterface):
arg.i = 1 arg.i = 1
elif arg.i == 3: elif arg.i == 3:
arg.i = 2 arg.i = 2
if op.input[0] in self._producer:
producer = self._producer[op.input[0]] producer = self._producer[op.input[0]]
input_shape = producer.output_shape[0].dims input_shape = producer.output_shape[0].dims
if producer.type == MaceOp.FullyConnected.name and\ if producer.type == MaceOp.FullyConnected.name and \
len(input_shape) == 2: len(input_shape) == 2:
axis_arg = ConverterUtil.get_arg( axis_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_axis_str) op, MaceKeyword.mace_axis_str)
if axis_arg.i == 1: if axis_arg.i == 1:
axis_arg.i = 3 axis_arg.i = 3
elif op.type == MaceOp.Squeeze.name: elif op.type == MaceOp.Squeeze.name:
for arg in op.arg: for arg in op.arg:
......
...@@ -22,9 +22,9 @@ import copy ...@@ -22,9 +22,9 @@ import copy
import yaml import yaml
from enum import Enum from enum import Enum
from python.utils.util import mace_check from utils.util import mace_check
from python.utils.util import MaceLogger from utils.util import MaceLogger
from python.py_proto import mace_pb2 from py_proto import mace_pb2
CPP_KEYWORDS = [ CPP_KEYWORDS = [
'alignas', 'alignof', 'and', 'and_eq', 'asm', 'atomic_cancel', 'alignas', 'alignof', 'and', 'and_eq', 'asm', 'atomic_cancel',
......
...@@ -23,7 +23,7 @@ import subprocess ...@@ -23,7 +23,7 @@ import subprocess
import random import random
import tempfile import tempfile
from python.utils import util from utils import util
def execute(cmd, verbose=True): def execute(cmd, verbose=True):
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function from __future__ import print_function
......
...@@ -12,19 +12,15 @@ ...@@ -12,19 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os import os
import os.path import os.path
import numpy as np import numpy as np
import six import six
from python.py_proto import mace_pb2 from py_proto import mace_pb2
from python.utils import util from utils import util
from python.utils.config_parser import DataFormat from utils.config_parser import DataFormat
from python.utils.config_parser import Platform from utils.config_parser import Platform
VALIDATION_MODULE = 'VALIDATION' VALIDATION_MODULE = 'VALIDATION'
......
...@@ -21,6 +21,7 @@ import re ...@@ -21,6 +21,7 @@ import re
import sh import sh
import struct import struct
import sys import sys
import time
import platform import platform
import six import six
......
...@@ -13,9 +13,11 @@ ...@@ -13,9 +13,11 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import sys
import os import os
import os.path import os.path
import numpy as np import numpy as np
import re
import six import six
import common import common
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册