From e2b1d7abc9f71c027f28ba19ec6d34bfcc58f811 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 20 Oct 2016 09:14:33 -0800 Subject: [PATCH] Add Requantize op. This is a step in splitting QuantizeDownAndShrinkRange into separate ops for finding the min/max and doing the requantization step. Change: 136727927 --- tensorflow/contrib/makefile/tf_op_files.txt | 1 + tensorflow/core/kernels/BUILD | 21 +++- tensorflow/core/kernels/requantize.cc | 92 ++++++++++++++++++ tensorflow/core/kernels/requantize_op_test.cc | 97 +++++++++++++++++++ tensorflow/core/ops/math_ops.cc | 42 ++++++++ tensorflow/core/ops/math_ops_test.cc | 13 +++ 6 files changed, 263 insertions(+), 3 deletions(-) create mode 100644 tensorflow/core/kernels/requantize.cc create mode 100644 tensorflow/core/kernels/requantize_op_test.cc diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index f77b98e724b..1573c2d356e 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -152,6 +152,7 @@ tensorflow/core/kernels/quantized_concat_op.cc tensorflow/core/kernels/quantized_conv_ops.cc tensorflow/core/kernels/quantized_matmul_op.cc tensorflow/core/kernels/quantized_pooling_ops.cc +tensorflow/core/kernels/requantize.cc tensorflow/core/ops/training_ops.cc tensorflow/core/ops/string_ops.cc tensorflow/core/ops/state_ops.cc diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 27719faaa67..add489a293f 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2385,6 +2385,7 @@ filegroup( "quantized_matmul_op.cc", "quantized_pooling_ops.cc", "reference_gemm.h", + "requantize.cc", ], visibility = ["//visibility:public"], ) @@ -2480,6 +2481,7 @@ tf_kernel_library( "quantized_conv_ops.cc", "quantized_matmul_op.cc", "quantized_pooling_ops.cc", + "requantize.cc", ], hdrs = [ "quantization_utils.h", @@ -2508,10 +2510,23 @@ tf_cc_test( srcs = ["quantize_down_and_shrink_range_op_test.cc"], deps = [ ":quantized_ops", - "//tensorflow/core:array_ops_op_lib", "//tensorflow/core:framework", - "//tensorflow/core:math_ops_op_lib", - "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/kernels:ops_util", + ], +) + +tf_cc_test( + name = "requantize_op_test", + size = "small", + srcs = ["requantize_op_test.cc"], + deps = [ + ":quantized_ops", + "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/core/kernels/requantize.cc b/tensorflow/core/kernels/requantize.cc new file mode 100644 index 00000000000..865970a99e0 --- /dev/null +++ b/tensorflow/core/kernels/requantize.cc @@ -0,0 +1,92 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/array_ops.cc. + +#define EIGEN_USE_THREADS + +#include + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/quantization_utils.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +class RequantizeOp : public OpKernel { + public: + explicit RequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& input = ctx->input(0); + const float input_min_float = ctx->input(1).flat()(0); + const float input_max_float = ctx->input(2).flat()(0); + const float requested_output_min_float = ctx->input(3).flat()(0); + const float requested_output_max_float = ctx->input(4).flat()(0); + + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output)); + Tensor* output_min = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &output_min)); + Tensor* output_max = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(2, TensorShape({}), &output_max)); + + OP_REQUIRES( + ctx, requested_output_min_float <= 0.0f, + errors::InvalidArgument("requested_output_min must be <= 0, but got ", + requested_output_min_float)); + OP_REQUIRES( + ctx, requested_output_max_float >= 0.0f, + errors::InvalidArgument("requested_output_max must be <= 0, but got ", + requested_output_max_float)); + + auto input_array = input.flat(); + +#if 0 + // This is the reference, non-eigen implementation: + auto output_array = output->flat(); + RequantizeManyInNewRange( + input_array.data(), input_array.size(), + input_min_float, input_max_float, + requested_output_min_float, requested_output_max_float, + output_array.data()); +#endif + + if (input_array.size() > 0) { + RequantizeManyInNewRangeUsingEigen( + ctx->eigen_device(), input, input_min_float, + input_max_float, requested_output_min_float, + requested_output_max_float, output); + } + + output_min->flat().setConstant(requested_output_min_float); + output_max->flat().setConstant(requested_output_max_float); + } +}; + +REGISTER_KERNEL_BUILDER(Name("Requantize") + .Device(DEVICE_CPU) + .TypeConstraint("Tinput") + .TypeConstraint("out_type"), + RequantizeOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/requantize_op_test.cc b/tensorflow/core/kernels/requantize_op_test.cc new file mode 100644 index 00000000000..e7674eb2946 --- /dev/null +++ b/tensorflow/core/kernels/requantize_op_test.cc @@ -0,0 +1,97 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +class RequantizeTest : public OpsTestBase { + protected: + void ConfigureRequantize() { + TF_ASSERT_OK(NodeDefBuilder("requantize", "Requantize") + .Input(FakeInput(DT_QINT32)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_FLOAT)) + .Attr("Tinput", DataTypeToEnum::v()) + .Attr("out_type", DataTypeToEnum::v()) + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + } +}; + +// Runs a manually generated array through the operator, and makes sure that the +// results match the expected hand-calculated values. +TEST_F(RequantizeTest, HandCraftedRequantize) { + ConfigureRequantize(); + const int value_count = 3; + + // Requantize to -1 to 1. + AddInputFromArray(TensorShape({value_count}), + {-(1 << 23), 0, (1 << 23)}); + AddInputFromArray(TensorShape({1}), {-256.0f}); + AddInputFromArray(TensorShape({1}), {256.0f}); + AddInputFromArray(TensorShape({1}), {-1.0f}); + AddInputFromArray(TensorShape({1}), {1.0f}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_QUINT8, TensorShape({value_count})); + test::FillValues(&expected, {0, 128, 255}); + test::ExpectTensorEqual(expected, *GetOutput(0)); + test::ExpectTensorEqual(test::AsScalar(-1.0f), *GetOutput(1)); + test::ExpectTensorEqual(test::AsScalar(1.0f), *GetOutput(2)); +} + +TEST_F(RequantizeTest, InvalidOutputMin) { + ConfigureRequantize(); + const int value_count = 3; + + AddInputFromArray(TensorShape({value_count}), + {-(1 << 23), 0, (1 << 23)}); + AddInputFromArray(TensorShape({1}), {-256.0f}); + AddInputFromArray(TensorShape({1}), {256.0f}); + AddInputFromArray(TensorShape({1}), {0.01f}); + AddInputFromArray(TensorShape({1}), {1.0f}); + EXPECT_EQ("requested_output_min must be <= 0, but got 0.01", + RunOpKernel().error_message()); +} + +TEST_F(RequantizeTest, InvalidOutputMax) { + ConfigureRequantize(); + const int value_count = 3; + + AddInputFromArray(TensorShape({value_count}), + {-(1 << 23), 0, (1 << 23)}); + AddInputFromArray(TensorShape({1}), {-256.0f}); + AddInputFromArray(TensorShape({1}), {256.0f}); + AddInputFromArray(TensorShape({1}), {-1.0f}); + AddInputFromArray(TensorShape({1}), {-0.001f}); + EXPECT_EQ("requested_output_max must be <= 0, but got -0.001", + RunOpKernel().error_message()); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 115fdae393b..732f6e7c1ed 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -2236,6 +2236,48 @@ out_type: The type of the output. Should be a lower bit depth than Tinput. )doc"); +REGISTER_OP("Requantize") + .Input("input: Tinput") + .Input("input_min: float") + .Input("input_max: float") + .Input("requested_output_min: float") + .Input("requested_output_max: float") + .Output("output: out_type") + .Output("output_min: float") + .Output("output_max: float") + .Attr("Tinput: quantizedtype") + .Attr("out_type: quantizedtype") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }) + .Doc(R"doc( +Convert the quantized 'input' tensor into a lower-precision 'output', using the +output range specified with 'requested_output_min' and 'requested_output_max'. + +[input_min, input_max] are scalar floats that specify the range for the float +interpretation of the 'input' data. For example, if input_min is -1.0f and +input_max is 1.0f, and we are dealing with quint16 quantized data, then a 0 +value in the 16-bit data should be interpreted as -1.0f, and a 65535 means 1.0f. + +input_min: The float value that the minimum quantized input value represents. +input_max: The float value that the maximum quantized input value represents. +Tinput: The type of the input. +requested_output_min: The float value that the minimum quantized output value represents. +requested_output_max: The float value that the maximum quantized output value represents. +output_min: The requested_output_min value is copied into this output. +output_max: The requested_output_max value is copied into this output. +out_type: The type of the output. Should be a lower bit depth than Tinput. + +)doc"); + // Deprecated ops: REGISTER_OP("BatchFFT") .Input("input: complex64") diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index 5d0b75c5794..69771ab1a02 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -446,4 +446,17 @@ TEST(MathOpsTest, Betainc_ShapeFn) { INFER_ERROR("must be equal", op, "[1,2];[];[1,2,3]"); } +TEST(MathOpsTest, Requantize_ShapeFn) { + ShapeInferenceTestOp op("Requantize"); + + INFER_OK(op, "?;?;?;?;?", "in0;[];[]"); + INFER_OK(op, "?;[];[];[];[]", "in0;[];[]"); + + // Rank checks on input scalars. + INFER_ERROR("must be rank 0", op, "?;[1];?;?;?"); + INFER_ERROR("must be rank 0", op, "?;?;[2];?;?"); + INFER_ERROR("must be rank 0", op, "?;?;?;[3];?"); + INFER_ERROR("must be rank 0", op, "?;?;?;?;[4]"); +} + } // end namespace tensorflow -- GitLab