提交 545247d7 编写于 作者: Z Zhen Wang

add channel wise quantize op.

上级 187cffd0
......@@ -134,6 +134,61 @@ $$Out = round(X/scale * range)$$
}
};
class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeChannelWiseQuantizeOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("Out"),
"Output(Out) of FakeChannelWiseQuantizeOp should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("OutScales"),
"Output(Scales) of FakeChannelWiseQuantizeOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScales", {ctx->GetInputDim("X")[0]});
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.GetPlace());
}
};
class FakeChannelWiseQuantizeAbsMaxOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input is float data type.");
AddOutput("Out",
"(Tensor) Output of quantized low level tensor, "
"but also saved as float data type.");
AddOutput("OutScales", "(Tensor) Current channel wise scale");
AddAttr<int>("bit_length", "(int, default 8)")
.SetDefault(8)
.AddCustomChecker([](const int& bit_length) {
PADDLE_ENFORCE(bit_length >= 1 && bit_length <= 16,
"'bit_length' should be between 1 and 16.");
});
AddComment(R"DOC(
The scale of FakeChannelWiseQuantize operator is a vector.
In detail, each channel of the input X has a scale value.
$$scale_c = max(abs(X_c))$$
$$range = 2^{bit_length - 1} - 1$$
$$Out_c = round(X_c / scale_c * range)$$
In above three formulas, the range value of c is as follow:
$$0 \leq c \leq \ the\ channel\ number\ of\ X$$
)DOC");
}
};
class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
public:
FakeQuantizeRangeAbsMaxOp(const std::string& type,
......@@ -218,3 +273,10 @@ REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxOp,
ops::FakeChannelWiseQuantizeAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CPU, float>);
......@@ -174,5 +174,7 @@ namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max,
ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float>);
......@@ -63,6 +63,39 @@ class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto* out_scales = context.Output<framework::Tensor>("OutScales");
T* out_scales_data = out_scales->mutable_data<T>(context.GetPlace());
out->mutable_data<T>(context.GetPlace());
int bit_length = context.Attr<int>("bit_length");
int bin_cnt = std::pow(2, bit_length - 1) - 1;
auto& dev_ctx = context.template device_context<DeviceContext>();
auto find_abs_max = FindAbsMaxFunctor<DeviceContext, T>();
for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_channel = in->Slice(i, i + 1);
const T* one_channel_data = one_channel.data<T>();
find_abs_max(dev_ctx, one_channel_data, one_channel.numel(),
&out_scales_data[i]);
}
auto clip_quant = ClipAndFakeQuantFunctor<DeviceContext, T>();
for (int64_t i = 0; i < in->dims()[0]; i++) {
framework::Tensor one_channel_in = in->Slice(i, i + 1);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
framework::Tensor one_channel_scale = out_scales->Slice(i, i + 1);
clip_quant(dev_ctx, one_channel_in, one_channel_scale, bin_cnt,
&one_channel_out);
}
}
};
template <typename DeviceContext, typename T>
class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
public:
......
......@@ -35,6 +35,30 @@ class TestFakeQuantizeOp(OpTest):
self.check_output()
class TestFakeChannelWiseQuantizeOp(OpTest):
def setUp(self):
self.op_type = "fake_channel_wise_quantize_abs_max"
self.attrs = {'bit_length': 8}
self.inputs = {
'X': np.random.random((4, 3, 64, 64)).astype("float32"),
}
scales = []
for i in range(self.inputs['X'].shape[0]):
scales.append(np.max(np.abs(self.inputs['X'][i])).astype("float32"))
outputs = self.inputs['X'].copy()
for i, scale in enumerate(scales):
outputs[i] = np.round(outputs[i] / scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1))
self.outputs = {
'Out': outputs,
'OutScales': np.array(scales).astype("float32"),
}
def test_check_output(self):
self.check_output()
class TestFakeQuantizeRangeAbsMaxOp(OpTest):
def setUp(self):
self.op_type = "fake_quantize_range_abs_max"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册