未验证 提交 3a29821b 编写于 作者: Q qingqing01 提交者: GitHub

Develop a fake dequantized op for fixed-point quantization training framework. (#10965)

* Develop a fake dequantized op for fixed-point quantization training framework.

* Add the missing file.
上级 66ec827a
...@@ -168,6 +168,8 @@ function(op_library TARGET) ...@@ -168,6 +168,8 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(relu);\n") file(APPEND ${pybind_file} "USE_OP(relu);\n")
elseif(${TARGET} STREQUAL "reduce") elseif(${TARGET} STREQUAL "reduce")
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n") file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
elseif(${TARGET} STREQUAL "fake_dequantize")
file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
else() else()
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n") file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
endif() endif()
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fake_dequantize_op.h"
#include <string>
namespace paddle {
namespace operators {
class FakeDequantizeMaxAbsOp : public framework::OperatorWithKernel {
public:
FakeDequantizeMaxAbsOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FakeDequantizeMaxAbsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of FakeDequantizeMaxAbsOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class FakeDequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor) The input with float-32/64 type is the "
"low precision tensor.");
AddOutput("Out",
"(Tensor) The output is the dequantized high "
"precision tensor.");
AddAttr<int>("num_bits",
"(int) `num_bits` is the quantization level bits, "
"such as 2, 5, 8.");
AddAttr<float>("scale",
"(float) The maximum absolute value of low precision tensor."
"It is usually calculated by the fake_quantize_max_abs_op.");
AddComment(R"DOC(
FakeDequantizeMaxAbsOp operator.
This calculation is an opposite operation of FakeQuantizeMaxAbsOp:
$$Out = \frac{scale*X}{2^{num_bits} - 1}$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(fake_dequantize_max_abs, ops::FakeDequantizeMaxAbsOp,
ops::FakeDequantizeMaxAbsOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_dequantize_max_abs,
ops::FakeDequantizeMaxAbsKernel<CPU, float>,
ops::FakeDequantizeMaxAbsKernel<CPU, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fake_dequantize_op.h"
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs,
ops::FakeDequantizeMaxAbsKernel<CUDA, float>,
ops::FakeDequantizeMaxAbsKernel<CUDA, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(in->place());
int num_bits = ctx.Attr<int>("num_bits");
T scale = static_cast<T>(ctx.Attr<float>("scale"));
int range = std::pow(2, num_bits) - 1;
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
eigen_out.device(dev) = (scale / range) * eigen_in;
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import math
from op_test import OpTest
def quantize_max_abs(x, num_bits):
range = math.pow(2, num_bits) - 1
scale = np.max(np.abs(x).flatten())
y = np.round(x / scale * range)
return y, scale
def dequantize_max_abs(x, num_bits, scale):
range = math.pow(2, num_bits) - 1
y = (scale / range) * x
return y
class TestFakeDequantizeMaxAbsOp(OpTest):
def set_args(self):
self.num_bits = 8
def setUp(self):
self.set_args()
self.op_type = "fake_dequantize_max_abs"
x = np.random.randn(31, 65).astype("float32")
yq, scale = quantize_max_abs(x, self.num_bits)
print 'scale ', scale
ydq = dequantize_max_abs(yq, self.num_bits, scale)
self.inputs = {'X': yq}
self.attrs = {'num_bits': self.num_bits, 'scale': float(scale)}
self.outputs = {'Out': ydq}
def test_check_output(self):
self.check_output()
class TestFakeDequantizeMaxAbsOp5Bits(OpTest):
def set_args(self):
self.num_bits = 5
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册