Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
9bd933d3
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9bd933d3
编写于
9月 03, 2018
作者:
Q
qingqing01
提交者:
GitHub
9月 03, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Improve and fix fake_quantize_op (#13092)
* Improve and fix fake_quantize_op.
上级
b4d43030
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
369 addition
and
376 deletion
+369
-376
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+3
-0
paddle/fluid/operators/fake_quantize_op.cc
paddle/fluid/operators/fake_quantize_op.cc
+171
-53
paddle/fluid/operators/fake_quantize_op.cu
paddle/fluid/operators/fake_quantize_op.cu
+99
-194
paddle/fluid/operators/fake_quantize_op.h
paddle/fluid/operators/fake_quantize_op.h
+67
-113
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
+29
-16
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
9bd933d3
...
...
@@ -178,6 +178,8 @@ function(op_library TARGET)
file
(
APPEND
${
pybind_file
}
"USE_OP(relu);
\n
"
)
elseif
(
${
TARGET
}
STREQUAL
"fake_dequantize"
)
file
(
APPEND
${
pybind_file
}
"USE_OP(fake_dequantize_max_abs);
\n
"
)
elseif
(
${
TARGET
}
STREQUAL
"fake_quantize"
)
file
(
APPEND
${
pybind_file
}
"USE_OP(fake_quantize_abs_max);
\n
"
)
elseif
(
${
TARGET
}
STREQUAL
"tensorrt_engine_op"
)
message
(
STATUS
"Pybind skips [tensorrt_engine_op], for this OP is only used in inference"
)
elseif
(
${
TARGET
}
STREQUAL
"fc"
)
...
...
@@ -293,6 +295,7 @@ op_library(extract_rows_op DEPS memory)
op_library
(
flatten_op DEPS reshape_op
)
op_library
(
sequence_pad_op DEPS sequence_padding
)
op_library
(
unstack_op DEPS stack_op
)
op_library
(
fake_quantize_op DEPS memory
)
if
(
WITH_GPU
)
op_library
(
conv_op DEPS vol2col depthwise_conv im2col
)
...
...
paddle/fluid/operators/fake_quantize_op.cc
浏览文件 @
9bd933d3
...
...
@@ -14,86 +14,198 @@ limitations under the License. */
#include "paddle/fluid/operators/fake_quantize_op.h"
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/platform/transform.h"
namespace
paddle
{
namespace
operators
{
class
FakeQuantizeOp
:
public
framework
::
OperatorWithKernel
{
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenVectorArrayMap
=
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
T
,
1
,
MajorType
,
IndexType
>>
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
ConstEigenVectorArrayMap
=
Eigen
::
TensorMap
<
const
Eigen
::
Tensor
<
T
,
1
,
MajorType
,
IndexType
>>
;
template
<
typename
T
>
struct
FindAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
T
*
in
,
const
int
num
,
T
*
out
)
{
Eigen
::
DSizes
<
Eigen
::
DenseIndex
,
1
>
idim
(
num
);
Eigen
::
DSizes
<
Eigen
::
DenseIndex
,
1
>
odim
(
1
);
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
const
T
,
1
,
Eigen
::
RowMajor
>>
in_e
(
in
,
idim
);
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
T
,
1
,
Eigen
::
RowMajor
>>
out_e
(
out
,
odim
);
out_e
=
in_e
.
abs
().
maximum
();
}
};
template
struct
FindAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
<
typename
T
>
struct
ClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
)
{
T
s
=
scale
.
data
<
T
>
()[
0
];
platform
::
Transform
<
platform
::
CPUDeviceContext
>
trans
;
trans
(
ctx
,
in
.
data
<
T
>
(),
in
.
data
<
T
>
()
+
in
.
numel
(),
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
ClipFunctor
<
T
>
(
-
s
,
s
));
auto
in_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
in
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
out_e
.
device
(
*
ctx
.
eigen_device
())
=
(
bin_cnt
/
s
*
in_e
).
round
();
}
};
template
struct
ClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
<
typename
T
>
struct
FindRangeAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
cur_scale
,
const
framework
::
Tensor
&
last_scale
,
const
framework
::
Tensor
&
iter
,
const
int
window_size
,
framework
::
Tensor
*
scales_arr
,
framework
::
Tensor
*
out_scale
)
{
T
*
scale_arr
=
scales_arr
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int64_t
it
=
iter
.
data
<
int64_t
>
()[
0
];
int
idx
=
it
%
window_size
;
T
removed
=
scale_arr
[
idx
];
T
cur
=
cur_scale
.
data
<
T
>
()[
0
];
scale_arr
[
idx
]
=
cur
;
T
max
=
last_scale
.
data
<
T
>
()[
0
];
if
(
max
<
cur
)
{
max
=
cur
;
}
else
if
(
fabs
(
removed
-
max
)
<
1e-6
)
{
int
size
=
(
it
>
window_size
)
?
window_size
:
it
;
FindAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
()(
ctx
,
scale_arr
,
size
,
&
max
);
}
out_scale
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())[
0
]
=
max
;
}
};
template
struct
FindRangeAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
float
>;
class
FakeQuantizeAbsMaxOp
:
public
framework
::
OperatorWithKernel
{
public:
FakeQuantize
Op
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
FakeQuantize
AbsMaxOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of FakeQuantizeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of FakeQuantizeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"OutMovingScale"
),
"OutMovingScale(Out) of FakeQuantizeOp should not be null"
);
// if (ctx->HasInput("InMovingScale")) {
ctx
->
SetOutputDim
(
"OutMovingScale"
,
ctx
->
GetInputDim
(
"InMovingScale"
));
//}
// if (ctx->HasInput("InScales")) {
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"OutScales"
),
"OutScales(Out) of FakeQuantizeOp should not be null"
);
ctx
->
SetOutputDim
(
"OutScales"
,
ctx
->
GetInputDim
(
"InScales"
));
// PADDLE_ENFORCE_EQ(ctx->Inputs("InScales")[0],
// ctx->Outputs("OutScales")[0],
// "Mean and MeanOut should share the same memory");
//}
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"OutScale"
),
"Output(Scale) of FakeQuantizeOp should not be null."
);
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"OutScale"
,
{
1
});
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
class
FakeQuantizeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
FakeQuantize
AbsMax
OpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) Input tensor of scale operator."
);
AddInput
(
"InScales"
,
"(Tensor) scale buffer, used in static quantization."
)
.
AsDispensable
();
AddInput
(
"InMovingScale"
,
"Last scale, used in static quantization."
)
.
AsDispensable
();
AddInput
(
"InCurrentIter"
,
"Last iteration number, used in static quantization."
)
.
AsDispensable
();
AddOutput
(
"Out"
,
"(Tensor) Output of quantized low level tensor."
);
AddOutput
(
"OutScales"
,
"(Tensor) scale buffer, used in static quantization."
)
.
AsDispensable
();
AddOutput
(
"OutMovingScale"
,
" Current scale"
);
AddOutput
(
"OutCurrentIter"
,
"Current iteration number."
).
AsDispensable
();
AddAttr
<
std
::
string
>
(
"quantize_type"
,
"(string, default abs_max)"
"The scaling tpe of the quantize operator."
)
.
SetDefault
(
"abs_max"
);
AddAttr
<
int
>
(
"window_size"
,
"(int, default 10000)"
).
SetDefault
(
10000
);
AddInput
(
"X"
,
"(Tensor) Input is float data type."
);
AddOutput
(
"Out"
,
"(Tensor) Output of quantized low level tensor, "
"but also saved as float data type."
);
AddOutput
(
"OutScale"
,
"(Tensor) Current scale"
);
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8)"
)
.
SetDefault
(
8
)
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
PADDLE_ENFORCE
(
bit_length
>=
1
&&
bit_length
<=
16
,
"'bit_length' should be between 1 and 16."
);
});
AddAttr
<
bool
>
(
"is_test"
,
""
).
SetDefault
(
false
);
AddComment
(
R"DOC(
FakeQuantize operator
quantize_type = abs_max:
$$scale = max(abs(X))$$
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
$$scale = max(abs(x))$$
)DOC"
);
}
};
quantize_type = range_abs_max:
class
FakeQuantizeRangeAbsMaxOp
:
public
framework
::
OperatorWithKernel
{
public:
FakeQuantizeRangeAbsMaxOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
$$scale = max(max(abs(x)), history_abs_max)$$
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of FakeQuantizeRangeAbsMaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of FakeQuantizeRangeAbsMaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"OutScale"
),
"Output(OutScale) of FakeQuantizeRangeAbsMaxOp should not be null"
);
if
(
ctx
->
HasOutput
(
"OutScales"
))
{
int
window_size
=
ctx
->
Attrs
().
Get
<
int
>
(
"window_size"
);
ctx
->
SetOutputDim
(
"OutScales"
,
{
window_size
});
}
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"OutScale"
,
{
1
});
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
quantize_type = moving_average_abs_max:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
$$scale = 0.1*scale+0.9*new_abs_max)$$
class
FakeQuantizeRangeAbsMaxOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) Input is float data type."
);
AddInput
(
"InScale"
,
"Last scale."
);
AddInput
(
"Iter"
,
"Global step iteration."
).
AsDispensable
();
AddOutput
(
"Out"
,
"(Tensor) Output of quantized low level tensor."
);
AddOutput
(
"OutScale"
,
" Current scale"
);
AddOutput
(
"OutScales"
,
"(Tensor) scale buffer."
).
AsDispensable
();
AddAttr
<
int
>
(
"window_size"
,
"(int, default 10000) window range size."
)
.
SetDefault
(
10000
);
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8), quantization bit number."
)
.
SetDefault
(
8
)
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
PADDLE_ENFORCE
(
bit_length
>=
1
&&
bit_length
<=
16
,
"'bit_length' should be between 1 and 16."
);
});
AddAttr
<
bool
>
(
"is_test"
,
""
).
SetDefault
(
false
);
AddComment
(
R"DOC(
FakeQuantize operator is used in static quantization.
$$Out = scale*X$$
$$scale = max(max(abs(x)), history_abs_max)$$
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
)DOC"
);
}
...
...
@@ -103,10 +215,16 @@ $$Out = scale*X$$
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
using
CPU
=
paddle
::
platform
::
CPUDeviceContext
;
REGISTER_OPERATOR
(
fake_quantize_abs_max
,
ops
::
FakeQuantizeAbsMaxOp
,
ops
::
FakeQuantizeAbsMaxOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
fake_quantize_abs_max
,
ops
::
FakeQuantizeAbsMaxKernel
<
CPU
,
float
>
);
REGISTER_OPERATOR
(
fake_quantize
,
ops
::
FakeQuantizeOp
,
ops
::
FakeQuantizeOpMaker
,
REGISTER_OPERATOR
(
fake_quantize_range_abs_max
,
ops
::
FakeQuantizeRangeAbsMaxOp
,
ops
::
FakeQuantizeRangeAbsMaxOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
fake_quantize
,
ops
::
FakeQuantizeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
FakeQuantizeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
fake_quantize_range_abs_max
,
ops
::
FakeQuantizeRangeAbsMaxKernel
<
CPU
,
float
>
);
paddle/fluid/operators/fake_quantize_op.cu
浏览文件 @
9bd933d3
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
...
...
@@ -20,7 +21,7 @@ namespace paddle {
namespace
operators
{
template
<
typename
T
>
__global__
void
FindAbsMaxKernel
(
const
int
n
,
const
T
*
i
n
,
T
*
out
)
{
__global__
void
FindAbsMaxKernel
(
const
T
*
in
,
const
int
n
,
T
*
out
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
...
...
@@ -43,7 +44,7 @@ __global__ void FindAbsMaxKernel(const int n, const T* in, T* out) {
__syncthreads
();
for
(
int
i
=
blockDim
.
x
/
2
;
i
>
0
;
i
>>=
1
)
{
if
(
tid
<
i
&&
shared_max_data
[
tid
]
<
shared_max_data
[
tid
+
i
]
)
{
if
(
tid
<
i
&&
(
shared_max_data
[
tid
]
<
shared_max_data
[
tid
+
i
])
)
{
shared_max_data
[
tid
]
=
shared_max_data
[
tid
+
i
];
}
__syncthreads
();
...
...
@@ -53,220 +54,124 @@ __global__ void FindAbsMaxKernel(const int n, const T* in, T* out) {
}
}
float
FindAbsMaxGpu
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
float
*
array
,
int
length
)
{
float
host_max
;
int
kNumTheads
=
1024
;
int
gridDimx
=
(
kNumTheads
-
1
+
length
)
/
kNumTheads
;
gridDimx
=
(
gridDimx
>
kNumTheads
)
?
kNumTheads
:
gridDimx
;
framework
::
Tensor
t
;
float
*
device_max
=
t
.
mutable_data
<
float
>
(
framework
::
make_ddim
({
gridDimx
}),
platform
::
CUDAPlace
())
;
FindAbsMaxKernel
<
float
><<<
gridDimx
,
kNumTheads
,
kNumTheads
*
sizeof
(
float
),
ctx
.
stream
()
>>>
(
length
,
array
,
device_max
);
FindAbsMaxKernel
<
float
><<<
1
,
kNumTheads
,
kNumTheads
*
sizeof
(
float
),
ctx
.
stream
()
>>>
(
gridDimx
,
device_max
,
device_max
);
PADDLE_ENFORCE_EQ
(
cudaMemcpy
(
&
host_max
,
device_max
,
sizeof
(
float
),
cudaMemcpyDeviceToHost
),
cudaSuccess
,
"cudaMemcpy failed"
)
;
return
host_max
;
}
template
<
typename
T
>
struct
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
T
*
in
,
const
int
num
,
T
*
out
)
{
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
grid
=
(
grid
>
block
)
?
block
:
grid
;
framework
::
Tensor
max
;
T
*
max_data
=
max
.
mutable_data
<
T
>
(
framework
::
make_ddim
({
grid
}),
ctx
.
GetPlace
()
);
FindAbsMaxKernel
<
T
><<<
grid
,
block
,
1024
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
in
,
num
,
max_data
);
FindAbsMaxKernel
<
T
><<<
1
,
block
,
1024
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
max_data
,
grid
,
out
);
}
}
;
template
struct
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
__global__
void
ApplySaturateKernel
(
const
int
n
,
const
T
*
in
,
T
*
out
,
int
*
num_saturate
,
const
T
min
,
const
T
max
)
{
__global__
void
ClipAndQuantKernel
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
T
*
out
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
extern
__shared__
int
shared_count
[];
shared_count
[
tid
]
=
0
;
T
s
=
scale
[
0
];
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
if
(
in
[
i
]
>
max
)
{
out
[
i
]
=
max
;
shared_count
[
tid
]
+=
1
;
}
else
if
(
in
[
i
]
<
min
)
{
out
[
i
]
=
min
;
shared_count
[
tid
]
+=
1
;
}
else
{
out
[
i
]
=
in
[
i
];
}
}
__syncthreads
();
for
(
int
i
=
blockDim
.
x
/
2
;
i
>
0
;
i
>>=
1
)
{
if
(
tid
<
i
)
{
shared_count
[
tid
]
+=
shared_count
[
tid
+
i
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
num_saturate
[
blockIdx
.
x
]
=
shared_count
[
0
];
T
x
=
in
[
bid
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
/
s
*
v
;
out
[
bid
]
=
round
(
v
);
}
}
template
<
typename
T
>
__global__
void
ReduceKernel
(
const
int
n
,
const
T
*
in
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
extern
__shared__
T
shared_sum
[];
if
(
tid
<
n
)
{
shared_sum
[
tid
]
=
in
[
tid
];
__global__
void
FindRangeAbsMaxAndFillArray
(
const
T
*
cur_scale
,
const
T
*
last_scale
,
const
int64_t
*
iter
,
const
int
window_size
,
T
*
scale_arr
,
T
*
out_scale
,
int
*
need_find_max
,
int
*
out_size
)
{
int
it
=
iter
[
0
];
int
idx
=
it
%
window_size
;
T
removed
=
scale_arr
[
idx
];
T
cur
=
cur_scale
[
0
];
scale_arr
[
idx
]
=
cur
;
T
max
=
last_scale
[
0
];
out_scale
[
0
]
=
max
<
cur
?
cur
:
max
;
if
(
fabs
(
removed
-
max
)
<
1e-6
)
{
need_find_max
[
0
]
=
1
;
out_size
[
0
]
=
it
>
window_size
?
window_size
:
it
;
}
else
{
shared_sum
[
tid
]
=
T
(
0
);
}
__syncthreads
();
// blockDim.x must >= n
for
(
int
i
=
(
n
+
1
)
/
2
;
i
>
0
;
i
>>=
1
)
{
if
(
tid
<
i
)
{
shared_sum
[
tid
]
+=
shared_sum
[
tid
+
i
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
out
[
0
]
=
shared_sum
[
0
];
need_find_max
[
0
]
=
0
;
}
}
template
<
typename
T
>
int
ApplySaturateGpu
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
int
n
,
const
T
*
in
,
T
*
out
,
const
T
min
,
const
T
max
)
{
int
host_num_saturate
;
int
kNumTheads
=
1024
;
int
gridDimx
=
(
n
+
kNumTheads
-
1
)
/
kNumTheads
;
gridDimx
=
(
gridDimx
>
kNumTheads
)
?
kNumTheads
:
gridDimx
;
framework
::
Tensor
t
;
int
*
device_num_saturate
=
t
.
mutable_data
<
int
>
(
framework
::
make_ddim
({
gridDimx
}),
platform
::
CUDAPlace
());
ApplySaturateKernel
<
T
><<<
gridDimx
,
kNumTheads
,
kNumTheads
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
n
,
in
,
out
,
device_num_saturate
,
min
,
max
);
ReduceKernel
<
int
><<<
1
,
kNumTheads
,
kNumTheads
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
gridDimx
,
device_num_saturate
,
device_num_saturate
);
PADDLE_ENFORCE_EQ
(
cudaSuccess
,
cudaMemcpy
(
&
host_num_saturate
,
device_num_saturate
,
sizeof
(
int
),
cudaMemcpyDeviceToHost
),
"cudaMemcpy failed"
);
return
host_num_saturate
;
}
template
<
typename
DeviceContext
,
typename
T
>
class
FakeQuantizeCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
T
FindRangeAbsMax
(
const
platform
::
CUDADeviceContext
&
ctx
,
framework
::
Tensor
*
scale_list
,
framework
::
Tensor
*
out_scale
,
const
T
&
cur_scale
,
int
window_size
,
int
current_iter
)
const
{
T
*
sl
=
scale_list
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
T
remove_tmp
=
sl
[
current_iter
];
sl
[
current_iter
]
=
cur_scale
;
T
&
max_scale
=
out_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
if
(
max_scale
<
cur_scale
)
{
max_scale
=
cur_scale
;
}
else
if
(
fabs
(
remove_tmp
-
max_scale
)
<
1e-6
)
{
int
size
=
(
current_iter
>
window_size
)
?
window_size
:
current_iter
;
max_scale
=
T
(
FindAbsMaxGpu
(
ctx
,
scale_list
->
data
<
float
>
(),
size
));
struct
FindRangeAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
cur_scale
,
const
framework
::
Tensor
&
last_scale
,
const
framework
::
Tensor
&
iter
,
const
int
window_size
,
framework
::
Tensor
*
scales_arr
,
framework
::
Tensor
*
out_scale
)
{
auto
&
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx
.
GetPlace
());
T
*
scale_arr
=
scales_arr
->
mutable_data
<
T
>
(
gpu_place
);
T
*
out_scale_data
=
out_scale
->
mutable_data
<
T
>
(
gpu_place
);
framework
::
Tensor
need_find_max
,
out_size
;
int
*
find_max
=
need_find_max
.
mutable_data
<
int
>
(
gpu_place
);
int
*
out_size_data
=
out_size
.
mutable_data
<
int
>
(
gpu_place
);
FindRangeAbsMaxAndFillArray
<
T
><<<
1
,
1
,
0
,
ctx
.
stream
()
>>>
(
cur_scale
.
data
<
T
>
(),
last_scale
.
data
<
T
>
(),
iter
.
data
<
int64_t
>
(),
window_size
,
scale_arr
,
out_scale_data
,
find_max
,
out_size_data
);
int
g_find_max
;
memory
::
Copy
(
platform
::
CPUPlace
(),
&
g_find_max
,
gpu_place
,
find_max
,
sizeof
(
int
),
0
);
if
(
g_find_max
)
{
int
len
;
memory
::
Copy
(
platform
::
CPUPlace
(),
&
len
,
gpu_place
,
out_size_data
,
sizeof
(
int
),
0
);
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
ctx
,
scale_arr
,
len
,
out_scale_data
);
}
return
max_scale
;
}
T
FindMovingAverageAbsMmax
(
framework
::
Tensor
*
in_scale
,
framework
::
Tensor
*
out_scale
,
const
T
&
cur_scale
)
const
{
T
*
ins
=
in_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
T
*
outs
=
out_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
outs
[
0
]
=
0.9
*
cur_scale
+
0.1
*
ins
[
0
];
return
T
(
outs
[
0
]);
}
};
virtual
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
context
.
GetPlace
()),
"This kernel only runs on GPU device."
);
auto
&
device_ctx
=
context
.
cuda_device_context
();
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
const
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
tensor
->
mutable_data
<
T
>
(
in
->
place
());
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
)
->
mutable_data
<
T
>
(
context
.
Input
<
framework
::
Tensor
>
(
"InMovingScale"
)
->
place
());
auto
quantize_type
=
static_cast
<
std
::
string
>
(
context
.
Attr
<
std
::
string
>
(
"quantize_type"
));
if
(
quantize_type
==
std
::
string
(
"range_abs_max"
))
{
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
)
->
mutable_data
<
T
>
(
context
.
Input
<
framework
::
Tensor
>
(
"InScales"
)
->
place
());
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
)
->
mutable_data
<
T
>
(
context
.
Input
<
framework
::
Tensor
>
(
"InCurrentIter"
)
->
place
());
}
T
scale
=
T
(
1
);
int
window_size
=
context
.
Attr
<
int
>
(
"window_size"
);
T
bin_cnt
=
(
T
)((
1
<<
(
context
.
Attr
<
int
>
(
"bit_length"
)
-
1
))
-
1
);
if
(
quantize_type
==
std
::
string
(
"abs_max"
))
{
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
scale
=
(
T
)
FindAbsMaxGpu
(
device_ctx
,
in
->
data
<
float
>
(),
in
->
numel
());
saving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
]
=
scale
;
auto
&
device_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
*
scale_list
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
math
::
SetConstant
<
DeviceContext
,
T
>
scalar
;
scale_list
->
mutable_data
<
T
>
(
context
.
GetPlace
());
scalar
(
device_ctx
,
scale_list
,
static_cast
<
T
>
(
0
));
auto
*
iter
=
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
);
iter
->
mutable_data
<
T
>
(
context
.
GetPlace
());
scalar
(
device_ctx
,
iter
,
static_cast
<
T
>
(
0
));
}
else
if
(
quantize_type
==
std
::
string
(
"range_abs_max"
))
{
auto
*
moving_scale
=
const_cast
<
framework
::
Tensor
*>
(
context
.
Input
<
framework
::
Tensor
>
(
"InMovingScale"
));
if
(
is_test
)
{
scale
=
moving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
}
else
{
auto
*
it
=
const_cast
<
framework
::
Tensor
*>
(
context
.
Input
<
framework
::
Tensor
>
(
"InCurrentIter"
));
auto
*
iter
=
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
);
int
*
last_iter
=
it
->
mutable_data
<
int
>
(
platform
::
CPUPlace
());
int
*
current_iter
=
iter
->
mutable_data
<
int
>
(
platform
::
CPUPlace
());
auto
*
scale_list
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
scale
=
(
T
)
FindAbsMaxGpu
(
device_ctx
,
in
->
data
<
float
>
(),
in
->
numel
());
scale
=
FindRangeAbsMax
(
device_ctx
,
scale_list
,
saving_scale
,
scale
,
window_size
,
current_iter
[
0
]);
(
*
current_iter
)
=
(
*
last_iter
)
+
1
;
}
}
else
if
(
quantize_type
==
std
::
string
(
"moving_average_abs_max"
))
{
auto
*
moving_scale
=
const_cast
<
framework
::
Tensor
*>
(
context
.
Input
<
framework
::
Tensor
>
(
"InMovingScale"
));
if
(
is_test
)
{
scale
=
moving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
}
else
{
scale
=
(
T
)
FindAbsMaxGpu
(
device_ctx
,
in
->
data
<
float
>
(),
in
->
numel
());
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
scale
=
FindMovingAverageAbsMmax
(
const_cast
<
framework
::
Tensor
*>
(
moving_scale
),
saving_scale
,
scale
);
}
}
ApplySaturateGpu
<
T
>
(
device_ctx
,
in
->
numel
(),
in
->
data
<
T
>
(),
tensor
->
mutable_data
<
T
>
(
in
->
place
()),
-
scale
,
scale
);
scale
=
bin_cnt
/
scale
;
template
struct
FindRangeAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
auto
&
dev
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
eigen_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
tensor
);
auto
eigen_in
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
tensor
);
eigen_out
.
device
(
dev
)
=
(
scale
*
eigen_in
).
round
();
template
<
typename
T
>
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
)
{
int
num
=
in
.
numel
();
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ClipAndQuantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
out_data
);
}
};
template
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
float
>;
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
fake_quantize
,
paddle
::
operators
::
FakeQuantizeCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
paddle
::
operators
::
FakeQuantizeCUDAKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
namespace
ops
=
paddle
::
operators
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_abs_max
,
ops
::
FakeQuantizeAbsMaxKernel
<
CUDA
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_range_abs_max
,
ops
::
FakeQuantizeRangeAbsMaxKernel
<
CUDA
,
float
>
);
paddle/fluid/operators/fake_quantize_op.h
浏览文件 @
9bd933d3
...
...
@@ -17,137 +17,91 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/transform.h"
namespace
paddle
{
namespace
operators
{
using
platform
::
Transform
;
template
<
typename
DeviceContext
,
typename
T
>
struct
FindAbsMaxFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
T
*
in
,
const
int
num
,
T
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
class
FakeQuantizeKernel
:
public
framework
::
OpKernel
<
T
>
{
struct
ClipAndFakeQuantFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
struct
FindRangeAbsMaxFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
cur_scale
,
const
framework
::
Tensor
&
last_scale
,
const
framework
::
Tensor
&
iter
,
const
int
window_size
,
framework
::
Tensor
*
scales_arr
,
framework
::
Tensor
*
out_scale
);
};
template
<
typename
DeviceContext
,
typename
T
>
class
FakeQuantizeAbsMaxKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
T
FindAbsMax
(
framework
::
Tensor
*
in
,
int
n
)
const
{
T
*
p
=
in
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
T
abs_max
=
(
T
)
0.00000001
;
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
T
tmp
=
fabs
(
p
[
i
]);
if
(
tmp
>
abs_max
)
abs_max
=
tmp
;
}
return
T
(
abs_max
);
}
T
FindRangeAbsMax
(
framework
::
Tensor
*
scale_list
,
framework
::
Tensor
*
out_scale
,
const
T
&
cur_scale
,
int
window_size
,
int
current_iter
)
const
{
T
*
sl
=
scale_list
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
T
remove_tmp
=
sl
[
current_iter
];
sl
[
current_iter
]
=
cur_scale
;
T
&
max_scale
=
out_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
if
(
max_scale
<
cur_scale
)
{
max_scale
=
cur_scale
;
}
else
if
(
fabs
(
remove_tmp
-
max_scale
)
<
1e-6
)
{
int
size
=
(
current_iter
>
window_size
)
?
window_size
:
current_iter
;
max_scale
=
T
(
FindAbsMax
(
scale_list
,
size
));
}
return
max_scale
;
}
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
T
FindMovingAverageAbsMmax
(
framework
::
Tensor
*
in_scale
,
framework
::
Tensor
*
out_scale
,
const
T
&
cur_scale
)
const
{
T
*
ins
=
in_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
T
*
outs
=
out_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
outs
[
0
]
=
0.9
*
cur_scale
+
0.1
*
ins
[
0
];
return
T
(
outs
[
0
]);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
T
*
out_s
=
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
const
T
*
in_data
=
in
->
data
<
T
>
();
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in_data
,
in
->
numel
(),
out_s
);
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
out
);
}
};
virtual
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
{
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
template
<
typename
DeviceContext
,
typename
T
>
class
FakeQuantizeRangeAbsMaxKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
const
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
tensor
->
mutable_data
<
T
>
(
in
->
place
());
auto
*
oms_tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
oms_tensor
->
mutable_data
<
T
>
(
in
->
place
());
auto
quantize_type
=
static_cast
<
std
::
string
>
(
context
.
Attr
<
std
::
string
>
(
"quantize_type"
));
if
(
quantize_type
==
std
::
string
(
"range_abs_max"
))
{
auto
*
oss_tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
oss_tensor
->
mutable_data
<
T
>
(
context
.
Input
<
framework
::
Tensor
>
(
"InScales"
)
->
place
());
auto
*
oci_tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
);
oci_tensor
->
mutable_data
<
T
>
(
context
.
Input
<
framework
::
Tensor
>
(
"InCurrentIter"
)
->
place
());
}
auto
*
in_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"InScale"
);
T
scale
=
static_cast
<
T
>
(
1
);
int
window_size
=
context
.
Attr
<
int
>
(
"window_size"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
raw_in
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
in
);
if
(
quantize_type
==
std
::
string
(
"abs_max"
))
{
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
auto
scale_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
saving_scale
);
scale_out
.
device
(
dev
)
=
raw_in
.
abs
().
maximum
();
scale
=
scale_out
(
0
);
auto
&
device_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
*
scale_list
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
math
::
SetConstant
<
DeviceContext
,
T
>
scalar
;
scale_list
->
mutable_data
<
T
>
(
context
.
GetPlace
());
scalar
(
device_ctx
,
scale_list
,
static_cast
<
T
>
(
0
));
auto
*
iter
=
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
);
iter
->
mutable_data
<
T
>
(
context
.
GetPlace
());
scalar
(
device_ctx
,
iter
,
static_cast
<
T
>
(
0
));
}
else
if
(
quantize_type
==
std
::
string
(
"range_abs_max"
))
{
auto
*
moving_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"InMovingScale"
);
if
(
is_test
)
{
scale
=
moving_scale
->
data
<
T
>
()[
0
];
}
else
{
auto
*
it
=
context
.
Input
<
framework
::
Tensor
>
(
"InCurrentIter"
);
auto
*
iter
=
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
);
const
int
*
last_iter
=
it
->
data
<
int
>
();
int
*
current_iter
=
iter
->
mutable_data
<
int
>
(
platform
::
CPUPlace
());
auto
*
scale_list
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
auto
scale_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
saving_scale
);
scale_out
.
device
(
dev
)
=
raw_in
.
abs
().
maximum
();
scale
=
saving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
scale
=
FindRangeAbsMax
(
scale_list
,
saving_scale
,
scale
,
window_size
,
current_iter
[
0
]);
saving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
]
=
scale
;
(
*
current_iter
)
=
(
*
last_iter
)
+
1
;
}
}
else
if
(
quantize_type
==
std
::
string
(
"moving_average_abs_max"
))
{
auto
*
moving_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"InMovingScale"
);
if
(
is_test
)
{
scale
=
moving_scale
->
data
<
T
>
()[
0
];
}
else
{
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
auto
scale_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
saving_scale
);
scale_out
.
device
(
dev
)
=
raw_in
.
abs
().
maximum
();
scale
=
saving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
scale
=
FindMovingAverageAbsMmax
(
const_cast
<
framework
::
Tensor
*>
(
moving_scale
),
saving_scale
,
scale
);
saving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
]
=
scale
;
}
// testing
if
(
is_test
)
{
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
in_scale
,
bin_cnt
,
out
);
return
;
}
Transform
<
DeviceContext
>
trans
;
trans
(
context
.
template
device_context
<
DeviceContext
>(),
in
->
data
<
T
>
(),
in
->
data
<
T
>
()
+
in
->
numel
(),
tensor
->
mutable_data
<
T
>
(
in
->
place
()),
ClipFunctor
<
T
>
(
-
scale
,
scale
));
auto
eigen_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
tensor
);
auto
eigen_in
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
tensor
);
eigen_out
.
device
(
dev
)
=
(
bin_cnt
/
scale
*
eigen_in
).
round
();
// training
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
auto
*
out_scales
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
auto
*
iter
=
context
.
Input
<
framework
::
Tensor
>
(
"Iter"
);
int
window_size
=
context
.
Attr
<
int
>
(
"window_size"
);
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
framework
::
Tensor
cur_scale
;
T
*
cur_scale_data
=
cur_scale
.
mutable_data
<
T
>
({
1
},
context
.
GetPlace
());
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
->
data
<
T
>
(),
in
->
numel
(),
cur_scale_data
);
FindRangeAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
cur_scale
,
*
in_scale
,
*
iter
,
window_size
,
out_scales
,
out_scale
);
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
out
);
}
};
...
...
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
浏览文件 @
9bd933d3
...
...
@@ -21,28 +21,41 @@ from op_test import OpTest
class
TestFakeQuantizeOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"fake_quantize"
self
.
op_type
=
"fake_quantize_abs_max"
self
.
attrs
=
{
'bit_length'
:
8
}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
124
,
240
)).
astype
(
"float32"
),
}
scale
=
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
])).
astype
(
"float32"
)
self
.
outputs
=
{
'Out'
:
np
.
round
(
self
.
inputs
[
'X'
]
/
scale
*
(
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
)),
'OutScale'
:
np
.
array
(
scale
).
astype
(
"float32"
),
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestFakeQuantizeOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"fake_quantize_range_abs_max"
self
.
attrs
=
{
'bit_length'
:
8
,
'
quantize_type'
:
'abs_max'
,
'
window_size'
:
10000
'bit_length'
:
int
(
5
)
,
'
window_size'
:
int
(
1
)
,
'
is_test'
:
False
}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
10
,
10
)).
astype
(
"float32"
),
'InScales'
:
np
.
zeros
(
self
.
attrs
[
'window_size'
]).
astype
(
"float32"
),
'InCurrentIter'
:
np
.
zeros
(
1
).
astype
(
"float32"
),
'InMovingScale'
:
np
.
zeros
(
1
).
astype
(
"float32"
)
}
self
.
scale
=
{
'abs_max'
:
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
])).
astype
(
"float32"
)
'X'
:
np
.
random
.
random
((
8
,
16
,
7
,
7
)).
astype
(
"float32"
),
'Iter'
:
np
.
zeros
(
1
).
astype
(
"int64"
),
'InScale'
:
np
.
zeros
(
1
).
astype
(
"float32"
)
}
scale
=
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
])).
astype
(
"float32"
)
out_scales
=
np
.
zeros
(
self
.
attrs
[
'window_size'
]).
astype
(
"float32"
)
out_scales
[
0
]
=
scale
self
.
outputs
=
{
'Out'
:
np
.
round
(
self
.
inputs
[
'X'
]
/
s
elf
.
scale
[
'abs_max'
]
*
(
'Out'
:
np
.
round
(
self
.
inputs
[
'X'
]
/
s
cale
*
(
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
)),
'OutScales'
:
np
.
zeros
(
self
.
attrs
[
'window_size'
]).
astype
(
"float32"
),
'OutMovingScale'
:
np
.
array
([
self
.
scale
[
'abs_max'
]]).
astype
(
"float32"
),
'OutCurrentIter'
:
np
.
zeros
(
1
).
astype
(
"float32"
)
'OutScale'
:
scale
,
'OutScales'
:
out_scales
,
}
def
test_check_output
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录