提交 c5f92ebf 编写于 作者: L liutuo

add tf.sign op and reorder tf ops in tf_converter

上级 d7533c48
......@@ -32,7 +32,8 @@ enum EltwiseType {
EQUAL = 10,
FLOOR_DIV = 11,
CLIP = 12,
NONE = 13,
SIGN = 13,
NONE = 14,
};
} // namespace ops
......
......@@ -385,6 +385,14 @@ inline void TensorBroadcastEltwise(const OpContext *context,
}
}
break;
case SIGN:
for (index_t d = start0; d < end0; d += step0) {
for (index_t i = start1; i < end1; i += step1) {
output[i + d * common_size] =
Sign(input0[i + d * common_size]);
}
}
break;
default:LOG(FATAL) << "Eltwise op not support type " << type;
}
}, 0, diff_size, 1, 0, common_size, 1);
......@@ -410,7 +418,6 @@ inline void TensorEltwise(const OpContext *context,
for (index_t i = start; i < end; i += step) {
output[i] = input0[i] + input1[i];
}
} else {
std::vector<float> coeff_copy = coeff;
if (swapped) {
......@@ -426,7 +433,6 @@ inline void TensorEltwise(const OpContext *context,
for (index_t i = start; i < end; i += step) {
output[i] = input0[i] - input1[i];
}
} else {
for (index_t i = start; i < end; i += step) {
output[i] = input1[i] - input0[i];
......@@ -437,7 +443,6 @@ inline void TensorEltwise(const OpContext *context,
for (index_t i = start; i < end; i += step) {
output[i] = input0[i] * input1[i];
}
break;
case DIV:
if (!swapped) {
......@@ -466,19 +471,16 @@ inline void TensorEltwise(const OpContext *context,
for (index_t i = start; i < end; i += step) {
output[i] = std::min(input0[i], input1[i]);
}
break;
case MAX:
for (index_t i = start; i < end; i += step) {
output[i] = std::max(input0[i], input1[i]);
}
break;
case SQR_DIFF:
for (index_t i = start; i < end; i += step) {
output[i] = std::pow(input0[i] - input1[i], 2.f);
}
break;
case POW:
if (!swapped) {
......@@ -511,6 +513,11 @@ inline void TensorEltwise(const OpContext *context,
output[i] = std::fmaxf(coeff[0], std::fminf(coeff[1], input0[i]));
}
break;
case SIGN:
for (index_t i = start; i < end; i += step) {
output[i] = Sign(input0[i]);
}
break;
default:LOG(FATAL) << "Eltwise op not support type " << type;
}
}, 0, size, 1);
......@@ -563,7 +570,6 @@ inline void TensorScalarEltwise(const OpContext *context,
for (index_t i = start; i < end; i += step) {
output[i] = input0[i] * input1;
}
break;
case DIV:
if (!swapped) {
......@@ -637,6 +643,11 @@ inline void TensorScalarEltwise(const OpContext *context,
output[i] = std::fmaxf(coeff[0], std::fminf(coeff[1], input0[i]));
}
break;
case SIGN:
for (index_t i = start; i < end; i += step) {
output[i] = Sign(input0[i]);
}
break;
default:LOG(FATAL) << "Eltwise op not support type " << type;
}
}, 0, size, 1);
......@@ -869,6 +880,15 @@ inline void TensorEltwisePerChannel(const OpContext *context,
}
}
break;
case SIGN:
for (index_t b = start0; b < end0; b += step0) {
for (index_t c = start1; c < end1; c += step1) {
for (index_t i = 0; i < image_size; ++i) {
output[i] = Sign(input0[i]);
}
}
}
break;
default:LOG(FATAL) << "Eltwise op not support type " << type;
}
}, 0, batch0, 1, 0, channel, 1);
......
......@@ -22,6 +22,10 @@ namespace ops {
inline bool IsLogicalType(EltwiseType type) { return type == EQUAL; }
template <typename T> int Sign(T val) {
return (T(0) < val) - (val < T(0));
}
} // namespace ops
} // namespace mace
......
......@@ -89,13 +89,16 @@ __kernel void eltwise(OUT_OF_RANGE_PARAMS
#endif
#elif ELTWISE_TYPE == 12
out = fmax(coeff0, fmin(coeff1, in0));
#elif ELTWISE_TYPE == 13
out = sign(in0);
#endif
#if defined(NOT_DIVISIBLE_FOUR) && \
((ELTWISE_TYPE == 3 || ELTWISE_TYPE == 9 || ELTWISE_TYPE == 11) \
|| ((defined(INPUT_SCALAR) || defined(INPUT_TENSOR_BC_CHAN)) && \
(ELTWISE_TYPE == 0 || ELTWISE_TYPE == 1 || ELTWISE_TYPE == 4 || \
ELTWISE_TYPE == 5 || ELTWISE_TYPE == 8 || ELTWISE_TYPE == 12)))
ELTWISE_TYPE == 5 || ELTWISE_TYPE == 8 || ELTWISE_TYPE == 12 || \
ELTWISE_TYPE == 13)))
const int remain_channel = channel - 4 * chan_idx;
if (remain_channel < 4) {
switch (remain_channel) {
......
......@@ -243,6 +243,8 @@ TEST_F(EltwiseOpTest, CPUSimpleScalarScalar) {
ops::EltwiseType::NEG, 1, 2, -1);
SimpleScalarScalar<DeviceType::CPU, float, float>(
ops::EltwiseType::ABS, -1, 3, 1);
SimpleScalarScalar<DeviceType::CPU, float, float>(
ops::EltwiseType::SIGN, -2, 3, -1);
SimpleScalarScalar<DeviceType::CPU, int32_t, int32_t>(
ops::EltwiseType::EQUAL, 1, 3, 0);
SimpleScalarScalar<DeviceType::CPU, int32_t, int32_t>(
......@@ -285,6 +287,9 @@ TEST_F(EltwiseOpTest, CPUSimpleTensorScalar) {
SimpleTensorScalar<DeviceType::CPU, int32_t, int32_t>(
ops::EltwiseType::EQUAL, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 3,
{0, 0, 1, 0, 0, 0});
SimpleTensorScalar<DeviceType::CPU, float, float>(
ops::EltwiseType::SIGN, {1, 1, 2, 3}, {1, 2, -3, 0, -5, -6}, 3,
{1, 1, -1, 0, -1, -1});
}
TEST_F(EltwiseOpTest, GPUSimpleTensorScalar) {
......@@ -320,6 +325,9 @@ TEST_F(EltwiseOpTest, GPUSimpleTensorScalar) {
SimpleTensorScalar<DeviceType::GPU, float, float>(
ops::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 1,
{0, 1, 4, 9, 16, 25});
SimpleTensorScalar<DeviceType::GPU, float, float>(
ops::EltwiseType::SIGN, {1, 1, 2, 3}, {-1, 2, 3, 0, -5, -6}, 3,
{-1, 1, 1, 0, -1, -1});
}
TEST_F(EltwiseOpTest, CPUSimpleTensorVector) {
......@@ -845,6 +853,7 @@ TEST_F(EltwiseOpTest, RandomTensorScalarFloat) {
RandomTensorScalar<float>(ops::EltwiseType::NEG, {1, 32, 32, 32});
RandomTensorScalar<float>(ops::EltwiseType::ABS, {3, 31, 37, 17});
RandomTensorScalar<float>(ops::EltwiseType::SQR_DIFF, {3, 31, 37, 17});
RandomTensorScalar<float>(ops::EltwiseType::SIGN, {3, 31, 37, 17});
}
TEST_F(EltwiseOpTest, RandomTensorScalarHalf) {
......@@ -857,6 +866,7 @@ TEST_F(EltwiseOpTest, RandomTensorScalarHalf) {
RandomTensorScalar<half>(ops::EltwiseType::NEG, {1, 32, 32, 32});
RandomTensorScalar<half>(ops::EltwiseType::ABS, {3, 31, 37, 17});
RandomTensorScalar<half>(ops::EltwiseType::SQR_DIFF, {3, 31, 37, 17});
RandomTensorScalar<half>(ops::EltwiseType::SIGN, {3, 31, 37, 17});
}
TEST_F(EltwiseOpTest, RandomTensorVecFloat) {
......
......@@ -65,6 +65,7 @@ class EltwiseType(Enum):
EQUAL = 10
FLOOR_DIV = 11
CLIP = 12
SIGN = 13
class ReduceType(Enum):
......
......@@ -50,80 +50,82 @@ tf_block_size = 'block_size'
tf_squeeze_dims = 'squeeze_dims'
tf_axis = 'axis'
# Keep in lexicographical order
TFSupportedOps = [
'Abs',
'Add',
'ArgMax',
'AvgPool',
'BatchMatMul',
'BatchToSpaceND',
'BiasAdd',
'Cast',
'ConcatV2',
'Const',
'Conv2D',
'DepthwiseConv2dNative',
'Conv2DBackpropInput',
'BiasAdd',
'Add',
'Sub',
'Mul',
'Cumsum',
'DepthwiseConv2dNative',
'DepthToSpace',
'Div',
'Min',
'Minimum',
'Equal',
'ExpandDims',
'FakeQuantWithMinMaxVars',
'FakeQuantWithMinMaxArgs',
'Fill',
'FloorDiv',
'FusedBatchNorm',
'Gather',
'GatherV2',
'Identity',
'LeakyRelu',
'MatMul',
'Max',
'Maximum',
'MaxPool',
'Mean',
'Min',
'Minimum',
'MirrorPad',
'Mul',
'Neg',
'Abs',
'OneHot',
'Pack',
'Pad',
'PadV2',
'Placeholder',
'Pow',
'Prod',
'RealDiv',
'Square',
'SquaredDifference',
'Rsqrt',
'Sum',
'Equal',
'Relu',
'LeakyRelu',
'Relu6',
'Tanh',
'Sigmoid',
'Fill',
'FusedBatchNorm',
'AvgPool',
'MaxPool',
'ExpandDims',
'Squeeze',
'MatMul',
'BatchMatMul',
'Identity',
'Reshape',
'Shape',
'Transpose',
'Softmax',
'ResizeBicubic',
'ResizeBilinear',
'ResizeNearestNeighbor',
'Placeholder',
'ReverseV2',
'Rsqrt',
'Shape',
'Sigmoid',
'Sign',
'Slice',
'Softmax',
'SpaceToBatchND',
'BatchToSpaceND',
'DepthToSpace',
'SpaceToDepth',
'Pad',
'PadV2',
'ConcatV2',
'Mean',
'Prod',
'Const',
'Gather',
'GatherV2',
'StridedSlice',
'Slice',
'ReverseV2',
'Stack',
'Pack',
'Unstack',
'Unpack',
'Cast',
'ArgMax',
'Split',
'FakeQuantWithMinMaxVars',
'FakeQuantWithMinMaxArgs',
'FloorDiv',
'Sqrt',
'MirrorPad',
'Cumsum',
'OneHot',
'Square',
'SquaredDifference',
'Squeeze',
'Stack',
'StridedSlice',
'Sub',
'Sum',
'Tanh',
'Tile',
'Transpose',
'Unpack',
'Unstack',
]
TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str)
......@@ -176,6 +178,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Rsqrt.name: EltwiseType.POW,
TFOpType.Sqrt.name: EltwiseType.POW,
TFOpType.Equal.name: EltwiseType.EQUAL,
TFOpType.Sign.name: EltwiseType.SIGN
}
activation_type = {
......@@ -201,80 +204,83 @@ class TensorflowConverter(base_converter.ConverterInterface):
}
def __init__(self, option, src_model_file):
# Keep in lexicographical order
self._op_converters = {
TFOpType.Abs.name: self.convert_elementwise,
TFOpType.Add.name: self.convert_add,
TFOpType.ArgMax.name: self.convert_argmax,
TFOpType.AvgPool.name: self.convert_pooling,
TFOpType.BatchMatMul.name: self.convert_matmul,
TFOpType.BatchToSpaceND.name: self.convert_space_batch,
TFOpType.BiasAdd.name: self.convert_biasadd,
TFOpType.Cast.name: self.convert_cast,
TFOpType.ConcatV2.name: self.convert_concat,
TFOpType.Const.name: self.convert_nop,
TFOpType.Conv2D.name: self.convert_conv2d,
TFOpType.DepthwiseConv2dNative.name: self.convert_conv2d,
TFOpType.Conv2DBackpropInput.name: self.convert_conv2d,
TFOpType.BiasAdd.name: self.convert_biasadd,
TFOpType.Add.name: self.convert_add,
TFOpType.Sub.name: self.convert_elementwise,
TFOpType.Mul.name: self.convert_elementwise,
TFOpType.Cumsum.name: self.convert_cumsum,
TFOpType.DepthwiseConv2dNative.name: self.convert_conv2d,
TFOpType.DepthToSpace.name: self.convert_space_depth,
TFOpType.Div.name: self.convert_elementwise,
TFOpType.Minimum.name: self.convert_elementwise,
TFOpType.Equal.name: self.convert_elementwise,
TFOpType.ExpandDims.name: self.convert_expand_dims,
TFOpType.FakeQuantWithMinMaxVars.name: self.convert_fake_quantize,
TFOpType.FakeQuantWithMinMaxArgs.name: self.convert_fake_quantize,
TFOpType.Fill.name: self.convert_fill,
TFOpType.FloorDiv.name: self.convert_elementwise,
TFOpType.FusedBatchNorm.name: self.convert_fused_batchnorm,
TFOpType.Gather.name: self.convert_gather,
TFOpType.GatherV2.name: self.convert_gather,
TFOpType.Identity.name: self.convert_identity,
TFOpType.LeakyRelu.name: self.convert_activation,
TFOpType.MatMul.name: self.convert_matmul,
TFOpType.Max.name: self.convert_reduce,
TFOpType.Maximum.name: self.convert_elementwise,
TFOpType.MaxPool.name: self.convert_pooling,
TFOpType.Mean.name: self.convert_reduce,
TFOpType.Min.name: self.convert_reduce,
TFOpType.Minimum.name: self.convert_elementwise,
TFOpType.MirrorPad.name: self.convert_pad,
TFOpType.Mul.name: self.convert_elementwise,
TFOpType.Neg.name: self.convert_elementwise,
TFOpType.Abs.name: self.convert_elementwise,
TFOpType.OneHot.name: self.convert_one_hot,
TFOpType.Pack.name: self.convert_stack,
TFOpType.Pad.name: self.convert_pad,
TFOpType.PadV2.name: self.convert_pad,
TFOpType.Placeholder.name: self.convert_nop,
TFOpType.Pow.name: self.convert_elementwise,
TFOpType.Prod.name: self.convert_reduce,
TFOpType.Sub.name: self.convert_elementwise,
TFOpType.RealDiv.name: self.convert_elementwise,
TFOpType.SquaredDifference.name: self.convert_elementwise,
TFOpType.Square.name: self.convert_elementwise,
TFOpType.Rsqrt.name: self.convert_elementwise,
TFOpType.Equal.name: self.convert_elementwise,
TFOpType.Min.name: self.convert_reduce,
TFOpType.Max.name: self.convert_reduce,
TFOpType.Mean.name: self.convert_reduce,
TFOpType.Prod.name: self.convert_reduce,
TFOpType.Relu.name: self.convert_activation,
TFOpType.LeakyRelu.name: self.convert_activation,
TFOpType.Relu6.name: self.convert_activation,
TFOpType.Tanh.name: self.convert_activation,
TFOpType.Sigmoid.name: self.convert_activation,
TFOpType.Fill.name: self.convert_fill,
TFOpType.FusedBatchNorm.name: self.convert_fused_batchnorm,
TFOpType.AvgPool.name: self.convert_pooling,
TFOpType.MaxPool.name: self.convert_pooling,
TFOpType.MatMul.name: self.convert_matmul,
TFOpType.BatchMatMul.name: self.convert_matmul,
TFOpType.Identity.name: self.convert_identity,
TFOpType.Reshape.name: self.convert_reshape,
TFOpType.Shape.name: self.convert_shape,
TFOpType.ExpandDims.name: self.convert_expand_dims,
TFOpType.Squeeze.name: self.convert_squeeze,
TFOpType.Transpose.name: self.convert_transpose,
TFOpType.Softmax.name: self.convert_softmax,
TFOpType.ResizeBicubic.name: self.convert_resize_bicubic,
TFOpType.ResizeBilinear.name: self.convert_resize_bilinear,
TFOpType.ResizeNearestNeighbor.name: self.convert_resize_nearest_neighbor, # noqa
TFOpType.Placeholder.name: self.convert_nop,
TFOpType.ReverseV2.name: self.convert_reverse,
TFOpType.Shape.name: self.convert_shape,
TFOpType.Sigmoid.name: self.convert_activation,
TFOpType.Sign.name: self.convert_elementwise,
TFOpType.Slice.name: self.convert_slice,
TFOpType.Softmax.name: self.convert_softmax,
TFOpType.SpaceToBatchND.name: self.convert_space_batch,
TFOpType.BatchToSpaceND.name: self.convert_space_batch,
TFOpType.DepthToSpace.name: self.convert_space_depth,
TFOpType.SpaceToDepth.name: self.convert_space_depth,
TFOpType.Pad.name: self.convert_pad,
TFOpType.PadV2.name: self.convert_pad,
TFOpType.ConcatV2.name: self.convert_concat,
TFOpType.Const.name: self.convert_nop,
TFOpType.Gather.name: self.convert_gather,
TFOpType.GatherV2.name: self.convert_gather,
TFOpType.StridedSlice.name: self.convert_stridedslice,
TFOpType.Slice.name: self.convert_slice,
TFOpType.ReverseV2.name: self.convert_reverse,
TFOpType.Pack.name: self.convert_stack,
TFOpType.Stack.name: self.convert_stack,
TFOpType.Unpack.name: self.convert_unstack,
TFOpType.Unstack.name: self.convert_unstack,
TFOpType.Cast.name: self.convert_cast,
TFOpType.ArgMax.name: self.convert_argmax,
TFOpType.Split.name: self.convert_split,
TFOpType.FakeQuantWithMinMaxVars.name: self.convert_fake_quantize,
TFOpType.FakeQuantWithMinMaxArgs.name: self.convert_fake_quantize,
TFOpType.FloorDiv.name: self.convert_elementwise,
TFOpType.Sqrt.name: self.convert_elementwise,
TFOpType.MirrorPad.name: self.convert_pad,
TFOpType.Cumsum.name: self.convert_cumsum,
TFOpType.OneHot.name: self.convert_one_hot,
TFOpType.Squeeze.name: self.convert_squeeze,
TFOpType.Stack.name: self.convert_stack,
TFOpType.StridedSlice.name: self.convert_stridedslice,
TFOpType.Sum.name: self.convert_reduce,
TFOpType.Tile.name: self.convert_tile,
TFOpType.Transpose.name: self.convert_transpose,
TFOpType.Unpack.name: self.convert_unstack,
TFOpType.Unstack.name: self.convert_unstack,
}
self._option = option
self._mace_net_def = mace_pb2.NetDef()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册