未验证 提交 0353eddb 编写于 作者: Q qingqing01 提交者: GitHub

Improve fake_dequantize_op. (#12877)

* Improve fake_dequantize_op.
* Follow comments.
上级 11e01d9b
......@@ -18,15 +18,32 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename T>
struct DequantizeFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor* scale,
T max_range, framework::Tensor* out) {
auto in_e = framework::EigenVector<T>::Flatten(*in);
const T* scale_factor = scale->data<T>();
auto out_e = framework::EigenVector<T>::Flatten(*out);
auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = (scale_factor[0] / max_range) * in_e;
}
};
template struct DequantizeFunctor<platform::CPUDeviceContext, float>;
template struct DequantizeFunctor<platform::CPUDeviceContext, double>;
class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
public:
FakeDequantizeMaxAbsOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
FakeDequantizeMaxAbsOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeDequantizeMaxAbsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......@@ -42,21 +59,17 @@ class FakeDequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X",
"(Tensor) The input with float-32/64 type is the "
"low precision tensor.");
AddInput("Scale", "(float) The scale in quantization stage.");
AddOutput("Out",
"(Tensor) The output is the dequantized high "
"precision tensor.");
AddAttr<int>("num_bits",
"(int) `num_bits` is the quantization level bits, "
"such as 2, 5, 8.");
AddAttr<float>("scale",
"(float) The maximum absolute value of low precision tensor."
"It is usually calculated by the fake_quantize_max_abs_op.");
AddAttr<float>("max_range", "(float) The max range in quantization stage.");
AddComment(R"DOC(
FakeDequantizeMaxAbsOp operator.
This calculation is an opposite operation of FakeQuantizeMaxAbsOp:
$$Out = \frac{scale*X}{2^{num_bits} - 1}$$
$$Out = \frac{scale*X}{ max_range }$$
)DOC");
}
......
......@@ -14,6 +14,42 @@ limitations under the License. */
#include "paddle/fluid/operators/fake_dequantize_op.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void KeDequantize(const T* in, const T* scale, T max_range, int num,
T* out) {
const int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < num) {
out[idx] = in[idx] * scale[0] / max_range;
}
}
template <typename T>
struct DequantizeFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor* scale,
T max_range, framework::Tensor* out) {
const T* in_data = in->data<T>();
const T* scale_factor = scale->data<T>();
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
int num = in->numel();
int block = 512;
int grid = (num + block - 1) / block;
KeDequantize<T><<<grid, block, 0, dev_ctx.stream()>>>(
in_data, scale_factor, max_range, num, out_data);
}
};
template struct DequantizeFunctor<platform::CUDADeviceContext, float>;
template struct DequantizeFunctor<platform::CUDADeviceContext, double>;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs,
......
......@@ -19,22 +19,29 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
struct DequantizeFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
const framework::Tensor* scale, T max_range,
framework::Tensor* out);
};
template <typename DeviceContext, typename T>
class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::Tensor>("X");
auto* scale = ctx.Input<framework::Tensor>("Scale");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(in->place());
int num_bits = ctx.Attr<int>("num_bits");
T scale = static_cast<T>(ctx.Attr<float>("scale"));
int range = std::pow(2, num_bits) - 1;
float max_range = ctx.Attr<float>("max_range");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
out->mutable_data<T>(dev_ctx.GetPlace());
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
eigen_out.device(dev) = (scale / range) * eigen_in;
DequantizeFunctor<DeviceContext, T>()(dev_ctx, in, scale,
static_cast<T>(max_range), out);
}
};
......
......@@ -20,41 +20,50 @@ import math
from op_test import OpTest
def quantize_max_abs(x, num_bits):
range = math.pow(2, num_bits) - 1
def quantize_max_abs(x, max_range):
scale = np.max(np.abs(x).flatten())
y = np.round(x / scale * range)
y = np.round(x / scale * max_range)
return y, scale
def dequantize_max_abs(x, num_bits, scale):
range = math.pow(2, num_bits) - 1
y = (scale / range) * x
def dequantize_max_abs(x, scale, max_range):
y = (scale / max_range) * x
return y
class TestFakeDequantizeMaxAbsOp(OpTest):
def set_args(self):
self.num_bits = 8
self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "float32"
def setUp(self):
self.set_args()
self.op_type = "fake_dequantize_max_abs"
x = np.random.randn(31, 65).astype("float32")
yq, scale = quantize_max_abs(x, self.num_bits)
ydq = dequantize_max_abs(yq, self.num_bits, scale)
x = np.random.randn(31, 65).astype(self.data_type)
yq, scale = quantize_max_abs(x, self.max_range)
ydq = dequantize_max_abs(yq, scale, self.max_range)
self.inputs = {'X': yq}
self.attrs = {'num_bits': self.num_bits, 'scale': float(scale)}
self.inputs = {'X': yq, 'Scale': np.array(scale).astype(self.data_type)}
self.attrs = {'max_range': self.max_range}
self.outputs = {'Out': ydq}
def test_check_output(self):
self.check_output()
class TestFakeDequantizeMaxAbsOp5Bits(OpTest):
class TestFakeDequantizeMaxAbsOpDouble(TestFakeDequantizeMaxAbsOp):
def set_args(self):
self.num_bits = 8
self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "float64"
class TestFakeDequantizeMaxAbsOp5Bits(TestFakeDequantizeMaxAbsOp):
def set_args(self):
self.num_bits = 5
self.max_range = math.pow(2, self.num_bits - 1) - 1
self.data_type = "float32"
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册