提交 6d768ae0 编写于 作者: L liutuo

add unsqueeze constant ops for onnx and fix concat bug

上级 0792637f
......@@ -94,8 +94,14 @@ class ConcatOp<DeviceType::CPU, T> : public ConcatOpBase {
}
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
T *output_ptr = output->mutable_data<T>();
Tensor::MappingGuard output_guard(output);
std::vector<Tensor::MappingGuard> mappers;
for (size_t i = 0; i < inputs_count; ++i) {
mappers.emplace_back(Tensor::MappingGuard(inputs[i]));
}
T *output_ptr = output->mutable_data<T>();
std::vector<const T *> input_ptrs(inputs.size(), nullptr);
for (size_t i = 0; i < inputs_count; ++i) {
input_ptrs[i] = inputs[i]->data<T>();
......
......@@ -76,6 +76,7 @@ extern void RegisterSumGroup(OpRegistryBase *op_registry);
extern void RegisterTargetRMSNorm(OpRegistryBase *op_registry);
extern void RegisterTranspose(OpRegistryBase *op_registry);
extern void RegisterUnstack(OpRegistryBase *op_registry);
extern void RegisterUnsqueeze(OpRegistryBase *op_registry);
#ifdef MACE_ENABLE_QUANTIZE
extern void RegisterDequantize(OpRegistryBase *op_registry);
......@@ -149,6 +150,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterTargetRMSNorm(this);
ops::RegisterTranspose(this);
ops::RegisterUnstack(this);
ops::RegisterUnsqueeze(this);
#ifdef MACE_ENABLE_QUANTIZE
ops::RegisterDequantize(this);
......
......@@ -27,10 +27,7 @@ namespace mace {
namespace ops {
template<DeviceType D, typename T>
class TransposeOp;
template<DeviceType D>
class TransposeOp<D, float> : public Operation {
class TransposeOp : public Operation {
public:
explicit TransposeOp(OpConstructContext *context)
: Operation(context),
......@@ -52,8 +49,8 @@ class TransposeOp<D, float> : public Operation {
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const float *input_data = input->data<float>();
float *output_data = output->mutable_data<float>();
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
return Transpose(&context->device()->cpu_runtime()->thread_pool(),
input_data, input->shape(), dims_, output_data);
......@@ -66,6 +63,8 @@ class TransposeOp<D, float> : public Operation {
void RegisterTranspose(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Transpose", TransposeOp,
DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "Transpose", TransposeOp,
DeviceType::CPU, half);
}
} // namespace ops
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <unordered_set>
#include <vector>
#include "mace/core/operator.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class UnsqueezeOp : public Operation {
public:
explicit UnsqueezeOp(OpConstructContext *context)
: Operation(context),
axis_(Operation::GetRepeatedArgs<int>("axis", {})) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(0);
MACE_CHECK(!axis_.empty(), "Unsqueeze op should have axis values.");
std::vector<index_t> output_shape = input->shape();
for (size_t i = 0; i < axis_.size(); ++i) {
MACE_CHECK(axis_[i] >= 0, "axis's value should be non-negative.");
output_shape.insert(output_shape.begin() + axis_[i], 1);
}
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_data = input->data<T>();
T *output_data = output->mutable_data<T>();
const index_t data_size =
std::accumulate(input->shape().begin(), input->shape().end(), 1,
std::multiplies<index_t>());
memcpy(output_data, input_data, data_size * sizeof(T));
return MaceStatus::MACE_SUCCESS;
}
private:
std::vector<int> axis_;
private:
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
void RegisterUnsqueeze(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Unsqueeze", UnsqueezeOp,
DeviceType::CPU, float);
MACE_REGISTER_OP(op_registry, "Unsqueeze", UnsqueezeOp,
DeviceType::CPU, int32_t);
}
} // namespace ops
} // namespace mace
......@@ -156,6 +156,7 @@ MaceSupportedOps = [
'Squeeze',
'Stack',
'Unstack',
'Unsqueeze',
'StridedSlice',
'Softmax',
'SpaceToBatchND',
......
......@@ -72,7 +72,7 @@ OnnxSupportedOps = [
'Clip',
# 'Compress',
'Concat',
# 'Constant',
'Constant',
# 'ConstantLike',
'Conv',
'ConvTranspose',
......@@ -179,7 +179,7 @@ OnnxSupportedOps = [
# 'Tile',
# 'TopK',
'Transpose',
# 'Unsqueeze',
'Unsqueeze',
# 'Upsample',
# 'Xor',
]
......@@ -336,6 +336,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.Concat.name: self.convert_concat,
OnnxOpType.Conv.name: self.convert_conv2d,
OnnxOpType.ConvTranspose.name: self.convert_deconv,
OnnxOpType.Constant.name: self.convert_constant,
OnnxOpType.DepthToSpace.name: self.convert_depth_space,
OnnxOpType.Dropout.name: self.convert_dropout,
OnnxOpType.DimRange.name: self.convert_dim_range,
......@@ -371,6 +372,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.Reciprocal.name: self.convert_eltwise,
OnnxOpType.ReduceMean.name: self.convert_reduce,
OnnxOpType.Scale.name: self.convert_eltwise,
OnnxOpType.Shape.name: self.convert_shape,
OnnxOpType.Sigmoid.name: self.convert_activation,
OnnxOpType.Slice.name: self.convert_slice,
OnnxOpType.Softmax.name: self.convert_softmax,
......@@ -385,6 +387,7 @@ class OnnxConverter(base_converter.ConverterInterface):
OnnxOpType.Tanh.name: self.convert_activation,
OnnxOpType.TargetRMSNorm: self.convert_target_rms_norm,
OnnxOpType.Transpose.name: self.convert_transpose,
OnnxOpType.Unsqueeze.name: self.convert_unsqueeze,
}
self._option = option
self._mace_net_def = mace_pb2.NetDef()
......@@ -513,6 +516,13 @@ class OnnxConverter(base_converter.ConverterInterface):
new_shape = shape
return new_shape
@staticmethod
def unsqueeze_shape(shape, axis):
new_shape = [n for n in shape]
for n in axis:
new_shape.insert(n, 1)
return new_shape
@staticmethod
def transpose_const(tensor):
shape = tensor.dims
......@@ -663,14 +673,34 @@ class OnnxConverter(base_converter.ConverterInterface):
mace_check('axis' in node.attrs,
'Concat op should have axis attribute.')
axis_value = node.attrs['axis']
mace_check(axis_value == 1 or axis_value == -3,
"only support concat at channel dimension")
else:
axis_value = -1
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = axis_value
def convert_constant(self, node):
output_name = node.outputs[0]
tensor = self._mace_net_def.tensors.add()
tensor.name = output_name
onnx_tensor = node.attrs['value']
tensor_value = numpy_helper.to_array(onnx_tensor)
tensor.dims.extend(list(onnx_tensor.dims))
data_type = onnx_dtype(onnx_tensor.data_type)
if data_type == np.float32 or data_type == np.float64:
tensor.data_type = mace_pb2.DT_FLOAT
tensor.float_data.extend(
tensor_value.astype(np.float32).flat)
elif data_type == np.int32 or data_type == np.int64:
tensor.data_type = mace_pb2.DT_INT32
tensor.int32_data.extend(
tensor_value.astype(np.int32).flat)
else:
mace_check(False,
"Not supported tensor type: %s" % data_type)
self._consts[tensor.name] = tensor
def convert_conv2d(self, node):
op = self.convert_general_op(node)
self.add_stride_pad_kernel_arg(node.attrs, op)
......@@ -1079,55 +1109,64 @@ class OnnxConverter(base_converter.ConverterInterface):
if self._isKaldi:
self.convert_affine(node)
return
# only supports FullyConnected Style Gemm for now.
mace_check(len(node.inputs) >= 2,
"Gemm should have at least two inputs.")
if 'alpha' in node.attrs:
alpha = node.attrs['alpha']
if alpha != 1.0 and node.inputs[1] in self._consts:
weights = self._consts[node.inputs[1]]
for idx in six.moves.range(self.get_tensor_len(weights)):
weights.float_data[idx] *= alpha
if 'beta' in node.attrs:
beta = node.attrs['beta']
if beta != 1.0 and len(node.inputs) == 3 and\
node.inputs[2] in self._consts:
bias = self._consts[node.inputs[2]]
for idx in six.moves.range(self.get_tensor_len(bias)):
bias.float_data[idx] *= beta
trans_a = node.attrs['transA'] if 'transA' in node.attrs else 0
trans_b = node.attrs['transB'] if 'transB' in node.attrs else 0
shape_a = self._graph_shapes_dict[node.inputs[0]]
shape_b = self._graph_shapes_dict[node.inputs[1]]
mace_check(trans_a == 0 and trans_b == 1,
"Do not support non-default transpose")
mace_check(len(shape_a) == 4,
"Unexpected fc input ndim.")
mace_check(node.inputs[1] in self._consts, "unexpect fc weight.")
if len(shape_b) == 4:
mace_check(list(shape_b[2:]) == [1, 1],
"Only support 4D weight with shape [*, *, 1, 1]")
elif len(shape_b) == 2:
tensor_b = self._consts[node.inputs[1]]
tensor_data = np.array(tensor_b.float_data).reshape(
shape_b[0], shape_b[1], 1, 1)
tensor_b.float_data[:] = tensor_data.flat
tensor_b.dims[:] = tensor_data.shape
is_fc = False
if trans_a == 0 and trans_b == 1 and\
node.inputs[0] in self._graph_shapes_dict and\
node.inputs[1] in self._graph_shapes_dict and \
node.inputs[1] in self._consts:
shape_a = self._graph_shapes_dict[node.inputs[0]]
shape_b = self._graph_shapes_dict[node.inputs[1]]
if len(shape_a) == 4 and len(shape_b) == 2:
tensor_b = self._consts[node.inputs[1]]
tensor_data = np.array(tensor_b.float_data).reshape(
shape_b[0], shape_b[1], 1, 1)
tensor_b.float_data[:] = tensor_data.flat
tensor_b.dims[:] = tensor_data.shape
is_fc = True
elif len(shape_a) == 4 and\
len(shape_b) == 4 and list(shape_b[2:]) == [1, 1]:
is_fc = True
if is_fc:
op = self.convert_general_op(node, with_shape=False)
op.type = MaceOp.FullyConnected.name
for output in node.outputs:
output_shape = op.output_shape.add()
shape_info = self._graph_shapes_dict[output]
mace_check(len(shape_info) in [2, 4],
"gemm output shape should be 2 or 4 dims.")
if len(shape_info) == 4:
mace_check(list(shape_info[2:]) == [1, 1],
"gemm's output shape should be [*, * , 1, 1]")
else:
shape_info = [shape_info[0], shape_info[1], 1, 1]
output_shape.dims.extend(shape_info)
else:
mace_check(False, "Unexpected fc weigth ndim.")
op = self._mace_net_def.op.add()
op.name = node.name
op.type = MaceOp.FullyConnected.name
data_type_arg = op.arg.add()
data_type_arg.name = 'T'
data_type_arg.i = self._option.data_type
framework_type_arg = op.arg.add()
framework_type_arg.name = MaceKeyword.mace_framework_type_str
framework_type_arg.i = FrameworkType.ONNX.value
ConverterUtil.add_data_format_arg(op, DataFormat.NCHW)
for input in node.inputs:
op.input.append(input)
for output in node.outputs:
op.output.append(output)
output_shape = op.output_shape.add()
shape_info = self._graph_shapes_dict[output]
mace_check(len(shape_info) in [2, 4],
"gemm output shape should be 2 or 4 dims.")
if len(shape_info) == 4:
mace_check(shape_info[2] == 1 and shape_info[3] == 1,
"gemm's 4-dim output shape should be [*, * , 1, 1]")
else:
shape_info = [shape_info[0], shape_info[1], 1, 1]
output_shape.dims.extend(shape_info)
op = self.convert_general_op(node)
op.type = MaceOp.MatMul.name
trans_a_arg = op.arg.add()
trans_a_arg.name = MaceKeyword.mace_transpose_a_str
trans_a_arg.i = trans_a
trans_b_arg = op.arg.add()
trans_b_arg.name = MaceKeyword.mace_transpose_b_str
trans_b_arg.i = trans_b
def convert_identity(self, node):
op = self.convert_general_op(node)
......@@ -1279,6 +1318,11 @@ class OnnxConverter(base_converter.ConverterInterface):
op = self.convert_general_op(node)
op.type = MaceOp.Reshape.name
def convert_shape(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Shape.name
op.output_type.extend([mace_pb2.DT_INT32])
def convert_slice(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Slice.name
......@@ -1357,6 +1401,24 @@ class OnnxConverter(base_converter.ConverterInterface):
axis_value = []
axis_arg.ints.extend(axis_value)
def convert_unsqueeze(self, node):
mace_check('axes' in node.attrs,
"Unsqueeze op should have 'axes' attribute.")
axis_value = node.attrs['axes']
if node.inputs[0] in self._consts:
tensor = self._consts[node.inputs[0]]
shape = tensor.dims
new_shape = self.unsqueeze_shape(shape, axis_value)
del tensor.dims[:]
tensor.dims.extend(new_shape)
self.remove_node(node)
else:
op = self.convert_general_op(node)
op.type = MaceOp.Unsqueeze.name
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.ints.extend(axis_value)
def convert_sum_group(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.SumGroup.name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册