From 8673a20ed7e5034c915b3e6b011d0b1f710ff03d Mon Sep 17 00:00:00 2001 From: liuqi Date: Thu, 21 Mar 2019 19:03:12 +0800 Subject: [PATCH] Feature: Support caffe Interp layer. --- .../tools/converter_tool/caffe_converter.py | 17 +++++++++++++++-- .../tools/converter_tool/shape_inference.py | 15 +++++++++++++++ third_party/caffe/caffe.proto | 10 ++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/mace/python/tools/converter_tool/caffe_converter.py b/mace/python/tools/converter_tool/caffe_converter.py index 3231ea9f..b5f7eb45 100644 --- a/mace/python/tools/converter_tool/caffe_converter.py +++ b/mace/python/tools/converter_tool/caffe_converter.py @@ -183,6 +183,7 @@ class CaffeConverter(base_converter.ConverterInterface): 'Slice': self.convert_slice, 'Softmax': self.convert_softmax, 'InnerProduct': self.convert_fully_connected, + 'Interp': self.convert_interp, 'BatchNorm': self.convert_folded_batchnorm, 'Crop': self.convert_crop, 'Scale': self.convert_scale, @@ -555,7 +556,7 @@ class CaffeConverter(base_converter.ConverterInterface): axis_arg = op.arg.add() axis_arg.name = MaceKeyword.mace_axis_str axis_arg.i = 2 - if param.HasField('axis'): + if param.HasField(MaceKeyword.mace_axis_str): axis_arg.i = param.axis axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i offset_arg = op.arg.add() @@ -573,7 +574,7 @@ class CaffeConverter(base_converter.ConverterInterface): axis_arg = op.arg.add() axis_arg.name = MaceKeyword.mace_axis_str axis_arg.i = 1 - if param.HasField('axis'): + if param.HasField(MaceKeyword.mace_axis_str): axis_arg.i = param.axis elif param.HasField('concat_dim'): axis_arg.i = param.concat_dim @@ -593,6 +594,18 @@ class CaffeConverter(base_converter.ConverterInterface): axis_arg.name = MaceKeyword.mace_axis_str axis_arg.i = 1 + def convert_interp(self, caffe_op): + op = self.convert_general_op(caffe_op) + param = caffe_op.layer.interp_param + mace_check(param.HasField("height") and param.HasField("width"), + 'Only support bilinear interp with height and width') + op.type = MaceOp.ResizeBilinear.name + + size_arg = op.arg.add() + size_arg.name = MaceKeyword.mace_resize_size_str + size_value = np.array([param.height, param.width], dtype=np.int32) + size_arg.ints.extend(size_value) + def convert_fully_connected(self, caffe_op): op = self.convert_general_op(caffe_op) param = caffe_op.layer.inner_product_param diff --git a/mace/python/tools/converter_tool/shape_inference.py b/mace/python/tools/converter_tool/shape_inference.py index 3e472216..0c0bfc49 100644 --- a/mace/python/tools/converter_tool/shape_inference.py +++ b/mace/python/tools/converter_tool/shape_inference.py @@ -52,6 +52,7 @@ class ShapeInference(object): MaceOp.Transpose.name: self.infer_shape_permute, MaceOp.PriorBox.name: self.infer_shape_prior_box, MaceOp.Reshape.name: self.infer_shape_reshape, + MaceOp.ResizeBilinear.name: self.infer_shape_resize_bilinear, } self._net = net @@ -289,3 +290,17 @@ class ShapeInference(object): output_shape.append(self._output_shape_cache[op.input[0]][i]) output_shape[axis] = dim self.add_output_shape(op, [output_shape]) + + def infer_shape_resize_bilinear(self, op): + input_shape = self._output_shape_cache[op.input[0]] + size = ConverterUtil.get_arg( + op, MaceKeyword.mace_resize_size_str).ints + if ConverterUtil.data_format(op) == DataFormat.NCHW: + output_shape = [input_shape[0], input_shape[1], size[0], size[1]] + elif ConverterUtil.data_format(op) == DataFormat.NHWC: + output_shape = [input_shape[0], size[0], size[1], input_shape[3]] + else: + output_shape = [] + mace_check(False, "format %s is not supported" + % ConverterUtil.data_format(op)) + self.add_output_shape(op, [output_shape]) diff --git a/third_party/caffe/caffe.proto b/third_party/caffe/caffe.proto index b2d56b98..c972c9f6 100644 --- a/third_party/caffe/caffe.proto +++ b/third_party/caffe/caffe.proto @@ -515,6 +515,7 @@ message LayerParameter { optional InfogainLossParameter infogain_loss_param = 116; optional InnerProductParameter inner_product_param = 117; optional InputParameter input_param = 143; + optional InterpParameter interp_param = 147; optional LogParameter log_param = 134; optional LRNParameter lrn_param = 118; optional MemoryDataParameter memory_data_param = 119; @@ -1207,6 +1208,15 @@ message InputParameter { repeated BlobShape shape = 1; } +message InterpParameter { + optional int32 height = 1 [default = 0]; // Height of output + optional int32 width = 2 [default = 0]; // Width of output + optional int32 zoom_factor = 3 [default = 1]; // zoom factor + optional int32 shrink_factor = 4 [default = 1]; // shrink factor + optional int32 pad_beg = 5 [default = 0]; // padding at begin of input + optional int32 pad_end = 6 [default = 0]; // padding at end of input +} + // Message that stores parameters used by LogLayer message LogParameter { // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0. -- GitLab