diff --git a/python/tools/tf_converter_lib.py b/python/tools/tf_converter_lib.py index feb0e4bcc230cda93daea0fea200fd483cda70cb..ac1acf64276a1dc270cfd9aa452297d3f5b896dc 100644 --- a/python/tools/tf_converter_lib.py +++ b/python/tools/tf_converter_lib.py @@ -216,6 +216,9 @@ class TFConverter(object): and self.tf_graph[final_op.name][0].type == 'Relu': relu_op = self.tf_graph[final_op.name][0] op_def.type = "FusedConv2D" + fused_relu_arg = op_def.arg.add() + fused_relu_arg.name = 'activation' + fused_relu_arg.s = "RELU" final_op = relu_op self.resolved_ops[relu_op.name] = 1 @@ -273,8 +276,8 @@ class TFConverter(object): relu_op = self.tf_graph[op.name][0] final_op = relu_op fused_relu_arg = op_def.arg.add() - fused_relu_arg.name = 'fused_relu' - fused_relu_arg.i = 1 + fused_relu_arg.name = 'activation' + fused_relu_arg.s = "RELU" self.resolved_ops[relu_op.name] = 1 op_def.output.extend([final_op.outputs[0].name]) @@ -362,16 +365,34 @@ class TFConverter(object): data_format_arg.s = 'NHWC' self.resolved_ops[op.name] = 1 + def convert_relu(self, op): + op_def = self.net_def.op.add() + arg = op_def.arg.add() + arg.name = 'T' + arg.i = self.dt + op_def.name = op.name + op_def.type = 'Activation' + activation_arg = op_def.arg.add() + activation_arg.name = 'activation' + activation_arg.s = "RELU" + op_def.input.extend([input.name for input in op.inputs]) + op_def.output.extend([output.name for output in op.outputs]) + self.add_output_shape(op.outputs, op_def) + self.resolved_ops[op.name] = 1 + def convert_relu6(self, op): op_def = self.net_def.op.add() arg = op_def.arg.add() arg.name = 'T' arg.i = self.dt op_def.name = op.name - op_def.type = 'Relu' + op_def.type = 'Activation' op_def.input.extend([input.name for input in op.inputs]) op_def.output.extend([output.name for output in op.outputs]) self.add_output_shape(op.outputs, op_def) + activation_arg = op_def.arg.add() + activation_arg.name = 'activation' + activation_arg.s = "RELUX" max_limit_arg = op_def.arg.add() max_limit_arg.name = 'max_limit' max_limit_arg.f = 6 @@ -531,6 +552,9 @@ class TFConverter(object): and self.tf_graph[final_op.name][0].type == 'Relu': relu_op = self.tf_graph[final_op.name][0] op_def.type = "FusedConv2D" + fused_relu_arg = op_def.arg.add() + fused_relu_arg.name = 'activation' + fused_relu_arg.s = "RELU" final_op = relu_op self.resolved_ops[relu_op.name] = 1 @@ -602,6 +626,8 @@ class TFConverter(object): self.convert_batchnorm(op) elif op.type == 'AvgPool' or op.type == 'MaxPool': self.convert_pooling(op) + elif op.type == 'Relu': + self.convert_relu(op) elif op.type == 'Relu6': self.convert_relu6(op) elif op.type == 'Add': @@ -618,8 +644,8 @@ class TFConverter(object): self.convert_space_to_batch(op, True) elif self.is_softmax(op): self.convert_softmax(op) - elif op.type in ['Relu']: - self.convert_normal_op(op) + #elif op.type in ['']: + # self.convert_normal_op(op) else: raise Exception('Unknown Op: %s, type: %s' % (op.name, op.type))