提交 a68dc2f3 编写于 作者: 李寅

Fix tensorflow dsp converter

上级 915d38c2
...@@ -122,6 +122,28 @@ MACE now supports models from TensorFlow and Caffe (more frameworks will be supp ...@@ -122,6 +122,28 @@ MACE now supports models from TensorFlow and Caffe (more frameworks will be supp
strip_unused_nodes strip_unused_nodes
sort_by_execution_order' sort_by_execution_order'
Usage for DSP,
.. code:: bash
# DSP:
./transform_graph \
--in_graph=/path/to/your/tf_model.pb \
--out_graph=/path/to/your/output/tf_model_opt.pb \
--inputs='input node name' \
--outputs='output node name' \
--transforms='strip_unused_nodes(type=float, shape="1,64,64,3")
strip_unused_nodes(type=float, shape="1,64,64,3")
remove_nodes(op=Identity, op=CheckNumerics)
fold_constants(ignore_errors=true)
fold_batch_norms
fold_old_batch_norms
backport_concatv2
quantize_weights(minimum_size=2)
quantize_nodes
strip_unused_nodes
sort_by_execution_order'
- Caffe - Caffe
Caffe 1.0+ models are supported in MACE converter tool. Caffe 1.0+ models are supported in MACE converter tool.
......
...@@ -3,7 +3,7 @@ py_library( ...@@ -3,7 +3,7 @@ py_library(
srcs = [ srcs = [
"convert_util.py", "convert_util.py",
"graph_util.py", "graph_util.py",
"tf_dsp_converter_lib.py", "converter_tool/tf_dsp_converter.py",
"converter_tool/base_converter.py", "converter_tool/base_converter.py",
"converter_tool/shape_inference.py", "converter_tool/shape_inference.py",
"converter_tool/tensorflow_converter.py", "converter_tool/tensorflow_converter.py",
......
...@@ -96,16 +96,6 @@ def main(unused_args): ...@@ -96,16 +96,6 @@ def main(unused_args):
print ("runtime %s is not supported." % FLAGS.runtime) print ("runtime %s is not supported." % FLAGS.runtime)
sys.exit(-1) sys.exit(-1)
if FLAGS.runtime == 'dsp':
if FLAGS.platform == 'tensorflow':
from mace.python.tools import tf_dsp_converter_lib
output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb(
FLAGS.model_file, FLAGS.input_node, FLAGS.output_node,
FLAGS.dsp_mode)
else:
print("%s does not support dsp runtime yet." % FLAGS.platform)
sys.exit(-1)
else:
if FLAGS.graph_optimize_options: if FLAGS.graph_optimize_options:
option = cvt.ConverterOption( option = cvt.ConverterOption(
FLAGS.graph_optimize_options.split(',')) FLAGS.graph_optimize_options.split(','))
...@@ -129,6 +119,15 @@ def main(unused_args): ...@@ -129,6 +119,15 @@ def main(unused_args):
output_node.name = output_node_names[i] output_node.name = output_node_names[i]
option.add_output_node(output_node) option.add_output_node(output_node)
print("Transform model to one that can better run on device")
if FLAGS.runtime == 'dsp':
mace_check(FLAGS.platform == 'tensorflow',
'DSP only supports tensorflow')
from mace.python.tools.converter_tool import tf_dsp_converter
converter = tf_dsp_converter.TensorflowDspConverter(
option, FLAGS.model_file)
output_graph_def = converter.run()
else:
if FLAGS.platform == 'tensorflow': if FLAGS.platform == 'tensorflow':
from mace.python.tools.converter_tool import tensorflow_converter from mace.python.tools.converter_tool import tensorflow_converter
converter = tensorflow_converter.TensorflowConverter( converter = tensorflow_converter.TensorflowConverter(
...@@ -144,7 +143,6 @@ def main(unused_args): ...@@ -144,7 +143,6 @@ def main(unused_args):
output_graph_def = converter.run() output_graph_def = converter.run()
print("Transform model to one that can better run on device")
if FLAGS.runtime == 'cpu+gpu': if FLAGS.runtime == 'cpu+gpu':
cpu_graph_def = copy.deepcopy(output_graph_def) cpu_graph_def = copy.deepcopy(output_graph_def)
......
...@@ -14,16 +14,80 @@ ...@@ -14,16 +14,80 @@
from mace.proto import mace_pb2 from mace.proto import mace_pb2
from mace.python.tools.converter_tool import base_converter
from mace.python.tools import graph_util
from mace.python.tools.convert_util import mace_check
import tensorflow as tf import tensorflow as tf
from tensorflow import gfile from tensorflow.core.framework import tensor_shape_pb2
from operator import mul from operator import mul
from dsp_ops import DspOps import numpy as np
from mace.python.tools import graph_util
class DspOps(object):
def __init__(self):
self.dsp_ops = {
'INPUT': 'INPUT"',
'OUTPUT': 'OUTPUT',
'NoOp': 'Nop',
'FLATTEN': 'Flatten',
'Identity': 'Nop',
'Placeholder': 'INPUT',
'Const': 'Const',
'QuantizedConv2D': 'QuantizedConv2d_8x8to32',
'QuantizedMatMul': 'QuantizedMatMul_8x8to32',
'QuantizeDownAndShrinkRange': 'QuantizeDownAndShrinkRange_32to8',
'QuantizedRelu': 'QuantizedRelu_8',
'QuantizedReluX': 'QuantizedReluX_8',
'QuantizedMaxPool': 'QuantizedMaxPool_8',
'QuantizedAvgPool': 'QuantizedAvgPool_8',
'QuantizedConcat': 'QuantizedConcat_8',
'QuantizedBiasAdd': 'QuantizedBiasAdd_8p8to32',
'QuantizedResizeBilinear': 'QuantizedResizeBilinear_8',
'QuantizedSpaceToBatchND': 'QuantizedSpaceToBatchND_8',
'QuantizedBatchToSpaceND': 'QuantizedBatchToSpaceND_8',
'QuantizedSoftmax': 'QuantizedSoftmax_8',
'QuantizedTanh': 'QuantizedTanh_8',
'Min': 'Min_f',
'Max': 'Max_f',
'QuantizeV2': 'Quantize',
'Dequantize': 'Dequantize',
'Softmax': 'Softmax_f',
'Reshape': 'Reshape',
'QuantizedReshape': 'QuantizedReshape',
'Sigmoid': 'Sigmoid_f',
'Slice': 'Slice_f',
'Add': 'Add_f',
'Mul': 'Mul_f',
'Requantize': 'Requantize_32to8',
'RequantizationRange': 'RequantizationRange_32',
'Sub': 'Sub_f',
'Pack': 'Pack_int32',
'StridedSlice': 'StridedSlice_f',
'ExpandDims': 'ExpandDims_f',
'QuantizedMul': 'QuantizedMul_8x8to32',
'QuantizedAdd': 'QuantizedAdd_8p8to32',
'Pad': 'Pad_f',
'SpaceToBatchND': 'SpaceToBatchND_f',
'BatchToSpaceND': 'BatchToSpaceND_f',
'ResizeBilinear': 'ResizeBilinear_f',
'ConcatV2': 'ConcatV2_f',
'Conv2DBackpropInput': 'Deconv_f',
'Tanh': 'Tanh_f',
'Split': 'Split_f',
'Transpose': 'Transpose_f',
'Concat': 'Concat_f',
'AddN': 'AddN_f',
}
def has_op(self, tf_op):
return tf_op in self.dsp_ops
def map_nn_op(self, tf_op):
if tf_op not in self.dsp_ops:
raise Exception('Could not map nn op for: ', tf_op)
return self.dsp_ops[tf_op]
# converter --input ../libcv/quantized_model.pb \
# --output quantized_model_dsp.pb \
# --runtime dsp --input_node input_node \
# --output_node output_node
TF_DTYPE_2_MACE_DTYPE_MAP = { TF_DTYPE_2_MACE_DTYPE_MAP = {
tf.float32: mace_pb2.DT_FLOAT, tf.float32: mace_pb2.DT_FLOAT,
...@@ -101,7 +165,6 @@ def get_input_tensor(op, index): ...@@ -101,7 +165,6 @@ def get_input_tensor(op, index):
def add_shape_const_node(net_def, op, values, name): def add_shape_const_node(net_def, op, values, name):
print('Add const node: ', op.name + '/' + name)
tensor = net_def.tensors.add() tensor = net_def.tensors.add()
node_name = op.name + '/' + name node_name = op.name + '/' + name
tensor.name = node_name + ':0' tensor.name = node_name + ':0'
...@@ -128,7 +191,7 @@ def convert_op_outputs(mace_op_def, tf_op): ...@@ -128,7 +191,7 @@ def convert_op_outputs(mace_op_def, tf_op):
mace_op_def.output_shape.extend(output_shapes) mace_op_def.output_shape.extend(output_shapes)
def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops): def convert_ops(unresolved_ops, resolved_ops, net_def, dsp_ops):
first_op = unresolved_ops[0] first_op = unresolved_ops[0]
print('Op: ', first_op.name, first_op.type, first_op.outputs[0].shape) print('Op: ', first_op.name, first_op.type, first_op.outputs[0].shape)
...@@ -152,7 +215,8 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops): ...@@ -152,7 +215,8 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops):
first_op.outputs[0].dtype == tf.quint8 or \ first_op.outputs[0].dtype == tf.quint8 or \
first_op.outputs[0].dtype == tf.quint16: first_op.outputs[0].dtype == tf.quint16:
tensor.int32_data.extend(tf_tensor.astype(int).flat) tensor.int32_data.extend(tf_tensor.astype(int).flat)
elif first_op.type == 'Shape':
resolved_ops.add(first_op.name)
else: else:
op_def = net_def.op.add() op_def = net_def.op.add()
op_def.name = first_op.name op_def.name = first_op.name
...@@ -162,7 +226,7 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops): ...@@ -162,7 +226,7 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops):
if len(first_op.outputs) > 0 and first_op.type == 'Dequantize' \ if len(first_op.outputs) > 0 and first_op.type == 'Dequantize' \
and len(first_op.outputs[0].consumers()) > 0 \ and len(first_op.outputs[0].consumers()) > 0 \
and (first_op.outputs[0].consumers()[0].type == 'SpaceToBatchND' or and (first_op.outputs[0].consumers()[0].type == 'SpaceToBatchND' or
first_op.outputs[0].consumers()[0].type == 'BatchToSpaceND'): first_op.outputs[0].consumers()[0].type == 'BatchToSpaceND'): # noqa
input_tensor = first_op.inputs[0] input_tensor = first_op.inputs[0]
min_tensor = first_op.inputs[1] min_tensor = first_op.inputs[1]
max_tensor = first_op.inputs[2] max_tensor = first_op.inputs[2]
...@@ -183,14 +247,12 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops): ...@@ -183,14 +247,12 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops):
op_def.input.extend([t.name for t in s2b_op.inputs[1:]]) op_def.input.extend([t.name for t in s2b_op.inputs[1:]])
op_def.input.extend([min_tensor.name, max_tensor.name]) op_def.input.extend([min_tensor.name, max_tensor.name])
convert_op_outputs(op_def, quantize_op) convert_op_outputs(op_def, quantize_op)
elif len(first_op.outputs) > 0 and \ elif (len(first_op.outputs) > 0 and
first_op.type == 'QuantizedReshape' and \ first_op.type == 'QuantizedReshape' and
len(first_op.outputs[0].consumers()) > 0 and \ len(first_op.outputs[0].consumers()) > 0 and
first_op.outputs[0].consumers()[0].type == 'Dequantize' and \ first_op.outputs[0].consumers()[0].type == 'Dequantize' and
len(first_op.outputs[0].consumers()[0].outputs[0].consumers()) \ len(first_op.outputs[0].consumers()[0].outputs[0].consumers()) > 0 and # noqa
> 0 and \ first_op.outputs[0].consumers()[0].outputs[0].consumers()[0].type == 'Softmax'): # noqa
first_op.outputs[0].consumers()[0].outputs[0].consumers()[0].type \
== 'Softmax':
input_tensor = first_op.inputs[0] input_tensor = first_op.inputs[0]
min_tensor = first_op.inputs[2] min_tensor = first_op.inputs[2]
max_tensor = first_op.inputs[3] max_tensor = first_op.inputs[3]
...@@ -216,17 +278,17 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops): ...@@ -216,17 +278,17 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops):
[input_tensor.name, min_tensor.name, max_tensor.name]) [input_tensor.name, min_tensor.name, max_tensor.name])
convert_op_outputs(op_def, quantize_reshape_op) convert_op_outputs(op_def, quantize_reshape_op)
# remove Squeeze # remove Squeeze
elif len(first_op.outputs) > 0 and \ elif (len(first_op.outputs) > 0 and
first_op.type == 'Requantize' and \ first_op.type == 'Requantize' and
len(first_op.outputs[0].consumers()) > 0 and \ len(first_op.outputs[0].consumers()) > 0 and
first_op.outputs[0].consumers()[0].type == 'Dequantize' and \ first_op.outputs[0].consumers()[0].type == 'Dequantize' and
len(first_op.outputs[0].consumers()[0].outputs[0].consumers()) \ len(first_op.outputs[0].consumers()[0].outputs[0].consumers()) > 0 and # noqa
> 0 and \ first_op.outputs[0].consumers()[0].outputs[0].consumers()[0].type == 'Squeeze'): # noqa
first_op.outputs[0].consumers()[0].outputs[0].consumers()[0].type \
== 'Squeeze':
dequantize_op = first_op.outputs[0].consumers()[0] dequantize_op = first_op.outputs[0].consumers()[0]
squeeze_op = dequantize_op.outputs[0].consumers()[0] squeeze_op = dequantize_op.outputs[0].consumers()[0]
reshape_op = squeeze_op.outputs[0].consumers()[0] reshape_op = squeeze_op.outputs[0].consumers()[0]
if reshape_op.type == 'Shape':
reshape_op = squeeze_op.outputs[0].consumers()[1]
min_op = reshape_op.outputs[0].consumers()[0] min_op = reshape_op.outputs[0].consumers()[0]
max_op = reshape_op.outputs[0].consumers()[1] max_op = reshape_op.outputs[0].consumers()[1]
quantize_op = min_op.outputs[0].consumers()[0] quantize_op = min_op.outputs[0].consumers()[0]
...@@ -249,7 +311,7 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops): ...@@ -249,7 +311,7 @@ def convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops):
if next_op and len(next_op.outputs) > 0 and \ if next_op and len(next_op.outputs) > 0 and \
next_op.type == 'QuantizedReshape' and \ next_op.type == 'QuantizedReshape' and \
len(next_op.outputs[0].consumers()) > 0 else None len(next_op.outputs[0].consumers()) > 0 else None
softmax_op = dequantize_op.outputs[0].consumers()[0]\ softmax_op = dequantize_op.outputs[0].consumers()[0] \
if dequantize_op and len(dequantize_op.outputs) > 0 and \ if dequantize_op and len(dequantize_op.outputs) > 0 and \
dequantize_op.type == 'Dequantize' and \ dequantize_op.type == 'Dequantize' and \
len(dequantize_op.outputs[0].consumers()) > 0 else None len(dequantize_op.outputs[0].consumers()) > 0 else None
...@@ -446,11 +508,9 @@ def reverse_batch_to_space_and_biasadd(net_def): ...@@ -446,11 +508,9 @@ def reverse_batch_to_space_and_biasadd(net_def):
new_follow_op.CopyFrom(follow_op) new_follow_op.CopyFrom(follow_op)
for i in xrange(len(follow_op.input)): for i in xrange(len(follow_op.input)):
for k in xrange(3): for k in xrange(3):
if new_follow_op.input[ if new_follow_op.input[i] == get_tensor_name_from_op( # noqa
i] == get_tensor_name_from_op(
biasadd_requantize_op.name, k): biasadd_requantize_op.name, k):
new_follow_op.input[ new_follow_op.input[i] = get_tensor_name_from_op( # noqa
i] = get_tensor_name_from_op(
b2s_op.name, k) b2s_op.name, k)
new_ops.append(new_follow_op) new_ops.append(new_follow_op)
skip_ops.add(follow_op.name) skip_ops.add(follow_op.name)
...@@ -518,7 +578,7 @@ def add_input_output_info(net_def, input_node, output_node, graph, dtype): ...@@ -518,7 +578,7 @@ def add_input_output_info(net_def, input_node, output_node, graph, dtype):
return net_def return net_def
def fuse_quantize(net_def, input_node, output_node): def fuse_quantize(net_def):
tensor_map = {} tensor_map = {}
for tensor in net_def.tensors: for tensor in net_def.tensors:
tensor_map[tensor.name] = tensor tensor_map[tensor.name] = tensor
...@@ -567,51 +627,71 @@ def fuse_quantize(net_def, input_node, output_node): ...@@ -567,51 +627,71 @@ def fuse_quantize(net_def, input_node, output_node):
return new_net_def return new_net_def
def convert_to_mace_pb(model_file, input_node, output_node, dsp_mode): class TensorflowDspConverter(base_converter.ConverterInterface):
""" def __init__(self, option, src_model_file):
nnlib does not have batch norm, so use tensorflow optimizer to fold self._option = option
batch norm with convolution. The fold optimization reorders ops, so self._mace_net_def = mace_pb2.NetDef()
we sort ops first by topology.
""" # import tensorflow graph
input_graph_def = tf.GraphDef() tf_graph_def = tf.GraphDef()
with gfile.Open(model_file, "rb") as f: with tf.gfile.Open(src_model_file, 'rb') as f:
data = f.read() tf_graph_def.ParseFromString(f.read())
input_graph_def.ParseFromString(data)
input_graph_def = graph_util.sort_tf_graph(input_graph_def) self._placeholders = {}
net_def = mace_pb2.NetDef() self.add_shape_info(tf_graph_def)
with tf.Session() as session: with tf.Session() as session:
with session.graph.as_default() as graph: with session.graph.as_default() as graph:
tf.import_graph_def(input_graph_def, name="") tf.import_graph_def(tf_graph_def, name='')
ops = graph.get_operations() self._tf_graph = graph
def run(self):
ops = self._tf_graph.get_operations()
dsp_ops = DspOps() dsp_ops = DspOps()
resolved_ops = set() resolved_ops = set()
mace_check(len(self._option.input_nodes) == 1
and len(self._option.output_nodes) == 1,
'dsp only support single input and output')
input_node = self._option.input_nodes.values()[0].name
output_node = self._option.output_nodes.values()[0].name
# convert const node # convert const node
unresolved_ops = [op for op in ops if op.type == 'Const'] unresolved_ops = [op for op in ops if op.type == 'Const']
with tf.Session() as session:
while len(unresolved_ops) > 0: while len(unresolved_ops) > 0:
convert_ops(unresolved_ops, resolved_ops, net_def, output_node, convert_ops(unresolved_ops, resolved_ops, self._mace_net_def,
dsp_ops) dsp_ops)
# convert op node # convert op node
unresolved_ops = [op for op in ops if op.type != 'Const'] unresolved_ops = [op for op in ops if op.type != 'Const']
while len(unresolved_ops) > 0: while len(unresolved_ops) > 0:
convert_ops(unresolved_ops, resolved_ops, net_def, output_node, convert_ops(unresolved_ops, resolved_ops, self._mace_net_def,
dsp_ops) dsp_ops)
add_output_node(net_def, output_node) add_output_node(self._mace_net_def, output_node)
net_def = reverse_batch_to_space_and_biasadd(net_def) net_def = reverse_batch_to_space_and_biasadd(self._mace_net_def)
net_def = fuse_quantize(net_def, input_node, output_node) net_def = fuse_quantize(net_def)
sorted_net_def = graph_util.sort_mace_graph(net_def, '__output__') sorted_net_def = graph_util.sort_mace_graph(net_def, '__output__')
net_def_with_node_id = add_node_id(sorted_net_def) net_def_with_node_id = add_node_id(sorted_net_def)
dtype = mace_pb2.DT_FLOAT dtype = mace_pb2.DT_FLOAT
final_net_def = add_input_output_info( final_net_def = add_input_output_info(
net_def_with_node_id, input_node, output_node, graph, dtype) net_def_with_node_id, input_node, output_node,
self._tf_graph, dtype)
arg = final_net_def.arg.add()
arg.name = 'dsp_mode'
arg.i = dsp_mode
return final_net_def return final_net_def
def add_shape_info(self, tf_graph_def):
for node in tf_graph_def.node:
for input_node in self._option.input_nodes.values():
if node.name == input_node.name or \
node.name + ':0' == input_node.name:
del node.attr['shape'].shape.dim[:]
node.attr['shape'].shape.dim.extend([
tensor_shape_pb2.TensorShapeProto.Dim(size=i) for i in
input_node.shape
])
self._placeholders[node.name + ':0'] = \
np.zeros(shape=input_node.shape, dtype=float)
# Copyright 2018 Xiaomi, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class DspOps(object):
def __init__(self):
self.dsp_ops = {
'INPUT': 'INPUT"',
'OUTPUT': 'OUTPUT',
'NoOp': 'Nop',
'FLATTEN': 'Flatten',
'Identity': 'Nop',
'Placeholder': 'INPUT',
'Const': 'Const',
'QuantizedConv2D': 'QuantizedConv2d_8x8to32',
'QuantizedMatMul': 'QuantizedMatMul_8x8to32',
'QuantizeDownAndShrinkRange': 'QuantizeDownAndShrinkRange_32to8',
'QuantizedRelu': 'QuantizedRelu_8',
'QuantizedReluX': 'QuantizedReluX_8',
'QuantizedMaxPool': 'QuantizedMaxPool_8',
'QuantizedAvgPool': 'QuantizedAvgPool_8',
'QuantizedConcat': 'QuantizedConcat_8',
'QuantizedBiasAdd': 'QuantizedBiasAdd_8p8to32',
'QuantizedResizeBilinear': 'QuantizedResizeBilinear_8',
'QuantizedSpaceToBatchND': 'QuantizedSpaceToBatchND_8',
'QuantizedBatchToSpaceND': 'QuantizedBatchToSpaceND_8',
'QuantizedSoftmax': 'QuantizedSoftmax_8',
'QuantizedTanh': 'QuantizedTanh_8',
'Min': 'Min_f',
'Max': 'Max_f',
'QuantizeV2': 'Quantize',
'Dequantize': 'Dequantize',
'Softmax': 'Softmax_f',
'Reshape': 'Reshape',
'QuantizedReshape': 'QuantizedReshape',
'Sigmoid': 'Sigmoid_f',
'Slice': 'Slice_f',
'Add': 'Add_f',
'Mul': 'Mul_f',
'Requantize': 'Requantize_32to8',
'RequantizationRange': 'RequantizationRange_32',
'Sub': 'Sub_f',
'Pack': 'Pack_int32',
'StridedSlice': 'StridedSlice_f',
'ExpandDims': 'ExpandDims_f',
'QuantizedMul': 'QuantizedMul_8x8to32',
'QuantizedAdd': 'QuantizedAdd_8p8to32',
'Pad': 'Pad_f',
'SpaceToBatchND': 'SpaceToBatchND_f',
'BatchToSpaceND': 'BatchToSpaceND_f',
'ResizeBilinear': 'ResizeBilinear_f',
'ConcatV2': 'ConcatV2_f',
'Conv2DBackpropInput': 'Deconv_f',
'Tanh': 'Tanh_f',
'Split': 'Split_f',
'Transpose': 'Transpose_f',
'Concat': 'Concat_f',
'AddN': 'AddN_f',
}
def has_op(self, tf_op):
return tf_op in self.dsp_ops
def map_nn_op(self, tf_op):
if tf_op not in self.dsp_ops:
raise Exception('Could not map nn op for: ', tf_op)
return self.dsp_ops[tf_op]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册