From 72bf66a9698083d233569b400d774ee3e3e1ca9b Mon Sep 17 00:00:00 2001 From: Liangliang He Date: Thu, 18 Jan 2018 18:12:51 +0800 Subject: [PATCH] Add activations support --- python/tools/tf_converter_lib.py | 36 +++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/python/tools/tf_converter_lib.py b/python/tools/tf_converter_lib.py index feb0e4bc..ac1acf64 100644 --- a/python/tools/tf_converter_lib.py +++ b/python/tools/tf_converter_lib.py @@ -216,6 +216,9 @@ class TFConverter(object): and self.tf_graph[final_op.name][0].type == 'Relu': relu_op = self.tf_graph[final_op.name][0] op_def.type = "FusedConv2D" + fused_relu_arg = op_def.arg.add() + fused_relu_arg.name = 'activation' + fused_relu_arg.s = "RELU" final_op = relu_op self.resolved_ops[relu_op.name] = 1 @@ -273,8 +276,8 @@ class TFConverter(object): relu_op = self.tf_graph[op.name][0] final_op = relu_op fused_relu_arg = op_def.arg.add() - fused_relu_arg.name = 'fused_relu' - fused_relu_arg.i = 1 + fused_relu_arg.name = 'activation' + fused_relu_arg.s = "RELU" self.resolved_ops[relu_op.name] = 1 op_def.output.extend([final_op.outputs[0].name]) @@ -362,16 +365,34 @@ class TFConverter(object): data_format_arg.s = 'NHWC' self.resolved_ops[op.name] = 1 + def convert_relu(self, op): + op_def = self.net_def.op.add() + arg = op_def.arg.add() + arg.name = 'T' + arg.i = self.dt + op_def.name = op.name + op_def.type = 'Activation' + activation_arg = op_def.arg.add() + activation_arg.name = 'activation' + activation_arg.s = "RELU" + op_def.input.extend([input.name for input in op.inputs]) + op_def.output.extend([output.name for output in op.outputs]) + self.add_output_shape(op.outputs, op_def) + self.resolved_ops[op.name] = 1 + def convert_relu6(self, op): op_def = self.net_def.op.add() arg = op_def.arg.add() arg.name = 'T' arg.i = self.dt op_def.name = op.name - op_def.type = 'Relu' + op_def.type = 'Activation' op_def.input.extend([input.name for input in op.inputs]) op_def.output.extend([output.name for output in op.outputs]) self.add_output_shape(op.outputs, op_def) + activation_arg = op_def.arg.add() + activation_arg.name = 'activation' + activation_arg.s = "RELUX" max_limit_arg = op_def.arg.add() max_limit_arg.name = 'max_limit' max_limit_arg.f = 6 @@ -531,6 +552,9 @@ class TFConverter(object): and self.tf_graph[final_op.name][0].type == 'Relu': relu_op = self.tf_graph[final_op.name][0] op_def.type = "FusedConv2D" + fused_relu_arg = op_def.arg.add() + fused_relu_arg.name = 'activation' + fused_relu_arg.s = "RELU" final_op = relu_op self.resolved_ops[relu_op.name] = 1 @@ -602,6 +626,8 @@ class TFConverter(object): self.convert_batchnorm(op) elif op.type == 'AvgPool' or op.type == 'MaxPool': self.convert_pooling(op) + elif op.type == 'Relu': + self.convert_relu(op) elif op.type == 'Relu6': self.convert_relu6(op) elif op.type == 'Add': @@ -618,8 +644,8 @@ class TFConverter(object): self.convert_space_to_batch(op, True) elif self.is_softmax(op): self.convert_softmax(op) - elif op.type in ['Relu']: - self.convert_normal_op(op) + #elif op.type in ['']: + # self.convert_normal_op(op) else: raise Exception('Unknown Op: %s, type: %s' % (op.name, op.type)) -- GitLab