提交 3890a7c5 编写于 作者: L liutuo

add reduce sum

上级 3eb10f41
...@@ -18,11 +18,11 @@ ...@@ -18,11 +18,11 @@
namespace mace { namespace mace {
enum ReduceType { enum ReduceType {
// SUM = 0,
MEAN = 0, MEAN = 0,
MIN = 1, MIN = 1,
MAX = 2, MAX = 2,
PROD = 3, PROD = 3,
SUM = 4,
// SUM_SQR = 4, // SUM_SQR = 4,
// SQR_MEAN = 5, // SQR_MEAN = 5,
}; };
......
...@@ -62,7 +62,7 @@ __kernel void reduce(OUT_OF_RANGE_PARAMS ...@@ -62,7 +62,7 @@ __kernel void reduce(OUT_OF_RANGE_PARAMS
// PROD // PROD
#elif REDUCE_TYPE == 3 #elif REDUCE_TYPE == 3
part_result = part_result * in; part_result = part_result * in;
// MEAN // MEAN or SUM
#else #else
part_result = part_result + in; part_result = part_result + in;
#endif #endif
......
...@@ -167,6 +167,12 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase { ...@@ -167,6 +167,12 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase {
tmp = tmp * input[i]; tmp = tmp * input[i];
} }
output[0] = tmp; output[0] = tmp;
} else if (type == ReduceType::SUM) {
T tmp = 0;
for (int i = 0; i < data_reshape_[0]; ++i) {
tmp = tmp + input[i];
}
output[0] = tmp;
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
...@@ -216,6 +222,14 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase { ...@@ -216,6 +222,14 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase {
} }
output[i] = tmp; output[i] = tmp;
} }
} else if (type == ReduceType::SUM) {
for (index_t i = start; i < end; i += step) {
T tmp = 0;
for (int j = 0; j < data_reshape_[0]; ++j) {
tmp += input[j * data_reshape_[1] + i];
}
output[i] = tmp;
}
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
...@@ -254,6 +268,14 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase { ...@@ -254,6 +268,14 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase {
} }
output[i] = tmp; output[i] = tmp;
} }
} else if (type == ReduceType::SUM) {
for (index_t i = start; i < end; i += step) {
T tmp = 0;
for (int j = 0; j < data_reshape_[1]; ++j) {
tmp += input[i * data_reshape_[1] + j];
}
output[i] = tmp;
}
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
...@@ -319,6 +341,16 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase { ...@@ -319,6 +341,16 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase {
} }
output[i] = tmp; output[i] = tmp;
} }
} else if (type == ReduceType::SUM) {
for (index_t i = start; i < end; i += step) {
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[0]; ++k) {
output[i] +=
input[(k * data_reshape_[1] + i) * data_reshape_[2]
+ j];
}
}
}
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
...@@ -371,6 +403,16 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase { ...@@ -371,6 +403,16 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase {
output[i * data_reshape_[2] + j] = tmp; output[i * data_reshape_[2] + j] = tmp;
} }
} }
} else if (type == ReduceType::SUM) {
for (index_t i = start; i < end; i += step) {
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[1]; ++k) {
output[i * data_reshape_[2] + j] +=
input[(i * data_reshape_[1] + k) * data_reshape_[2]
+ j];
}
}
}
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
...@@ -445,6 +487,18 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase { ...@@ -445,6 +487,18 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase {
output[i * data_reshape_[3] + j] = tmp; output[i * data_reshape_[3] + j] = tmp;
} }
} }
} else if (type == ReduceType::SUM) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
for (int k = 0; k < data_reshape_[2]; ++k) {
for (int t = 0; t < data_reshape_[0]; ++t) {
output[i * data_reshape_[3] + j] +=
input[((t * data_reshape_[1] + i) *
data_reshape_[2] + k) * data_reshape_[3] + j];
}
}
}
}
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
...@@ -513,6 +567,18 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase { ...@@ -513,6 +567,18 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase {
output[i * data_reshape_[2] + j] = tmp; output[i * data_reshape_[2] + j] = tmp;
} }
} }
} else if (type == ReduceType::SUM) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
for (int k = 0; k < data_reshape_[1]; ++k) {
for (int t = 0; t < data_reshape_[3]; ++t) {
output[i * data_reshape_[2] + j] +=
input[((i * data_reshape_[1] + k) *
data_reshape_[2] + j) * data_reshape_[3] + t];
}
}
}
}
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
...@@ -574,6 +640,12 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce1Dims( ...@@ -574,6 +640,12 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce1Dims(
tmp = std::max<uint8_t>(tmp, input[i]); tmp = std::max<uint8_t>(tmp, input[i]);
} }
output[0] = tmp; output[0] = tmp;
} else if (type == ReduceType::SUM) {
uint32_t tmp = 0;
for (int i = 0; i < data_reshape_[0]; ++i) {
tmp = tmp + input[i];
}
output[0] = static_cast<uint8_t>(tmp + data_reshape_[0] / 2);
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
...@@ -616,6 +688,14 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce2Dims( ...@@ -616,6 +688,14 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce2Dims(
} }
output[i] = tmp; output[i] = tmp;
} }
} else if (type == ReduceType::SUM) {
for (index_t i = start; i < end; i += step) {
uint32_t tmp = 0;
for (int j = 0; j < data_reshape_[0]; ++j) {
tmp += input[j * data_reshape_[1] + i];
}
output[i] = static_cast<uint8_t>(tmp + data_reshape_[0] / 2);
}
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
...@@ -647,6 +727,14 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce2Dims( ...@@ -647,6 +727,14 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce2Dims(
} }
output[i] = tmp; output[i] = tmp;
} }
} else if (type == ReduceType::SUM) {
for (index_t i = start; i < end; i += step) {
uint32_t tmp = 0;
for (int j = 0; j < data_reshape_[1]; ++j) {
tmp += input[i * data_reshape_[1] + j];
}
output[i] = static_cast<uint8_t>(tmp + data_reshape_[1] / 2);
}
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
...@@ -699,6 +787,17 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce3Dims( ...@@ -699,6 +787,17 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce3Dims(
} }
output[i] = tmp; output[i] = tmp;
} }
} else if (type == ReduceType::SUM) {
for (index_t i = start; i < end; i += step) {
uint32_t tmp = 0;
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[0]; ++k) {
tmp += input[(k * data_reshape_[1] + i) * data_reshape_[2] + j];
}
}
index_t dim = data_reshape_[0] * data_reshape_[2];
output[i] = static_cast<uint8_t>(tmp + dim / 2);
}
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
...@@ -742,6 +841,17 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce3Dims( ...@@ -742,6 +841,17 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce3Dims(
output[i * data_reshape_[2] + j] = tmp; output[i * data_reshape_[2] + j] = tmp;
} }
} }
} else if (type == ReduceType::SUM) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
uint32_t tmp = 0;
for (int k = 0; k < data_reshape_[1]; ++k) {
tmp += input[(i * data_reshape_[1] + k) * data_reshape_[2] + j];
}
output[i * data_reshape_[2] + j] =
static_cast<uint8_t>(tmp + data_reshape_[1] / 2);
}
}
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
...@@ -804,6 +914,21 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce4Dims( ...@@ -804,6 +914,21 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce4Dims(
output[i * data_reshape_[3] + j] = tmp; output[i * data_reshape_[3] + j] = tmp;
} }
} }
} else if (type == ReduceType::SUM) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
uint32_t tmp = 0;
for (int k = 0; k < data_reshape_[2]; ++k) {
for (int t = 0; t < data_reshape_[0]; ++t) {
tmp += input[((t * data_reshape_[1] + i) *
data_reshape_[2] + k) * data_reshape_[3] + j];
}
}
index_t dim = data_reshape_[0] * data_reshape_[2];
output[i * data_reshape_[3] + j] =
static_cast<uint8_t>(tmp + dim / 2);
}
}
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
...@@ -858,6 +983,21 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce4Dims( ...@@ -858,6 +983,21 @@ void ReduceOp<DeviceType::CPU, uint8_t>::Reduce4Dims(
output[i * data_reshape_[2] + j] = tmp; output[i * data_reshape_[2] + j] = tmp;
} }
} }
} else if (type == ReduceType::SUM) {
for (index_t i = start0; i < end0; i += step0) {
for (index_t j = start1; j < end1; j += step1) {
uint32_t tmp = 0;
for (int k = 0; k < data_reshape_[1]; ++k) {
for (int t = 0; t < data_reshape_[3]; ++t) {
tmp += input[((i * data_reshape_[1] + k) *
data_reshape_[2] + j) * data_reshape_[3] + t];
}
}
index_t dim = data_reshape_[1] * data_reshape_[3];
output[i * data_reshape_[2] + j] =
static_cast<uint8_t>(tmp + dim / 2);
}
}
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
......
...@@ -88,6 +88,7 @@ class ReduceType(Enum): ...@@ -88,6 +88,7 @@ class ReduceType(Enum):
MIN = 1 MIN = 1
MAX = 2 MAX = 2
PROD = 3 PROD = 3
SUM = 4
class PadType(Enum): class PadType(Enum):
......
...@@ -70,6 +70,7 @@ TFSupportedOps = [ ...@@ -70,6 +70,7 @@ TFSupportedOps = [
'Square', 'Square',
'SquaredDifference', 'SquaredDifference',
'Rsqrt', 'Rsqrt',
'Sum',
'Equal', 'Equal',
'Relu', 'Relu',
'LeakyRelu', 'LeakyRelu',
...@@ -188,6 +189,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -188,6 +189,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Max.name: ReduceType.MAX, TFOpType.Max.name: ReduceType.MAX,
TFOpType.Mean.name: ReduceType.MEAN, TFOpType.Mean.name: ReduceType.MEAN,
TFOpType.Prod.name: ReduceType.PROD, TFOpType.Prod.name: ReduceType.PROD,
TFOpType.Sum.name: ReduceType.SUM,
} }
pad_type = { pad_type = {
...@@ -268,6 +270,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -268,6 +270,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.MirrorPad.name: self.convert_pad, TFOpType.MirrorPad.name: self.convert_pad,
TFOpType.Cumsum.name: self.convert_cumsum, TFOpType.Cumsum.name: self.convert_cumsum,
TFOpType.OneHot.name: self.convert_one_hot, TFOpType.OneHot.name: self.convert_one_hot,
TFOpType.Sum.name: self.convert_reduce,
} }
self._option = option self._option = option
self._mace_net_def = mace_pb2.NetDef() self._mace_net_def = mace_pb2.NetDef()
...@@ -909,7 +912,10 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -909,7 +912,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
reduce_dims = tf_op.get_attr('reduction_indices') reduce_dims = tf_op.get_attr('reduction_indices')
except ValueError: except ValueError:
reduce_dims = [] reduce_dims = []
axis_arg.ints.extend(reduce_dims) if isinstance(reduce_dims, list):
axis_arg.ints.extend(reduce_dims)
else:
axis_arg.ints.append(reduce_dims)
keep_dims_arg = op.arg.add() keep_dims_arg = op.arg.add()
keep_dims_arg.name = MaceKeyword.mace_keepdims_str keep_dims_arg.name = MaceKeyword.mace_keepdims_str
try: try:
......
...@@ -1205,12 +1205,13 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1205,12 +1205,13 @@ class Transformer(base_converter.ConverterInterface):
if op.output[0] in self._consumers: if op.output[0] in self._consumers:
consumer = self._consumers[op.output[0]][0] consumer = self._consumers[op.output[0]][0]
# if there is a shape op, remove it too # if there is a shape op, remove it too
if (consumer.input[1] in self._producer if len(consumer.input) > 1:
and self._producer[consumer.input[1]].type if (consumer.input[1] in self._producer
== 'Shape'): and self._producer[consumer.input[1]].type
self.safe_remove_node( == 'Shape'):
self._producer[consumer.input[1]], None, self.safe_remove_node(
remove_input_tensor=True) self._producer[consumer.input[1]], None,
remove_input_tensor=True)
# remove consumer reshape # remove consumer reshape
self.safe_remove_node(consumer, op, self.safe_remove_node(consumer, op,
remove_input_tensor=True) remove_input_tensor=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册