diff --git a/mace/kernels/arm/conv_2d_neon_1x1.cc b/mace/kernels/arm/conv_2d_neon_1x1.cc index b4c4b828542d18eb91cd61151adce4e002f2ba3f..28aa6e4f824342fa4bf6e5d624718eca19428a45 100644 --- a/mace/kernels/arm/conv_2d_neon_1x1.cc +++ b/mace/kernels/arm/conv_2d_neon_1x1.cc @@ -12,10 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#if defined(MACE_ENABLE_NEON) && defined(__aarch64__) -#include -#endif - #include "mace/kernels/arm/conv_2d_neon.h" #include "mace/kernels/gemm.h" diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index 7b88c1251e2de100d353fd9ed7b5fb3d7ce1039c..4e472b7fd4169a53de41b76150499db70c9a2c3f 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -32,6 +32,7 @@ #include "mace/kernels/arm/conv_2d_neon.h" #include "mace/kernels/arm/conv_winograd.h" #include "mace/kernels/gemmlowp_util.h" +#include "mace/kernels/quantize.h" #include "mace/utils/utils.h" #ifdef MACE_ENABLE_OPENCL @@ -826,18 +827,11 @@ struct Conv2dFunctor : Conv2dFunctorBase { int32_t *quantized_multiplier, int *right_shift) { float real_multiplier = lhs_scale * rhs_scale / output_scale; MACE_CHECK(real_multiplier > 0.f && real_multiplier < 1.f, real_multiplier); + int exponent; - const double significand = std::frexp(real_multiplier, &exponent); + QuantizeMultiplier(real_multiplier, quantized_multiplier, &exponent); *right_shift = -exponent; - int64_t q = static_cast(std::round(significand * (1ll << 31))); - MACE_CHECK(q <= (1ll << 31)); - if (q == (1ll << 31)) { - q /= 2; - (*right_shift)--; - } MACE_CHECK(*right_shift >= 0); - MACE_CHECK(q <= std::numeric_limits::max()); - *quantized_multiplier = static_cast(q); } typedef gemmlowp::VectorMap diff --git a/mace/kernels/fixpoint.h b/mace/kernels/fixpoint.h new file mode 100644 index 0000000000000000000000000000000000000000..47f0a8d89e54f47ab86bfe5c2cb97b0b1ccfd8d6 --- /dev/null +++ b/mace/kernels/fixpoint.h @@ -0,0 +1,64 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_FIXPOINT_H_ +#define MACE_KERNELS_FIXPOINT_H_ + +#if defined(MACE_ENABLE_NEON) +#include +#endif + +#include +#include "mace/core/types.h" + +namespace mace { +namespace kernels { + +inline uint8_t FindMax(const uint8_t *xs, const index_t size) { + uint8_t max_value = 0; + index_t i = 0; +#if defined(MACE_ENABLE_NEON) + uint8x16_t vmax16_0 = vdupq_n_u8(0); + uint8x16_t vmax16_1 = vdupq_n_u8(0); + for (; i <= size - 32; i += 32) { + vmax16_0 = vmaxq_u8(vmax16_0, vld1q_u8(xs + i + 0)); + vmax16_1 = vmaxq_u8(vmax16_1, vld1q_u8(xs + i + 16)); + } + uint8x16_t vmax16 = vmaxq_u8(vmax16_0, vmax16_1); + if (i <= size - 16) { + vmax16 = vmaxq_u8(vmax16, vld1q_u8(xs + i)); + i += 16; + } + uint8x8_t vmax8 = vmax_u8(vget_low_u8(vmax16), vget_high_u8(vmax16)); + if (i <= size - 8) { + vmax8 = vmax_u8(vmax8, vld1_u8(xs + i)); + i += 8; + } + uint8x8_t vmax4 = vmax_u8(vmax8, vext_u8(vmax8, vmax8, 4)); + uint8x8_t vmax2 = vmax_u8(vmax4, vext_u8(vmax4, vmax4, 2)); + uint8x8_t vmax1 = vpmax_u8(vmax2, vmax2); + max_value = vget_lane_u8(vmax1, 0); +#endif + for (; i < size; ++i) { + max_value = std::max(max_value, xs[i]); + } + return max_value; +} + + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_FIXPOINT_H_ + diff --git a/mace/kernels/fixpoint_test.cc b/mace/kernels/fixpoint_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8b926cd9d76c1cf26df9f8b995c52ad5cddd4c38 --- /dev/null +++ b/mace/kernels/fixpoint_test.cc @@ -0,0 +1,54 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "mace/kernels/fixpoint.h" + +namespace mace { +namespace kernels { +namespace test { + +namespace { +void TestFindMax(int test_count) { + static unsigned int seed = time(NULL); + std::vector input(test_count); + uint8_t expected_max = 0; + for (int i = 0; i < test_count; ++i) { + input[i] = rand_r(&seed) % 255; + expected_max = std::max(expected_max, input[i]); + } + + uint8_t actual_max = FindMax(input.data(), input.size()); + EXPECT_EQ(expected_max, actual_max); +} +} // namespace + +TEST(FixpointTest, FindMax) { + TestFindMax(1); + TestFindMax(2); + TestFindMax(4); + TestFindMax(8); + TestFindMax(32); + TestFindMax(33); + TestFindMax(127); +} + +} // namespace test +} // namespace kernels +} // namespace mace + diff --git a/mace/kernels/gemmlowp_util.h b/mace/kernels/gemmlowp_util.h index ef6efe918022d366912d5e1533f81f315f363aa1..9b8e400be9d71e3a5c70aa7b2abef3458f448e37 100644 --- a/mace/kernels/gemmlowp_util.h +++ b/mace/kernels/gemmlowp_util.h @@ -15,7 +15,6 @@ #ifndef MACE_KERNELS_GEMMLOWP_UTIL_H_ #define MACE_KERNELS_GEMMLOWP_UTIL_H_ -#include #include "public/gemmlowp.h" namespace mace { diff --git a/mace/kernels/quantize.h b/mace/kernels/quantize.h index 975d7dbbff780bc4deae5f278aace65c0c720f8f..3dc4c48ee94416fe51e4055f8aa955f6644271b6 100644 --- a/mace/kernels/quantize.h +++ b/mace/kernels/quantize.h @@ -137,6 +137,24 @@ inline void Dequantize(const T *input, } } +inline void QuantizeMultiplier(double multiplier, + int32_t* output_multiplier, + int32_t* shift) { + if (multiplier == 0.f) { + *output_multiplier = 0; + *shift = 0; + return; + } + const double q = std::frexp(multiplier, shift); + auto qint = static_cast(roundl(q * (1ll << 31))); + if (qint == (1ll << 31)) { + qint /= 2; + ++*shift; + } + *output_multiplier = static_cast(qint); + MACE_CHECK(*output_multiplier <= std::numeric_limits::max()); +} + template struct QuantizeFunctor; diff --git a/mace/kernels/softmax.h b/mace/kernels/softmax.h index 406f87afc579c7c0528454a94164f03cc8da2825..62e089c5c3000ebe0113454c7f72b2accedf91a0 100644 --- a/mace/kernels/softmax.h +++ b/mace/kernels/softmax.h @@ -25,6 +25,9 @@ #include "mace/core/tensor.h" #include "mace/public/mace.h" #include "mace/utils/utils.h" +#include "mace/kernels/fixpoint.h" +#include "mace/kernels/gemmlowp_util.h" +#include "mace/kernels/quantize.h" #ifdef MACE_ENABLE_OPENCL #include "mace/core/runtime/opencl/cl2_header.h" @@ -120,6 +123,235 @@ struct SoftmaxFunctor { } }; +static const int kInputDeltaIntBits = 5; +static const int kSumExpIntBits = 12; + +template<> +struct SoftmaxFunctor { + MaceStatus operator()(const Tensor *input, + Tensor *output, + StatsFuture *future) { + MACE_UNUSED(future); + // Ignore range stat, fix range to [0, 1]. For large depth, each softmax + // output may be too small (<<1), which causes precision issue. But it is + // fine when doing classification inference. + output->SetScale(1.f / 255); + output->SetZeroPoint(0); + + using FixPointInputDelta = gemmlowp::FixedPoint; + using FixPointSumExp = gemmlowp::FixedPoint; + using FixPoint0 = gemmlowp::FixedPoint; + + MACE_CHECK(input->dim_size() == 2 || input->dim_size() == 4, + "Softmax does not support dim size: ", + input->dim_size()); + index_t batch; + index_t depth; + + if (input->dim_size() == 2) { + batch = input->dim(0); + depth = input->dim(1); + } else { + batch = input->dim(0) * input->dim(1) * input->dim(2); + depth = input->dim(3); + } + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard output_guard(output); + const uint8_t *input_data = input->data(); + float input_scale = input->scale(); + uint8_t *output_data = output->mutable_data(); + + // If depth is short, do it using float32. Float computation should not + // be here, but as long as it is on CPU, it is fine. + if (depth < 32) { +#pragma omp parallel for + for (index_t b = 0; b < batch; ++b) { + const uint8_t *input_ptr = input_data + b * depth; + uint8_t *output_ptr = output_data + b * depth; + + uint8_t max_value = FindMax(input_ptr, depth); + float sum = 0; + std::vector depth_cache(depth); + for (index_t d = 0; d < depth; ++d) { + float exp_value = ::exp((static_cast(input_ptr[d]) - max_value) + * input_scale); + sum += exp_value; + depth_cache[d] = exp_value; + } + + sum = std::max(sum, std::numeric_limits::min()); + for (index_t d = 0; d < depth; ++d) { + double output_f = depth_cache[d] / sum; + output_ptr[d] = static_cast(output_f * 255); + } + } + return MACE_SUCCESS; + } + + int32_t scale_q = static_cast(std::min( + static_cast(input_scale) * (1 << (31 - kInputDeltaIntBits)), + (1ll << 31) - 1.0)); + int32_t input_delta_limit = -((1ll << 31) - 1) / scale_q; + +#pragma omp parallel for + for (index_t b = 0; b < batch; ++b) { + const uint8_t *input_ptr = input_data + b * depth; + uint8_t *output_ptr = output_data + b * depth; + + FixPointSumExp sum = FixPointSumExp::Zero(); + uint8_t max_value = FindMax(input_ptr, depth); + index_t d = 0; + + // Neon optimization is not useful so far as we benchmark. + // Enable it when we find a case that proves it useful. +#if 0 && defined(MACE_ENABLE_NEON) + using FixPointInputDeltaInt32x4 = gemmlowp::FixedPoint; + using FixPointSumExpInt32x4 = gemmlowp::FixedPoint; + using FixPoint0Int32x4 = gemmlowp::FixedPoint; + + int16x8_t vmax_value_s16 = vdupq_n_s16(max_value); + int32x4_t vinput_delta_limit_s32 = vdupq_n_s32(input_delta_limit); + + FixPointSumExpInt32x4 vsum_s32_fp_0 = FixPointSumExpInt32x4::Zero(); + FixPointSumExpInt32x4 vsum_s32_fp_1 = FixPointSumExpInt32x4::Zero(); + FixPointSumExpInt32x4 vzero_s32_fp = FixPointSumExpInt32x4::Zero(); + + int32_t scale_q_multipler, scale_q_shift; + QuantizeMultiplier(scale_q, &scale_q_multipler, &scale_q_shift); + FixPointInputDeltaInt32x4 vscale_s32_fp = + FixPointInputDeltaInt32x4::FromScalarRaw(scale_q); + FixPoint0Int32x4 vscale_s32_fp_multiplier = + FixPoint0Int32x4::FromScalarRaw(scale_q_multipler); + + for (; d <= depth - 8; d += 8) { + uint16x8_t vinput_u16 = vmovl_u8(vld1_u8(input_ptr + d)); + int16x8_t vinput_delta_s16 = + vsubq_s16(vreinterpretq_s16_u16(vinput_u16), vmax_value_s16); + int32x4_t input_delta_s32_0 = vmovl_s16(vget_low_s16(vinput_delta_s16)); + int32x4_t + input_delta_s32_1 = vmovl_s16(vget_high_s16(vinput_delta_s16)); + int32x4_t vmask_s32_0 = + gemmlowp::MaskIfGreaterThanOrEqual(input_delta_s32_0, + vinput_delta_limit_s32); + int32x4_t vmask_s32_1 = + gemmlowp::MaskIfGreaterThanOrEqual(input_delta_s32_1, + vinput_delta_limit_s32); + FixPointInputDeltaInt32x4 + vscaled_input_delta_s32_fp_0 = vscale_s32_fp_multiplier * + FixPointInputDeltaInt32x4::FromRaw( + gemmlowp::ShiftLeft(input_delta_s32_0, scale_q_shift)); + FixPointInputDeltaInt32x4 + vscaled_input_delta_s32_fp_1 = vscale_s32_fp_multiplier * + FixPointInputDeltaInt32x4::FromRaw( + gemmlowp::ShiftLeft(input_delta_s32_1, scale_q_shift)); + FixPointSumExpInt32x4 vexp_s32_fp_0 = gemmlowp::Rescale( + exp_on_negative_values(vscaled_input_delta_s32_fp_0)); + FixPointSumExpInt32x4 vexp_s32_fp_1 = gemmlowp::Rescale( + exp_on_negative_values(vscaled_input_delta_s32_fp_1)); + FixPointSumExpInt32x4 vmasked_exp_s32_fp_0 = + SelectUsingMask(vmask_s32_0, vexp_s32_fp_0, vzero_s32_fp); + FixPointSumExpInt32x4 vmasked_exp_s32_fp_1 = + SelectUsingMask(vmask_s32_1, vexp_s32_fp_1, vzero_s32_fp); + vsum_s32_fp_0 = vsum_s32_fp_0 + vmasked_exp_s32_fp_0; + vsum_s32_fp_1 = vsum_s32_fp_1 + vmasked_exp_s32_fp_1; + } + int32x4_t vsum_s32 = (vsum_s32_fp_0 + vsum_s32_fp_1).raw(); + int32x2_t vsum_reduced_2_s32 = + vadd_s32(vget_low_s32(vsum_s32), vget_high_s32(vsum_s32)); + int32x2_t vsum_reduced_1_s32 = + vpadd_s32(vsum_reduced_2_s32, vsum_reduced_2_s32); + sum = FixPointSumExp::FromRaw(vget_lane_s32(vsum_reduced_1_s32, 0)); +#endif + for (; d < depth; ++d) { + int32_t input_delta = static_cast(input_ptr[d]) - max_value; + if (input_delta >= input_delta_limit) { + int32_t scaled_input_delta_q = scale_q * input_delta; + FixPointInputDelta scaled_input_delta_fp = + FixPointInputDelta::FromRaw(scaled_input_delta_q); + sum = sum + gemmlowp::Rescale( + exp_on_negative_values(scaled_input_delta_fp)); + } + } + + int32_t sum_q = sum.raw(); + int left_zero_count = + __builtin_clz(static_cast(sum_q)); + int tail_count = kSumExpIntBits - left_zero_count; + int32_t fractional_q0 = static_cast( + (static_cast(sum_q) << left_zero_count) - + (static_cast(1) << 31)); + FixPoint0 recip_sum_q0 = gemmlowp::one_over_one_plus_x_for_x_in_0_1( + FixPoint0::FromRaw(fractional_q0)); + + d = 0; + + // Neon optimization is not useful so far as we benchmark. + // Enable it when we find a case that proves it useful. +#if 0 && defined(MACE_ENABLE_NEON) + FixPoint0Int32x4 vrecip_sum_q0_s32_fp = + FixPoint0Int32x4::FromScalarRaw(recip_sum_q0.raw()); + int16x8_t vinput_delta_limit_s16 = vdupq_n_s16(input_delta_limit); + for (; d <= depth - 8; d += 8) { + uint16x8_t vinput_u16 = vmovl_u8(vld1_u8(input_ptr + d)); + int16x8_t vinput_delta_s16 = + vsubq_s16(vreinterpretq_s16_u16(vinput_u16), vmax_value_s16); + int32x4_t input_delta_s32_0 = vmovl_s16(vget_low_s16(vinput_delta_s16)); + int32x4_t + input_delta_s32_1 = vmovl_s16(vget_high_s16(vinput_delta_s16)); + int16x8_t vmask_s16 = gemmlowp::MaskIfGreaterThanOrEqual( + vinput_delta_s16, + vinput_delta_limit_s16); + FixPointInputDeltaInt32x4 + vscaled_input_delta_s32_fp_0 = vscale_s32_fp_multiplier * + FixPointInputDeltaInt32x4::FromRaw( + gemmlowp::ShiftLeft(input_delta_s32_0, scale_q_shift)); + FixPointInputDeltaInt32x4 + vscaled_input_delta_s32_fp_1 = vscale_s32_fp_multiplier * + FixPointInputDeltaInt32x4::FromRaw( + gemmlowp::ShiftLeft(input_delta_s32_1, scale_q_shift)); + FixPoint0Int32x4 vexp_s32_fp_0 = + exp_on_negative_values(vscaled_input_delta_s32_fp_0); + FixPoint0Int32x4 vexp_s32_fp_1 = + exp_on_negative_values(vscaled_input_delta_s32_fp_1); + int32x4_t voutput_data_s32_0 = gemmlowp::RoundingDivideByPOT( + (vrecip_sum_q0_s32_fp * vexp_s32_fp_0).raw(), tail_count + 31 - 8); + int32x4_t voutput_data_s32_1 = gemmlowp::RoundingDivideByPOT( + (vrecip_sum_q0_s32_fp * vexp_s32_fp_1).raw(), tail_count + 31 - 8); + int16x8_t voutput_data_s16 = + vcombine_s16(vqmovn_s32(voutput_data_s32_0), + vqmovn_s32(voutput_data_s32_1)); + int16x8_t masked_voutput_data_s16 = + gemmlowp::SelectUsingMask(vmask_s16, + voutput_data_s16, + vdupq_n_s16(0)); + uint8x8_t voutput_u8 = vqmovun_s16(masked_voutput_data_s16); + vst1_u8(output_ptr + d, voutput_u8); + } +#endif + for (; d < depth; ++d) { + int32_t input_delta = static_cast(input_ptr[d]) - max_value; + if (input_delta >= input_delta_limit) { + int32_t scaled_input_delta_q = scale_q * input_delta; + FixPointInputDelta scaled_input_delta_fp = + FixPointInputDelta::FromRaw(scaled_input_delta_q); + + FixPoint0 exp = exp_on_negative_values(scaled_input_delta_fp); + int32_t output_data = gemmlowp::RoundingDivideByPOT( + (recip_sum_q0 * exp).raw(), tail_count + 31 - 8); + + output_ptr[d] = std::max(std::min(output_data, 255), 0); + } + } + } + + return MACE_SUCCESS; + } +}; + #ifdef MACE_ENABLE_OPENCL template struct SoftmaxFunctor { diff --git a/mace/ops/softmax.cc b/mace/ops/softmax.cc index 6c1a895b76015488eb5e9788f2b51345ea5e2dd0..64586329b7a9817d4d85b63a722d305c2e5f0f17 100644 --- a/mace/ops/softmax.cc +++ b/mace/ops/softmax.cc @@ -23,6 +23,11 @@ void Register_Softmax(OperatorRegistryBase *op_registry) { .TypeConstraint("T") .Build(), SoftmaxOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Softmax") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + SoftmaxOp); #ifdef MACE_ENABLE_OPENCL MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Softmax") diff --git a/mace/ops/softmax_benchmark.cc b/mace/ops/softmax_benchmark.cc index 2e9aed7c52b88dc5c7d9d618c90d50231d2181d4..009d1aaeca7f5e9945d4292d593e641556e08b34 100644 --- a/mace/ops/softmax_benchmark.cc +++ b/mace/ops/softmax_benchmark.cc @@ -68,6 +68,42 @@ void SoftmaxBenchmark( } net.Sync(); } + +template <> +void SoftmaxBenchmark( + int iters, int batch, int channels, int height, int width) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + net.AddRandomInput( + "Input", {batch, height, width, channels}); + + OpDefBuilder("Softmax", "SoftmaxBM") + .Input("Input") + .Output("Output") + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + + net.Setup(DeviceType::CPU); + + Tensor *output = net.GetTensor("Output"); + output->SetScale(0); + output->SetZeroPoint(1); + + // Warm-up + for (int i = 0; i < 2; ++i) { + net.Run(); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.Run(); + } + net.Sync(); +} } // namespace #define MACE_BM_SOFTMAX_MACRO(N, C, H, W, TYPE, DEVICE) \ @@ -82,6 +118,7 @@ void SoftmaxBenchmark( #define MACE_BM_SOFTMAX(N, C, H, W) \ MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, CPU); \ + MACE_BM_SOFTMAX_MACRO(N, C, H, W, uint8_t, CPU); \ MACE_BM_SOFTMAX_MACRO(N, C, H, W, float, GPU); \ MACE_BM_SOFTMAX_MACRO(N, C, H, W, half, GPU); diff --git a/mace/ops/softmax_test.cc b/mace/ops/softmax_test.cc index 62f7f32f8a1ff4667dbe3660b0b516e2897d7fa5..827067f4ce093b42539cc388fefb13ffa691b905 100644 --- a/mace/ops/softmax_test.cc +++ b/mace/ops/softmax_test.cc @@ -155,6 +155,56 @@ TEST_F(SoftmaxOpTest, OPENCLAlignedRank2) { Complex({3, 1001}); } +namespace { + +void TestQuantizedSoftmax(const std::vector &input_shape) { + OpsTestNet net; + net.AddRandomInput("Input", input_shape, false, true); + + OpDefBuilder("Softmax", "SoftmaxTest") + .Input("Input") + .Output("Output") + .Finalize(net.NewOperatorDef()); + net.RunOp(); + OpDefBuilder("Quantize", "QuantizeInput") + .Input("Input") + .Output("QuantizedInput") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + OpDefBuilder("Softmax", "SoftmaxQuantizeTest") + .Input("QuantizedInput") + .Output("QuantizedOutput") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.Setup(DeviceType::CPU); + Tensor *q_output = net.GetTensor("QuantizedOutput"); + q_output->SetScale(1.0f / 255); + q_output->SetZeroPoint(0); + net.Run(); + OpDefBuilder("Dequantize", "DeQuantizeTest") + .Input("QuantizedOutput") + .Output("DequantizedOutput") + .OutputType({DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + // Check + ExpectTensorSimilar(*net.GetOutput("Output"), + *net.GetTensor("DequantizedOutput"), 0.1); +} + +} // namespace + +TEST_F(SoftmaxOpTest, QuantizeTest) { + TestQuantizedSoftmax({5, 10}); + TestQuantizedSoftmax({50, 100}); + TestQuantizedSoftmax({1, 31}); +} + } // namespace test } // namespace ops } // namespace mace