Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5420cf95
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5420cf95
编写于
3月 12, 2019
作者:
Z
Zhen Wang
提交者:
GitHub
3月 12, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #16070 from wzzju/channel_wise_quant_op
Add channel wise quant op and channel wise dequant op.
上级
4e8c03bd
8063b31e
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
302 addition
and
0 deletion
+302
-0
paddle/fluid/operators/fake_dequantize_op.cc
paddle/fluid/operators/fake_dequantize_op.cc
+66
-0
paddle/fluid/operators/fake_dequantize_op.cu
paddle/fluid/operators/fake_dequantize_op.cu
+4
-0
paddle/fluid/operators/fake_dequantize_op.h
paddle/fluid/operators/fake_dequantize_op.h
+38
-0
paddle/fluid/operators/fake_quantize_op.cc
paddle/fluid/operators/fake_quantize_op.cc
+61
-0
paddle/fluid/operators/fake_quantize_op.cu
paddle/fluid/operators/fake_quantize_op.cu
+2
-0
paddle/fluid/operators/fake_quantize_op.h
paddle/fluid/operators/fake_quantize_op.h
+33
-0
python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py
...n/paddle/fluid/tests/unittests/test_fake_dequantize_op.py
+74
-0
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
+24
-0
未找到文件。
paddle/fluid/operators/fake_dequantize_op.cc
浏览文件 @
5420cf95
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fake_dequantize_op.h"
#include <string>
#include <vector>
namespace
paddle
{
namespace
operators
{
...
...
@@ -76,6 +77,63 @@ $$Out = \frac{scale*X}{ max_range }$$
}
};
class
FakeChannelWiseDequantizeMaxAbsOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of FakeChannelWiseDequantizeMaxAbsOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInputs
(
"Scales"
),
"Input(Scales) of FakeChannelWiseDequantizeMaxAbsOp "
"should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of FakeChannelWiseDequantizeMaxAbsOp should not be null."
);
ctx
->
ShareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
};
class
FakeChannelWiseDequantizeMaxAbsOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) The input with float-32/64 type is the "
"low precision tensor."
);
AddInput
(
"Scales"
,
"(Tensors) The scales in quantization stage. "
"Now, `Scales` is a vector with at most two tensors. "
"If Scales has two elements, the second tensor should only have "
"one value."
)
.
AsDuplicable
();
AddOutput
(
"Out"
,
"(Tensor) The output is the dequantized high "
"precision tensor."
);
AddAttr
<
std
::
vector
<
int
>>
(
"quant_bits"
,
"Quantization bit numbers in quantization stage. "
"The size of `quant_bits` should be equal to the size of `Scales`."
)
.
SetDefault
({
8
});
AddComment
(
R"DOC(
FakeChannelWiseDequantizeMaxAbsOp operator.
This calculation is an opposite operation of FakeChannelWiseQuantizeMaxAbsOp:
$$Out_c = \frac{X_c\prod_{i=1}^{n}Scales_{ic}}{\prod_{i=1}^{n}(2^{quant\_bits_i-1}-1)}$$
In the above formula, the range value of $c$ can be represented as $0 \leq c \lt \ the\ channel\ number\ of\ X$.
Besides, the size of $quant\_bits$ should be equal to the size of $Scales$, and it is called $n$ in the formula.
Notes: In general, the per-channel quantization is only applied to weights and the activations use per-layer quantization.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
...
...
@@ -88,3 +146,11 @@ REGISTER_OPERATOR(fake_dequantize_max_abs, ops::FakeDequantizeMaxAbsOp,
REGISTER_OP_CPU_KERNEL
(
fake_dequantize_max_abs
,
ops
::
FakeDequantizeMaxAbsKernel
<
CPU
,
float
>
,
ops
::
FakeDequantizeMaxAbsKernel
<
CPU
,
double
>
);
REGISTER_OPERATOR
(
fake_channel_wise_dequantize_max_abs
,
ops
::
FakeChannelWiseDequantizeMaxAbsOp
,
ops
::
FakeChannelWiseDequantizeMaxAbsOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
fake_channel_wise_dequantize_max_abs
,
ops
::
FakeChannelWiseDequantizeMaxAbsKernel
<
CPU
,
float
>
,
ops
::
FakeChannelWiseDequantizeMaxAbsKernel
<
CPU
,
double
>
);
paddle/fluid/operators/fake_dequantize_op.cu
浏览文件 @
5420cf95
...
...
@@ -55,3 +55,7 @@ using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL
(
fake_dequantize_max_abs
,
ops
::
FakeDequantizeMaxAbsKernel
<
CUDA
,
float
>
,
ops
::
FakeDequantizeMaxAbsKernel
<
CUDA
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_channel_wise_dequantize_max_abs
,
ops
::
FakeChannelWiseDequantizeMaxAbsKernel
<
CUDA
,
float
>
,
ops
::
FakeChannelWiseDequantizeMaxAbsKernel
<
CUDA
,
double
>
);
paddle/fluid/operators/fake_dequantize_op.h
浏览文件 @
5420cf95
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -45,5 +46,42 @@ class FakeDequantizeMaxAbsKernel : public framework::OpKernel<T> {
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
FakeChannelWiseDequantizeMaxAbsKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
virtual
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
*
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
scales
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"Scales"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
PADDLE_ENFORCE_EQ
(
scales
[
0
]
->
numel
(),
in
->
dims
()[
0
],
"The number of first scale values must be the same with "
"first dimension value of Input(X)."
);
auto
quant_bits
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"quant_bits"
);
int
max_range
=
std
::
pow
(
2
,
quant_bits
[
0
]
-
1
)
-
1
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
auto
dequant
=
DequantizeFunctor
<
DeviceContext
,
T
>
();
for
(
int64_t
i
=
0
;
i
<
in
->
dims
()[
0
];
i
++
)
{
framework
::
Tensor
one_channel_in
=
in
->
Slice
(
i
,
i
+
1
);
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
framework
::
Tensor
one_channel_scale
=
scales
[
0
]
->
Slice
(
i
,
i
+
1
);
dequant
(
dev_ctx
,
&
one_channel_in
,
&
one_channel_scale
,
static_cast
<
T
>
(
max_range
),
&
one_channel_out
);
}
if
(
scales
.
size
()
==
2
)
{
PADDLE_ENFORCE_EQ
(
scales
[
1
]
->
numel
(),
1
,
"The second scale tensor should only have one value at now."
);
max_range
=
std
::
pow
(
2
,
quant_bits
[
1
]
-
1
)
-
1
;
dequant
(
dev_ctx
,
out
,
scales
[
1
],
static_cast
<
T
>
(
max_range
),
out
);
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/fake_quantize_op.cc
浏览文件 @
5420cf95
...
...
@@ -134,6 +134,60 @@ $$Out = round(X/scale * range)$$
}
};
class
FakeChannelWiseQuantizeAbsMaxOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of FakeChannelWiseQuantizeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of FakeChannelWiseQuantizeOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"OutScales"
),
"Output(Scales) of FakeChannelWiseQuantizeOp should not be null."
);
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"OutScales"
,
{
ctx
->
GetInputDim
(
"X"
)[
0
]});
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
type
(),
ctx
.
GetPlace
());
}
};
class
FakeChannelWiseQuantizeAbsMaxOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) Input is float data type."
);
AddOutput
(
"Out"
,
"(Tensor) Output of quantized low level tensor, "
"but also saved as float data type."
);
AddOutput
(
"OutScales"
,
"(Tensor) Current channel wise scale"
);
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8)"
)
.
SetDefault
(
8
)
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
PADDLE_ENFORCE
(
bit_length
>=
1
&&
bit_length
<=
16
,
"'bit_length' should be between 1 and 16."
);
});
AddComment
(
R"DOC(
The scale of FakeChannelWiseQuantize operator is a vector.
In detail, each channel of the input X has a scale value.
$$scale_c = max(abs(X_c))$$
$$range = 2^{bit\_length - 1} - 1$$
$$Out_c = round(\frac{X_c * range} {scale_c})$$
In above three formulas, the range value of c is as follow:
$$0 \leq c \lt \ the\ channel\ number\ of\ X$$
)DOC"
);
}
};
class
FakeQuantizeRangeAbsMaxOp
:
public
framework
::
OperatorWithKernel
{
public:
FakeQuantizeRangeAbsMaxOp
(
const
std
::
string
&
type
,
...
...
@@ -218,3 +272,10 @@ REGISTER_OPERATOR(fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
fake_quantize_range_abs_max
,
ops
::
FakeQuantizeRangeAbsMaxKernel
<
CPU
,
float
>
);
REGISTER_OPERATOR
(
fake_channel_wise_quantize_abs_max
,
ops
::
FakeChannelWiseQuantizeAbsMaxOp
,
ops
::
FakeChannelWiseQuantizeAbsMaxOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
fake_channel_wise_quantize_abs_max
,
ops
::
FakeChannelWiseQuantizeAbsMaxKernel
<
CPU
,
float
>
);
paddle/fluid/operators/fake_quantize_op.cu
浏览文件 @
5420cf95
...
...
@@ -174,5 +174,7 @@ namespace ops = paddle::operators;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_abs_max
,
ops
::
FakeQuantizeAbsMaxKernel
<
CUDA
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_channel_wise_quantize_abs_max
,
ops
::
FakeChannelWiseQuantizeAbsMaxKernel
<
CUDA
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_range_abs_max
,
ops
::
FakeQuantizeRangeAbsMaxKernel
<
CUDA
,
float
>
);
paddle/fluid/operators/fake_quantize_op.h
浏览文件 @
5420cf95
...
...
@@ -63,6 +63,39 @@ class FakeQuantizeAbsMaxKernel : public framework::OpKernel<T> {
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
FakeChannelWiseQuantizeAbsMaxKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out_scales
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
T
*
out_scales_data
=
out_scales
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
find_abs_max
=
FindAbsMaxFunctor
<
DeviceContext
,
T
>
();
for
(
int64_t
i
=
0
;
i
<
in
->
dims
()[
0
];
i
++
)
{
framework
::
Tensor
one_channel
=
in
->
Slice
(
i
,
i
+
1
);
const
T
*
one_channel_data
=
one_channel
.
data
<
T
>
();
find_abs_max
(
dev_ctx
,
one_channel_data
,
one_channel
.
numel
(),
&
out_scales_data
[
i
]);
}
auto
clip_quant
=
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
();
for
(
int64_t
i
=
0
;
i
<
in
->
dims
()[
0
];
i
++
)
{
framework
::
Tensor
one_channel_in
=
in
->
Slice
(
i
,
i
+
1
);
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
framework
::
Tensor
one_channel_scale
=
out_scales
->
Slice
(
i
,
i
+
1
);
clip_quant
(
dev_ctx
,
one_channel_in
,
one_channel_scale
,
bin_cnt
,
&
one_channel_out
);
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
FakeQuantizeRangeAbsMaxKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py
浏览文件 @
5420cf95
...
...
@@ -31,6 +31,80 @@ def dequantize_max_abs(x, scale, max_range):
return
y
def
channel_wise_quantize_max_abs
(
x
,
quant_bit
=
8
):
scales
=
[]
for
i
in
range
(
x
.
shape
[
0
]):
scales
.
append
(
np
.
max
(
np
.
abs
(
x
[
i
])).
astype
(
"float32"
))
y
=
x
.
copy
()
max_range
=
math
.
pow
(
2
,
quant_bit
-
1
)
-
1
for
i
,
scale
in
enumerate
(
scales
):
y
[
i
]
=
np
.
round
(
y
[
i
]
/
scale
*
max_range
)
return
y
,
scales
def
channel_wise_dequantize_max_abs
(
x
,
scales
,
quant_bits
,
activation_scale
=
None
):
y
=
x
.
copy
()
for
i
in
range
(
x
.
shape
[
0
]):
y
[
i
]
=
(
scales
[
i
]
/
(
math
.
pow
(
2
,
quant_bits
[
0
]
-
1
)
-
1
))
*
y
[
i
]
if
activation_scale
is
not
None
:
y
*=
activation_scale
/
(
math
.
pow
(
2
,
quant_bits
[
1
]
-
1
)
-
1
)
return
y
class
TestFakeChannelWiseDequantizeMaxAbsOpTwoScales
(
OpTest
):
def
set_args
(
self
):
self
.
quant_bits
=
[
8
,
8
]
self
.
data_type
=
"float32"
self
.
activation_scale
=
0.7861
def
setUp
(
self
):
self
.
set_args
()
self
.
op_type
=
"fake_channel_wise_dequantize_max_abs"
x
=
np
.
random
.
randn
(
4
,
3
,
64
,
64
).
astype
(
self
.
data_type
)
yq
,
scales
=
channel_wise_quantize_max_abs
(
x
,
self
.
quant_bits
[
0
])
ydq
=
channel_wise_dequantize_max_abs
(
yq
,
scales
,
self
.
quant_bits
,
self
.
activation_scale
)
self
.
inputs
=
{
'X'
:
yq
,
'Scales'
:
[(
"scales0"
,
np
.
array
(
scales
).
astype
(
self
.
data_type
)),
(
"scales1"
,
np
.
array
(
[
self
.
activation_scale
]).
astype
(
self
.
data_type
))]
}
self
.
attrs
=
{
'quant_bits'
:
self
.
quant_bits
}
self
.
outputs
=
{
'Out'
:
ydq
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestFakeChannelWiseDequantizeMaxAbsOpOneScale
(
OpTest
):
def
set_args
(
self
):
self
.
quant_bits
=
[
8
]
self
.
data_type
=
"float32"
def
setUp
(
self
):
self
.
set_args
()
self
.
op_type
=
"fake_channel_wise_dequantize_max_abs"
x
=
np
.
random
.
randn
(
4
,
3
,
64
,
64
).
astype
(
self
.
data_type
)
yq
,
scales
=
channel_wise_quantize_max_abs
(
x
,
self
.
quant_bits
[
0
])
ydq
=
channel_wise_dequantize_max_abs
(
yq
,
scales
,
self
.
quant_bits
)
self
.
inputs
=
{
'X'
:
yq
,
'Scales'
:
[(
"scales0"
,
np
.
array
(
scales
).
astype
(
self
.
data_type
))]
}
self
.
attrs
=
{
'quant_bits'
:
self
.
quant_bits
}
self
.
outputs
=
{
'Out'
:
ydq
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestFakeDequantizeMaxAbsOp
(
OpTest
):
def
set_args
(
self
):
self
.
num_bits
=
8
...
...
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
浏览文件 @
5420cf95
...
...
@@ -35,6 +35,30 @@ class TestFakeQuantizeOp(OpTest):
self
.
check_output
()
class
TestFakeChannelWiseQuantizeOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"fake_channel_wise_quantize_abs_max"
self
.
attrs
=
{
'bit_length'
:
8
}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
4
,
3
,
64
,
64
)).
astype
(
"float32"
),
}
scales
=
[]
for
i
in
range
(
self
.
inputs
[
'X'
].
shape
[
0
]):
scales
.
append
(
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
][
i
])).
astype
(
"float32"
))
outputs
=
self
.
inputs
[
'X'
].
copy
()
for
i
,
scale
in
enumerate
(
scales
):
outputs
[
i
]
=
np
.
round
(
outputs
[
i
]
/
scale
*
(
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
))
self
.
outputs
=
{
'Out'
:
outputs
,
'OutScales'
:
np
.
array
(
scales
).
astype
(
"float32"
),
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestFakeQuantizeRangeAbsMaxOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"fake_quantize_range_abs_max"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录