Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9bd933d3
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9bd933d3
编写于
9月 03, 2018
作者:
Q
qingqing01
提交者:
GitHub
9月 03, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Improve and fix fake_quantize_op (#13092)
* Improve and fix fake_quantize_op.
上级
b4d43030
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
369 addition
and
376 deletion
+369
-376
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+3
-0
paddle/fluid/operators/fake_quantize_op.cc
paddle/fluid/operators/fake_quantize_op.cc
+171
-53
paddle/fluid/operators/fake_quantize_op.cu
paddle/fluid/operators/fake_quantize_op.cu
+99
-194
paddle/fluid/operators/fake_quantize_op.h
paddle/fluid/operators/fake_quantize_op.h
+67
-113
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
+29
-16
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
9bd933d3
...
@@ -178,6 +178,8 @@ function(op_library TARGET)
...
@@ -178,6 +178,8 @@ function(op_library TARGET)
file
(
APPEND
${
pybind_file
}
"USE_OP(relu);
\n
"
)
file
(
APPEND
${
pybind_file
}
"USE_OP(relu);
\n
"
)
elseif
(
${
TARGET
}
STREQUAL
"fake_dequantize"
)
elseif
(
${
TARGET
}
STREQUAL
"fake_dequantize"
)
file
(
APPEND
${
pybind_file
}
"USE_OP(fake_dequantize_max_abs);
\n
"
)
file
(
APPEND
${
pybind_file
}
"USE_OP(fake_dequantize_max_abs);
\n
"
)
elseif
(
${
TARGET
}
STREQUAL
"fake_quantize"
)
file
(
APPEND
${
pybind_file
}
"USE_OP(fake_quantize_abs_max);
\n
"
)
elseif
(
${
TARGET
}
STREQUAL
"tensorrt_engine_op"
)
elseif
(
${
TARGET
}
STREQUAL
"tensorrt_engine_op"
)
message
(
STATUS
"Pybind skips [tensorrt_engine_op], for this OP is only used in inference"
)
message
(
STATUS
"Pybind skips [tensorrt_engine_op], for this OP is only used in inference"
)
elseif
(
${
TARGET
}
STREQUAL
"fc"
)
elseif
(
${
TARGET
}
STREQUAL
"fc"
)
...
@@ -293,6 +295,7 @@ op_library(extract_rows_op DEPS memory)
...
@@ -293,6 +295,7 @@ op_library(extract_rows_op DEPS memory)
op_library
(
flatten_op DEPS reshape_op
)
op_library
(
flatten_op DEPS reshape_op
)
op_library
(
sequence_pad_op DEPS sequence_padding
)
op_library
(
sequence_pad_op DEPS sequence_padding
)
op_library
(
unstack_op DEPS stack_op
)
op_library
(
unstack_op DEPS stack_op
)
op_library
(
fake_quantize_op DEPS memory
)
if
(
WITH_GPU
)
if
(
WITH_GPU
)
op_library
(
conv_op DEPS vol2col depthwise_conv im2col
)
op_library
(
conv_op DEPS vol2col depthwise_conv im2col
)
...
...
paddle/fluid/operators/fake_quantize_op.cc
浏览文件 @
9bd933d3
...
@@ -14,86 +14,198 @@ limitations under the License. */
...
@@ -14,86 +14,198 @@ limitations under the License. */
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/operators/fake_quantize_op.h"
#include <string>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/platform/transform.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
class
FakeQuantizeOp
:
public
framework
::
OperatorWithKernel
{
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenVectorArrayMap
=
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
T
,
1
,
MajorType
,
IndexType
>>
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
ConstEigenVectorArrayMap
=
Eigen
::
TensorMap
<
const
Eigen
::
Tensor
<
T
,
1
,
MajorType
,
IndexType
>>
;
template
<
typename
T
>
struct
FindAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
T
*
in
,
const
int
num
,
T
*
out
)
{
Eigen
::
DSizes
<
Eigen
::
DenseIndex
,
1
>
idim
(
num
);
Eigen
::
DSizes
<
Eigen
::
DenseIndex
,
1
>
odim
(
1
);
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
const
T
,
1
,
Eigen
::
RowMajor
>>
in_e
(
in
,
idim
);
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
T
,
1
,
Eigen
::
RowMajor
>>
out_e
(
out
,
odim
);
out_e
=
in_e
.
abs
().
maximum
();
}
};
template
struct
FindAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
<
typename
T
>
struct
ClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
)
{
T
s
=
scale
.
data
<
T
>
()[
0
];
platform
::
Transform
<
platform
::
CPUDeviceContext
>
trans
;
trans
(
ctx
,
in
.
data
<
T
>
(),
in
.
data
<
T
>
()
+
in
.
numel
(),
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
ClipFunctor
<
T
>
(
-
s
,
s
));
auto
in_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
in
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
out_e
.
device
(
*
ctx
.
eigen_device
())
=
(
bin_cnt
/
s
*
in_e
).
round
();
}
};
template
struct
ClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
<
typename
T
>
struct
FindRangeAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
cur_scale
,
const
framework
::
Tensor
&
last_scale
,
const
framework
::
Tensor
&
iter
,
const
int
window_size
,
framework
::
Tensor
*
scales_arr
,
framework
::
Tensor
*
out_scale
)
{
T
*
scale_arr
=
scales_arr
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int64_t
it
=
iter
.
data
<
int64_t
>
()[
0
];
int
idx
=
it
%
window_size
;
T
removed
=
scale_arr
[
idx
];
T
cur
=
cur_scale
.
data
<
T
>
()[
0
];
scale_arr
[
idx
]
=
cur
;
T
max
=
last_scale
.
data
<
T
>
()[
0
];
if
(
max
<
cur
)
{
max
=
cur
;
}
else
if
(
fabs
(
removed
-
max
)
<
1e-6
)
{
int
size
=
(
it
>
window_size
)
?
window_size
:
it
;
FindAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
()(
ctx
,
scale_arr
,
size
,
&
max
);
}
out_scale
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())[
0
]
=
max
;
}
};
template
struct
FindRangeAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
float
>;
class
FakeQuantizeAbsMaxOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
FakeQuantize
Op
(
const
std
::
string
&
type
,
FakeQuantize
AbsMaxOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of FakeQuantizeOp should not be null."
);
"Input(X) of FakeQuantizeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of FakeQuantizeOp should not be null."
);
"Output(Out) of FakeQuantizeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"OutMovingScale"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"OutScale"
),
"OutMovingScale(Out) of FakeQuantizeOp should not be null"
);
"Output(Scale) of FakeQuantizeOp should not be null."
);
// if (ctx->HasInput("InMovingScale")) {
ctx
->
SetOutputDim
(
"OutMovingScale"
,
ctx
->
GetInputDim
(
"InMovingScale"
));
//}
// if (ctx->HasInput("InScales")) {
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"OutScales"
),
"OutScales(Out) of FakeQuantizeOp should not be null"
);
ctx
->
SetOutputDim
(
"OutScales"
,
ctx
->
GetInputDim
(
"InScales"
));
// PADDLE_ENFORCE_EQ(ctx->Inputs("InScales")[0],
// ctx->Outputs("OutScales")[0],
// "Mean and MeanOut should share the same memory");
//}
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"OutScale"
,
{
1
});
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
};
class
FakeQuantizeOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
FakeQuantize
AbsMax
OpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
override
{
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) Input tensor of scale operator."
);
AddInput
(
"X"
,
"(Tensor) Input is float data type."
);
AddInput
(
"InScales"
,
"(Tensor) scale buffer, used in static quantization."
)
AddOutput
(
"Out"
,
.
AsDispensable
();
"(Tensor) Output of quantized low level tensor, "
AddInput
(
"InMovingScale"
,
"Last scale, used in static quantization."
)
"but also saved as float data type."
);
.
AsDispensable
();
AddOutput
(
"OutScale"
,
"(Tensor) Current scale"
);
AddInput
(
"InCurrentIter"
,
"Last iteration number, used in static quantization."
)
.
AsDispensable
();
AddOutput
(
"Out"
,
"(Tensor) Output of quantized low level tensor."
);
AddOutput
(
"OutScales"
,
"(Tensor) scale buffer, used in static quantization."
)
.
AsDispensable
();
AddOutput
(
"OutMovingScale"
,
" Current scale"
);
AddOutput
(
"OutCurrentIter"
,
"Current iteration number."
).
AsDispensable
();
AddAttr
<
std
::
string
>
(
"quantize_type"
,
"(string, default abs_max)"
"The scaling tpe of the quantize operator."
)
.
SetDefault
(
"abs_max"
);
AddAttr
<
int
>
(
"window_size"
,
"(int, default 10000)"
).
SetDefault
(
10000
);
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8)"
)
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8)"
)
.
SetDefault
(
8
)
.
SetDefault
(
8
)
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
PADDLE_ENFORCE
(
bit_length
>=
1
&&
bit_length
<=
16
,
PADDLE_ENFORCE
(
bit_length
>=
1
&&
bit_length
<=
16
,
"'bit_length' should be between 1 and 16."
);
"'bit_length' should be between 1 and 16."
);
});
});
AddAttr
<
bool
>
(
"is_test"
,
""
).
SetDefault
(
false
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
FakeQuantize operator
FakeQuantize operator
quantize_type = abs_max:
$$scale = max(abs(X))$$
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
$$scale = max(abs(x))$$
)DOC"
);
}
};
quantize_type = range_abs_max:
class
FakeQuantizeRangeAbsMaxOp
:
public
framework
::
OperatorWithKernel
{
public:
FakeQuantizeRangeAbsMaxOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
$$scale = max(max(abs(x)), history_abs_max)$$
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of FakeQuantizeRangeAbsMaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of FakeQuantizeRangeAbsMaxOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"OutScale"
),
"Output(OutScale) of FakeQuantizeRangeAbsMaxOp should not be null"
);
if
(
ctx
->
HasOutput
(
"OutScales"
))
{
int
window_size
=
ctx
->
Attrs
().
Get
<
int
>
(
"window_size"
);
ctx
->
SetOutputDim
(
"OutScales"
,
{
window_size
});
}
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"OutScale"
,
{
1
});
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
quantize_type = moving_average_abs_max:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
framework
::
ToDataType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
()),
ctx
.
device_context
());
}
};
$$scale = 0.1*scale+0.9*new_abs_max)$$
class
FakeQuantizeRangeAbsMaxOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) Input is float data type."
);
AddInput
(
"InScale"
,
"Last scale."
);
AddInput
(
"Iter"
,
"Global step iteration."
).
AsDispensable
();
AddOutput
(
"Out"
,
"(Tensor) Output of quantized low level tensor."
);
AddOutput
(
"OutScale"
,
" Current scale"
);
AddOutput
(
"OutScales"
,
"(Tensor) scale buffer."
).
AsDispensable
();
AddAttr
<
int
>
(
"window_size"
,
"(int, default 10000) window range size."
)
.
SetDefault
(
10000
);
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8), quantization bit number."
)
.
SetDefault
(
8
)
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
PADDLE_ENFORCE
(
bit_length
>=
1
&&
bit_length
<=
16
,
"'bit_length' should be between 1 and 16."
);
});
AddAttr
<
bool
>
(
"is_test"
,
""
).
SetDefault
(
false
);
AddComment
(
R"DOC(
FakeQuantize operator is used in static quantization.
$$Out = scale*X$$
$$scale = max(max(abs(x)), history_abs_max)$$
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
)DOC"
);
)DOC"
);
}
}
...
@@ -103,10 +215,16 @@ $$Out = scale*X$$
...
@@ -103,10 +215,16 @@ $$Out = scale*X$$
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
using
CPU
=
paddle
::
platform
::
CPUDeviceContext
;
REGISTER_OPERATOR
(
fake_quantize_abs_max
,
ops
::
FakeQuantizeAbsMaxOp
,
ops
::
FakeQuantizeAbsMaxOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
fake_quantize_abs_max
,
ops
::
FakeQuantizeAbsMaxKernel
<
CPU
,
float
>
);
REGISTER_OPERATOR
(
fake_quantize
,
ops
::
FakeQuantizeOp
,
ops
::
FakeQuantizeOpMaker
,
REGISTER_OPERATOR
(
fake_quantize_range_abs_max
,
ops
::
FakeQuantizeRangeAbsMaxOp
,
ops
::
FakeQuantizeRangeAbsMaxOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
fake_quantize_range_abs_max
,
fake_quantize
,
ops
::
FakeQuantizeRangeAbsMaxKernel
<
CPU
,
float
>
);
ops
::
FakeQuantizeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
FakeQuantizeKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/fake_quantize_op.cu
浏览文件 @
9bd933d3
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <string>
#include <string>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/cuda_primitives.h"
...
@@ -20,7 +21,7 @@ namespace paddle {
...
@@ -20,7 +21,7 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
template
<
typename
T
>
__global__
void
FindAbsMaxKernel
(
const
int
n
,
const
T
*
i
n
,
T
*
out
)
{
__global__
void
FindAbsMaxKernel
(
const
T
*
in
,
const
int
n
,
T
*
out
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
...
@@ -43,7 +44,7 @@ __global__ void FindAbsMaxKernel(const int n, const T* in, T* out) {
...
@@ -43,7 +44,7 @@ __global__ void FindAbsMaxKernel(const int n, const T* in, T* out) {
__syncthreads
();
__syncthreads
();
for
(
int
i
=
blockDim
.
x
/
2
;
i
>
0
;
i
>>=
1
)
{
for
(
int
i
=
blockDim
.
x
/
2
;
i
>
0
;
i
>>=
1
)
{
if
(
tid
<
i
&&
shared_max_data
[
tid
]
<
shared_max_data
[
tid
+
i
]
)
{
if
(
tid
<
i
&&
(
shared_max_data
[
tid
]
<
shared_max_data
[
tid
+
i
])
)
{
shared_max_data
[
tid
]
=
shared_max_data
[
tid
+
i
];
shared_max_data
[
tid
]
=
shared_max_data
[
tid
+
i
];
}
}
__syncthreads
();
__syncthreads
();
...
@@ -53,220 +54,124 @@ __global__ void FindAbsMaxKernel(const int n, const T* in, T* out) {
...
@@ -53,220 +54,124 @@ __global__ void FindAbsMaxKernel(const int n, const T* in, T* out) {
}
}
}
}
float
FindAbsMaxGpu
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
float
*
array
,
template
<
typename
T
>
int
length
)
{
struct
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
float
host_max
;
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
T
*
in
,
int
kNumTheads
=
1024
;
const
int
num
,
T
*
out
)
{
int
gridDimx
=
(
kNumTheads
-
1
+
length
)
/
kNumTheads
;
int
block
=
1024
;
gridDimx
=
(
gridDimx
>
kNumTheads
)
?
kNumTheads
:
gridDimx
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
framework
::
Tensor
t
;
grid
=
(
grid
>
block
)
?
block
:
grid
;
float
*
device_max
=
t
.
mutable_data
<
float
>
(
framework
::
make_ddim
({
gridDimx
}),
platform
::
CUDAPlace
())
;
framework
::
Tensor
max
;
FindAbsMaxKernel
<
float
><<<
gridDimx
,
kNumTheads
,
kNumTheads
*
sizeof
(
float
),
T
*
max_data
=
ctx
.
stream
()
>>>
(
length
,
array
,
device_max
);
max
.
mutable_data
<
T
>
(
framework
::
make_ddim
({
grid
}),
ctx
.
GetPlace
()
);
FindAbsMaxKernel
<
FindAbsMaxKernel
<
T
><<<
grid
,
block
,
1024
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
float
><<<
1
,
kNumTheads
,
kNumTheads
*
sizeof
(
float
),
ctx
.
stream
()
>>>
(
in
,
num
,
max_data
);
gridDimx
,
device_max
,
device_max
);
FindAbsMaxKernel
<
T
><<<
1
,
block
,
1024
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
PADDLE_ENFORCE_EQ
(
max_data
,
grid
,
out
);
cudaMemcpy
(
&
host_max
,
device_max
,
sizeof
(
float
),
cudaMemcpyDeviceToHost
),
}
cudaSuccess
,
"cudaMemcpy failed"
)
;
}
;
return
host_max
;
}
template
struct
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
template
<
typename
T
>
__global__
void
ApplySaturateKernel
(
const
int
n
,
const
T
*
in
,
T
*
out
,
__global__
void
ClipAndQuantKernel
(
const
T
*
in
,
const
T
*
scale
,
int
*
num_saturate
,
const
T
min
,
const
int
bin_cnt
,
const
int
n
,
T
*
out
)
{
const
T
max
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
extern
__shared__
int
shared_count
[];
T
s
=
scale
[
0
];
shared_count
[
tid
]
=
0
;
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
if
(
in
[
i
]
>
max
)
{
T
x
=
in
[
bid
];
out
[
i
]
=
max
;
T
v
=
x
>
s
?
s
:
x
;
shared_count
[
tid
]
+=
1
;
v
=
v
<
-
s
?
-
s
:
v
;
}
else
if
(
in
[
i
]
<
min
)
{
v
=
bin_cnt
/
s
*
v
;
out
[
i
]
=
min
;
out
[
bid
]
=
round
(
v
);
shared_count
[
tid
]
+=
1
;
}
else
{
out
[
i
]
=
in
[
i
];
}
}
__syncthreads
();
for
(
int
i
=
blockDim
.
x
/
2
;
i
>
0
;
i
>>=
1
)
{
if
(
tid
<
i
)
{
shared_count
[
tid
]
+=
shared_count
[
tid
+
i
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
num_saturate
[
blockIdx
.
x
]
=
shared_count
[
0
];
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
ReduceKernel
(
const
int
n
,
const
T
*
in
,
T
*
out
)
{
__global__
void
FindRangeAbsMaxAndFillArray
(
const
T
*
cur_scale
,
int
tid
=
threadIdx
.
x
;
const
T
*
last_scale
,
extern
__shared__
T
shared_sum
[];
const
int64_t
*
iter
,
if
(
tid
<
n
)
{
const
int
window_size
,
T
*
scale_arr
,
shared_sum
[
tid
]
=
in
[
tid
];
T
*
out_scale
,
int
*
need_find_max
,
int
*
out_size
)
{
int
it
=
iter
[
0
];
int
idx
=
it
%
window_size
;
T
removed
=
scale_arr
[
idx
];
T
cur
=
cur_scale
[
0
];
scale_arr
[
idx
]
=
cur
;
T
max
=
last_scale
[
0
];
out_scale
[
0
]
=
max
<
cur
?
cur
:
max
;
if
(
fabs
(
removed
-
max
)
<
1e-6
)
{
need_find_max
[
0
]
=
1
;
out_size
[
0
]
=
it
>
window_size
?
window_size
:
it
;
}
else
{
}
else
{
shared_sum
[
tid
]
=
T
(
0
);
need_find_max
[
0
]
=
0
;
}
__syncthreads
();
// blockDim.x must >= n
for
(
int
i
=
(
n
+
1
)
/
2
;
i
>
0
;
i
>>=
1
)
{
if
(
tid
<
i
)
{
shared_sum
[
tid
]
+=
shared_sum
[
tid
+
i
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
out
[
0
]
=
shared_sum
[
0
];
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
int
ApplySaturateGpu
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
int
n
,
struct
FindRangeAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
const
T
*
in
,
T
*
out
,
const
T
min
,
const
T
max
)
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
int
host_num_saturate
;
const
framework
::
Tensor
&
cur_scale
,
int
kNumTheads
=
1024
;
const
framework
::
Tensor
&
last_scale
,
int
gridDimx
=
(
n
+
kNumTheads
-
1
)
/
kNumTheads
;
const
framework
::
Tensor
&
iter
,
const
int
window_size
,
gridDimx
=
(
gridDimx
>
kNumTheads
)
?
kNumTheads
:
gridDimx
;
framework
::
Tensor
*
scales_arr
,
framework
::
Tensor
*
out_scale
)
{
framework
::
Tensor
t
;
auto
&
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx
.
GetPlace
());
int
*
device_num_saturate
=
t
.
mutable_data
<
int
>
(
T
*
scale_arr
=
scales_arr
->
mutable_data
<
T
>
(
gpu_place
);
framework
::
make_ddim
({
gridDimx
}),
platform
::
CUDAPlace
());
T
*
out_scale_data
=
out_scale
->
mutable_data
<
T
>
(
gpu_place
);
ApplySaturateKernel
<
T
><<<
gridDimx
,
kNumTheads
,
kNumTheads
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
framework
::
Tensor
need_find_max
,
out_size
;
n
,
in
,
out
,
device_num_saturate
,
min
,
max
);
int
*
find_max
=
need_find_max
.
mutable_data
<
int
>
(
gpu_place
);
ReduceKernel
<
int
><<<
1
,
kNumTheads
,
kNumTheads
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
int
*
out_size_data
=
out_size
.
mutable_data
<
int
>
(
gpu_place
);
gridDimx
,
device_num_saturate
,
device_num_saturate
);
PADDLE_ENFORCE_EQ
(
cudaSuccess
,
FindRangeAbsMaxAndFillArray
<
T
><<<
1
,
1
,
0
,
ctx
.
stream
()
>>>
(
cudaMemcpy
(
&
host_num_saturate
,
device_num_saturate
,
cur_scale
.
data
<
T
>
(),
last_scale
.
data
<
T
>
(),
iter
.
data
<
int64_t
>
(),
sizeof
(
int
),
cudaMemcpyDeviceToHost
),
window_size
,
scale_arr
,
out_scale_data
,
find_max
,
out_size_data
);
"cudaMemcpy failed"
);
return
host_num_saturate
;
int
g_find_max
;
}
memory
::
Copy
(
platform
::
CPUPlace
(),
&
g_find_max
,
gpu_place
,
find_max
,
sizeof
(
int
),
0
);
template
<
typename
DeviceContext
,
typename
T
>
if
(
g_find_max
)
{
class
FakeQuantizeCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
int
len
;
public:
memory
::
Copy
(
platform
::
CPUPlace
(),
&
len
,
gpu_place
,
out_size_data
,
T
FindRangeAbsMax
(
const
platform
::
CUDADeviceContext
&
ctx
,
sizeof
(
int
),
0
);
framework
::
Tensor
*
scale_list
,
framework
::
Tensor
*
out_scale
,
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
ctx
,
scale_arr
,
len
,
const
T
&
cur_scale
,
int
window_size
,
out_scale_data
);
int
current_iter
)
const
{
T
*
sl
=
scale_list
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
T
remove_tmp
=
sl
[
current_iter
];
sl
[
current_iter
]
=
cur_scale
;
T
&
max_scale
=
out_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
if
(
max_scale
<
cur_scale
)
{
max_scale
=
cur_scale
;
}
else
if
(
fabs
(
remove_tmp
-
max_scale
)
<
1e-6
)
{
int
size
=
(
current_iter
>
window_size
)
?
window_size
:
current_iter
;
max_scale
=
T
(
FindAbsMaxGpu
(
ctx
,
scale_list
->
data
<
float
>
(),
size
));
}
}
return
max_scale
;
}
T
FindMovingAverageAbsMmax
(
framework
::
Tensor
*
in_scale
,
framework
::
Tensor
*
out_scale
,
const
T
&
cur_scale
)
const
{
T
*
ins
=
in_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
T
*
outs
=
out_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
outs
[
0
]
=
0.9
*
cur_scale
+
0.1
*
ins
[
0
];
return
T
(
outs
[
0
]);
}
}
};
virtual
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
{
template
struct
FindRangeAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
context
.
GetPlace
()),
"This kernel only runs on GPU device."
);
auto
&
device_ctx
=
context
.
cuda_device_context
();
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
const
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
tensor
->
mutable_data
<
T
>
(
in
->
place
());
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
)
->
mutable_data
<
T
>
(
context
.
Input
<
framework
::
Tensor
>
(
"InMovingScale"
)
->
place
());
auto
quantize_type
=
static_cast
<
std
::
string
>
(
context
.
Attr
<
std
::
string
>
(
"quantize_type"
));
if
(
quantize_type
==
std
::
string
(
"range_abs_max"
))
{
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
)
->
mutable_data
<
T
>
(
context
.
Input
<
framework
::
Tensor
>
(
"InScales"
)
->
place
());
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
)
->
mutable_data
<
T
>
(
context
.
Input
<
framework
::
Tensor
>
(
"InCurrentIter"
)
->
place
());
}
T
scale
=
T
(
1
);
int
window_size
=
context
.
Attr
<
int
>
(
"window_size"
);
T
bin_cnt
=
(
T
)((
1
<<
(
context
.
Attr
<
int
>
(
"bit_length"
)
-
1
))
-
1
);
if
(
quantize_type
==
std
::
string
(
"abs_max"
))
{
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
scale
=
(
T
)
FindAbsMaxGpu
(
device_ctx
,
in
->
data
<
float
>
(),
in
->
numel
());
saving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
]
=
scale
;
auto
&
device_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
*
scale_list
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
math
::
SetConstant
<
DeviceContext
,
T
>
scalar
;
scale_list
->
mutable_data
<
T
>
(
context
.
GetPlace
());
scalar
(
device_ctx
,
scale_list
,
static_cast
<
T
>
(
0
));
auto
*
iter
=
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
);
iter
->
mutable_data
<
T
>
(
context
.
GetPlace
());
scalar
(
device_ctx
,
iter
,
static_cast
<
T
>
(
0
));
}
else
if
(
quantize_type
==
std
::
string
(
"range_abs_max"
))
{
auto
*
moving_scale
=
const_cast
<
framework
::
Tensor
*>
(
context
.
Input
<
framework
::
Tensor
>
(
"InMovingScale"
));
if
(
is_test
)
{
scale
=
moving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
}
else
{
auto
*
it
=
const_cast
<
framework
::
Tensor
*>
(
context
.
Input
<
framework
::
Tensor
>
(
"InCurrentIter"
));
auto
*
iter
=
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
);
int
*
last_iter
=
it
->
mutable_data
<
int
>
(
platform
::
CPUPlace
());
int
*
current_iter
=
iter
->
mutable_data
<
int
>
(
platform
::
CPUPlace
());
auto
*
scale_list
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
scale
=
(
T
)
FindAbsMaxGpu
(
device_ctx
,
in
->
data
<
float
>
(),
in
->
numel
());
scale
=
FindRangeAbsMax
(
device_ctx
,
scale_list
,
saving_scale
,
scale
,
window_size
,
current_iter
[
0
]);
(
*
current_iter
)
=
(
*
last_iter
)
+
1
;
}
}
else
if
(
quantize_type
==
std
::
string
(
"moving_average_abs_max"
))
{
auto
*
moving_scale
=
const_cast
<
framework
::
Tensor
*>
(
context
.
Input
<
framework
::
Tensor
>
(
"InMovingScale"
));
if
(
is_test
)
{
scale
=
moving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
}
else
{
scale
=
(
T
)
FindAbsMaxGpu
(
device_ctx
,
in
->
data
<
float
>
(),
in
->
numel
());
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
scale
=
FindMovingAverageAbsMmax
(
const_cast
<
framework
::
Tensor
*>
(
moving_scale
),
saving_scale
,
scale
);
}
}
ApplySaturateGpu
<
T
>
(
device_ctx
,
in
->
numel
(),
in
->
data
<
T
>
(),
tensor
->
mutable_data
<
T
>
(
in
->
place
()),
-
scale
,
scale
);
scale
=
bin_cnt
/
scale
;
auto
&
dev
=
template
<
typename
T
>
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
auto
eigen_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
tensor
);
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
auto
eigen_in
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
tensor
);
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
eigen_out
.
device
(
dev
)
=
(
scale
*
eigen_in
).
round
();
const
int
bin_cnt
,
framework
::
Tensor
*
out
)
{
int
num
=
in
.
numel
();
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ClipAndQuantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
out_data
);
}
}
};
};
template
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
float
>;
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
fake_quantize
,
namespace
ops
=
paddle
::
operators
;
paddle
::
operators
::
FakeQuantizeCUDAKernel
<
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_abs_max
,
paddle
::
operators
::
FakeQuantizeCUDAKernel
<
ops
::
FakeQuantizeAbsMaxKernel
<
CUDA
,
float
>
);
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_range_abs_max
,
ops
::
FakeQuantizeRangeAbsMaxKernel
<
CUDA
,
float
>
);
paddle/fluid/operators/fake_quantize_op.h
浏览文件 @
9bd933d3
...
@@ -17,137 +17,91 @@ limitations under the License. */
...
@@ -17,137 +17,91 @@ limitations under the License. */
#include <string>
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/transform.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
platform
::
Transform
;
template
<
typename
DeviceContext
,
typename
T
>
struct
FindAbsMaxFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
T
*
in
,
const
int
num
,
T
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
FakeQuantizeKernel
:
public
framework
::
OpKernel
<
T
>
{
struct
ClipAndFakeQuantFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
struct
FindRangeAbsMaxFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
cur_scale
,
const
framework
::
Tensor
&
last_scale
,
const
framework
::
Tensor
&
iter
,
const
int
window_size
,
framework
::
Tensor
*
scales_arr
,
framework
::
Tensor
*
out_scale
);
};
template
<
typename
DeviceContext
,
typename
T
>
class
FakeQuantizeAbsMaxKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
T
FindAbsMax
(
framework
::
Tensor
*
in
,
int
n
)
const
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
T
*
p
=
in
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
T
abs_max
=
(
T
)
0.00000001
;
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
T
tmp
=
fabs
(
p
[
i
]);
if
(
tmp
>
abs_max
)
abs_max
=
tmp
;
}
return
T
(
abs_max
);
}
T
FindRangeAbsMax
(
framework
::
Tensor
*
scale_list
,
framework
::
Tensor
*
out_scale
,
const
T
&
cur_scale
,
int
window_size
,
int
current_iter
)
const
{
T
*
sl
=
scale_list
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
T
remove_tmp
=
sl
[
current_iter
];
sl
[
current_iter
]
=
cur_scale
;
T
&
max_scale
=
out_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
if
(
max_scale
<
cur_scale
)
{
max_scale
=
cur_scale
;
}
else
if
(
fabs
(
remove_tmp
-
max_scale
)
<
1e-6
)
{
int
size
=
(
current_iter
>
window_size
)
?
window_size
:
current_iter
;
max_scale
=
T
(
FindAbsMax
(
scale_list
,
size
));
}
return
max_scale
;
}
T
FindMovingAverageAbsMmax
(
framework
::
Tensor
*
in_scale
,
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
framework
::
Tensor
*
out_scale
,
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
const
T
&
cur_scale
)
const
{
T
*
out_s
=
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
ins
=
in_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
T
*
outs
=
out_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
());
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
outs
[
0
]
=
0.9
*
cur_scale
+
0.1
*
ins
[
0
];
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
return
T
(
outs
[
0
]);
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
const
T
*
in_data
=
in
->
data
<
T
>
();
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in_data
,
in
->
numel
(),
out_s
);
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
out
);
}
}
};
virtual
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
{
template
<
typename
DeviceContext
,
typename
T
>
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
class
FakeQuantizeRangeAbsMaxKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
const
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
auto
*
in_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"InScale"
);
tensor
->
mutable_data
<
T
>
(
in
->
place
());
auto
*
oms_tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
oms_tensor
->
mutable_data
<
T
>
(
in
->
place
());
auto
quantize_type
=
static_cast
<
std
::
string
>
(
context
.
Attr
<
std
::
string
>
(
"quantize_type"
));
if
(
quantize_type
==
std
::
string
(
"range_abs_max"
))
{
auto
*
oss_tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
oss_tensor
->
mutable_data
<
T
>
(
context
.
Input
<
framework
::
Tensor
>
(
"InScales"
)
->
place
());
auto
*
oci_tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
);
oci_tensor
->
mutable_data
<
T
>
(
context
.
Input
<
framework
::
Tensor
>
(
"InCurrentIter"
)
->
place
());
}
T
scale
=
static_cast
<
T
>
(
1
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
int
window_size
=
context
.
Attr
<
int
>
(
"window_size"
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev
=
// testing
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
if
(
is_test
)
{
auto
raw_in
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
in
);
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
in_scale
,
if
(
quantize_type
==
std
::
string
(
"abs_max"
))
{
bin_cnt
,
out
);
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
return
;
auto
scale_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
saving_scale
);
scale_out
.
device
(
dev
)
=
raw_in
.
abs
().
maximum
();
scale
=
scale_out
(
0
);
auto
&
device_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
*
scale_list
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
math
::
SetConstant
<
DeviceContext
,
T
>
scalar
;
scale_list
->
mutable_data
<
T
>
(
context
.
GetPlace
());
scalar
(
device_ctx
,
scale_list
,
static_cast
<
T
>
(
0
));
auto
*
iter
=
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
);
iter
->
mutable_data
<
T
>
(
context
.
GetPlace
());
scalar
(
device_ctx
,
iter
,
static_cast
<
T
>
(
0
));
}
else
if
(
quantize_type
==
std
::
string
(
"range_abs_max"
))
{
auto
*
moving_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"InMovingScale"
);
if
(
is_test
)
{
scale
=
moving_scale
->
data
<
T
>
()[
0
];
}
else
{
auto
*
it
=
context
.
Input
<
framework
::
Tensor
>
(
"InCurrentIter"
);
auto
*
iter
=
context
.
Output
<
framework
::
Tensor
>
(
"OutCurrentIter"
);
const
int
*
last_iter
=
it
->
data
<
int
>
();
int
*
current_iter
=
iter
->
mutable_data
<
int
>
(
platform
::
CPUPlace
());
auto
*
scale_list
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
auto
scale_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
saving_scale
);
scale_out
.
device
(
dev
)
=
raw_in
.
abs
().
maximum
();
scale
=
saving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
scale
=
FindRangeAbsMax
(
scale_list
,
saving_scale
,
scale
,
window_size
,
current_iter
[
0
]);
saving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
]
=
scale
;
(
*
current_iter
)
=
(
*
last_iter
)
+
1
;
}
}
else
if
(
quantize_type
==
std
::
string
(
"moving_average_abs_max"
))
{
auto
*
moving_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"InMovingScale"
);
if
(
is_test
)
{
scale
=
moving_scale
->
data
<
T
>
()[
0
];
}
else
{
auto
*
saving_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutMovingScale"
);
auto
scale_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
saving_scale
);
scale_out
.
device
(
dev
)
=
raw_in
.
abs
().
maximum
();
scale
=
saving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
];
scale
=
FindMovingAverageAbsMmax
(
const_cast
<
framework
::
Tensor
*>
(
moving_scale
),
saving_scale
,
scale
);
saving_scale
->
mutable_data
<
T
>
(
platform
::
CPUPlace
())[
0
]
=
scale
;
}
}
}
Transform
<
DeviceContext
>
trans
;
// training
trans
(
context
.
template
device_context
<
DeviceContext
>(),
in
->
data
<
T
>
(),
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
in
->
data
<
T
>
()
+
in
->
numel
(),
tensor
->
mutable_data
<
T
>
(
in
->
place
()),
auto
*
out_scales
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
ClipFunctor
<
T
>
(
-
scale
,
scale
));
auto
*
iter
=
context
.
Input
<
framework
::
Tensor
>
(
"Iter"
);
auto
eigen_out
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
tensor
);
auto
eigen_in
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
tensor
);
int
window_size
=
context
.
Attr
<
int
>
(
"window_size"
);
eigen_out
.
device
(
dev
)
=
(
bin_cnt
/
scale
*
eigen_in
).
round
();
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
framework
::
Tensor
cur_scale
;
T
*
cur_scale_data
=
cur_scale
.
mutable_data
<
T
>
({
1
},
context
.
GetPlace
());
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
->
data
<
T
>
(),
in
->
numel
(),
cur_scale_data
);
FindRangeAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
cur_scale
,
*
in_scale
,
*
iter
,
window_size
,
out_scales
,
out_scale
);
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
out
);
}
}
};
};
...
...
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
浏览文件 @
9bd933d3
...
@@ -21,28 +21,41 @@ from op_test import OpTest
...
@@ -21,28 +21,41 @@ from op_test import OpTest
class
TestFakeQuantizeOp
(
OpTest
):
class
TestFakeQuantizeOp
(
OpTest
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
"fake_quantize"
self
.
op_type
=
"fake_quantize_abs_max"
self
.
attrs
=
{
'bit_length'
:
8
}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
124
,
240
)).
astype
(
"float32"
),
}
scale
=
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
])).
astype
(
"float32"
)
self
.
outputs
=
{
'Out'
:
np
.
round
(
self
.
inputs
[
'X'
]
/
scale
*
(
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
)),
'OutScale'
:
np
.
array
(
scale
).
astype
(
"float32"
),
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestFakeQuantizeOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"fake_quantize_range_abs_max"
self
.
attrs
=
{
self
.
attrs
=
{
'bit_length'
:
8
,
'bit_length'
:
int
(
5
)
,
'
quantize_type'
:
'abs_max'
,
'
window_size'
:
int
(
1
)
,
'
window_size'
:
10000
'
is_test'
:
False
}
}
self
.
inputs
=
{
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
10
,
10
)).
astype
(
"float32"
),
'X'
:
np
.
random
.
random
((
8
,
16
,
7
,
7
)).
astype
(
"float32"
),
'InScales'
:
np
.
zeros
(
self
.
attrs
[
'window_size'
]).
astype
(
"float32"
),
'Iter'
:
np
.
zeros
(
1
).
astype
(
"int64"
),
'InCurrentIter'
:
np
.
zeros
(
1
).
astype
(
"float32"
),
'InScale'
:
np
.
zeros
(
1
).
astype
(
"float32"
)
'InMovingScale'
:
np
.
zeros
(
1
).
astype
(
"float32"
)
}
self
.
scale
=
{
'abs_max'
:
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
])).
astype
(
"float32"
)
}
}
scale
=
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
])).
astype
(
"float32"
)
out_scales
=
np
.
zeros
(
self
.
attrs
[
'window_size'
]).
astype
(
"float32"
)
out_scales
[
0
]
=
scale
self
.
outputs
=
{
self
.
outputs
=
{
'Out'
:
np
.
round
(
self
.
inputs
[
'X'
]
/
s
elf
.
scale
[
'abs_max'
]
*
(
'Out'
:
np
.
round
(
self
.
inputs
[
'X'
]
/
s
cale
*
(
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
)),
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
)),
'OutScales'
:
np
.
zeros
(
self
.
attrs
[
'window_size'
]).
astype
(
"float32"
),
'OutScale'
:
scale
,
'OutMovingScale'
:
'OutScales'
:
out_scales
,
np
.
array
([
self
.
scale
[
'abs_max'
]]).
astype
(
"float32"
),
'OutCurrentIter'
:
np
.
zeros
(
1
).
astype
(
"float32"
)
}
}
def
test_check_output
(
self
):
def
test_check_output
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录