diff --git a/WORKSPACE b/WORKSPACE index 1899e0437f5c97216a04a661bbaf58d1eeb8aeeb..325cd25668c3dede8cd78366573183202d79a332 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -85,6 +85,15 @@ http_archive( ], ) +http_archive( + name = "tflite", + sha256 = "c886d46ad8c91fcafed2d910ad9e7bc5aeb29856c387bdf9b6b4903cc16e6e60", + strip_prefix = "tensorflow-mace-ffc8cc7e8c9d1894753509e88b17e251bc6255e3", + urls = [ + "https://cnbj1.fds.api.xiaomi.com/mace/third-party/tflite/tensorflow-mace-ffc8cc7e8c9d1894753509e88b17e251bc6255e3.zip", + ], +) + new_http_archive( name = "six_archive", build_file = "third_party/six/six.BUILD", diff --git a/mace/kernels/BUILD b/mace/kernels/BUILD index 1035b54bed1b69c2da05b4c3f2c52d6222bac95f..5706f94e8357f7a835bbfdea0601a8a6d132e785 100644 --- a/mace/kernels/BUILD +++ b/mace/kernels/BUILD @@ -73,6 +73,7 @@ cc_library( "//mace/core", "//mace/utils", "@gemmlowp", + "@tflite", ], ) diff --git a/mace/kernels/conv_2d.h b/mace/kernels/conv_2d.h index 4e472b7fd4169a53de41b76150499db70c9a2c3f..53ccb310fdf184f6d590b3c9712051bd9bf640ec 100644 --- a/mace/kernels/conv_2d.h +++ b/mace/kernels/conv_2d.h @@ -260,11 +260,11 @@ struct Conv2dFunctor : Conv2dFunctorBase { } // b } - MaceStatus operator()(const Tensor *input, - const Tensor *filter, - const Tensor *bias, - Tensor *output, - StatsFuture *future) { + MaceStatus operator()(const Tensor *input, // NCHW + const Tensor *filter, // OIHW + const Tensor *bias, + Tensor *output, // NCHW + StatsFuture *future) { MACE_UNUSED(future); MACE_CHECK_NOTNULL(input); MACE_CHECK_NOTNULL(filter); @@ -822,18 +822,6 @@ struct Conv2dFunctor : Conv2dFunctorBase { } } - inline void GetOutputMultiplierAndShift( - const float lhs_scale, const float rhs_scale, const float output_scale, - int32_t *quantized_multiplier, int *right_shift) { - float real_multiplier = lhs_scale * rhs_scale / output_scale; - MACE_CHECK(real_multiplier > 0.f && real_multiplier < 1.f, real_multiplier); - - int exponent; - QuantizeMultiplier(real_multiplier, quantized_multiplier, &exponent); - *right_shift = -exponent; - MACE_CHECK(*right_shift >= 0); - } - typedef gemmlowp::VectorMap ColVectorMap; typedef std::tuple< diff --git a/mace/kernels/depthwise_conv2d.h b/mace/kernels/depthwise_conv2d.h index 14c83042f01bd5ab9f12de519c207b20daab68d2..9304b14f711f184616d42228cca0713b487f7511 100644 --- a/mace/kernels/depthwise_conv2d.h +++ b/mace/kernels/depthwise_conv2d.h @@ -22,10 +22,12 @@ #include #include +#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h" #include "mace/core/future.h" #include "mace/kernels/conv_pool_2d_util.h" #include "mace/kernels/activation.h" #include "mace/kernels/arm/depthwise_conv2d_neon.h" +#include "mace/kernels/quantize.h" #include "mace/public/mace.h" #ifdef MACE_ENABLE_OPENCL @@ -113,8 +115,7 @@ struct DepthwiseConv2dFunctor ((b * in_channels + c) * in_height + ih) * in_width + iw; index_t filter_offset = (((o * in_channels) + c) * filter_height + kh) - * filter_width - + kw; + * filter_width + kw; sum += input[in_offset] * filter[filter_offset]; } @@ -127,11 +128,11 @@ struct DepthwiseConv2dFunctor } } - MaceStatus operator()(const Tensor *input, - const Tensor *filter, - const Tensor *bias, - Tensor *output, - StatsFuture *future) { + MaceStatus operator()(const Tensor *input, // NCHW + const Tensor *filter, // OIHW + const Tensor *bias, + Tensor *output, // NCHW + StatsFuture *future) { MACE_UNUSED(future); MACE_CHECK_NOTNULL(input); MACE_CHECK_NOTNULL(filter); @@ -284,6 +285,212 @@ struct DepthwiseConv2dFunctor } }; +template<> +struct DepthwiseConv2dFunctor + : public DepthwiseConv2dFunctorBase { + DepthwiseConv2dFunctor(const int *strides, + const Padding padding_type, + const std::vector &paddings, + const int *dilations, + const ActivationType activation, + const float relux_max_limit) + : DepthwiseConv2dFunctorBase(strides, + padding_type, + paddings, + dilations, + activation, + relux_max_limit) {} + + void DepthwiseConv2dGeneral(const uint8_t *input, + const uint8_t *filter, + const int32_t *bias, + const index_t *in_shape, + const index_t *out_shape, + const index_t *filter_shape, + const int32_t input_zero, + const int32_t filter_zero, + const int32_t output_zero, + const float output_multiplier, + const int *stride_hw, + const int *dilation_hw, + const int *pad_hw, + uint8_t *output) { +#pragma omp parallel for collapse(2) + for (index_t b = 0; b < out_shape[0]; ++b) { + for (index_t h = 0; h < out_shape[1]; ++h) { + for (index_t w = 0; w < out_shape[2]; ++w) { + for (index_t m = 0; m < out_shape[3]; ++m) { + const index_t filter_height = filter_shape[0]; + const index_t filter_width = filter_shape[1]; + const index_t in_channels = filter_shape[2]; + const index_t depth_multiplier = filter_shape[3]; + const index_t in_height = in_shape[1]; + const index_t in_width = in_shape[2]; + const index_t out_height = out_shape[1]; + const index_t out_width = out_shape[2]; + const index_t out_channels = out_shape[3]; + index_t out_offset = + ((b * out_height + h) * out_width + w) * out_channels + m; + index_t c = m / depth_multiplier; + index_t o = m % depth_multiplier; + index_t ih_base = h * stride_hw[0] - pad_hw[0]; + index_t iw_base = w * stride_hw[1] - pad_hw[1]; + int32_t sum = 0; + for (index_t kh = 0; kh < filter_height; ++kh) { + const index_t ih = ih_base + kh * dilation_hw[0]; + for (index_t kw = 0; kw < filter_width; ++kw) { + const index_t iw = iw_base + kw * dilation_hw[1]; + if (ih >= 0 && ih < in_height && iw >= 0 && iw < in_width) { + index_t in_offset = + ((b * in_height + ih) * in_width + iw) * in_channels + c; + index_t filter_offset = + ((kh * filter_width + kw) * in_channels + c) + * depth_multiplier + o; + + sum += (input[in_offset] - input_zero) * + (filter[filter_offset] - filter_zero); + } + } + } + if (bias) { + sum += bias[m]; + } + sum = static_cast(std::round(sum * output_multiplier)); + sum += output_zero; + output[out_offset] = + static_cast(std::min(255, std::max(0, sum))); + } + } + } + } + } + + inline tflite::Dims<4> ShapeToTfliteDims(const std::vector &shape) { + tflite::Dims<4> d; + for (int i = 0; i < 4; ++i) { + int src = static_cast(shape.size() - i - 1); + if (src >= 0) { + d.sizes[i] = shape[src]; + } else { + d.sizes[i] = 1; + } + } + d.strides[0] = 1; + for (int i = 1; i < 4; i++) { + d.strides[i] = d.strides[i - 1] * d.sizes[i - 1]; + } + return d; + } + + MaceStatus operator()(const Tensor *input, // NHWC + const Tensor *filter, // HWIM + const Tensor *bias, + Tensor *output, // NHWC + StatsFuture *future) { + MACE_UNUSED(future); + MACE_CHECK_NOTNULL(input); + MACE_CHECK_NOTNULL(filter); + MACE_CHECK_NOTNULL(output); + + std::vector output_shape(4); + std::vector paddings(2); + + // reuse OHWI format, only for calculating output + std::vector ohwi_shape{ + filter->dim(2) * filter->dim(3), filter->dim(0), filter->dim(1), 1}; + if (paddings_.empty()) { + CalcPaddingAndOutputSize(input->shape().data(), + NHWC, + ohwi_shape.data(), + OHWI, + dilations_, + strides_, + padding_type_, + output_shape.data(), + paddings.data()); + } else { + paddings = paddings_; + CalcOutputSize(input->shape().data(), + NHWC, + ohwi_shape.data(), + OHWI, + paddings_.data(), + dilations_, + strides_, + RoundType::FLOOR, + output_shape.data()); + } + + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + output->Clear(); + + MACE_CHECK(output->dim(0) == input->dim(0), + "Input/Output batch size mismatch"); + MACE_CHECK(filter->dim(2) == input->dim(3), filter->dim(2), " != ", + input->dim(3)); + + index_t out_channels = output_shape[3]; + index_t stride_h = strides_[0]; + index_t stride_w = strides_[1]; + index_t dilation_h = dilations_[0]; + index_t dilation_w = dilations_[1]; + int pad_top = paddings[0] >> 1; + int pad_left = paddings[1] >> 1; + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard filter_guard(filter); + Tensor::MappingGuard bias_guard(bias); + Tensor::MappingGuard output_guard(output); + auto input_data = input->data(); + auto filter_data = filter->data(); + auto output_data = output->mutable_data(); + + if (dilation_h == 1 && dilation_w == 1) { + std::vector bias_shape{out_channels}; + std::unique_ptr zero_bias; + const int32_t *bias_data = nullptr; + if (bias == nullptr) { + zero_bias.reset( + new Tensor(GetDeviceAllocator(DeviceType::CPU), DT_INT32)); + zero_bias->Resize(bias_shape); + zero_bias->Clear(); + bias_data = zero_bias->data(); + } else { + bias_data = bias->data(); + } + + int32_t quantized_multiplier; + int32_t right_shift; + GetOutputMultiplierAndShift(input->scale(), filter->scale(), + output->scale(), &quantized_multiplier, + &right_shift); + // 1HWO + std::vector filter_shape{ + 1, filter->dim(0), filter->dim(1), filter->dim(2) * filter->dim(3)}; + + tflite::optimized_ops::DepthwiseConv( + input_data, ShapeToTfliteDims(input->shape()), -input->zero_point(), + filter_data, ShapeToTfliteDims(filter_shape), -filter->zero_point(), + bias_data, ShapeToTfliteDims(bias_shape), stride_w, stride_h, + pad_left, pad_top, filter->dim(3), output->zero_point(), + quantized_multiplier, right_shift, 0, 255, output_data, + ShapeToTfliteDims(output->shape())); + } else { + auto bias_data = bias == nullptr ? nullptr : bias->data(); + float output_multiplier = + input->scale() * filter->scale() / output->scale(); + const int pad_hw[2] = {pad_top, pad_left}; + DepthwiseConv2dGeneral( + input_data, filter_data, bias_data, input->shape().data(), + output_shape.data(), filter->shape().data(), input->zero_point(), + filter->zero_point(), output->zero_point(), output_multiplier, + strides_, dilations_, pad_hw, output_data); + } + + return MACE_SUCCESS; + } +}; + #ifdef MACE_ENABLE_OPENCL template struct DepthwiseConv2dFunctor diff --git a/mace/kernels/quantize.h b/mace/kernels/quantize.h index 3dc4c48ee94416fe51e4055f8aa955f6644271b6..7030d79fd788048893ae8ef65f467fe55bdf3fcd 100644 --- a/mace/kernels/quantize.h +++ b/mace/kernels/quantize.h @@ -155,6 +155,18 @@ inline void QuantizeMultiplier(double multiplier, MACE_CHECK(*output_multiplier <= std::numeric_limits::max()); } +inline void GetOutputMultiplierAndShift( + const float lhs_scale, const float rhs_scale, const float output_scale, + int32_t *quantized_multiplier, int *right_shift) { + float real_multiplier = lhs_scale * rhs_scale / output_scale; + MACE_CHECK(real_multiplier > 0.f && real_multiplier < 1.f, real_multiplier); + + int exponent; + QuantizeMultiplier(real_multiplier, quantized_multiplier, &exponent); + *right_shift = -exponent; + MACE_CHECK(*right_shift >= 0); +} + template struct QuantizeFunctor; diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc index 6616b8badcbf3a1496810fc4f36daab77624e4a2..313cd35bbcc4253c029535666cfba30b3ba1fdd2 100644 --- a/mace/ops/conv_2d_benchmark.cc +++ b/mace/ops/conv_2d_benchmark.cc @@ -122,14 +122,10 @@ void Conv2d(int iters, // Add input data net.AddRandomInput( "Input", {batch, height, width, channels}); - Tensor *input = net.GetTensor("Input"); - input->SetScale(0.00705); - input->SetZeroPoint(114); + net.GetTensor("Input")->SetScale(0.1); net.AddRandomInput( "Filter", {output_channels, kernel_h, kernel_w, channels}); - Tensor *filter = net.GetTensor("Filter"); - filter->SetScale(0.0066); - filter->SetZeroPoint(113); + net.GetTensor("Filter")->SetScale(0.1); net.AddRandomInput("Bias", {output_channels}); OpDefBuilder("Conv2D", "Conv2dTest") .Input("Input") @@ -144,9 +140,7 @@ void Conv2d(int iters, net.Setup(DeviceType::CPU); - Tensor *output = net.GetTensor("Output"); - output->SetScale(0.0107); - output->SetZeroPoint(118); + net.GetTensor("Output")->SetScale(0.1); // Warm-up for (int i = 0; i < 2; ++i) { diff --git a/mace/ops/depthwise_conv2d.cc b/mace/ops/depthwise_conv2d.cc index 66396f6002b80219280b98015910eedab51ef0a6..61f87e5f98be2e5e7466e0f1ad5c16608f52a73b 100644 --- a/mace/ops/depthwise_conv2d.cc +++ b/mace/ops/depthwise_conv2d.cc @@ -24,6 +24,12 @@ void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry) { .Build(), DepthwiseConv2dOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthwiseConv2d") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + DepthwiseConv2dOp); + #ifdef MACE_ENABLE_OPENCL MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthwiseConv2d") .Device(DeviceType::GPU) diff --git a/mace/ops/depthwise_conv2d_benchmark.cc b/mace/ops/depthwise_conv2d_benchmark.cc index ea847fd103f6c5af97ee7cea7a2cac7762c2a3c7..60abfaf3bb99b3dc62f008983de7bec110b618a2 100644 --- a/mace/ops/depthwise_conv2d_benchmark.cc +++ b/mace/ops/depthwise_conv2d_benchmark.cc @@ -41,17 +41,32 @@ void DepthwiseConv2d(int iters, // Add input data if (D == DeviceType::CPU) { - net.AddRandomInput("Input", - {batch, input_channels, height, width}); + if (DataTypeToEnum::value != DT_UINT8) { + net.AddRandomInput( + "Input", {batch, input_channels, height, width}); + } else { + net.AddRandomInput( + "Input", {batch, height, width, input_channels}); + net.GetTensor("Input")->SetScale(0.1); + } + } else if (D == DeviceType::GPU) { - net.AddRandomInput("Input", - {batch, height, width, input_channels}); + net.AddRandomInput( + "Input", {batch, height, width, input_channels}); } else { MACE_NOT_IMPLEMENTED; } - net.AddRandomInput( - "Filter", {multiplier, input_channels, kernel_h, kernel_w}); - net.AddRandomInput("Bias", {input_channels * multiplier}); + if (DataTypeToEnum::value != DT_UINT8) { + net.AddRandomInput( + "Filter", {multiplier, input_channels, kernel_h, kernel_w}); + net.AddRandomInput("Bias", {input_channels * multiplier}); + } else { + net.AddRandomInput( + "Filter", {kernel_h, kernel_w, input_channels, multiplier}); + net.GetTensor("Filter")->SetScale(0.1); + net.AddRandomInput( + "Bias", {input_channels * multiplier}); + } if (D == DeviceType::CPU) { OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest") @@ -87,6 +102,10 @@ void DepthwiseConv2d(int iters, net.Setup(D); + if (DataTypeToEnum::value == DT_UINT8) { + net.GetTensor("Output")->SetScale(0.1); + } + // Warm-up for (int i = 0; i < 2; ++i) { net.Run(); @@ -132,7 +151,8 @@ void DepthwiseConv2d(int iters, #define MACE_BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, M) \ MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, CPU); \ MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, float, GPU); \ - MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, half, GPU); + MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, half, GPU); \ + MACE_BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, M, uint8_t, CPU); MACE_BM_DEPTHWISE_CONV_2D(1, 32, 112, 112, 3, 3, 1, SAME, 1); MACE_BM_DEPTHWISE_CONV_2D(1, 32, 56, 56, 3, 3, 2, VALID, 1); @@ -156,7 +176,14 @@ MACE_BM_DEPTHWISE_CONV_2D(1, 3, 512, 512, 3, 3, 2, SAME, 1); MACE_BM_DEPTHWISE_CONV_2D(1, 3, 112, 112, 3, 3, 2, VALID, 1); MACE_BM_DEPTHWISE_CONV_2D(1, 3, 224, 224, 3, 3, 2, SAME, 1); MACE_BM_DEPTHWISE_CONV_2D(1, 8, 224, 224, 3, 3, 2, SAME, 1); - +MACE_BM_DEPTHWISE_CONV_2D(1, 128, 56, 56, 3, 3, 1, SAME, 1); +MACE_BM_DEPTHWISE_CONV_2D(1, 128, 56, 56, 3, 3, 2, SAME, 1); +MACE_BM_DEPTHWISE_CONV_2D(1, 256, 28, 28, 3, 3, 1, SAME, 1); +MACE_BM_DEPTHWISE_CONV_2D(1, 256, 28, 28, 3, 3, 2, SAME, 1); +MACE_BM_DEPTHWISE_CONV_2D(1, 512, 14, 14, 3, 3, 1, SAME, 1); +MACE_BM_DEPTHWISE_CONV_2D(1, 512, 14, 14, 3, 3, 2, SAME, 1); +MACE_BM_DEPTHWISE_CONV_2D(1, 1024, 7, 7, 3, 3, 1, SAME, 1); +MACE_BM_DEPTHWISE_CONV_2D(1, 1024, 7, 7, 3, 3, 2, SAME, 1); } // namespace test } // namespace ops diff --git a/mace/ops/depthwise_conv2d_test.cc b/mace/ops/depthwise_conv2d_test.cc index 254fb4c2cecfcbc3da1a55f8fd67974c4544e6e5..0da8041d9b69a71a3148aba97ec748217398517e 100644 --- a/mace/ops/depthwise_conv2d_test.cc +++ b/mace/ops/depthwise_conv2d_test.cc @@ -351,6 +351,164 @@ TEST_F(DepthwiseConv2dOpTest, OpenCLUnalignedNxNS12Half) { TestNxNS12(107, 113); } +namespace { + +void QuantSimpleValidTest() { + testing::internal::LogToStderr(); + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray( + "Input", {1, 3, 3, 2}, + {31, 98, 1, 54, 197, 172, 70, 146, 255, 71, 24, 182, 28, 78, 85, 96, 180, + 59}, 0.00735299, 86); + net.AddInputFromArray( + "Filter", {3, 3, 2, 1}, + {212, 239, 110, 170, 216, 91, 162, 161, 255, 2, 10, 120, 183, 101, 100, + 33, 137, 51}, 0.0137587, 120); + net.AddInputFromArray("Bias", {2}, {2, 2}); + OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") + .Input("Input") + .Input("Filter") + .Input("Bias") + .Output("Output") + .AddIntsArg("strides", {1, 1}) + .AddIntArg("padding", Padding::VALID) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", static_cast(DT_UINT8)) + .Finalize(net.NewOperatorDef()); + + net.Setup(CPU); + Tensor *output = net.GetTensor("Output"); + output->SetScale(0.013241); + output->SetZeroPoint(0); + // Run + net.Run(); + + // Check + auto expected = CreateTensor({1, 1, 1, 2}, {255, 21}); + + ExpectTensorNear(*expected, *net.GetOutput("Output")); +} + +void TestQuant(const index_t batch, + const index_t multiplier, + const index_t in_channels, + const index_t in_height, + const index_t in_width, + const index_t k_height, + const index_t k_width, + enum Padding padding_type, + const std::vector &strides) { + OpsTestNet net; + const index_t out_channels = multiplier * in_channels; + net.AddRandomInput( + "Input", {batch, in_height, in_width, in_channels}, false); + net.AddRandomInput( + "Filter", {k_height, k_width, in_channels, multiplier}, false); + net.AddRandomInput("Bias", {out_channels}); + net.TransformDataFormat( + "Input", NHWC, "InputNCHW", NCHW); + net.TransformDataFormat( + "Filter", HWIO, "FilterOIHW", OIHW); + + OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") + .Input("InputNCHW") + .Input("FilterOIHW") + .Input("Bias") + .Output("OutputNCHW") + .AddIntsArg("strides", strides) + .AddIntArg("padding", padding_type) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", static_cast(DT_FLOAT)) + .Finalize(net.NewOperatorDef()); + net.RunOp(CPU); + net.TransformDataFormat( + "OutputNCHW", NCHW, "Output", NHWC); + + OpDefBuilder("Quantize", "QuantizeFilter") + .Input("Filter") + .Output("QuantizedFilter") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .AddIntArg("non_zero", true) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + OpDefBuilder("Quantize", "QuantizeInput") + .Input("Input") + .Output("QuantizedInput") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .AddIntArg("non_zero", true) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + OpDefBuilder("Quantize", "QuantizeOutput") + .Input("Output") + .Output("ExpectedQuantizedOutput") + .OutputType({DT_UINT8}) + .AddIntArg("T", DT_UINT8) + .AddIntArg("non_zero", true) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + Tensor *q_filter = net.GetTensor("QuantizedFilter"); + Tensor *q_input = net.GetTensor("QuantizedInput"); + Tensor *bias = net.GetTensor("Bias"); + auto bias_data = bias->data(); + std::vector q_bias(bias->size()); + kernels::QuantizeWithScaleAndZeropoint( + bias_data, bias->size(), q_input->scale() * q_filter->scale(), 0, + q_bias.data()); + net.AddInputFromArray( + "QuantizedBias", {out_channels}, q_bias); + OpDefBuilder("DepthwiseConv2d", "QuantizedDepthwiseConv2DTest") + .Input("QuantizedInput") + .Input("QuantizedFilter") + .Input("QuantizedBias") + .Output("QuantizedOutput") + .AddIntsArg("strides", strides) + .AddIntArg("padding", padding_type) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", static_cast(DT_UINT8)) + .Finalize(net.NewOperatorDef()); + net.Setup(DeviceType::CPU); + Tensor *eq_output = net.GetTensor("ExpectedQuantizedOutput"); + Tensor *q_output = net.GetTensor("QuantizedOutput"); + q_output->SetScale(eq_output->scale()); + q_output->SetZeroPoint(eq_output->zero_point()); + net.Run(); + + OpDefBuilder("Dequantize", "DeQuantizeTest") + .Input("QuantizedOutput") + .Output("DequantizedOutput") + .OutputType({DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + net.RunOp(); + + // Check + ExpectTensorSimilar(*net.GetOutput("Output"), + *net.GetTensor("DequantizedOutput"), 0.01); +} +} // namespace + +TEST_F(DepthwiseConv2dOpTest, Quant) { + QuantSimpleValidTest(); + TestQuant(1, 1, 2, 3, 3, 3, 3, VALID, {1, 1}); + TestQuant(1, 1, 2, 3, 3, 3, 3, SAME, {1, 1}); + TestQuant(1, 1, 2, 3, 3, 3, 3, FULL, {1, 1}); + TestQuant(1, 2, 2, 3, 3, 3, 3, SAME, {1, 1}); + TestQuant(1, 2, 2, 3, 3, 3, 3, SAME, {2, 2}); + TestQuant(1, 1, 512, 14, 14, 3, 3, SAME, {1, 1}); + TestQuant(1, 1, 512, 14, 13, 5, 5, SAME, {2, 2}); + TestQuant(1, 1, 256, 28, 28, 3, 3, SAME, {1, 1}); + TestQuant(1, 1, 128, 56, 56, 3, 3, SAME, {2, 2}); + TestQuant(3, 1, 128, 56, 56, 3, 3, SAME, {2, 2}); +} + } // namespace test } // namespace ops } // namespace mace diff --git a/third_party/tflite/LICENSE b/third_party/tflite/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4862420c0234f7542d4fe8f3520516b484a64aed --- /dev/null +++ b/third_party/tflite/LICENSE @@ -0,0 +1,203 @@ +Copyright 2018 The TensorFlow Authors. All rights reserved. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2017, The TensorFlow Authors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License.