未验证 提交 a914d9b1 编写于 作者: Z Zhen Wang 提交者: GitHub

Quant output scale (#17215)

* Add MovingAverageAbsMaxScale operator which is only used for calculating the quantization scale.

* test=develop

* change the output into inplace. test=develop

* Revert "test=develop"

This reverts commit 696cf626.

* Revert "change the output into inplace. test=develop"

This reverts commit a19acd20.

* test=develop.

* update the MovingAverageAbsMaxScaleOp test. test=develop
上级 32b62c25
......@@ -388,14 +388,76 @@ class FakeQuantizeMovingAverageAbsMaxOpMaker
AddComment(R"DOC(
FakeQuantize operator is used in static quantization.
$$scale = (0.9*max(abs(x))+accum)/(0.9*state+1)$$
$$range = 2^{bit_length - 1} - 1$$
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$range = 2^{bit\_length - 1} - 1$$
$$Out = round(X/scale * range)$$
)DOC");
}
};
class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(
ctx->HasInput("X"),
"Input(X) of MovingAverageAbsMaxScaleOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of MovingAverageAbsMaxScaleOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("OutScale"),
"Output(OutScale) of MovingAverageAbsMaxScaleOp"
"should not be null");
if (ctx->HasOutput("OutState")) {
ctx->SetOutputDim("OutState", {1});
}
if (ctx->HasOutput("OutAccum")) {
ctx->SetOutputDim("OutAccum", {1});
}
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {1});
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.GetPlace());
}
};
class MovingAverageAbsMaxScaleOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input is float data type.");
AddInput("InAccum", "Last accum.").AsDispensable();
AddInput("InState", "Last state.").AsDispensable();
AddOutput("Out",
"(Tensor) Output tensor is just equivalent to the input tensor.");
AddOutput("OutScale", " Current scale");
AddOutput("OutState", "(Tensor) state buffer.").AsDispensable();
AddOutput("OutAccum", "(Tensor) accum buffer.").AsDispensable();
AddAttr<float>("moving_rate", "(float, default 0.9) moving rate.")
.SetDefault(0.9);
AddAttr<bool>("is_test",
"(bool, default false) Set true for inference only and false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddComment(R"DOC(
MovingAverageAbsMaxScale operator is only used for calculating the quantization scale.
And it will not quantize the input tensor.
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$Out = X$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
......@@ -426,3 +488,9 @@ REGISTER_OPERATOR(fake_channel_wise_quantize_abs_max,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR(moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp,
ops::MovingAverageAbsMaxScaleOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale,
ops::MovingAverageAbsMaxScaleKernel<CPU, float>);
......@@ -300,3 +300,5 @@ REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max,
REGISTER_OP_CUDA_KERNEL(
fake_quantize_moving_average_abs_max,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale,
ops::MovingAverageAbsMaxScaleKernel<CUDA, float>);
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
......@@ -197,5 +198,46 @@ class FakeQuantizeMovingAverageAbsMaxKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
bool is_test = context.Attr<bool>("is_test");
// testing
if (is_test) {
return;
}
// training
auto* in_accum = context.Input<framework::Tensor>("InAccum");
auto* in_state = context.Input<framework::Tensor>("InState");
auto& allocator =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx);
auto cur_scale = allocator.Allocate(1 * sizeof(T));
T* cur_scale_data = static_cast<T*>(cur_scale->ptr());
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in->data<T>(), in->numel(),
cur_scale_data);
auto* out_state = context.Output<framework::Tensor>("OutState");
auto* out_accum = context.Output<framework::Tensor>("OutAccum");
auto* out_scale = context.Output<framework::Tensor>("OutScale");
out_state->mutable_data<T>(context.GetPlace());
out_accum->mutable_data<T>(context.GetPlace());
out_scale->mutable_data<T>(context.GetPlace());
float moving_rate = context.Attr<float>("moving_rate");
FindMovingAverageAbsMaxFunctor<DeviceContext, T>()(
dev_ctx, *in_accum, *in_state, cur_scale_data, moving_rate, out_state,
out_accum, out_scale);
}
};
} // namespace operators
} // namespace paddle
......@@ -130,6 +130,38 @@ class TestFakeQuantizeMovingOp(OpTest):
self.check_output()
class TestMovingAverageAbsMaxScaleOp(OpTest):
def setUp(self):
self.op_type = "moving_average_abs_max_scale"
self.attrs = {'moving_rate': float(0.9), 'is_test': False}
accum = np.zeros(1).astype("float32")
accum[0] = 1
state = np.zeros(1).astype("float32")
state[0] = 1
self.inputs = {
'X': np.random.random((8, 16, 7, 7)).astype("float32"),
'InAccum': accum,
'InState': state,
}
out_accum = np.zeros(1).astype("float32")
out_state = np.zeros(1).astype("float32")
out_scale = np.zeros(1).astype("float32")
out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max(
np.abs(self.inputs['X'])).astype("float32")
out_state[0] = self.attrs['moving_rate'] * state[0] + 1
out_scale = out_accum / out_state
self.outputs = {
'Out': self.inputs['X'],
'OutAccum': out_accum,
'OutState': out_state,
'OutScale': out_scale,
}
def test_check_output(self):
self.check_output()
class TestFakeQuantizeRangeAbsMaxOp2(OpTest):
def setUp(self):
self.op_type = "fake_quantize_range_abs_max"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册