提交 3890a7c5 编写于 作者: L liutuo

add reduce sum

上级 3eb10f41
......@@ -18,11 +18,11 @@
namespace mace {
enum ReduceType {
// SUM = 0,
MEAN = 0,
MIN = 1,
MAX = 2,
PROD = 3,
SUM = 4,
// SUM_SQR = 4,
// SQR_MEAN = 5,
};
......
......@@ -62,7 +62,7 @@ __kernel void reduce(OUT_OF_RANGE_PARAMS
// PROD
#elif REDUCE_TYPE == 3
part_result = part_result * in;
// MEAN
// MEAN or SUM
#else
part_result = part_result + in;
#endif
......
......@@ -167,6 +167,12 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase {
tmp = tmp * input[i];
}
output[0] = tmp;
} else if (type == ReduceType::SUM) {
T tmp = 0;
for (int i = 0; i < data_reshape_[0]; ++i) {
tmp = tmp + input[i];
}
output[0] = tmp;
} else {
MACE_NOT_IMPLEMENTED;
}
......@@ -216,6 +222,14 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase {
}
output[i] = tmp;
}
} else if (type == ReduceType::SUM) {
for (index_t i = start; i < end; i += step) {
T tmp = 0;
for (int j = 0; j < data_reshape_[0]; ++j) {
tmp += input[j * data_reshape_[1] + i];
}
output[i] = tmp;
}
} else {
MACE_NOT_IMPLEMENTED;
}
......@@ -254,6 +268,14 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase {
}
output[i] = tmp;
}
} else if (type == ReduceType::SUM) {
for (index_t i = start; i < end; i += step) {
T tmp = 0;
for (int j = 0; j < data_reshape_[1]; ++j) {
tmp += input[i * data_reshape_[1] + j];
}
output[i] = tmp;
}
} else {
MACE_NOT_IMPLEMENTED;
}
......@@ -319,6 +341,16 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase {
}
output[i] = tmp;
}
} else if (type == ReduceType::SUM) {
for (index_t i = start; i < end; i += step) {
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[0]; ++k) {
output[i] +=
input[(k * data_reshape_[1] + i) * data_reshape_[2]
+ j];
}
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
......@@ -371,6 +403,16 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase {
output[i * data_reshape_[2] + j] = tmp;
}
}
} else if (type == ReduceType::SUM) {
for (index_t i = start; i < end; i += step) {
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[1]; ++k) {
output[i * data_reshape_[2] + j] +=
input[(i * data_reshape_[1] + k) * data_reshape_[2]
+ j];
}
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
......@@ -445,6 +487,18 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase {
output[i * data_reshape_[3] + j] = tmp;
}
}
} else if (type == ReduceType::SUM) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
for (int k = 0; k < data_reshape_[2]; ++k) {
for (int t = 0; t < data_reshape_[0]; ++t) {
output[i * data_reshape_[3] + j] +=
input[((t * data_reshape_[1] + i) *
data_reshape_[2] + k) * data_reshape_[3] + j];
}
}
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
......@@ -513,6 +567,18 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase {
output[i * data_reshape_[2] + j] = tmp;
}
}
} else if (type == ReduceType::SUM) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
for (int k = 0; k < data_reshape_[1]; ++k) {
for (int t = 0; t < data_reshape_[3]; ++t) {
output[i * data_reshape_[2] + j] +=
input[((i * data_reshape_[1] + k) *
data_reshape_[2] + j) * data_reshape_[3] + t];
}
}
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
......@@ -574,6 +640,12 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce1Dims(
tmp = std::max<uint8_t>(tmp, input[i]);
}
output[0] = tmp;
} else if (type == ReduceType::SUM) {
uint32_t tmp = 0;
for (int i = 0; i < data_reshape_[0]; ++i) {
tmp = tmp + input[i];
}
output[0] = static_cast<uint8_t>(tmp + data_reshape_[0] / 2);
} else {
MACE_NOT_IMPLEMENTED;
}
......@@ -616,6 +688,14 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce2Dims(
}
output[i] = tmp;
}
} else if (type == ReduceType::SUM) {
for (index_t i = start; i < end; i += step) {
uint32_t tmp = 0;
for (int j = 0; j < data_reshape_[0]; ++j) {
tmp += input[j * data_reshape_[1] + i];
}
output[i] = static_cast<uint8_t>(tmp + data_reshape_[0] / 2);
}
} else {
MACE_NOT_IMPLEMENTED;
}
......@@ -647,6 +727,14 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce2Dims(
}
output[i] = tmp;
}
} else if (type == ReduceType::SUM) {
for (index_t i = start; i < end; i += step) {
uint32_t tmp = 0;
for (int j = 0; j < data_reshape_[1]; ++j) {
tmp += input[i * data_reshape_[1] + j];
}
output[i] = static_cast<uint8_t>(tmp + data_reshape_[1] / 2);
}
} else {
MACE_NOT_IMPLEMENTED;
}
......@@ -699,6 +787,17 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce3Dims(
}
output[i] = tmp;
}
} else if (type == ReduceType::SUM) {
for (index_t i = start; i < end; i += step) {
uint32_t tmp = 0;
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[0]; ++k) {
tmp += input[(k * data_reshape_[1] + i) * data_reshape_[2] + j];
}
}
index_t dim = data_reshape_[0] * data_reshape_[2];
output[i] = static_cast<uint8_t>(tmp + dim / 2);
}
} else {
MACE_NOT_IMPLEMENTED;
}
......@@ -742,6 +841,17 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce3Dims(
output[i * data_reshape_[2] + j] = tmp;
}
}
} else if (type == ReduceType::SUM) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
uint32_t tmp = 0;
for (int k = 0; k < data_reshape_[1]; ++k) {
tmp += input[(i * data_reshape_[1] + k) * data_reshape_[2] + j];
}
output[i * data_reshape_[2] + j] =
static_cast<uint8_t>(tmp + data_reshape_[1] / 2);
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
......@@ -804,6 +914,21 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce4Dims(
output[i * data_reshape_[3] + j] = tmp;
}
}
} else if (type == ReduceType::SUM) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
uint32_t tmp = 0;
for (int k = 0; k < data_reshape_[2]; ++k) {
for (int t = 0; t < data_reshape_[0]; ++t) {
tmp += input[((t * data_reshape_[1] + i) *
data_reshape_[2] + k) * data_reshape_[3] + j];
}
}
index_t dim = data_reshape_[0] * data_reshape_[2];
output[i * data_reshape_[3] + j] =
static_cast<uint8_t>(tmp + dim / 2);
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
......@@ -858,6 +983,21 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce4Dims(
output[i * data_reshape_[2] + j] = tmp;
}
}
} else if (type == ReduceType::SUM) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
uint32_t tmp = 0;
for (int k = 0; k < data_reshape_[1]; ++k) {
for (int t = 0; t < data_reshape_[3]; ++t) {
tmp += input[((i * data_reshape_[1] + k) *
data_reshape_[2] + j) * data_reshape_[3] + t];
}
}
index_t dim = data_reshape_[1] * data_reshape_[3];
output[i * data_reshape_[2] + j] =
static_cast<uint8_t>(tmp + dim / 2);
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
......
......@@ -88,6 +88,7 @@ class ReduceType(Enum):
MIN = 1
MAX = 2
PROD = 3
SUM = 4
class PadType(Enum):
......
......@@ -70,6 +70,7 @@ TFSupportedOps = [
'Square',
'SquaredDifference',
'Rsqrt',
'Sum',
'Equal',
'Relu',
'LeakyRelu',
......@@ -188,6 +189,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Max.name: ReduceType.MAX,
TFOpType.Mean.name: ReduceType.MEAN,
TFOpType.Prod.name: ReduceType.PROD,
TFOpType.Sum.name: ReduceType.SUM,
}
pad_type = {
......@@ -268,6 +270,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.MirrorPad.name: self.convert_pad,
TFOpType.Cumsum.name: self.convert_cumsum,
TFOpType.OneHot.name: self.convert_one_hot,
TFOpType.Sum.name: self.convert_reduce,
}
self._option = option
self._mace_net_def = mace_pb2.NetDef()
......@@ -909,7 +912,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
reduce_dims = tf_op.get_attr('reduction_indices')
except ValueError:
reduce_dims = []
axis_arg.ints.extend(reduce_dims)
if isinstance(reduce_dims, list):
axis_arg.ints.extend(reduce_dims)
else:
axis_arg.ints.append(reduce_dims)
keep_dims_arg = op.arg.add()
keep_dims_arg.name = MaceKeyword.mace_keepdims_str
try:
......
......@@ -1205,12 +1205,13 @@ class Transformer(base_converter.ConverterInterface):
if op.output[0] in self._consumers:
consumer = self._consumers[op.output[0]][0]
# if there is a shape op, remove it too
if (consumer.input[1] in self._producer
and self._producer[consumer.input[1]].type
== 'Shape'):
self.safe_remove_node(
self._producer[consumer.input[1]], None,
remove_input_tensor=True)
if len(consumer.input) > 1:
if (consumer.input[1] in self._producer
and self._producer[consumer.input[1]].type
== 'Shape'):
self.safe_remove_node(
self._producer[consumer.input[1]], None,
remove_input_tensor=True)
# remove consumer reshape
self.safe_remove_node(consumer, op,
remove_input_tensor=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册