未验证 提交 ff7f911b 编写于 作者: Z Zhaolong Xing 提交者: GitHub

add quant_dequant_moving_avg_max_abs op (#17480)

* add quant_dequant_moving_avg_max_abs op
test=develop

* add more note for quantdequant op
test=develop
上级 306eadcd
...@@ -68,6 +68,23 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { ...@@ -68,6 +68,23 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>; template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;
template <typename T>
struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, framework::Tensor* out) {
T s = scale.data<T>()[0];
platform::Transform<platform::CPUDeviceContext> trans;
trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(*ctx.eigen_device()) =
(s / bin_cnt) * (bin_cnt / s * out_e).round();
}
};
template struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext,
float>;
template <typename T> template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, void operator()(const platform::CPUDeviceContext& ctx,
...@@ -324,24 +341,26 @@ $$Out = round(X/scale * range)$$ ...@@ -324,24 +341,26 @@ $$Out = round(X/scale * range)$$
} }
}; };
class FakeQuantizeMovingAverageAbsMaxOp : public framework::OperatorWithKernel { class FakeQuantOrWithDequantMovingAverageAbsMaxOp
: public framework::OperatorWithKernel {
public: public:
FakeQuantizeMovingAverageAbsMaxOp(const std::string& type, FakeQuantOrWithDequantMovingAverageAbsMaxOp(
const framework::VariableNameMap& inputs, const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs, const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeQuantOrWithDequantMovingAverageAbsMaxOp "
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FakeQuantOrWithDequantMovingAverageAbsMaxOp "
"should not be null.");
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->HasInput("X"), ctx->HasOutput("OutScale"),
"Input(X) of FakeQuantizeMovingAverageAbsMaxOp should not be null."); "Output(OutScale) of FakeQuantOrWithDequantMovingAverageAbsMaxOp "
PADDLE_ENFORCE( "should not be null");
ctx->HasOutput("Out"),
"Output(Out) of FakeQuantizeMovingAverageAbsMaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutScale"),
"Output(OutScale) of FakeQuantizeMovingAverageAbsMaxOp "
"should not be null");
if (ctx->HasOutput("OutState")) { if (ctx->HasOutput("OutState")) {
ctx->SetOutputDim("OutState", {1}); ctx->SetOutputDim("OutState", {1});
} }
...@@ -361,7 +380,7 @@ class FakeQuantizeMovingAverageAbsMaxOp : public framework::OperatorWithKernel { ...@@ -361,7 +380,7 @@ class FakeQuantizeMovingAverageAbsMaxOp : public framework::OperatorWithKernel {
} }
}; };
class FakeQuantizeMovingAverageAbsMaxOpMaker class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
: public framework::OpProtoAndCheckerMaker { : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
...@@ -386,12 +405,19 @@ class FakeQuantizeMovingAverageAbsMaxOpMaker ...@@ -386,12 +405,19 @@ class FakeQuantizeMovingAverageAbsMaxOpMaker
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
.SetDefault(false); .SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
FakeQuantize operator is used in static quantization. This is a Base Op which support FakeQuantMovingAverageAbsMaxOp and FakeQuantDequantMovingAverageAbsMaxOp
FakeQuantMovingAverageAbsMaxOp operator is used in static quantization.
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$ $$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$range = 2^{bit\_length - 1} - 1$$ $$range = 2^{bit\_length - 1} - 1$$
$$Out = round(X/scale * range)$$ $$Out = round(X/scale * range)$$
FakeQuantDequantMovingAverageAbsMaxOp operator do the moving_average_abs_max op quant and then dequant.
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$range = 2^{bit\_length - 1} - 1$$
$$Out = round(X/scale * range) * scale / range$$
)DOC"); )DOC");
} }
}; };
...@@ -477,11 +503,21 @@ REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max, ...@@ -477,11 +503,21 @@ REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxKernel<CPU, float>); ops::FakeQuantizeRangeAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR(fake_quantize_moving_average_abs_max, REGISTER_OPERATOR(fake_quantize_moving_average_abs_max,
ops::FakeQuantizeMovingAverageAbsMaxOp, ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
ops::FakeQuantizeMovingAverageAbsMaxOpMaker, ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max, REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>); ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR(fake_quantize_dequantize_moving_average_abs_max,
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
fake_quantize_dequantize_moving_average_abs_max,
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR(fake_channel_wise_quantize_abs_max, REGISTER_OPERATOR(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxOp, ops::FakeChannelWiseQuantizeAbsMaxOp,
ops::FakeChannelWiseQuantizeAbsMaxOpMaker, ops::FakeChannelWiseQuantizeAbsMaxOpMaker,
......
...@@ -129,6 +129,23 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale, ...@@ -129,6 +129,23 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
} }
} }
template <typename T>
__global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
const int bin_cnt, const int n,
T* out) {
int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x;
T s = scale[0];
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[i];
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt / s * v;
out[i] = round(v) * s / bin_cnt;
}
}
template <typename T> template <typename T>
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> { struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, void operator()(const platform::CUDADeviceContext& ctx,
...@@ -149,6 +166,27 @@ struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> { ...@@ -149,6 +166,27 @@ struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>; template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;
template <typename T>
struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, framework::Tensor* out) {
int num = in.numel();
int block = 1024;
int grid = (block - 1 + num) / block;
const T* in_data = in.data<T>();
const T* scale_data = scale.data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
ClipAndQuantDequantKernel<T><<<grid, block, 0, ctx.stream()>>>(
in_data, scale_data, bin_cnt, num, out_data);
}
};
template struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext,
float>;
template <typename T> template <typename T>
__global__ void ChannelClipAndQuantKernel(const T* in, const T* scale, __global__ void ChannelClipAndQuantKernel(const T* in, const T* scale,
const int bin_cnt, const int n, const int bin_cnt, const int n,
...@@ -302,3 +340,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -302,3 +340,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float>); ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale, REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale,
ops::MovingAverageAbsMaxScaleKernel<CUDA, float>); ops::MovingAverageAbsMaxScaleKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(
fake_quantize_dequantize_moving_average_abs_max,
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>);
...@@ -35,6 +35,13 @@ struct ClipAndFakeQuantFunctor { ...@@ -35,6 +35,13 @@ struct ClipAndFakeQuantFunctor {
framework::Tensor* out); framework::Tensor* out);
}; };
template <typename DeviceContext, typename T>
struct ClipAndFakeQuantDequantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor& scale, const int bin_cnt,
framework::Tensor* out);
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct FindRangeAbsMaxFunctor { struct FindRangeAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& cur_scale, void operator()(const DeviceContext& ctx, const framework::Tensor& cur_scale,
...@@ -150,8 +157,13 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> { ...@@ -150,8 +157,13 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeQuantizeMovingAverageAbsMaxKernel : public framework::OpKernel<T> { class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
public: public:
~FakeMovingAverageAbsMaxKernelBase() {}
virtual void RunClipFunctor(const DeviceContext& dev_ctx,
const framework::Tensor& in,
const framework::Tensor& in_scale, int bin_cnt,
framework::Tensor* out) const = 0;
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
auto* in_scale = context.Input<framework::Tensor>("InScale"); auto* in_scale = context.Input<framework::Tensor>("InScale");
...@@ -165,8 +177,7 @@ class FakeQuantizeMovingAverageAbsMaxKernel : public framework::OpKernel<T> { ...@@ -165,8 +177,7 @@ class FakeQuantizeMovingAverageAbsMaxKernel : public framework::OpKernel<T> {
// testing // testing
if (is_test) { if (is_test) {
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *in_scale, RunClipFunctor(dev_ctx, *in, *in_scale, bin_cnt, out);
bin_cnt, out);
return; return;
} }
...@@ -193,8 +204,31 @@ class FakeQuantizeMovingAverageAbsMaxKernel : public framework::OpKernel<T> { ...@@ -193,8 +204,31 @@ class FakeQuantizeMovingAverageAbsMaxKernel : public framework::OpKernel<T> {
dev_ctx, *in_accum, *in_state, cur_scale_data, moving_rate, out_state, dev_ctx, *in_accum, *in_state, cur_scale_data, moving_rate, out_state,
out_accum, out_scale); out_accum, out_scale);
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale, RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, out);
bin_cnt, out); }
};
template <typename DeviceContext, typename T>
class FakeQuantizeMovingAverageAbsMaxKernel
: public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> {
public:
void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in,
const framework::Tensor& in_scale, int bin_cnt,
framework::Tensor* out) const override {
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, in, in_scale, bin_cnt,
out);
}
};
template <typename DeviceContext, typename T>
class FakeQuantizeDequantizeMovingAverageAbsMaxKernel
: public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> {
public:
void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in,
const framework::Tensor& in_scale, int bin_cnt,
framework::Tensor* out) const override {
ClipAndFakeQuantDequantFunctor<DeviceContext, T>()(dev_ctx, in, in_scale,
bin_cnt, out);
} }
}; };
......
...@@ -90,46 +90,6 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest): ...@@ -90,46 +90,6 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest):
self.check_output() self.check_output()
class TestFakeQuantizeMovingOp(OpTest):
def setUp(self):
self.op_type = "fake_quantize_moving_average_abs_max"
self.attrs = {
'bit_length': int(5),
'moving_rate': float(0.9),
'is_test': False
}
accum = np.zeros(1).astype("float32")
accum[0] = 1
state = np.zeros(1).astype("float32")
state[0] = 1
scale = np.zeros(1).astype("float32")
scale[0] = 0.001
self.inputs = {
'X': np.random.random((8, 16, 7, 7)).astype("float32"),
'InScale': scale,
'InAccum': accum,
'InState': state,
}
out_accum = np.zeros(1).astype("float32")
out_state = np.zeros(1).astype("float32")
out_scale = np.zeros(1).astype("float32")
out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max(
np.abs(self.inputs['X'])).astype("float32")
out_state[0] = self.attrs['moving_rate'] * state[0] + 1
out_scale = out_accum / out_state
self.outputs = {
'Out': np.round(self.inputs['X'] / out_scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1)),
'OutAccum': out_accum,
'OutState': out_state,
'OutScale': out_scale,
}
def test_check_output(self):
self.check_output()
class TestMovingAverageAbsMaxScaleOp(OpTest): class TestMovingAverageAbsMaxScaleOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "moving_average_abs_max_scale" self.op_type = "moving_average_abs_max_scale"
...@@ -193,5 +153,62 @@ class TestFakeQuantizeRangeAbsMaxOp2(OpTest): ...@@ -193,5 +153,62 @@ class TestFakeQuantizeRangeAbsMaxOp2(OpTest):
self.check_output(no_check_set=set(['OutScale', 'OutScales'])) self.check_output(no_check_set=set(['OutScale', 'OutScales']))
class TestMovingOpBase(OpTest):
def setUp(self):
self.init_type()
self.attrs = {
'bit_length': int(5),
'moving_rate': float(0.9),
'is_test': False
}
accum = np.zeros(1).astype("float32")
accum[0] = 1
state = np.zeros(1).astype("float32")
state[0] = 1
scale = np.zeros(1).astype("float32")
scale[0] = 0.001
self.inputs = {
'X': np.random.random((8, 16, 7, 7)).astype("float32"),
'InScale': scale,
'InAccum': accum,
'InState': state,
}
out_accum = np.zeros(1).astype("float32")
out_state = np.zeros(1).astype("float32")
out_scale = np.zeros(1).astype("float32")
out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max(
np.abs(self.inputs['X'])).astype("float32")
out_state[0] = self.attrs['moving_rate'] * state[0] + 1
out_scale = out_accum / out_state
out_data = self.calc_output(out_scale)
self.outputs = {
'Out': out_data,
'OutAccum': out_accum,
'OutState': out_state,
'OutScale': out_scale,
}
def init_type(self):
self.op_type = "fake_quantize_moving_average_abs_max"
def calc_output(self, out_scale):
return np.round(self.inputs['X'] / out_scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1))
def test_check_output(self):
self.check_output()
class TestFakeQuantDequantMovingOp(TestMovingOpBase):
def init_type(self):
self.op_type = "fake_quantize_dequantize_moving_average_abs_max"
def calc_output(self, out_scale):
range_v = (1 << (self.attrs['bit_length'] - 1)) - 1
return np.round(self.inputs['X'] / out_scale *
range_v) * out_scale / range_v
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册