From 8e4b225fe4b740d6f82e0ba03c19b1030e8157ea Mon Sep 17 00:00:00 2001 From: achao2013 Date: Wed, 11 Jul 2018 13:24:59 +0800 Subject: [PATCH] Add fake_quantize_op. (#11359) * Add a fake_quantize_op, which quantize an input tensor to a tensor with lower bits. --- paddle/fluid/operators/fake_quantize_op.cc | 112 ++++++++ paddle/fluid/operators/fake_quantize_op.cu | 272 ++++++++++++++++++ paddle/fluid/operators/fake_quantize_op.h | 155 ++++++++++ .../multi_thread/convert_protobin.sh | 2 +- .../sequence/convert_protobin.sh | 2 +- .../sparse_binary/convert_protobin.sh | 2 +- .../tests/unittests/test_fake_quantize_op.py | 51 ++++ 7 files changed, 593 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/operators/fake_quantize_op.cc create mode 100644 paddle/fluid/operators/fake_quantize_op.cu create mode 100644 paddle/fluid/operators/fake_quantize_op.h mode change 120000 => 100644 paddle/legacy/capi/examples/model_inference/multi_thread/convert_protobin.sh mode change 120000 => 100644 paddle/legacy/capi/examples/model_inference/sequence/convert_protobin.sh mode change 120000 => 100644 paddle/legacy/capi/examples/model_inference/sparse_binary/convert_protobin.sh create mode 100644 python/paddle/fluid/tests/unittests/test_fake_quantize_op.py diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc new file mode 100644 index 0000000000..a91e0f520e --- /dev/null +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -0,0 +1,112 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fake_quantize_op.h" +#include + +namespace paddle { +namespace operators { + +class FakeQuantizeOp : public framework::OperatorWithKernel { + public: + FakeQuantizeOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of FakeQuantizeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of FakeQuantizeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("OutMovingScale"), + "OutMovingScale(Out) of FakeQuantizeOp should not be null"); + // if (ctx->HasInput("InMovingScale")) { + ctx->SetOutputDim("OutMovingScale", ctx->GetInputDim("InMovingScale")); + //} + // if (ctx->HasInput("InScales")) { + PADDLE_ENFORCE(ctx->HasOutput("OutScales"), + "OutScales(Out) of FakeQuantizeOp should not be null"); + ctx->SetOutputDim("OutScales", ctx->GetInputDim("InScales")); + // PADDLE_ENFORCE_EQ(ctx->Inputs("InScales")[0], + // ctx->Outputs("OutScales")[0], + // "Mean and MeanOut should share the same memory"); + //} + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class FakeQuantizeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) Input tensor of scale operator."); + AddInput("InScales", "(Tensor) scale buffer, used in static quantization.") + .AsDispensable(); + AddInput("InMovingScale", "Last scale, used in static quantization.") + .AsDispensable(); + AddInput("InCurrentIter", + "Last iteration number, used in static quantization.") + .AsDispensable(); + AddOutput("Out", "(Tensor) Output of quantized low level tensor."); + AddOutput("OutScales", + "(Tensor) scale buffer, used in static quantization.") + .AsDispensable(); + AddOutput("OutMovingScale", " Current scale"); + AddOutput("OutCurrentIter", "Current iteration number.").AsDispensable(); + AddAttr("quantize_type", + "(string, default abs_max)" + "The scaling tpe of the quantize operator.") + .SetDefault("abs_max"); + AddAttr("window_size", "(int, default 10000)").SetDefault(10000); + AddAttr("bit_length", "(int, default 8)") + .SetDefault(8) + .AddCustomChecker([](const int &bit_length) { + PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16, + "'bit_length' should be between 1 and 16."); + }); + AddAttr("is_test", "").SetDefault(false); + AddComment(R"DOC( +FakeQuantize operator + +quantize_type = abs_max: + + $$scale = max(abs(x))$$ + +quantize_type = range_abs_max: + + $$scale = max(max(abs(x)), history_abs_max)$$ + +quantize_type = moving_average_abs_max: + + $$scale = 0.1*scale+0.9*new_abs_max)$$ + +$$Out = scale*X$$ + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(fake_quantize, ops::FakeQuantizeOp, ops::FakeQuantizeOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + fake_quantize, + ops::FakeQuantizeKernel, + ops::FakeQuantizeKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu new file mode 100644 index 0000000000..be0c6730a5 --- /dev/null +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -0,0 +1,272 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/operators/fake_quantize_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +template +__global__ void FindAbsMaxKernel(const int n, const T* in, T* out) { + int bid = threadIdx.x + blockIdx.x * blockDim.x; + int tid = threadIdx.x; + + extern __shared__ T shared_max_data[]; + if (gridDim.x > 1) { + shared_max_data[tid] = T(0); + for (int i = bid; i < n; i += blockDim.x * gridDim.x) { + T tmp = fabs(in[i]); + if (tmp > shared_max_data[tid]) { + shared_max_data[tid] = tmp; + } + } + } else { + if (bid < n) { + shared_max_data[tid] = fabs(in[bid]); + } else { + shared_max_data[tid] = T(0); + } + } + __syncthreads(); + + for (int i = blockDim.x / 2; i > 0; i >>= 1) { + if (tid < i && shared_max_data[tid] < shared_max_data[tid + i]) { + shared_max_data[tid] = shared_max_data[tid + i]; + } + __syncthreads(); + } + if (tid == 0) { + out[blockIdx.x] = shared_max_data[0]; + } +} + +float FindAbsMaxGpu(const platform::CUDADeviceContext& ctx, const float* array, + int length) { + float host_max; + int kNumTheads = 1024; + int gridDimx = (kNumTheads - 1 + length) / kNumTheads; + gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx; + framework::Tensor t; + float* device_max = t.mutable_data(framework::make_ddim({gridDimx}), + platform::CUDAPlace()); + FindAbsMaxKernel<<>>(length, array, device_max); + FindAbsMaxKernel< + float><<<1, kNumTheads, kNumTheads * sizeof(float), ctx.stream()>>>( + gridDimx, device_max, device_max); + PADDLE_ENFORCE_EQ( + cudaMemcpy(&host_max, device_max, sizeof(float), cudaMemcpyDeviceToHost), + cudaSuccess, "cudaMemcpy failed"); + return host_max; +} + +template +__global__ void ApplySaturateKernel(const int n, const T* in, T* out, + int* num_saturate, const T min, + const T max) { + int bid = threadIdx.x + blockIdx.x * blockDim.x; + int tid = threadIdx.x; + + extern __shared__ int shared_count[]; + shared_count[tid] = 0; + for (int i = bid; i < n; i += blockDim.x * gridDim.x) { + if (in[i] > max) { + out[i] = max; + shared_count[tid] += 1; + } else if (in[i] < min) { + out[i] = min; + shared_count[tid] += 1; + } else { + out[i] = in[i]; + } + } + __syncthreads(); + + for (int i = blockDim.x / 2; i > 0; i >>= 1) { + if (tid < i) { + shared_count[tid] += shared_count[tid + i]; + } + __syncthreads(); + } + if (tid == 0) { + num_saturate[blockIdx.x] = shared_count[0]; + } +} + +template +__global__ void ReduceKernel(const int n, const T* in, T* out) { + int tid = threadIdx.x; + extern __shared__ T shared_sum[]; + if (tid < n) { + shared_sum[tid] = in[tid]; + } else { + shared_sum[tid] = T(0); + } + __syncthreads(); + // blockDim.x must >= n + for (int i = (n + 1) / 2; i > 0; i >>= 1) { + if (tid < i) { + shared_sum[tid] += shared_sum[tid + i]; + } + __syncthreads(); + } + if (tid == 0) { + out[0] = shared_sum[0]; + } +} + +template +int ApplySaturateGpu(const platform::CUDADeviceContext& ctx, const int n, + const T* in, T* out, const T min, const T max) { + int host_num_saturate; + int kNumTheads = 1024; + int gridDimx = (n + kNumTheads - 1) / kNumTheads; + gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx; + framework::Tensor t; + int* device_num_saturate = t.mutable_data( + framework::make_ddim({gridDimx}), platform::CUDAPlace()); + ApplySaturateKernel< + T><<>>( + n, in, out, device_num_saturate, min, max); + ReduceKernel<<<1, kNumTheads, kNumTheads * sizeof(T), ctx.stream()>>>( + gridDimx, device_num_saturate, device_num_saturate); + PADDLE_ENFORCE_EQ(cudaSuccess, + cudaMemcpy(&host_num_saturate, device_num_saturate, + sizeof(int), cudaMemcpyDeviceToHost), + "cudaMemcpy failed"); + return host_num_saturate; +} + +template +class FakeQuantizeCUDAKernel : public framework::OpKernel { + public: + T FindRangeAbsMax(const platform::CUDADeviceContext& ctx, + framework::Tensor* scale_list, framework::Tensor* out_scale, + const T& cur_scale, int window_size, + int current_iter) const { + T* sl = scale_list->mutable_data(platform::CPUPlace()); + T remove_tmp = sl[current_iter]; + sl[current_iter] = cur_scale; + T& max_scale = out_scale->mutable_data(platform::CPUPlace())[0]; + if (max_scale < cur_scale) { + max_scale = cur_scale; + } else if (fabs(remove_tmp - max_scale) < 1e-6) { + int size = (current_iter > window_size) ? window_size : current_iter; + max_scale = T(FindAbsMaxGpu(ctx, scale_list->data(), size)); + } + return max_scale; + } + + T FindMovingAverageAbsMmax(framework::Tensor* in_scale, + framework::Tensor* out_scale, + const T& cur_scale) const { + T* ins = in_scale->mutable_data(platform::CPUPlace()); + T* outs = out_scale->mutable_data(platform::CPUPlace()); + outs[0] = 0.9 * cur_scale + 0.1 * ins[0]; + return T(outs[0]); + } + + virtual void Compute(const framework::ExecutionContext& context) const { + PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), + "This kernel only runs on GPU device."); + auto& device_ctx = context.cuda_device_context(); + auto* tensor = context.Output("Out"); + auto* in = context.Input("X"); + const bool is_test = context.Attr("is_test"); + tensor->mutable_data(in->place()); + context.Output("OutMovingScale") + ->mutable_data( + context.Input("InMovingScale")->place()); + auto quantize_type = + static_cast(context.Attr("quantize_type")); + if (quantize_type == std::string("range_abs_max")) { + context.Output("OutScales") + ->mutable_data( + context.Input("InScales")->place()); + context.Output("OutCurrentIter") + ->mutable_data( + context.Input("InCurrentIter")->place()); + } + + T scale = T(1); + int window_size = context.Attr("window_size"); + T bin_cnt = (T)((1 << (context.Attr("bit_length") - 1)) - 1); + if (quantize_type == std::string("abs_max")) { + auto* saving_scale = context.Output("OutMovingScale"); + scale = (T)FindAbsMaxGpu(device_ctx, in->data(), in->numel()); + saving_scale->mutable_data(platform::CPUPlace())[0] = scale; + + auto& device_ctx = context.template device_context(); + auto* scale_list = context.Output("OutScales"); + math::SetConstant scalar; + scale_list->mutable_data(context.GetPlace()); + scalar(device_ctx, scale_list, static_cast(0)); + auto* iter = context.Output("OutCurrentIter"); + iter->mutable_data(context.GetPlace()); + scalar(device_ctx, iter, static_cast(0)); + } else if (quantize_type == std::string("range_abs_max")) { + auto* moving_scale = const_cast( + context.Input("InMovingScale")); + if (is_test) { + scale = moving_scale->mutable_data(platform::CPUPlace())[0]; + } else { + auto* it = const_cast( + context.Input("InCurrentIter")); + auto* iter = context.Output("OutCurrentIter"); + int* last_iter = it->mutable_data(platform::CPUPlace()); + int* current_iter = iter->mutable_data(platform::CPUPlace()); + auto* scale_list = context.Output("OutScales"); + auto* saving_scale = + context.Output("OutMovingScale"); + scale = (T)FindAbsMaxGpu(device_ctx, in->data(), in->numel()); + scale = FindRangeAbsMax(device_ctx, scale_list, saving_scale, scale, + window_size, current_iter[0]); + (*current_iter) = (*last_iter) + 1; + } + } else if (quantize_type == std::string("moving_average_abs_max")) { + auto* moving_scale = const_cast( + context.Input("InMovingScale")); + if (is_test) { + scale = moving_scale->mutable_data(platform::CPUPlace())[0]; + } else { + scale = (T)FindAbsMaxGpu(device_ctx, in->data(), in->numel()); + auto* saving_scale = + context.Output("OutMovingScale"); + scale = FindMovingAverageAbsMmax( + const_cast(moving_scale), saving_scale, scale); + } + } + + ApplySaturateGpu(device_ctx, in->numel(), in->data(), + tensor->mutable_data(in->place()), -scale, scale); + scale = bin_cnt / scale; + + auto& dev = + *context.template device_context().eigen_device(); + auto eigen_out = framework::EigenVector::Flatten(*tensor); + auto eigen_in = framework::EigenVector::Flatten(*tensor); + eigen_out.device(dev) = (scale * eigen_in).round(); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL(fake_quantize, + paddle::operators::FakeQuantizeCUDAKernel< + paddle::platform::CUDADeviceContext, float>, + paddle::operators::FakeQuantizeCUDAKernel< + paddle::platform::CUDADeviceContext, double>); diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h new file mode 100644 index 0000000000..80f71d85dd --- /dev/null +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -0,0 +1,155 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/clip_op.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/transform.h" + +namespace paddle { +namespace operators { + +using platform::Transform; + +template +class FakeQuantizeKernel : public framework::OpKernel { + public: + T FindAbsMax(framework::Tensor* in, int n) const { + T* p = in->mutable_data(platform::CPUPlace()); + T abs_max = (T)0.00000001; + for (int i = 0; i < n; i++) { + T tmp = fabs(p[i]); + if (tmp > abs_max) abs_max = tmp; + } + return T(abs_max); + } + T FindRangeAbsMax(framework::Tensor* scale_list, framework::Tensor* out_scale, + const T& cur_scale, int window_size, + int current_iter) const { + T* sl = scale_list->mutable_data(platform::CPUPlace()); + T remove_tmp = sl[current_iter]; + sl[current_iter] = cur_scale; + T& max_scale = out_scale->mutable_data(platform::CPUPlace())[0]; + if (max_scale < cur_scale) { + max_scale = cur_scale; + } else if (fabs(remove_tmp - max_scale) < 1e-6) { + int size = (current_iter > window_size) ? window_size : current_iter; + max_scale = T(FindAbsMax(scale_list, size)); + } + return max_scale; + } + + T FindMovingAverageAbsMmax(framework::Tensor* in_scale, + framework::Tensor* out_scale, + const T& cur_scale) const { + T* ins = in_scale->mutable_data(platform::CPUPlace()); + T* outs = out_scale->mutable_data(platform::CPUPlace()); + outs[0] = 0.9 * cur_scale + 0.1 * ins[0]; + return T(outs[0]); + } + + virtual void Compute(const framework::ExecutionContext& context) const { + auto* tensor = context.Output("Out"); + auto* in = context.Input("X"); + const bool is_test = context.Attr("is_test"); + tensor->mutable_data(in->place()); + + auto* oms_tensor = context.Output("OutMovingScale"); + oms_tensor->mutable_data(in->place()); + + auto quantize_type = + static_cast(context.Attr("quantize_type")); + if (quantize_type == std::string("range_abs_max")) { + auto* oss_tensor = context.Output("OutScales"); + oss_tensor->mutable_data( + context.Input("InScales")->place()); + auto* oci_tensor = context.Output("OutCurrentIter"); + oci_tensor->mutable_data( + context.Input("InCurrentIter")->place()); + } + + T scale = static_cast(1); + int window_size = context.Attr("window_size"); + int bit_length = context.Attr("bit_length"); + int bin_cnt = std::pow(2, bit_length - 1) - 1; + + auto& dev = + *context.template device_context().eigen_device(); + auto raw_in = framework::EigenVector::Flatten(*in); + if (quantize_type == std::string("abs_max")) { + auto* saving_scale = context.Output("OutMovingScale"); + auto scale_out = framework::EigenVector::Flatten(*saving_scale); + scale_out.device(dev) = raw_in.abs().maximum(); + scale = scale_out(0); + + auto& device_ctx = context.template device_context(); + auto* scale_list = context.Output("OutScales"); + math::SetConstant scalar; + scale_list->mutable_data(context.GetPlace()); + scalar(device_ctx, scale_list, static_cast(0)); + auto* iter = context.Output("OutCurrentIter"); + iter->mutable_data(context.GetPlace()); + scalar(device_ctx, iter, static_cast(0)); + } else if (quantize_type == std::string("range_abs_max")) { + auto* moving_scale = context.Input("InMovingScale"); + if (is_test) { + scale = moving_scale->data()[0]; + } else { + auto* it = context.Input("InCurrentIter"); + auto* iter = context.Output("OutCurrentIter"); + const int* last_iter = it->data(); + int* current_iter = iter->mutable_data(platform::CPUPlace()); + auto* scale_list = context.Output("OutScales"); + auto* saving_scale = + context.Output("OutMovingScale"); + auto scale_out = framework::EigenVector::Flatten(*saving_scale); + scale_out.device(dev) = raw_in.abs().maximum(); + scale = saving_scale->mutable_data(platform::CPUPlace())[0]; + scale = FindRangeAbsMax(scale_list, saving_scale, scale, window_size, + current_iter[0]); + saving_scale->mutable_data(platform::CPUPlace())[0] = scale; + (*current_iter) = (*last_iter) + 1; + } + } else if (quantize_type == std::string("moving_average_abs_max")) { + auto* moving_scale = context.Input("InMovingScale"); + if (is_test) { + scale = moving_scale->data()[0]; + } else { + auto* saving_scale = + context.Output("OutMovingScale"); + auto scale_out = framework::EigenVector::Flatten(*saving_scale); + scale_out.device(dev) = raw_in.abs().maximum(); + scale = saving_scale->mutable_data(platform::CPUPlace())[0]; + scale = FindMovingAverageAbsMmax( + const_cast(moving_scale), saving_scale, scale); + saving_scale->mutable_data(platform::CPUPlace())[0] = scale; + } + } + + Transform trans; + trans(context.template device_context(), in->data(), + in->data() + in->numel(), tensor->mutable_data(in->place()), + ClipFunctor(-scale, scale)); + auto eigen_out = framework::EigenVector::Flatten(*tensor); + auto eigen_in = framework::EigenVector::Flatten(*tensor); + eigen_out.device(dev) = (bin_cnt / scale * eigen_in).round(); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/legacy/capi/examples/model_inference/multi_thread/convert_protobin.sh b/paddle/legacy/capi/examples/model_inference/multi_thread/convert_protobin.sh deleted file mode 120000 index 3c1b353352..0000000000 --- a/paddle/legacy/capi/examples/model_inference/multi_thread/convert_protobin.sh +++ /dev/null @@ -1 +0,0 @@ -../dense/convert_protobin.sh \ No newline at end of file diff --git a/paddle/legacy/capi/examples/model_inference/multi_thread/convert_protobin.sh b/paddle/legacy/capi/examples/model_inference/multi_thread/convert_protobin.sh new file mode 100644 index 0000000000..b29f2cd214 --- /dev/null +++ b/paddle/legacy/capi/examples/model_inference/multi_thread/convert_protobin.sh @@ -0,0 +1 @@ +../dense/convert_protobin.sh diff --git a/paddle/legacy/capi/examples/model_inference/sequence/convert_protobin.sh b/paddle/legacy/capi/examples/model_inference/sequence/convert_protobin.sh deleted file mode 120000 index 3c1b353352..0000000000 --- a/paddle/legacy/capi/examples/model_inference/sequence/convert_protobin.sh +++ /dev/null @@ -1 +0,0 @@ -../dense/convert_protobin.sh \ No newline at end of file diff --git a/paddle/legacy/capi/examples/model_inference/sequence/convert_protobin.sh b/paddle/legacy/capi/examples/model_inference/sequence/convert_protobin.sh new file mode 100644 index 0000000000..b29f2cd214 --- /dev/null +++ b/paddle/legacy/capi/examples/model_inference/sequence/convert_protobin.sh @@ -0,0 +1 @@ +../dense/convert_protobin.sh diff --git a/paddle/legacy/capi/examples/model_inference/sparse_binary/convert_protobin.sh b/paddle/legacy/capi/examples/model_inference/sparse_binary/convert_protobin.sh deleted file mode 120000 index 3c1b353352..0000000000 --- a/paddle/legacy/capi/examples/model_inference/sparse_binary/convert_protobin.sh +++ /dev/null @@ -1 +0,0 @@ -../dense/convert_protobin.sh \ No newline at end of file diff --git a/paddle/legacy/capi/examples/model_inference/sparse_binary/convert_protobin.sh b/paddle/legacy/capi/examples/model_inference/sparse_binary/convert_protobin.sh new file mode 100644 index 0000000000..b29f2cd214 --- /dev/null +++ b/paddle/legacy/capi/examples/model_inference/sparse_binary/convert_protobin.sh @@ -0,0 +1 @@ +../dense/convert_protobin.sh diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py new file mode 100644 index 0000000000..6c6aa9d3bb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -0,0 +1,51 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest + + +class TestFakeQuantizeOp(OpTest): + def setUp(self): + self.op_type = "fake_quantize" + self.attrs = { + 'bit_length': 8, + 'quantize_type': 'abs_max', + 'window_size': 10000 + } + self.inputs = { + 'X': np.random.random((10, 10)).astype("float32"), + 'InScales': np.zeros(self.attrs['window_size']).astype("float32"), + 'InCurrentIter': np.zeros(1).astype("float32"), + 'InMovingScale': np.zeros(1).astype("float32") + } + self.scale = { + 'abs_max': np.max(np.abs(self.inputs['X'])).astype("float32") + } + self.outputs = { + 'Out': np.round(self.inputs['X'] / self.scale['abs_max'] * ( + (1 << (self.attrs['bit_length'] - 1)) - 1)), + 'OutScales': np.zeros(self.attrs['window_size']).astype("float32"), + 'OutMovingScale': + np.array([self.scale['abs_max']]).astype("float32"), + 'OutCurrentIter': np.zeros(1).astype("float32") + } + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() -- GitLab