提交 4ef1aae0 编写于 作者: L liuqi

Support atrous depthwise convolution converter.

上级 948db3c7
......@@ -946,7 +946,8 @@ class TFConverter(object):
def is_atrous_conv2d(self, op):
return op.type == 'SpaceToBatchND' and \
len(self.tf_graph[op.name]) == 1 and \
self.tf_graph[op.name][0].type == 'Conv2D'
(self.tf_graph[op.name][0].type == 'Conv2D'
or self.tf_graph[op.name][0].type == 'DepthwiseConv2dNative')
def convert_atrous_conv2d(self, op):
op_def = mace_pb2.OperatorDef()
......@@ -955,18 +956,25 @@ class TFConverter(object):
arg.i = self.dt
conv_op = self.tf_graph[op.name][0]
op_def.name = conv_op.name
op_def.type = conv_op.type
if conv_op.type == 'DepthwiseConv2dNative':
op_def.type = 'DepthwiseConv2d'
else:
op_def.type = conv_op.type
if self.device == 'gpu':
self.transpose_filter_tensor[
get_input_tensor(conv_op, 1).name] = (0, 1, 3, 2)
op_def.input.extend([op.inputs[0].name])
if op_def.type == 'DepthwiseConv2d':
buffer_type = "DW_CONV2D_FILTER"
else:
self.transpose_filter_tensor[get_input_tensor(
conv_op, 1).name] = (0, 1, 3, 2)
buffer_type = "CONV2D_FILTER"
output_name = self.add_buffer_to_image(
get_input_tensor(conv_op, 1).name, "CONV2D_FILTER")
get_input_tensor(conv_op, 1).name, buffer_type)
op_def.input.extend([output_name])
else:
self.transpose_filter_tensor[
get_input_tensor(conv_op, 1).name] = (3, 2, 0, 1)
self.transpose_filter_tensor[get_input_tensor(
conv_op, 1).name] = (3, 2, 0, 1)
op_def.input.extend([get_input_tensor(op, 0).name])
op_def.input.extend([get_input_tensor(conv_op, 1).name])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册