提交 fc7a469c 编写于 作者: 叶剑武

Merge branch 'activation' into 'master'

Add activations support

See merge request !3
......@@ -216,6 +216,9 @@ class TFConverter(object):
and self.tf_graph[final_op.name][0].type == 'Relu':
relu_op = self.tf_graph[final_op.name][0]
op_def.type = "FusedConv2D"
fused_relu_arg = op_def.arg.add()
fused_relu_arg.name = 'activation'
fused_relu_arg.s = "RELU"
final_op = relu_op
self.resolved_ops[relu_op.name] = 1
......@@ -273,8 +276,8 @@ class TFConverter(object):
relu_op = self.tf_graph[op.name][0]
final_op = relu_op
fused_relu_arg = op_def.arg.add()
fused_relu_arg.name = 'fused_relu'
fused_relu_arg.i = 1
fused_relu_arg.name = 'activation'
fused_relu_arg.s = "RELU"
self.resolved_ops[relu_op.name] = 1
op_def.output.extend([final_op.outputs[0].name])
......@@ -362,16 +365,34 @@ class TFConverter(object):
data_format_arg.s = 'NHWC'
self.resolved_ops[op.name] = 1
def convert_relu(self, op):
op_def = self.net_def.op.add()
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self.dt
op_def.name = op.name
op_def.type = 'Activation'
activation_arg = op_def.arg.add()
activation_arg.name = 'activation'
activation_arg.s = "RELU"
op_def.input.extend([input.name for input in op.inputs])
op_def.output.extend([output.name for output in op.outputs])
self.add_output_shape(op.outputs, op_def)
self.resolved_ops[op.name] = 1
def convert_relu6(self, op):
op_def = self.net_def.op.add()
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self.dt
op_def.name = op.name
op_def.type = 'Relu'
op_def.type = 'Activation'
op_def.input.extend([input.name for input in op.inputs])
op_def.output.extend([output.name for output in op.outputs])
self.add_output_shape(op.outputs, op_def)
activation_arg = op_def.arg.add()
activation_arg.name = 'activation'
activation_arg.s = "RELUX"
max_limit_arg = op_def.arg.add()
max_limit_arg.name = 'max_limit'
max_limit_arg.f = 6
......@@ -531,6 +552,9 @@ class TFConverter(object):
and self.tf_graph[final_op.name][0].type == 'Relu':
relu_op = self.tf_graph[final_op.name][0]
op_def.type = "FusedConv2D"
fused_relu_arg = op_def.arg.add()
fused_relu_arg.name = 'activation'
fused_relu_arg.s = "RELU"
final_op = relu_op
self.resolved_ops[relu_op.name] = 1
......@@ -602,6 +626,8 @@ class TFConverter(object):
self.convert_batchnorm(op)
elif op.type == 'AvgPool' or op.type == 'MaxPool':
self.convert_pooling(op)
elif op.type == 'Relu':
self.convert_relu(op)
elif op.type == 'Relu6':
self.convert_relu6(op)
elif op.type == 'Add':
......@@ -618,8 +644,8 @@ class TFConverter(object):
self.convert_space_to_batch(op, True)
elif self.is_softmax(op):
self.convert_softmax(op)
elif op.type in ['Relu']:
self.convert_normal_op(op)
#elif op.type in ['']:
# self.convert_normal_op(op)
else:
raise Exception('Unknown Op: %s, type: %s' % (op.name, op.type))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册