提交 99261b26 编写于 作者: 李寅

Merge branch 'atrous_depthwise' into 'master'

Support atrous depthwise convolution converter.

See merge request !438
...@@ -946,7 +946,8 @@ class TFConverter(object): ...@@ -946,7 +946,8 @@ class TFConverter(object):
def is_atrous_conv2d(self, op): def is_atrous_conv2d(self, op):
return op.type == 'SpaceToBatchND' and \ return op.type == 'SpaceToBatchND' and \
len(self.tf_graph[op.name]) == 1 and \ len(self.tf_graph[op.name]) == 1 and \
self.tf_graph[op.name][0].type == 'Conv2D' (self.tf_graph[op.name][0].type == 'Conv2D'
or self.tf_graph[op.name][0].type == 'DepthwiseConv2dNative')
def convert_atrous_conv2d(self, op): def convert_atrous_conv2d(self, op):
op_def = mace_pb2.OperatorDef() op_def = mace_pb2.OperatorDef()
...@@ -955,18 +956,25 @@ class TFConverter(object): ...@@ -955,18 +956,25 @@ class TFConverter(object):
arg.i = self.dt arg.i = self.dt
conv_op = self.tf_graph[op.name][0] conv_op = self.tf_graph[op.name][0]
op_def.name = conv_op.name op_def.name = conv_op.name
op_def.type = conv_op.type if conv_op.type == 'DepthwiseConv2dNative':
op_def.type = 'DepthwiseConv2d'
else:
op_def.type = conv_op.type
if self.device == 'gpu': if self.device == 'gpu':
self.transpose_filter_tensor[
get_input_tensor(conv_op, 1).name] = (0, 1, 3, 2)
op_def.input.extend([op.inputs[0].name]) op_def.input.extend([op.inputs[0].name])
if op_def.type == 'DepthwiseConv2d':
buffer_type = "DW_CONV2D_FILTER"
else:
self.transpose_filter_tensor[get_input_tensor(
conv_op, 1).name] = (0, 1, 3, 2)
buffer_type = "CONV2D_FILTER"
output_name = self.add_buffer_to_image( output_name = self.add_buffer_to_image(
get_input_tensor(conv_op, 1).name, "CONV2D_FILTER") get_input_tensor(conv_op, 1).name, buffer_type)
op_def.input.extend([output_name]) op_def.input.extend([output_name])
else: else:
self.transpose_filter_tensor[ self.transpose_filter_tensor[get_input_tensor(
get_input_tensor(conv_op, 1).name] = (3, 2, 0, 1) conv_op, 1).name] = (3, 2, 0, 1)
op_def.input.extend([get_input_tensor(op, 0).name]) op_def.input.extend([get_input_tensor(op, 0).name])
op_def.input.extend([get_input_tensor(conv_op, 1).name]) op_def.input.extend([get_input_tensor(conv_op, 1).name])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册