提交 875ed7a5 编写于 作者: L liuqi

Caffe converter support depthwise convolution.

上级 a1c6ba92
...@@ -333,8 +333,18 @@ class CaffeConverter(object): ...@@ -333,8 +333,18 @@ class CaffeConverter(object):
return pad, stride, kernel return pad, stride, kernel
def convert_conv2d(self, op): def convert_conv2d(self, op):
op_def = self.CommonConvert(op, 'Conv2D')
param = op.layer.convolution_param param = op.layer.convolution_param
is_depthwise = False
if param.HasField('group'):
if param.group == op.data[0].shape[0] and op.data[0].shape[1] == 1:
is_depthwise = True
else:
raise Exception("Mace do not support group convolution yet")
if is_depthwise:
op_def = self.CommonConvert(op, 'DepthwiseConv2d')
else:
op_def = self.CommonConvert(op, 'Conv2D')
# Add filter # Add filter
weight_tensor_name = op.name + '_weight:0' weight_tensor_name = op.name + '_weight:0'
...@@ -342,7 +352,7 @@ class CaffeConverter(object): ...@@ -342,7 +352,7 @@ class CaffeConverter(object):
self.add_tensor(weight_tensor_name, weight_data) self.add_tensor(weight_tensor_name, weight_data)
if self.device == 'gpu': if self.device == 'gpu':
buffer_type = "CONV2D_FILTER" buffer_type = "DW_CONV2D_FILTER" if is_depthwise else "CONV2D_FILTER"
output_name = self.add_buffer_to_image(weight_tensor_name, buffer_type) output_name = self.add_buffer_to_image(weight_tensor_name, buffer_type)
op_def.input.extend([output_name]) op_def.input.extend([output_name])
else: else:
...@@ -381,6 +391,7 @@ class CaffeConverter(object): ...@@ -381,6 +391,7 @@ class CaffeConverter(object):
if len(self.ops_map[final_op.name].children) == 1 \ if len(self.ops_map[final_op.name].children) == 1 \
and self.ops_map[final_op.name].children[0].type in activation_name_map: and self.ops_map[final_op.name].children[0].type in activation_name_map:
activation_op = self.ops_map[final_op.name].children[0] activation_op = self.ops_map[final_op.name].children[0]
if not is_depthwise:
op_def.type = "FusedConv2D" op_def.type = "FusedConv2D"
fused_act_arg = op_def.arg.add() fused_act_arg = op_def.arg.add()
fused_act_arg.name = 'activation' fused_act_arg.name = 'activation'
...@@ -412,7 +423,7 @@ class CaffeConverter(object): ...@@ -412,7 +423,7 @@ class CaffeConverter(object):
width = output_shape[0] * ((output_shape[1] + 1)/2) * ((output_shape[2]+1)/2) width = output_shape[0] * ((output_shape[1] + 1)/2) * ((output_shape[2]+1)/2)
return self.winograd and self.device == 'gpu' and \ return self.winograd and self.device == 'gpu' and \
filter_shape[0] == 3 and (filter_shape[0] == filter_shape[1]) and \ filter_shape[0] == 3 and (filter_shape[0] == filter_shape[1]) and \
dilations[0] == 1 and (dilations[0] == dilations[1]) and\ dilations[0] == 1 and (dilations[0] == dilations[1]) and \
(strides[0] == 1) and (strides[0] == strides[1]) and \ (strides[0] == 1) and (strides[0] == strides[1]) and \
(16 * filter_shape[2] < OPENCL_IMAGE_MAX_SIZE) and \ (16 * filter_shape[2] < OPENCL_IMAGE_MAX_SIZE) and \
(16 * filter_shape[3] < OPENCL_IMAGE_MAX_SIZE) and \ (16 * filter_shape[3] < OPENCL_IMAGE_MAX_SIZE) and \
...@@ -966,3 +977,4 @@ def convert_to_mace_pb(model_file, weight_file, input_node_str, input_shape_str, ...@@ -966,3 +977,4 @@ def convert_to_mace_pb(model_file, weight_file, input_node_str, input_shape_str,
print "Memory optimization done." print "Memory optimization done."
return net_def return net_def
...@@ -362,6 +362,7 @@ class TFConverter(object): ...@@ -362,6 +362,7 @@ class TFConverter(object):
if len(self.tf_graph.get(final_op.name, [])) == 1 \ if len(self.tf_graph.get(final_op.name, [])) == 1 \
and self.tf_graph[final_op.name][0].type in activation_name_map: and self.tf_graph[final_op.name][0].type in activation_name_map:
activation_op = self.tf_graph[final_op.name][0] activation_op = self.tf_graph[final_op.name][0]
if op_def.type == "Conv2D":
op_def.type = "FusedConv2D" op_def.type = "FusedConv2D"
fused_act_arg = op_def.arg.add() fused_act_arg = op_def.arg.add()
fused_act_arg.name = 'activation' fused_act_arg.name = 'activation'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册