提交 1b86bc21 编写于 作者: W wuchenghui

fix benchmark & support global_avg_pool convert

上级 336a8f20
...@@ -346,7 +346,7 @@ class TFConverter(object): ...@@ -346,7 +346,7 @@ class TFConverter(object):
final_op = bias_add_op final_op = bias_add_op
self.resolved_ops[bias_add_op.name] = 1 self.resolved_ops[bias_add_op.name] = 1
if len(self.tf_graph[final_op.name]) == 1 \ if len(self.tf_graph.get(final_op.name, [])) == 1 \
and self.tf_graph[final_op.name][0].type in activation_name_map: and self.tf_graph[final_op.name][0].type in activation_name_map:
activation_op = self.tf_graph[final_op.name][0] activation_op = self.tf_graph[final_op.name][0]
op_def.type = "FusedConv2D" op_def.type = "FusedConv2D"
...@@ -508,6 +508,33 @@ class TFConverter(object): ...@@ -508,6 +508,33 @@ class TFConverter(object):
data_format_arg.s = 'NHWC' data_format_arg.s = 'NHWC'
self.resolved_ops[op.name] = 1 self.resolved_ops[op.name] = 1
def convert_global_avg_pooling(self, op):
op_def = self.net_def.op.add()
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self.dt
op_def.name = op.name
op_def.type = 'Pooling'
op_def.input.extend([op.inputs[0].name])
op_def.output.extend([output.name for output in op.outputs])
self.add_output_shape(op.outputs, op_def)
pooling_type_arg = op_def.arg.add()
pooling_type_arg.name = 'pooling_type'
pooling_type_arg.i = pooling_type_mode['AvgPool']
padding_arg = op_def.arg.add()
padding_arg.name = 'padding'
padding_arg.i = padding_mode['VALID']
strides_arg = op_def.arg.add()
strides_arg.name = 'strides'
strides_arg.ints.extend([1, 1])
kernels_arg = op_def.arg.add()
kernels_arg.name = 'kernels'
kernels_arg.ints.extend(op.inputs[0].shape.as_list()[1:3])
data_format_arg = op_def.arg.add()
data_format_arg.name = 'data_format'
data_format_arg.s = 'NHWC'
self.resolved_ops[op.name] = 1
def convert_activation(self, op): def convert_activation(self, op):
op_def = self.net_def.op.add() op_def = self.net_def.op.add()
arg = op_def.arg.add() arg = op_def.arg.add()
...@@ -824,6 +851,14 @@ class TFConverter(object): ...@@ -824,6 +851,14 @@ class TFConverter(object):
# FIXME: hardcode for inception_v3 # FIXME: hardcode for inception_v3
elif op.type in ['Squeeze', 'Shape']: elif op.type in ['Squeeze', 'Shape']:
self.resolved_ops[op.name] = 1 self.resolved_ops[op.name] = 1
elif op.type == 'Mean':
# Global avg pooling
reduce_dims = op.inputs[1].eval()
if reduce_dims[0] == 1 and reduce_dims[1] == 2:
self.convert_global_avg_pooling(op)
self.unused_tensor.add(op.inputs[1].name)
else:
raise Exception('Unknown Op: %s, type: %s' % (op.name, op.type))
#elif op.type in ['']: #elif op.type in ['']:
# self.convert_normal_op(op) # self.convert_normal_op(op)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册