提交 a5f2fdc0 编写于 作者: 叶剑武

Merge branch 'winograd' into 'master'

Support winograd convolution type conversion.

See merge request !8
......@@ -32,6 +32,16 @@ class MemoryOptimizer(object):
def is_buffer_image_op(self, op):
return op.type == 'BufferToImage' or op.type == 'ImageToBuffer'
def get_mem_size(self, op_type, output_shape):
mem_size = [0, 0]
if op_type == 'WinogradTransform' or op_type == 'GEMM':
mem_size[0] = output_shape[2] * output_shape[3]
mem_size[1] = output_shape[0] * int((output_shape[1]+3)/4)
else:
mem_size[0] = output_shape[2] * int((output_shape[3]+3)/4)
mem_size[1] = output_shape[0] * output_shape[1]
return mem_size
def optimize(self):
for op in self.net_def.op:
if self.is_buffer_image_op(op):
......@@ -52,8 +62,9 @@ class MemoryOptimizer(object):
if mem_id not in self.mem_block:
self.mem_block[mem_id] = [0, 0]
mem_size = self.mem_block[mem_id]
mem_size[1] = max(mem_size[1], op.output_shape[0].dims[0] * op.output_shape[0].dims[1])
mem_size[0] = max(mem_size[0], op.output_shape[0].dims[2] * int((op.output_shape[0].dims[3]+3)/4))
op_mem_size = self.get_mem_size(op.type, op.output_shape[0].dims)
mem_size[0] = max(mem_size[0], op_mem_size[0])
mem_size[1] = max(mem_size[1], op_mem_size[1])
# de-ref input tensor mem
for ipt in op.input:
......
......@@ -35,7 +35,7 @@ def main(unused_args):
input_graph_def, FLAGS.input_node, FLAGS.output_node)
else:
output_graph_def = tf_converter_lib.convert_to_mace_pb(
input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime)
input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime, FLAGS.winograd)
if FLAGS.output_type == 'source':
source_converter_lib.convert_to_source(output_graph_def, mode_pb_checksum, FLAGS.template, FLAGS.obfuscate,
......@@ -112,6 +112,13 @@ def parse_args():
type=str,
default="",
help="model tag for generated function and namespace")
parser.add_argument(
"--winograd",
type=str2bool,
nargs='?',
const=False,
default=False,
help="open winograd convolution or not")
return parser.parse_known_args()
......
......@@ -17,8 +17,11 @@ pooling_type_mode = {
buffer_type_map = {
'FILTER' : 0,
'IN_OUT' : 1,
'IN_OUT_CHANNEL' : 1,
'ARGUMENT' : 2,
'IN_OUT_HEIGHT' : 3,
'IN_OUT_WIDTH' : 4,
'WINOGRAD_FILTER' : 5,
}
data_type_map = {
......@@ -31,6 +34,8 @@ BATCH_NORM_ORDER = ["Add", "Rsqrt", "Mul", "Mul", "Mul", "Sub", "Add"]
MACE_INPUT_NODE_NAME = "mace_input_node"
MACE_OUTPUT_NODE_NAME = "mace_output_node"
OPENCL_IMAGE_MAX_SIZE = 16384
def get_input_tensor(op, index):
input_tensor = op.inputs[index]
if input_tensor.op.type == 'Reshape':
......@@ -38,15 +43,17 @@ def get_input_tensor(op, index):
return input_tensor
class TFConverter(object):
def __init__(self, tf_ops, net_def, dt, device):
def __init__(self, tf_ops, net_def, dt, device, winograd):
self.net_def = net_def
self.tf_ops = tf_ops
self.dt = dt
self.device = device
self.winograd = winograd
self.tf_graph = {}
self.tf_parents = {}
self.resolved_ops = {}
self.unused_tensor = set()
self.transpose_filter_tensor = set()
self.ops = {}
for op in tf_ops:
......@@ -109,7 +116,7 @@ class TFConverter(object):
epsilon_arg = op_def.arg.add()
epsilon_arg.name = 'buffer_type'
epsilon_arg.i = buffer_type_map['IN_OUT']
epsilon_arg.i = buffer_type_map['IN_OUT_CHANNEL']
arg = op_def.arg.add()
arg.name = 'T'
......@@ -125,7 +132,7 @@ class TFConverter(object):
epsilon_arg = op_def.arg.add()
epsilon_arg.name = 'buffer_type'
epsilon_arg.i = buffer_type_map['IN_OUT']
epsilon_arg.i = buffer_type_map['IN_OUT_CHANNEL']
@staticmethod
def add_output_shape(outputs, op):
......@@ -157,6 +164,8 @@ class TFConverter(object):
if op.outputs[0].name not in self.unused_tensor:
tensor = self.net_def.tensors.add()
tf_tensor = op.outputs[0].eval()
if op.outputs[0].name in self.transpose_filter_tensor:
tf_tensor = tf_tensor.transpose(3, 2, 0, 1)
tensor.name = op.outputs[0].name
shape = list(tf_tensor.shape)
......@@ -173,6 +182,100 @@ class TFConverter(object):
raise Exception("Not supported tensor type: " + tf_dt.name)
self.resolved_ops[op.name] = 1
def check_winograd_conv(self, op):
filter_shape = get_input_tensor(op, 1).shape.as_list()
strides = op.get_attr('strides')[1:3]
output_shape = op.outputs[0].shape.as_list()
width = output_shape[0] * ((output_shape[1] + 1)/2) * ((output_shape[2]+1)/2)
return self.winograd and op.type != 'DepthwiseConv2dNative' and self.device == 'gpu' and \
filter_shape[0] == 3 and (filter_shape[0] == filter_shape[1]) and \
(strides[0] == 1) and (strides[0] == strides[1]) and \
(16 * filter_shape[2] < OPENCL_IMAGE_MAX_SIZE) and \
(16 * filter_shape[3] < OPENCL_IMAGE_MAX_SIZE) and \
(width < OPENCL_IMAGE_MAX_SIZE)
def convert_winograd_conv(self, op):
filter_tensor = get_input_tensor(op, 1)
filter_shape = filter_tensor.shape.as_list()
output_shape = op.outputs[0].shape.as_list()
self.transpose_filter_tensor.add(filter_tensor.name)
filter_name = self.add_buffer_to_image(op.inputs[1].name, "WINOGRAD_FILTER")
# Input transform
wt_op = mace_pb2.OperatorDef()
arg = wt_op.arg.add()
arg.name = 'T'
arg.i = self.dt
padding_arg = wt_op.arg.add()
padding_arg.name = 'padding'
padding_arg.i = padding_mode[op.get_attr('padding')]
wt_op.name = op.name + '_input_transform'
wt_op.type = 'WinogradTransform'
wt_op.input.extend([op.inputs[0].name])
wt_output_name = wt_op.name + ":0"
wt_op.output.extend([wt_output_name])
wt_output_shape = mace_pb2.OutputShape()
wt_output_width = output_shape[0] * ((output_shape[1] + 1)/2) * ((output_shape[2]+1)/2)
wt_output_shape.dims.extend([16, filter_shape[2], wt_output_width, 1])
wt_op.output_shape.extend([wt_output_shape])
# GEMM
gemm_op = mace_pb2.OperatorDef()
arg = gemm_op.arg.add()
arg.name = 'T'
arg.i = self.dt
gemm_op.name = op.name + '_gemm'
gemm_op.type = 'GEMM'
gemm_op.input.extend([filter_name, wt_output_name])
gemm_output_name = gemm_op.name + ":0"
gemm_op.output.extend([gemm_output_name])
gemm_output_shape = mace_pb2.OutputShape()
gemm_output_shape.dims.extend([16, filter_shape[3], wt_output_width, 1])
gemm_op.output_shape.extend([gemm_output_shape])
# Inverse transform
iwt_op = mace_pb2.OperatorDef()
arg = iwt_op.arg.add()
arg.name = 'T'
arg.i = self.dt
batch_arg = iwt_op.arg.add()
batch_arg.name = 'batch'
batch_arg.i = output_shape[0]
height_arg = iwt_op.arg.add()
height_arg.name = 'height'
height_arg.i = output_shape[1]
width_arg = iwt_op.arg.add()
width_arg.name = 'width'
width_arg.i = output_shape[2]
iwt_op.name = op.name + '_inverse_transform'
iwt_op.type = 'WinogradInverseTransform'
iwt_op.input.extend([gemm_output_name])
final_op = op
self.resolved_ops[op.name] = 1
if len(self.tf_graph[op.name]) == 1 and self.tf_graph[op.name][0].type == 'BiasAdd' :
bias_add_op = self.tf_graph[op.name][0]
output_name = self.add_buffer_to_image(bias_add_op.inputs[1].name, "ARGUMENT")
iwt_op.input.extend([output_name])
final_op = bias_add_op
self.resolved_ops[bias_add_op.name] = 1
if len(self.tf_graph[final_op.name]) == 1 \
and self.tf_graph[final_op.name][0].type == 'Relu':
relu_op = self.tf_graph[final_op.name][0]
fused_relu_arg = iwt_op.arg.add()
fused_relu_arg.name = 'activation'
fused_relu_arg.s = "RELU"
final_op = relu_op
self.resolved_ops[relu_op.name] = 1
iwt_op.output.extend([output.name for output in final_op.outputs])
self.add_output_shape(final_op.outputs, iwt_op)
self.net_def.op.extend([wt_op, gemm_op, iwt_op])
def convert_conv2d(self, op):
op_def = mace_pb2.OperatorDef()
arg = op_def.arg.add()
......@@ -267,7 +370,7 @@ class TFConverter(object):
output_name = self.add_buffer_to_image(name, "ARGUMENT")
op_def.input.extend([output_name])
else:
op_def.input.extend([input.name for input in input_names])
op_def.input.extend([name for name in input_names])
self.resolved_ops[op.name] = 1
final_op = op
......@@ -619,7 +722,10 @@ class TFConverter(object):
elif self.is_atrous_conv2d(op):
self.convert_atrous_conv2d(op)
elif op.type == 'Conv2D' or op.type == 'DepthwiseConv2dNative':
self.convert_conv2d(op)
if self.check_winograd_conv(op):
self.convert_winograd_conv(op)
else:
self.convert_conv2d(op)
elif op.type == 'FusedBatchNorm':
self.convert_fused_batchnorm(op)
elif op.type == 'Add' and op.name.endswith('batchnorm/add'):
......@@ -664,7 +770,7 @@ class TFConverter(object):
if self.resolved_ops[key] != 1:
print 'Unresolve Op: %s' % key
def convert_to_mace_pb(input_graph_def, input_node, output_node, data_type, device):
def convert_to_mace_pb(input_graph_def, input_node, output_node, data_type, device, winograd):
net_def = mace_pb2.NetDef()
dt = data_type_map[data_type]
......@@ -672,7 +778,7 @@ def convert_to_mace_pb(input_graph_def, input_node, output_node, data_type, devi
with session.graph.as_default() as graph:
tf.import_graph_def(input_graph_def, name="")
ops = graph.get_operations()
converter = TFConverter(ops, net_def, dt, device)
converter = TFConverter(ops, net_def, dt, device, winograd)
converter.convert(input_node, output_node)
print "PB Converted, start optimize memory."
mem_optimizer = memory_optimizer.MemoryOptimizer(net_def)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册