Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
bb45af02
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bb45af02
编写于
7月 09, 2020
作者:
Z
Zhen Wang
提交者:
GitHub
7月 09, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add the c++ part of Imperative QAT. test=develop (#25446)
上级
090a331d
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
185 addition
and
37 deletion
+185
-37
paddle/fluid/operators/fake_dequantize_op.cc
paddle/fluid/operators/fake_dequantize_op.cc
+3
-3
paddle/fluid/operators/fake_quantize_op.cc
paddle/fluid/operators/fake_quantize_op.cc
+82
-23
paddle/fluid/operators/fake_quantize_op.cu
paddle/fluid/operators/fake_quantize_op.cu
+5
-1
paddle/fluid/operators/fake_quantize_op.h
paddle/fluid/operators/fake_quantize_op.h
+63
-10
paddle/fluid/platform/dynload/cusolver.h
paddle/fluid/platform/dynload/cusolver.h
+1
-0
paddle/fluid/pybind/op_function_generator.cc
paddle/fluid/pybind/op_function_generator.cc
+1
-0
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
+30
-0
未找到文件。
paddle/fluid/operators/fake_dequantize_op.cc
浏览文件 @
bb45af02
...
@@ -29,7 +29,7 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> {
...
@@ -29,7 +29,7 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> {
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
auto
&
dev
=
*
dev_ctx
.
eigen_device
();
auto
&
dev
=
*
dev_ctx
.
eigen_device
();
out_e
.
device
(
dev
)
=
scale_factor
[
0
]
*
in_e
/
max_range
;
out_e
.
device
(
dev
)
=
in_e
*
scale_factor
[
0
]
/
max_range
;
}
}
};
};
...
@@ -48,7 +48,7 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
...
@@ -48,7 +48,7 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
auto
in_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_in
);
auto
in_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_in
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
auto
&
dev
=
*
dev_ctx
.
eigen_device
();
auto
&
dev
=
*
dev_ctx
.
eigen_device
();
out_e
.
device
(
dev
)
=
s
*
in_e
/
max_range
;
out_e
.
device
(
dev
)
=
in_e
*
s
/
max_range
;
}
}
}
else
if
(
scale_num
==
2
)
{
}
else
if
(
scale_num
==
2
)
{
int
batch_size
=
in
->
dims
()[
0
];
int
batch_size
=
in
->
dims
()[
0
];
...
@@ -67,7 +67,7 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
...
@@ -67,7 +67,7 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
auto
in_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_in
);
auto
in_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_in
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
auto
&
dev
=
*
dev_ctx
.
eigen_device
();
auto
&
dev
=
*
dev_ctx
.
eigen_device
();
out_e
.
device
(
dev
)
=
(
s
*
scale_two
[
0
])
*
in_e
/
max_range
;
out_e
.
device
(
dev
)
=
in_e
*
s
*
scale_two
[
0
]
/
max_range
;
}
}
}
}
}
}
...
...
paddle/fluid/operators/fake_quantize_op.cc
浏览文件 @
bb45af02
...
@@ -82,7 +82,7 @@ struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
...
@@ -82,7 +82,7 @@ struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
ClipFunctor
<
T
>
(
-
s
,
s
));
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
ClipFunctor
<
T
>
(
-
s
,
s
));
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
out_e
.
device
(
*
ctx
.
eigen_device
())
=
out_e
.
device
(
*
ctx
.
eigen_device
())
=
(
s
/
bin_cnt
)
*
(
bin_cnt
*
inv_s
*
out_e
).
round
(
);
(
bin_cnt
*
inv_s
*
out_e
).
round
()
*
s
/
static_cast
<
T
>
(
bin_cnt
);
}
}
};
};
template
struct
ClipAndFakeQuantDequantFunctor
<
platform
::
CPUDeviceContext
,
template
struct
ClipAndFakeQuantDequantFunctor
<
platform
::
CPUDeviceContext
,
...
@@ -171,20 +171,21 @@ struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext, T> {
...
@@ -171,20 +171,21 @@ struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext, T> {
template
struct
FindMovingAverageAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
template
struct
FindMovingAverageAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
float
>;
float
>;
class
FakeQuant
ize
AbsMaxOp
:
public
framework
::
OperatorWithKernel
{
class
FakeQuant
OrWithDequant
AbsMaxOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
FakeQuant
ize
AbsMaxOp
(
const
std
::
string
&
type
,
FakeQuant
OrWithDequant
AbsMaxOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FakeQuantizeAbsMax"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FakeQuantOrWithDequantAbsMaxOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"FakeQuant
izeAbsMax
"
);
"FakeQuant
OrWithDequantAbsMaxOp
"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"OutScale"
),
"Output"
,
"OutScale"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"OutScale"
),
"Output"
,
"OutScale"
,
"FakeQuant
izeAbsMax
"
);
"FakeQuant
OrWithDequantAbsMaxOp
"
);
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"OutScale"
,
{
1
});
ctx
->
SetOutputDim
(
"OutScale"
,
{
1
});
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
...
@@ -199,7 +200,8 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
...
@@ -199,7 +200,8 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
}
}
};
};
class
FakeQuantizeAbsMaxOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
FakeQuantOrWithDequantAbsMaxOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
override
{
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) Input is float data type."
);
AddInput
(
"X"
,
"(Tensor) Input is float data type."
);
...
@@ -217,12 +219,19 @@ class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -217,12 +219,19 @@ class FakeQuantizeAbsMaxOpMaker : public framework::OpProtoAndCheckerMaker {
bit_length
));
bit_length
));
});
});
AddComment
(
R"DOC(
AddComment
(
R"DOC(
FakeQuantize operator
This is a Base Op which support FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker.
FakeQuantAbsMaxOp operator is used in the dynamic quantization.
$$scale = max(abs(X))$$
$$scale = max(abs(X))$$
$$range = 2^{bit_length - 1} - 1$$
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
$$Out = round(X/scale * range)$$
FakeQuantDequantAbsMaxOp operator do the abs_max quant and then dequant.
$$scale = max(abs(X))$$
$$range = 2^{bit\_length - 1} - 1$$
$$Out = round(X/scale * range) * scale / range$$
)DOC"
);
)DOC"
);
}
}
};
};
...
@@ -414,14 +423,14 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
...
@@ -414,14 +423,14 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
"for training. Some layers may run faster when this is true."
)
"for training. Some layers may run faster when this is true."
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
This is a Base Op which support FakeQuantMovingAverageAbsMaxOp and FakeQuantDequantMovingAverageAbsMaxOp
This is a Base Op which support FakeQuantMovingAverageAbsMaxOp and FakeQuantDequantMovingAverageAbsMaxOp
.
FakeQuantMovingAverageAbsMaxOp operator is used in static quantization.
FakeQuantMovingAverageAbsMaxOp operator is used in
the
static quantization.
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$range = 2^{bit\_length - 1} - 1$$
$$range = 2^{bit\_length - 1} - 1$$
$$Out = round(X/scale * range)$$
$$Out = round(X/scale * range)$$
FakeQuantDequantMovingAverageAbsMaxOp operator do the moving_average_abs_max
op
quant and then dequant.
FakeQuantDequantMovingAverageAbsMaxOp operator do the moving_average_abs_max quant and then dequant.
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$range = 2^{bit\_length - 1} - 1$$
$$range = 2^{bit\_length - 1} - 1$$
...
@@ -490,6 +499,46 @@ $$Out = X$$
...
@@ -490,6 +499,46 @@ $$Out = X$$
}
}
};
};
class
FakeQuantDequantGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
out_grad_name
=
framework
::
GradVarName
(
"Out"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
out_grad_name
),
"Input"
,
out_grad_name
,
"FakeQuantDequantGradOp"
);
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
x_grad_name
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"FakeQuantDequantGradOp doesn't have the output named %s."
,
x_grad_name
));
ctx
->
SetOutputDim
(
x_grad_name
,
ctx
->
GetInputDim
(
out_grad_name
));
}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
framework
::
GradVarName
(
"Out"
));
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
};
template
<
typename
T
>
class
FakeQuantDequantGradMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
grad_op
)
const
override
{
grad_op
->
SetType
(
"fake_quantize_dequantize_grad"
);
grad_op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
grad_op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
grad_op
->
SetAttrMap
(
this
->
Attrs
());
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
@@ -497,13 +546,21 @@ namespace ops = paddle::operators;
...
@@ -497,13 +546,21 @@ namespace ops = paddle::operators;
using
CPU
=
paddle
::
platform
::
CPUDeviceContext
;
using
CPU
=
paddle
::
platform
::
CPUDeviceContext
;
REGISTER_OPERATOR
(
REGISTER_OPERATOR
(
fake_quantize_abs_max
,
ops
::
FakeQuant
ize
AbsMaxOp
,
fake_quantize_abs_max
,
ops
::
FakeQuant
OrWithDequant
AbsMaxOp
,
ops
::
FakeQuant
ize
AbsMaxOpMaker
,
ops
::
FakeQuant
OrWithDequant
AbsMaxOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
fake_quantize_abs_max
,
REGISTER_OP_CPU_KERNEL
(
fake_quantize_abs_max
,
ops
::
FakeQuantizeAbsMaxKernel
<
CPU
,
float
>
);
ops
::
FakeQuantizeAbsMaxKernel
<
CPU
,
float
>
);
REGISTER_OPERATOR
(
fake_quantize_dequantize_abs_max
,
ops
::
FakeQuantOrWithDequantAbsMaxOp
,
ops
::
FakeQuantOrWithDequantAbsMaxOpMaker
,
ops
::
FakeQuantDequantGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
FakeQuantDequantGradMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
fake_quantize_dequantize_abs_max
,
ops
::
FakeQuantizeDequantizeAbsMaxKernel
<
CPU
,
float
>
);
REGISTER_OPERATOR
(
REGISTER_OPERATOR
(
fake_quantize_range_abs_max
,
ops
::
FakeQuantizeRangeAbsMaxOp
,
fake_quantize_range_abs_max
,
ops
::
FakeQuantizeRangeAbsMaxOp
,
ops
::
FakeQuantizeRangeAbsMaxOpMaker
,
ops
::
FakeQuantizeRangeAbsMaxOpMaker
,
...
@@ -518,16 +575,14 @@ REGISTER_OPERATOR(
...
@@ -518,16 +575,14 @@ REGISTER_OPERATOR(
ops
::
FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
,
ops
::
FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
fake_quantize_moving_average_abs_max
,
REGISTER_OP_CPU_KERNEL
(
fake_quantize_moving_average_abs_max
,
ops
::
FakeQuantizeMovingAverageAbsMaxKernel
<
CPU
,
float
>
);
ops
::
FakeQuantizeMovingAverageAbsMaxKernel
<
CPU
,
float
>
);
REGISTER_OPERATOR
(
REGISTER_OPERATOR
(
fake_quantize_dequantize_moving_average_abs_max
,
fake_quantize_dequantize_moving_average_abs_max
,
ops
::
FakeQuantOrWithDequantMovingAverageAbsMaxOp
,
ops
::
FakeQuantOrWithDequantMovingAverageAbsMaxOp
,
ops
::
FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
,
ops
::
FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
,
ops
::
FakeQuantDequantGradMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
FakeQuantDequantGradMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
fake_quantize_dequantize_moving_average_abs_max
,
fake_quantize_dequantize_moving_average_abs_max
,
ops
::
FakeQuantizeDequantizeMovingAverageAbsMaxKernel
<
CPU
,
float
>
);
ops
::
FakeQuantizeDequantizeMovingAverageAbsMaxKernel
<
CPU
,
float
>
);
...
@@ -547,3 +602,7 @@ REGISTER_OPERATOR(
...
@@ -547,3 +602,7 @@ REGISTER_OPERATOR(
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
moving_average_abs_max_scale
,
REGISTER_OP_CPU_KERNEL
(
moving_average_abs_max_scale
,
ops
::
MovingAverageAbsMaxScaleKernel
<
CPU
,
float
>
);
ops
::
MovingAverageAbsMaxScaleKernel
<
CPU
,
float
>
);
REGISTER_OPERATOR
(
fake_quantize_dequantize_grad
,
ops
::
FakeQuantDequantGradOp
);
REGISTER_OP_CPU_KERNEL
(
fake_quantize_dequantize_grad
,
ops
::
FakeQuantDequantGradKernel
<
CPU
,
float
>
);
paddle/fluid/operators/fake_quantize_op.cu
浏览文件 @
bb45af02
...
@@ -138,9 +138,9 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
...
@@ -138,9 +138,9 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
T
s
=
scale
[
0
];
T
s
=
scale
[
0
];
T
inv_s
=
inverse
(
s
);
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
x
=
in
[
i
];
T
x
=
in
[
i
];
T
inv_s
=
inverse
(
s
);
T
v
=
x
>
s
?
s
:
x
;
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
v
=
bin_cnt
*
inv_s
*
v
;
...
@@ -335,6 +335,8 @@ namespace ops = paddle::operators;
...
@@ -335,6 +335,8 @@ namespace ops = paddle::operators;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_abs_max
,
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_abs_max
,
ops
::
FakeQuantizeAbsMaxKernel
<
CUDA
,
float
>
);
ops
::
FakeQuantizeAbsMaxKernel
<
CUDA
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_dequantize_abs_max
,
ops
::
FakeQuantizeDequantizeAbsMaxKernel
<
CUDA
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_channel_wise_quantize_abs_max
,
REGISTER_OP_CUDA_KERNEL
(
fake_channel_wise_quantize_abs_max
,
ops
::
FakeChannelWiseQuantizeAbsMaxKernel
<
CUDA
,
float
>
);
ops
::
FakeChannelWiseQuantizeAbsMaxKernel
<
CUDA
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_range_abs_max
,
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_range_abs_max
,
...
@@ -347,3 +349,5 @@ REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale,
...
@@ -347,3 +349,5 @@ REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale,
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_dequantize_moving_average_abs_max
,
fake_quantize_dequantize_moving_average_abs_max
,
ops
::
FakeQuantizeDequantizeMovingAverageAbsMaxKernel
<
CUDA
,
float
>
);
ops
::
FakeQuantizeDequantizeMovingAverageAbsMaxKernel
<
CUDA
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_dequantize_grad
,
ops
::
FakeQuantDequantGradKernel
<
CUDA
,
float
>
);
paddle/fluid/operators/fake_quantize_op.h
浏览文件 @
bb45af02
...
@@ -21,6 +21,7 @@ limitations under the License. */
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "paddle/fluid/platform/transform.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -81,7 +82,7 @@ struct FindMovingAverageAbsMaxFunctor {
...
@@ -81,7 +82,7 @@ struct FindMovingAverageAbsMaxFunctor {
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
Fake
QuantizeAbsMaxKernel
:
public
framework
::
OpKernel
<
T
>
{
class
Fake
AbsMaxKernelBase
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
...
@@ -95,8 +96,38 @@ class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> {
...
@@ -95,8 +96,38 @@ class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> {
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
const
T
*
in_data
=
in
->
data
<
T
>
();
const
T
*
in_data
=
in
->
data
<
T
>
();
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in_data
,
in
->
numel
(),
out_s
);
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in_data
,
in
->
numel
(),
out_s
);
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
out_scale
,
RunClipFunctor
(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
out
);
bin_cnt
,
out
);
}
virtual
~
FakeAbsMaxKernelBase
()
=
default
;
protected:
virtual
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
int
bin_cnt
,
framework
::
Tensor
*
out
)
const
=
0
;
};
template
<
typename
DeviceContext
,
typename
T
>
class
FakeQuantizeAbsMaxKernel
:
public
FakeAbsMaxKernelBase
<
DeviceContext
,
T
>
{
protected:
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
int
bin_cnt
,
framework
::
Tensor
*
out
)
const
override
{
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
scale
,
bin_cnt
,
out
);
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
FakeQuantizeDequantizeAbsMaxKernel
:
public
FakeAbsMaxKernelBase
<
DeviceContext
,
T
>
{
protected:
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
int
bin_cnt
,
framework
::
Tensor
*
out
)
const
override
{
ClipAndFakeQuantDequantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
scale
,
bin_cnt
,
out
);
}
}
};
};
...
@@ -167,11 +198,6 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
...
@@ -167,11 +198,6 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
FakeMovingAverageAbsMaxKernelBase
:
public
framework
::
OpKernel
<
T
>
{
class
FakeMovingAverageAbsMaxKernelBase
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
~
FakeMovingAverageAbsMaxKernelBase
()
{}
virtual
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
in_scale
,
int
bin_cnt
,
framework
::
Tensor
*
out
)
const
=
0
;
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
in_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"InScale"
);
auto
*
in_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"InScale"
);
...
@@ -212,12 +238,20 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
...
@@ -212,12 +238,20 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
RunClipFunctor
(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
out
);
RunClipFunctor
(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
out
);
}
}
virtual
~
FakeMovingAverageAbsMaxKernelBase
()
=
default
;
protected:
virtual
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
in_scale
,
int
bin_cnt
,
framework
::
Tensor
*
out
)
const
=
0
;
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
FakeQuantizeMovingAverageAbsMaxKernel
class
FakeQuantizeMovingAverageAbsMaxKernel
:
public
FakeMovingAverageAbsMaxKernelBase
<
DeviceContext
,
T
>
{
:
public
FakeMovingAverageAbsMaxKernelBase
<
DeviceContext
,
T
>
{
p
ublic
:
p
rotected
:
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
in_scale
,
int
bin_cnt
,
const
framework
::
Tensor
&
in_scale
,
int
bin_cnt
,
framework
::
Tensor
*
out
)
const
override
{
framework
::
Tensor
*
out
)
const
override
{
...
@@ -229,7 +263,7 @@ class FakeQuantizeMovingAverageAbsMaxKernel
...
@@ -229,7 +263,7 @@ class FakeQuantizeMovingAverageAbsMaxKernel
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
FakeQuantizeDequantizeMovingAverageAbsMaxKernel
class
FakeQuantizeDequantizeMovingAverageAbsMaxKernel
:
public
FakeMovingAverageAbsMaxKernelBase
<
DeviceContext
,
T
>
{
:
public
FakeMovingAverageAbsMaxKernelBase
<
DeviceContext
,
T
>
{
p
ublic
:
p
rotected
:
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
in_scale
,
int
bin_cnt
,
const
framework
::
Tensor
&
in_scale
,
int
bin_cnt
,
framework
::
Tensor
*
out
)
const
override
{
framework
::
Tensor
*
out
)
const
override
{
...
@@ -277,5 +311,24 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
...
@@ -277,5 +311,24 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
}
}
};
};
template
<
typename
DeviceContext
,
typename
T
>
class
FakeQuantDequantGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
d_out
=
context
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
auto
*
d_x
=
context
.
Output
<
framework
::
LoDTensor
>
(
x_grad_name
);
PADDLE_ENFORCE_NOT_NULL
(
d_x
,
platform
::
errors
::
PreconditionNotMet
(
"FakeQuantDequantGradOp doesn't have the output named %s."
,
x_grad_name
));
// Initialize dx as same as d_out
d_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
framework
::
TensorCopy
(
*
d_out
,
context
.
GetPlace
(),
d_x
);
}
};
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/platform/dynload/cusolver.h
浏览文件 @
bb45af02
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <cuda.h>
#include <cusolverDn.h>
#include <cusolverDn.h>
#include <mutex> // NOLINT
#include <mutex> // NOLINT
...
...
paddle/fluid/pybind/op_function_generator.cc
浏览文件 @
bb45af02
...
@@ -80,6 +80,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
...
@@ -80,6 +80,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{
"matmul"
,
{
"Out"
}},
{
"matmul"
,
{
"Out"
}},
{
"fake_quantize_dequantize_moving_average_abs_max"
,
{
"fake_quantize_dequantize_moving_average_abs_max"
,
{
"Out"
,
"OutScale"
,
"OutAccum"
,
"OutState"
}},
{
"Out"
,
"OutScale"
,
"OutAccum"
,
"OutState"
}},
{
"fake_quantize_dequantize_abs_max"
,
{
"Out"
,
"OutScale"
}},
{
"amp_check_finite_and_scale"
,
{
"Out"
,
"FoundInfinite"
}},
{
"amp_check_finite_and_scale"
,
{
"Out"
,
"FoundInfinite"
}},
};
};
...
...
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
浏览文件 @
bb45af02
...
@@ -242,6 +242,36 @@ class TestFakeQuantDequantMovingOp(TestMovingOpBase):
...
@@ -242,6 +242,36 @@ class TestFakeQuantDequantMovingOp(TestMovingOpBase):
return
np
.
round
(
self
.
inputs
[
'X'
]
/
out_scale
*
return
np
.
round
(
self
.
inputs
[
'X'
]
/
out_scale
*
range_v
)
*
out_scale
/
range_v
range_v
)
*
out_scale
/
range_v
def
test_check_grad
(
self
):
x
=
self
.
inputs
[
"X"
]
gradient
=
[
np
.
ones
(
x
.
shape
)
/
np
.
product
(
x
.
shape
)]
self
.
check_grad
([
"X"
],
"Out"
,
user_defined_grads
=
gradient
)
class
TestFakeQuantDequantAbsOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"fake_quantize_dequantize_abs_max"
self
.
attrs
=
{
'bit_length'
:
8
}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
124
,
240
)).
astype
(
"float32"
),
}
scale
=
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
])).
astype
(
"float32"
)
out_data
=
self
.
calc_output
(
scale
)
self
.
outputs
=
{
'Out'
:
out_data
,
'OutScale'
:
np
.
array
(
scale
).
astype
(
"float32"
),
}
def
calc_output
(
self
,
scale
):
range_v
=
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
return
np
.
round
(
self
.
inputs
[
'X'
]
/
scale
*
range_v
)
*
scale
/
range_v
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
x
=
self
.
inputs
[
"X"
]
gradient
=
[
np
.
ones
(
x
.
shape
)
/
np
.
product
(
x
.
shape
)]
self
.
check_grad
([
"X"
],
"Out"
,
user_defined_grads
=
gradient
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录