未验证 提交 bb45af02 编写于 作者: Z Zhen Wang 提交者: GitHub

add the c++ part of Imperative QAT. test=develop (#25446)

上级 090a331d
...@@ -29,7 +29,7 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> { ...@@ -29,7 +29,7 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> {
auto out_e = framework::EigenVector<T>::Flatten(*out); auto out_e = framework::EigenVector<T>::Flatten(*out);
auto& dev = *dev_ctx.eigen_device(); auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = scale_factor[0] * in_e / max_range; out_e.device(dev) = in_e * scale_factor[0] / max_range;
} }
}; };
...@@ -48,7 +48,7 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> { ...@@ -48,7 +48,7 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in); auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out); auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
auto& dev = *dev_ctx.eigen_device(); auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = s * in_e / max_range; out_e.device(dev) = in_e * s / max_range;
} }
} else if (scale_num == 2) { } else if (scale_num == 2) {
int batch_size = in->dims()[0]; int batch_size = in->dims()[0];
...@@ -67,7 +67,7 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> { ...@@ -67,7 +67,7 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
auto in_e = framework::EigenVector<T>::Flatten(one_channel_in); auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out); auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
auto& dev = *dev_ctx.eigen_device(); auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = (s * scale_two[0]) * in_e / max_range; out_e.device(dev) = in_e * s * scale_two[0] / max_range;
} }
} }
} }
......
...@@ -82,7 +82,7 @@ struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, T> { ...@@ -82,7 +82,7 @@ struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s)); out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
auto out_e = framework::EigenVector<T>::Flatten(*out); auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(*ctx.eigen_device()) = out_e.device(*ctx.eigen_device()) =
(s / bin_cnt) * (bin_cnt * inv_s * out_e).round(); (bin_cnt * inv_s * out_e).round() * s / static_cast<T>(bin_cnt);
} }
}; };
template struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, template struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext,
...@@ -171,20 +171,21 @@ struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext, T> { ...@@ -171,20 +171,21 @@ struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext, T> {
template struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext, template struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext,
float>; float>;
class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel {
public: public:
FakeQuantizeAbsMaxOp(const std::string& type, FakeQuantOrWithDequantAbsMaxOp(const std::string& type,
const framework::VariableNameMap& inputs, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs, const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs) const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {} : OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeAbsMax"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
"FakeQuantOrWithDequantAbsMaxOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"FakeQuantizeAbsMax"); "FakeQuantOrWithDequantAbsMaxOp");
OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale", OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
"FakeQuantizeAbsMax"); "FakeQuantOrWithDequantAbsMaxOp");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("OutScale", {1}); ctx->SetOutputDim("OutScale", {1});
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
...@@ -199,7 +200,8 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -199,7 +200,8 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
} }
}; };
class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker { class FakeQuantOrWithDequantAbsMaxOpMaker
: public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor) Input is float data type."); AddInput("X", "(Tensor) Input is float data type.");
...@@ -217,12 +219,19 @@ class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -217,12 +219,19 @@ class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker {
bit_length)); bit_length));
}); });
AddComment(R"DOC( AddComment(R"DOC(
FakeQuantize operator This is a Base Op which support FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker.
FakeQuantAbsMaxOp operator is used in the dynamic quantization.
$$scale = max(abs(X))$$ $$scale = max(abs(X))$$
$$range = 2^{bit_length - 1} - 1$$ $$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$ $$Out = round(X/scale * range)$$
FakeQuantDequantAbsMaxOp operator do the abs_max quant and then dequant.
$$scale = max(abs(X))$$
$$range = 2^{bit\_length - 1} - 1$$
$$Out = round(X/scale * range) * scale / range$$
)DOC"); )DOC");
} }
}; };
...@@ -414,14 +423,14 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker ...@@ -414,14 +423,14 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
.SetDefault(false); .SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
This is a Base Op which support FakeQuantMovingAverageAbsMaxOp and FakeQuantDequantMovingAverageAbsMaxOp This is a Base Op which support FakeQuantMovingAverageAbsMaxOp and FakeQuantDequantMovingAverageAbsMaxOp.
FakeQuantMovingAverageAbsMaxOp operator is used in static quantization. FakeQuantMovingAverageAbsMaxOp operator is used in the static quantization.
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$ $$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$range = 2^{bit\_length - 1} - 1$$ $$range = 2^{bit\_length - 1} - 1$$
$$Out = round(X/scale * range)$$ $$Out = round(X/scale * range)$$
FakeQuantDequantMovingAverageAbsMaxOp operator do the moving_average_abs_max op quant and then dequant. FakeQuantDequantMovingAverageAbsMaxOp operator do the moving_average_abs_max quant and then dequant.
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$ $$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$range = 2^{bit\_length - 1} - 1$$ $$range = 2^{bit\_length - 1} - 1$$
...@@ -490,6 +499,46 @@ $$Out = X$$ ...@@ -490,6 +499,46 @@ $$Out = X$$
} }
}; };
class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name,
"FakeQuantDequantGradOp");
auto x_grad_name = framework::GradVarName("X");
PADDLE_ENFORCE_EQ(
ctx->HasOutput(x_grad_name), true,
platform::errors::PreconditionNotMet(
"FakeQuantDequantGradOp doesn't have the output named %s.",
x_grad_name));
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name));
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename T>
class FakeQuantDequantGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("fake_quantize_dequantize_grad");
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -497,13 +546,21 @@ namespace ops = paddle::operators; ...@@ -497,13 +546,21 @@ namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext; using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR( REGISTER_OPERATOR(
fake_quantize_abs_max, ops::FakeQuantizeAbsMaxOp, fake_quantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp,
ops::FakeQuantizeAbsMaxOpMaker, ops::FakeQuantOrWithDequantAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max, REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CPU, float>); ops::FakeQuantizeAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR(fake_quantize_dequantize_abs_max,
ops::FakeQuantOrWithDequantAbsMaxOp,
ops::FakeQuantOrWithDequantAbsMaxOpMaker,
ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max,
ops::FakeQuantizeDequantizeAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp, fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp,
ops::FakeQuantizeRangeAbsMaxOpMaker, ops::FakeQuantizeRangeAbsMaxOpMaker,
...@@ -518,16 +575,14 @@ REGISTER_OPERATOR( ...@@ -518,16 +575,14 @@ REGISTER_OPERATOR(
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker, ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max, REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>); ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR( REGISTER_OPERATOR(fake_quantize_dequantize_moving_average_abs_max,
fake_quantize_dequantize_moving_average_abs_max, ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp, ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker, ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
fake_quantize_dequantize_moving_average_abs_max, fake_quantize_dequantize_moving_average_abs_max,
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>); ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>);
...@@ -547,3 +602,7 @@ REGISTER_OPERATOR( ...@@ -547,3 +602,7 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale, REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale,
ops::MovingAverageAbsMaxScaleKernel<CPU, float>); ops::MovingAverageAbsMaxScaleKernel<CPU, float>);
REGISTER_OPERATOR(fake_quantize_dequantize_grad, ops::FakeQuantDequantGradOp);
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_grad,
ops::FakeQuantDequantGradKernel<CPU, float>);
...@@ -138,9 +138,9 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, ...@@ -138,9 +138,9 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
int tid = threadIdx.x; int tid = threadIdx.x;
T s = scale[0]; T s = scale[0];
T inv_s = inverse(s);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) { for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[i]; T x = in[i];
T inv_s = inverse(s);
T v = x > s ? s : x; T v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
v = bin_cnt * inv_s * v; v = bin_cnt * inv_s * v;
...@@ -335,6 +335,8 @@ namespace ops = paddle::operators; ...@@ -335,6 +335,8 @@ namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext; using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max, REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max,
ops::FakeQuantizeAbsMaxKernel<CUDA, float>); ops::FakeQuantizeAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_abs_max,
ops::FakeQuantizeDequantizeAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max, REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float>); ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max, REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max,
...@@ -347,3 +349,5 @@ REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale, ...@@ -347,3 +349,5 @@ REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale,
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
fake_quantize_dequantize_moving_average_abs_max, fake_quantize_dequantize_moving_average_abs_max,
ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>); ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_grad,
ops::FakeQuantDequantGradKernel<CUDA, float>);
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -81,7 +82,7 @@ struct FindMovingAverageAbsMaxFunctor { ...@@ -81,7 +82,7 @@ struct FindMovingAverageAbsMaxFunctor {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> { class FakeAbsMaxKernelBase : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
...@@ -95,8 +96,38 @@ class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> { ...@@ -95,8 +96,38 @@ class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
const T* in_data = in->data<T>(); const T* in_data = in->data<T>();
FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in_data, in->numel(), out_s); FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in_data, in->numel(), out_s);
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, *in, *out_scale, RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, out);
bin_cnt, out); }
virtual ~FakeAbsMaxKernelBase() = default;
protected:
virtual void RunClipFunctor(const DeviceContext& dev_ctx,
const framework::Tensor& in,
const framework::Tensor& scale, int bin_cnt,
framework::Tensor* out) const = 0;
};
template <typename DeviceContext, typename T>
class FakeQuantizeAbsMaxKernel : public FakeAbsMaxKernelBase<DeviceContext, T> {
protected:
void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in,
const framework::Tensor& scale, int bin_cnt,
framework::Tensor* out) const override {
ClipAndFakeQuantFunctor<DeviceContext, T>()(dev_ctx, in, scale, bin_cnt,
out);
}
};
template <typename DeviceContext, typename T>
class FakeQuantizeDequantizeAbsMaxKernel
: public FakeAbsMaxKernelBase<DeviceContext, T> {
protected:
void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in,
const framework::Tensor& scale, int bin_cnt,
framework::Tensor* out) const override {
ClipAndFakeQuantDequantFunctor<DeviceContext, T>()(dev_ctx, in, scale,
bin_cnt, out);
} }
}; };
...@@ -167,11 +198,6 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> { ...@@ -167,11 +198,6 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> { class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
public: public:
~FakeMovingAverageAbsMaxKernelBase() {}
virtual void RunClipFunctor(const DeviceContext& dev_ctx,
const framework::Tensor& in,
const framework::Tensor& in_scale, int bin_cnt,
framework::Tensor* out) const = 0;
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
auto* in_scale = context.Input<framework::Tensor>("InScale"); auto* in_scale = context.Input<framework::Tensor>("InScale");
...@@ -212,12 +238,20 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> { ...@@ -212,12 +238,20 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, out); RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, out);
} }
virtual ~FakeMovingAverageAbsMaxKernelBase() = default;
protected:
virtual void RunClipFunctor(const DeviceContext& dev_ctx,
const framework::Tensor& in,
const framework::Tensor& in_scale, int bin_cnt,
framework::Tensor* out) const = 0;
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeQuantizeMovingAverageAbsMaxKernel class FakeQuantizeMovingAverageAbsMaxKernel
: public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> { : public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> {
public: protected:
void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in,
const framework::Tensor& in_scale, int bin_cnt, const framework::Tensor& in_scale, int bin_cnt,
framework::Tensor* out) const override { framework::Tensor* out) const override {
...@@ -229,7 +263,7 @@ class FakeQuantizeMovingAverageAbsMaxKernel ...@@ -229,7 +263,7 @@ class FakeQuantizeMovingAverageAbsMaxKernel
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class FakeQuantizeDequantizeMovingAverageAbsMaxKernel class FakeQuantizeDequantizeMovingAverageAbsMaxKernel
: public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> { : public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> {
public: protected:
void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in, void RunClipFunctor(const DeviceContext& dev_ctx, const framework::Tensor& in,
const framework::Tensor& in_scale, int bin_cnt, const framework::Tensor& in_scale, int bin_cnt,
framework::Tensor* out) const override { framework::Tensor* out) const override {
...@@ -277,5 +311,24 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> { ...@@ -277,5 +311,24 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T>
class FakeQuantDequantGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* d_out =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto x_grad_name = framework::GradVarName("X");
auto* d_x = context.Output<framework::LoDTensor>(x_grad_name);
PADDLE_ENFORCE_NOT_NULL(
d_x, platform::errors::PreconditionNotMet(
"FakeQuantDequantGradOp doesn't have the output named %s.",
x_grad_name));
// Initialize dx as same as d_out
d_x->mutable_data<T>(context.GetPlace());
framework::TensorCopy(*d_out, context.GetPlace(), d_x);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <cuda.h>
#include <cusolverDn.h> #include <cusolverDn.h>
#include <mutex> // NOLINT #include <mutex> // NOLINT
......
...@@ -80,6 +80,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = { ...@@ -80,6 +80,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"matmul", {"Out"}}, {"matmul", {"Out"}},
{"fake_quantize_dequantize_moving_average_abs_max", {"fake_quantize_dequantize_moving_average_abs_max",
{"Out", "OutScale", "OutAccum", "OutState"}}, {"Out", "OutScale", "OutAccum", "OutState"}},
{"fake_quantize_dequantize_abs_max", {"Out", "OutScale"}},
{"amp_check_finite_and_scale", {"Out", "FoundInfinite"}}, {"amp_check_finite_and_scale", {"Out", "FoundInfinite"}},
}; };
......
...@@ -242,6 +242,36 @@ class TestFakeQuantDequantMovingOp(TestMovingOpBase): ...@@ -242,6 +242,36 @@ class TestFakeQuantDequantMovingOp(TestMovingOpBase):
return np.round(self.inputs['X'] / out_scale * return np.round(self.inputs['X'] / out_scale *
range_v) * out_scale / range_v range_v) * out_scale / range_v
def test_check_grad(self):
x = self.inputs["X"]
gradient = [np.ones(x.shape) / np.product(x.shape)]
self.check_grad(["X"], "Out", user_defined_grads=gradient)
class TestFakeQuantDequantAbsOp(OpTest):
def setUp(self):
self.op_type = "fake_quantize_dequantize_abs_max"
self.attrs = {'bit_length': 8}
self.inputs = {'X': np.random.random((124, 240)).astype("float32"), }
scale = np.max(np.abs(self.inputs['X'])).astype("float32")
out_data = self.calc_output(scale)
self.outputs = {
'Out': out_data,
'OutScale': np.array(scale).astype("float32"),
}
def calc_output(self, scale):
range_v = (1 << (self.attrs['bit_length'] - 1)) - 1
return np.round(self.inputs['X'] / scale * range_v) * scale / range_v
def test_check_output(self):
self.check_output()
def test_check_grad(self):
x = self.inputs["X"]
gradient = [np.ones(x.shape) / np.product(x.shape)]
self.check_grad(["X"], "Out", user_defined_grads=gradient)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册