提交 fa26d01c 编写于 作者: J jiangjiajun

modify code format

上级 73830eb2
- repo: local - repo: https://github.com/PaddlePaddle/mirrors-yapf.git
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks: hooks:
- id: yapf - id: yapf
name: yapf
entry: yapf
language: system
args: [-i, --style .style.yapf]
files: \.py$ files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
sha: a11d9314b22d8f8c7556443875b731ef05965464 sha: a11d9314b22d8f8c7556443875b731ef05965464
hooks: hooks:
...@@ -18,6 +14,7 @@ ...@@ -18,6 +14,7 @@
- id: check-symlinks - id: check-symlinks
- id: check-added-large-files - id: check-added-large-files
- repo: local - repo: local
hooks: hooks:
- id: copyright_checker - id: copyright_checker
name: copyright_checker name: copyright_checker
......
...@@ -11,8 +11,7 @@ setuptools.setup( ...@@ -11,8 +11,7 @@ setuptools.setup(
version=x2paddle.__version__, version=x2paddle.__version__,
author="dltp-sz", author="dltp-sz",
author_email="dltp-sz@baidu.com", author_email="dltp-sz@baidu.com",
description= description="a toolkit for converting trained model to PaddlePaddle from other deep learning frameworks.",
"a toolkit for converting trained model to PaddlePaddle from other deep learning frameworks.",
long_description=long_description, long_description=long_description,
long_description_content_type="text/plain", long_description_content_type="text/plain",
url="https://github.com/PaddlePaddle/x2paddle", url="https://github.com/PaddlePaddle/x2paddle",
...@@ -23,6 +22,4 @@ setuptools.setup( ...@@ -23,6 +22,4 @@ setuptools.setup(
"Operating System :: OS Independent", "Operating System :: OS Independent",
], ],
license='Apache 2.0', license='Apache 2.0',
entry_points={'console_scripts': [ entry_points={'console_scripts': ['x2paddle=x2paddle.convert:main', ]})
'x2paddle=x2paddle.convert:main',
]})
...@@ -48,8 +48,7 @@ def arg_parser(): ...@@ -48,8 +48,7 @@ def arg_parser():
"-f", "-f",
type=_text_type, type=_text_type,
default=None, default=None,
help= help="define which deeplearning framework(tensorflow/caffe/onnx/paddle2onnx)"
"define which deeplearning framework(tensorflow/caffe/onnx/paddle2onnx)"
) )
parser.add_argument( parser.add_argument(
"--caffe_proto", "--caffe_proto",
...@@ -126,7 +125,6 @@ def tf2paddle(model_path, ...@@ -126,7 +125,6 @@ def tf2paddle(model_path,
optimizer.merge_bias() optimizer.merge_bias()
optimizer.optimize_sub_graph() optimizer.optimize_sub_graph()
# optimizer.merge_batch_norm() # optimizer.merge_batch_norm()
# optimizer.merge_prelu() # optimizer.merge_prelu()
else: else:
......
...@@ -46,8 +46,9 @@ class Layer(object): ...@@ -46,8 +46,9 @@ class Layer(object):
for input in self.inputs: for input in self.inputs:
if isinstance(input, GraphNode): if isinstance(input, GraphNode):
if hasattr(input, "index"): if hasattr(input, "index"):
in_list += (input.layer_name + "[{}]".format( in_list += (
input.index) + ", ") input.layer_name + "[{}]".format(input.index) + ", "
)
else: else:
in_list += (input.layer_name + ", ") in_list += (input.layer_name + ", ")
elif isinstance(input, six.string_types): elif isinstance(input, six.string_types):
......
...@@ -34,8 +34,8 @@ class CaffeResolver(object): ...@@ -34,8 +34,8 @@ class CaffeResolver(object):
if not os.path.isfile(self.caffe_proto): if not os.path.isfile(self.caffe_proto):
raise Exception( raise Exception(
"The .py file compiled by caffe.proto is not exist.") "The .py file compiled by caffe.proto is not exist.")
(filepath, tempfilename) = os.path.split( (filepath,
os.path.abspath(self.caffe_proto)) tempfilename) = os.path.split(os.path.abspath(self.caffe_proto))
(filename, extension) = os.path.splitext(tempfilename) (filename, extension) = os.path.splitext(tempfilename)
sys.path.append(filepath) sys.path.append(filepath)
out = __import__(filename) out = __import__(filename)
...@@ -50,12 +50,10 @@ class CaffeGraphNode(GraphNode): ...@@ -50,12 +50,10 @@ class CaffeGraphNode(GraphNode):
def __init__(self, layer, type_str, layer_name=None): def __init__(self, layer, type_str, layer_name=None):
if layer_name is None: if layer_name is None:
super(CaffeGraphNode, self).__init__( super(CaffeGraphNode, self).__init__(
layer, layer, layer.name.replace('/', '_').replace('-', '_'))
layer.name.replace('/', '_').replace('-', '_'))
else: else:
super(CaffeGraphNode, self).__init__( super(CaffeGraphNode, self).__init__(
layer, layer, layer_name.replace('/', '_').replace('-', '_'))
layer_name.replace('/', '_').replace('-', '_'))
self.layer_type = type_str self.layer_type = type_str
self.fluid_code = FluidCode() self.fluid_code = FluidCode()
self.data = None self.data = None
......
...@@ -36,8 +36,7 @@ _PHASE = _descriptor.EnumDescriptor( ...@@ -36,8 +36,7 @@ _PHASE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=21908, serialized_start=21908,
serialized_end=21936, serialized_end=21936, )
)
_sym_db.RegisterEnumDescriptor(_PHASE) _sym_db.RegisterEnumDescriptor(_PHASE)
Phase = enum_type_wrapper.EnumTypeWrapper(_PHASE) Phase = enum_type_wrapper.EnumTypeWrapper(_PHASE)
...@@ -66,8 +65,7 @@ _EMITCONSTRAINT_EMITTYPE = _descriptor.EnumDescriptor( ...@@ -66,8 +65,7 @@ _EMITCONSTRAINT_EMITTYPE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=1146, serialized_start=1146,
serialized_end=1185, serialized_end=1185, )
)
_sym_db.RegisterEnumDescriptor(_EMITCONSTRAINT_EMITTYPE) _sym_db.RegisterEnumDescriptor(_EMITCONSTRAINT_EMITTYPE)
_ANNOTATEDDATUM_ANNOTATIONTYPE = _descriptor.EnumDescriptor( _ANNOTATEDDATUM_ANNOTATIONTYPE = _descriptor.EnumDescriptor(
...@@ -82,8 +80,7 @@ _ANNOTATEDDATUM_ANNOTATIONTYPE = _descriptor.EnumDescriptor( ...@@ -82,8 +80,7 @@ _ANNOTATEDDATUM_ANNOTATIONTYPE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=1629, serialized_start=1629,
serialized_end=1655, serialized_end=1655, )
)
_sym_db.RegisterEnumDescriptor(_ANNOTATEDDATUM_ANNOTATIONTYPE) _sym_db.RegisterEnumDescriptor(_ANNOTATEDDATUM_ANNOTATIONTYPE)
_FILLERPARAMETER_VARIANCENORM = _descriptor.EnumDescriptor( _FILLERPARAMETER_VARIANCENORM = _descriptor.EnumDescriptor(
...@@ -114,8 +111,7 @@ _FILLERPARAMETER_VARIANCENORM = _descriptor.EnumDescriptor( ...@@ -114,8 +111,7 @@ _FILLERPARAMETER_VARIANCENORM = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=1872, serialized_start=1872,
serialized_end=1924, serialized_end=1924, )
)
_sym_db.RegisterEnumDescriptor(_FILLERPARAMETER_VARIANCENORM) _sym_db.RegisterEnumDescriptor(_FILLERPARAMETER_VARIANCENORM)
_SOLVERPARAMETER_SNAPSHOTFORMAT = _descriptor.EnumDescriptor( _SOLVERPARAMETER_SNAPSHOTFORMAT = _descriptor.EnumDescriptor(
...@@ -136,8 +132,7 @@ _SOLVERPARAMETER_SNAPSHOTFORMAT = _descriptor.EnumDescriptor( ...@@ -136,8 +132,7 @@ _SOLVERPARAMETER_SNAPSHOTFORMAT = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=3480, serialized_start=3480,
serialized_end=3523, serialized_end=3523, )
)
_sym_db.RegisterEnumDescriptor(_SOLVERPARAMETER_SNAPSHOTFORMAT) _sym_db.RegisterEnumDescriptor(_SOLVERPARAMETER_SNAPSHOTFORMAT)
_SOLVERPARAMETER_SOLVERMODE = _descriptor.EnumDescriptor( _SOLVERPARAMETER_SOLVERMODE = _descriptor.EnumDescriptor(
...@@ -154,8 +149,7 @@ _SOLVERPARAMETER_SOLVERMODE = _descriptor.EnumDescriptor( ...@@ -154,8 +149,7 @@ _SOLVERPARAMETER_SOLVERMODE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=3525, serialized_start=3525,
serialized_end=3555, serialized_end=3555, )
)
_sym_db.RegisterEnumDescriptor(_SOLVERPARAMETER_SOLVERMODE) _sym_db.RegisterEnumDescriptor(_SOLVERPARAMETER_SOLVERMODE)
_SOLVERPARAMETER_SOLVERTYPE = _descriptor.EnumDescriptor( _SOLVERPARAMETER_SOLVERTYPE = _descriptor.EnumDescriptor(
...@@ -196,8 +190,7 @@ _SOLVERPARAMETER_SOLVERTYPE = _descriptor.EnumDescriptor( ...@@ -196,8 +190,7 @@ _SOLVERPARAMETER_SOLVERTYPE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=3557, serialized_start=3557,
serialized_end=3642, serialized_end=3642, )
)
_sym_db.RegisterEnumDescriptor(_SOLVERPARAMETER_SOLVERTYPE) _sym_db.RegisterEnumDescriptor(_SOLVERPARAMETER_SOLVERTYPE)
_PARAMSPEC_DIMCHECKMODE = _descriptor.EnumDescriptor( _PARAMSPEC_DIMCHECKMODE = _descriptor.EnumDescriptor(
...@@ -222,8 +215,7 @@ _PARAMSPEC_DIMCHECKMODE = _descriptor.EnumDescriptor( ...@@ -222,8 +215,7 @@ _PARAMSPEC_DIMCHECKMODE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=4131, serialized_start=4131,
serialized_end=4173, serialized_end=4173, )
)
_sym_db.RegisterEnumDescriptor(_PARAMSPEC_DIMCHECKMODE) _sym_db.RegisterEnumDescriptor(_PARAMSPEC_DIMCHECKMODE)
_RESIZEPARAMETER_RESIZE_MODE = _descriptor.EnumDescriptor( _RESIZEPARAMETER_RESIZE_MODE = _descriptor.EnumDescriptor(
...@@ -250,8 +242,7 @@ _RESIZEPARAMETER_RESIZE_MODE = _descriptor.EnumDescriptor( ...@@ -250,8 +242,7 @@ _RESIZEPARAMETER_RESIZE_MODE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=8049, serialized_start=8049,
serialized_end=8120, serialized_end=8120, )
)
_sym_db.RegisterEnumDescriptor(_RESIZEPARAMETER_RESIZE_MODE) _sym_db.RegisterEnumDescriptor(_RESIZEPARAMETER_RESIZE_MODE)
_RESIZEPARAMETER_PAD_MODE = _descriptor.EnumDescriptor( _RESIZEPARAMETER_PAD_MODE = _descriptor.EnumDescriptor(
...@@ -282,8 +273,7 @@ _RESIZEPARAMETER_PAD_MODE = _descriptor.EnumDescriptor( ...@@ -282,8 +273,7 @@ _RESIZEPARAMETER_PAD_MODE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=8122, serialized_start=8122,
serialized_end=8180, serialized_end=8180, )
)
_sym_db.RegisterEnumDescriptor(_RESIZEPARAMETER_PAD_MODE) _sym_db.RegisterEnumDescriptor(_RESIZEPARAMETER_PAD_MODE)
_RESIZEPARAMETER_INTERP_MODE = _descriptor.EnumDescriptor( _RESIZEPARAMETER_INTERP_MODE = _descriptor.EnumDescriptor(
...@@ -319,8 +309,7 @@ _RESIZEPARAMETER_INTERP_MODE = _descriptor.EnumDescriptor( ...@@ -319,8 +309,7 @@ _RESIZEPARAMETER_INTERP_MODE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=8182, serialized_start=8182,
serialized_end=8255, serialized_end=8255, )
)
_sym_db.RegisterEnumDescriptor(_RESIZEPARAMETER_INTERP_MODE) _sym_db.RegisterEnumDescriptor(_RESIZEPARAMETER_INTERP_MODE)
_LOSSPARAMETER_NORMALIZATIONMODE = _descriptor.EnumDescriptor( _LOSSPARAMETER_NORMALIZATIONMODE = _descriptor.EnumDescriptor(
...@@ -346,8 +335,7 @@ _LOSSPARAMETER_NORMALIZATIONMODE = _descriptor.EnumDescriptor( ...@@ -346,8 +335,7 @@ _LOSSPARAMETER_NORMALIZATIONMODE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=9202, serialized_start=9202,
serialized_end=9268, serialized_end=9268, )
)
_sym_db.RegisterEnumDescriptor(_LOSSPARAMETER_NORMALIZATIONMODE) _sym_db.RegisterEnumDescriptor(_LOSSPARAMETER_NORMALIZATIONMODE)
_CONVOLUTIONPARAMETER_ENGINE = _descriptor.EnumDescriptor( _CONVOLUTIONPARAMETER_ENGINE = _descriptor.EnumDescriptor(
...@@ -372,8 +360,7 @@ _CONVOLUTIONPARAMETER_ENGINE = _descriptor.EnumDescriptor( ...@@ -372,8 +360,7 @@ _CONVOLUTIONPARAMETER_ENGINE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=10385, serialized_start=10385,
serialized_end=10428, serialized_end=10428, )
)
_sym_db.RegisterEnumDescriptor(_CONVOLUTIONPARAMETER_ENGINE) _sym_db.RegisterEnumDescriptor(_CONVOLUTIONPARAMETER_ENGINE)
_DATAPARAMETER_DB = _descriptor.EnumDescriptor( _DATAPARAMETER_DB = _descriptor.EnumDescriptor(
...@@ -394,8 +381,7 @@ _DATAPARAMETER_DB = _descriptor.EnumDescriptor( ...@@ -394,8 +381,7 @@ _DATAPARAMETER_DB = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=10746, serialized_start=10746,
serialized_end=10773, serialized_end=10773, )
)
_sym_db.RegisterEnumDescriptor(_DATAPARAMETER_DB) _sym_db.RegisterEnumDescriptor(_DATAPARAMETER_DB)
_ELTWISEPARAMETER_ELTWISEOP = _descriptor.EnumDescriptor( _ELTWISEPARAMETER_ELTWISEOP = _descriptor.EnumDescriptor(
...@@ -414,8 +400,7 @@ _ELTWISEPARAMETER_ELTWISEOP = _descriptor.EnumDescriptor( ...@@ -414,8 +400,7 @@ _ELTWISEPARAMETER_ELTWISEOP = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=12106, serialized_start=12106,
serialized_end=12145, serialized_end=12145, )
)
_sym_db.RegisterEnumDescriptor(_ELTWISEPARAMETER_ELTWISEOP) _sym_db.RegisterEnumDescriptor(_ELTWISEPARAMETER_ELTWISEOP)
_HINGELOSSPARAMETER_NORM = _descriptor.EnumDescriptor( _HINGELOSSPARAMETER_NORM = _descriptor.EnumDescriptor(
...@@ -432,8 +417,7 @@ _HINGELOSSPARAMETER_NORM = _descriptor.EnumDescriptor( ...@@ -432,8 +417,7 @@ _HINGELOSSPARAMETER_NORM = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=12680, serialized_start=12680,
serialized_end=12702, serialized_end=12702, )
)
_sym_db.RegisterEnumDescriptor(_HINGELOSSPARAMETER_NORM) _sym_db.RegisterEnumDescriptor(_HINGELOSSPARAMETER_NORM)
_LRNPARAMETER_NORMREGION = _descriptor.EnumDescriptor( _LRNPARAMETER_NORMREGION = _descriptor.EnumDescriptor(
...@@ -458,8 +442,7 @@ _LRNPARAMETER_NORMREGION = _descriptor.EnumDescriptor( ...@@ -458,8 +442,7 @@ _LRNPARAMETER_NORMREGION = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=13569, serialized_start=13569,
serialized_end=13622, serialized_end=13622, )
)
_sym_db.RegisterEnumDescriptor(_LRNPARAMETER_NORMREGION) _sym_db.RegisterEnumDescriptor(_LRNPARAMETER_NORMREGION)
_LRNPARAMETER_ENGINE = _descriptor.EnumDescriptor( _LRNPARAMETER_ENGINE = _descriptor.EnumDescriptor(
...@@ -484,8 +467,7 @@ _LRNPARAMETER_ENGINE = _descriptor.EnumDescriptor( ...@@ -484,8 +467,7 @@ _LRNPARAMETER_ENGINE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=10385, serialized_start=10385,
serialized_end=10428, serialized_end=10428, )
)
_sym_db.RegisterEnumDescriptor(_LRNPARAMETER_ENGINE) _sym_db.RegisterEnumDescriptor(_LRNPARAMETER_ENGINE)
_MULTIBOXLOSSPARAMETER_LOCLOSSTYPE = _descriptor.EnumDescriptor( _MULTIBOXLOSSPARAMETER_LOCLOSSTYPE = _descriptor.EnumDescriptor(
...@@ -506,8 +488,7 @@ _MULTIBOXLOSSPARAMETER_LOCLOSSTYPE = _descriptor.EnumDescriptor( ...@@ -506,8 +488,7 @@ _MULTIBOXLOSSPARAMETER_LOCLOSSTYPE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=14703, serialized_start=14703,
serialized_end=14739, serialized_end=14739, )
)
_sym_db.RegisterEnumDescriptor(_MULTIBOXLOSSPARAMETER_LOCLOSSTYPE) _sym_db.RegisterEnumDescriptor(_MULTIBOXLOSSPARAMETER_LOCLOSSTYPE)
_MULTIBOXLOSSPARAMETER_CONFLOSSTYPE = _descriptor.EnumDescriptor( _MULTIBOXLOSSPARAMETER_CONFLOSSTYPE = _descriptor.EnumDescriptor(
...@@ -532,8 +513,7 @@ _MULTIBOXLOSSPARAMETER_CONFLOSSTYPE = _descriptor.EnumDescriptor( ...@@ -532,8 +513,7 @@ _MULTIBOXLOSSPARAMETER_CONFLOSSTYPE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=14741, serialized_start=14741,
serialized_end=14782, serialized_end=14782, )
)
_sym_db.RegisterEnumDescriptor(_MULTIBOXLOSSPARAMETER_CONFLOSSTYPE) _sym_db.RegisterEnumDescriptor(_MULTIBOXLOSSPARAMETER_CONFLOSSTYPE)
_MULTIBOXLOSSPARAMETER_MATCHTYPE = _descriptor.EnumDescriptor( _MULTIBOXLOSSPARAMETER_MATCHTYPE = _descriptor.EnumDescriptor(
...@@ -558,8 +538,7 @@ _MULTIBOXLOSSPARAMETER_MATCHTYPE = _descriptor.EnumDescriptor( ...@@ -558,8 +538,7 @@ _MULTIBOXLOSSPARAMETER_MATCHTYPE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=14784, serialized_start=14784,
serialized_end=14830, serialized_end=14830, )
)
_sym_db.RegisterEnumDescriptor(_MULTIBOXLOSSPARAMETER_MATCHTYPE) _sym_db.RegisterEnumDescriptor(_MULTIBOXLOSSPARAMETER_MATCHTYPE)
_MULTIBOXLOSSPARAMETER_MININGTYPE = _descriptor.EnumDescriptor( _MULTIBOXLOSSPARAMETER_MININGTYPE = _descriptor.EnumDescriptor(
...@@ -586,8 +565,7 @@ _MULTIBOXLOSSPARAMETER_MININGTYPE = _descriptor.EnumDescriptor( ...@@ -586,8 +565,7 @@ _MULTIBOXLOSSPARAMETER_MININGTYPE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=14832, serialized_start=14832,
serialized_end=14890, serialized_end=14890, )
)
_sym_db.RegisterEnumDescriptor(_MULTIBOXLOSSPARAMETER_MININGTYPE) _sym_db.RegisterEnumDescriptor(_MULTIBOXLOSSPARAMETER_MININGTYPE)
_POOLINGPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor( _POOLINGPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor(
...@@ -610,8 +588,7 @@ _POOLINGPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor( ...@@ -610,8 +588,7 @@ _POOLINGPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=15561, serialized_start=15561,
serialized_end=15607, serialized_end=15607, )
)
_sym_db.RegisterEnumDescriptor(_POOLINGPARAMETER_POOLMETHOD) _sym_db.RegisterEnumDescriptor(_POOLINGPARAMETER_POOLMETHOD)
_POOLINGPARAMETER_ENGINE = _descriptor.EnumDescriptor( _POOLINGPARAMETER_ENGINE = _descriptor.EnumDescriptor(
...@@ -636,8 +613,7 @@ _POOLINGPARAMETER_ENGINE = _descriptor.EnumDescriptor( ...@@ -636,8 +613,7 @@ _POOLINGPARAMETER_ENGINE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=10385, serialized_start=10385,
serialized_end=10428, serialized_end=10428, )
)
_sym_db.RegisterEnumDescriptor(_POOLINGPARAMETER_ENGINE) _sym_db.RegisterEnumDescriptor(_POOLINGPARAMETER_ENGINE)
_PRIORBOXPARAMETER_CODETYPE = _descriptor.EnumDescriptor( _PRIORBOXPARAMETER_CODETYPE = _descriptor.EnumDescriptor(
...@@ -668,8 +644,7 @@ _PRIORBOXPARAMETER_CODETYPE = _descriptor.EnumDescriptor( ...@@ -668,8 +644,7 @@ _PRIORBOXPARAMETER_CODETYPE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=15980, serialized_start=15980,
serialized_end=16036, serialized_end=16036, )
)
_sym_db.RegisterEnumDescriptor(_PRIORBOXPARAMETER_CODETYPE) _sym_db.RegisterEnumDescriptor(_PRIORBOXPARAMETER_CODETYPE)
_REDUCTIONPARAMETER_REDUCTIONOP = _descriptor.EnumDescriptor( _REDUCTIONPARAMETER_REDUCTIONOP = _descriptor.EnumDescriptor(
...@@ -691,8 +666,7 @@ _REDUCTIONPARAMETER_REDUCTIONOP = _descriptor.EnumDescriptor( ...@@ -691,8 +666,7 @@ _REDUCTIONPARAMETER_REDUCTIONOP = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=16459, serialized_start=16459,
serialized_end=16512, serialized_end=16512, )
)
_sym_db.RegisterEnumDescriptor(_REDUCTIONPARAMETER_REDUCTIONOP) _sym_db.RegisterEnumDescriptor(_REDUCTIONPARAMETER_REDUCTIONOP)
_RELUPARAMETER_ENGINE = _descriptor.EnumDescriptor( _RELUPARAMETER_ENGINE = _descriptor.EnumDescriptor(
...@@ -717,8 +691,7 @@ _RELUPARAMETER_ENGINE = _descriptor.EnumDescriptor( ...@@ -717,8 +691,7 @@ _RELUPARAMETER_ENGINE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=10385, serialized_start=10385,
serialized_end=10428, serialized_end=10428, )
)
_sym_db.RegisterEnumDescriptor(_RELUPARAMETER_ENGINE) _sym_db.RegisterEnumDescriptor(_RELUPARAMETER_ENGINE)
_SIGMOIDPARAMETER_ENGINE = _descriptor.EnumDescriptor( _SIGMOIDPARAMETER_ENGINE = _descriptor.EnumDescriptor(
...@@ -743,8 +716,7 @@ _SIGMOIDPARAMETER_ENGINE = _descriptor.EnumDescriptor( ...@@ -743,8 +716,7 @@ _SIGMOIDPARAMETER_ENGINE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=10385, serialized_start=10385,
serialized_end=10428, serialized_end=10428, )
)
_sym_db.RegisterEnumDescriptor(_SIGMOIDPARAMETER_ENGINE) _sym_db.RegisterEnumDescriptor(_SIGMOIDPARAMETER_ENGINE)
_SOFTMAXPARAMETER_ENGINE = _descriptor.EnumDescriptor( _SOFTMAXPARAMETER_ENGINE = _descriptor.EnumDescriptor(
...@@ -769,8 +741,7 @@ _SOFTMAXPARAMETER_ENGINE = _descriptor.EnumDescriptor( ...@@ -769,8 +741,7 @@ _SOFTMAXPARAMETER_ENGINE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=10385, serialized_start=10385,
serialized_end=10428, serialized_end=10428, )
)
_sym_db.RegisterEnumDescriptor(_SOFTMAXPARAMETER_ENGINE) _sym_db.RegisterEnumDescriptor(_SOFTMAXPARAMETER_ENGINE)
_TANHPARAMETER_ENGINE = _descriptor.EnumDescriptor( _TANHPARAMETER_ENGINE = _descriptor.EnumDescriptor(
...@@ -795,8 +766,7 @@ _TANHPARAMETER_ENGINE = _descriptor.EnumDescriptor( ...@@ -795,8 +766,7 @@ _TANHPARAMETER_ENGINE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=10385, serialized_start=10385,
serialized_end=10428, serialized_end=10428, )
)
_sym_db.RegisterEnumDescriptor(_TANHPARAMETER_ENGINE) _sym_db.RegisterEnumDescriptor(_TANHPARAMETER_ENGINE)
_VIDEODATAPARAMETER_VIDEOTYPE = _descriptor.EnumDescriptor( _VIDEODATAPARAMETER_VIDEOTYPE = _descriptor.EnumDescriptor(
...@@ -818,8 +788,7 @@ _VIDEODATAPARAMETER_VIDEOTYPE = _descriptor.EnumDescriptor( ...@@ -818,8 +788,7 @@ _VIDEODATAPARAMETER_VIDEOTYPE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=17621, serialized_start=17621,
serialized_end=17655, serialized_end=17655, )
)
_sym_db.RegisterEnumDescriptor(_VIDEODATAPARAMETER_VIDEOTYPE) _sym_db.RegisterEnumDescriptor(_VIDEODATAPARAMETER_VIDEOTYPE)
_SPPPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor( _SPPPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor(
...@@ -842,8 +811,7 @@ _SPPPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor( ...@@ -842,8 +811,7 @@ _SPPPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=15561, serialized_start=15561,
serialized_end=15607, serialized_end=15607, )
)
_sym_db.RegisterEnumDescriptor(_SPPPARAMETER_POOLMETHOD) _sym_db.RegisterEnumDescriptor(_SPPPARAMETER_POOLMETHOD)
_SPPPARAMETER_ENGINE = _descriptor.EnumDescriptor( _SPPPARAMETER_ENGINE = _descriptor.EnumDescriptor(
...@@ -868,8 +836,7 @@ _SPPPARAMETER_ENGINE = _descriptor.EnumDescriptor( ...@@ -868,8 +836,7 @@ _SPPPARAMETER_ENGINE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=10385, serialized_start=10385,
serialized_end=10428, serialized_end=10428, )
)
_sym_db.RegisterEnumDescriptor(_SPPPARAMETER_ENGINE) _sym_db.RegisterEnumDescriptor(_SPPPARAMETER_ENGINE)
_V1LAYERPARAMETER_LAYERTYPE = _descriptor.EnumDescriptor( _V1LAYERPARAMETER_LAYERTYPE = _descriptor.EnumDescriptor(
...@@ -1101,8 +1068,7 @@ _V1LAYERPARAMETER_LAYERTYPE = _descriptor.EnumDescriptor( ...@@ -1101,8 +1068,7 @@ _V1LAYERPARAMETER_LAYERTYPE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=20104, serialized_start=20104,
serialized_end=20704, serialized_end=20704, )
)
_sym_db.RegisterEnumDescriptor(_V1LAYERPARAMETER_LAYERTYPE) _sym_db.RegisterEnumDescriptor(_V1LAYERPARAMETER_LAYERTYPE)
_V1LAYERPARAMETER_DIMCHECKMODE = _descriptor.EnumDescriptor( _V1LAYERPARAMETER_DIMCHECKMODE = _descriptor.EnumDescriptor(
...@@ -1127,8 +1093,7 @@ _V1LAYERPARAMETER_DIMCHECKMODE = _descriptor.EnumDescriptor( ...@@ -1127,8 +1093,7 @@ _V1LAYERPARAMETER_DIMCHECKMODE = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=4131, serialized_start=4131,
serialized_end=4173, serialized_end=4173, )
)
_sym_db.RegisterEnumDescriptor(_V1LAYERPARAMETER_DIMCHECKMODE) _sym_db.RegisterEnumDescriptor(_V1LAYERPARAMETER_DIMCHECKMODE)
_V0LAYERPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor( _V0LAYERPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor(
...@@ -1151,8 +1116,7 @@ _V0LAYERPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor( ...@@ -1151,8 +1116,7 @@ _V0LAYERPARAMETER_POOLMETHOD = _descriptor.EnumDescriptor(
containing_type=None, containing_type=None,
serialized_options=None, serialized_options=None,
serialized_start=15561, serialized_start=15561,
serialized_end=15607, serialized_end=15607, )
)
_sym_db.RegisterEnumDescriptor(_V0LAYERPARAMETER_POOLMETHOD) _sym_db.RegisterEnumDescriptor(_V0LAYERPARAMETER_POOLMETHOD)
_BLOBSHAPE = _descriptor.Descriptor( _BLOBSHAPE = _descriptor.Descriptor(
...@@ -1189,8 +1153,7 @@ _BLOBSHAPE = _descriptor.Descriptor( ...@@ -1189,8 +1153,7 @@ _BLOBSHAPE = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=22, serialized_start=22,
serialized_end=50, serialized_end=50, )
)
_BLOBPROTO = _descriptor.Descriptor( _BLOBPROTO = _descriptor.Descriptor(
name='BlobProto', name='BlobProto',
...@@ -1362,8 +1325,7 @@ _BLOBPROTO = _descriptor.Descriptor( ...@@ -1362,8 +1325,7 @@ _BLOBPROTO = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=53, serialized_start=53,
serialized_end=257, serialized_end=257, )
)
_BLOBPROTOVECTOR = _descriptor.Descriptor( _BLOBPROTOVECTOR = _descriptor.Descriptor(
name='BlobProtoVector', name='BlobProtoVector',
...@@ -1399,8 +1361,7 @@ _BLOBPROTOVECTOR = _descriptor.Descriptor( ...@@ -1399,8 +1361,7 @@ _BLOBPROTOVECTOR = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=259, serialized_start=259,
serialized_end=309, serialized_end=309, )
)
_DATUM = _descriptor.Descriptor( _DATUM = _descriptor.Descriptor(
name='Datum', name='Datum',
...@@ -1538,8 +1499,7 @@ _DATUM = _descriptor.Descriptor( ...@@ -1538,8 +1499,7 @@ _DATUM = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=312, serialized_start=312,
serialized_end=441, serialized_end=441, )
)
_LABELMAPITEM = _descriptor.Descriptor( _LABELMAPITEM = _descriptor.Descriptor(
name='LabelMapItem', name='LabelMapItem',
...@@ -1609,8 +1569,7 @@ _LABELMAPITEM = _descriptor.Descriptor( ...@@ -1609,8 +1569,7 @@ _LABELMAPITEM = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=443, serialized_start=443,
serialized_end=508, serialized_end=508, )
)
_LABELMAP = _descriptor.Descriptor( _LABELMAP = _descriptor.Descriptor(
name='LabelMap', name='LabelMap',
...@@ -1646,8 +1605,7 @@ _LABELMAP = _descriptor.Descriptor( ...@@ -1646,8 +1605,7 @@ _LABELMAP = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=510, serialized_start=510,
serialized_end=555, serialized_end=555, )
)
_SAMPLER = _descriptor.Descriptor( _SAMPLER = _descriptor.Descriptor(
name='Sampler', name='Sampler',
...@@ -1734,8 +1692,7 @@ _SAMPLER = _descriptor.Descriptor( ...@@ -1734,8 +1692,7 @@ _SAMPLER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=557, serialized_start=557,
serialized_end=668, serialized_end=668, )
)
_SAMPLECONSTRAINT = _descriptor.Descriptor( _SAMPLECONSTRAINT = _descriptor.Descriptor(
name='SampleConstraint', name='SampleConstraint',
...@@ -1856,8 +1813,7 @@ _SAMPLECONSTRAINT = _descriptor.Descriptor( ...@@ -1856,8 +1813,7 @@ _SAMPLECONSTRAINT = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=671, serialized_start=671,
serialized_end=863, serialized_end=863, )
)
_BATCHSAMPLER = _descriptor.Descriptor( _BATCHSAMPLER = _descriptor.Descriptor(
name='BatchSampler', name='BatchSampler',
...@@ -1961,8 +1917,7 @@ _BATCHSAMPLER = _descriptor.Descriptor( ...@@ -1961,8 +1917,7 @@ _BATCHSAMPLER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=866, serialized_start=866,
serialized_end=1044, serialized_end=1044, )
)
_EMITCONSTRAINT = _descriptor.Descriptor( _EMITCONSTRAINT = _descriptor.Descriptor(
name='EmitConstraint', name='EmitConstraint',
...@@ -2008,17 +1963,14 @@ _EMITCONSTRAINT = _descriptor.Descriptor( ...@@ -2008,17 +1963,14 @@ _EMITCONSTRAINT = _descriptor.Descriptor(
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[_EMITCONSTRAINT_EMITTYPE, ],
_EMITCONSTRAINT_EMITTYPE,
],
serialized_options=None, serialized_options=None,
is_extendable=False, is_extendable=False,
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=1047, serialized_start=1047,
serialized_end=1185, serialized_end=1185, )
)
_NORMALIZEDBBOX = _descriptor.Descriptor( _NORMALIZEDBBOX = _descriptor.Descriptor(
name='NormalizedBBox', name='NormalizedBBox',
...@@ -2173,8 +2125,7 @@ _NORMALIZEDBBOX = _descriptor.Descriptor( ...@@ -2173,8 +2125,7 @@ _NORMALIZEDBBOX = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=1188, serialized_start=1188,
serialized_end=1323, serialized_end=1323, )
)
_ANNOTATION = _descriptor.Descriptor( _ANNOTATION = _descriptor.Descriptor(
name='Annotation', name='Annotation',
...@@ -2227,8 +2178,7 @@ _ANNOTATION = _descriptor.Descriptor( ...@@ -2227,8 +2178,7 @@ _ANNOTATION = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=1325, serialized_start=1325,
serialized_end=1398, serialized_end=1398, )
)
_ANNOTATIONGROUP = _descriptor.Descriptor( _ANNOTATIONGROUP = _descriptor.Descriptor(
name='AnnotationGroup', name='AnnotationGroup',
...@@ -2281,8 +2231,7 @@ _ANNOTATIONGROUP = _descriptor.Descriptor( ...@@ -2281,8 +2231,7 @@ _ANNOTATIONGROUP = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=1400, serialized_start=1400,
serialized_end=1477, serialized_end=1477, )
)
_ANNOTATEDDATUM = _descriptor.Descriptor( _ANNOTATEDDATUM = _descriptor.Descriptor(
name='AnnotatedDatum', name='AnnotatedDatum',
...@@ -2345,17 +2294,14 @@ _ANNOTATEDDATUM = _descriptor.Descriptor( ...@@ -2345,17 +2294,14 @@ _ANNOTATEDDATUM = _descriptor.Descriptor(
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[_ANNOTATEDDATUM_ANNOTATIONTYPE, ],
_ANNOTATEDDATUM_ANNOTATIONTYPE,
],
serialized_options=None, serialized_options=None,
is_extendable=False, is_extendable=False,
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=1480, serialized_start=1480,
serialized_end=1655, serialized_end=1655, )
)
_FILLERPARAMETER = _descriptor.Descriptor( _FILLERPARAMETER = _descriptor.Descriptor(
name='FillerParameter', name='FillerParameter',
...@@ -2503,17 +2449,14 @@ _FILLERPARAMETER = _descriptor.Descriptor( ...@@ -2503,17 +2449,14 @@ _FILLERPARAMETER = _descriptor.Descriptor(
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[_FILLERPARAMETER_VARIANCENORM, ],
_FILLERPARAMETER_VARIANCENORM,
],
serialized_options=None, serialized_options=None,
is_extendable=False, is_extendable=False,
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=1658, serialized_start=1658,
serialized_end=1924, serialized_end=1924, )
)
_NETPARAMETER = _descriptor.Descriptor( _NETPARAMETER = _descriptor.Descriptor(
name='NetParameter', name='NetParameter',
...@@ -2685,8 +2628,7 @@ _NETPARAMETER = _descriptor.Descriptor( ...@@ -2685,8 +2628,7 @@ _NETPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=1927, serialized_start=1927,
serialized_end=2197, serialized_end=2197, )
)
_SOLVERPARAMETER = _descriptor.Descriptor( _SOLVERPARAMETER = _descriptor.Descriptor(
name='SolverParameter', name='SolverParameter',
...@@ -3457,8 +3399,7 @@ _SOLVERPARAMETER = _descriptor.Descriptor( ...@@ -3457,8 +3399,7 @@ _SOLVERPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=2200, serialized_start=2200,
serialized_end=3642, serialized_end=3642, )
)
_SOLVERSTATE = _descriptor.Descriptor( _SOLVERSTATE = _descriptor.Descriptor(
name='SolverState', name='SolverState',
...@@ -3579,8 +3520,7 @@ _SOLVERSTATE = _descriptor.Descriptor( ...@@ -3579,8 +3520,7 @@ _SOLVERSTATE = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=3645, serialized_start=3645,
serialized_end=3810, serialized_end=3810, )
)
_NETSTATE = _descriptor.Descriptor( _NETSTATE = _descriptor.Descriptor(
name='NetState', name='NetState',
...@@ -3650,8 +3590,7 @@ _NETSTATE = _descriptor.Descriptor( ...@@ -3650,8 +3590,7 @@ _NETSTATE = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=3812, serialized_start=3812,
serialized_end=3890, serialized_end=3890, )
)
_NETSTATERULE = _descriptor.Descriptor( _NETSTATERULE = _descriptor.Descriptor(
name='NetStateRule', name='NetStateRule',
...@@ -3755,8 +3694,7 @@ _NETSTATERULE = _descriptor.Descriptor( ...@@ -3755,8 +3694,7 @@ _NETSTATERULE = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=3892, serialized_start=3892,
serialized_end=4007, serialized_end=4007, )
)
_PARAMSPEC = _descriptor.Descriptor( _PARAMSPEC = _descriptor.Descriptor(
name='ParamSpec', name='ParamSpec',
...@@ -3836,17 +3774,14 @@ _PARAMSPEC = _descriptor.Descriptor( ...@@ -3836,17 +3774,14 @@ _PARAMSPEC = _descriptor.Descriptor(
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[_PARAMSPEC_DIMCHECKMODE, ],
_PARAMSPEC_DIMCHECKMODE,
],
serialized_options=None, serialized_options=None,
is_extendable=False, is_extendable=False,
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=4010, serialized_start=4010,
serialized_end=4173, serialized_end=4173, )
)
_LAYERPARAMETER = _descriptor.Descriptor( _LAYERPARAMETER = _descriptor.Descriptor(
name='LayerParameter', name='LayerParameter',
...@@ -5004,8 +4939,7 @@ _LAYERPARAMETER = _descriptor.Descriptor( ...@@ -5004,8 +4939,7 @@ _LAYERPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=4176, serialized_start=4176,
serialized_end=7263, serialized_end=7263, )
)
_TRANSFORMATIONPARAMETER = _descriptor.Descriptor( _TRANSFORMATIONPARAMETER = _descriptor.Descriptor(
name='TransformationParameter', name='TransformationParameter',
...@@ -5262,8 +5196,7 @@ _TRANSFORMATIONPARAMETER = _descriptor.Descriptor( ...@@ -5262,8 +5196,7 @@ _TRANSFORMATIONPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=7266, serialized_start=7266,
serialized_end=7724, serialized_end=7724, )
)
_RESIZEPARAMETER = _descriptor.Descriptor( _RESIZEPARAMETER = _descriptor.Descriptor(
name='ResizeParameter', name='ResizeParameter',
...@@ -5439,8 +5372,7 @@ _RESIZEPARAMETER = _descriptor.Descriptor( ...@@ -5439,8 +5372,7 @@ _RESIZEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=7727, serialized_start=7727,
serialized_end=8255, serialized_end=8255, )
)
_SALTPEPPERPARAMETER = _descriptor.Descriptor( _SALTPEPPERPARAMETER = _descriptor.Descriptor(
name='SaltPepperParameter', name='SaltPepperParameter',
...@@ -5493,8 +5425,7 @@ _SALTPEPPERPARAMETER = _descriptor.Descriptor( ...@@ -5493,8 +5425,7 @@ _SALTPEPPERPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=8257, serialized_start=8257,
serialized_end=8314, serialized_end=8314, )
)
_NOISEPARAMETER = _descriptor.Descriptor( _NOISEPARAMETER = _descriptor.Descriptor(
name='NoiseParameter', name='NoiseParameter',
...@@ -5734,8 +5665,7 @@ _NOISEPARAMETER = _descriptor.Descriptor( ...@@ -5734,8 +5665,7 @@ _NOISEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=8317, serialized_start=8317,
serialized_end=8683, serialized_end=8683, )
)
_DISTORTIONPARAMETER = _descriptor.Descriptor( _DISTORTIONPARAMETER = _descriptor.Descriptor(
name='DistortionParameter', name='DistortionParameter',
...@@ -5941,8 +5871,7 @@ _DISTORTIONPARAMETER = _descriptor.Descriptor( ...@@ -5941,8 +5871,7 @@ _DISTORTIONPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=8686, serialized_start=8686,
serialized_end=9003, serialized_end=9003, )
)
_EXPANSIONPARAMETER = _descriptor.Descriptor( _EXPANSIONPARAMETER = _descriptor.Descriptor(
name='ExpansionParameter', name='ExpansionParameter',
...@@ -5995,8 +5924,7 @@ _EXPANSIONPARAMETER = _descriptor.Descriptor( ...@@ -5995,8 +5924,7 @@ _EXPANSIONPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=9005, serialized_start=9005,
serialized_end=9071, serialized_end=9071, )
)
_LOSSPARAMETER = _descriptor.Descriptor( _LOSSPARAMETER = _descriptor.Descriptor(
name='LossParameter', name='LossParameter',
...@@ -6059,17 +5987,14 @@ _LOSSPARAMETER = _descriptor.Descriptor( ...@@ -6059,17 +5987,14 @@ _LOSSPARAMETER = _descriptor.Descriptor(
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[_LOSSPARAMETER_NORMALIZATIONMODE, ],
_LOSSPARAMETER_NORMALIZATIONMODE,
],
serialized_options=None, serialized_options=None,
is_extendable=False, is_extendable=False,
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=9074, serialized_start=9074,
serialized_end=9268, serialized_end=9268, )
)
_ACCURACYPARAMETER = _descriptor.Descriptor( _ACCURACYPARAMETER = _descriptor.Descriptor(
name='AccuracyParameter', name='AccuracyParameter',
...@@ -6139,8 +6064,7 @@ _ACCURACYPARAMETER = _descriptor.Descriptor( ...@@ -6139,8 +6064,7 @@ _ACCURACYPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=9270, serialized_start=9270,
serialized_end=9346, serialized_end=9346, )
)
_ANNOTATEDDATAPARAMETER = _descriptor.Descriptor( _ANNOTATEDDATAPARAMETER = _descriptor.Descriptor(
name='AnnotatedDataParameter', name='AnnotatedDataParameter',
...@@ -6210,8 +6134,7 @@ _ANNOTATEDDATAPARAMETER = _descriptor.Descriptor( ...@@ -6210,8 +6134,7 @@ _ANNOTATEDDATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=9349, serialized_start=9349,
serialized_end=9498, serialized_end=9498, )
)
_ARGMAXPARAMETER = _descriptor.Descriptor( _ARGMAXPARAMETER = _descriptor.Descriptor(
name='ArgMaxParameter', name='ArgMaxParameter',
...@@ -6281,8 +6204,7 @@ _ARGMAXPARAMETER = _descriptor.Descriptor( ...@@ -6281,8 +6204,7 @@ _ARGMAXPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=9500, serialized_start=9500,
serialized_end=9577, serialized_end=9577, )
)
_CONCATPARAMETER = _descriptor.Descriptor( _CONCATPARAMETER = _descriptor.Descriptor(
name='ConcatParameter', name='ConcatParameter',
...@@ -6335,8 +6257,7 @@ _CONCATPARAMETER = _descriptor.Descriptor( ...@@ -6335,8 +6257,7 @@ _CONCATPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=9579, serialized_start=9579,
serialized_end=9636, serialized_end=9636, )
)
_BATCHNORMPARAMETER = _descriptor.Descriptor( _BATCHNORMPARAMETER = _descriptor.Descriptor(
name='BatchNormParameter', name='BatchNormParameter',
...@@ -6406,8 +6327,7 @@ _BATCHNORMPARAMETER = _descriptor.Descriptor( ...@@ -6406,8 +6327,7 @@ _BATCHNORMPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=9638, serialized_start=9638,
serialized_end=9744, serialized_end=9744, )
)
_BIASPARAMETER = _descriptor.Descriptor( _BIASPARAMETER = _descriptor.Descriptor(
name='BiasParameter', name='BiasParameter',
...@@ -6477,8 +6397,7 @@ _BIASPARAMETER = _descriptor.Descriptor( ...@@ -6477,8 +6397,7 @@ _BIASPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=9746, serialized_start=9746,
serialized_end=9839, serialized_end=9839, )
)
_CONTRASTIVELOSSPARAMETER = _descriptor.Descriptor( _CONTRASTIVELOSSPARAMETER = _descriptor.Descriptor(
name='ContrastiveLossParameter', name='ContrastiveLossParameter',
...@@ -6531,8 +6450,7 @@ _CONTRASTIVELOSSPARAMETER = _descriptor.Descriptor( ...@@ -6531,8 +6450,7 @@ _CONTRASTIVELOSSPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=9841, serialized_start=9841,
serialized_end=9917, serialized_end=9917, )
)
_CONVOLUTIONPARAMETER = _descriptor.Descriptor( _CONVOLUTIONPARAMETER = _descriptor.Descriptor(
name='ConvolutionParameter', name='ConvolutionParameter',
...@@ -6850,17 +6768,14 @@ _CONVOLUTIONPARAMETER = _descriptor.Descriptor( ...@@ -6850,17 +6768,14 @@ _CONVOLUTIONPARAMETER = _descriptor.Descriptor(
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[_CONVOLUTIONPARAMETER_ENGINE, ],
_CONVOLUTIONPARAMETER_ENGINE,
],
serialized_options=None, serialized_options=None,
is_extendable=False, is_extendable=False,
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=9920, serialized_start=9920,
serialized_end=10428, serialized_end=10428, )
)
_CROPPARAMETER = _descriptor.Descriptor( _CROPPARAMETER = _descriptor.Descriptor(
name='CropParameter', name='CropParameter',
...@@ -6913,8 +6828,7 @@ _CROPPARAMETER = _descriptor.Descriptor( ...@@ -6913,8 +6828,7 @@ _CROPPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=10430, serialized_start=10430,
serialized_end=10478, serialized_end=10478, )
)
_DATAPARAMETER = _descriptor.Descriptor( _DATAPARAMETER = _descriptor.Descriptor(
name='DataParameter', name='DataParameter',
...@@ -7096,17 +7010,14 @@ _DATAPARAMETER = _descriptor.Descriptor( ...@@ -7096,17 +7010,14 @@ _DATAPARAMETER = _descriptor.Descriptor(
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[_DATAPARAMETER_DB, ],
_DATAPARAMETER_DB,
],
serialized_options=None, serialized_options=None,
is_extendable=False, is_extendable=False,
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=10481, serialized_start=10481,
serialized_end=10773, serialized_end=10773, )
)
_DETECTIONEVALUATEPARAMETER = _descriptor.Descriptor( _DETECTIONEVALUATEPARAMETER = _descriptor.Descriptor(
name='DetectionEvaluateParameter', name='DetectionEvaluateParameter',
...@@ -7227,8 +7138,7 @@ _DETECTIONEVALUATEPARAMETER = _descriptor.Descriptor( ...@@ -7227,8 +7138,7 @@ _DETECTIONEVALUATEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=10776, serialized_start=10776,
serialized_end=10996, serialized_end=10996, )
)
_NONMAXIMUMSUPPRESSIONPARAMETER = _descriptor.Descriptor( _NONMAXIMUMSUPPRESSIONPARAMETER = _descriptor.Descriptor(
name='NonMaximumSuppressionParameter', name='NonMaximumSuppressionParameter',
...@@ -7298,8 +7208,7 @@ _NONMAXIMUMSUPPRESSIONPARAMETER = _descriptor.Descriptor( ...@@ -7298,8 +7208,7 @@ _NONMAXIMUMSUPPRESSIONPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=10998, serialized_start=10998,
serialized_end=11089, serialized_end=11089, )
)
_SAVEOUTPUTPARAMETER = _descriptor.Descriptor( _SAVEOUTPUTPARAMETER = _descriptor.Descriptor(
name='SaveOutputParameter', name='SaveOutputParameter',
...@@ -7437,8 +7346,7 @@ _SAVEOUTPUTPARAMETER = _descriptor.Descriptor( ...@@ -7437,8 +7346,7 @@ _SAVEOUTPUTPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=11092, serialized_start=11092,
serialized_end=11308, serialized_end=11308, )
)
_DETECTIONOUTPUTPARAMETER = _descriptor.Descriptor( _DETECTIONOUTPUTPARAMETER = _descriptor.Descriptor(
name='DetectionOutputParameter', name='DetectionOutputParameter',
...@@ -7551,8 +7459,7 @@ _DETECTIONOUTPUTPARAMETER = _descriptor.Descriptor( ...@@ -7551,8 +7459,7 @@ _DETECTIONOUTPUTPARAMETER = _descriptor.Descriptor(
file=DESCRIPTOR), file=DESCRIPTOR),
_descriptor.FieldDescriptor( _descriptor.FieldDescriptor(
name='variance_encoded_in_target', name='variance_encoded_in_target',
full_name= full_name='caffe.DetectionOutputParameter.variance_encoded_in_target',
'caffe.DetectionOutputParameter.variance_encoded_in_target',
index=6, index=6,
number=8, number=8,
type=8, type=8,
...@@ -7662,8 +7569,7 @@ _DETECTIONOUTPUTPARAMETER = _descriptor.Descriptor( ...@@ -7662,8 +7569,7 @@ _DETECTIONOUTPUTPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=11311, serialized_start=11311,
serialized_end=11766, serialized_end=11766, )
)
_DROPOUTPARAMETER = _descriptor.Descriptor( _DROPOUTPARAMETER = _descriptor.Descriptor(
name='DropoutParameter', name='DropoutParameter',
...@@ -7699,8 +7605,7 @@ _DROPOUTPARAMETER = _descriptor.Descriptor( ...@@ -7699,8 +7605,7 @@ _DROPOUTPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=11768, serialized_start=11768,
serialized_end=11814, serialized_end=11814, )
)
_DUMMYDATAPARAMETER = _descriptor.Descriptor( _DUMMYDATAPARAMETER = _descriptor.Descriptor(
name='DummyDataParameter', name='DummyDataParameter',
...@@ -7821,8 +7726,7 @@ _DUMMYDATAPARAMETER = _descriptor.Descriptor( ...@@ -7821,8 +7726,7 @@ _DUMMYDATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=11817, serialized_start=11817,
serialized_end=11977, serialized_end=11977, )
)
_ELTWISEPARAMETER = _descriptor.Descriptor( _ELTWISEPARAMETER = _descriptor.Descriptor(
name='EltwiseParameter', name='EltwiseParameter',
...@@ -7885,17 +7789,14 @@ _ELTWISEPARAMETER = _descriptor.Descriptor( ...@@ -7885,17 +7789,14 @@ _ELTWISEPARAMETER = _descriptor.Descriptor(
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[_ELTWISEPARAMETER_ELTWISEOP, ],
_ELTWISEPARAMETER_ELTWISEOP,
],
serialized_options=None, serialized_options=None,
is_extendable=False, is_extendable=False,
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=11980, serialized_start=11980,
serialized_end=12145, serialized_end=12145, )
)
_ELUPARAMETER = _descriptor.Descriptor( _ELUPARAMETER = _descriptor.Descriptor(
name='ELUParameter', name='ELUParameter',
...@@ -7931,8 +7832,7 @@ _ELUPARAMETER = _descriptor.Descriptor( ...@@ -7931,8 +7832,7 @@ _ELUPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=12147, serialized_start=12147,
serialized_end=12179, serialized_end=12179, )
)
_EMBEDPARAMETER = _descriptor.Descriptor( _EMBEDPARAMETER = _descriptor.Descriptor(
name='EmbedParameter', name='EmbedParameter',
...@@ -8036,8 +7936,7 @@ _EMBEDPARAMETER = _descriptor.Descriptor( ...@@ -8036,8 +7936,7 @@ _EMBEDPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=12182, serialized_start=12182,
serialized_end=12354, serialized_end=12354, )
)
_EXPPARAMETER = _descriptor.Descriptor( _EXPPARAMETER = _descriptor.Descriptor(
name='ExpParameter', name='ExpParameter',
...@@ -8107,8 +8006,7 @@ _EXPPARAMETER = _descriptor.Descriptor( ...@@ -8107,8 +8006,7 @@ _EXPPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=12356, serialized_start=12356,
serialized_end=12424, serialized_end=12424, )
)
_FLATTENPARAMETER = _descriptor.Descriptor( _FLATTENPARAMETER = _descriptor.Descriptor(
name='FlattenParameter', name='FlattenParameter',
...@@ -8161,8 +8059,7 @@ _FLATTENPARAMETER = _descriptor.Descriptor( ...@@ -8161,8 +8059,7 @@ _FLATTENPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=12426, serialized_start=12426,
serialized_end=12483, serialized_end=12483, )
)
_HDF5DATAPARAMETER = _descriptor.Descriptor( _HDF5DATAPARAMETER = _descriptor.Descriptor(
name='HDF5DataParameter', name='HDF5DataParameter',
...@@ -8232,8 +8129,7 @@ _HDF5DATAPARAMETER = _descriptor.Descriptor( ...@@ -8232,8 +8129,7 @@ _HDF5DATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=12485, serialized_start=12485,
serialized_end=12564, serialized_end=12564, )
)
_HDF5OUTPUTPARAMETER = _descriptor.Descriptor( _HDF5OUTPUTPARAMETER = _descriptor.Descriptor(
name='HDF5OutputParameter', name='HDF5OutputParameter',
...@@ -8269,8 +8165,7 @@ _HDF5OUTPUTPARAMETER = _descriptor.Descriptor( ...@@ -8269,8 +8165,7 @@ _HDF5OUTPUTPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=12566, serialized_start=12566,
serialized_end=12606, serialized_end=12606, )
)
_HINGELOSSPARAMETER = _descriptor.Descriptor( _HINGELOSSPARAMETER = _descriptor.Descriptor(
name='HingeLossParameter', name='HingeLossParameter',
...@@ -8299,17 +8194,14 @@ _HINGELOSSPARAMETER = _descriptor.Descriptor( ...@@ -8299,17 +8194,14 @@ _HINGELOSSPARAMETER = _descriptor.Descriptor(
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[_HINGELOSSPARAMETER_NORM, ],
_HINGELOSSPARAMETER_NORM,
],
serialized_options=None, serialized_options=None,
is_extendable=False, is_extendable=False,
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=12608, serialized_start=12608,
serialized_end=12702, serialized_end=12702, )
)
_IMAGEDATAPARAMETER = _descriptor.Descriptor( _IMAGEDATAPARAMETER = _descriptor.Descriptor(
name='ImageDataParameter', name='ImageDataParameter',
...@@ -8532,8 +8424,7 @@ _IMAGEDATAPARAMETER = _descriptor.Descriptor( ...@@ -8532,8 +8424,7 @@ _IMAGEDATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=12705, serialized_start=12705,
serialized_end=12984, serialized_end=12984, )
)
_INFOGAINLOSSPARAMETER = _descriptor.Descriptor( _INFOGAINLOSSPARAMETER = _descriptor.Descriptor(
name='InfogainLossParameter', name='InfogainLossParameter',
...@@ -8569,8 +8460,7 @@ _INFOGAINLOSSPARAMETER = _descriptor.Descriptor( ...@@ -8569,8 +8460,7 @@ _INFOGAINLOSSPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=12986, serialized_start=12986,
serialized_end=13025, serialized_end=13025, )
)
_INNERPRODUCTPARAMETER = _descriptor.Descriptor( _INNERPRODUCTPARAMETER = _descriptor.Descriptor(
name='InnerProductParameter', name='InnerProductParameter',
...@@ -8691,8 +8581,7 @@ _INNERPRODUCTPARAMETER = _descriptor.Descriptor( ...@@ -8691,8 +8581,7 @@ _INNERPRODUCTPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=13028, serialized_start=13028,
serialized_end=13231, serialized_end=13231, )
)
_INPUTPARAMETER = _descriptor.Descriptor( _INPUTPARAMETER = _descriptor.Descriptor(
name='InputParameter', name='InputParameter',
...@@ -8728,8 +8617,7 @@ _INPUTPARAMETER = _descriptor.Descriptor( ...@@ -8728,8 +8617,7 @@ _INPUTPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=13233, serialized_start=13233,
serialized_end=13282, serialized_end=13282, )
)
_LOGPARAMETER = _descriptor.Descriptor( _LOGPARAMETER = _descriptor.Descriptor(
name='LogParameter', name='LogParameter',
...@@ -8799,8 +8687,7 @@ _LOGPARAMETER = _descriptor.Descriptor( ...@@ -8799,8 +8687,7 @@ _LOGPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=13284, serialized_start=13284,
serialized_end=13352, serialized_end=13352, )
)
_LRNPARAMETER = _descriptor.Descriptor( _LRNPARAMETER = _descriptor.Descriptor(
name='LRNParameter', name='LRNParameter',
...@@ -8924,8 +8811,7 @@ _LRNPARAMETER = _descriptor.Descriptor( ...@@ -8924,8 +8811,7 @@ _LRNPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=13355, serialized_start=13355,
serialized_end=13667, serialized_end=13667, )
)
_MEMORYDATAPARAMETER = _descriptor.Descriptor( _MEMORYDATAPARAMETER = _descriptor.Descriptor(
name='MemoryDataParameter', name='MemoryDataParameter',
...@@ -9012,8 +8898,7 @@ _MEMORYDATAPARAMETER = _descriptor.Descriptor( ...@@ -9012,8 +8898,7 @@ _MEMORYDATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=13669, serialized_start=13669,
serialized_end=13759, serialized_end=13759, )
)
_MULTIBOXLOSSPARAMETER = _descriptor.Descriptor( _MULTIBOXLOSSPARAMETER = _descriptor.Descriptor(
name='MultiBoxLossParameter', name='MultiBoxLossParameter',
...@@ -9411,8 +9296,7 @@ _MULTIBOXLOSSPARAMETER = _descriptor.Descriptor( ...@@ -9411,8 +9296,7 @@ _MULTIBOXLOSSPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=13762, serialized_start=13762,
serialized_end=14890, serialized_end=14890, )
)
_MVNPARAMETER = _descriptor.Descriptor( _MVNPARAMETER = _descriptor.Descriptor(
name='MVNParameter', name='MVNParameter',
...@@ -9482,8 +9366,7 @@ _MVNPARAMETER = _descriptor.Descriptor( ...@@ -9482,8 +9366,7 @@ _MVNPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=14892, serialized_start=14892,
serialized_end=14992, serialized_end=14992, )
)
_NORMALIZEPARAMETER = _descriptor.Descriptor( _NORMALIZEPARAMETER = _descriptor.Descriptor(
name='NormalizeParameter', name='NormalizeParameter',
...@@ -9570,8 +9453,7 @@ _NORMALIZEPARAMETER = _descriptor.Descriptor( ...@@ -9570,8 +9453,7 @@ _NORMALIZEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=14995, serialized_start=14995,
serialized_end=15141, serialized_end=15141, )
)
_PARAMETERPARAMETER = _descriptor.Descriptor( _PARAMETERPARAMETER = _descriptor.Descriptor(
name='ParameterParameter', name='ParameterParameter',
...@@ -9607,8 +9489,7 @@ _PARAMETERPARAMETER = _descriptor.Descriptor( ...@@ -9607,8 +9489,7 @@ _PARAMETERPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=15143, serialized_start=15143,
serialized_end=15196, serialized_end=15196, )
)
_PERMUTEPARAMETER = _descriptor.Descriptor( _PERMUTEPARAMETER = _descriptor.Descriptor(
name='PermuteParameter', name='PermuteParameter',
...@@ -9644,8 +9525,7 @@ _PERMUTEPARAMETER = _descriptor.Descriptor( ...@@ -9644,8 +9525,7 @@ _PERMUTEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=15198, serialized_start=15198,
serialized_end=15231, serialized_end=15231, )
)
_POOLINGPARAMETER = _descriptor.Descriptor( _POOLINGPARAMETER = _descriptor.Descriptor(
name='PoolingParameter', name='PoolingParameter',
...@@ -9871,8 +9751,7 @@ _POOLINGPARAMETER = _descriptor.Descriptor( ...@@ -9871,8 +9751,7 @@ _POOLINGPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=15234, serialized_start=15234,
serialized_end=15652, serialized_end=15652, )
)
_POWERPARAMETER = _descriptor.Descriptor( _POWERPARAMETER = _descriptor.Descriptor(
name='PowerParameter', name='PowerParameter',
...@@ -9942,8 +9821,7 @@ _POWERPARAMETER = _descriptor.Descriptor( ...@@ -9942,8 +9821,7 @@ _POWERPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=15654, serialized_start=15654,
serialized_end=15724, serialized_end=15724, )
)
_PRIORBOXPARAMETER = _descriptor.Descriptor( _PRIORBOXPARAMETER = _descriptor.Descriptor(
name='PriorBoxParameter', name='PriorBoxParameter',
...@@ -10176,17 +10054,14 @@ _PRIORBOXPARAMETER = _descriptor.Descriptor( ...@@ -10176,17 +10054,14 @@ _PRIORBOXPARAMETER = _descriptor.Descriptor(
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[_PRIORBOXPARAMETER_CODETYPE, ],
_PRIORBOXPARAMETER_CODETYPE,
],
serialized_options=None, serialized_options=None,
is_extendable=False, is_extendable=False,
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=15727, serialized_start=15727,
serialized_end=16036, serialized_end=16036, )
)
_PYTHONPARAMETER = _descriptor.Descriptor( _PYTHONPARAMETER = _descriptor.Descriptor(
name='PythonParameter', name='PythonParameter',
...@@ -10273,8 +10148,7 @@ _PYTHONPARAMETER = _descriptor.Descriptor( ...@@ -10273,8 +10148,7 @@ _PYTHONPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=16038, serialized_start=16038,
serialized_end=16141, serialized_end=16141, )
)
_RECURRENTPARAMETER = _descriptor.Descriptor( _RECURRENTPARAMETER = _descriptor.Descriptor(
name='RecurrentParameter', name='RecurrentParameter',
...@@ -10378,8 +10252,7 @@ _RECURRENTPARAMETER = _descriptor.Descriptor( ...@@ -10378,8 +10252,7 @@ _RECURRENTPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=16144, serialized_start=16144,
serialized_end=16336, serialized_end=16336, )
)
_REDUCTIONPARAMETER = _descriptor.Descriptor( _REDUCTIONPARAMETER = _descriptor.Descriptor(
name='ReductionParameter', name='ReductionParameter',
...@@ -10442,17 +10315,14 @@ _REDUCTIONPARAMETER = _descriptor.Descriptor( ...@@ -10442,17 +10315,14 @@ _REDUCTIONPARAMETER = _descriptor.Descriptor(
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[_REDUCTIONPARAMETER_REDUCTIONOP, ],
_REDUCTIONPARAMETER_REDUCTIONOP,
],
serialized_options=None, serialized_options=None,
is_extendable=False, is_extendable=False,
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=16339, serialized_start=16339,
serialized_end=16512, serialized_end=16512, )
)
_RELUPARAMETER = _descriptor.Descriptor( _RELUPARAMETER = _descriptor.Descriptor(
name='ReLUParameter', name='ReLUParameter',
...@@ -10498,17 +10368,14 @@ _RELUPARAMETER = _descriptor.Descriptor( ...@@ -10498,17 +10368,14 @@ _RELUPARAMETER = _descriptor.Descriptor(
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[_RELUPARAMETER_ENGINE, ],
_RELUPARAMETER_ENGINE,
],
serialized_options=None, serialized_options=None,
is_extendable=False, is_extendable=False,
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=16515, serialized_start=16515,
serialized_end=16656, serialized_end=16656, )
)
_RESHAPEPARAMETER = _descriptor.Descriptor( _RESHAPEPARAMETER = _descriptor.Descriptor(
name='ReshapeParameter', name='ReshapeParameter',
...@@ -10578,8 +10445,7 @@ _RESHAPEPARAMETER = _descriptor.Descriptor( ...@@ -10578,8 +10445,7 @@ _RESHAPEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=16658, serialized_start=16658,
serialized_end=16748, serialized_end=16748, )
)
_SCALEPARAMETER = _descriptor.Descriptor( _SCALEPARAMETER = _descriptor.Descriptor(
name='ScaleParameter', name='ScaleParameter',
...@@ -10683,8 +10549,7 @@ _SCALEPARAMETER = _descriptor.Descriptor( ...@@ -10683,8 +10549,7 @@ _SCALEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=16751, serialized_start=16751,
serialized_end=16916, serialized_end=16916, )
)
_SIGMOIDPARAMETER = _descriptor.Descriptor( _SIGMOIDPARAMETER = _descriptor.Descriptor(
name='SigmoidParameter', name='SigmoidParameter',
...@@ -10713,17 +10578,14 @@ _SIGMOIDPARAMETER = _descriptor.Descriptor( ...@@ -10713,17 +10578,14 @@ _SIGMOIDPARAMETER = _descriptor.Descriptor(
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[_SIGMOIDPARAMETER_ENGINE, ],
_SIGMOIDPARAMETER_ENGINE,
],
serialized_options=None, serialized_options=None,
is_extendable=False, is_extendable=False,
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=16918, serialized_start=16918,
serialized_end=17038, serialized_end=17038, )
)
_SLICEPARAMETER = _descriptor.Descriptor( _SLICEPARAMETER = _descriptor.Descriptor(
name='SliceParameter', name='SliceParameter',
...@@ -10793,8 +10655,7 @@ _SLICEPARAMETER = _descriptor.Descriptor( ...@@ -10793,8 +10655,7 @@ _SLICEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=17040, serialized_start=17040,
serialized_end=17116, serialized_end=17116, )
)
_SOFTMAXPARAMETER = _descriptor.Descriptor( _SOFTMAXPARAMETER = _descriptor.Descriptor(
name='SoftmaxParameter', name='SoftmaxParameter',
...@@ -10840,17 +10701,14 @@ _SOFTMAXPARAMETER = _descriptor.Descriptor( ...@@ -10840,17 +10701,14 @@ _SOFTMAXPARAMETER = _descriptor.Descriptor(
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[_SOFTMAXPARAMETER_ENGINE, ],
_SOFTMAXPARAMETER_ENGINE,
],
serialized_options=None, serialized_options=None,
is_extendable=False, is_extendable=False,
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=17119, serialized_start=17119,
serialized_end=17256, serialized_end=17256, )
)
_TANHPARAMETER = _descriptor.Descriptor( _TANHPARAMETER = _descriptor.Descriptor(
name='TanHParameter', name='TanHParameter',
...@@ -10879,17 +10737,14 @@ _TANHPARAMETER = _descriptor.Descriptor( ...@@ -10879,17 +10737,14 @@ _TANHPARAMETER = _descriptor.Descriptor(
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[_TANHPARAMETER_ENGINE, ],
_TANHPARAMETER_ENGINE,
],
serialized_options=None, serialized_options=None,
is_extendable=False, is_extendable=False,
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=17258, serialized_start=17258,
serialized_end=17372, serialized_end=17372, )
)
_TILEPARAMETER = _descriptor.Descriptor( _TILEPARAMETER = _descriptor.Descriptor(
name='TileParameter', name='TileParameter',
...@@ -10942,8 +10797,7 @@ _TILEPARAMETER = _descriptor.Descriptor( ...@@ -10942,8 +10797,7 @@ _TILEPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=17374, serialized_start=17374,
serialized_end=17421, serialized_end=17421, )
)
_THRESHOLDPARAMETER = _descriptor.Descriptor( _THRESHOLDPARAMETER = _descriptor.Descriptor(
name='ThresholdParameter', name='ThresholdParameter',
...@@ -10979,8 +10833,7 @@ _THRESHOLDPARAMETER = _descriptor.Descriptor( ...@@ -10979,8 +10833,7 @@ _THRESHOLDPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=17423, serialized_start=17423,
serialized_end=17465, serialized_end=17465, )
)
_VIDEODATAPARAMETER = _descriptor.Descriptor( _VIDEODATAPARAMETER = _descriptor.Descriptor(
name='VideoDataParameter', name='VideoDataParameter',
...@@ -11060,17 +10913,14 @@ _VIDEODATAPARAMETER = _descriptor.Descriptor( ...@@ -11060,17 +10913,14 @@ _VIDEODATAPARAMETER = _descriptor.Descriptor(
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[_VIDEODATAPARAMETER_VIDEOTYPE, ],
_VIDEODATAPARAMETER_VIDEOTYPE,
],
serialized_options=None, serialized_options=None,
is_extendable=False, is_extendable=False,
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=17468, serialized_start=17468,
serialized_end=17655, serialized_end=17655, )
)
_WINDOWDATAPARAMETER = _descriptor.Descriptor( _WINDOWDATAPARAMETER = _descriptor.Descriptor(
name='WindowDataParameter', name='WindowDataParameter',
...@@ -11310,8 +11160,7 @@ _WINDOWDATAPARAMETER = _descriptor.Descriptor( ...@@ -11310,8 +11160,7 @@ _WINDOWDATAPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=17658, serialized_start=17658,
serialized_end=17979, serialized_end=17979, )
)
_SPPPARAMETER = _descriptor.Descriptor( _SPPPARAMETER = _descriptor.Descriptor(
name='SPPParameter', name='SPPParameter',
...@@ -11384,8 +11233,7 @@ _SPPPARAMETER = _descriptor.Descriptor( ...@@ -11384,8 +11233,7 @@ _SPPPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=17982, serialized_start=17982,
serialized_end=18217, serialized_end=18217, )
)
_V1LAYERPARAMETER = _descriptor.Descriptor( _V1LAYERPARAMETER = _descriptor.Descriptor(
name='V1LayerParameter', name='V1LayerParameter',
...@@ -12138,8 +11986,7 @@ _V1LAYERPARAMETER = _descriptor.Descriptor( ...@@ -12138,8 +11986,7 @@ _V1LAYERPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=18220, serialized_start=18220,
serialized_end=20748, serialized_end=20748, )
)
_V0LAYERPARAMETER = _descriptor.Descriptor( _V0LAYERPARAMETER = _descriptor.Descriptor(
name='V0LayerParameter', name='V0LayerParameter',
...@@ -12797,17 +12644,14 @@ _V0LAYERPARAMETER = _descriptor.Descriptor( ...@@ -12797,17 +12644,14 @@ _V0LAYERPARAMETER = _descriptor.Descriptor(
], ],
extensions=[], extensions=[],
nested_types=[], nested_types=[],
enum_types=[ enum_types=[_V0LAYERPARAMETER_POOLMETHOD, ],
_V0LAYERPARAMETER_POOLMETHOD,
],
serialized_options=None, serialized_options=None,
is_extendable=False, is_extendable=False,
syntax='proto2', syntax='proto2',
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=20751, serialized_start=20751,
serialized_end=21772, serialized_end=21772, )
)
_PRELUPARAMETER = _descriptor.Descriptor( _PRELUPARAMETER = _descriptor.Descriptor(
name='PReLUParameter', name='PReLUParameter',
...@@ -12860,8 +12704,7 @@ _PRELUPARAMETER = _descriptor.Descriptor( ...@@ -12860,8 +12704,7 @@ _PRELUPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=21774, serialized_start=21774,
serialized_end=21861, serialized_end=21861, )
)
_SHUFFLECHANNELPARAMETER = _descriptor.Descriptor( _SHUFFLECHANNELPARAMETER = _descriptor.Descriptor(
name='ShuffleChannelParameter', name='ShuffleChannelParameter',
...@@ -12897,8 +12740,7 @@ _SHUFFLECHANNELPARAMETER = _descriptor.Descriptor( ...@@ -12897,8 +12740,7 @@ _SHUFFLECHANNELPARAMETER = _descriptor.Descriptor(
extension_ranges=[], extension_ranges=[],
oneofs=[], oneofs=[],
serialized_start=21863, serialized_start=21863,
serialized_end=21906, serialized_end=21906, )
)
_BLOBPROTO.fields_by_name['shape'].message_type = _BLOBSHAPE _BLOBPROTO.fields_by_name['shape'].message_type = _BLOBSHAPE
_BLOBPROTOVECTOR.fields_by_name['blobs'].message_type = _BLOBPROTO _BLOBPROTOVECTOR.fields_by_name['blobs'].message_type = _BLOBPROTO
......
...@@ -492,8 +492,8 @@ class ONNXDecoder(object): ...@@ -492,8 +492,8 @@ class ONNXDecoder(object):
sess = rt.InferenceSession(model_path) sess = rt.InferenceSession(model_path)
for ipt in sess.get_inputs(): for ipt in sess.get_inputs():
datatype = datatype_map[ipt.type] datatype = datatype_map[ipt.type]
input_dict[ipt.name] = np.random.random( input_dict[ipt.name] = np.random.random(ipt.shape).astype(
ipt.shape).astype(datatype) datatype)
res = sess.run(None, input_feed=input_dict) res = sess.run(None, input_feed=input_dict)
except: except:
......
...@@ -120,8 +120,8 @@ def convolutiondepthwise_layer(inputs, ...@@ -120,8 +120,8 @@ def convolutiondepthwise_layer(inputs,
dila_len) dila_len)
c_in = input_shape[0][1] c_in = input_shape[0][1]
c_out = num_output if num_output is not None else input_shape[0][1] c_out = num_output if num_output is not None else input_shape[0][1]
group = int(c_in / (c_in / c_out)) if c_in > c_out else int( group = int(c_in / (c_in / c_out)) if c_in > c_out else int(c_in /
c_in / (c_out / c_in)) (c_out / c_in))
out = fluid.layers.conv2d( out = fluid.layers.conv2d(
input, input,
dilation=[dila_h, dila_w], dilation=[dila_h, dila_w],
......
...@@ -23,8 +23,7 @@ def register(kind, shape, layer, weights): ...@@ -23,8 +23,7 @@ def register(kind, shape, layer, weights):
kind = [kind] kind = [kind]
else: else:
assert type( assert type(
kind kind) is list, 'invalid param "kind" for register, not a list or str'
) is list, 'invalid param "kind" for register, not a list or str'
for k in kind: for k in kind:
assert type( assert type(
......
...@@ -144,8 +144,8 @@ class CaffeOpMapper(OpMapper): ...@@ -144,8 +144,8 @@ class CaffeOpMapper(OpMapper):
[s_h, s_w] = [params.stride] * 2 [s_h, s_w] = [params.stride] * 2
elif len(params.stride) > 0: elif len(params.stride) > 0:
s_h = params.stride_h if params.stride_h > 0 else params.stride[0] s_h = params.stride_h if params.stride_h > 0 else params.stride[0]
s_w = params.stride_w if params.stride_w > 0 else params.stride[ s_w = params.stride_w if params.stride_w > 0 else params.stride[len(
len(params.stride) - 1] params.stride) - 1]
elif params.stride_h > 0 or params.stride_w > 0: elif params.stride_h > 0 or params.stride_w > 0:
s_h = params.stride_h s_h = params.stride_h
s_w = params.stride_w s_w = params.stride_w
...@@ -154,8 +154,8 @@ class CaffeOpMapper(OpMapper): ...@@ -154,8 +154,8 @@ class CaffeOpMapper(OpMapper):
[p_h, p_w] = [params.pad] * 2 [p_h, p_w] = [params.pad] * 2
elif len(params.pad) > 0: elif len(params.pad) > 0:
p_h = params.pad_h if params.pad_h > 0 else params.pad[0] p_h = params.pad_h if params.pad_h > 0 else params.pad[0]
p_w = params.pad_w if params.pad_w > 0 else params.pad[ p_w = params.pad_w if params.pad_w > 0 else params.pad[len(
len(params.pad) - 1] params.pad) - 1]
elif params.pad_h > 0 or params.pad_w > 0: elif params.pad_h > 0 or params.pad_w > 0:
p_h = params.pad_h p_h = params.pad_h
p_w = params.pad_w p_w = params.pad_w
...@@ -225,11 +225,9 @@ class CaffeOpMapper(OpMapper): ...@@ -225,11 +225,9 @@ class CaffeOpMapper(OpMapper):
input_c = node.input_shape[0][1] input_c = node.input_shape[0][1]
output_c = channel output_c = channel
data.append( data.append(
np.zeros([output_c, input_c, kernel[0], np.zeros([output_c, input_c, kernel[0], kernel[1]]).astype(
kernel[1]]).astype('float32')) 'float32'))
data.append(np.zeros([ data.append(np.zeros([output_c, ])).astype('float32')
output_c,
])).astype('float32')
else: else:
data = self.adjust_parameters(node) data = self.adjust_parameters(node)
self.weights[node.layer_name + '_weights'] = data[0] self.weights[node.layer_name + '_weights'] = data[0]
...@@ -240,24 +238,16 @@ class CaffeOpMapper(OpMapper): ...@@ -240,24 +238,16 @@ class CaffeOpMapper(OpMapper):
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
attr = { attr = {
'filter_size': 'filter_size': kernel,
kernel, 'num_filters': channel,
'num_filters': 'stride': stride,
channel, 'padding': pad,
'stride': 'dilation': dilation,
stride, 'groups': group,
'padding': 'name': string(node.layer_name),
pad, 'param_attr': string(node.layer_name + '_weights'),
'dilation': 'bias_attr': False
dilation, if len(data) == 1 else string(node.layer_name + '_bias'),
'groups':
group,
'name':
string(node.layer_name),
'param_attr':
string(node.layer_name + '_weights'),
'bias_attr':
False if len(data) == 1 else string(node.layer_name + '_bias'),
} }
node.fluid_code.add_layer( node.fluid_code.add_layer(
"conv2d", inputs=input, output=node, param_attr=attr) "conv2d", inputs=input, output=node, param_attr=attr)
...@@ -275,11 +265,9 @@ class CaffeOpMapper(OpMapper): ...@@ -275,11 +265,9 @@ class CaffeOpMapper(OpMapper):
input_c = node.input_shape[0][1] input_c = node.input_shape[0][1]
output_c = channel output_c = channel
data.append( data.append(
np.zeros([output_c, input_c, kernel[0], np.zeros([output_c, input_c, kernel[0], kernel[1]]).astype(
kernel[1]]).astype('float32')) 'float32'))
data.append(np.zeros([ data.append(np.zeros([output_c, ]).astype('float32'))
output_c,
]).astype('float32'))
else: else:
data = self.adjust_parameters(node) data = self.adjust_parameters(node)
self.weights[node.layer_name + '_weights'] = data[0] self.weights[node.layer_name + '_weights'] = data[0]
...@@ -289,26 +277,17 @@ class CaffeOpMapper(OpMapper): ...@@ -289,26 +277,17 @@ class CaffeOpMapper(OpMapper):
) == 1, 'The count of Deconvolution node\'s input is not 1.' ) == 1, 'The count of Deconvolution node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
attr = { attr = {
'output_size': 'output_size': None,
None, 'filter_size': kernel,
'filter_size': 'num_filters': channel,
kernel, 'stride': stride,
'num_filters': 'padding': pad,
channel, 'dilation': dilation,
'stride': 'groups': group,
stride, 'name': string(node.layer_name),
'padding': 'param_attr': string(node.layer_name + '_weights'),
pad, 'bias_attr': False
'dilation': if len(data) == 1 else string(node.layer_name + '_bias')
dilation,
'groups':
group,
'name':
string(node.layer_name),
'param_attr':
string(node.layer_name + '_weights'),
'bias_attr':
False if len(data) == 1 else string(node.layer_name + '_bias')
} }
node.fluid_code.add_layer( node.fluid_code.add_layer(
"conv2d_transpose", inputs=input, output=node, param_attr=attr) "conv2d_transpose", inputs=input, output=node, param_attr=attr)
...@@ -372,8 +351,8 @@ class CaffeOpMapper(OpMapper): ...@@ -372,8 +351,8 @@ class CaffeOpMapper(OpMapper):
output_c = params.num_output output_c = params.num_output
data = [] data = []
data.append( data.append(
np.zeros([input_c, np.zeros([input_c, output_c]).astype('float32').astype(
output_c]).astype('float32').astype('float32')) 'float32'))
data.append( data.append(
np.zeros([output_c]).astype('float32').astype('float32')) np.zeros([output_c]).astype('float32').astype('float32'))
else: else:
...@@ -397,16 +376,12 @@ class CaffeOpMapper(OpMapper): ...@@ -397,16 +376,12 @@ class CaffeOpMapper(OpMapper):
assert params.bias_term == True assert params.bias_term == True
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
attr = { attr = {
'size': 'size': params.num_output,
params.num_output, 'name': string(node.layer_name),
'name': 'act': None,
string(node.layer_name), 'param_attr': string(node.layer_name + '_weights'),
'act': 'bias_attr': False
None, if len(data) == 1 else string(node.layer_name + '_bias')
'param_attr':
string(node.layer_name + '_weights'),
'bias_attr':
False if len(data) == 1 else string(node.layer_name + '_bias')
} }
node.fluid_code.add_layer( node.fluid_code.add_layer(
"fc", inputs=input, output=node, param_attr=attr) "fc", inputs=input, output=node, param_attr=attr)
...@@ -607,12 +582,8 @@ class CaffeOpMapper(OpMapper): ...@@ -607,12 +582,8 @@ class CaffeOpMapper(OpMapper):
'The parameter of {} (type is {}) is not set. So we set the parameters as 0' 'The parameter of {} (type is {}) is not set. So we set the parameters as 0'
.format(node.layer_name, node.layer_type)) .format(node.layer_name, node.layer_type))
input_c = node.input_shape[0][1] input_c = node.input_shape[0][1]
mean = np.zeros([ mean = np.zeros([input_c, ]).astype('float32')
input_c, variance = np.zeros([input_c, ]).astype('float32')
]).astype('float32')
variance = np.zeros([
input_c,
]).astype('float32')
scale = 0 scale = 0
else: else:
...@@ -649,10 +620,10 @@ class CaffeOpMapper(OpMapper): ...@@ -649,10 +620,10 @@ class CaffeOpMapper(OpMapper):
input_c, input_c,
]).astype('float32') ]).astype('float32')
else: else:
self.weights[node.layer_name + '_scale'] = np.squeeze( self.weights[node.layer_name + '_scale'] = np.squeeze(node.data[
node.data[0]).astype('float32') 0]).astype('float32')
self.weights[node.layer_name + '_offset'] = np.squeeze( self.weights[node.layer_name + '_offset'] = np.squeeze(node.data[
node.data[1]).astype('float32') 1]).astype('float32')
params = node.layer.scale_param params = node.layer.scale_param
axis = params.axis axis = params.axis
num_axes = params.num_axes num_axes = params.num_axes
...@@ -750,8 +721,8 @@ class CaffeOpMapper(OpMapper): ...@@ -750,8 +721,8 @@ class CaffeOpMapper(OpMapper):
node.fluid_code.add_layer( node.fluid_code.add_layer(
"topk", "topk",
inputs=input, inputs=input,
output='{}_topk_var, {}_index_var'.format( output='{}_topk_var, {}_index_var'.format(node.layer_name,
node.layer_name, node.layer_name), node.layer_name),
param_attr=attr) param_attr=attr)
attr = {'dtype': '{}_topk_var.dtype'.format(node.layer_name)} attr = {'dtype': '{}_topk_var.dtype'.format(node.layer_name)}
node.fluid_code.add_layer( node.fluid_code.add_layer(
...@@ -762,8 +733,8 @@ class CaffeOpMapper(OpMapper): ...@@ -762,8 +733,8 @@ class CaffeOpMapper(OpMapper):
attr = {'axis': axis, 'name': string(node.layer_name)} attr = {'axis': axis, 'name': string(node.layer_name)}
node.fluid_code.add_layer( node.fluid_code.add_layer(
"concat", "concat",
inputs='{}_topk_var, {}_index_var'.format( inputs='{}_topk_var, {}_index_var'.format(node.layer_name,
node.layer_name, node.layer_name), node.layer_name),
output=node, output=node,
param_attr=attr) param_attr=attr)
else: else:
...@@ -787,23 +758,22 @@ class CaffeOpMapper(OpMapper): ...@@ -787,23 +758,22 @@ class CaffeOpMapper(OpMapper):
offset_real = [0] * len(input_shape) offset_real = [0] * len(input_shape)
if hasattr(params, "offset") and len(params.offset) > 0: if hasattr(params, "offset") and len(params.offset) > 0:
offset = list(params.offset) offset = list(params.offset)
assert (len(input_shape) - axis) == len( assert (len(input_shape) - axis
offset), "invalid offset[%s] in crop layer" % (str(offset)) ) == len(offset), "invalid offset[%s] in crop layer" % (
str(offset))
offset_real = [0] * axis + offset offset_real = [0] * axis + offset
attr = {'offsets': list(offset_real), 'name': string(node.layer_name)} attr = {'offsets': list(offset_real), 'name': string(node.layer_name)}
node.fluid_code.add_layer( node.fluid_code.add_layer(
"crop", "crop",
inputs={ inputs={'x': input,
'x': input, 'shape': node.input_shape[1]},
'shape': node.input_shape[1]
},
output=node, output=node,
param_attr=attr) param_attr=attr)
def Flatten(self, node): def Flatten(self, node):
assert len( assert len(
node.inputs node.
) == 1, 'The count of DetectionOutput node\'s input is not 1.' inputs) == 1, 'The count of DetectionOutput node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
shape = node.output_shape[0] shape = node.output_shape[0]
attr = {'shape': shape, 'name': string(node.layer_name)} attr = {'shape': shape, 'name': string(node.layer_name)}
......
...@@ -33,8 +33,8 @@ def get_kernel_parameters(params): ...@@ -33,8 +33,8 @@ def get_kernel_parameters(params):
[s_h, s_w] = [params.stride] * 2 [s_h, s_w] = [params.stride] * 2
elif len(params.stride) > 0: elif len(params.stride) > 0:
s_h = params.stride_h if params.stride_h > 0 else params.stride[0] s_h = params.stride_h if params.stride_h > 0 else params.stride[0]
s_w = params.stride_w if params.stride_w > 0 else params.stride[ s_w = params.stride_w if params.stride_w > 0 else params.stride[len(
len(params.stride) - 1] params.stride) - 1]
elif params.stride_h > 0 or params.stride_w > 0: elif params.stride_h > 0 or params.stride_w > 0:
s_h = params.stride_h s_h = params.stride_h
s_w = params.stride_w s_w = params.stride_w
...@@ -67,10 +67,10 @@ def get_strided_kernel_output_shape(params, input_shape, round_func): ...@@ -67,10 +67,10 @@ def get_strided_kernel_output_shape(params, input_shape, round_func):
i_w = input_shape[3] i_w = input_shape[3]
dila_h, dila_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w = get_kernel_parameters( dila_h, dila_w, pad_h, pad_w, kernel_h, kernel_w, stride_h, stride_w = get_kernel_parameters(
params) params)
o_h = (i_h + 2 * pad_h - o_h = (i_h + 2 * pad_h - (dila_h *
(dila_h * (kernel_h - 1) + 1)) / float(stride_h) + 1 (kernel_h - 1) + 1)) / float(stride_h) + 1
o_w = (i_w + 2 * pad_w - o_w = (i_w + 2 * pad_w - (dila_w *
(dila_w * (kernel_w - 1) + 1)) / float(stride_w) + 1 (kernel_w - 1) + 1)) / float(stride_w) + 1
o_h = int(round_func(o_h)) o_h = int(round_func(o_h))
o_w = int(round_func(o_w)) o_w = int(round_func(o_w))
has_c_o = hasattr(params, 'num_output') has_c_o = hasattr(params, 'num_output')
......
...@@ -36,8 +36,7 @@ def register(kind, shape, layer, child_func, weights): ...@@ -36,8 +36,7 @@ def register(kind, shape, layer, child_func, weights):
kind = [kind] kind = [kind]
else: else:
assert type( assert type(
kind kind) is list, 'invalid param "kind" for register, not a list or str'
) is list, 'invalid param "kind" for register, not a list or str'
for k in kind: for k in kind:
assert type( assert type(
......
...@@ -28,60 +28,49 @@ default_op_mapping_field_values['FILL_NAME_FIELD'] = True ...@@ -28,60 +28,49 @@ default_op_mapping_field_values['FILL_NAME_FIELD'] = True
default_op_mapping = { default_op_mapping = {
'Shape': ['shape', ['X'], ['Out']], 'Shape': ['shape', ['X'], ['Out']],
'Clip': [ 'Clip': [
'clip', ['X'], ['Out'], 'clip', ['X'], ['Out'], dict(), dict(
dict(), min=(_np.asarray(
dict( [255, 255, 127, 255], dtype=_np.uint8).view(_np.float32)[0]),
min=(_np.asarray([255, 255, 127, 255], max=(_np.asarray(
dtype=_np.uint8).view(_np.float32)[0]), [255, 255, 127, 127], dtype=_np.uint8).view(_np.float32)[0]), )
max=(_np.asarray([255, 255, 127, 127],
dtype=_np.uint8).view(_np.float32)[0]),
)
], ],
'Erf': ['erf', ['X'], ['Out']], 'Erf': ['erf', ['X'], ['Out']],
'Ceil': ['ceil', ['X'], ['Out']], 'Ceil': ['ceil', ['X'], ['Out']],
'ReduceMean': [ 'ReduceMean': [
'reduce_mean', ['X'], ['Out'], 'reduce_mean', ['X'], ['Out'], dict(
dict(axes='dim', keepdims='keep_dim'), axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
dict(keep_dim=1)
], ],
'ReduceSum': [ 'ReduceSum': [
'reduce_sum', ['X'], ['Out'], 'reduce_sum', ['X'], ['Out'], dict(
dict(axes='dim', keepdims='keep_dim'), axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
dict(keep_dim=1)
], ],
'ReduceMin': [ 'ReduceMin': [
'reduce_min', ['X'], ['Out'], 'reduce_min', ['X'], ['Out'], dict(
dict(axes='dim', keepdims='keep_dim'), axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
dict(keep_dim=1)
], ],
'ReduceMax': [ 'ReduceMax': [
'reduce_max', ['X'], ['Out'], 'reduce_max', ['X'], ['Out'], dict(
dict(axes='dim', keepdims='keep_dim'), axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
dict(keep_dim=1)
], ],
#active function #active function
'Relu': ['relu', ['X'], ['Out']], 'Relu': ['relu', ['X'], ['Out']],
'LeakyRelu': ['leaky_relu', ['X'], ['Out'], 'LeakyRelu': ['leaky_relu', ['X'], ['Out'], dict(), dict(alpha=.01)],
dict(), dict(alpha=.01)], 'Elu': ['elu', ['X'], ['Out'], dict(), dict(alpha=1.)],
'Elu': ['elu', ['X'], ['Out'],
dict(), dict(alpha=1.)],
'ThresholdedRelu': [ 'ThresholdedRelu': [
'thresholded_relu', ['X'], ['Out'], 'thresholded_relu', ['X'], ['Out'], dict(alpha='threshold'),
dict(alpha='threshold'),
dict(alpha=1.) dict(alpha=1.)
], ],
'Tanh': ['tanh', ['X'], ['Out']], 'Tanh': ['tanh', ['X'], ['Out']],
'Sigmoid': ['sigmoid', ['X'], ['Out']], 'Sigmoid': ['sigmoid', ['X'], ['Out']],
'HardSigmoid': [ 'HardSigmoid': [
'hard_sigmoid', ['X'], ['Out'], 'hard_sigmoid', ['X'], ['Out'], dict(
dict(alpha='slope', beta='offset'), alpha='slope', beta='offset'), dict(
dict(slope=.2, offset=.5) slope=.2, offset=.5)
], ],
'Softsign': ['softsign', ['X'], ['Out']], 'Softsign': ['softsign', ['X'], ['Out']],
'Softplus': ['softplus', ['X'], ['Out']], 'Softplus': ['softplus', ['X'], ['Out']],
'Exp': ['exp', ['X'], ['Out']], 'Exp': ['exp', ['X'], ['Out']],
'Softmax': ['softmax', ['X'], ['Out'], 'Softmax': ['softmax', ['X'], ['Out'], dict(), dict(axis=1)],
dict(), dict(axis=1)],
'Sqrt': ['sqrt', ['X'], ['Out']], 'Sqrt': ['sqrt', ['X'], ['Out']],
'Floor': ['floor', ['X'], ['Out']], 'Floor': ['floor', ['X'], ['Out']],
'Abs': ['abs', ['X'], ['Out']], 'Abs': ['abs', ['X'], ['Out']],
......
...@@ -140,8 +140,8 @@ class ONNXOpMapper(OpMapper): ...@@ -140,8 +140,8 @@ class ONNXOpMapper(OpMapper):
model.graph.ClearField('output') model.graph.ClearField('output')
model.graph.output.MergeFrom(model.graph.value_info) model.graph.output.MergeFrom(model.graph.value_info)
onnx.save(model, os.path.join(self.tmp_data_dir, onnx.save(model,
'onnx_model_infer.onnx')) os.path.join(self.tmp_data_dir, 'onnx_model_infer.onnx'))
sess = rt.InferenceSession( sess = rt.InferenceSession(
os.path.join(self.tmp_data_dir, 'onnx_model_infer.onnx')) os.path.join(self.tmp_data_dir, 'onnx_model_infer.onnx'))
res = sess.run(None, input_feed=inputs_dict) res = sess.run(None, input_feed=inputs_dict)
...@@ -217,8 +217,7 @@ class ONNXOpMapper(OpMapper): ...@@ -217,8 +217,7 @@ class ONNXOpMapper(OpMapper):
default_attrs, default_attrs,
input_perm, input_perm,
output_perm, output_perm,
fill_name_field, fill_name_field, ) = info
) = info
if fluid_op in default_ioa_constraint: if fluid_op in default_ioa_constraint:
for predicate, message in default_ioa_constraint[fluid_op]: for predicate, message in default_ioa_constraint[fluid_op]:
...@@ -429,10 +428,8 @@ class ONNXOpMapper(OpMapper): ...@@ -429,10 +428,8 @@ class ONNXOpMapper(OpMapper):
} }
node.fluid_code.add_layer( node.fluid_code.add_layer(
'roi_align', 'roi_align',
inputs={ inputs={'input': val_x,
'input': val_x, 'rois': val_rois},
'rois': val_rois
},
output=node, output=node,
param_attr=attr) param_attr=attr)
...@@ -449,10 +446,8 @@ class ONNXOpMapper(OpMapper): ...@@ -449,10 +446,8 @@ class ONNXOpMapper(OpMapper):
} }
node.fluid_code.add_layer( node.fluid_code.add_layer(
'roi_pool', 'roi_pool',
inputs={ inputs={'input': val_x,
'input': val_x, 'rois': val_rois},
'rois': val_rois
},
output=node, output=node,
param_attr=attr) param_attr=attr)
...@@ -527,10 +522,8 @@ class ONNXOpMapper(OpMapper): ...@@ -527,10 +522,8 @@ class ONNXOpMapper(OpMapper):
val_y = self.graph.get_input_node(node, idx=1, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True)
node.fluid_code.add_layer( node.fluid_code.add_layer(
'greater_than', 'greater_than',
inputs={ inputs={'x': val_x,
'x': val_x, 'y': val_y},
'y': val_y
},
output=node, output=node,
param_attr=None) param_attr=None)
...@@ -549,8 +542,7 @@ class ONNXOpMapper(OpMapper): ...@@ -549,8 +542,7 @@ class ONNXOpMapper(OpMapper):
shape = val_output.out_shapes[0] shape = val_output.out_shapes[0]
if shape is None: if shape is None:
shape = list(value.shape) shape = list(value.shape)
_logger.warning( _logger.warning('in (Constant -> %s): '
'in (Constant -> %s): '
'attribute "shape" of %s not inferred, ' 'attribute "shape" of %s not inferred, '
'using value as 1-D tensor may lead to fails', 'using value as 1-D tensor may lead to fails',
val_output.layer_name, val_output.layer_name) val_output.layer_name, val_output.layer_name)
...@@ -616,10 +608,8 @@ class ONNXOpMapper(OpMapper): ...@@ -616,10 +608,8 @@ class ONNXOpMapper(OpMapper):
if axis == 0 and len(indices_shape) <= 1: if axis == 0 and len(indices_shape) <= 1:
node.fluid_code.add_layer( node.fluid_code.add_layer(
'gather', 'gather',
inputs={ inputs={'input': val_x,
'input': val_x, 'index': indices},
'index': indices
},
output=node, output=node,
param_attr=None) param_attr=None)
elif axis > 0 and len(indices_shape) <= 1: elif axis > 0 and len(indices_shape) <= 1:
...@@ -634,10 +624,8 @@ class ONNXOpMapper(OpMapper): ...@@ -634,10 +624,8 @@ class ONNXOpMapper(OpMapper):
param_attr=attr_trans) param_attr=attr_trans)
node.fluid_code.add_layer( node.fluid_code.add_layer(
'gather', 'gather',
inputs={ inputs={'input': name_trans,
'input': name_trans, 'index': indices},
'index': indices
},
output=node, output=node,
param_attr=None) param_attr=None)
node.fluid_code.add_layer( node.fluid_code.add_layer(
...@@ -649,9 +637,7 @@ class ONNXOpMapper(OpMapper): ...@@ -649,9 +637,7 @@ class ONNXOpMapper(OpMapper):
'reshape', 'reshape',
inputs=indices, inputs=indices,
output=indices, output=indices,
param_attr={'shape': [ param_attr={'shape': [reshape_shape, ]})
reshape_shape,
]})
perm = list(range(len(val_x.out_shapes[0]))) perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:] perm = [axis] + perm[:axis] + perm[axis + 1:]
...@@ -664,10 +650,8 @@ class ONNXOpMapper(OpMapper): ...@@ -664,10 +650,8 @@ class ONNXOpMapper(OpMapper):
param_attr=attr_trans) param_attr=attr_trans)
node.fluid_code.add_layer( node.fluid_code.add_layer(
'gather', 'gather',
inputs={ inputs={'input': name_trans,
'input': name_trans, 'index': indices},
'index': indices
},
output=node, output=node,
param_attr=None) param_attr=None)
node.fluid_code.add_layer( node.fluid_code.add_layer(
...@@ -926,8 +910,10 @@ class ONNXOpMapper(OpMapper): ...@@ -926,8 +910,10 @@ class ONNXOpMapper(OpMapper):
def Sum(self, node): def Sum(self, node):
val_inps = node.layer.input val_inps = node.layer.input
inputs = { inputs = {
"x": self.graph.get_input_node(node, idx=0, copy=True), "x": self.graph.get_input_node(
"y": self.graph.get_input_node(node, idx=1, copy=True), node, idx=0, copy=True),
"y": self.graph.get_input_node(
node, idx=1, copy=True),
} }
node.fluid_code.add_layer("elementwise_add", inputs=inputs, output=node) node.fluid_code.add_layer("elementwise_add", inputs=inputs, output=node)
...@@ -1022,10 +1008,8 @@ class ONNXOpMapper(OpMapper): ...@@ -1022,10 +1008,8 @@ class ONNXOpMapper(OpMapper):
val_y = self.graph.get_input_node(node, idx=1, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True)
node.fluid_code.add_layer( node.fluid_code.add_layer(
"equal", "equal",
inputs={ inputs={'x': val_x,
'x': val_x, 'y': val_y},
'y': val_y
},
output=node, output=node,
param_attr=None) param_attr=None)
...@@ -1055,29 +1039,23 @@ class ONNXOpMapper(OpMapper): ...@@ -1055,29 +1039,23 @@ class ONNXOpMapper(OpMapper):
mul_val_x = val_x.layer_name + '_mul' mul_val_x = val_x.layer_name + '_mul'
node.fluid_code.add_layer( node.fluid_code.add_layer(
"elementwise_mul", "elementwise_mul",
inputs={ inputs={'x': val_x,
'x': val_x, 'y': cast_condition},
'y': cast_condition
},
output=mul_val_x, output=mul_val_x,
param_attr=None) param_attr=None)
mul_val_y = val_y.layer_name + '_mul' mul_val_y = val_y.layer_name + '_mul'
node.fluid_code.add_layer( node.fluid_code.add_layer(
"elementwise_mul", "elementwise_mul",
inputs={ inputs={'x': val_y,
'x': val_y, 'y': cast_not_condition},
'y': cast_not_condition
},
output=mul_val_y, output=mul_val_y,
param_attr=None) param_attr=None)
node.fluid_code.add_layer( node.fluid_code.add_layer(
"elementwise_add", "elementwise_add",
inputs={ inputs={'x': mul_val_x,
'x': mul_val_x, 'y': mul_val_y},
'y': mul_val_y
},
output=node, output=node,
param_attr=None) param_attr=None)
...@@ -1106,7 +1084,8 @@ class ONNXOpMapper(OpMapper): ...@@ -1106,7 +1084,8 @@ class ONNXOpMapper(OpMapper):
output=flatten_name, output=flatten_name,
param_attr={'axis': 0}) param_attr={'axis': 0})
node.fluid_code.add_layer( node.fluid_code.add_layer(
"concat", inputs=flatten_names, output=node, param_attr={'axis': 0}) "concat", inputs=flatten_names, output=node,
param_attr={'axis': 0})
def Identity(self, node): def Identity(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
...@@ -1280,11 +1259,11 @@ class ONNXOpMapper(OpMapper): ...@@ -1280,11 +1259,11 @@ class ONNXOpMapper(OpMapper):
output_size = [0, 0] output_size = [0, 0]
output_size[0] = (val_x.out_shapes[0][2] - output_size[0] = (val_x.out_shapes[0][2] - 1
1) * strides[0] - 2 * paddings[0] + dilations[0] * ( ) * strides[0] - 2 * paddings[0] + dilations[0] * (
kernel_shape[0] - 1) + 1 + out_padding[0] kernel_shape[0] - 1) + 1 + out_padding[0]
output_size[1] = (val_x.out_shapes[0][3] - output_size[1] = (val_x.out_shapes[0][3] - 1
1) * strides[1] - 2 * paddings[1] + dilations[1] * ( ) * strides[1] - 2 * paddings[1] + dilations[1] * (
kernel_shape[1] - 1) + 1 + out_padding[1] kernel_shape[1] - 1) + 1 + out_padding[1]
attr = { attr = {
'num_filters': num_out_channels, 'num_filters': num_out_channels,
...@@ -1367,29 +1346,23 @@ class ONNXOpMapper(OpMapper): ...@@ -1367,29 +1346,23 @@ class ONNXOpMapper(OpMapper):
'squeeze', 'squeeze',
inputs=val_x, inputs=val_x,
output=var_x0, output=var_x0,
param_attr={ param_attr={'axes': [1],
'axes': [1], 'name': string(var_x0)})
'name': string(var_x0)
})
var_w0 = node.layer_name + '_w0' var_w0 = node.layer_name + '_w0'
node.fluid_code.add_layer( node.fluid_code.add_layer(
'squeeze', 'squeeze',
inputs=val_w, inputs=val_w,
output=var_w0, output=var_w0,
param_attr={ param_attr={'axes': [0],
'axes': [0], 'name': string(var_w0)})
'name': string(var_w0)
})
var_fc = node.layer_name + '_fc' var_fc = node.layer_name + '_fc'
var_mm = (node.layer_name + '_mm') if val_b else var_fc var_mm = (node.layer_name + '_mm') if val_b else var_fc
node.fluid_code.add_layer( node.fluid_code.add_layer(
'matmul', 'matmul',
inputs={ inputs={'x': var_x0,
'x': var_x0, 'y': var_w0},
'y': var_w0
},
output=var_mm, output=var_mm,
param_attr={ param_attr={
'transpose_x': 0, 'transpose_x': 0,
...@@ -1402,10 +1375,8 @@ class ONNXOpMapper(OpMapper): ...@@ -1402,10 +1375,8 @@ class ONNXOpMapper(OpMapper):
'squeeze', 'squeeze',
inputs=val_r, inputs=val_r,
output=var_r0, output=var_r0,
param_attr={ param_attr={'axes': [0],
'axes': [0], 'name': string(var_r0)})
'name': string(var_r0)
})
var_r0t = node.layer_name + '_r0t' var_r0t = node.layer_name + '_r0t'
...@@ -1413,10 +1384,8 @@ class ONNXOpMapper(OpMapper): ...@@ -1413,10 +1384,8 @@ class ONNXOpMapper(OpMapper):
'transpose', 'transpose',
inputs=var_r0, inputs=var_r0,
output=var_r0t, output=var_r0t,
param_attr={ param_attr={'perm': [1, 0],
'perm': [1, 0], 'name': string(var_r0t)})
'name': string(var_r0t)
})
if val_b: if val_b:
var_bi = node.layer_name + '_bi' var_bi = node.layer_name + '_bi'
var_bh = node.layer_name + '_bh' var_bh = node.layer_name + '_bh'
...@@ -1434,10 +1403,8 @@ class ONNXOpMapper(OpMapper): ...@@ -1434,10 +1403,8 @@ class ONNXOpMapper(OpMapper):
'squeeze', 'squeeze',
inputs=var_bi, inputs=var_bi,
output=var_bi0, output=var_bi0,
param_attr={ param_attr={'axes': [0],
'axes': [0], 'name': string(var_bi0)})
'name': string(var_bi0)
})
node.fluid_code.add_layer( node.fluid_code.add_layer(
'elmentwise_add', 'elmentwise_add',
...@@ -1454,10 +1421,8 @@ class ONNXOpMapper(OpMapper): ...@@ -1454,10 +1421,8 @@ class ONNXOpMapper(OpMapper):
'squeeze', 'squeeze',
inputs=val_xh, inputs=val_xh,
output=var_xh0, output=var_xh0,
param_attr={ param_attr={'axes': [1],
'axes': [1], 'name': string(var_xh0)})
'name': string(var_xh0)
})
var_y00 = node.layer_name + '_y00' var_y00 = node.layer_name + '_y00'
attr = { attr = {
......
...@@ -30,8 +30,8 @@ def im2sequence(op, block): ...@@ -30,8 +30,8 @@ def im2sequence(op, block):
slice_blocks = list() slice_blocks = list()
for i in range(out_h): for i in range(out_h):
for j in range(out_w): for j in range(out_w):
starts_name = "im2sequence.starts.{}.{}.{}".format( starts_name = "im2sequence.starts.{}.{}.{}".format(im2seq_counter,
im2seq_counter, i, j) i, j)
starts_tensor = helper.make_tensor( starts_tensor = helper.make_tensor(
name=starts_name, name=starts_name,
data_type=onnx_pb.TensorProto.INT64, data_type=onnx_pb.TensorProto.INT64,
......
...@@ -44,8 +44,7 @@ def multiclass_nms(op, block): ...@@ -44,8 +44,7 @@ def multiclass_nms(op, block):
if normalized == False: if normalized == False:
warnings.warn( warnings.warn(
'The parameter normalized of multiclass_nms OP of Paddle is False, which has diff with ONNX. \ 'The parameter normalized of multiclass_nms OP of Paddle is False, which has diff with ONNX. \
Please set normalized=True in multiclass_nms of Paddle' Please set normalized=True in multiclass_nms of Paddle')
)
#convert the paddle attribute to onnx tensor #convert the paddle attribute to onnx tensor
name_score_threshold = [outputs['Out'][0] + "@score_threshold"] name_score_threshold = [outputs['Out'][0] + "@score_threshold"]
...@@ -353,7 +352,8 @@ def multiclass_nms(op, block): ...@@ -353,7 +352,8 @@ def multiclass_nms(op, block):
outputs_gather_topk_class = [result_name + "@gather_topk_class"] outputs_gather_topk_class = [result_name + "@gather_topk_class"]
node_gather_topk_class = onnx.helper.make_node( node_gather_topk_class = onnx.helper.make_node(
'Gather', 'Gather',
inputs=outputs_gather_1_nonzero + [outputs_topk_select_topk_indices[1]], inputs=outputs_gather_1_nonzero +
[outputs_topk_select_topk_indices[1]],
outputs=outputs_gather_topk_class, outputs=outputs_gather_topk_class,
axis=1) axis=1)
node_list.append(node_gather_topk_class) node_list.append(node_gather_topk_class)
...@@ -362,7 +362,8 @@ def multiclass_nms(op, block): ...@@ -362,7 +362,8 @@ def multiclass_nms(op, block):
outputs_gather_topk_boxes_id = [result_name + "@gather_topk_boxes_id"] outputs_gather_topk_boxes_id = [result_name + "@gather_topk_boxes_id"]
node_gather_topk_boxes_id = onnx.helper.make_node( node_gather_topk_boxes_id = onnx.helper.make_node(
'Gather', 'Gather',
inputs=outputs_gather_2_nonzero + [outputs_topk_select_topk_indices[1]], inputs=outputs_gather_2_nonzero +
[outputs_topk_select_topk_indices[1]],
outputs=outputs_gather_topk_boxes_id, outputs=outputs_gather_topk_boxes_id,
axis=1) axis=1)
node_list.append(node_gather_topk_boxes_id) node_list.append(node_gather_topk_boxes_id)
......
...@@ -38,8 +38,8 @@ def yolo_box(op, block): ...@@ -38,8 +38,8 @@ def yolo_box(op, block):
downsample_ratio = attrs['downsample_ratio'] downsample_ratio = attrs['downsample_ratio']
input_size = input_height * downsample_ratio input_size = input_height * downsample_ratio
conf_thresh = attrs['conf_thresh'] conf_thresh = attrs['conf_thresh']
conf_thresh_mat = np.ones([num_anchors * input_height * input_width conf_thresh_mat = np.ones([num_anchors * input_height *
]) * conf_thresh input_width]) * conf_thresh
node_list = [] node_list = []
im_outputs = [] im_outputs = []
......
...@@ -250,8 +250,7 @@ class PaddleOpMapper(object): ...@@ -250,8 +250,7 @@ class PaddleOpMapper(object):
node = helper.make_node( node = helper.make_node(
pool_type[op.attr('pooling_type')][1], pool_type[op.attr('pooling_type')][1],
inputs=op.input('X'), inputs=op.input('X'),
outputs=op.output('Out'), outputs=op.output('Out'), )
)
else: else:
input_shape = block.var(op.input('X')[0]).shape input_shape = block.var(op.input('X')[0]).shape
k_size = op.attr('ksize') k_size = op.attr('ksize')
...@@ -407,8 +406,7 @@ class PaddleOpMapper(object): ...@@ -407,8 +406,7 @@ class PaddleOpMapper(object):
node = helper.make_node( node = helper.make_node(
'Clip', 'Clip',
inputs=[op.input('X')[0], min_name, max_name], inputs=[op.input('X')[0], min_name, max_name],
outputs=op.output('Out'), outputs=op.output('Out'), )
)
return [min_node, max_node, node] return [min_node, max_node, node]
def shape(self, op, block): def shape(self, op, block):
...@@ -450,8 +448,7 @@ class PaddleOpMapper(object): ...@@ -450,8 +448,7 @@ class PaddleOpMapper(object):
node = helper.make_node( node = helper.make_node(
"Slice", "Slice",
inputs=[op.input('Input')[0], starts_name, ends_name, axes_name], inputs=[op.input('Input')[0], starts_name, ends_name, axes_name],
outputs=op.output('Out'), outputs=op.output('Out'), )
)
return [starts_node, ends_node, axes_node, node] return [starts_node, ends_node, axes_node, node]
def fill_constant(self, op, block): def fill_constant(self, op, block):
...@@ -551,8 +548,8 @@ class PaddleOpMapper(object): ...@@ -551,8 +548,8 @@ class PaddleOpMapper(object):
if op.attr('align_corners'): if op.attr('align_corners'):
coordinate_transformation_mode = 'align_corners' coordinate_transformation_mode = 'align_corners'
if ('OutSize' in input_names and len(op.input('OutSize')) > 0) or ( if ('OutSize' in input_names and len(op.input('OutSize')) > 0) or (
'SizeTensor' in input_names 'SizeTensor' in input_names and
and len(op.input('SizeTensor')) > 0): len(op.input('SizeTensor')) > 0):
node_list = list() node_list = list()
roi_node = self.make_constant_node( roi_node = self.make_constant_node(
self.get_name(op.type, 'roi'), onnx_pb.TensorProto.FLOAT, self.get_name(op.type, 'roi'), onnx_pb.TensorProto.FLOAT,
...@@ -631,8 +628,7 @@ class PaddleOpMapper(object): ...@@ -631,8 +628,7 @@ class PaddleOpMapper(object):
elif 'Scale' in input_names and len(op.input('Scale')) > 0: elif 'Scale' in input_names and len(op.input('Scale')) > 0:
node = helper.make_node( node = helper.make_node(
'Resize', 'Resize',
inputs=[op.input('X')[0], inputs=[op.input('X')[0], op.input('Scale')[0]],
op.input('Scale')[0]],
outputs=op.output('Out'), outputs=op.output('Out'),
mode='linear', mode='linear',
coordinate_transformation_mode=coordinate_transformation_mode) coordinate_transformation_mode=coordinate_transformation_mode)
...@@ -641,8 +637,9 @@ class PaddleOpMapper(object): ...@@ -641,8 +637,9 @@ class PaddleOpMapper(object):
scale = op.attr('scale') scale = op.attr('scale')
if out_shape.count(-1) > 0: if out_shape.count(-1) > 0:
scale_name = self.get_name(op.type, 'scale') scale_name = self.get_name(op.type, 'scale')
scale_node = self.make_constant_node( scale_node = self.make_constant_node(scale_name,
scale_name, onnx_pb.TensorProto.FLOAT, [1, 1, scale, scale]) onnx_pb.TensorProto.FLOAT,
[1, 1, scale, scale])
roi_name = self.get_name(op.type, 'roi') roi_name = self.get_name(op.type, 'roi')
roi_node = self.make_constant_node(roi_name, roi_node = self.make_constant_node(roi_name,
onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.FLOAT,
...@@ -667,16 +664,14 @@ class PaddleOpMapper(object): ...@@ -667,16 +664,14 @@ class PaddleOpMapper(object):
if 'OutSize' in input_names and len(op.input('OutSize')) > 0: if 'OutSize' in input_names and len(op.input('OutSize')) > 0:
node = helper.make_node( node = helper.make_node(
'Resize', 'Resize',
inputs=[op.input('X')[0], '', inputs=[op.input('X')[0], '', op.input('OutSize')[0]],
op.input('OutSize')[0]],
outputs=op.output('Out'), outputs=op.output('Out'),
mode='nearest', mode='nearest',
coordinate_transformation_mode=coordinate_transformation_mode) coordinate_transformation_mode=coordinate_transformation_mode)
elif 'Scale' in input_names and len(op.input('Scale')) > 0: elif 'Scale' in input_names and len(op.input('Scale')) > 0:
node = helper.make_node( node = helper.make_node(
'Resize', 'Resize',
inputs=[op.input('X')[0], inputs=[op.input('X')[0], op.input('Scale')[0]],
op.input('Scale')[0]],
outputs=op.output('Out'), outputs=op.output('Out'),
mode='nearest', mode='nearest',
coordinate_transformation_mode=coordinate_transformation_mode) coordinate_transformation_mode=coordinate_transformation_mode)
...@@ -685,8 +680,9 @@ class PaddleOpMapper(object): ...@@ -685,8 +680,9 @@ class PaddleOpMapper(object):
scale = op.attr('scale') scale = op.attr('scale')
if out_shape.count(-1) > 0: if out_shape.count(-1) > 0:
scale_name = self.get_name(op.type, 'scale') scale_name = self.get_name(op.type, 'scale')
scale_node = self.make_constant_node( scale_node = self.make_constant_node(scale_name,
scale_name, onnx_pb.TensorProto.FLOAT, [1, 1, scale, scale]) onnx_pb.TensorProto.FLOAT,
[1, 1, scale, scale])
roi_name = self.get_name(op.type, 'roi') roi_name = self.get_name(op.type, 'roi')
roi_node = self.make_constant_node(roi_name, roi_node = self.make_constant_node(roi_name,
onnx_pb.TensorProto.FLOAT, onnx_pb.TensorProto.FLOAT,
...@@ -737,8 +733,7 @@ class PaddleOpMapper(object): ...@@ -737,8 +733,7 @@ class PaddleOpMapper(object):
node1 = helper.make_node( node1 = helper.make_node(
'Clip', 'Clip',
inputs=[name0, min_name, max_name], inputs=[name0, min_name, max_name],
outputs=[name1], outputs=[name1], )
)
name2 = self.get_name(op.type, 'mul') name2 = self.get_name(op.type, 'mul')
node2 = helper.make_node( node2 = helper.make_node(
'Mul', inputs=[op.input('X')[0], name1], outputs=[name2]) 'Mul', inputs=[op.input('X')[0], name1], outputs=[name2])
......
...@@ -114,9 +114,8 @@ class TFOpMapper(OpMapper): ...@@ -114,9 +114,8 @@ class TFOpMapper(OpMapper):
else: else:
unsupported_ops.add(op) unsupported_ops.add(op)
if len(unsupported_ops) > 0: if len(unsupported_ops) > 0:
sys.stderr.write( sys.stderr.write("=========={} Ops are not supported yet======\n".
"=========={} Ops are not supported yet======\n".format( format(len(unsupported_ops)))
len(unsupported_ops)))
for op in unsupported_ops: for op in unsupported_ops:
sys.stderr.write("========== {} ==========\n".format(op)) sys.stderr.write("========== {} ==========\n".format(op))
sys.exit(-1) sys.exit(-1)
...@@ -296,8 +295,8 @@ class TFOpMapper(OpMapper): ...@@ -296,8 +295,8 @@ class TFOpMapper(OpMapper):
shape = [shape[i] for i in [0, 3, 1, 2]] shape = [shape[i] for i in [0, 3, 1, 2]]
if len(shape) == 3: if len(shape) == 3:
shape = [shape[i] for i in [2, 0, 1]] shape = [shape[i] for i in [2, 0, 1]]
self.weights[node.layer_name] = numpy.transpose( self.weights[node.layer_name] = numpy.transpose(node.value,
node.value, (2, 0, 1)) (2, 0, 1))
elif node.tf_data_format == "NCHW": elif node.tf_data_format == "NCHW":
if len(shape) == 4: if len(shape) == 4:
self.graph.data_format_propagation(node) self.graph.data_format_propagation(node)
...@@ -534,8 +533,8 @@ class TFOpMapper(OpMapper): ...@@ -534,8 +533,8 @@ class TFOpMapper(OpMapper):
attr = {"shape": shape} attr = {"shape": shape}
self.add_omit_nodes(param.layer_name, node.layer_name) self.add_omit_nodes(param.layer_name, node.layer_name)
else: else:
assert len(param.out_shapes[0] assert len(param.out_shapes[
) == 1, "Unexpected situation of shape parameter" 0]) == 1, "Unexpected situation of shape parameter"
attr = {"shape": [-1]} attr = {"shape": [-1]}
node.fluid_code.add_layer( node.fluid_code.add_layer(
"reshape", "reshape",
...@@ -647,15 +646,15 @@ class TFOpMapper(OpMapper): ...@@ -647,15 +646,15 @@ class TFOpMapper(OpMapper):
def ConcatV2(self, node): def ConcatV2(self, node):
inputs = [ inputs = [
self.graph.get_node(name, copy=True) self.graph.get_node(
for name in node.layer.input[:-1] name, copy=True) for name in node.layer.input[:-1]
] ]
axis = self.graph.get_node(node.layer.input[-1], copy=True) axis = self.graph.get_node(node.layer.input[-1], copy=True)
assert axis.layer_type == "Const" assert axis.layer_type == "Const"
self.add_omit_nodes(axis.layer_name, node.layer_name) self.add_omit_nodes(axis.layer_name, node.layer_name)
axis = axis.value axis = axis.value
if inputs[0].tf_data_format == "NHWC" and len( if inputs[0].tf_data_format == "NHWC" and len(inputs[0].out_shapes[
inputs[0].out_shapes[0]) == 4: 0]) == 4:
axis = nhwc_dim_to_nchw(inputs[0], axis) axis = nhwc_dim_to_nchw(inputs[0], axis)
attr = {"axis": axis} attr = {"axis": axis}
node.fluid_code.add_layer( node.fluid_code.add_layer(
...@@ -684,11 +683,12 @@ class TFOpMapper(OpMapper): ...@@ -684,11 +683,12 @@ class TFOpMapper(OpMapper):
def Pack(self, node): def Pack(self, node):
inputs = [ inputs = [
self.graph.get_node(name, copy=True) for name in node.layer.input self.graph.get_node(
name, copy=True) for name in node.layer.input
] ]
axis = node.get_attr("axis") axis = node.get_attr("axis")
if inputs[0].tf_data_format == "NHWC" and len( if inputs[0].tf_data_format == "NHWC" and len(inputs[0].out_shapes[
inputs[0].out_shapes[0]) == 4: 0]) == 4:
tf_data_format = list(inputs[0].tf_data_format) tf_data_format = list(inputs[0].tf_data_format)
tf_data_format.insert(axis, str(len(tf_data_format))) tf_data_format.insert(axis, str(len(tf_data_format)))
axis = nhwc_dim_to_nchw(inputs[0], axis) axis = nhwc_dim_to_nchw(inputs[0], axis)
...@@ -1010,8 +1010,8 @@ class TFOpMapper(OpMapper): ...@@ -1010,8 +1010,8 @@ class TFOpMapper(OpMapper):
if resize_shape.layer_type == "Const": if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist() resize_shape = resize_shape.value.tolist()
else: else:
resize_shape = self.decoder.infer_shape_tensor( resize_shape = self.decoder.infer_shape_tensor(resize_shape,
resize_shape, node.out_shapes[0]) node.out_shapes[0])
align_corners = node.get_attr("align_corners") align_corners = node.get_attr("align_corners")
attr = {"align_corners": align_corners, "out_shape": resize_shape} attr = {"align_corners": align_corners, "out_shape": resize_shape}
node.fluid_code.add_layer( node.fluid_code.add_layer(
...@@ -1024,8 +1024,8 @@ class TFOpMapper(OpMapper): ...@@ -1024,8 +1024,8 @@ class TFOpMapper(OpMapper):
if resize_shape.layer_type == "Const": if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist() resize_shape = resize_shape.value.tolist()
else: else:
resize_shape = self.decoder.infer_shape_tensor( resize_shape = self.decoder.infer_shape_tensor(resize_shape,
resize_shape, node.out_shapes[0]) node.out_shapes[0])
align_corners = node.get_attr("align_corners") align_corners = node.get_attr("align_corners")
attr = { attr = {
"align_corners": align_corners, "align_corners": align_corners,
......
...@@ -486,8 +486,8 @@ class TFOpMapperNHWC(OpMapper): ...@@ -486,8 +486,8 @@ class TFOpMapperNHWC(OpMapper):
attr = {"shape": shape} attr = {"shape": shape}
self.add_omit_nodes(param.layer_name, node.layer_name) self.add_omit_nodes(param.layer_name, node.layer_name)
else: else:
assert len(param.out_shapes[0] assert len(param.out_shapes[
) == 1, "Unexpected situation of shape parameter" 0]) == 1, "Unexpected situation of shape parameter"
attr = {"shape": [-1]} attr = {"shape": [-1]}
node.fluid_code.add_layer( node.fluid_code.add_layer(
"reshape", "reshape",
...@@ -577,8 +577,8 @@ class TFOpMapperNHWC(OpMapper): ...@@ -577,8 +577,8 @@ class TFOpMapperNHWC(OpMapper):
def ConcatV2(self, node): def ConcatV2(self, node):
inputs = [ inputs = [
self.graph.get_node(name, copy=True) self.graph.get_node(
for name in node.layer.input[:-1] name, copy=True) for name in node.layer.input[:-1]
] ]
axis = self.graph.get_node(node.layer.input[-1], copy=True) axis = self.graph.get_node(node.layer.input[-1], copy=True)
assert axis.layer_type == "Const" assert axis.layer_type == "Const"
...@@ -608,7 +608,8 @@ class TFOpMapperNHWC(OpMapper): ...@@ -608,7 +608,8 @@ class TFOpMapperNHWC(OpMapper):
def Pack(self, node): def Pack(self, node):
inputs = [ inputs = [
self.graph.get_node(name, copy=True) for name in node.layer.input self.graph.get_node(
name, copy=True) for name in node.layer.input
] ]
axis = node.get_attr("axis") axis = node.get_attr("axis")
attr = {"axis": axis} attr = {"axis": axis}
...@@ -949,8 +950,8 @@ class TFOpMapperNHWC(OpMapper): ...@@ -949,8 +950,8 @@ class TFOpMapperNHWC(OpMapper):
if resize_shape.layer_type == "Const": if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist() resize_shape = resize_shape.value.tolist()
else: else:
resize_shape = self.decoder.infer_shape_tensor( resize_shape = self.decoder.infer_shape_tensor(resize_shape,
resize_shape, node.out_shapes[0]) node.out_shapes[0])
align_corners = node.get_attr("align_corners") align_corners = node.get_attr("align_corners")
attr = {"perm": [0, 3, 1, 2]} attr = {"perm": [0, 3, 1, 2]}
node.fluid_code.add_layer( node.fluid_code.add_layer(
...@@ -969,8 +970,8 @@ class TFOpMapperNHWC(OpMapper): ...@@ -969,8 +970,8 @@ class TFOpMapperNHWC(OpMapper):
if resize_shape.layer_type == "Const": if resize_shape.layer_type == "Const":
resize_shape = resize_shape.value.tolist() resize_shape = resize_shape.value.tolist()
else: else:
resize_shape = self.decoder.infer_shape_tensor( resize_shape = self.decoder.infer_shape_tensor(resize_shape,
resize_shape, node.out_shapes[0]) node.out_shapes[0])
align_corners = node.get_attr("align_corners") align_corners = node.get_attr("align_corners")
attr = {"perm": [0, 3, 1, 2]} attr = {"perm": [0, 3, 1, 2]}
node.fluid_code.add_layer( node.fluid_code.add_layer(
......
...@@ -768,8 +768,8 @@ class TFOptimizer(object): ...@@ -768,8 +768,8 @@ class TFOptimizer(object):
is_prelu = False is_prelu = False
continue continue
if len(in_nodes0[0].outputs) != 1 or len( if len(in_nodes0[0].outputs) != 1 or len(in_nodes0[1]
in_nodes0[1].outputs) != 1: .outputs) != 1:
is_prelu = False is_prelu = False
continue continue
...@@ -778,8 +778,8 @@ class TFOptimizer(object): ...@@ -778,8 +778,8 @@ class TFOptimizer(object):
self.graph.get_node(in_name) self.graph.get_node(in_name)
for in_name in in_nodes0[1].inputs for in_name in in_nodes0[1].inputs
] ]
if in_nodes2[1].layer_type != "Const" or numpy.fabs( if in_nodes2[1].layer_type != "Const" or numpy.fabs(in_nodes2[
in_nodes2[1].value - 0.5) > 1e-06: 1].value - 0.5) > 1e-06:
is_prelu = False is_prelu = False
continue continue
if in_nodes2[0].layer_type != "Mul": if in_nodes2[0].layer_type != "Mul":
...@@ -788,8 +788,8 @@ class TFOptimizer(object): ...@@ -788,8 +788,8 @@ class TFOptimizer(object):
if exist_act(in_nodes2[0]): if exist_act(in_nodes2[0]):
is_prelu = False is_prelu = False
continue continue
if len(in_nodes2[1].outputs) != 1 or len( if len(in_nodes2[1].outputs) != 1 or len(in_nodes2[0]
in_nodes2[0].outputs) != 1: .outputs) != 1:
is_prelu = False is_prelu = False
continue continue
...@@ -804,8 +804,8 @@ class TFOptimizer(object): ...@@ -804,8 +804,8 @@ class TFOptimizer(object):
if exist_act(in_nodes3[1]): if exist_act(in_nodes3[1]):
is_prelu = False is_prelu = False
continue continue
if len(in_nodes3[0].outputs) != 1 or len( if len(in_nodes3[0].outputs) != 1 or len(in_nodes3[1]
in_nodes3[1].outputs) != 1: .outputs) != 1:
is_prelu = False is_prelu = False
continue continue
...@@ -857,12 +857,12 @@ class TFOptimizer(object): ...@@ -857,12 +857,12 @@ class TFOptimizer(object):
mode = "element" mode = "element"
elif len(in_nodes3[0].value.shape) == 0: elif len(in_nodes3[0].value.shape) == 0:
mode = "all" mode = "all"
elif len(in_nodes3[0].value.shape elif len(in_nodes3[0].value.shape) == 1 and in_nodes3[
) == 1 and in_nodes3[0].value.shape[0] == 1: 0].value.shape[0] == 1:
mode = "all" mode = "all"
elif len(in_shape) == 4 and len( elif len(in_shape) == 4 and len(in_nodes3[
in_nodes3[0].value.shape 0].value.shape) == 1 and in_nodes3[0].value.shape[
) == 1 and in_nodes3[0].value.shape[0] == in_shape[-1]: 0] == in_shape[-1]:
mode = "channel" mode = "channel"
weight = self.op_mapper.weights[in_nodes3[0].layer_name] weight = self.op_mapper.weights[in_nodes3[0].layer_name]
weight = numpy.expand_dims(weight, 0) weight = numpy.expand_dims(weight, 0)
...@@ -917,14 +917,15 @@ class TFOptimizer(object): ...@@ -917,14 +917,15 @@ class TFOptimizer(object):
self.graph.get_node(in_name) for in_name in node.inputs self.graph.get_node(in_name) for in_name in node.inputs
] ]
if in_nodes0[0].layer_type != "Mul" or in_nodes0[ if in_nodes0[0].layer_type != "Mul" or in_nodes0[
1].layer_type != "Const" or in_nodes0[1].value.size != 1: 1].layer_type != "Const" or in_nodes0[
1].value.size != 1:
is_scale = False is_scale = False
continue continue
if exist_act(in_nodes0[0]): if exist_act(in_nodes0[0]):
is_scale = False is_scale = False
continue continue
if len(in_nodes0[0].outputs) != 1 or len( if len(in_nodes0[0].outputs) != 1 or len(in_nodes0[1]
in_nodes0[1].outputs) != 1: .outputs) != 1:
is_scale = False is_scale = False
continue continue
...@@ -940,8 +941,8 @@ class TFOptimizer(object): ...@@ -940,8 +941,8 @@ class TFOptimizer(object):
if exist_act(in_nodes1[1]): if exist_act(in_nodes1[1]):
is_scale = False is_scale = False
continue continue
if len(in_nodes1[0].outputs) != 1 or len( if len(in_nodes1[0].outputs) != 1 or len(in_nodes1[1]
in_nodes1[1].outputs) != 1: .outputs) != 1:
is_scale = False is_scale = False
continue continue
...@@ -963,8 +964,8 @@ class TFOptimizer(object): ...@@ -963,8 +964,8 @@ class TFOptimizer(object):
scale = 1.0 / in_nodes2[1].value * in_nodes1[0].value scale = 1.0 / in_nodes2[1].value * in_nodes1[0].value
act = None act = None
if node.fluid_code.layers[0].param_attr is not None: if node.fluid_code.layers[0].param_attr is not None:
act = node.fluid_code.layers[0].param_attr.get( act = node.fluid_code.layers[0].param_attr.get("act",
"act", None) None)
node.fluid_code.clear() node.fluid_code.clear()
attr = { attr = {
...@@ -1003,17 +1004,17 @@ class TFOptimizer(object): ...@@ -1003,17 +1004,17 @@ class TFOptimizer(object):
if exist_act(in_nodes0[0]): if exist_act(in_nodes0[0]):
is_affine_channel = False is_affine_channel = False
continue continue
if len(in_nodes0[0].outputs) != 1 or len( if len(in_nodes0[0].outputs) != 1 or len(in_nodes0[1]
in_nodes0[1].outputs) != 1: .outputs) != 1:
is_affine_channel = False is_affine_channel = False
continue continue
in_nodes1 = [ in_nodes1 = [
self.graph.get_node(in_name) self.graph.get_node(in_name)
for in_name in in_nodes0[0].inputs for in_name in in_nodes0[0].inputs
] ]
if len(in_nodes1[0].out_shapes[0] if len(in_nodes1[0].out_shapes[0]) != 4 or in_nodes1[
) != 4 or in_nodes1[1].layer_type != "Const" or len( 1].layer_type != "Const" or len(in_nodes1[1]
in_nodes1[1].value.shape) != 3: .value.shape) != 3:
is_affine_channel = False is_affine_channel = False
continue continue
if len(in_nodes1[1].outputs) != 1: if len(in_nodes1[1].outputs) != 1:
...@@ -1036,8 +1037,8 @@ class TFOptimizer(object): ...@@ -1036,8 +1037,8 @@ class TFOptimizer(object):
node.layer_type = "AffineChannel" node.layer_type = "AffineChannel"
node.inputs = [in_node.layer_name] node.inputs = [in_node.layer_name]
scale = 1.0 / in_nodes0[1].value.flatten() scale = 1.0 / in_nodes0[1].value.flatten()
bias = in_nodes1[1].value.flatten( bias = in_nodes1[1].value.flatten() / in_nodes0[
) / in_nodes0[1].value.flatten() 1].value.flatten()
if not bias_add: if not bias_add:
bias *= -1.0 bias *= -1.0
self.op_mapper.weights[node.layer_name + "_scale"] = scale self.op_mapper.weights[node.layer_name + "_scale"] = scale
...@@ -1045,8 +1046,8 @@ class TFOptimizer(object): ...@@ -1045,8 +1046,8 @@ class TFOptimizer(object):
act = None act = None
if node.fluid_code.layers[0].param_attr is not None: if node.fluid_code.layers[0].param_attr is not None:
act = node.fluid_code.layers[0].param_attr.get( act = node.fluid_code.layers[0].param_attr.get("act",
"act", None) None)
node.fluid_code.clear() node.fluid_code.clear()
attr = { attr = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册