提交 2e1d191c 编写于 作者: 李寅

Implement quantize related ops

上级 c384a6e2
...@@ -86,6 +86,7 @@ extern void Register_Conv2D(OperatorRegistry *op_registry); ...@@ -86,6 +86,7 @@ extern void Register_Conv2D(OperatorRegistry *op_registry);
extern void Register_CWise(OperatorRegistry *op_registry); extern void Register_CWise(OperatorRegistry *op_registry);
extern void Register_DepthToSpace(OperatorRegistry *op_registry); extern void Register_DepthToSpace(OperatorRegistry *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry);
extern void Register_Dequantize(OperatorRegistry *op_registry);
extern void Register_Eltwise(OperatorRegistry *op_registry); extern void Register_Eltwise(OperatorRegistry *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry); extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry);
extern void Register_FullyConnected(OperatorRegistry *op_registry); extern void Register_FullyConnected(OperatorRegistry *op_registry);
...@@ -98,7 +99,9 @@ extern void Register_Pad(OperatorRegistry *op_registry); ...@@ -98,7 +99,9 @@ extern void Register_Pad(OperatorRegistry *op_registry);
extern void Register_Pooling(OperatorRegistry *op_registry); extern void Register_Pooling(OperatorRegistry *op_registry);
extern void Register_Proposal(OperatorRegistry *op_registry); extern void Register_Proposal(OperatorRegistry *op_registry);
extern void Register_PSROIAlign(OperatorRegistry *op_registry); extern void Register_PSROIAlign(OperatorRegistry *op_registry);
extern void Register_Quantize(OperatorRegistry *op_registry);
extern void Register_ReOrganize(OperatorRegistry *op_registry); extern void Register_ReOrganize(OperatorRegistry *op_registry);
extern void Register_Requantize(OperatorRegistry *op_registry);
extern void Register_Reshape(OperatorRegistry *op_registry); extern void Register_Reshape(OperatorRegistry *op_registry);
extern void Register_ResizeBilinear(OperatorRegistry *op_registry); extern void Register_ResizeBilinear(OperatorRegistry *op_registry);
extern void Register_Slice(OperatorRegistry *op_registry); extern void Register_Slice(OperatorRegistry *op_registry);
...@@ -124,6 +127,7 @@ OperatorRegistry::OperatorRegistry() { ...@@ -124,6 +127,7 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_CWise(this); ops::Register_CWise(this);
ops::Register_DepthToSpace(this); ops::Register_DepthToSpace(this);
ops::Register_DepthwiseConv2d(this); ops::Register_DepthwiseConv2d(this);
ops::Register_Dequantize(this);
ops::Register_Eltwise(this); ops::Register_Eltwise(this);
ops::Register_FoldedBatchNorm(this); ops::Register_FoldedBatchNorm(this);
ops::Register_FullyConnected(this); ops::Register_FullyConnected(this);
...@@ -136,6 +140,8 @@ OperatorRegistry::OperatorRegistry() { ...@@ -136,6 +140,8 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_Pooling(this); ops::Register_Pooling(this);
ops::Register_Proposal(this); ops::Register_Proposal(this);
ops::Register_PSROIAlign(this); ops::Register_PSROIAlign(this);
ops::Register_Quantize(this);
ops::Register_Requantize(this);
ops::Register_ReOrganize(this); ops::Register_ReOrganize(this);
ops::Register_Reshape(this); ops::Register_Reshape(this);
ops::Register_ResizeBilinear(this); ops::Register_ResizeBilinear(this);
......
...@@ -108,12 +108,25 @@ class Operator : public OperatorBase { ...@@ -108,12 +108,25 @@ class Operator : public OperatorBase {
inputs_.push_back(tensor); inputs_.push_back(tensor);
} }
for (const std::string &output_str : operator_def.output()) { for (size_t i = 0; i < operator_def.output().size(); ++i) {
const std::string output_str = operator_def.output()[i];
if (ws->HasTensor(output_str)) { if (ws->HasTensor(output_str)) {
outputs_.push_back(ws->GetTensor(output_str)); outputs_.push_back(ws->GetTensor(output_str));
} else { } else {
MACE_CHECK(
operator_def.output_type().size() == 0
|| operator_def.output().size() == operator_def.output_type().size(),
"operator output size != operator output type size",
operator_def.output().size(),
operator_def.output_type().size());
DataType output_type;
if (i < operator_def.output_type().size()) {
output_type = operator_def.output_type()[i];
} else {
output_type = DataTypeToEnum<T>::v();
}
outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor( outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor(
output_str, GetDeviceAllocator(D), DataTypeToEnum<T>::v()))); output_str, GetDeviceAllocator(D), output_type)));
} }
} }
} }
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_QUANTIZE_H_
#define MACE_KERNELS_QUANTIZE_H_
#include <vector>
#include <algorithm>
#include <limits>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
namespace mace {
namespace kernels {
template<typename T>
inline void AdjustRange(const float in_min_data,
const float in_max_data,
float *out_min_data,
float *out_max_data) {
// re-range to make range include zero float and
// make zero float as integer u8
const float quantized_max = std::numeric_limits<uint8_t>::max();
float out_min = fminf(0.f, in_min_data);
float out_max = fmaxf(0.f, in_max_data);
if (out_min < 0.f) {
float stepsize = (in_max_data - in_min_data) / quantized_max;
float quantized_zero = -in_min_data / stepsize;
float quantized_zero_near_int = roundf(quantized_zero);
if (fabs(quantized_zero - quantized_zero_near_int) > 1e-6) {
if (quantized_zero < quantized_zero_near_int) {
// keep out_max fixed, and move out_min
stepsize = out_max / (quantized_max - quantized_zero_near_int);
out_min = out_max - quantized_max * stepsize;
} else {
// keep out_min fixed, and move out_max
stepsize = -out_min / quantized_zero_near_int;
out_max = out_min + quantized_max * stepsize;
}
}
}
*out_min_data = out_min;
*out_max_data = out_max;
}
template<typename T>
inline T Saturate(float value) {
int rounded_value = static_cast<int>(value);
if (rounded_value <= std::numeric_limits<T>::lowest()) {
return std::numeric_limits<T>::lowest();
} else if (rounded_value >= std::numeric_limits<T>::max()) {
return std::numeric_limits<T>::max();
} else {
return static_cast<T>(rounded_value);
}
}
template<DeviceType D, typename T>
struct QuantizeFunctor;
template<>
struct QuantizeFunctor<CPU, uint8_t> {
QuantizeFunctor() {}
void operator()(const Tensor *input,
const Tensor *in_min,
const Tensor *in_max,
Tensor *output,
Tensor *out_min,
Tensor *out_max,
StatsFuture *future) {
const float *input_data = input->data<float>();
const float in_min_data = in_min->data<float>()[0];
const float in_max_data = in_max->data<float>()[0];
uint8_t *output_data = output->mutable_data<uint8_t>();
float *out_min_data = out_min->mutable_data<float>();
float *out_max_data = out_max->mutable_data<float>();
AdjustRange<uint8_t>(in_min_data, in_max_data, out_min_data, out_max_data);
float recip_stepsize = 255.f / (out_max_data[0] - out_min_data[0]);
for (int i = 0; i < input->size(); ++i) {
output_data[i] = Saturate<uint8_t>(roundf(
(input_data[i] - in_min_data) * recip_stepsize));
}
}
};
template<DeviceType D, typename T>
struct DequantizeFunctor;
template<>
struct DequantizeFunctor<CPU, uint8_t> {
DequantizeFunctor() {}
void operator()(const Tensor *input,
const Tensor *in_min,
const Tensor *in_max,
Tensor *output,
StatsFuture *future) {
const uint8_t *input_data = input->data<uint8_t>();
const float in_min_data = in_min->data<float>()[0];
const float in_max_data = in_max->data<float>()[0];
float *output_data = output->mutable_data<float>();
float stepsize = (in_max_data - in_min_data) / 255.0;
for (int i = 0; i < input->size(); ++i) {
output_data[i] = in_min_data + stepsize * input_data[i];
}
}
};
template<DeviceType D, typename T>
struct RequantizeFunctor;
template<>
struct RequantizeFunctor<CPU, uint8_t> {
RequantizeFunctor() {}
void operator()(const Tensor *input,
const Tensor *in_min,
const Tensor *in_max,
const Tensor *rerange_min,
const Tensor *rerange_max,
Tensor *output,
Tensor *out_min,
Tensor *out_max,
StatsFuture *future) {
const int *input_data = input->data<int>();
const float in_min_data = in_min->data<float>()[0];
const float in_max_data = in_max->data<float>()[0];
float rerange_min_data;
float rerange_max_data;
int min_val = std::numeric_limits<int>::max();
int max_val = std::numeric_limits<int>::lowest();
double
si = (in_max_data - in_min_data) / std::numeric_limits<uint32_t>::max();
if (rerange_min == nullptr && rerange_max == nullptr) {
for (int i = 0; i < input->size(); ++i) {
min_val = std::min(min_val, input_data[i]);
max_val = std::max(max_val, input_data[i]);
}
rerange_min_data = min_val * si;
rerange_max_data = max_val * si;
} else {
rerange_min_data = rerange_min->data<float>()[0];
rerange_max_data = rerange_max->data<float>()[0];
}
uint8_t *output_data = output->mutable_data<uint8_t>();
float *out_min_data = out_min->mutable_data<float>();
float *out_max_data = out_max->mutable_data<float>();
AdjustRange<uint8_t>(rerange_min_data,
rerange_max_data,
out_min_data,
out_max_data);
/**
* f = qi * si = min_o + qo * so
* => qo = (qi * si - min_o) / so
* = qi * (si/so) - min_o / so
* = qi * (si / so) + zo
*
* zo = -min_o / so
*
*/
float so =
(out_max_data[0] - out_min_data[0]) / std::numeric_limits<uint8_t>::max();
double step_ratio = si / so;
float quantized_out_zero = -out_min_data[0] / so;
for (int i = 0; i < output->size(); ++i) {
output_data[i] =
Saturate<uint8_t>(roundf(
quantized_out_zero + input_data[i] * step_ratio));
}
}
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_QUANTIZE_H_
...@@ -429,7 +429,7 @@ TEST_F(BatchNormOpTest, NEONTest) { ...@@ -429,7 +429,7 @@ TEST_F(BatchNormOpTest, NEONTest) {
ExpectTensorNear<float>(*net.GetOutput("OutputExptected"), ExpectTensorNear<float>(*net.GetOutput("OutputExptected"),
*net.GetOutput("OutputNeon"), *net.GetOutput("OutputNeon"),
1e-5); 1e-5, 1e-4);
} }
} // namespace test } // namespace test
......
...@@ -52,6 +52,11 @@ class OpDefBuilder { ...@@ -52,6 +52,11 @@ class OpDefBuilder {
return *this; return *this;
} }
OpDefBuilder &OutputType(const std::vector<DataType> &output_type) {
op_def_.set_output_type(output_type);
return *this;
}
OpDefBuilder AddIntArg(const std::string &name, const int value) { OpDefBuilder AddIntArg(const std::string &name, const int value) {
auto arg = op_def_.add_arg(); auto arg = op_def_.add_arg();
arg->set_name(name); arg->set_name(name);
...@@ -461,7 +466,7 @@ struct Expector<EXP_TYPE, RES_TYPE, true> { ...@@ -461,7 +466,7 @@ struct Expector<EXP_TYPE, RES_TYPE, true> {
auto a = x.data<EXP_TYPE>(); auto a = x.data<EXP_TYPE>();
auto b = y.data<RES_TYPE>(); auto b = y.data<RES_TYPE>();
for (int i = 0; i < x.size(); ++i) { for (int i = 0; i < x.size(); ++i) {
ExpectEqual(a(i), b(i)); ExpectEqual(a[i], b[i]);
} }
} }
...@@ -499,12 +504,35 @@ struct Expector<EXP_TYPE, RES_TYPE, true> { ...@@ -499,12 +504,35 @@ struct Expector<EXP_TYPE, RES_TYPE, true> {
} }
}; };
template<typename EXP_TYPE, typename RES_TYPE>
struct Expector<EXP_TYPE, RES_TYPE, false> {
static void Equal(const EXP_TYPE &a, const RES_TYPE &b) { ExpectEqual(a, b); }
static void Equal(const Tensor &x, const Tensor &y) {
ASSERT_EQ(x.dtype(), DataTypeToEnum<EXP_TYPE>::v());
ASSERT_EQ(y.dtype(), DataTypeToEnum<RES_TYPE>::v());
AssertSameDims(x, y);
Tensor::MappingGuard x_mapper(&x);
Tensor::MappingGuard y_mapper(&y);
auto a = x.data<EXP_TYPE>();
auto b = y.data<RES_TYPE>();
for (int i = 0; i < x.size(); ++i) {
ExpectEqual(a[i], b[i]);
}
}
static void Near(const Tensor &x, const Tensor &y,
const double rel_err,
const double abs_err) {
Equal(x, y);
}
};
template<typename T> template<typename T>
void ExpectTensorNear(const Tensor &x, const Tensor &y, void ExpectTensorNear(const Tensor &x, const Tensor &y,
const double rel_err = 1e-5, const double rel_err = 1e-5,
const double abs_err = 1e-8) { const double abs_err = 1e-8) {
static_assert(is_floating_point_type<T>::value,
"T is not a floating point type");
Expector<T, T>::Near(x, y, rel_err, abs_err); Expector<T, T>::Near(x, y, rel_err, abs_err);
} }
...@@ -512,9 +540,6 @@ template<typename EXP_TYPE, typename RES_TYPE> ...@@ -512,9 +540,6 @@ template<typename EXP_TYPE, typename RES_TYPE>
void ExpectTensorNear(const Tensor &x, const Tensor &y, void ExpectTensorNear(const Tensor &x, const Tensor &y,
const double rel_err = 1e-5, const double rel_err = 1e-5,
const double abs_err = 1e-8) { const double abs_err = 1e-8) {
static_assert(is_floating_point_type<EXP_TYPE>::value &&
is_floating_point_type<RES_TYPE>::value,
"T is not a floating point type");
Expector<EXP_TYPE, RES_TYPE>::Near(x, y, rel_err, abs_err); Expector<EXP_TYPE, RES_TYPE>::Near(x, y, rel_err, abs_err);
} }
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/quantize.h"
namespace mace {
namespace ops {
void Register_Quantize(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Quantize")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
.Build(),
QuantizeOp<DeviceType::CPU, uint8_t>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Quantize")
.Device(DeviceType::NEON)
.TypeConstraint<uint8_t>("T")
.Build(),
QuantizeOp<DeviceType::CPU, uint8_t>);
}
void Register_Dequantize(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Dequantize")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
.Build(),
DequantizeOp<DeviceType::CPU, uint8_t>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Dequantize")
.Device(DeviceType::NEON)
.TypeConstraint<uint8_t>("T")
.Build(),
DequantizeOp<DeviceType::CPU, uint8_t>);
}
void Register_Requantize(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Requantize")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
.Build(),
RequantizeOp<DeviceType::CPU, uint8_t>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("Requantize")
.Device(DeviceType::NEON)
.TypeConstraint<uint8_t>("T")
.Build(),
RequantizeOp<DeviceType::CPU, uint8_t>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_QUANTIZE_H_
#define MACE_OPS_QUANTIZE_H_
#include "mace/core/operator.h"
#include "mace/kernels/quantize.h"
namespace mace {
namespace ops {
template<DeviceType D, class T>
class QuantizeOp : public Operator<D, T> {
public:
QuantizeOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {
}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *in_min = this->Input(IN_MIN);
const Tensor *in_max = this->Input(IN_MAX);
MACE_CHECK(in_min->size() == 1, "min val tensor has more than 1 value");
MACE_CHECK(in_max->size() == 1, "max val tensor has more than 1 value");
Tensor *output = this->Output(OUTPUT);
Tensor *out_min = this->Output(OUT_MIN);
Tensor *out_max = this->Output(OUT_MAX);
output->ResizeLike(input);
out_min->ResizeLike(in_min);
out_max->ResizeLike(in_max);
functor_(input, in_min, in_max, output, out_min, out_max, future);
return true;
}
private:
kernels::QuantizeFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT, IN_MIN, IN_MAX);
OP_OUTPUT_TAGS(OUTPUT, OUT_MIN, OUT_MAX);
};
template<DeviceType D, class T>
class DequantizeOp : public Operator<D, T> {
public:
DequantizeOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {
}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *in_min = this->Input(IN_MIN);
const Tensor *in_max = this->Input(IN_MAX);
MACE_CHECK(in_min->size() == 1, "min val tensor has more than 1 value");
MACE_CHECK(in_max->size() == 1, "max val tensor has more than 1 value");
Tensor *output = this->Output(OUTPUT);
output->ResizeLike(input);
functor_(input, in_min, in_max, output, future);
return true;
}
private:
kernels::DequantizeFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT, IN_MIN, IN_MAX);
OP_OUTPUT_TAGS(OUTPUT);
};
template<DeviceType D, class T>
class RequantizeOp : public Operator<D, T> {
public:
RequantizeOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {
}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *in_min = this->Input(IN_MIN);
const Tensor *in_max = this->Input(IN_MAX);
const Tensor *rerange_min = nullptr;
const Tensor *rerange_max = nullptr;
MACE_CHECK(in_min->size() == 1, "min val tensor has more than 1 value");
MACE_CHECK(in_max->size() == 1, "max val tensor has more than 1 value");
if (this->InputSize() >= 5) {
rerange_min = this->Input(RERANGE_MIN);
rerange_max = this->Input(RERANGE_MAX);
MACE_CHECK(rerange_min->size() == 1,
"rerange min val tensor has more than 1 value");
MACE_CHECK(rerange_max->size() == 1,
"rerange max val tensor has more than 1 value");
}
Tensor *output = this->Output(OUTPUT);
Tensor *out_min = this->Output(OUT_MIN);
Tensor *out_max = this->Output(OUT_MAX);
output->ResizeLike(input);
out_min->ResizeLike(in_min);
out_max->ResizeLike(out_max);
functor_(input,
in_min,
in_max,
rerange_min,
rerange_max,
output,
out_min,
out_max,
future);
return true;
}
private:
kernels::RequantizeFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT, IN_MIN, IN_MAX, RERANGE_MIN, RERANGE_MAX);
OP_OUTPUT_TAGS(OUTPUT, OUT_MIN, OUT_MAX);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_QUANTIZE_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class QuantizeTest : public OpsTestBase {};
TEST_F(QuantizeTest, TestQuantize) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<CPU, float>("Input", {1, 2, 3, 1}, {
-2, -1, 1, 2, 3, 4
});
net.AddInputFromArray<CPU, float>("InputMin", {1}, {-3});
net.AddInputFromArray<CPU, float>("InputMax", {1}, {5});
OpDefBuilder("Quantize", "QuantizeTest")
.Input("Input")
.Input("InputMin")
.Input("InputMax")
.Output("Output")
.Output("OutputMin")
.Output("OutputMax")
.OutputType({DT_UINT8, DT_FLOAT, DT_FLOAT})
.AddIntArg("T", DT_UINT8)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
auto output = net.GetTensor("Output");
auto output_min = net.GetTensor("OutputMin");
auto output_max = net.GetTensor("OutputMax");
auto expected_output = CreateTensor<uint8_t>({1, 2, 3, 1},
{
32, 64, 127, 159, 191, 223
});
auto expected_min = CreateTensor<float>({1}, {-3.01887});
auto expected_max = CreateTensor<float>({1}, {5});
ExpectTensorNear<uint8_t>(*expected_output, *output);
ExpectTensorNear<float>(*expected_min, *output_min);
ExpectTensorNear<float>(*expected_max, *output_max);
}
TEST_F(QuantizeTest, TestQuantizeTrend) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddRandomInput<CPU, float>("Input", {100});
const float *input_data = net.GetTensor("Input")->data<float>();
net.AddInputFromArray<CPU, float>("InputMin",
{1},
{*std::min_element(input_data,
input_data
+ net.GetTensor("Input")->size())});
net.AddInputFromArray<CPU, float>("InputMax",
{1},
{*std::max_element(input_data,
input_data
+ net.GetTensor("Input")->size())});
OpDefBuilder("Quantize", "QuantizeTest")
.Input("Input")
.Input("InputMin")
.Input("InputMax")
.Output("Output")
.Output("OutputMin")
.Output("OutputMax")
.OutputType({DT_UINT8, DT_FLOAT, DT_FLOAT})
.AddIntArg("T", DT_UINT8)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
auto output = net.GetTensor("Output");
auto output_min = net.GetTensor("OutputMin");
auto output_max = net.GetTensor("OutputMax");
const uint8_t *output_data = net.GetTensor("Output")->data<uint8_t>();
for (int i = 1; i < output->size(); ++i) {
if (input_data[i] > input_data[i - 1]) {
EXPECT_GE(output_data[i], output_data[i - 1]);
} else if (input_data[i] == input_data[i - 1]) {
EXPECT_EQ(output_data[i], output_data[i - 1]);
} else {
EXPECT_LE(output_data[i], output_data[i - 1]);
}
}
}
TEST_F(QuantizeTest, TestDequantize) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<CPU, uint8_t>("Input", {1, 2, 3, 1}, {
32, 64, 127, 159, 191, 223
});
net.AddInputFromArray<CPU, float>("InputMin", {1}, {-3.01887});
net.AddInputFromArray<CPU, float>("InputMax", {1}, {5});
OpDefBuilder("Dequantize", "DequantizeTest")
.Input("Input")
.Input("InputMin")
.Input("InputMax")
.Output("Output")
.OutputType({DT_FLOAT})
.AddIntArg("T", DT_UINT8)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
auto output = net.GetTensor("Output");
auto expected_output = CreateTensor<float>({1, 2, 3, 1},
{
-2, -1, 1, 2, 3, 4
});
auto expected_min = CreateTensor<float>({1}, {-3.01887});
auto expected_max = CreateTensor<float>({1}, {5});
ExpectTensorNear<float>(*expected_output, *output, 0.1, 0.01);
}
TEST_F(QuantizeTest, TestRequantizeWithMinMax) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<CPU, int>("Input", {1, 2, 3, 1}, {
-1073741824, -536870912, 536870912, 1073741824, 1610612736, 2147483647
});
net.AddInputFromArray<CPU, float>("InputMin", {1}, {-3});
net.AddInputFromArray<CPU, float>("InputMax", {1}, {5});
net.AddInputFromArray<CPU, float>("RerangeMin", {1}, {-3.01887});
net.AddInputFromArray<CPU, float>("RerangeMax", {1}, {5});
OpDefBuilder("Requantize", "RequantizeTest")
.Input("Input")
.Input("InputMin")
.Input("InputMax")
.Input("RerangeMin")
.Input("RerangeMax")
.Output("Output")
.Output("OutputMin")
.Output("OutputMax")
.OutputType({DT_UINT8, DT_FLOAT, DT_FLOAT})
.AddIntArg("T", DT_UINT8)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
auto output = net.GetTensor("Output");
auto expected_output = CreateTensor<uint8_t>({1, 2, 3, 1},
{
32, 64, 128, 160, 191, 223
});
auto expected_min = CreateTensor<float>({1}, {-3.01887});
auto expected_max = CreateTensor<float>({1}, {5});
ExpectTensorNear<uint8_t>(*expected_output, *output);
}
TEST_F(QuantizeTest, TestRequantizeWithoutMinMax) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<CPU, int>("Input", {1, 2, 3, 1}, {
-1073741824, -536870912, 536870912, 1073741824, 1610612736, 2147483647
});
net.AddInputFromArray<CPU, float>("InputMin", {1}, {-3});
net.AddInputFromArray<CPU, float>("InputMax", {1}, {5});
OpDefBuilder("Requantize", "RequantizeTest")
.Input("Input")
.Input("InputMin")
.Input("InputMax")
.Output("Output")
.Output("OutputMin")
.Output("OutputMax")
.OutputType({DT_UINT8, DT_FLOAT, DT_FLOAT})
.AddIntArg("T", DT_UINT8)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
auto output = net.GetTensor("Output");
auto expected_output = CreateTensor<uint8_t>({1, 2, 3, 1},
{
0, 43, 128, 170, 213, 255
});
auto expected_min = CreateTensor<float>({1}, {-3.01887});
auto expected_max = CreateTensor<float>({1}, {5});
ExpectTensorNear<uint8_t>(*expected_output, *output);
}
} // namespace test
} // namespace ops
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册