提交 8e4b225f 编写于 作者: 视言's avatar 视言 提交者: qingqing01

Add fake_quantize_op. (#11359)

* Add a fake_quantize_op, which quantize an input tensor to a tensor with lower bits.
上级 79d797fd
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fake_quantize_op.h"
#include <string>
namespace paddle {
namespace operators {
class FakeQuantizeOp : public framework::OperatorWithKernel {
public:
FakeQuantizeOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeQuantizeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FakeQuantizeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutMovingScale"),
"OutMovingScale(Out) of FakeQuantizeOp should not be null");
// if (ctx->HasInput("InMovingScale")) {
ctx->SetOutputDim("OutMovingScale", ctx->GetInputDim("InMovingScale"));
//}
// if (ctx->HasInput("InScales")) {
PADDLE_ENFORCE(ctx->HasOutput("OutScales"),
"OutScales(Out) of FakeQuantizeOp should not be null");
ctx->SetOutputDim("OutScales", ctx->GetInputDim("InScales"));
// PADDLE_ENFORCE_EQ(ctx->Inputs("InScales")[0],
// ctx->Outputs("OutScales")[0],
// "Mean and MeanOut should share the same memory");
//}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class FakeQuantizeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input tensor of scale operator.");
AddInput("InScales", "(Tensor) scale buffer, used in static quantization.")
.AsDispensable();
AddInput("InMovingScale", "Last scale, used in static quantization.")
.AsDispensable();
AddInput("InCurrentIter",
"Last iteration number, used in static quantization.")
.AsDispensable();
AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
AddOutput("OutScales",
"(Tensor) scale buffer, used in static quantization.")
.AsDispensable();
AddOutput("OutMovingScale", " Current scale");
AddOutput("OutCurrentIter", "Current iteration number.").AsDispensable();
AddAttr<std::string>("quantize_type",
"(string, default abs_max)"
"The scaling tpe of the quantize operator.")
.SetDefault("abs_max");
AddAttr<int>("window_size", "(int, default 10000)").SetDefault(10000);
AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8)
.AddCustomChecker([](const int &bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
"'bit_length' should be between 1 and 16.");
});
AddAttr<bool>("is_test", "").SetDefault(false);
AddComment(R"DOC(
FakeQuantize operator
quantize_type = abs_max:
$$scale = max(abs(x))$$
quantize_type = range_abs_max:
$$scale = max(max(abs(x)), history_abs_max)$$
quantize_type = moving_average_abs_max:
$$scale = 0.1*scale+0.9*new_abs_max)$$
$$Out = scale*X$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fake_quantize, ops::FakeQuantizeOp, ops::FakeQuantizeOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
fake_quantize,
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, float>,
ops::FakeQuantizeKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void FindAbsMaxKernel(const int n, const T* in, T* out) {
int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x;
extern __shared__ T shared_max_data[];
if (gridDim.x > 1) {
shared_max_data[tid] = T(0);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T tmp = fabs(in[i]);
if (tmp > shared_max_data[tid]) {
shared_max_data[tid] = tmp;
}
}
} else {
if (bid < n) {
shared_max_data[tid] = fabs(in[bid]);
} else {
shared_max_data[tid] = T(0);
}
}
__syncthreads();
for (int i = blockDim.x / 2; i > 0; i >>= 1) {
if (tid < i && shared_max_data[tid] < shared_max_data[tid + i]) {
shared_max_data[tid] = shared_max_data[tid + i];
}
__syncthreads();
}
if (tid == 0) {
out[blockIdx.x] = shared_max_data[0];
}
}
float FindAbsMaxGpu(const platform::CUDADeviceContext& ctx, const float* array,
int length) {
float host_max;
int kNumTheads = 1024;
int gridDimx = (kNumTheads - 1 + length) / kNumTheads;
gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx;
framework::Tensor t;
float* device_max = t.mutable_data<float>(framework::make_ddim({gridDimx}),
platform::CUDAPlace());
FindAbsMaxKernel<float><<<gridDimx, kNumTheads, kNumTheads * sizeof(float),
ctx.stream()>>>(length, array, device_max);
FindAbsMaxKernel<
float><<<1, kNumTheads, kNumTheads * sizeof(float), ctx.stream()>>>(
gridDimx, device_max, device_max);
PADDLE_ENFORCE_EQ(
cudaMemcpy(&host_max, device_max, sizeof(float), cudaMemcpyDeviceToHost),
cudaSuccess, "cudaMemcpy failed");
return host_max;
}
template <typename T>
__global__ void ApplySaturateKernel(const int n, const T* in, T* out,
int* num_saturate, const T min,
const T max) {
int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x;
extern __shared__ int shared_count[];
shared_count[tid] = 0;
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
if (in[i] > max) {
out[i] = max;
shared_count[tid] += 1;
} else if (in[i] < min) {
out[i] = min;
shared_count[tid] += 1;
} else {
out[i] = in[i];
}
}
__syncthreads();
for (int i = blockDim.x / 2; i > 0; i >>= 1) {
if (tid < i) {
shared_count[tid] += shared_count[tid + i];
}
__syncthreads();
}
if (tid == 0) {
num_saturate[blockIdx.x] = shared_count[0];
}
}
template <typename T>
__global__ void ReduceKernel(const int n, const T* in, T* out) {
int tid = threadIdx.x;
extern __shared__ T shared_sum[];
if (tid < n) {
shared_sum[tid] = in[tid];
} else {
shared_sum[tid] = T(0);
}
__syncthreads();
// blockDim.x must >= n
for (int i = (n + 1) / 2; i > 0; i >>= 1) {
if (tid < i) {
shared_sum[tid] += shared_sum[tid + i];
}
__syncthreads();
}
if (tid == 0) {
out[0] = shared_sum[0];
}
}
template <typename T>
int ApplySaturateGpu(const platform::CUDADeviceContext& ctx, const int n,
const T* in, T* out, const T min, const T max) {
int host_num_saturate;
int kNumTheads = 1024;
int gridDimx = (n + kNumTheads - 1) / kNumTheads;
gridDimx = (gridDimx > kNumTheads) ? kNumTheads : gridDimx;
framework::Tensor t;
int* device_num_saturate = t.mutable_data<int>(
framework::make_ddim({gridDimx}), platform::CUDAPlace());
ApplySaturateKernel<
T><<<gridDimx, kNumTheads, kNumTheads * sizeof(T), ctx.stream()>>>(
n, in, out, device_num_saturate, min, max);
ReduceKernel<int><<<1, kNumTheads, kNumTheads * sizeof(T), ctx.stream()>>>(
gridDimx, device_num_saturate, device_num_saturate);
PADDLE_ENFORCE_EQ(cudaSuccess,
cudaMemcpy(&host_num_saturate, device_num_saturate,
sizeof(int), cudaMemcpyDeviceToHost),
"cudaMemcpy failed");
return host_num_saturate;
}
template <typename DeviceContext, typename T>
class FakeQuantizeCUDAKernel : public framework::OpKernel<T> {
public:
T FindRangeAbsMax(const platform::CUDADeviceContext& ctx,
framework::Tensor* scale_list, framework::Tensor* out_scale,
const T& cur_scale, int window_size,
int current_iter) const {
T* sl = scale_list->mutable_data<T>(platform::CPUPlace());
T remove_tmp = sl[current_iter];
sl[current_iter] = cur_scale;
T& max_scale = out_scale->mutable_data<T>(platform::CPUPlace())[0];
if (max_scale < cur_scale) {
max_scale = cur_scale;
} else if (fabs(remove_tmp - max_scale) < 1e-6) {
int size = (current_iter > window_size) ? window_size : current_iter;
max_scale = T(FindAbsMaxGpu(ctx, scale_list->data<float>(), size));
}
return max_scale;
}
T FindMovingAverageAbsMmax(framework::Tensor* in_scale,
framework::Tensor* out_scale,
const T& cur_scale) const {
T* ins = in_scale->mutable_data<T>(platform::CPUPlace());
T* outs = out_scale->mutable_data<T>(platform::CPUPlace());
outs[0] = 0.9 * cur_scale + 0.1 * ins[0];
return T(outs[0]);
}
virtual void Compute(const framework::ExecutionContext& context) const {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
"This kernel only runs on GPU device.");
auto& device_ctx = context.cuda_device_context();
auto* tensor = context.Output<framework::Tensor>("Out");
auto* in = context.Input<framework::Tensor>("X");
const bool is_test = context.Attr<bool>("is_test");
tensor->mutable_data<T>(in->place());
context.Output<framework::Tensor>("OutMovingScale")
->mutable_data<T>(
context.Input<framework::Tensor>("InMovingScale")->place());
auto quantize_type =
static_cast<std::string>(context.Attr<std::string>("quantize_type"));
if (quantize_type == std::string("range_abs_max")) {
context.Output<framework::Tensor>("OutScales")
->mutable_data<T>(
context.Input<framework::Tensor>("InScales")->place());
context.Output<framework::Tensor>("OutCurrentIter")
->mutable_data<T>(
context.Input<framework::Tensor>("InCurrentIter")->place());
}
T scale = T(1);
int window_size = context.Attr<int>("window_size");
T bin_cnt = (T)((1 << (context.Attr<int>("bit_length") - 1)) - 1);
if (quantize_type == std::string("abs_max")) {
auto* saving_scale = context.Output<framework::Tensor>("OutMovingScale");
scale = (T)FindAbsMaxGpu(device_ctx, in->data<float>(), in->numel());
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
auto& device_ctx = context.template device_context<DeviceContext>();
auto* scale_list = context.Output<framework::Tensor>("OutScales");
math::SetConstant<DeviceContext, T> scalar;
scale_list->mutable_data<T>(context.GetPlace());
scalar(device_ctx, scale_list, static_cast<T>(0));
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
iter->mutable_data<T>(context.GetPlace());
scalar(device_ctx, iter, static_cast<T>(0));
} else if (quantize_type == std::string("range_abs_max")) {
auto* moving_scale = const_cast<framework::Tensor*>(
context.Input<framework::Tensor>("InMovingScale"));
if (is_test) {
scale = moving_scale->mutable_data<T>(platform::CPUPlace())[0];
} else {
auto* it = const_cast<framework::Tensor*>(
context.Input<framework::Tensor>("InCurrentIter"));
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
int* last_iter = it->mutable_data<int>(platform::CPUPlace());
int* current_iter = iter->mutable_data<int>(platform::CPUPlace());
auto* scale_list = context.Output<framework::Tensor>("OutScales");
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
scale = (T)FindAbsMaxGpu(device_ctx, in->data<float>(), in->numel());
scale = FindRangeAbsMax(device_ctx, scale_list, saving_scale, scale,
window_size, current_iter[0]);
(*current_iter) = (*last_iter) + 1;
}
} else if (quantize_type == std::string("moving_average_abs_max")) {
auto* moving_scale = const_cast<framework::Tensor*>(
context.Input<framework::Tensor>("InMovingScale"));
if (is_test) {
scale = moving_scale->mutable_data<T>(platform::CPUPlace())[0];
} else {
scale = (T)FindAbsMaxGpu(device_ctx, in->data<float>(), in->numel());
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
scale = FindMovingAverageAbsMmax(
const_cast<framework::Tensor*>(moving_scale), saving_scale, scale);
}
}
ApplySaturateGpu<T>(device_ctx, in->numel(), in->data<T>(),
tensor->mutable_data<T>(in->place()), -scale, scale);
scale = bin_cnt / scale;
auto& dev =
*context.template device_context<DeviceContext>().eigen_device();
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
auto eigen_in = framework::EigenVector<T>::Flatten(*tensor);
eigen_out.device(dev) = (scale * eigen_in).round();
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(fake_quantize,
paddle::operators::FakeQuantizeCUDAKernel<
paddle::platform::CUDADeviceContext, float>,
paddle::operators::FakeQuantizeCUDAKernel<
paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
using platform::Transform;
template <typename DeviceContext, typename T>
class FakeQuantizeKernel : public framework::OpKernel<T> {
public:
T FindAbsMax(framework::Tensor* in, int n) const {
T* p = in->mutable_data<T>(platform::CPUPlace());
T abs_max = (T)0.00000001;
for (int i = 0; i < n; i++) {
T tmp = fabs(p[i]);
if (tmp > abs_max) abs_max = tmp;
}
return T(abs_max);
}
T FindRangeAbsMax(framework::Tensor* scale_list, framework::Tensor* out_scale,
const T& cur_scale, int window_size,
int current_iter) const {
T* sl = scale_list->mutable_data<T>(platform::CPUPlace());
T remove_tmp = sl[current_iter];
sl[current_iter] = cur_scale;
T& max_scale = out_scale->mutable_data<T>(platform::CPUPlace())[0];
if (max_scale < cur_scale) {
max_scale = cur_scale;
} else if (fabs(remove_tmp - max_scale) < 1e-6) {
int size = (current_iter > window_size) ? window_size : current_iter;
max_scale = T(FindAbsMax(scale_list, size));
}
return max_scale;
}
T FindMovingAverageAbsMmax(framework::Tensor* in_scale,
framework::Tensor* out_scale,
const T& cur_scale) const {
T* ins = in_scale->mutable_data<T>(platform::CPUPlace());
T* outs = out_scale->mutable_data<T>(platform::CPUPlace());
outs[0] = 0.9 * cur_scale + 0.1 * ins[0];
return T(outs[0]);
}
virtual void Compute(const framework::ExecutionContext& context) const {
auto* tensor = context.Output<framework::Tensor>("Out");
auto* in = context.Input<framework::Tensor>("X");
const bool is_test = context.Attr<bool>("is_test");
tensor->mutable_data<T>(in->place());
auto* oms_tensor = context.Output<framework::Tensor>("OutMovingScale");
oms_tensor->mutable_data<T>(in->place());
auto quantize_type =
static_cast<std::string>(context.Attr<std::string>("quantize_type"));
if (quantize_type == std::string("range_abs_max")) {
auto* oss_tensor = context.Output<framework::Tensor>("OutScales");
oss_tensor->mutable_data<T>(
context.Input<framework::Tensor>("InScales")->place());
auto* oci_tensor = context.Output<framework::Tensor>("OutCurrentIter");
oci_tensor->mutable_data<T>(
context.Input<framework::Tensor>("InCurrentIter")->place());
}
T scale = static_cast<T>(1);
int window_size = context.Attr<int>("window_size");
int bit_length = context.Attr<int>("bit_length");
int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev =
*context.template device_context<DeviceContext>().eigen_device();
auto raw_in = framework::EigenVector<T>::Flatten(*in);
if (quantize_type == std::string("abs_max")) {
auto* saving_scale = context.Output<framework::Tensor>("OutMovingScale");
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
scale_out.device(dev) = raw_in.abs().maximum();
scale = scale_out(0);
auto& device_ctx = context.template device_context<DeviceContext>();
auto* scale_list = context.Output<framework::Tensor>("OutScales");
math::SetConstant<DeviceContext, T> scalar;
scale_list->mutable_data<T>(context.GetPlace());
scalar(device_ctx, scale_list, static_cast<T>(0));
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
iter->mutable_data<T>(context.GetPlace());
scalar(device_ctx, iter, static_cast<T>(0));
} else if (quantize_type == std::string("range_abs_max")) {
auto* moving_scale = context.Input<framework::Tensor>("InMovingScale");
if (is_test) {
scale = moving_scale->data<T>()[0];
} else {
auto* it = context.Input<framework::Tensor>("InCurrentIter");
auto* iter = context.Output<framework::Tensor>("OutCurrentIter");
const int* last_iter = it->data<int>();
int* current_iter = iter->mutable_data<int>(platform::CPUPlace());
auto* scale_list = context.Output<framework::Tensor>("OutScales");
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
scale_out.device(dev) = raw_in.abs().maximum();
scale = saving_scale->mutable_data<T>(platform::CPUPlace())[0];
scale = FindRangeAbsMax(scale_list, saving_scale, scale, window_size,
current_iter[0]);
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
(*current_iter) = (*last_iter) + 1;
}
} else if (quantize_type == std::string("moving_average_abs_max")) {
auto* moving_scale = context.Input<framework::Tensor>("InMovingScale");
if (is_test) {
scale = moving_scale->data<T>()[0];
} else {
auto* saving_scale =
context.Output<framework::Tensor>("OutMovingScale");
auto scale_out = framework::EigenVector<T>::Flatten(*saving_scale);
scale_out.device(dev) = raw_in.abs().maximum();
scale = saving_scale->mutable_data<T>(platform::CPUPlace())[0];
scale = FindMovingAverageAbsMmax(
const_cast<framework::Tensor*>(moving_scale), saving_scale, scale);
saving_scale->mutable_data<T>(platform::CPUPlace())[0] = scale;
}
}
Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), in->data<T>(),
in->data<T>() + in->numel(), tensor->mutable_data<T>(in->place()),
ClipFunctor<T>(-scale, scale));
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
auto eigen_in = framework::EigenVector<T>::Flatten(*tensor);
eigen_out.device(dev) = (bin_cnt / scale * eigen_in).round();
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestFakeQuantizeOp(OpTest):
def setUp(self):
self.op_type = "fake_quantize"
self.attrs = {
'bit_length': 8,
'quantize_type': 'abs_max',
'window_size': 10000
}
self.inputs = {
'X': np.random.random((10, 10)).astype("float32"),
'InScales': np.zeros(self.attrs['window_size']).astype("float32"),
'InCurrentIter': np.zeros(1).astype("float32"),
'InMovingScale': np.zeros(1).astype("float32")
}
self.scale = {
'abs_max': np.max(np.abs(self.inputs['X'])).astype("float32")
}
self.outputs = {
'Out': np.round(self.inputs['X'] / self.scale['abs_max'] * (
(1 << (self.attrs['bit_length'] - 1)) - 1)),
'OutScales': np.zeros(self.attrs['window_size']).astype("float32"),
'OutMovingScale':
np.array([self.scale['abs_max']]).astype("float32"),
'OutCurrentIter': np.zeros(1).astype("float32")
}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册