diff --git a/python/tools/tf_converter_lib.py b/python/tools/tf_converter_lib.py index 807a1d59e5f7e2ae3690a9e371f7e13844cb5e34..c35c26b675db74f177be3365c083a49571bc3660 100644 --- a/python/tools/tf_converter_lib.py +++ b/python/tools/tf_converter_lib.py @@ -346,7 +346,7 @@ class TFConverter(object): final_op = bias_add_op self.resolved_ops[bias_add_op.name] = 1 - if len(self.tf_graph[final_op.name]) == 1 \ + if len(self.tf_graph.get(final_op.name, [])) == 1 \ and self.tf_graph[final_op.name][0].type in activation_name_map: activation_op = self.tf_graph[final_op.name][0] op_def.type = "FusedConv2D" @@ -508,6 +508,33 @@ class TFConverter(object): data_format_arg.s = 'NHWC' self.resolved_ops[op.name] = 1 + def convert_global_avg_pooling(self, op): + op_def = self.net_def.op.add() + arg = op_def.arg.add() + arg.name = 'T' + arg.i = self.dt + op_def.name = op.name + op_def.type = 'Pooling' + op_def.input.extend([op.inputs[0].name]) + op_def.output.extend([output.name for output in op.outputs]) + self.add_output_shape(op.outputs, op_def) + pooling_type_arg = op_def.arg.add() + pooling_type_arg.name = 'pooling_type' + pooling_type_arg.i = pooling_type_mode['AvgPool'] + padding_arg = op_def.arg.add() + padding_arg.name = 'padding' + padding_arg.i = padding_mode['VALID'] + strides_arg = op_def.arg.add() + strides_arg.name = 'strides' + strides_arg.ints.extend([1, 1]) + kernels_arg = op_def.arg.add() + kernels_arg.name = 'kernels' + kernels_arg.ints.extend(op.inputs[0].shape.as_list()[1:3]) + data_format_arg = op_def.arg.add() + data_format_arg.name = 'data_format' + data_format_arg.s = 'NHWC' + self.resolved_ops[op.name] = 1 + def convert_activation(self, op): op_def = self.net_def.op.add() arg = op_def.arg.add() @@ -824,6 +851,14 @@ class TFConverter(object): # FIXME: hardcode for inception_v3 elif op.type in ['Squeeze', 'Shape']: self.resolved_ops[op.name] = 1 + elif op.type == 'Mean': + # Global avg pooling + reduce_dims = op.inputs[1].eval() + if reduce_dims[0] == 1 and reduce_dims[1] == 2: + self.convert_global_avg_pooling(op) + self.unused_tensor.add(op.inputs[1].name) + else: + raise Exception('Unknown Op: %s, type: %s' % (op.name, op.type)) #elif op.type in ['']: # self.convert_normal_op(op) else: