提交 abef71cf 编写于 作者: L liuqi

Add shape information to input node.

上级 08314882
...@@ -34,8 +34,10 @@ def main(unused_args): ...@@ -34,8 +34,10 @@ def main(unused_args):
output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb( output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb(
input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.dsp_mode) input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.dsp_mode)
else: else:
input_shape = [int(x) for x in FLAGS.input_shape.split(',')]
output_graph_def = tf_converter_lib.convert_to_mace_pb( output_graph_def = tf_converter_lib.convert_to_mace_pb(
input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime, FLAGS.winograd) input_graph_def, FLAGS.input_node, input_shape, FLAGS.output_node,
FLAGS.data_type, FLAGS.runtime, FLAGS.winograd)
if FLAGS.output_type == 'source': if FLAGS.output_type == 'source':
source_converter_lib.convert_to_source(output_graph_def, mode_pb_checksum, FLAGS.template, FLAGS.obfuscate, source_converter_lib.convert_to_source(output_graph_def, mode_pb_checksum, FLAGS.template, FLAGS.obfuscate,
...@@ -124,6 +126,11 @@ def parse_args(): ...@@ -124,6 +126,11 @@ def parse_args():
type=int, type=int,
default=0, default=0,
help="dsp run mode, defalut=0") help="dsp run mode, defalut=0")
parser.add_argument(
"--input_shape",
type=str,
default="1,512,512,3",
help="input shape.")
return parser.parse_known_args() return parser.parse_known_args()
......
...@@ -2,7 +2,12 @@ from lib.proto import mace_pb2 ...@@ -2,7 +2,12 @@ from lib.proto import mace_pb2
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
import math import math
import copy
from lib.python.tools import memory_optimizer from lib.python.tools import memory_optimizer
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import node_def_pb2
# TODO: support NCHW formt, now only support NHWC. # TODO: support NCHW formt, now only support NHWC.
padding_mode = { padding_mode = {
...@@ -903,10 +908,27 @@ class Optimizer: ...@@ -903,10 +908,27 @@ class Optimizer:
new_net = self.fold_batch_norm() new_net = self.fold_batch_norm()
return new_net return new_net
def convert_to_mace_pb(input_graph_def, input_node, output_node, data_type, device, winograd): def add_shape_info(input_graph_def, input_node, input_shape):
inputs_replaced_graph = graph_pb2.GraphDef()
for node in input_graph_def.node:
if node.name == input_node:
placeholder_node = copy.deepcopy(node)
placeholder_node.attr.clear()
placeholder_node.attr['shape'].shape.dim.extend([
tensor_shape_pb2.TensorShapeProto.Dim(size=i) for i in input_shape
])
placeholder_node.attr['dtype'].CopyFrom(node.attr['dtype'])
inputs_replaced_graph.node.extend([placeholder_node])
else:
inputs_replaced_graph.node.extend([copy.deepcopy(node)])
return inputs_replaced_graph
def convert_to_mace_pb(input_graph_def, input_node, input_shape, output_node, data_type, device, winograd):
net_def = mace_pb2.NetDef() net_def = mace_pb2.NetDef()
dt = data_type_map[data_type] dt = data_type_map[data_type]
input_graph_def = add_shape_info(input_graph_def, input_node, input_shape)
with tf.Session() as session: with tf.Session() as session:
with session.graph.as_default() as graph: with session.graph.as_default() as graph:
tf.import_graph_def(input_graph_def, name="") tf.import_graph_def(input_graph_def, name="")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册