提交 8673a20e 编写于 作者: L liuqi

Feature: Support caffe Interp layer.

上级 04db6237
......@@ -183,6 +183,7 @@ class CaffeConverter(base_converter.ConverterInterface):
'Slice': self.convert_slice,
'Softmax': self.convert_softmax,
'InnerProduct': self.convert_fully_connected,
'Interp': self.convert_interp,
'BatchNorm': self.convert_folded_batchnorm,
'Crop': self.convert_crop,
'Scale': self.convert_scale,
......@@ -555,7 +556,7 @@ class CaffeConverter(base_converter.ConverterInterface):
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = 2
if param.HasField('axis'):
if param.HasField(MaceKeyword.mace_axis_str):
axis_arg.i = param.axis
axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i
offset_arg = op.arg.add()
......@@ -573,7 +574,7 @@ class CaffeConverter(base_converter.ConverterInterface):
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = 1
if param.HasField('axis'):
if param.HasField(MaceKeyword.mace_axis_str):
axis_arg.i = param.axis
elif param.HasField('concat_dim'):
axis_arg.i = param.concat_dim
......@@ -593,6 +594,18 @@ class CaffeConverter(base_converter.ConverterInterface):
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = 1
def convert_interp(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.interp_param
mace_check(param.HasField("height") and param.HasField("width"),
'Only support bilinear interp with height and width')
op.type = MaceOp.ResizeBilinear.name
size_arg = op.arg.add()
size_arg.name = MaceKeyword.mace_resize_size_str
size_value = np.array([param.height, param.width], dtype=np.int32)
size_arg.ints.extend(size_value)
def convert_fully_connected(self, caffe_op):
op = self.convert_general_op(caffe_op)
param = caffe_op.layer.inner_product_param
......
......@@ -52,6 +52,7 @@ class ShapeInference(object):
MaceOp.Transpose.name: self.infer_shape_permute,
MaceOp.PriorBox.name: self.infer_shape_prior_box,
MaceOp.Reshape.name: self.infer_shape_reshape,
MaceOp.ResizeBilinear.name: self.infer_shape_resize_bilinear,
}
self._net = net
......@@ -289,3 +290,17 @@ class ShapeInference(object):
output_shape.append(self._output_shape_cache[op.input[0]][i])
output_shape[axis] = dim
self.add_output_shape(op, [output_shape])
def infer_shape_resize_bilinear(self, op):
input_shape = self._output_shape_cache[op.input[0]]
size = ConverterUtil.get_arg(
op, MaceKeyword.mace_resize_size_str).ints
if ConverterUtil.data_format(op) == DataFormat.NCHW:
output_shape = [input_shape[0], input_shape[1], size[0], size[1]]
elif ConverterUtil.data_format(op) == DataFormat.NHWC:
output_shape = [input_shape[0], size[0], size[1], input_shape[3]]
else:
output_shape = []
mace_check(False, "format %s is not supported"
% ConverterUtil.data_format(op))
self.add_output_shape(op, [output_shape])
......@@ -515,6 +515,7 @@ message LayerParameter {
optional InfogainLossParameter infogain_loss_param = 116;
optional InnerProductParameter inner_product_param = 117;
optional InputParameter input_param = 143;
optional InterpParameter interp_param = 147;
optional LogParameter log_param = 134;
optional LRNParameter lrn_param = 118;
optional MemoryDataParameter memory_data_param = 119;
......@@ -1207,6 +1208,15 @@ message InputParameter {
repeated BlobShape shape = 1;
}
message InterpParameter {
optional int32 height = 1 [default = 0]; // Height of output
optional int32 width = 2 [default = 0]; // Width of output
optional int32 zoom_factor = 3 [default = 1]; // zoom factor
optional int32 shrink_factor = 4 [default = 1]; // shrink factor
optional int32 pad_beg = 5 [default = 0]; // padding at begin of input
optional int32 pad_end = 6 [default = 0]; // padding at end of input
}
// Message that stores parameters used by LogLayer
message LogParameter {
// LogLayer computes outputs y = log_base(shift + scale * x), for base > 0.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册