提交 89832e0d 编写于 作者: L lichao18

Fix reshape format bugs

上级 9d3e2cc5
......@@ -75,8 +75,17 @@ class ReshapeOp : public Operation {
<< "Input size not match reshaped tensor size";
out_shape[unknown_idx] = missing;
}
Tensor *output = this->Output(OUTPUT);
// NHWC -> NCHW
auto df = static_cast<DataFormat>(Operation::GetOptionalArg<int>(
"data_format", DataFormat::DF_NONE));
if (df == DataFormat::NHWC && D == DeviceType::CPU
&& out_shape.size() == 4 && shape->is_weight()) {
std::vector<int> dst_dims = {0, 3, 1, 2};
std::vector<index_t> out_shape_gpu = TransposeShape<index_t, index_t>(
out_shape, dst_dims);
out_shape = out_shape_gpu;
}
output->ReuseTensorBuffer(*input);
output->Reshape(out_shape);
......
......@@ -276,18 +276,19 @@ class ShapeInference(object):
output_shape[idx] = input_size / product
self.add_output_shape(op, [output_shape])
else:
output_shape = list(self._output_shape_cache[op.input[0]])
output_shape = []
axis = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str).i
end_axis = ConverterUtil.get_arg(op, MaceKeyword.mace_end_axis_str).i # noqa
if end_axis < 0:
end_axis = len(output_shape) + end_axis
end_axis = end_axis if end_axis > 0 else end_axis + len(
list(self._output_shape_cache[op.input[0]]))
dim = 1
for i in range(0, axis):
output_shape[i] = self._output_shape_cache[op.input[0]][i]
output_shape.append(self._output_shape_cache[op.input[0]][i])
for i in range(axis, end_axis + 1):
dim *= self._output_shape_cache[op.input[0]][i]
output_shape[i] = 1
for i in range(end_axis + 1, len(output_shape)):
output_shape[i] = self._output_shape_cache[op.input[0]][i]
output_shape.append(-1)
for i in range(end_axis + 1, len(
list(self._output_shape_cache[op.input[0]]))):
output_shape.append(self._output_shape_cache[op.input[0]][i])
output_shape[axis] = dim
self.add_output_shape(op, [output_shape])
......@@ -1790,31 +1790,35 @@ class Transformer(base_converter.ConverterInterface):
if op.type == MaceOp.Reshape.name and \
len(op.input) == 1:
print("Transform Caffe Reshape")
if op.arg[3].name == 'dim':
dims = []
dim_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_dim_str)
axis_arg = ConverterUtil.get_arg(op, MaceKeyword.mace_axis_str)
# transform caffe reshape op
if dim_arg:
dims = dim_arg.ints
shape_tensor = net.tensors.add()
shape_tensor.name = op.name + '_shape'
shape_tensor.dims.append(len(op.output_shape[0].dims))
shape_tensor.data_type = mace_pb2.DT_INT32
shape_tensor.int32_data.extend(op.arg[3].ints)
op.input.append(shape_tensor.name)
else:
axis = op.arg[3].i
dims = [1] * len(op.output_shape[0].dims)
end_axis = op.arg[4].i
end_axis = end_axis if end_axis >= 0 else end_axis + len(dims) # noqa
# transform caffe flatten op
elif axis_arg is not None:
axis = axis_arg.i
for i in range(0, axis):
dims[i] = 0
for i in range(axis + 1, end_axis + 1):
dims[i] = 1
for i in range(end_axis + 1, len(dims)):
dims[i] = 0
dims[axis] = -1
dims.append(0)
dims.append(-1)
for i in range(axis + 1, len(op.output_shape[0].dims)):
dims.append(0)
shape_tensor = net.tensors.add()
shape_tensor.name = op.name + '_shape'
shape_tensor.dims.append(len(dims))
shape_tensor.data_type = mace_pb2.DT_INT32
shape_tensor.int32_data.extend(dims)
op.input.append(shape_tensor.name)
else:
mace_check(False, "Only support reshape and flatten")
# NCHW -> NHWC
if len(dims) == 4:
self.transpose_shape(dims, [0, 2, 3, 1])
shape_tensor.int32_data.extend(dims)
op.input.append(shape_tensor.name)
def fold_fc_reshape(self):
net = self._model
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册