From 2e1d191cc4a368b14341cb3fc7182f8e224df15f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Thu, 19 Apr 2018 13:36:44 +0800 Subject: [PATCH] Implement quantize related ops --- mace/core/operator.cc | 6 + mace/core/operator.h | 17 ++- mace/kernels/quantize.h | 195 +++++++++++++++++++++++++++++++ mace/ops/batch_norm_test.cc | 2 +- mace/ops/ops_test_util.h | 37 +++++- mace/ops/quantize.cc | 60 ++++++++++ mace/ops/quantize.h | 144 +++++++++++++++++++++++ mace/ops/quantize_test.cc | 224 ++++++++++++++++++++++++++++++++++++ 8 files changed, 676 insertions(+), 9 deletions(-) create mode 100644 mace/kernels/quantize.h create mode 100644 mace/ops/quantize.cc create mode 100644 mace/ops/quantize.h create mode 100644 mace/ops/quantize_test.cc diff --git a/mace/core/operator.cc b/mace/core/operator.cc index b40c2979..1aedbe70 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -86,6 +86,7 @@ extern void Register_Conv2D(OperatorRegistry *op_registry); extern void Register_CWise(OperatorRegistry *op_registry); extern void Register_DepthToSpace(OperatorRegistry *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry); +extern void Register_Dequantize(OperatorRegistry *op_registry); extern void Register_Eltwise(OperatorRegistry *op_registry); extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry); extern void Register_FullyConnected(OperatorRegistry *op_registry); @@ -98,7 +99,9 @@ extern void Register_Pad(OperatorRegistry *op_registry); extern void Register_Pooling(OperatorRegistry *op_registry); extern void Register_Proposal(OperatorRegistry *op_registry); extern void Register_PSROIAlign(OperatorRegistry *op_registry); +extern void Register_Quantize(OperatorRegistry *op_registry); extern void Register_ReOrganize(OperatorRegistry *op_registry); +extern void Register_Requantize(OperatorRegistry *op_registry); extern void Register_Reshape(OperatorRegistry *op_registry); extern void Register_ResizeBilinear(OperatorRegistry *op_registry); extern void Register_Slice(OperatorRegistry *op_registry); @@ -124,6 +127,7 @@ OperatorRegistry::OperatorRegistry() { ops::Register_CWise(this); ops::Register_DepthToSpace(this); ops::Register_DepthwiseConv2d(this); + ops::Register_Dequantize(this); ops::Register_Eltwise(this); ops::Register_FoldedBatchNorm(this); ops::Register_FullyConnected(this); @@ -136,6 +140,8 @@ OperatorRegistry::OperatorRegistry() { ops::Register_Pooling(this); ops::Register_Proposal(this); ops::Register_PSROIAlign(this); + ops::Register_Quantize(this); + ops::Register_Requantize(this); ops::Register_ReOrganize(this); ops::Register_Reshape(this); ops::Register_ResizeBilinear(this); diff --git a/mace/core/operator.h b/mace/core/operator.h index 387a41ef..037aa1e0 100644 --- a/mace/core/operator.h +++ b/mace/core/operator.h @@ -108,12 +108,25 @@ class Operator : public OperatorBase { inputs_.push_back(tensor); } - for (const std::string &output_str : operator_def.output()) { + for (size_t i = 0; i < operator_def.output().size(); ++i) { + const std::string output_str = operator_def.output()[i]; if (ws->HasTensor(output_str)) { outputs_.push_back(ws->GetTensor(output_str)); } else { + MACE_CHECK( + operator_def.output_type().size() == 0 + || operator_def.output().size() == operator_def.output_type().size(), + "operator output size != operator output type size", + operator_def.output().size(), + operator_def.output_type().size()); + DataType output_type; + if (i < operator_def.output_type().size()) { + output_type = operator_def.output_type()[i]; + } else { + output_type = DataTypeToEnum::v(); + } outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor( - output_str, GetDeviceAllocator(D), DataTypeToEnum::v()))); + output_str, GetDeviceAllocator(D), output_type))); } } } diff --git a/mace/kernels/quantize.h b/mace/kernels/quantize.h new file mode 100644 index 00000000..1ffab488 --- /dev/null +++ b/mace/kernels/quantize.h @@ -0,0 +1,195 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_QUANTIZE_H_ +#define MACE_KERNELS_QUANTIZE_H_ + +#include +#include +#include + +#include "mace/core/future.h" +#include "mace/core/tensor.h" + +namespace mace { +namespace kernels { + +template +inline void AdjustRange(const float in_min_data, + const float in_max_data, + float *out_min_data, + float *out_max_data) { + // re-range to make range include zero float and + // make zero float as integer u8 + const float quantized_max = std::numeric_limits::max(); + float out_min = fminf(0.f, in_min_data); + float out_max = fmaxf(0.f, in_max_data); + if (out_min < 0.f) { + float stepsize = (in_max_data - in_min_data) / quantized_max; + float quantized_zero = -in_min_data / stepsize; + float quantized_zero_near_int = roundf(quantized_zero); + if (fabs(quantized_zero - quantized_zero_near_int) > 1e-6) { + if (quantized_zero < quantized_zero_near_int) { + // keep out_max fixed, and move out_min + stepsize = out_max / (quantized_max - quantized_zero_near_int); + out_min = out_max - quantized_max * stepsize; + } else { + // keep out_min fixed, and move out_max + stepsize = -out_min / quantized_zero_near_int; + out_max = out_min + quantized_max * stepsize; + } + } + } + *out_min_data = out_min; + *out_max_data = out_max; +} + +template +inline T Saturate(float value) { + int rounded_value = static_cast(value); + if (rounded_value <= std::numeric_limits::lowest()) { + return std::numeric_limits::lowest(); + } else if (rounded_value >= std::numeric_limits::max()) { + return std::numeric_limits::max(); + } else { + return static_cast(rounded_value); + } +} + +template +struct QuantizeFunctor; + +template<> +struct QuantizeFunctor { + QuantizeFunctor() {} + + void operator()(const Tensor *input, + const Tensor *in_min, + const Tensor *in_max, + Tensor *output, + Tensor *out_min, + Tensor *out_max, + StatsFuture *future) { + const float *input_data = input->data(); + const float in_min_data = in_min->data()[0]; + const float in_max_data = in_max->data()[0]; + uint8_t *output_data = output->mutable_data(); + float *out_min_data = out_min->mutable_data(); + float *out_max_data = out_max->mutable_data(); + + AdjustRange(in_min_data, in_max_data, out_min_data, out_max_data); + float recip_stepsize = 255.f / (out_max_data[0] - out_min_data[0]); + for (int i = 0; i < input->size(); ++i) { + output_data[i] = Saturate(roundf( + (input_data[i] - in_min_data) * recip_stepsize)); + } + } +}; + +template +struct DequantizeFunctor; + +template<> +struct DequantizeFunctor { + DequantizeFunctor() {} + + void operator()(const Tensor *input, + const Tensor *in_min, + const Tensor *in_max, + Tensor *output, + StatsFuture *future) { + const uint8_t *input_data = input->data(); + const float in_min_data = in_min->data()[0]; + const float in_max_data = in_max->data()[0]; + float *output_data = output->mutable_data(); + + float stepsize = (in_max_data - in_min_data) / 255.0; + for (int i = 0; i < input->size(); ++i) { + output_data[i] = in_min_data + stepsize * input_data[i]; + } + } +}; + +template +struct RequantizeFunctor; + +template<> +struct RequantizeFunctor { + RequantizeFunctor() {} + + void operator()(const Tensor *input, + const Tensor *in_min, + const Tensor *in_max, + const Tensor *rerange_min, + const Tensor *rerange_max, + Tensor *output, + Tensor *out_min, + Tensor *out_max, + StatsFuture *future) { + const int *input_data = input->data(); + const float in_min_data = in_min->data()[0]; + const float in_max_data = in_max->data()[0]; + + float rerange_min_data; + float rerange_max_data; + int min_val = std::numeric_limits::max(); + int max_val = std::numeric_limits::lowest(); + double + si = (in_max_data - in_min_data) / std::numeric_limits::max(); + if (rerange_min == nullptr && rerange_max == nullptr) { + for (int i = 0; i < input->size(); ++i) { + min_val = std::min(min_val, input_data[i]); + max_val = std::max(max_val, input_data[i]); + } + rerange_min_data = min_val * si; + rerange_max_data = max_val * si; + } else { + rerange_min_data = rerange_min->data()[0]; + rerange_max_data = rerange_max->data()[0]; + } + + uint8_t *output_data = output->mutable_data(); + float *out_min_data = out_min->mutable_data(); + float *out_max_data = out_max->mutable_data(); + + AdjustRange(rerange_min_data, + rerange_max_data, + out_min_data, + out_max_data); + /** + * f = qi * si = min_o + qo * so + * => qo = (qi * si - min_o) / so + * = qi * (si/so) - min_o / so + * = qi * (si / so) + zo + * + * zo = -min_o / so + * + */ + float so = + (out_max_data[0] - out_min_data[0]) / std::numeric_limits::max(); + double step_ratio = si / so; + float quantized_out_zero = -out_min_data[0] / so; + + for (int i = 0; i < output->size(); ++i) { + output_data[i] = + Saturate(roundf( + quantized_out_zero + input_data[i] * step_ratio)); + } + } +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_QUANTIZE_H_ diff --git a/mace/ops/batch_norm_test.cc b/mace/ops/batch_norm_test.cc index c1f0ca02..7f1bb037 100644 --- a/mace/ops/batch_norm_test.cc +++ b/mace/ops/batch_norm_test.cc @@ -429,7 +429,7 @@ TEST_F(BatchNormOpTest, NEONTest) { ExpectTensorNear(*net.GetOutput("OutputExptected"), *net.GetOutput("OutputNeon"), - 1e-5); + 1e-5, 1e-4); } } // namespace test diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index a8f72f58..09b2fa10 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -52,6 +52,11 @@ class OpDefBuilder { return *this; } + OpDefBuilder &OutputType(const std::vector &output_type) { + op_def_.set_output_type(output_type); + return *this; + } + OpDefBuilder AddIntArg(const std::string &name, const int value) { auto arg = op_def_.add_arg(); arg->set_name(name); @@ -461,7 +466,7 @@ struct Expector { auto a = x.data(); auto b = y.data(); for (int i = 0; i < x.size(); ++i) { - ExpectEqual(a(i), b(i)); + ExpectEqual(a[i], b[i]); } } @@ -499,12 +504,35 @@ struct Expector { } }; +template +struct Expector { + static void Equal(const EXP_TYPE &a, const RES_TYPE &b) { ExpectEqual(a, b); } + + static void Equal(const Tensor &x, const Tensor &y) { + ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); + ASSERT_EQ(y.dtype(), DataTypeToEnum::v()); + AssertSameDims(x, y); + Tensor::MappingGuard x_mapper(&x); + Tensor::MappingGuard y_mapper(&y); + auto a = x.data(); + auto b = y.data(); + for (int i = 0; i < x.size(); ++i) { + ExpectEqual(a[i], b[i]); + } + } + + static void Near(const Tensor &x, const Tensor &y, + const double rel_err, + const double abs_err) { + Equal(x, y); + } +}; + + template void ExpectTensorNear(const Tensor &x, const Tensor &y, const double rel_err = 1e-5, const double abs_err = 1e-8) { - static_assert(is_floating_point_type::value, - "T is not a floating point type"); Expector::Near(x, y, rel_err, abs_err); } @@ -512,9 +540,6 @@ template void ExpectTensorNear(const Tensor &x, const Tensor &y, const double rel_err = 1e-5, const double abs_err = 1e-8) { - static_assert(is_floating_point_type::value && - is_floating_point_type::value, - "T is not a floating point type"); Expector::Near(x, y, rel_err, abs_err); } diff --git a/mace/ops/quantize.cc b/mace/ops/quantize.cc new file mode 100644 index 00000000..49695fde --- /dev/null +++ b/mace/ops/quantize.cc @@ -0,0 +1,60 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/quantize.h" + +namespace mace { +namespace ops { + +void Register_Quantize(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Quantize") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + QuantizeOp); + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Quantize") + .Device(DeviceType::NEON) + .TypeConstraint("T") + .Build(), + QuantizeOp); +} + +void Register_Dequantize(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Dequantize") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + DequantizeOp); + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Dequantize") + .Device(DeviceType::NEON) + .TypeConstraint("T") + .Build(), + DequantizeOp); +} + +void Register_Requantize(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Requantize") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + RequantizeOp); + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Requantize") + .Device(DeviceType::NEON) + .TypeConstraint("T") + .Build(), + RequantizeOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/quantize.h b/mace/ops/quantize.h new file mode 100644 index 00000000..cee215f1 --- /dev/null +++ b/mace/ops/quantize.h @@ -0,0 +1,144 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_QUANTIZE_H_ +#define MACE_OPS_QUANTIZE_H_ + +#include "mace/core/operator.h" +#include "mace/kernels/quantize.h" + +namespace mace { +namespace ops { + +template +class QuantizeOp : public Operator { + public: + QuantizeOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws) { + } + + bool Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const Tensor *in_min = this->Input(IN_MIN); + const Tensor *in_max = this->Input(IN_MAX); + + MACE_CHECK(in_min->size() == 1, "min val tensor has more than 1 value"); + MACE_CHECK(in_max->size() == 1, "max val tensor has more than 1 value"); + + Tensor *output = this->Output(OUTPUT); + Tensor *out_min = this->Output(OUT_MIN); + Tensor *out_max = this->Output(OUT_MAX); + output->ResizeLike(input); + out_min->ResizeLike(in_min); + out_max->ResizeLike(in_max); + + functor_(input, in_min, in_max, output, out_min, out_max, future); + return true; + } + + private: + kernels::QuantizeFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT, IN_MIN, IN_MAX); + OP_OUTPUT_TAGS(OUTPUT, OUT_MIN, OUT_MAX); +}; + +template +class DequantizeOp : public Operator { + public: + DequantizeOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws) { + } + + bool Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const Tensor *in_min = this->Input(IN_MIN); + const Tensor *in_max = this->Input(IN_MAX); + + MACE_CHECK(in_min->size() == 1, "min val tensor has more than 1 value"); + MACE_CHECK(in_max->size() == 1, "max val tensor has more than 1 value"); + + Tensor *output = this->Output(OUTPUT); + output->ResizeLike(input); + + functor_(input, in_min, in_max, output, future); + return true; + } + + private: + kernels::DequantizeFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT, IN_MIN, IN_MAX); + OP_OUTPUT_TAGS(OUTPUT); +}; + +template +class RequantizeOp : public Operator { + public: + RequantizeOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws) { + } + + bool Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const Tensor *in_min = this->Input(IN_MIN); + const Tensor *in_max = this->Input(IN_MAX); + const Tensor *rerange_min = nullptr; + const Tensor *rerange_max = nullptr; + + MACE_CHECK(in_min->size() == 1, "min val tensor has more than 1 value"); + MACE_CHECK(in_max->size() == 1, "max val tensor has more than 1 value"); + + if (this->InputSize() >= 5) { + rerange_min = this->Input(RERANGE_MIN); + rerange_max = this->Input(RERANGE_MAX); + MACE_CHECK(rerange_min->size() == 1, + "rerange min val tensor has more than 1 value"); + MACE_CHECK(rerange_max->size() == 1, + "rerange max val tensor has more than 1 value"); + } + + Tensor *output = this->Output(OUTPUT); + Tensor *out_min = this->Output(OUT_MIN); + Tensor *out_max = this->Output(OUT_MAX); + output->ResizeLike(input); + out_min->ResizeLike(in_min); + out_max->ResizeLike(out_max); + + functor_(input, + in_min, + in_max, + rerange_min, + rerange_max, + output, + out_min, + out_max, + future); + return true; + } + + private: + kernels::RequantizeFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT, IN_MIN, IN_MAX, RERANGE_MIN, RERANGE_MAX); + OP_OUTPUT_TAGS(OUTPUT, OUT_MIN, OUT_MAX); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_QUANTIZE_H_ diff --git a/mace/ops/quantize_test.cc b/mace/ops/quantize_test.cc new file mode 100644 index 00000000..1672ac53 --- /dev/null +++ b/mace/ops/quantize_test.cc @@ -0,0 +1,224 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class QuantizeTest : public OpsTestBase {}; + +TEST_F(QuantizeTest, TestQuantize) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input", {1, 2, 3, 1}, { + -2, -1, 1, 2, 3, 4 + }); + net.AddInputFromArray("InputMin", {1}, {-3}); + net.AddInputFromArray("InputMax", {1}, {5}); + + OpDefBuilder("Quantize", "QuantizeTest") + .Input("Input") + .Input("InputMin") + .Input("InputMax") + .Output("Output") + .Output("OutputMin") + .Output("OutputMax") + .OutputType({DT_UINT8, DT_FLOAT, DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + auto output = net.GetTensor("Output"); + auto output_min = net.GetTensor("OutputMin"); + auto output_max = net.GetTensor("OutputMax"); + + auto expected_output = CreateTensor({1, 2, 3, 1}, + { + 32, 64, 127, 159, 191, 223 + }); + auto expected_min = CreateTensor({1}, {-3.01887}); + auto expected_max = CreateTensor({1}, {5}); + + ExpectTensorNear(*expected_output, *output); + ExpectTensorNear(*expected_min, *output_min); + ExpectTensorNear(*expected_max, *output_max); +} + +TEST_F(QuantizeTest, TestQuantizeTrend) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", {100}); + const float *input_data = net.GetTensor("Input")->data(); + net.AddInputFromArray("InputMin", + {1}, + {*std::min_element(input_data, + input_data + + net.GetTensor("Input")->size())}); + net.AddInputFromArray("InputMax", + {1}, + {*std::max_element(input_data, + input_data + + net.GetTensor("Input")->size())}); + + OpDefBuilder("Quantize", "QuantizeTest") + .Input("Input") + .Input("InputMin") + .Input("InputMax") + .Output("Output") + .Output("OutputMin") + .Output("OutputMax") + .OutputType({DT_UINT8, DT_FLOAT, DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + auto output = net.GetTensor("Output"); + auto output_min = net.GetTensor("OutputMin"); + auto output_max = net.GetTensor("OutputMax"); + + const uint8_t *output_data = net.GetTensor("Output")->data(); + for (int i = 1; i < output->size(); ++i) { + if (input_data[i] > input_data[i - 1]) { + EXPECT_GE(output_data[i], output_data[i - 1]); + } else if (input_data[i] == input_data[i - 1]) { + EXPECT_EQ(output_data[i], output_data[i - 1]); + } else { + EXPECT_LE(output_data[i], output_data[i - 1]); + } + } +} + +TEST_F(QuantizeTest, TestDequantize) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input", {1, 2, 3, 1}, { + 32, 64, 127, 159, 191, 223 + }); + net.AddInputFromArray("InputMin", {1}, {-3.01887}); + net.AddInputFromArray("InputMax", {1}, {5}); + + OpDefBuilder("Dequantize", "DequantizeTest") + .Input("Input") + .Input("InputMin") + .Input("InputMax") + .Output("Output") + .OutputType({DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + auto output = net.GetTensor("Output"); + auto expected_output = CreateTensor({1, 2, 3, 1}, + { + -2, -1, 1, 2, 3, 4 + }); + auto expected_min = CreateTensor({1}, {-3.01887}); + auto expected_max = CreateTensor({1}, {5}); + + ExpectTensorNear(*expected_output, *output, 0.1, 0.01); +} + +TEST_F(QuantizeTest, TestRequantizeWithMinMax) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input", {1, 2, 3, 1}, { + -1073741824, -536870912, 536870912, 1073741824, 1610612736, 2147483647 + }); + net.AddInputFromArray("InputMin", {1}, {-3}); + net.AddInputFromArray("InputMax", {1}, {5}); + net.AddInputFromArray("RerangeMin", {1}, {-3.01887}); + net.AddInputFromArray("RerangeMax", {1}, {5}); + + OpDefBuilder("Requantize", "RequantizeTest") + .Input("Input") + .Input("InputMin") + .Input("InputMax") + .Input("RerangeMin") + .Input("RerangeMax") + .Output("Output") + .Output("OutputMin") + .Output("OutputMax") + .OutputType({DT_UINT8, DT_FLOAT, DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + auto output = net.GetTensor("Output"); + auto expected_output = CreateTensor({1, 2, 3, 1}, + { + 32, 64, 128, 160, 191, 223 + }); + auto expected_min = CreateTensor({1}, {-3.01887}); + auto expected_max = CreateTensor({1}, {5}); + + ExpectTensorNear(*expected_output, *output); +} + +TEST_F(QuantizeTest, TestRequantizeWithoutMinMax) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input", {1, 2, 3, 1}, { + -1073741824, -536870912, 536870912, 1073741824, 1610612736, 2147483647 + }); + net.AddInputFromArray("InputMin", {1}, {-3}); + net.AddInputFromArray("InputMax", {1}, {5}); + + OpDefBuilder("Requantize", "RequantizeTest") + .Input("Input") + .Input("InputMin") + .Input("InputMax") + .Output("Output") + .Output("OutputMin") + .Output("OutputMax") + .OutputType({DT_UINT8, DT_FLOAT, DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + auto output = net.GetTensor("Output"); + auto expected_output = CreateTensor({1, 2, 3, 1}, + { + 0, 43, 128, 170, 213, 255 + }); + auto expected_min = CreateTensor({1}, {-3.01887}); + auto expected_max = CreateTensor({1}, {5}); + ExpectTensorNear(*expected_output, *output); +} + +} // namespace test +} // namespace ops +} // namespace mace -- GitLab