diff --git a/python/tools/tf_converter_lib.py b/python/tools/tf_converter_lib.py index bf714231efd4a0a491af41e34a615254489b0e3e..f388a9f15f3d4b29ec7c5b3cb6a88a98c49bb9a1 100644 --- a/python/tools/tf_converter_lib.py +++ b/python/tools/tf_converter_lib.py @@ -398,8 +398,8 @@ class TFConverter(object): self.add_tensor(input_names[1], gamma_value.shape, gamma_tensor.dtype, offset_value) + op_def.input.extend([op.inputs[0].name]) if self.device == 'gpu': - op_def.input.extend([op.inputs[0].name]) for name in input_names: output_name = self.add_buffer_to_image(name, "ARGUMENT") op_def.input.extend([output_name]) @@ -746,6 +746,16 @@ class TFConverter(object): self.add_output_shape(op.outputs, op_def) self.resolved_ops[op.name] = 1 + def replace_in_out_name(self, input_name, output_name): + input_name = input_name + ":0" + output_name = output_name + ":0" + for op in self.net_def.op: + if len(op.input) > 0 and op.input[0] == input_name: + op.input[0] = MACE_INPUT_NODE_NAME + ":0" + if len(op.output) > 0 and op.output[0] == output_name: + op.output[0] = MACE_OUTPUT_NODE_NAME + ":0" + + def convert(self, input_node, output_node): if self.device == 'gpu': self.add_input_transform(input_node) @@ -807,6 +817,9 @@ class TFConverter(object): if self.device == 'gpu': self.add_output_transform(output_node) + if self.device == 'cpu': + self.replace_in_out_name(input_node, output_node) + for key in self.resolved_ops: if self.resolved_ops[key] != 1: print 'Unresolve Op: %s' % key @@ -935,9 +948,11 @@ def convert_to_mace_pb(input_graph_def, input_node, input_shape, output_node, da converter.convert(input_node, output_node) optimizer = Optimizer(net_def, device) net_def = optimizer.optimize() - print "PB Converted, start optimize memory." - mem_optimizer = memory_optimizer.MemoryOptimizer(net_def) - mem_optimizer.optimize() - print "Memory optimization done." + print "PB Converted." + if device == 'gpu': + print "start optimize memory." + mem_optimizer = memory_optimizer.MemoryOptimizer(net_def) + mem_optimizer.optimize() + print "Memory optimization done." return net_def